diff --git a/cmd/dnsd/dnsd.go b/cmd/dnsd/dnsd.go index 67daa2c..36f2011 100644 --- a/cmd/dnsd/dnsd.go +++ b/cmd/dnsd/dnsd.go @@ -96,6 +96,7 @@ func logic() error { log.Printf("cannot resolve DHCP hostnames: %v", err) } http.Handle("/metrics", srv.PrometheusHandler()) + http.HandleFunc("/dyndns", srv.DyndnsHandler) if err := updateListeners(srv.Mux); err != nil { return err } diff --git a/internal/dns/dns.go b/internal/dns/dns.go index ee0b13f..c404884 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -17,6 +17,7 @@ package dns import ( "errors" + "fmt" "log" "net" "net/http" @@ -51,6 +52,7 @@ type Server struct { hostname, ip string hostsByName map[string]string hostsByIP map[string]string + subnames map[string]map[string]net.IP // hostname → subname → ip } func NewServer(addr, domain string) *Server { @@ -64,6 +66,7 @@ func NewServer(addr, domain string) *Server { sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second hostname: hostname, ip: ip, + subnames: make(map[string]map[string]net.IP), } server.prom.registry = prometheus.NewRegistry() @@ -105,6 +108,8 @@ func (s *Server) initHostsLocked() { if rev, err := dns.ReverseAddr(s.ip); err == nil { s.hostsByIP[rev] = s.hostname } + s.Mux.HandleFunc(s.hostname+".", s.subnameHandler(s.hostname)) + s.Mux.HandleFunc(s.hostname+"."+s.domain+".", s.subnameHandler(s.hostname)) } } @@ -122,10 +127,52 @@ func (s *Server) hostByIP(n string) (string, bool) { return r, ok } +func (s *Server) subname(hostname, host string) (net.IP, bool) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.subnames[hostname][host] + return r, ok +} + func (s *Server) PrometheusHandler() http.Handler { return promhttp.HandlerFor(s.prom.registry, promhttp.HandlerOpts{}) } +func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) { + host := r.FormValue("host") + ip := net.ParseIP(r.FormValue("ip")) + if ip == nil { + http.Error(w, "invalid ip", http.StatusBadRequest) + return + } + + s.mu.Lock() + defer s.mu.Unlock() + remote, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(w, fmt.Sprintf("net.SplitHostPort(%q): %v", r.RemoteAddr, err), http.StatusBadRequest) + return + } + rev, err := dns.ReverseAddr(remote) + if err != nil { + http.Error(w, fmt.Sprintf("dns.ReverseAddr(%v): %v", remote, err), http.StatusBadRequest) + return + } + hostname, ok := s.hostsByIP[rev] + if !ok { + err := fmt.Sprintf("connection without corresponding DHCP lease: %v", rev) + http.Error(w, err, http.StatusForbidden) + return + } + subnames, ok := s.subnames[hostname] + if !ok { + subnames = make(map[string]net.IP) + s.subnames[hostname] = subnames + } + subnames[host] = ip + w.Write([]byte("ok\n")) +} + func (s *Server) SetLeases(leases []dhcp4d.Lease) { s.mu.Lock() defer s.mu.Unlock() @@ -135,6 +182,9 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) { if l.Expired(now) { continue } + if l.Hostname == "" { + continue + } if _, ok := s.hostsByName[l.Hostname]; ok { continue // don’t overwrite e.g. the hostname entry } @@ -142,6 +192,8 @@ func (s *Server) SetLeases(leases []dhcp4d.Lease) { if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil { s.hostsByIP[rev] = l.Hostname } + s.Mux.HandleFunc(l.Hostname+".", s.subnameHandler(l.Hostname)) + s.Mux.HandleFunc(l.Hostname+"."+s.domain+".", s.subnameHandler(l.Hostname)) } } @@ -284,3 +336,66 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { } w.WriteMsg(in) } + +func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) { + if q.Qclass != dns.ClassINET { + return nil, nil + } + if q.Qtype == dns.TypeA || + q.Qtype == dns.TypeAAAA || + q.Qtype == dns.TypeMX { + name := strings.TrimSuffix(q.Name, "."+hostname+".") + name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".") + + if q.Name == hostname+"." || + q.Name == hostname+"."+s.domain+"." { + host, _ := s.hostByName(hostname) + if q.Qtype == dns.TypeA { + return dns.NewRR(q.Name + " 3600 IN A " + host) + } + return nil, sentinelEmpty + } + + if ip, ok := s.subname(hostname, name); ok { + if q.Qtype == dns.TypeA && ip.To4() != nil { + return dns.NewRR(q.Name + " 3600 IN A " + ip.String()) + } + if q.Qtype == dns.TypeAAAA && ip.To4() == nil { + return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.String()) + } + return nil, sentinelEmpty + } + } + return nil, nil +} + +func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) { + return func(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) != 1 { // TODO: answer all questions we can answer + return + } + + rr, err := s.resolveSubname(hostname, r.Question[0]) + if err != nil { + if err == sentinelEmpty { + m := new(dns.Msg) + m.SetReply(r) + w.WriteMsg(m) + return + } + log.Fatal(err) + } + if rr != nil { + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + return + } + // Send an authoritative NXDOMAIN for local names: + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, dns.RcodeNameError) + w.WriteMsg(m) + } +} diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index 809fdef..ed51f71 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -16,8 +16,13 @@ package dns import ( "bytes" + "io/ioutil" "net" + "net/http" + "net/http/httptest" + "net/url" "os" + "strings" "testing" "time" @@ -339,3 +344,97 @@ func TestDHCPReverse(t *testing.T) { } // TODO: multiple questions + +func TestSubname(t *testing.T) { + r := &recorder{} + s := NewServer("127.0.0.2:0", "lan") + s.SetLeases([]dhcp4d.Lease{ + { + Hostname: "testtarget", + Addr: net.IP{192, 168, 42, 23}, + }, + }) + + resolveTestTarget := func(t *testing.T, name string, want net.IP) { + m := new(dns.Msg) + typ := dns.TypeA + if want.To4() == nil { + typ = dns.TypeAAAA + } + m.SetQuestion(name, typ) + s.Mux.ServeDNS(r, m) + if r.response == nil { + t.Fatalf("nil response") + } + if got, want := len(r.response.Answer), 1; got != want { + t.Fatalf("unexpected number of answers: got %d, want %d", got, want) + } + a := r.response.Answer[0] + if typ == dns.TypeA { + if _, ok := a.(*dns.A); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got := a.(*dns.A).A; !got.Equal(want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } + } else { + if _, ok := a.(*dns.AAAA); !ok { + t.Fatalf("unexpected response type: got %T, want dns.A", a) + } + if got := a.(*dns.AAAA).AAAA; !got.Equal(want) { + t.Fatalf("unexpected response IP: got %v, want %v", got, want) + } + } + } + + t.Run("testtarget.lan.", func(t *testing.T) { + resolveTestTarget(t, "testtarget.lan.", net.ParseIP("192.168.42.23")) + }) + + t.Run("sub.testtarget.lan.", func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("notfound.lan.", dns.TypeA) + s.Mux.ServeDNS(r, m) + if got, want := r.response.Rcode, dns.RcodeNameError; got != want { + t.Fatalf("unexpected rcode: got %v, want %v", got, want) + } + }) + + setSubname := func(ip, remoteAddr string) { + val := url.Values{ + "host": []string{"sub"}, + "ip": []string{ip}, + } + req := httptest.NewRequest("POST", "/dyndns", strings.NewReader(val.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = remoteAddr + rec := httptest.NewRecorder() + s.DyndnsHandler(rec, req) + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusOK; got != want { + body, _ := ioutil.ReadAll(resp.Body) + t.Fatalf("POST /dyndns: unexpected HTTP status: got %v, want %v (%q)", resp.Status, want, string(body)) + } + } + const ip = "fdf5:3606:2a21:1341:b26e:bfff:fe30:504b" + setSubname(ip, "192.168.42.23:1234") + + for _, name := range []string{ + "sub.testtarget.lan.", + "sub.testtarget.", + } { + t.Run(name+" (after dyndns)", func(t *testing.T) { + resolveTestTarget(t, name, net.ParseIP(ip)) + }) + } + + t.Run("Hostname", func(t *testing.T) { + hostname, err := os.Hostname() + if err != nil { + t.Skipf("os.Hostname: %v", err) + } + resolveTestTarget(t, hostname+".lan.", net.ParseIP("127.0.0.2")) + setSubname(ip, "127.0.0.2:1234") + resolveTestTarget(t, "sub."+hostname+".lan.", net.ParseIP(ip)) + }) +}