DNS changes

This commit is contained in:
lordwelch 2020-06-14 10:56:54 -07:00
parent 8ba14148d7
commit a5420430ab

View File

@ -42,6 +42,11 @@ var log = teelogger.NewConsole()
// DHCP-based local name resolution can be made case-insensitive.
type lcHostname string
type DNSIP struct {
IPv6 net.IP
IPv4 net.IP
}
type Server struct {
Mux *dns.ServeMux
@ -59,7 +64,7 @@ type Server struct {
hostname, ip string
hostsByName map[lcHostname]string
hostsByIP map[string]string
subnames map[lcHostname]map[string]net.IP // hostname → subname → ip
subnames map[lcHostname]map[string]DNSIP // hostname → subname → ip
upstreamMu sync.RWMutex
upstream []string
@ -74,6 +79,10 @@ func NewServer(addr, domain string) *Server {
domain: domain,
upstream: []string{
// https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses
"1.1.1.1:53",
"1.0.0.1:53",
"2606:4700:4700::1111:53",
"2606:4700:4700::1001:53",
"8.8.8.8:53",
"8.8.4.4:53",
"[2001:4860:4860::8888]:53",
@ -82,7 +91,7 @@ func NewServer(addr, domain string) *Server {
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname,
ip: ip,
subnames: make(map[lcHostname]map[string]net.IP),
subnames: make(map[lcHostname]map[string]DNSIP),
}
server.prom.registry = prometheus.NewRegistry()
@ -111,8 +120,8 @@ func NewServer(addr, domain string) *Server {
server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked()
server.Mux.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc("lan.", server.handleInternal)
server.Mux.HandleFunc(domain+".", server.handleInternal)
server.Mux.HandleFunc(domain+".", server.subnameHandler(domain))
server.Mux.HandleFunc("lan.", server.subnameHandler(domain))
server.Mux.HandleFunc("localhost.", server.handleInternal)
go func() {
for range time.Tick(10 * time.Second) {
@ -125,14 +134,20 @@ func NewServer(addr, domain string) *Server {
func (s *Server) initHostsLocked() {
s.hostsByName = make(map[lcHostname]string)
s.hostsByIP = make(map[string]string)
s.subnames[lcHostname(s.domain)] = make(map[string]DNSIP)
if s.hostname != "" && s.ip != "" {
lower := strings.ToLower(s.hostname)
s.hostsByName[lcHostname(lower)] = s.ip
if rev, err := dns.ReverseAddr(s.ip); err == nil {
s.hostsByIP[rev] = s.hostname
}
s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname))
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname))
subnames := s.subnames[lcHostname(s.domain)]
ip := net.ParseIP(s.ip)
if ip.To4() != nil {
subnames[lower] = DNSIP{IPv4: ip}
} else {
subnames[lower] = DNSIP{IPv6: ip}
}
}
}
@ -196,10 +211,12 @@ func (s *Server) hostByIP(n string) (string, bool) {
return r, ok
}
func (s *Server) subname(hostname, host string) (net.IP, bool) {
func (s *Server) subname(hostname, host string) (DNSIP, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host]
// // log.Println(s.subnames)
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][strings.ToLower(host)]
// log.Println("returning", r, ok)
return r, ok
}
@ -236,10 +253,20 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) {
lower := strings.ToLower(hostname)
subnames, ok := s.subnames[lcHostname(lower)]
if !ok {
subnames = make(map[string]net.IP)
subnames = make(map[string]DNSIP)
s.subnames[lcHostname(lower)] = subnames
}
subnames[host] = ip
if ip.To4() != nil {
subnames[host] = DNSIP{
IPv4: ip,
IPv6: subnames[host].IPv6,
}
} else {
subnames[host] = DNSIP{
IPv4: subnames[host].IPv4,
IPv6: ip,
}
}
w.Write([]byte("ok\n"))
}
@ -271,11 +298,27 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) {
continue // dont overwrite e.g. the hostname entry
}
s.hostsByName[lcHostname(lower)] = l.Addr.String()
subnames, ok := s.subnames[lcHostname(s.domain)]
if !ok {
subnames = make(map[string]DNSIP)
s.subnames[lcHostname(s.domain)] = subnames
}
if l.Addr.To4() != nil {
subnames[lower] = DNSIP{
IPv4: l.Addr,
IPv6: subnames[lower].IPv6,
}
} else {
subnames[lower] = DNSIP{
IPv4: subnames[lower].IPv4,
IPv6: l.Addr,
}
}
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
s.hostsByIP[rev] = l.Hostname
}
s.Mux.HandleFunc(lower+".", s.subnameHandler(lower))
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(lower))
}
}
@ -330,10 +373,7 @@ func isLocalInAddrArpa(q string) bool {
var errEmpty = errors.New("no answers")
func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
if q.Qclass != dns.ClassINET {
return nil, nil
}
func (s *Server) resolveLocal(q dns.Question) (rr dns.RR, err error) {
if strings.ToLower(q.Name) == "localhost." {
if q.Qtype == dns.TypeAAAA {
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
@ -342,18 +382,6 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
}
}
if q.Qtype == dns.TypeA ||
q.Qtype == dns.TypeAAAA ||
q.Qtype == dns.TypeMX {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if host, ok := s.hostByName(name); ok {
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
}
}
if q.Qtype == dns.TypePTR {
if host, ok := s.hostByIP(q.Name); ok {
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
@ -366,13 +394,11 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
}
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("local").Inc()
s.promInc("local", r)
if len(r.Question) != 1 { // TODO: answer all questions we can answer
return
}
rr, err := s.resolve(r.Question[0])
rr, err := s.resolveLocal(r.Question[0])
if err != nil {
if err == errEmpty {
m := new(dns.Msg)
@ -380,7 +406,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
return
}
log.Fatal(err)
log.Fatalf("question %#v: %v", r.Question[0], err)
}
if rr != nil {
m := new(dns.Msg)
@ -389,7 +415,7 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
// Send an authoritative NXDOMAIN for local:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
@ -413,9 +439,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
}
}
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
s.promInc("DNS", r)
for idx, u := range s.upstreams() {
in, _, err := s.client.Exchange(r, u)
@ -437,36 +461,23 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
// DNS has no reply for resolving errors
}
func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) {
func (s *Server) resolveSubname(domain string, q dns.Question) (dns.RR, error) {
// log.Println("relolving subname of", domain, q.Name)
if q.Qclass != dns.ClassINET {
return nil, nil
}
if q.Qtype == dns.TypeA ||
q.Qtype == dns.TypeAAAA ||
q.Qtype == dns.TypeMX {
name := strings.TrimSuffix(q.Name, "."+hostname+".")
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA /*|| q.Qtype == dns.TypeMX*/ {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+domain)
// log.Println("name to search", name)
if lower := strings.ToLower(q.Name); lower == hostname+"." ||
lower == hostname+"."+s.domain+"." {
host, ok := s.hostByName(hostname)
if !ok {
// The corresponding DHCP lease might have expired, but this
// handler is still installed on the mux.
return nil, nil // NXDOMAIN
}
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
}
if ip, ok := s.subname(domain, name); ok {
if ip, ok := s.subname(hostname, name); ok {
if q.Qtype == dns.TypeA && ip.To4() != nil {
return dns.NewRR(q.Name + " 3600 IN A " + ip.String())
if q.Qtype == dns.TypeA && ip.IPv4.To4() != nil {
return dns.NewRR(q.Name + " 3600 IN A " + ip.IPv4.String())
}
if q.Qtype == dns.TypeAAAA && ip.To4() == nil {
return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.String())
if q.Qtype == dns.TypeAAAA && ip.IPv6.To4() == nil && ip.IPv6 != nil {
return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.IPv6.String())
}
return nil, errEmpty
}
@ -474,14 +485,24 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error)
return nil, nil
}
func (s *Server) promInc(label string, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues(label).Inc()
}
func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) != 1 { // TODO: answer all questions we can answer
s.promInc("local", r)
return
}
rr, err := s.resolveSubname(hostname, r.Question[0])
// log.Println("handle subname", hostname, r.Question[0].Name, rr, err)
if err != nil {
s.promInc("local", r)
if err == errEmpty {
m := new(dns.Msg)
m.SetReply(r)
@ -491,16 +512,24 @@ func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *d
log.Fatalf("question %#v: %v", r.Question[0], err)
}
if rr != nil {
s.promInc("local", r)
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
if r.Question[0].Qtype == dns.TypePTR || !strings.Contains(strings.TrimSuffix(r.Question[0].Name, "."), ".") || strings.HasSuffix(r.Question[0].Name, ".lan.") {
s.promInc("local", r)
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
}
s.handleRequest(w, r)
}
}