From 89e1276ad4fed3ac91a4c138568c647fb0c1becc Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Tue, 26 Jun 2018 09:32:34 +0200 Subject: [PATCH] dns: simplify resolving code --- cmd/dnsd/dnsd.go | 12 ++- integration/dns/dns_test.go | 4 +- internal/dns/dns.go | 161 ++++++++++++++++-------------------- internal/dns/dns_test.go | 22 ++--- 4 files changed, 91 insertions(+), 108 deletions(-) diff --git a/cmd/dnsd/dnsd.go b/cmd/dnsd/dnsd.go index 93249bb9..cd135e1 100644 --- a/cmd/dnsd/dnsd.go +++ b/cmd/dnsd/dnsd.go @@ -28,7 +28,7 @@ var ( dnsListeners = multilisten.NewPool() ) -func updateListeners() error { +func updateListeners(mux *miekgdns.ServeMux) error { hosts, err := gokrazy.PrivateInterfaceAddrs() if err != nil { return err @@ -39,7 +39,11 @@ func updateListeners() error { }) dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { - return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}} + return &listenerAdapter{&miekgdns.Server{ + Addr: net.JoinHostPort(host, "53"), + Net: "udp", + Handler: mux, + }} }) return nil } @@ -73,11 +77,11 @@ func logic() error { log.Printf("cannot resolve DHCP hostnames: %v", err) } http.Handle("/metrics", srv.PrometheusHandler()) - updateListeners() + updateListeners(srv.Mux) ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGUSR1) for range ch { - if err := updateListeners(); err != nil { + if err := updateListeners(srv.Mux); err != nil { log.Printf("updateListeners: %v", err) } if err := readLeases(); err != nil { diff --git a/integration/dns/dns_test.go b/integration/dns/dns_test.go index 63f7479..846a7b7 100644 --- a/integration/dns/dns_test.go +++ b/integration/dns/dns_test.go @@ -13,8 +13,8 @@ import ( ) func TestDNS(t *testing.T) { - dns.NewServer("localhost:4453", "lan") - s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"} + srv := dns.NewServer("localhost:4453", "lan") + s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp", Handler: srv.Mux} go s.ListenAndServe() const port = 4453 dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1") diff --git a/internal/dns/dns.go b/internal/dns/dns.go index f504259..dfc22ed 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -19,6 +19,8 @@ import ( ) type Server struct { + Mux *dns.ServeMux + client *dns.Client domain string upstream string @@ -40,6 +42,7 @@ func NewServer(addr, domain string) *Server { hostname, _ := os.Hostname() ip, _, _ := net.SplitHostPort(addr) server := &Server{ + Mux: dns.NewServeMux(), client: &dns.Client{}, domain: domain, upstream: "8.8.8.8:53", @@ -73,7 +76,9 @@ func NewServer(addr, domain string) *Server { server.prom.registry.MustRegister(prometheus.NewGoCollector()) server.initHostsLocked() - dns.HandleFunc(".", server.handleRequest) + server.Mux.HandleFunc(".", server.handleRequest) + server.Mux.HandleFunc("lan.", server.handleInternal) + server.Mux.HandleFunc("localhost.", server.handleInternal) return server } @@ -167,98 +172,73 @@ func isLocalInAddrArpa(q string) bool { return local } -// TODO: require search domains to be present, then use HandleFunc("lan.", internalName) -// TODO: add test for non-A records on internal names, they should not go upstream -func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { - s.prom.queries.Inc() - s.prom.questions.Observe(float64(len(r.Question))) - if len(r.Question) == 1 { // TODO: answer all questions we can answer - q := r.Question[0] - if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET { - if q.Name == "localhost." { - s.prom.upstream.WithLabelValues("local").Inc() - rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1") - if err != nil { - log.Fatal(err) - } - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) - return - } +func (s *Server) resolve(q dns.Question) (dns.RR, error) { + if q.Qclass != dns.ClassINET { + return nil, nil + } + if q.Name == "localhost." { + if q.Qtype == dns.TypeAAAA { + return dns.NewRR(q.Name + " 3600 IN AAAA ::1") } - if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET { - name := strings.TrimSuffix(q.Name, ".") - name = strings.TrimSuffix(name, "."+s.domain) - - if q.Name == "localhost." { - s.prom.upstream.WithLabelValues("local").Inc() - rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") - if err != nil { - log.Fatal(err) - } - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) - return - } - if !strings.Contains(name, ".") { - s.prom.upstream.WithLabelValues("local").Inc() - if host, ok := s.hostByName(name); ok { - rr, err := dns.NewRR(q.Name + " 3600 IN A " + host) - if err != nil { - log.Fatal(err) - } - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) - return - } - // Send an authoritative NXDOMAIN for local names: - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeNameError) - w.WriteMsg(m) - return - } - } - if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET { - if isLocalInAddrArpa(q.Name) { - s.prom.upstream.WithLabelValues("local").Inc() - if host, ok := s.hostByIP(q.Name); ok { - rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain) - if err != nil { - log.Fatal(err) - } - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) - return - } - if strings.HasSuffix(q.Name, "127.in-addr.arpa.") { - rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.") - if err != nil { - log.Fatal(err) - } - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) - return - } - // Send an authoritative NXDOMAIN for local names: - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeNameError) - w.WriteMsg(m) - return - } + if q.Qtype == dns.TypeA { + return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") } } + if q.Qtype == dns.TypeA { + name := strings.TrimSuffix(q.Name, ".") + name = strings.TrimSuffix(name, "."+s.domain) + if host, ok := s.hostByName(name); ok { + return dns.NewRR(q.Name + " 3600 IN A " + host) + } + } + if q.Qtype == dns.TypePTR { + if host, ok := s.hostByIP(q.Name); ok { + return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain) + } + if strings.HasSuffix(q.Name, "127.in-addr.arpa.") { + return dns.NewRR(q.Name + " 3600 IN PTR localhost.") + } + } + return nil, nil +} + +func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { + s.prom.queries.Inc() + s.prom.questions.Observe(float64(len(r.Question))) + s.prom.upstream.WithLabelValues("local").Inc() + if len(r.Question) != 1 { // TODO: answer all questions we can answer + return + } + rr, err := s.resolve(r.Question[0]) + if err != nil { + log.Fatal(err) + } + if rr != nil { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + return + } + // Send an authoritative NXDOMAIN for local names: + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, dns.RcodeNameError) + w.WriteMsg(m) +} + +func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 1 { // TODO: answer all questions we can answer + q := r.Question[0] + if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) { + s.handleInternal(w, r) + return + } + } + + s.prom.queries.Inc() + s.prom.questions.Observe(float64(len(r.Question))) + s.prom.upstream.WithLabelValues("DNS").Inc() in, _, err := s.client.Exchange(r, s.upstream) if err != nil { @@ -268,5 +248,4 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { return // DNS has no reply for resolving errors } w.WriteMsg(in) - s.prom.upstream.WithLabelValues("DNS").Inc() } diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index 6df406d..0506b11 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -33,7 +33,7 @@ func TestNXDOMAIN(t *testing.T) { s := NewServer("localhost:0", "lan") m := new(dns.Msg) m.SetQuestion("foo.invalid.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want { t.Fatalf("unexpected rcode: got %v, want %v", got, want) } @@ -45,7 +45,7 @@ func TestResolveError(t *testing.T) { s.upstream = "266.266.266.266:53" m := new(dns.Msg) m.SetQuestion("foo.invalid.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if r.response != nil { t.Fatalf("r.response unexpectedly not nil: %v", r.response) } @@ -64,7 +64,7 @@ func TestDHCP(t *testing.T) { t.Run("xps.lan.", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("xps.lan.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers: got %d, want %d", got, want) } @@ -80,7 +80,7 @@ func TestDHCP(t *testing.T) { t.Run("notfound.lan.", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("notfound.lan.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := r.response.Rcode, dns.RcodeNameError; got != want { t.Fatalf("unexpected rcode: got %v, want %v", got, want) } @@ -99,7 +99,7 @@ func TestHostname(t *testing.T) { t.Run("A", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion(hostname+".lan.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) } @@ -115,7 +115,7 @@ func TestHostname(t *testing.T) { t.Run("PTR", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers: got %d, want %d", got, want) } @@ -136,7 +136,7 @@ func TestLocalhost(t *testing.T) { t.Run("A", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("localhost.", dns.TypeA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) } @@ -152,7 +152,7 @@ func TestLocalhost(t *testing.T) { t.Run("AAAA", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("localhost.", dns.TypeAAAA) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) } @@ -168,7 +168,7 @@ func TestLocalhost(t *testing.T) { t.Run("PTR", func(t *testing.T) { m := new(dns.Msg) m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers: got %d, want %d", got, want) } @@ -218,7 +218,7 @@ func TestDHCPReverse(t *testing.T) { }) m := new(dns.Msg) m.SetQuestion(test.question, dns.TypePTR) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := len(r.response.Answer), 1; got != want { t.Fatalf("unexpected number of answers: got %d, want %d", got, want) } @@ -237,7 +237,7 @@ func TestDHCPReverse(t *testing.T) { s := NewServer("localhost:0", "lan") m := new(dns.Msg) m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR) - s.handleRequest(r, m) + s.Mux.ServeDNS(r, m) if got, want := r.response.Rcode, dns.RcodeNameError; got != want { t.Fatalf("unexpected rcode: got %v, want %v", got, want) }