From 36995097b9de20e0dfbb41e1274735a4d4e1a2fe Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Sat, 20 Jul 2019 12:07:30 +0200 Subject: [PATCH] make local name resolution case-insensitive fixes #34 --- internal/dns/dns.go | 43 +++++++++++++++++++++++----------------- internal/dns/dns_test.go | 35 ++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index f6921a4..78bbd09 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -38,6 +38,10 @@ import ( var log = teelogger.NewConsole() +// lcHostname is a string type used for lower-cased hostnames so that the +// DHCP-based local name resolution can be made case-insensitive. +type lcHostname string + type Server struct { Mux *dns.ServeMux @@ -53,9 +57,9 @@ type Server struct { mu sync.Mutex hostname, ip string - hostsByName map[string]string + hostsByName map[lcHostname]string hostsByIP map[string]string - subnames map[string]map[string]net.IP // hostname → subname → ip + subnames map[lcHostname]map[string]net.IP // hostname → subname → ip upstreamMu sync.RWMutex upstream []string @@ -78,7 +82,7 @@ func NewServer(addr, domain string) *Server { sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second hostname: hostname, ip: ip, - subnames: make(map[string]map[string]net.IP), + subnames: make(map[lcHostname]map[string]net.IP), } server.prom.registry = prometheus.NewRegistry() @@ -118,15 +122,16 @@ func NewServer(addr, domain string) *Server { } func (s *Server) initHostsLocked() { - s.hostsByName = make(map[string]string) + s.hostsByName = make(map[lcHostname]string) s.hostsByIP = make(map[string]string) if s.hostname != "" && s.ip != "" { - s.hostsByName[s.hostname] = s.ip + lower := strings.ToLower(s.hostname) + s.hostsByName[lcHostname(lower)] = s.ip if rev, err := dns.ReverseAddr(s.ip); err == nil { s.hostsByIP[rev] = s.hostname } - s.Mux.HandleFunc(s.hostname+".", s.subnameHandler(s.hostname)) - s.Mux.HandleFunc(s.hostname+"."+s.domain+".", s.subnameHandler(s.hostname)) + s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname)) + s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname)) } } @@ -179,7 +184,7 @@ func (s *Server) probeUpstreamLatency() { func (s *Server) hostByName(n string) (string, bool) { s.mu.Lock() defer s.mu.Unlock() - r, ok := s.hostsByName[n] + r, ok := s.hostsByName[lcHostname(strings.ToLower(n))] return r, ok } @@ -193,7 +198,7 @@ func (s *Server) hostByIP(n string) (string, bool) { func (s *Server) subname(hostname, host string) (net.IP, bool) { s.mu.Lock() defer s.mu.Unlock() - r, ok := s.subnames[hostname][host] + r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host] return r, ok } @@ -227,10 +232,11 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, err, http.StatusForbidden) return } - subnames, ok := s.subnames[hostname] + lower := strings.ToLower(hostname) + subnames, ok := s.subnames[lcHostname(lower)] if !ok { subnames = make(map[string]net.IP) - s.subnames[hostname] = subnames + s.subnames[lcHostname(lower)] = subnames } subnames[host] = ip w.Write([]byte("ok\n")) @@ -248,15 +254,16 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) { if l.Hostname == "" { continue } - if _, ok := s.hostsByName[l.Hostname]; ok { + lower := strings.ToLower(l.Hostname) + if _, ok := s.hostsByName[lcHostname(lower)]; ok { continue // don’t overwrite e.g. the hostname entry } - s.hostsByName[l.Hostname] = l.Addr.String() + s.hostsByName[lcHostname(lower)] = l.Addr.String() if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { s.hostsByIP[rev] = l.Hostname } - s.Mux.HandleFunc(l.Hostname+".", s.subnameHandler(l.Hostname)) - s.Mux.HandleFunc(l.Hostname+"."+s.domain+".", s.subnameHandler(l.Hostname)) + s.Mux.HandleFunc(lower+".", s.subnameHandler(l.Hostname)) + s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(l.Hostname)) } } @@ -315,7 +322,7 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) { if q.Qclass != dns.ClassINET { return nil, nil } - if q.Name == "localhost." { + if strings.ToLower(q.Name) == "localhost." { if q.Qtype == dns.TypeAAAA { return dns.NewRR(q.Name + " 3600 IN AAAA ::1") } @@ -428,8 +435,8 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) name := strings.TrimSuffix(q.Name, "."+hostname+".") name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".") - if q.Name == hostname+"." || - q.Name == hostname+"."+s.domain+"." { + if lower := strings.ToLower(q.Name); lower == hostname+"." || + lower == hostname+"."+s.domain+"." { host, _ := s.hostByName(hostname) if q.Qtype == dns.TypeA { return dns.NewRR(q.Name + " 3600 IN A " + host) diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index 62ea4dd..e1ecd95 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -215,24 +215,33 @@ func TestHostname(t *testing.T) { s := NewServer("127.0.0.2:0", "lan") s.SetLeases([]dhcp4d.Lease{ { - Hostname: hostname, + Hostname: strings.ToUpper(hostname), Addr: net.IP{192, 168, 42, 23}, }, }) t.Run("A", func(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion(hostname+".lan.", dns.TypeA) - s.Mux.ServeDNS(r, m) - if got, want := len(r.response.Answer), 1; got != want { - t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) - } - a := r.response.Answer[0] - if _, ok := a.(*dns.A); !ok { - t.Fatalf("unexpected response type: got %T, want dns.A", a) - } - if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) { - t.Fatalf("unexpected response IP: got %v, want %v", got, want) + for _, hostname := range []string{ + hostname, + strings.ToUpper(hostname), + } { + t.Run(hostname, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(hostname+".lan.", dns.TypeA) + log.Printf("before ServeDNS") + s.Mux.ServeDNS(r, m) + log.Printf("after ServeDNS") + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) + } + a := r.response.Answer[0] + if _, ok := a.(*dns.A); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } + }) } })