From 92d995bf79fe048b922a230536749061a2f61285 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Tue, 1 Jan 2019 17:21:50 +0100 Subject: [PATCH] dns: return empty reply for non-A queries for DNS hostnames instead of NXDOMAIN, which is incorrect --- internal/dns/dns.go | 20 +++++++++++++++++--- internal/dns/dns_test.go | 12 ++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 5ba0c70..ee0b13f 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -16,6 +16,7 @@ package dns import ( + "errors" "log" "net" "net/http" @@ -193,7 +194,9 @@ func isLocalInAddrArpa(q string) bool { return local } -func (s *Server) resolve(q dns.Question) (dns.RR, error) { +var sentinelEmpty = errors.New("no answers") + +func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { if q.Qclass != dns.ClassINET { return nil, nil } @@ -205,11 +208,16 @@ func (s *Server) resolve(q dns.Question) (dns.RR, error) { return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") } } - if q.Qtype == dns.TypeA { + if q.Qtype == dns.TypeA || + q.Qtype == dns.TypeAAAA || + q.Qtype == dns.TypeMX { name := strings.TrimSuffix(q.Name, ".") name = strings.TrimSuffix(name, "."+s.domain) if host, ok := s.hostByName(name); ok { - return dns.NewRR(q.Name + " 3600 IN A " + host) + if q.Qtype == dns.TypeA { + return dns.NewRR(q.Name + " 3600 IN A " + host) + } + return nil, sentinelEmpty } } if q.Qtype == dns.TypePTR { @@ -232,6 +240,12 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { } rr, err := s.resolve(r.Question[0]) if err != nil { + if err == sentinelEmpty { + m := new(dns.Msg) + m.SetReply(r) + w.WriteMsg(m) + return + } log.Fatal(err) } if rr != nil { diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index c590102..809fdef 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -208,6 +208,18 @@ func TestHostnameDHCP(t *testing.T) { t.Fatalf("unexpected response record: got %q, want %q", got, want) } }) + + t.Run("AAAA", func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(hostname+".lan.", dns.TypeAAAA) + s.Mux.ServeDNS(r, m) + if got, want := r.response.MsgHdr.Rcode, dns.RcodeSuccess; got != want { + t.Fatalf("unexpected rcode: got %v, want %v", got, want) + } + if got, want := len(r.response.Answer), 0; got != want { + t.Fatalf("unexpected number of answers: got %d, want %d", got, want) + } + }) } func TestLocalhost(t *testing.T) {