dns: resolve localhost locally

This commit is contained in:
Michael Stapelberg 2018-06-25 20:24:02 +02:00
parent 60de127991
commit 08249aec6a
2 changed files with 92 additions and 2 deletions

View File

@ -175,10 +175,36 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.questions.Observe(float64(len(r.Question))) s.prom.questions.Observe(float64(len(r.Question)))
if len(r.Question) == 1 { // TODO: answer all questions we can answer if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0] q := r.Question[0]
if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET {
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
}
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET { if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET {
name := strings.TrimSuffix(q.Name, ".") name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain) name = strings.TrimSuffix(name, "."+s.domain)
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
if !strings.Contains(name, ".") { if !strings.Contains(name, ".") {
s.prom.upstream.WithLabelValues("local").Inc() s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByName(name); ok { if host, ok := s.hostByName(name); ok {
@ -214,6 +240,17 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names: // Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)

View File

@ -72,7 +72,7 @@ func TestDHCP(t *testing.T) {
if _, ok := a.(*dns.A); !ok { if _, ok := a.(*dns.A); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a) t.Fatalf("unexpected response type: got %T, want dns.A", a)
} }
if got, want := a.(*dns.A).A.To4(), (net.IP{192, 168, 42, 23}); !bytes.Equal(got, want) { if got, want := a.(*dns.A).A, net.ParseIP("192.168.42.23"); !got.Equal(want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want) t.Fatalf("unexpected response IP: got %v, want %v", got, want)
} }
}) })
@ -107,7 +107,7 @@ func TestHostname(t *testing.T) {
if _, ok := a.(*dns.A); !ok { if _, ok := a.(*dns.A); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a) t.Fatalf("unexpected response type: got %T, want dns.A", a)
} }
if got, want := a.(*dns.A).A.To4(), (net.IP{127, 0, 0, 2}); !bytes.Equal(got, want) { if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want) t.Fatalf("unexpected response IP: got %v, want %v", got, want)
} }
}) })
@ -129,6 +129,59 @@ func TestHostname(t *testing.T) {
}) })
} }
func TestLocalhost(t *testing.T) {
r := &recorder{}
s := NewServer("127.0.0.2:0", "lan")
t.Run("A", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeA)
s.handleRequest(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
a := r.response.Answer[0]
if _, ok := a.(*dns.A); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a)
}
if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.1"); !got.Equal(want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
}
})
t.Run("AAAA", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeAAAA)
s.handleRequest(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
a := r.response.Answer[0]
if _, ok := a.(*dns.AAAA); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a)
}
if got, want := a.(*dns.AAAA).AAAA, (net.ParseIP("::1")); !bytes.Equal(got, want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
}
})
t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
a := r.response.Answer[0]
if _, ok := a.(*dns.PTR); !ok {
t.Fatalf("unexpected response type: got %T, want dns.PTR", a)
}
if got, want := a.(*dns.PTR).Ptr, "localhost."; got != want {
t.Fatalf("unexpected response record: got %q, want %q", got, want)
}
})
}
func TestDHCPReverse(t *testing.T) { func TestDHCPReverse(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
ip net.IP ip net.IP