From 7d278289f08c3ee23778f1f1d01cc078992e9b65 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Tue, 23 Oct 2018 09:56:07 +0200 Subject: [PATCH] captured: directly call NextPacket() to prevent hanging reads Using Packets() spawns off a separate goroutine which calls NextPacket in a loop until io.EOF is returned. This goroutine will stick around after Close() returned, resulting in only the first wireshark connection working. --- cmd/captured/captured.go | 12 +++++++++--- cmd/captured/ssh.go | 11 ++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/cmd/captured/captured.go b/cmd/captured/captured.go index 851c0d4..ef9ab0d 100644 --- a/cmd/captured/captured.go +++ b/cmd/captured/captured.go @@ -20,6 +20,7 @@ import ( "container/ring" "context" "flag" + "fmt" "log" "os" "os/signal" @@ -47,17 +48,22 @@ func capturePackets(ctx context.Context) (chan gopacket.Packet, error) { for _, ifname := range []string{"uplink0", "lan0"} { handle, err := pcapgo.NewEthernetHandle(ifname) if err != nil { - return nil, err + return nil, fmt.Errorf("pcapgo.NewEthernetHandle(%v): %v", ifname, err) } if err := handle.SetBPF(instructions); err != nil { - return nil, err + return nil, fmt.Errorf("SetBPF: %v", err) } pkgsrc := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet) go func() { defer handle.Close() - for packet := range pkgsrc.Packets() { + for { + packet, err := pkgsrc.NextPacket() + if err != nil { + log.Printf("NextPacket: %v", err) + return + } select { case packets <- packet: case <-ctx.Done(): diff --git a/cmd/captured/ssh.go b/cmd/captured/ssh.go index f5602db..dc9d66c 100644 --- a/cmd/captured/ssh.go +++ b/cmd/captured/ssh.go @@ -61,20 +61,25 @@ type session struct { channel ssh.Channel } -func (s *session) request(req *ssh.Request, prb *packetRingBuffer) error { +func (s *session) request(req *ssh.Request, prb *packetRingBuffer) (err error) { switch req.Type { case "exec": if got, want := len(req.Payload), 4; got < want { return fmt.Errorf("exec request payload too short: got %d, want >= %d", got, want) } log.Printf("exec, wantReply %v, payload %q", req.WantReply, string(req.Payload[4:])) + defer func() { + if err != nil { + log.Printf("exec done: %v", err) + } + }() ctx, canc := context.WithCancel(context.Background()) defer canc() pcapw := pcapgo.NewWriter(s.channel) if err := pcapw.WriteFileHeader(1600, layers.LinkTypeEthernet); err != nil { - return err + return fmt.Errorf("pcapw.WriteFileHeader: %v", err) } prb.Lock() @@ -82,7 +87,7 @@ func (s *session) request(req *ssh.Request, prb *packetRingBuffer) error { buffered := prb.packetsLocked() prb.Unlock() if err != nil { - return err + return fmt.Errorf("capturePackets: %v", err) } req.Reply(true, nil)