diff --git a/cmd/captured/captured.go b/cmd/captured/captured.go index 9e9a0ab..22dc974 100644 --- a/cmd/captured/captured.go +++ b/cmd/captured/captured.go @@ -1,8 +1,19 @@ package main import ( + "container/ring" + "context" "flag" "log" + "sync" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcapgo" + + "golang.org/x/sync/errgroup" + + _ "net/http/pprof" ) var ( @@ -11,8 +22,78 @@ var ( "path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)") ) +func capturePackets(ctx context.Context) (chan gopacket.Packet, error) { + packets := make(chan gopacket.Packet) + for _, ifname := range []string{"uplink0", "lan0"} { + handle, err := pcapgo.OpenEthernet(ifname) + if err != nil { + return nil, err + } + + if err := handle.SetBPF(instructions); err != nil { + return nil, err + } + + pkgsrc := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet) + go func() { + defer handle.Close() + for packet := range pkgsrc.Packets() { + select { + case packets <- packet: + case <-ctx.Done(): + return + } + } + }() + } + return packets, nil +} + +type packetRingBuffer struct { + sync.Mutex + r *ring.Ring +} + +func newPacketRingBuffer(size int) *packetRingBuffer { + return &packetRingBuffer{ + r: ring.New(size), + } +} + +func (prb *packetRingBuffer) writePacket(p gopacket.Packet) { + prb.Lock() + defer prb.Unlock() + prb.r.Value = p + prb.r = prb.r.Next() +} + +func (prb *packetRingBuffer) packetsLocked() []gopacket.Packet { + packets := make([]gopacket.Packet, 0, prb.r.Len()) + prb.r.Do(func(x interface{}) { + if x != nil { + packets = append(packets, x.(gopacket.Packet)) + } + }) + return packets +} + func logic() error { - return listenAndServe() + prb := newPacketRingBuffer(5000) + + var eg errgroup.Group + eg.Go(func() error { return listenAndServe(prb) }) + eg.Go(func() error { + packets, err := capturePackets(context.Background()) + if err != nil { + return err + } + for packet := range packets { + prb.writePacket(packet) + } + return nil + }) + + return eg.Wait() } func main() { diff --git a/cmd/captured/ssh.go b/cmd/captured/ssh.go index 11ca384..3fa1df3 100644 --- a/cmd/captured/ssh.go +++ b/cmd/captured/ssh.go @@ -8,13 +8,12 @@ import ( "net" "github.com/gokrazy/gokrazy" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" "golang.org/x/crypto/ssh" ) -func handleChannel(newChannel ssh.NewChannel) { +func handleChannel(newChannel ssh.NewChannel, prb *packetRingBuffer) { if t := newChannel.ChannelType(); t != "session" { newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t)) return @@ -30,7 +29,7 @@ func handleChannel(newChannel ssh.NewChannel) { go func(channel ssh.Channel, requests <-chan *ssh.Request) { s := session{channel: channel} for req := range requests { - if err := s.request(req); err != nil { + if err := s.request(req, prb); err != nil { errmsg := []byte(err.Error()) // Append a trailing newline; the error message is // displayed as-is by ssh(1). @@ -49,7 +48,7 @@ type session struct { channel ssh.Channel } -func (s *session) request(req *ssh.Request) error { +func (s *session) request(req *ssh.Request, prb *packetRingBuffer) error { switch req.Type { case "exec": if got, want := len(req.Payload), 4; got < want { @@ -65,34 +64,22 @@ func (s *session) request(req *ssh.Request) error { return err } - packets := make(chan gopacket.Packet) - for _, ifname := range []string{"uplink0", "lan0"} { - handle, err := pcapgo.OpenEthernet(ifname) - //handle, err := pcap.OpenLive("uplink0", 1600, false /* promisc */, pcap.BlockForever) - if err != nil { - return err - } - - if err := handle.SetBPF(instructions); err != nil { - //if err := handle.SetBPFFilter("icmp6 or (udp and (port 67 or port 68 or port 546 or port 547))"); err != nil { - return err - } - - pkgsrc := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet) - go func() { - defer handle.Close() - for packet := range pkgsrc.Packets() { - select { - case packets <- packet: - case <-ctx.Done(): - return - } - } - }() + prb.Lock() + packets, err := capturePackets(ctx) + buffered := prb.packetsLocked() + prb.Unlock() + if err != nil { + return err } req.Reply(true, nil) + for _, packet := range buffered { + if err := pcapw.WritePacket(packet.Metadata().CaptureInfo, packet.Data()); err != nil { + return fmt.Errorf("pcap.WritePacket(): %v", err) + } + } + for packet := range packets { if err := pcapw.WritePacket(packet.Metadata().CaptureInfo, packet.Data()); err != nil { return fmt.Errorf("pcap.WritePacket(): %v", err) @@ -117,7 +104,7 @@ func loadHostKey(path string) (ssh.Signer, error) { return ssh.ParsePrivateKey(b) } -func listenAndServe() error { +func listenAndServe(prb *packetRingBuffer) error { config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { return nil, nil // authorize all users @@ -149,7 +136,7 @@ func listenAndServe() error { go ssh.DiscardRequests(reqs) for newChannel := range chans { - handleChannel(newChannel) + handleChannel(newChannel, prb) } }(conn) }