dns: simplify resolving code

This commit is contained in:
Michael Stapelberg 2018-06-26 09:32:34 +02:00
parent 8e95e25442
commit 89e1276ad4
4 changed files with 91 additions and 108 deletions

View File

@ -28,7 +28,7 @@ var (
dnsListeners = multilisten.NewPool() dnsListeners = multilisten.NewPool()
) )
func updateListeners() error { func updateListeners(mux *miekgdns.ServeMux) error {
hosts, err := gokrazy.PrivateInterfaceAddrs() hosts, err := gokrazy.PrivateInterfaceAddrs()
if err != nil { if err != nil {
return err return err
@ -39,7 +39,11 @@ func updateListeners() error {
}) })
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener { dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}} return &listenerAdapter{&miekgdns.Server{
Addr: net.JoinHostPort(host, "53"),
Net: "udp",
Handler: mux,
}}
}) })
return nil return nil
} }
@ -73,11 +77,11 @@ func logic() error {
log.Printf("cannot resolve DHCP hostnames: %v", err) log.Printf("cannot resolve DHCP hostnames: %v", err)
} }
http.Handle("/metrics", srv.PrometheusHandler()) http.Handle("/metrics", srv.PrometheusHandler())
updateListeners() updateListeners(srv.Mux)
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGUSR1) signal.Notify(ch, syscall.SIGUSR1)
for range ch { for range ch {
if err := updateListeners(); err != nil { if err := updateListeners(srv.Mux); err != nil {
log.Printf("updateListeners: %v", err) log.Printf("updateListeners: %v", err)
} }
if err := readLeases(); err != nil { if err := readLeases(); err != nil {

View File

@ -13,8 +13,8 @@ import (
) )
func TestDNS(t *testing.T) { func TestDNS(t *testing.T) {
dns.NewServer("localhost:4453", "lan") srv := dns.NewServer("localhost:4453", "lan")
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"} s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp", Handler: srv.Mux}
go s.ListenAndServe() go s.ListenAndServe()
const port = 4453 const port = 4453
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1") dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")

View File

@ -19,6 +19,8 @@ import (
) )
type Server struct { type Server struct {
Mux *dns.ServeMux
client *dns.Client client *dns.Client
domain string domain string
upstream string upstream string
@ -40,6 +42,7 @@ func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr) ip, _, _ := net.SplitHostPort(addr)
server := &Server{ server := &Server{
Mux: dns.NewServeMux(),
client: &dns.Client{}, client: &dns.Client{},
domain: domain, domain: domain,
upstream: "8.8.8.8:53", upstream: "8.8.8.8:53",
@ -73,7 +76,9 @@ func NewServer(addr, domain string) *Server {
server.prom.registry.MustRegister(prometheus.NewGoCollector()) server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked() server.initHostsLocked()
dns.HandleFunc(".", server.handleRequest) server.Mux.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc("lan.", server.handleInternal)
server.Mux.HandleFunc("localhost.", server.handleInternal)
return server return server
} }
@ -167,98 +172,73 @@ func isLocalInAddrArpa(q string) bool {
return local return local
} }
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName) func (s *Server) resolve(q dns.Question) (dns.RR, error) {
// TODO: add test for non-A records on internal names, they should not go upstream if q.Qclass != dns.ClassINET {
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { return nil, nil
s.prom.queries.Inc() }
s.prom.questions.Observe(float64(len(r.Question))) if q.Name == "localhost." {
if len(r.Question) == 1 { // TODO: answer all questions we can answer if q.Qtype == dns.TypeAAAA {
q := r.Question[0] return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET {
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
} }
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET { if q.Qtype == dns.TypeA {
name := strings.TrimSuffix(q.Name, ".") return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
name = strings.TrimSuffix(name, "."+s.domain)
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
if !strings.Contains(name, ".") {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByName(name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
}
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
if isLocalInAddrArpa(q.Name) {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByIP(q.Name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
} }
} }
if q.Qtype == dns.TypeA {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if host, ok := s.hostByName(name); ok {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
}
if q.Qtype == dns.TypePTR {
if host, ok := s.hostByIP(q.Name); ok {
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
}
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
return dns.NewRR(q.Name + " 3600 IN PTR localhost.")
}
}
return nil, nil
}
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("local").Inc()
if len(r.Question) != 1 { // TODO: answer all questions we can answer
return
}
rr, err := s.resolve(r.Question[0])
if err != nil {
log.Fatal(err)
}
if rr != nil {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
}
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) {
s.handleInternal(w, r)
return
}
}
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
in, _, err := s.client.Exchange(r, s.upstream) in, _, err := s.client.Exchange(r, s.upstream)
if err != nil { if err != nil {
@ -268,5 +248,4 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
return // DNS has no reply for resolving errors return // DNS has no reply for resolving errors
} }
w.WriteMsg(in) w.WriteMsg(in)
s.prom.upstream.WithLabelValues("DNS").Inc()
} }

