DNS changes

This commit is contained in:
lordwelch 2020-06-14 10:56:54 -07:00
parent 8ba14148d7
commit a5420430ab

View File

@ -42,6 +42,11 @@ var log = teelogger.NewConsole()
// DHCP-based local name resolution can be made case-insensitive. // DHCP-based local name resolution can be made case-insensitive.
type lcHostname string type lcHostname string
type DNSIP struct {
IPv6 net.IP
IPv4 net.IP
}
type Server struct { type Server struct {
Mux *dns.ServeMux Mux *dns.ServeMux
@ -59,7 +64,7 @@ type Server struct {
hostname, ip string hostname, ip string
hostsByName map[lcHostname]string hostsByName map[lcHostname]string
hostsByIP map[string]string hostsByIP map[string]string
subnames map[lcHostname]map[string]net.IP // hostname → subname → ip subnames map[lcHostname]map[string]DNSIP // hostname → subname → ip
upstreamMu sync.RWMutex upstreamMu sync.RWMutex
upstream []string upstream []string
@ -74,6 +79,10 @@ func NewServer(addr, domain string) *Server {
domain: domain, domain: domain,
upstream: []string{ upstream: []string{
// https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses // https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses
"1.1.1.1:53",
"1.0.0.1:53",
"2606:4700:4700::1111:53",
"2606:4700:4700::1001:53",
"8.8.8.8:53", "8.8.8.8:53",
"8.8.4.4:53", "8.8.4.4:53",
"[2001:4860:4860::8888]:53", "[2001:4860:4860::8888]:53",
@ -82,7 +91,7 @@ func NewServer(addr, domain string) *Server {
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname, hostname: hostname,
ip: ip, ip: ip,
subnames: make(map[lcHostname]map[string]net.IP), subnames: make(map[lcHostname]map[string]DNSIP),
} }
server.prom.registry = prometheus.NewRegistry() server.prom.registry = prometheus.NewRegistry()
@ -111,8 +120,8 @@ func NewServer(addr, domain string) *Server {
server.prom.registry.MustRegister(prometheus.NewGoCollector()) server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked() server.initHostsLocked()
server.Mux.HandleFunc(".", server.handleRequest) server.Mux.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc("lan.", server.handleInternal) server.Mux.HandleFunc(domain+".", server.subnameHandler(domain))
server.Mux.HandleFunc(domain+".", server.handleInternal) server.Mux.HandleFunc("lan.", server.subnameHandler(domain))
server.Mux.HandleFunc("localhost.", server.handleInternal) server.Mux.HandleFunc("localhost.", server.handleInternal)
go func() { go func() {
for range time.Tick(10 * time.Second) { for range time.Tick(10 * time.Second) {
@ -125,14 +134,20 @@ func NewServer(addr, domain string) *Server {
func (s *Server) initHostsLocked() { func (s *Server) initHostsLocked() {
s.hostsByName = make(map[lcHostname]string) s.hostsByName = make(map[lcHostname]string)
s.hostsByIP = make(map[string]string) s.hostsByIP = make(map[string]string)
s.subnames[lcHostname(s.domain)] = make(map[string]DNSIP)
if s.hostname != "" && s.ip != "" { if s.hostname != "" && s.ip != "" {
lower := strings.ToLower(s.hostname) lower := strings.ToLower(s.hostname)
s.hostsByName[lcHostname(lower)] = s.ip s.hostsByName[lcHostname(lower)] = s.ip
if rev, err := dns.ReverseAddr(s.ip); err == nil { if rev, err := dns.ReverseAddr(s.ip); err == nil {
s.hostsByIP[rev] = s.hostname s.hostsByIP[rev] = s.hostname
} }
s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname)) subnames := s.subnames[lcHostname(s.domain)]
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname)) ip := net.ParseIP(s.ip)
if ip.To4() != nil {
subnames[lower] = DNSIP{IPv4: ip}
} else {
subnames[lower] = DNSIP{IPv6: ip}
}
} }
} }
@ -196,10 +211,12 @@ func (s *Server) hostByIP(n string) (string, bool) {
return r, ok return r, ok
} }
func (s *Server) subname(hostname, host string) (net.IP, bool) { func (s *Server) subname(hostname, host string) (DNSIP, bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host] // // log.Println(s.subnames)
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][strings.ToLower(host)]
// log.Println("returning", r, ok)
return r, ok return r, ok
} }
@ -236,10 +253,20 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) {
lower := strings.ToLower(hostname) lower := strings.ToLower(hostname)
subnames, ok := s.subnames[lcHostname(lower)] subnames, ok := s.subnames[lcHostname(lower)]
if !ok { if !ok {
subnames = make(map[string]net.IP) subnames = make(map[string]DNSIP)
s.subnames[lcHostname(lower)] = subnames s.subnames[lcHostname(lower)] = subnames
} }
subnames[host] = ip if ip.To4() != nil {
subnames[host] = DNSIP{
IPv4: ip,
IPv6: subnames[host].IPv6,
}
} else {
subnames[host] = DNSIP{
IPv4: subnames[host].IPv4,
IPv6: ip,
}
}
w.Write([]byte("ok\n")) w.Write([]byte("ok\n"))
} }
@ -271,11 +298,27 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) {
continue // dont overwrite e.g. the hostname entry continue // dont overwrite e.g. the hostname entry
} }
s.hostsByName[lcHostname(lower)] = l.Addr.String() s.hostsByName[lcHostname(lower)] = l.Addr.String()
subnames, ok := s.subnames[lcHostname(s.domain)]
if !ok {
subnames = make(map[string]DNSIP)
s.subnames[lcHostname(s.domain)] = subnames
}
if l.Addr.To4() != nil {
subnames[lower] = DNSIP{
IPv4: l.Addr,
IPv6: subnames[lower].IPv6,
}
} else {
subnames[lower] = DNSIP{
IPv4: subnames[lower].IPv4,
IPv6: l.Addr,
}
}
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
s.hostsByIP[rev] = l.Hostname s.hostsByIP[rev] = l.Hostname
} }
s.Mux.HandleFunc(lower+".", s.subnameHandler(lower))
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(lower))
} }
} }
@ -330,10 +373,7 @@ func isLocalInAddrArpa(q string) bool {
var errEmpty = errors.New("no answers") var errEmpty = errors.New("no answers")
func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { func (s *Server) resolveLocal(q dns.Question) (rr dns.RR, err error) {
if q.Qclass != dns.ClassINET {
return nil, nil
}
if strings.ToLower(q.Name) == "localhost." { if strings.ToLower(q.Name) == "localhost." {
if q.Qtype == dns.TypeAAAA { if q.Qtype == dns.TypeAAAA {
return dns.NewRR(q.Name + " 3600 IN AAAA ::1") return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
@ -342,18 +382,6 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1") return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
} }
} }
if q.Qtype == dns.TypeA ||
q.Qtype == dns.TypeAAAA ||
q.Qtype == dns.TypeMX {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if host, ok := s.hostByName(name); ok {
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
}
}
if q.Qtype == dns.TypePTR { if q.Qtype == dns.TypePTR {
if host, ok := s.hostByIP(q.Name); ok { if host, ok := s.hostByIP(q.Name); ok {
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain) return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
@ -366,13 +394,11 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
} }
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc() s.promInc("local", r)
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("local").Inc()
if len(r.Question) != 1 { // TODO: answer all questions we can answer if len(r.Question) != 1 { // TODO: answer all questions we can answer
return return
} }
rr, err := s.resolve(r.Question[0]) rr, err := s.resolveLocal(r.Question[0])
if err != nil { if err != nil {
if err == errEmpty { if err == errEmpty {
m := new(dns.Msg) m := new(dns.Msg)
@ -380,7 +406,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
log.Fatal(err) log.Fatalf("question %#v: %v", r.Question[0], err)
} }
if rr != nil { if rr != nil {
m := new(dns.Msg) m := new(dns.Msg)
@ -389,7 +415,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
// Send an authoritative NXDOMAIN for local names: // Send an authoritative NXDOMAIN for local:
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError) m.SetRcode(r, dns.RcodeNameError)
@ -413,9 +439,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
s.prom.queries.Inc() s.promInc("DNS", r)
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
for idx, u := range s.upstreams() { for idx, u := range s.upstreams() {
in, _, err := s.client.Exchange(r, u) in, _, err := s.client.Exchange(r, u)
@ -437,36 +461,23 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
// DNS has no reply for resolving errors // DNS has no reply for resolving errors
} }
func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) { func (s *Server) resolveSubname(domain string, q dns.Question) (dns.RR, error) {
// log.Println("relolving subname of", domain, q.Name)
if q.Qclass != dns.ClassINET { if q.Qclass != dns.ClassINET {
return nil, nil return nil, nil
} }
if q.Qtype == dns.TypeA || if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA /*|| q.Qtype == dns.TypeMX*/ {
q.Qtype == dns.TypeAAAA || name := strings.TrimSuffix(q.Name, ".")
q.Qtype == dns.TypeMX { name = strings.TrimSuffix(name, "."+domain)
name := strings.TrimSuffix(q.Name, "."+hostname+".") // log.Println("name to search", name)
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
if lower := strings.ToLower(q.Name); lower == hostname+"." || if ip, ok := s.subname(domain, name); ok {
lower == hostname+"."+s.domain+"." {
host, ok := s.hostByName(hostname)
if !ok {
// The corresponding DHCP lease might have expired, but this
// handler is still installed on the mux.
return nil, nil // NXDOMAIN
}
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
}
if ip, ok := s.subname(hostname, name); ok { if q.Qtype == dns.TypeA && ip.IPv4.To4() != nil {
if q.Qtype == dns.TypeA && ip.To4() != nil { return dns.NewRR(q.Name + " 3600 IN A " + ip.IPv4.String())
return dns.NewRR(q.Name + " 3600 IN A " + ip.String())
} }
if q.Qtype == dns.TypeAAAA && ip.To4() == nil { if q.Qtype == dns.TypeAAAA && ip.IPv6.To4() == nil && ip.IPv6 != nil {
return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.String()) return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.IPv6.String())
} }
return nil, errEmpty return nil, errEmpty
} }
@ -474,14 +485,24 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error)
return nil, nil return nil, nil
} }
func (s *Server) promInc(label string, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues(label).Inc()
}
func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) { func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) {
return func(w dns.ResponseWriter, r *dns.Msg) { return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) != 1 { // TODO: answer all questions we can answer if len(r.Question) != 1 { // TODO: answer all questions we can answer
s.promInc("local", r)
return return
} }
rr, err := s.resolveSubname(hostname, r.Question[0]) rr, err := s.resolveSubname(hostname, r.Question[0])
// log.Println("handle subname", hostname, r.Question[0].Name, rr, err)
if err != nil { if err != nil {
s.promInc("local", r)
if err == errEmpty { if err == errEmpty {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
@ -491,16 +512,24 @@ func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *d
log.Fatalf("question %#v: %v", r.Question[0], err) log.Fatalf("question %#v: %v", r.Question[0], err)
} }
if rr != nil { if rr != nil {
s.promInc("local", r)
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Answer = append(m.Answer, rr) m.Answer = append(m.Answer, rr)
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
// Send an authoritative NXDOMAIN for local names: // Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg) if r.Question[0].Qtype == dns.TypePTR || !strings.Contains(strings.TrimSuffix(r.Question[0].Name, "."), ".") || strings.HasSuffix(r.Question[0].Name, ".lan.") {
m.SetReply(r) s.promInc("local", r)
m.SetRcode(r, dns.RcodeNameError) m := new(dns.Msg)
w.WriteMsg(m) m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
s.handleRequest(w, r)
} }
} }