From 2b3cf0bf613cb6f22b86ff03054cc0b39ef55adb Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Wed, 27 Jun 2018 19:44:39 +0200 Subject: [PATCH] captured: use multilisten --- cmd/captured/captured.go | 57 +++++++++++++++++------ cmd/captured/ssh.go | 81 ++++++++++++++++++--------------- internal/netconfig/netconfig.go | 9 ++-- 3 files changed, 92 insertions(+), 55 deletions(-) diff --git a/cmd/captured/captured.go b/cmd/captured/captured.go index 4f3032a..d4cea0e 100644 --- a/cmd/captured/captured.go +++ b/cmd/captured/captured.go @@ -7,14 +7,18 @@ import ( "context" "flag" "log" + "os" + "os/signal" "sync" + "syscall" + "router7/internal/multilisten" + + "github.com/gokrazy/gokrazy" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" - "golang.org/x/sync/errgroup" - _ "net/http/pprof" ) @@ -79,23 +83,48 @@ func (prb *packetRingBuffer) packetsLocked() []gopacket.Packet { return packets } +var sshListeners = multilisten.NewPool() + +func updateListeners(srv *server) error { + hosts, err := gokrazy.PrivateInterfaceAddrs() + if err != nil { + return err + } + + sshListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { + return srv.listenerFor(host) + }) + return nil +} + func logic() error { prb := newPacketRingBuffer(50000) + srv, err := newServer(prb) + if err != nil { + return err + } + if err := updateListeners(srv); err != nil { + return err + } - var eg errgroup.Group - eg.Go(func() error { return listenAndServe(prb) }) - eg.Go(func() error { - packets, err := capturePackets(context.Background()) - if err != nil { - return err + go func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGUSR1) + for range ch { + if err := updateListeners(srv); err != nil { + log.Printf("updateListeners: %v", err) + } } - for packet := range packets { - prb.writePacket(packet) - } - return nil - }) + }() - return eg.Wait() + packets, err := capturePackets(context.Background()) + if err != nil { + return err + } + for packet := range packets { + prb.writePacket(packet) + } + return nil } func main() { diff --git a/cmd/captured/ssh.go b/cmd/captured/ssh.go index 3fa1df3..29010e6 100644 --- a/cmd/captured/ssh.go +++ b/cmd/captured/ssh.go @@ -7,7 +7,6 @@ import ( "log" "net" - "github.com/gokrazy/gokrazy" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" "golang.org/x/crypto/ssh" @@ -104,7 +103,12 @@ func loadHostKey(path string) (ssh.Signer, error) { return ssh.ParsePrivateKey(b) } -func listenAndServe(prb *packetRingBuffer) error { +type server struct { + config *ssh.ServerConfig + prb *packetRingBuffer +} + +func newServer(prb *packetRingBuffer) (*server, error) { config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { return nil, nil // authorize all users @@ -113,53 +117,56 @@ func listenAndServe(prb *packetRingBuffer) error { signer, err := loadHostKey(*hostKeyPath) if err != nil { - return err + return nil, err } config.AddHostKey(signer) - accept := func(listener net.Listener) { - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("accept: %v", err) - continue - } + return &server{ + config: config, + prb: prb, + }, nil +} - go func(conn net.Conn) { - _, chans, reqs, err := ssh.NewServerConn(conn, config) - if err != nil { - log.Printf("handshake: %v", err) - return - } +func (s *server) listenerFor(host string) *serverListener { + return &serverListener{srv: s, host: host} +} - // discard all out of band requests - go ssh.DiscardRequests(reqs) +type serverListener struct { + srv *server + host string + ln net.Listener +} - for newChannel := range chans { - handleChannel(newChannel, prb) - } - }(conn) - } - } - - addrs, err := gokrazy.PrivateInterfaceAddrs() +func (sl *serverListener) ListenAndServe() error { + ln, err := net.Listen("tcp", net.JoinHostPort(sl.host, "5022")) if err != nil { return err } - - for _, addr := range addrs { - hostport := net.JoinHostPort(addr, "5022") - listener, err := net.Listen("tcp", hostport) + sl.ln = ln + for { + conn, err := ln.Accept() if err != nil { return err } - fmt.Printf("listening on %s\n", hostport) - go accept(listener) + + go func(conn net.Conn) { + _, chans, reqs, err := ssh.NewServerConn(conn, sl.srv.config) + if err != nil { + log.Printf("handshake: %v", err) + return + } + + // discard all out of band requests + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + handleChannel(newChannel, sl.srv.prb) + } + }(conn) } - - fmt.Printf("host key fingerprint: %s\n", ssh.FingerprintSHA256(signer.PublicKey())) - - select {} - return nil } + +func (sl *serverListener) Close() error { + return sl.ln.Close() +} diff --git a/internal/netconfig/netconfig.go b/internal/netconfig/netconfig.go index ddd2a76..d0760eb 100644 --- a/internal/netconfig/netconfig.go +++ b/internal/netconfig/netconfig.go @@ -600,10 +600,11 @@ func Apply(dir, root string) error { } for _, process := range []string{ - "dyndns", // depends on the public IPv4 address - "dnsd", // listens on private IPv4/IPv6 - "diagd", // listens on private IPv4/IPv6 - "backupd", // listens on private IPv4/IPv6 + "dyndns", // depends on the public IPv4 address + "dnsd", // listens on private IPv4/IPv6 + "diagd", // listens on private IPv4/IPv6 + "backupd", // listens on private IPv4/IPv6 + "captured", // listens on private IPv4/IPv6 } { if err := notify.Process("/user/"+process, syscall.SIGUSR1); err != nil { log.Printf("notifying %s: %v", process, err)