make local name resolution case-insensitive

fixes #34
This commit is contained in:
Michael Stapelberg 2019-07-20 12:07:30 +02:00
parent 975f05d7ac
commit 36995097b9
2 changed files with 47 additions and 31 deletions

View File

@ -38,6 +38,10 @@ import (
var log = teelogger.NewConsole() var log = teelogger.NewConsole()
// lcHostname is a string type used for lower-cased hostnames so that the
// DHCP-based local name resolution can be made case-insensitive.
type lcHostname string
type Server struct { type Server struct {
Mux *dns.ServeMux Mux *dns.ServeMux
@ -53,9 +57,9 @@ type Server struct {
mu sync.Mutex mu sync.Mutex
hostname, ip string hostname, ip string
hostsByName map[string]string hostsByName map[lcHostname]string
hostsByIP map[string]string hostsByIP map[string]string
subnames map[string]map[string]net.IP // hostname → subname → ip subnames map[lcHostname]map[string]net.IP // hostname → subname → ip
upstreamMu sync.RWMutex upstreamMu sync.RWMutex
upstream []string upstream []string
@ -78,7 +82,7 @@ func NewServer(addr, domain string) *Server {
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname, hostname: hostname,
ip: ip, ip: ip,
subnames: make(map[string]map[string]net.IP), subnames: make(map[lcHostname]map[string]net.IP),
} }
server.prom.registry = prometheus.NewRegistry() server.prom.registry = prometheus.NewRegistry()
@ -118,15 +122,16 @@ func NewServer(addr, domain string) *Server {
} }
func (s *Server) initHostsLocked() { func (s *Server) initHostsLocked() {
s.hostsByName = make(map[string]string) s.hostsByName = make(map[lcHostname]string)
s.hostsByIP = make(map[string]string) s.hostsByIP = make(map[string]string)
if s.hostname != "" && s.ip != "" { if s.hostname != "" && s.ip != "" {
s.hostsByName[s.hostname] = s.ip lower := strings.ToLower(s.hostname)
s.hostsByName[lcHostname(lower)] = s.ip
if rev, err := dns.ReverseAddr(s.ip); err == nil { if rev, err := dns.ReverseAddr(s.ip); err == nil {
s.hostsByIP[rev] = s.hostname s.hostsByIP[rev] = s.hostname
} }
s.Mux.HandleFunc(s.hostname+".", s.subnameHandler(s.hostname)) s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname))
s.Mux.HandleFunc(s.hostname+"."+s.domain+".", s.subnameHandler(s.hostname)) s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname))
} }
} }
@ -179,7 +184,7 @@ func (s *Server) probeUpstreamLatency() {
func (s *Server) hostByName(n string) (string, bool) { func (s *Server) hostByName(n string) (string, bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
r, ok := s.hostsByName[n] r, ok := s.hostsByName[lcHostname(strings.ToLower(n))]
return r, ok return r, ok
} }
@ -193,7 +198,7 @@ func (s *Server) hostByIP(n string) (string, bool) {
func (s *Server) subname(hostname, host string) (net.IP, bool) { func (s *Server) subname(hostname, host string) (net.IP, bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
r, ok := s.subnames[hostname][host] r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host]
return r, ok return r, ok
} }
@ -227,10 +232,11 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, err, http.StatusForbidden) http.Error(w, err, http.StatusForbidden)
return return
} }
subnames, ok := s.subnames[hostname] lower := strings.ToLower(hostname)
subnames, ok := s.subnames[lcHostname(lower)]
if !ok { if !ok {
subnames = make(map[string]net.IP) subnames = make(map[string]net.IP)
s.subnames[hostname] = subnames s.subnames[lcHostname(lower)] = subnames
} }
subnames[host] = ip subnames[host] = ip
w.Write([]byte("ok\n")) w.Write([]byte("ok\n"))
@ -248,15 +254,16 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) {
if l.Hostname == "" { if l.Hostname == "" {
continue continue
} }
if _, ok := s.hostsByName[l.Hostname]; ok { lower := strings.ToLower(l.Hostname)
if _, ok := s.hostsByName[lcHostname(lower)]; ok {
continue // dont overwrite e.g. the hostname entry continue // dont overwrite e.g. the hostname entry
} }
s.hostsByName[l.Hostname] = l.Addr.String() s.hostsByName[lcHostname(lower)] = l.Addr.String()
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
s.hostsByIP[rev] = l.Hostname s.hostsByIP[rev] = l.Hostname
} }
s.Mux.HandleFunc(l.Hostname+".", s.subnameHandler(l.Hostname)) s.Mux.HandleFunc(lower+".", s.subnameHandler(l.Hostname))
s.Mux.HandleFunc(l.Hostname+"."+s.domain+".", s.subnameHandler(l.Hostname)) s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(l.Hostname))
} }
} }
@ -315,7 +322,7 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
if q.Qclass != dns.ClassINET { if q.Qclass != dns.ClassINET {
return nil, nil return nil, nil
} }
if q.Name == "localhost." { if strings.ToLower(q.Name) == "localhost." {
if q.Qtype == dns.TypeAAAA { if q.Qtype == dns.TypeAAAA {
return dns.NewRR(q.Name + " 3600 IN AAAA ::1") return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
} }
@ -428,8 +435,8 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error)
name := strings.TrimSuffix(q.Name, "."+hostname+".") name := strings.TrimSuffix(q.Name, "."+hostname+".")
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".") name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
if q.Name == hostname+"." || if lower := strings.ToLower(q.Name); lower == hostname+"." ||
q.Name == hostname+"."+s.domain+"." { lower == hostname+"."+s.domain+"." {
host, _ := s.hostByName(hostname) host, _ := s.hostByName(hostname)
if q.Qtype == dns.TypeA { if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host) return dns.NewRR(q.Name + " 3600 IN A " + host)

View File

@ -215,15 +215,22 @@ func TestHostname(t *testing.T) {
s := NewServer("127.0.0.2:0", "lan") s := NewServer("127.0.0.2:0", "lan")
s.SetLeases([]dhcp4d.Lease{ s.SetLeases([]dhcp4d.Lease{
{ {
Hostname: hostname, Hostname: strings.ToUpper(hostname),
Addr: net.IP{192, 168, 42, 23}, Addr: net.IP{192, 168, 42, 23},
}, },
}) })
t.Run("A", func(t *testing.T) { t.Run("A", func(t *testing.T) {
for _, hostname := range []string{
hostname,
strings.ToUpper(hostname),
} {
t.Run(hostname, func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(hostname+".lan.", dns.TypeA) m.SetQuestion(hostname+".lan.", dns.TypeA)
log.Printf("before ServeDNS")
s.Mux.ServeDNS(r, m) s.Mux.ServeDNS(r, m)
log.Printf("after ServeDNS")
if got, want := len(r.response.Answer), 1; got != want { if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
} }
@ -235,6 +242,8 @@ func TestHostname(t *testing.T) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want) t.Fatalf("unexpected response IP: got %v, want %v", got, want)
} }
}) })
}
})
t.Run("PTR", func(t *testing.T) { t.Run("PTR", func(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)