diff --git a/internal/dns/dns.go b/internal/dns/dns.go index f8dd456..05e1a16 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -3,7 +3,9 @@ package dns import ( "log" "net" + "os" "strings" + "sync" "time" "router7/internal/dhcp4d" @@ -15,29 +17,61 @@ import ( type Server struct { *dns.Server - client *dns.Client - domain string - upstream string - sometimes *rate.Limiter - hostsByName map[string]string - hostsByIP map[string]string + client *dns.Client + domain string + upstream string + sometimes *rate.Limiter + + mu sync.Mutex + hostname, ip string + hostsByName map[string]string + hostsByIP map[string]string } func NewServer(addr, domain string) *Server { + hostname, _ := os.Hostname() + ip, _, _ := net.SplitHostPort(addr) server := &Server{ - Server: &dns.Server{Addr: addr, Net: "udp"}, - client: &dns.Client{}, - domain: domain, - upstream: "8.8.8.8:53", - sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second - hostsByName: make(map[string]string), - hostsByIP: make(map[string]string), + Server: &dns.Server{Addr: addr, Net: "udp"}, + client: &dns.Client{}, + domain: domain, + upstream: "8.8.8.8:53", + sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second + hostname: hostname, + ip: ip, } + server.initHostsLocked() dns.HandleFunc(".", server.handleRequest) return server } +func (s *Server) initHostsLocked() { + s.hostsByName = make(map[string]string) + s.hostsByIP = make(map[string]string) + if s.hostname != "" && s.ip != "" { + s.hostsByName[s.hostname] = s.ip + s.hostsByIP[s.ip] = s.hostname + } +} + +func (s *Server) hostByName(n string) (string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.hostsByName[n] + return r, ok +} + +func (s *Server) hostByIP(n string) (string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.hostsByIP[n] + return r, ok +} + func (s *Server) SetLeases(leases []dhcp4d.Lease) { + s.mu.Lock() + defer s.mu.Unlock() + s.initHostsLocked() for _, l := range leases { s.hostsByName[l.Hostname] = l.Addr.String() if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { @@ -90,7 +124,6 @@ func isLocalInAddrArpa(q string) bool { return local } -// TODO: is handleRequest called in more than one goroutine at a time? // TODO: require search domains to be present, then use HandleFunc("lan.", internalName) func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 1 { // TODO: answer all questions we can answer @@ -100,7 +133,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { name = strings.TrimSuffix(name, "."+s.domain) if !strings.Contains(name, ".") { - if host, ok := s.hostsByName[name]; ok { + if host, ok := s.hostByName(name); ok { rr, err := dns.NewRR(q.Name + " 3600 IN A " + host) if err != nil { log.Fatal(err) @@ -115,7 +148,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { } if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET { if isLocalInAddrArpa(q.Name) { - if host, ok := s.hostsByIP[q.Name]; ok { + if host, ok := s.hostByIP(q.Name); ok { rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain) if err != nil { log.Fatal(err) diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index 63d2356..36546aa 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -3,6 +3,7 @@ package dns import ( "bytes" "net" + "os" "router7/internal/dhcp4d" "testing" @@ -74,6 +75,29 @@ func TestDHCP(t *testing.T) { } } +func TestHostname(t *testing.T) { + hostname, err := os.Hostname() + if err != nil { + t.Skipf("os.Hostname: %v", err) + } + + r := &recorder{} + s := NewServer("127.0.0.2:0", "lan") + m := new(dns.Msg) + m.SetQuestion(hostname+".", dns.TypeA) + s.handleRequest(r, m) + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want) + } + a := r.response.Answer[0] + if _, ok := a.(*dns.A); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got, want := a.(*dns.A).A.To4(), (net.IP{127, 0, 0, 2}); !bytes.Equal(got, want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } +} + func TestDHCPReverse(t *testing.T) { for _, test := range []struct { ip net.IP