View File

@ -33,7 +33,7 @@ func TestNXDOMAIN(t *testing.T) {
s := NewServer("localhost:0", "lan") s := NewServer("localhost:0", "lan")
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("foo.invalid.", dns.TypeA) m.SetQuestion("foo.invalid.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want { if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want) t.Fatalf("unexpected rcode: got %v, want %v", got, want)
} }
@ -45,7 +45,7 @@ func TestResolveError(t *testing.T) {
s.upstream = "266.266.266.266:53" s.upstream = "266.266.266.266:53"
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("foo.invalid.", dns.TypeA) m.SetQuestion("foo.invalid.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if r.response != nil { if r.response != nil {
t.Fatalf("r.response unexpectedly not nil: %v", r.response) t.Fatalf("r.response unexpectedly not nil: %v", r.response)
} }
@ -64,7 +64,7 @@ func TestDHCP(t *testing.T) {
t.Run("xps.lan.", func(t *testing.T) { t.Run("xps.lan.", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("xps.lan.", dns.TypeA) m.SetQuestion("xps.lan.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want) t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
} }
@ -80,7 +80,7 @@ func TestDHCP(t *testing.T) {
t.Run("notfound.lan.", func(t *testing.T) { t.Run("notfound.lan.", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("notfound.lan.", dns.TypeA) m.SetQuestion("notfound.lan.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := r.response.Rcode, dns.RcodeNameError; got != want { if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want) t.Fatalf("unexpected rcode: got %v, want %v", got, want)
} }
@ -99,7 +99,7 @@ func TestHostname(t *testing.T) {
t.Run("A", func(t *testing.T) { t.Run("A", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(hostname+".lan.", dns.TypeA) m.SetQuestion(hostname+".lan.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
} }
@ -115,7 +115,7 @@ func TestHostname(t *testing.T) {
t.Run("PTR", func(t *testing.T) { t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR) m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want) t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
} }
@ -136,7 +136,7 @@ func TestLocalhost(t *testing.T) {
t.Run("A", func(t *testing.T) { t.Run("A", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeA) m.SetQuestion("localhost.", dns.TypeA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
} }
@ -152,7 +152,7 @@ func TestLocalhost(t *testing.T) {
t.Run("AAAA", func(t *testing.T) { t.Run("AAAA", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeAAAA) m.SetQuestion("localhost.", dns.TypeAAAA)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
} }
@ -168,7 +168,7 @@ func TestLocalhost(t *testing.T) {
t.Run("PTR", func(t *testing.T) { t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR) m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want) t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
} }
@ -218,7 +218,7 @@ func TestDHCPReverse(t *testing.T) {
}) })
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(test.question, dns.TypePTR) m.SetQuestion(test.question, dns.TypePTR)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want) t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
} }
@ -237,7 +237,7 @@ func TestDHCPReverse(t *testing.T) {
s := NewServer("localhost:0", "lan") s := NewServer("localhost:0", "lan")
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR) m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m) s.Mux.ServeDNS(r, m)
if got, want := r.response.Rcode, dns.RcodeNameError; got != want { if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want) t.Fatalf("unexpected rcode: got %v, want %v", got, want)
} }