dns: simplify resolving code

This commit is contained in:
Michael Stapelberg 2018-06-26 09:32:34 +02:00
parent 8e95e25442
commit 89e1276ad4
4 changed files with 91 additions and 108 deletions

View File

@ -28,7 +28,7 @@ var (
dnsListeners = multilisten.NewPool()
)
func updateListeners() error {
func updateListeners(mux *miekgdns.ServeMux) error {
hosts, err := gokrazy.PrivateInterfaceAddrs()
if err != nil {
return err
@ -39,7 +39,11 @@ func updateListeners() error {
})
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}}
return &listenerAdapter{&miekgdns.Server{
Addr: net.JoinHostPort(host, "53"),
Net: "udp",
Handler: mux,
}}
})
return nil
}
@ -73,11 +77,11 @@ func logic() error {
log.Printf("cannot resolve DHCP hostnames: %v", err)
}
http.Handle("/metrics", srv.PrometheusHandler())
updateListeners()
updateListeners(srv.Mux)
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGUSR1)
for range ch {
if err := updateListeners(); err != nil {
if err := updateListeners(srv.Mux); err != nil {
log.Printf("updateListeners: %v", err)
}
if err := readLeases(); err != nil {

View File

@ -13,8 +13,8 @@ import (
)
func TestDNS(t *testing.T) {
dns.NewServer("localhost:4453", "lan")
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"}
srv := dns.NewServer("localhost:4453", "lan")
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp", Handler: srv.Mux}
go s.ListenAndServe()
const port = 4453
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")

View File

@ -19,6 +19,8 @@ import (
)
type Server struct {
Mux *dns.ServeMux
client *dns.Client
domain string
upstream string
@ -40,6 +42,7 @@ func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr)
server := &Server{
Mux: dns.NewServeMux(),
client: &dns.Client{},
domain: domain,
upstream: "8.8.8.8:53",
@ -73,7 +76,9 @@ func NewServer(addr, domain string) *Server {
server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked()
dns.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc("lan.", server.handleInternal)
server.Mux.HandleFunc("localhost.", server.handleInternal)
return server
}
@ -167,98 +172,73 @@ func isLocalInAddrArpa(q string) bool {
return local
}
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
// TODO: add test for non-A records on internal names, they should not go upstream
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET {
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
func (s *Server) resolve(q dns.Question) (dns.RR, error) {
if q.Qclass != dns.ClassINET {
return nil, nil
}
if q.Name == "localhost." {
if q.Qtype == dns.TypeAAAA {
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
}
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if q.Name == "localhost." {
s.prom.upstream.WithLabelValues("local").Inc()
rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
if !strings.Contains(name, ".") {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByName(name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
}
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
if isLocalInAddrArpa(q.Name) {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByIP(q.Name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.")
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
}
}
if q.Qtype == dns.TypeA {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if host, ok := s.hostByName(name); ok {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
}
if q.Qtype == dns.TypePTR {
if host, ok := s.hostByIP(q.Name); ok {
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
}
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
return dns.NewRR(q.Name + " 3600 IN PTR localhost.")
}
}
return nil, nil
}
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("local").Inc()
if len(r.Question) != 1 { // TODO: answer all questions we can answer
return
}
rr, err := s.resolve(r.Question[0])
if err != nil {
log.Fatal(err)
}
if rr != nil {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
}
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) {
s.handleInternal(w, r)
return
}
}
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
in, _, err := s.client.Exchange(r, s.upstream)
if err != nil {
@ -268,5 +248,4 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
return // DNS has no reply for resolving errors
}
w.WriteMsg(in)
s.prom.upstream.WithLabelValues("DNS").Inc()
}

View File

@ -33,7 +33,7 @@ func TestNXDOMAIN(t *testing.T) {
s := NewServer("localhost:0", "lan")
m := new(dns.Msg)
m.SetQuestion("foo.invalid.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
}
@ -45,7 +45,7 @@ func TestResolveError(t *testing.T) {
s.upstream = "266.266.266.266:53"
m := new(dns.Msg)
m.SetQuestion("foo.invalid.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if r.response != nil {
t.Fatalf("r.response unexpectedly not nil: %v", r.response)
}
@ -64,7 +64,7 @@ func TestDHCP(t *testing.T) {
t.Run("xps.lan.", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("xps.lan.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
@ -80,7 +80,7 @@ func TestDHCP(t *testing.T) {
t.Run("notfound.lan.", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("notfound.lan.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
}
@ -99,7 +99,7 @@ func TestHostname(t *testing.T) {
t.Run("A", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion(hostname+".lan.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
@ -115,7 +115,7 @@ func TestHostname(t *testing.T) {
t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
@ -136,7 +136,7 @@ func TestLocalhost(t *testing.T) {
t.Run("A", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
@ -152,7 +152,7 @@ func TestLocalhost(t *testing.T) {
t.Run("AAAA", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("localhost.", dns.TypeAAAA)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
@ -168,7 +168,7 @@ func TestLocalhost(t *testing.T) {
t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
@ -218,7 +218,7 @@ func TestDHCPReverse(t *testing.T) {
})
m := new(dns.Msg)
m.SetQuestion(test.question, dns.TypePTR)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
@ -237,7 +237,7 @@ func TestDHCPReverse(t *testing.T) {
s := NewServer("localhost:0", "lan")
m := new(dns.Msg)
m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR)
s.handleRequest(r, m)
s.Mux.ServeDNS(r, m)
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
}