diff --git a/cmd/dnsd/dnsd.go b/cmd/dnsd/dnsd.go index f7122dd..93249bb9 100644 --- a/cmd/dnsd/dnsd.go +++ b/cmd/dnsd/dnsd.go @@ -6,12 +6,14 @@ import ( "flag" "io/ioutil" "log" + "net" "net/http" "os" "os/signal" "syscall" "github.com/gokrazy/gokrazy" + miekgdns "github.com/miekg/dns" "router7/internal/dhcp4d" "router7/internal/dns" @@ -21,18 +23,33 @@ import ( _ "net/http/pprof" ) +var ( + httpListeners = multilisten.NewPool() + dnsListeners = multilisten.NewPool() +) + func updateListeners() error { hosts, err := gokrazy.PrivateInterfaceAddrs() if err != nil { return err } - if net1, err := multilisten.IPv6Net1("/perm"); err == nil { - hosts = append(hosts, net1) - } - return multilisten.ListenAndServe(hosts, "8053", http.DefaultServeMux) + httpListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { + return &http.Server{Addr: net.JoinHostPort(host, "8053")} + }) + + dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { + return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}} + }) + return nil } +type listenerAdapter struct { + *miekgdns.Server +} + +func (a *listenerAdapter) Close() error { return a.Shutdown() } + func logic() error { // TODO: set correct upstream DNS resolver(s) ip, err := netconfig.LinkAddress("/perm", "lan0") @@ -59,17 +76,15 @@ func logic() error { updateListeners() ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGUSR1) - go func() { - for range ch { - if err := updateListeners(); err != nil { - log.Printf("updateListeners: %v", err) - } - if err := readLeases(); err != nil { - log.Printf("readLeases: %v", err) - } + for range ch { + if err := updateListeners(); err != nil { + log.Printf("updateListeners: %v", err) } - }() - return srv.ListenAndServe() + if err := readLeases(); err != nil { + log.Printf("readLeases: %v", err) + } + } + return nil } func main() { diff --git a/cmd/netconfigd/netconfigd.go b/cmd/netconfigd/netconfigd.go index 43e9ccd..07dc78f 100644 --- a/cmd/netconfigd/netconfigd.go +++ b/cmd/netconfigd/netconfigd.go @@ -3,6 +3,7 @@ package main import ( "flag" + "net" "net/http" "os" "os/signal" @@ -89,6 +90,8 @@ func init() { } } +var httpListeners = multilisten.NewPool() + func updateListeners() error { hosts, err := gokrazy.PrivateInterfaceAddrs() if err != nil { @@ -98,7 +101,10 @@ func updateListeners() error { hosts = append(hosts, net1) } - return multilisten.ListenAndServe(hosts, "8066", http.DefaultServeMux) + httpListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { + return &http.Server{Addr: net.JoinHostPort(host, "8066")} + }) + return nil } func logic() error { diff --git a/integration/dns/dns_test.go b/integration/dns/dns_test.go index 48bff6e..63f7479 100644 --- a/integration/dns/dns_test.go +++ b/integration/dns/dns_test.go @@ -8,10 +8,14 @@ import ( "testing" "router7/internal/dns" + + miekgdns "github.com/miekg/dns" ) func TestDNS(t *testing.T) { - go dns.NewServer("localhost:4453", "lan").ListenAndServe() + dns.NewServer("localhost:4453", "lan") + s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"} + go s.ListenAndServe() const port = 4453 dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1") dig.Stderr = os.Stderr diff --git a/internal/dns/dns.go b/internal/dns/dns.go index f855cbf..f504259 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -19,7 +19,6 @@ import ( ) type Server struct { - *dns.Server client *dns.Client domain string upstream string @@ -41,7 +40,6 @@ func NewServer(addr, domain string) *Server { hostname, _ := os.Hostname() ip, _, _ := net.SplitHostPort(addr) server := &Server{ - Server: &dns.Server{Addr: addr, Net: "udp"}, client: &dns.Client{}, domain: domain, upstream: "8.8.8.8:53", @@ -170,6 +168,7 @@ func isLocalInAddrArpa(q string) bool { } // TODO: require search domains to be present, then use HandleFunc("lan.", internalName) +// TODO: add test for non-A records on internal names, they should not go upstream func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { s.prom.queries.Inc() s.prom.questions.Observe(float64(len(r.Question))) diff --git a/internal/multilisten/multilisten.go b/internal/multilisten/multilisten.go new file mode 100644 index 0000000..2213699 --- /dev/null +++ b/internal/multilisten/multilisten.go @@ -0,0 +1,81 @@ +// Package multilisten implements listening on multiple addresses at once. +package multilisten + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "path/filepath" + "router7/internal/dhcp6" + "sync" +) + +type Listener interface { + ListenAndServe() error + Close() error +} + +type Pool struct { + mu sync.Mutex + listeners map[string]Listener +} + +func NewPool() *Pool { + return &Pool{ + listeners: make(map[string]Listener), + } +} + +func (p *Pool) ListenAndServe(hosts []string, listenerFor func(host string) Listener) { + p.mu.Lock() + defer p.mu.Unlock() + vanished := make(map[string]bool) + for host := range p.listeners { + vanished[host] = false + } + for _, host := range hosts { + if _, ok := p.listeners[host]; ok { + // confirm found + delete(vanished, host) + } else { + log.Printf("now listening on %s", host) + // add a new listener + ln := listenerFor(host) + p.listeners[host] = ln + go func(host string, ln Listener) { + err := ln.ListenAndServe() + log.Printf("listener for %q died: %v", host, err) + p.mu.Lock() + defer p.mu.Unlock() + delete(p.listeners, host) + }(host, ln) + } + } + for host := range vanished { + log.Printf("no longer listening on %s", host) + p.listeners[host].Close() + delete(p.listeners, host) + } +} + +// IPv6Net1 returns the IP address which router7 picks from the IPv6 prefix for +// itself, e.g. address 2a02:168:4a00::1 for prefix 2a02:168:4a00::/48. +func IPv6Net1(dir string) (string, error) { + b, err := ioutil.ReadFile(filepath.Join(dir, "dhcp6/wire/lease.json")) + if err != nil { + return "", err + } + var got dhcp6.Config + if err := json.Unmarshal(b, &got); err != nil { + return "", err + } + + for _, prefix := range got.Prefixes { + // pick the first address of the prefix, e.g. address 2a02:168:4a00::1 + // for prefix 2a02:168:4a00::/48 + prefix.IP[len(prefix.IP)-1] = 1 + return prefix.IP.String(), nil + } + return "", fmt.Errorf("no DHCPv6 prefix obtained") +}