diff --git a/internal/dns/dns.go b/internal/dns/dns.go index fba1cd0..2383747 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -42,6 +42,11 @@ var log = teelogger.NewConsole() // DHCP-based local name resolution can be made case-insensitive. type lcHostname string +type DNSIP struct { + IPv6 net.IP + IPv4 net.IP +} + type Server struct { Mux *dns.ServeMux @@ -59,7 +64,7 @@ type Server struct { hostname, ip string hostsByName map[lcHostname]string hostsByIP map[string]string - subnames map[lcHostname]map[string]net.IP // hostname → subname → ip + subnames map[lcHostname]map[string]DNSIP // hostname → subname → ip upstreamMu sync.RWMutex upstream []string @@ -74,6 +79,10 @@ func NewServer(addr, domain string) *Server { domain: domain, upstream: []string{ // https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses + "1.1.1.1:53", + "1.0.0.1:53", + "2606:4700:4700::1111:53", + "2606:4700:4700::1001:53", "8.8.8.8:53", "8.8.4.4:53", "[2001:4860:4860::8888]:53", @@ -82,7 +91,7 @@ func NewServer(addr, domain string) *Server { sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second hostname: hostname, ip: ip, - subnames: make(map[lcHostname]map[string]net.IP), + subnames: make(map[lcHostname]map[string]DNSIP), } server.prom.registry = prometheus.NewRegistry() @@ -111,8 +120,8 @@ func NewServer(addr, domain string) *Server { server.prom.registry.MustRegister(prometheus.NewGoCollector()) server.initHostsLocked() server.Mux.HandleFunc(".", server.handleRequest) - server.Mux.HandleFunc("lan.", server.handleInternal) - server.Mux.HandleFunc(domain+".", server.handleInternal) + server.Mux.HandleFunc(domain+".", server.subnameHandler(domain)) + server.Mux.HandleFunc("lan.", server.subnameHandler(domain)) server.Mux.HandleFunc("localhost.", server.handleInternal) go func() { for range time.Tick(10 * time.Second) { @@ -125,14 +134,20 @@ func NewServer(addr, domain string) *Server { func (s *Server) initHostsLocked() { s.hostsByName = make(map[lcHostname]string) s.hostsByIP = make(map[string]string) + s.subnames[lcHostname(s.domain)] = make(map[string]DNSIP) if s.hostname != "" && s.ip != "" { lower := strings.ToLower(s.hostname) s.hostsByName[lcHostname(lower)] = s.ip if rev, err := dns.ReverseAddr(s.ip); err == nil { s.hostsByIP[rev] = s.hostname } - s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname)) - s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname)) + subnames := s.subnames[lcHostname(s.domain)] + ip := net.ParseIP(s.ip) + if ip.To4() != nil { + subnames[lower] = DNSIP{IPv4: ip} + } else { + subnames[lower] = DNSIP{IPv6: ip} + } } } @@ -196,10 +211,12 @@ func (s *Server) hostByIP(n string) (string, bool) { return r, ok } -func (s *Server) subname(hostname, host string) (net.IP, bool) { +func (s *Server) subname(hostname, host string) (DNSIP, bool) { s.mu.Lock() defer s.mu.Unlock() - r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host] + // // log.Println(s.subnames) + r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][strings.ToLower(host)] + // log.Println("returning", r, ok) return r, ok } @@ -236,10 +253,20 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) { lower := strings.ToLower(hostname) subnames, ok := s.subnames[lcHostname(lower)] if !ok { - subnames = make(map[string]net.IP) + subnames = make(map[string]DNSIP) s.subnames[lcHostname(lower)] = subnames } - subnames[host] = ip + if ip.To4() != nil { + subnames[host] = DNSIP{ + IPv4: ip, + IPv6: subnames[host].IPv6, + } + } else { + subnames[host] = DNSIP{ + IPv4: subnames[host].IPv4, + IPv6: ip, + } + } w.Write([]byte("ok\n")) } @@ -271,11 +298,27 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) { continue // don’t overwrite e.g. the hostname entry } s.hostsByName[lcHostname(lower)] = l.Addr.String() + + subnames, ok := s.subnames[lcHostname(s.domain)] + if !ok { + subnames = make(map[string]DNSIP) + s.subnames[lcHostname(s.domain)] = subnames + } + if l.Addr.To4() != nil { + subnames[lower] = DNSIP{ + IPv4: l.Addr, + IPv6: subnames[lower].IPv6, + } + } else { + subnames[lower] = DNSIP{ + IPv4: subnames[lower].IPv4, + IPv6: l.Addr, + } + } + if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { s.hostsByIP[rev] = l.Hostname } - s.Mux.HandleFunc(lower+".", s.subnameHandler(lower)) - s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(lower)) } } @@ -330,10 +373,7 @@ func isLocalInAddrArpa(q string) bool { var errEmpty = errors.New("no answers") -func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { - if q.Qclass != dns.ClassINET { - return nil, nil - } +func (s *Server) resolveLocal(q dns.Question) (rr dns.RR, err error) { if strings.ToLower(q.Name) == "localhost." { if q.Qtype == dns.TypeAAAA { return dns.NewRR(q.Name + " 3600 IN AAAA ::1") @@ -342,18 +382,6 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") } } - if q.Qtype == dns.TypeA || - q.Qtype == dns.TypeAAAA || - q.Qtype == dns.TypeMX { - name := strings.TrimSuffix(q.Name, ".") - name = strings.TrimSuffix(name, "."+s.domain) - if host, ok := s.hostByName(name); ok { - if q.Qtype == dns.TypeA { - return dns.NewRR(q.Name + " 3600 IN A " + host) - } - return nil, errEmpty - } - } if q.Qtype == dns.TypePTR { if host, ok := s.hostByIP(q.Name); ok { return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain) @@ -366,13 +394,11 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { } func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { - s.prom.queries.Inc() - s.prom.questions.Observe(float64(len(r.Question))) - s.prom.upstream.WithLabelValues("local").Inc() + s.promInc("local", r) if len(r.Question) != 1 { // TODO: answer all questions we can answer return } - rr, err := s.resolve(r.Question[0]) + rr, err := s.resolveLocal(r.Question[0]) if err != nil { if err == errEmpty { m := new(dns.Msg) @@ -380,7 +406,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) return } - log.Fatal(err) + log.Fatalf("question %#v: %v", r.Question[0], err) } if rr != nil { m := new(dns.Msg) @@ -389,7 +415,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) return } - // Send an authoritative NXDOMAIN for local names: + // Send an authoritative NXDOMAIN for local: m := new(dns.Msg) m.SetReply(r) m.SetRcode(r, dns.RcodeNameError) @@ -413,9 +439,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { } } - s.prom.queries.Inc() - s.prom.questions.Observe(float64(len(r.Question))) - s.prom.upstream.WithLabelValues("DNS").Inc() + s.promInc("DNS", r) for idx, u := range s.upstreams() { in, _, err := s.client.Exchange(r, u) @@ -437,36 +461,23 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { // DNS has no reply for resolving errors } -func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) { +func (s *Server) resolveSubname(domain string, q dns.Question) (dns.RR, error) { + // log.Println("relolving subname of", domain, q.Name) if q.Qclass != dns.ClassINET { return nil, nil } - if q.Qtype == dns.TypeA || - q.Qtype == dns.TypeAAAA || - q.Qtype == dns.TypeMX { - name := strings.TrimSuffix(q.Name, "."+hostname+".") - name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".") + if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA /*|| q.Qtype == dns.TypeMX*/ { + name := strings.TrimSuffix(q.Name, ".") + name = strings.TrimSuffix(name, "."+domain) + // log.Println("name to search", name) - if lower := strings.ToLower(q.Name); lower == hostname+"." || - lower == hostname+"."+s.domain+"." { - host, ok := s.hostByName(hostname) - if !ok { - // The corresponding DHCP lease might have expired, but this - // handler is still installed on the mux. - return nil, nil // NXDOMAIN - } - if q.Qtype == dns.TypeA { - return dns.NewRR(q.Name + " 3600 IN A " + host) - } - return nil, errEmpty - } + if ip, ok := s.subname(domain, name); ok { - if ip, ok := s.subname(hostname, name); ok { - if q.Qtype == dns.TypeA && ip.To4() != nil { - return dns.NewRR(q.Name + " 3600 IN A " + ip.String()) + if q.Qtype == dns.TypeA && ip.IPv4.To4() != nil { + return dns.NewRR(q.Name + " 3600 IN A " + ip.IPv4.String()) } - if q.Qtype == dns.TypeAAAA && ip.To4() == nil { - return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.String()) + if q.Qtype == dns.TypeAAAA && ip.IPv6.To4() == nil && ip.IPv6 != nil { + return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.IPv6.String()) } return nil, errEmpty } @@ -474,14 +485,24 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) return nil, nil } +func (s *Server) promInc(label string, r *dns.Msg) { + s.prom.queries.Inc() + s.prom.questions.Observe(float64(len(r.Question))) + s.prom.upstream.WithLabelValues(label).Inc() +} + func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) { return func(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) != 1 { // TODO: answer all questions we can answer + s.promInc("local", r) return } - rr, err := s.resolveSubname(hostname, r.Question[0]) + + // log.Println("handle subname", hostname, r.Question[0].Name, rr, err) if err != nil { + s.promInc("local", r) if err == errEmpty { m := new(dns.Msg) m.SetReply(r) @@ -491,16 +512,24 @@ func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *d log.Fatalf("question %#v: %v", r.Question[0], err) } if rr != nil { + s.promInc("local", r) m := new(dns.Msg) m.SetReply(r) m.Answer = append(m.Answer, rr) w.WriteMsg(m) return } + // Send an authoritative NXDOMAIN for local names: - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeNameError) - w.WriteMsg(m) + if r.Question[0].Qtype == dns.TypePTR || !strings.Contains(strings.TrimSuffix(r.Question[0].Name, "."), ".") || strings.HasSuffix(r.Question[0].Name, ".lan.") { + s.promInc("local", r) + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, dns.RcodeNameError) + w.WriteMsg(m) + return + } + + s.handleRequest(w, r) } }