parent
975f05d7ac
commit
36995097b9
@ -38,6 +38,10 @@ import (
|
|||||||
|
|
||||||
var log = teelogger.NewConsole()
|
var log = teelogger.NewConsole()
|
||||||
|
|
||||||
|
// lcHostname is a string type used for lower-cased hostnames so that the
|
||||||
|
// DHCP-based local name resolution can be made case-insensitive.
|
||||||
|
type lcHostname string
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Mux *dns.ServeMux
|
Mux *dns.ServeMux
|
||||||
|
|
||||||
@ -53,9 +57,9 @@ type Server struct {
|
|||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
hostname, ip string
|
hostname, ip string
|
||||||
hostsByName map[string]string
|
hostsByName map[lcHostname]string
|
||||||
hostsByIP map[string]string
|
hostsByIP map[string]string
|
||||||
subnames map[string]map[string]net.IP // hostname → subname → ip
|
subnames map[lcHostname]map[string]net.IP // hostname → subname → ip
|
||||||
|
|
||||||
upstreamMu sync.RWMutex
|
upstreamMu sync.RWMutex
|
||||||
upstream []string
|
upstream []string
|
||||||
@ -78,7 +82,7 @@ func NewServer(addr, domain string) *Server {
|
|||||||
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
|
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
|
||||||
hostname: hostname,
|
hostname: hostname,
|
||||||
ip: ip,
|
ip: ip,
|
||||||
subnames: make(map[string]map[string]net.IP),
|
subnames: make(map[lcHostname]map[string]net.IP),
|
||||||
}
|
}
|
||||||
server.prom.registry = prometheus.NewRegistry()
|
server.prom.registry = prometheus.NewRegistry()
|
||||||
|
|
||||||
@ -118,15 +122,16 @@ func NewServer(addr, domain string) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) initHostsLocked() {
|
func (s *Server) initHostsLocked() {
|
||||||
s.hostsByName = make(map[string]string)
|
s.hostsByName = make(map[lcHostname]string)
|
||||||
s.hostsByIP = make(map[string]string)
|
s.hostsByIP = make(map[string]string)
|
||||||
if s.hostname != "" && s.ip != "" {
|
if s.hostname != "" && s.ip != "" {
|
||||||
s.hostsByName[s.hostname] = s.ip
|
lower := strings.ToLower(s.hostname)
|
||||||
|
s.hostsByName[lcHostname(lower)] = s.ip
|
||||||
if rev, err := dns.ReverseAddr(s.ip); err == nil {
|
if rev, err := dns.ReverseAddr(s.ip); err == nil {
|
||||||
s.hostsByIP[rev] = s.hostname
|
s.hostsByIP[rev] = s.hostname
|
||||||
}
|
}
|
||||||
s.Mux.HandleFunc(s.hostname+".", s.subnameHandler(s.hostname))
|
s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname))
|
||||||
s.Mux.HandleFunc(s.hostname+"."+s.domain+".", s.subnameHandler(s.hostname))
|
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +184,7 @@ func (s *Server) probeUpstreamLatency() {
|
|||||||
func (s *Server) hostByName(n string) (string, bool) {
|
func (s *Server) hostByName(n string) (string, bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
r, ok := s.hostsByName[n]
|
r, ok := s.hostsByName[lcHostname(strings.ToLower(n))]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +198,7 @@ func (s *Server) hostByIP(n string) (string, bool) {
|
|||||||
func (s *Server) subname(hostname, host string) (net.IP, bool) {
|
func (s *Server) subname(hostname, host string) (net.IP, bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
r, ok := s.subnames[hostname][host]
|
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,10 +232,11 @@ func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, err, http.StatusForbidden)
|
http.Error(w, err, http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
subnames, ok := s.subnames[hostname]
|
lower := strings.ToLower(hostname)
|
||||||
|
subnames, ok := s.subnames[lcHostname(lower)]
|
||||||
if !ok {
|
if !ok {
|
||||||
subnames = make(map[string]net.IP)
|
subnames = make(map[string]net.IP)
|
||||||
s.subnames[hostname] = subnames
|
s.subnames[lcHostname(lower)] = subnames
|
||||||
}
|
}
|
||||||
subnames[host] = ip
|
subnames[host] = ip
|
||||||
w.Write([]byte("ok\n"))
|
w.Write([]byte("ok\n"))
|
||||||
@ -248,15 +254,16 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) {
|
|||||||
if l.Hostname == "" {
|
if l.Hostname == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, ok := s.hostsByName[l.Hostname]; ok {
|
lower := strings.ToLower(l.Hostname)
|
||||||
|
if _, ok := s.hostsByName[lcHostname(lower)]; ok {
|
||||||
continue // don’t overwrite e.g. the hostname entry
|
continue // don’t overwrite e.g. the hostname entry
|
||||||
}
|
}
|
||||||
s.hostsByName[l.Hostname] = l.Addr.String()
|
s.hostsByName[lcHostname(lower)] = l.Addr.String()
|
||||||
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
|
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
|
||||||
s.hostsByIP[rev] = l.Hostname
|
s.hostsByIP[rev] = l.Hostname
|
||||||
}
|
}
|
||||||
s.Mux.HandleFunc(l.Hostname+".", s.subnameHandler(l.Hostname))
|
s.Mux.HandleFunc(lower+".", s.subnameHandler(l.Hostname))
|
||||||
s.Mux.HandleFunc(l.Hostname+"."+s.domain+".", s.subnameHandler(l.Hostname))
|
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(l.Hostname))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +322,7 @@ func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
|
|||||||
if q.Qclass != dns.ClassINET {
|
if q.Qclass != dns.ClassINET {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if q.Name == "localhost." {
|
if strings.ToLower(q.Name) == "localhost." {
|
||||||
if q.Qtype == dns.TypeAAAA {
|
if q.Qtype == dns.TypeAAAA {
|
||||||
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
||||||
}
|
}
|
||||||
@ -428,8 +435,8 @@ func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error)
|
|||||||
name := strings.TrimSuffix(q.Name, "."+hostname+".")
|
name := strings.TrimSuffix(q.Name, "."+hostname+".")
|
||||||
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
|
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
|
||||||
|
|
||||||
if q.Name == hostname+"." ||
|
if lower := strings.ToLower(q.Name); lower == hostname+"." ||
|
||||||
q.Name == hostname+"."+s.domain+"." {
|
lower == hostname+"."+s.domain+"." {
|
||||||
host, _ := s.hostByName(hostname)
|
host, _ := s.hostByName(hostname)
|
||||||
if q.Qtype == dns.TypeA {
|
if q.Qtype == dns.TypeA {
|
||||||
return dns.NewRR(q.Name + " 3600 IN A " + host)
|
return dns.NewRR(q.Name + " 3600 IN A " + host)
|
||||||
|
@ -215,24 +215,33 @@ func TestHostname(t *testing.T) {
|
|||||||
s := NewServer("127.0.0.2:0", "lan")
|
s := NewServer("127.0.0.2:0", "lan")
|
||||||
s.SetLeases([]dhcp4d.Lease{
|
s.SetLeases([]dhcp4d.Lease{
|
||||||
{
|
{
|
||||||
Hostname: hostname,
|
Hostname: strings.ToUpper(hostname),
|
||||||
Addr: net.IP{192, 168, 42, 23},
|
Addr: net.IP{192, 168, 42, 23},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("A", func(t *testing.T) {
|
t.Run("A", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
for _, hostname := range []string{
|
||||||
m.SetQuestion(hostname+".lan.", dns.TypeA)
|
hostname,
|
||||||
s.Mux.ServeDNS(r, m)
|
strings.ToUpper(hostname),
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
} {
|
||||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
t.Run(hostname, func(t *testing.T) {
|
||||||
}
|
m := new(dns.Msg)
|
||||||
a := r.response.Answer[0]
|
m.SetQuestion(hostname+".lan.", dns.TypeA)
|
||||||
if _, ok := a.(*dns.A); !ok {
|
log.Printf("before ServeDNS")
|
||||||
t.Fatalf("unexpected response type: got %T, want dns.A", a)
|
s.Mux.ServeDNS(r, m)
|
||||||
}
|
log.Printf("after ServeDNS")
|
||||||
if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
|
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||||
|
}
|
||||||
|
a := r.response.Answer[0]
|
||||||
|
if _, ok := a.(*dns.A); !ok {
|
||||||
|
t.Fatalf("unexpected response type: got %T, want dns.A", a)
|
||||||
|
}
|
||||||
|
if got, want := a.(*dns.A).A, net.ParseIP("127.0.0.2"); !got.Equal(want) {
|
||||||
|
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user