From 08249aec6aa77b1d02b1a2d37492a879eb7b14e9 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Mon, 25 Jun 2018 20:24:02 +0200 Subject: [PATCH] dns: resolve localhost locally --- internal/dns/dns.go | 37 ++++++++++++++++++++++++++ internal/dns/dns_test.go | 57 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index e19db7a..f855cbf 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -175,10 +175,36 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { s.prom.questions.Observe(float64(len(r.Question))) if len(r.Question) == 1 { // TODO: answer all questions we can answer q := r.Question[0] + if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET { + if q.Name == "localhost." { + s.prom.upstream.WithLabelValues("local").Inc() + rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1") + if err != nil { + log.Fatal(err) + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + return + } + } if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET { name := strings.TrimSuffix(q.Name, ".") name = strings.TrimSuffix(name, "."+s.domain) + if q.Name == "localhost." { + s.prom.upstream.WithLabelValues("local").Inc() + rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") + if err != nil { + log.Fatal(err) + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + return + } if !strings.Contains(name, ".") { s.prom.upstream.WithLabelValues("local").Inc() if host, ok := s.hostByName(name); ok { @@ -214,6 +240,17 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) return } + if strings.HasSuffix(q.Name, "127.in-addr.arpa.") { + rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.") + if err != nil { + log.Fatal(err) + } + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + return + } // Send an authoritative NXDOMAIN for local names: m := new(dns.Msg) m.SetReply(r) diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index 1a79b2b..6df406d 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -72,7 +72,7 @@ func TestDHCP(t *testing.T) { if _, ok := a.(*dns.A); !ok { t.Fatalf("unexpected response type: got %T, want dns.A", a) } - if got, want := a.(*dns.A).A.To4(), (net.IP{192, 168, 42, 23}); !bytes.Equal(got, want) { + if got, want := a.(*dns.A).A, net.ParseIP("192.168.42.23"); !got.Equal(want) { t.Fatalf("unexpected response IP: got %v, want %v", got, want) } }) @@ -107,7 +107,7 @@ func TestHostname(t *testing.T) { if _, ok := a.(*dns.A); !ok { t.Fatalf("unexpected response type: got %T, want dns.A", a) } - if got, want := a.(*dns.A).A.To4(), (net.IP{127, 0, 0, 2}); !bytes.Equal(got, want) { + if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) { t.Fatalf("unexpected response IP: got %v, want %v", got, want) } }) @@ -129,6 +129,59 @@ func TestHostname(t *testing.T) { }) } +func TestLocalhost(t *testing.T) { + r := &recorder{} + s := NewServer("127.0.0.2:0", "lan") + + t.Run("A", func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("localhost.", dns.TypeA) + s.handleRequest(r, m) + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) + } + a := r.response.Answer[0] + if _, ok := a.(*dns.A); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.1"); !got.Equal(want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } + }) + + t.Run("AAAA", func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("localhost.", dns.TypeAAAA) + s.handleRequest(r, m) + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) + } + a := r.response.Answer[0] + if _, ok := a.(*dns.AAAA); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got, want := a.(*dns.AAAA).AAAA, (net.ParseIP("::1")); !bytes.Equal(got, want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } + }) + + t.Run("PTR", func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR) + s.handleRequest(r, m) + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers: got %d, want %d", got, want) + } + a := r.response.Answer[0] + if _, ok := a.(*dns.PTR); !ok { + t.Fatalf("unexpected response type: got %T, want dns.PTR", a) + } + if got, want := a.(*dns.PTR).Ptr, "localhost."; got != want { + t.Fatalf("unexpected response record: got %q, want %q", got, want) + } + }) +} + func TestDHCPReverse(t *testing.T) { for _, test := range []struct { ip net.IP