captured: implement a packet ring buffer

So that when you connect with Wireshark, you’ll see the most recent
packets (takes up to 7 MB of RAM).
This commit is contained in:
Michael Stapelberg 2018-06-17 17:47:26 +02:00
parent 2c302d976d
commit bb6b901b90
2 changed files with 99 additions and 31 deletions

View File

@ -1,8 +1,19 @@
package main
import (
"container/ring"
"context"
"flag"
"log"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcapgo"
"golang.org/x/sync/errgroup"
_ "net/http/pprof"
)
var (
@ -11,8 +22,78 @@ var (
"path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)")
)
func capturePackets(ctx context.Context) (chan gopacket.Packet, error) {
packets := make(chan gopacket.Packet)
for _, ifname := range []string{"uplink0", "lan0"} {
handle, err := pcapgo.OpenEthernet(ifname)
if err != nil {
return nil, err
}
if err := handle.SetBPF(instructions); err != nil {
return nil, err
}
pkgsrc := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
go func() {
defer handle.Close()
for packet := range pkgsrc.Packets() {
select {
case packets <- packet:
case <-ctx.Done():
return
}
}
}()
}
return packets, nil
}
type packetRingBuffer struct {
sync.Mutex
r *ring.Ring
}
func newPacketRingBuffer(size int) *packetRingBuffer {
return &packetRingBuffer{
r: ring.New(size),
}
}
func (prb *packetRingBuffer) writePacket(p gopacket.Packet) {
prb.Lock()
defer prb.Unlock()
prb.r.Value = p
prb.r = prb.r.Next()
}
func (prb *packetRingBuffer) packetsLocked() []gopacket.Packet {
packets := make([]gopacket.Packet, 0, prb.r.Len())
prb.r.Do(func(x interface{}) {
if x != nil {
packets = append(packets, x.(gopacket.Packet))
}
})
return packets
}
func logic() error {
return listenAndServe()
prb := newPacketRingBuffer(5000)
var eg errgroup.Group
eg.Go(func() error { return listenAndServe(prb) })
eg.Go(func() error {
packets, err := capturePackets(context.Background())
if err != nil {
return err
}
for packet := range packets {
prb.writePacket(packet)
}
return nil
})
return eg.Wait()
}
func main() {

View File

@ -8,13 +8,12 @@ import (
"net"
"github.com/gokrazy/gokrazy"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcapgo"
"golang.org/x/crypto/ssh"
)
func handleChannel(newChannel ssh.NewChannel) {
func handleChannel(newChannel ssh.NewChannel, prb *packetRingBuffer) {
if t := newChannel.ChannelType(); t != "session" {
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t))
return
@ -30,7 +29,7 @@ func handleChannel(newChannel ssh.NewChannel) {
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
s := session{channel: channel}
for req := range requests {
if err := s.request(req); err != nil {
if err := s.request(req, prb); err != nil {
errmsg := []byte(err.Error())
// Append a trailing newline; the error message is
// displayed as-is by ssh(1).
@ -49,7 +48,7 @@ type session struct {
channel ssh.Channel
}
func (s *session) request(req *ssh.Request) error {
func (s *session) request(req *ssh.Request, prb *packetRingBuffer) error {
switch req.Type {
case "exec":
if got, want := len(req.Payload), 4; got < want {
@ -65,34 +64,22 @@ func (s *session) request(req *ssh.Request) error {
return err
}
packets := make(chan gopacket.Packet)
for _, ifname := range []string{"uplink0", "lan0"} {
handle, err := pcapgo.OpenEthernet(ifname)
//handle, err := pcap.OpenLive("uplink0", 1600, false /* promisc */, pcap.BlockForever)
if err != nil {
return err
}
if err := handle.SetBPF(instructions); err != nil {
//if err := handle.SetBPFFilter("icmp6 or (udp and (port 67 or port 68 or port 546 or port 547))"); err != nil {
return err
}
pkgsrc := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
go func() {
defer handle.Close()
for packet := range pkgsrc.Packets() {
select {
case packets <- packet:
case <-ctx.Done():
return
}
}
}()
prb.Lock()
packets, err := capturePackets(ctx)
buffered := prb.packetsLocked()
prb.Unlock()
if err != nil {
return err
}
req.Reply(true, nil)
for _, packet := range buffered {
if err := pcapw.WritePacket(packet.Metadata().CaptureInfo, packet.Data()); err != nil {
return fmt.Errorf("pcap.WritePacket(): %v", err)
}
}
for packet := range packets {
if err := pcapw.WritePacket(packet.Metadata().CaptureInfo, packet.Data()); err != nil {
return fmt.Errorf("pcap.WritePacket(): %v", err)
@ -117,7 +104,7 @@ func loadHostKey(path string) (ssh.Signer, error) {
return ssh.ParsePrivateKey(b)
}
func listenAndServe() error {
func listenAndServe(prb *packetRingBuffer) error {
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
return nil, nil // authorize all users
@ -149,7 +136,7 @@ func listenAndServe() error {
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
handleChannel(newChannel)
handleChannel(newChannel, prb)
}
}(conn)
}