dns: simplify resolving code
This commit is contained in:
parent
8e95e25442
commit
89e1276ad4
@ -28,7 +28,7 @@ var (
|
||||
dnsListeners = multilisten.NewPool()
|
||||
)
|
||||
|
||||
func updateListeners() error {
|
||||
func updateListeners(mux *miekgdns.ServeMux) error {
|
||||
hosts, err := gokrazy.PrivateInterfaceAddrs()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -39,7 +39,11 @@ func updateListeners() error {
|
||||
})
|
||||
|
||||
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
|
||||
return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}}
|
||||
return &listenerAdapter{&miekgdns.Server{
|
||||
Addr: net.JoinHostPort(host, "53"),
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
}}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@ -73,11 +77,11 @@ func logic() error {
|
||||
log.Printf("cannot resolve DHCP hostnames: %v", err)
|
||||
}
|
||||
http.Handle("/metrics", srv.PrometheusHandler())
|
||||
updateListeners()
|
||||
updateListeners(srv.Mux)
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
for range ch {
|
||||
if err := updateListeners(); err != nil {
|
||||
if err := updateListeners(srv.Mux); err != nil {
|
||||
log.Printf("updateListeners: %v", err)
|
||||
}
|
||||
if err := readLeases(); err != nil {
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func TestDNS(t *testing.T) {
|
||||
dns.NewServer("localhost:4453", "lan")
|
||||
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"}
|
||||
srv := dns.NewServer("localhost:4453", "lan")
|
||||
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp", Handler: srv.Mux}
|
||||
go s.ListenAndServe()
|
||||
const port = 4453
|
||||
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")
|
||||
|
@ -19,6 +19,8 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Mux *dns.ServeMux
|
||||
|
||||
client *dns.Client
|
||||
domain string
|
||||
upstream string
|
||||
@ -40,6 +42,7 @@ func NewServer(addr, domain string) *Server {
|
||||
hostname, _ := os.Hostname()
|
||||
ip, _, _ := net.SplitHostPort(addr)
|
||||
server := &Server{
|
||||
Mux: dns.NewServeMux(),
|
||||
client: &dns.Client{},
|
||||
domain: domain,
|
||||
upstream: "8.8.8.8:53",
|
||||
@ -73,7 +76,9 @@ func NewServer(addr, domain string) *Server {
|
||||
|
||||
server.prom.registry.MustRegister(prometheus.NewGoCollector())
|
||||
server.initHostsLocked()
|
||||
dns.HandleFunc(".", server.handleRequest)
|
||||
server.Mux.HandleFunc(".", server.handleRequest)
|
||||
server.Mux.HandleFunc("lan.", server.handleInternal)
|
||||
server.Mux.HandleFunc("localhost.", server.handleInternal)
|
||||
return server
|
||||
}
|
||||
|
||||
@ -167,98 +172,73 @@ func isLocalInAddrArpa(q string) bool {
|
||||
return local
|
||||
}
|
||||
|
||||
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
|
||||
// TODO: add test for non-A records on internal names, they should not go upstream
|
||||
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
s.prom.queries.Inc()
|
||||
s.prom.questions.Observe(float64(len(r.Question)))
|
||||
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
||||
q := r.Question[0]
|
||||
if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET {
|
||||
if q.Name == "localhost." {
|
||||
s.prom.upstream.WithLabelValues("local").Inc()
|
||||
rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
func (s *Server) resolve(q dns.Question) (dns.RR, error) {
|
||||
if q.Qclass != dns.ClassINET {
|
||||
return nil, nil
|
||||
}
|
||||
if q.Name == "localhost." {
|
||||
if q.Qtype == dns.TypeAAAA {
|
||||
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
||||
}
|
||||
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET {
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
name = strings.TrimSuffix(name, "."+s.domain)
|
||||
|
||||
if q.Name == "localhost." {
|
||||
s.prom.upstream.WithLabelValues("local").Inc()
|
||||
rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(name, ".") {
|
||||
s.prom.upstream.WithLabelValues("local").Inc()
|
||||
if host, ok := s.hostByName(name); ok {
|
||||
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
// Send an authoritative NXDOMAIN for local names:
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.SetRcode(r, dns.RcodeNameError)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
}
|
||||
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
|
||||
if isLocalInAddrArpa(q.Name) {
|
||||
s.prom.upstream.WithLabelValues("local").Inc()
|
||||
if host, ok := s.hostByIP(q.Name); ok {
|
||||
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
|
||||
rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
// Send an authoritative NXDOMAIN for local names:
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.SetRcode(r, dns.RcodeNameError)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
if q.Qtype == dns.TypeA {
|
||||
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
|
||||
}
|
||||
}
|
||||
if q.Qtype == dns.TypeA {
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
name = strings.TrimSuffix(name, "."+s.domain)
|
||||
if host, ok := s.hostByName(name); ok {
|
||||
return dns.NewRR(q.Name + " 3600 IN A " + host)
|
||||
}
|
||||
}
|
||||
if q.Qtype == dns.TypePTR {
|
||||
if host, ok := s.hostByIP(q.Name); ok {
|
||||
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
||||
}
|
||||
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
|
||||
return dns.NewRR(q.Name + " 3600 IN PTR localhost.")
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
|
||||
s.prom.queries.Inc()
|
||||
s.prom.questions.Observe(float64(len(r.Question)))
|
||||
s.prom.upstream.WithLabelValues("local").Inc()
|
||||
if len(r.Question) != 1 { // TODO: answer all questions we can answer
|
||||
return
|
||||
}
|
||||
rr, err := s.resolve(r.Question[0])
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if rr != nil {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
// Send an authoritative NXDOMAIN for local names:
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.SetRcode(r, dns.RcodeNameError)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
||||
q := r.Question[0]
|
||||
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) {
|
||||
s.handleInternal(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.prom.queries.Inc()
|
||||
s.prom.questions.Observe(float64(len(r.Question)))
|
||||
s.prom.upstream.WithLabelValues("DNS").Inc()
|
||||
|
||||
in, _, err := s.client.Exchange(r, s.upstream)
|
||||
if err != nil {
|
||||
@ -268,5 +248,4 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return // DNS has no reply for resolving errors
|
||||
}
|
||||
w.WriteMsg(in)
|
||||
s.prom.upstream.WithLabelValues("DNS").Inc()
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func TestNXDOMAIN(t *testing.T) {
|
||||
s := NewServer("localhost:0", "lan")
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("foo.invalid.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want {
|
||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||
}
|
||||
@ -45,7 +45,7 @@ func TestResolveError(t *testing.T) {
|
||||
s.upstream = "266.266.266.266:53"
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("foo.invalid.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if r.response != nil {
|
||||
t.Fatalf("r.response unexpectedly not nil: %v", r.response)
|
||||
}
|
||||
@ -64,7 +64,7 @@ func TestDHCP(t *testing.T) {
|
||||
t.Run("xps.lan.", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("xps.lan.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||
}
|
||||
@ -80,7 +80,7 @@ func TestDHCP(t *testing.T) {
|
||||
t.Run("notfound.lan.", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("notfound.lan.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||
}
|
||||
@ -99,7 +99,7 @@ func TestHostname(t *testing.T) {
|
||||
t.Run("A", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname+".lan.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||
}
|
||||
@ -115,7 +115,7 @@ func TestHostname(t *testing.T) {
|
||||
t.Run("PTR", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||
}
|
||||
@ -136,7 +136,7 @@ func TestLocalhost(t *testing.T) {
|
||||
t.Run("A", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("localhost.", dns.TypeA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||
}
|
||||
@ -152,7 +152,7 @@ func TestLocalhost(t *testing.T) {
|
||||
t.Run("AAAA", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("localhost.", dns.TypeAAAA)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||
}
|
||||
@ -168,7 +168,7 @@ func TestLocalhost(t *testing.T) {
|
||||
t.Run("PTR", func(t *testing.T) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||
}
|
||||
@ -218,7 +218,7 @@ func TestDHCPReverse(t *testing.T) {
|
||||
})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(test.question, dns.TypePTR)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := len(r.response.Answer), 1; got != want {
|
||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||
}
|
||||
@ -237,7 +237,7 @@ func TestDHCPReverse(t *testing.T) {
|
||||
s := NewServer("localhost:0", "lan")
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR)
|
||||
s.handleRequest(r, m)
|
||||
s.Mux.ServeDNS(r, m)
|
||||
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user