dns: simplify resolving code
This commit is contained in:
parent
8e95e25442
commit
89e1276ad4
@ -28,7 +28,7 @@ var (
|
|||||||
dnsListeners = multilisten.NewPool()
|
dnsListeners = multilisten.NewPool()
|
||||||
)
|
)
|
||||||
|
|
||||||
func updateListeners() error {
|
func updateListeners(mux *miekgdns.ServeMux) error {
|
||||||
hosts, err := gokrazy.PrivateInterfaceAddrs()
|
hosts, err := gokrazy.PrivateInterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -39,7 +39,11 @@ func updateListeners() error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
|
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
|
||||||
return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}}
|
return &listenerAdapter{&miekgdns.Server{
|
||||||
|
Addr: net.JoinHostPort(host, "53"),
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
}}
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -73,11 +77,11 @@ func logic() error {
|
|||||||
log.Printf("cannot resolve DHCP hostnames: %v", err)
|
log.Printf("cannot resolve DHCP hostnames: %v", err)
|
||||||
}
|
}
|
||||||
http.Handle("/metrics", srv.PrometheusHandler())
|
http.Handle("/metrics", srv.PrometheusHandler())
|
||||||
updateListeners()
|
updateListeners(srv.Mux)
|
||||||
ch := make(chan os.Signal, 1)
|
ch := make(chan os.Signal, 1)
|
||||||
signal.Notify(ch, syscall.SIGUSR1)
|
signal.Notify(ch, syscall.SIGUSR1)
|
||||||
for range ch {
|
for range ch {
|
||||||
if err := updateListeners(); err != nil {
|
if err := updateListeners(srv.Mux); err != nil {
|
||||||
log.Printf("updateListeners: %v", err)
|
log.Printf("updateListeners: %v", err)
|
||||||
}
|
}
|
||||||
if err := readLeases(); err != nil {
|
if err := readLeases(); err != nil {
|
||||||
|
@ -13,8 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestDNS(t *testing.T) {
|
func TestDNS(t *testing.T) {
|
||||||
dns.NewServer("localhost:4453", "lan")
|
srv := dns.NewServer("localhost:4453", "lan")
|
||||||
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"}
|
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp", Handler: srv.Mux}
|
||||||
go s.ListenAndServe()
|
go s.ListenAndServe()
|
||||||
const port = 4453
|
const port = 4453
|
||||||
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")
|
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")
|
||||||
|
@ -19,6 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
Mux *dns.ServeMux
|
||||||
|
|
||||||
client *dns.Client
|
client *dns.Client
|
||||||
domain string
|
domain string
|
||||||
upstream string
|
upstream string
|
||||||
@ -40,6 +42,7 @@ func NewServer(addr, domain string) *Server {
|
|||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
ip, _, _ := net.SplitHostPort(addr)
|
ip, _, _ := net.SplitHostPort(addr)
|
||||||
server := &Server{
|
server := &Server{
|
||||||
|
Mux: dns.NewServeMux(),
|
||||||
client: &dns.Client{},
|
client: &dns.Client{},
|
||||||
domain: domain,
|
domain: domain,
|
||||||
upstream: "8.8.8.8:53",
|
upstream: "8.8.8.8:53",
|
||||||
@ -73,7 +76,9 @@ func NewServer(addr, domain string) *Server {
|
|||||||
|
|
||||||
server.prom.registry.MustRegister(prometheus.NewGoCollector())
|
server.prom.registry.MustRegister(prometheus.NewGoCollector())
|
||||||
server.initHostsLocked()
|
server.initHostsLocked()
|
||||||
dns.HandleFunc(".", server.handleRequest)
|
server.Mux.HandleFunc(".", server.handleRequest)
|
||||||
|
server.Mux.HandleFunc("lan.", server.handleInternal)
|
||||||
|
server.Mux.HandleFunc("localhost.", server.handleInternal)
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,83 +172,48 @@ func isLocalInAddrArpa(q string) bool {
|
|||||||
return local
|
return local
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
|
func (s *Server) resolve(q dns.Question) (dns.RR, error) {
|
||||||
// TODO: add test for non-A records on internal names, they should not go upstream
|
if q.Qclass != dns.ClassINET {
|
||||||
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
return nil, nil
|
||||||
s.prom.queries.Inc()
|
}
|
||||||
s.prom.questions.Observe(float64(len(r.Question)))
|
|
||||||
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
|
||||||
q := r.Question[0]
|
|
||||||
if q.Qtype == dns.TypeAAAA && q.Qclass == dns.ClassINET {
|
|
||||||
if q.Name == "localhost." {
|
if q.Name == "localhost." {
|
||||||
s.prom.upstream.WithLabelValues("local").Inc()
|
if q.Qtype == dns.TypeAAAA {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
}
|
||||||
m := new(dns.Msg)
|
if q.Qtype == dns.TypeA {
|
||||||
m.SetReply(r)
|
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
|
||||||
m.Answer = append(m.Answer, rr)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET {
|
if q.Qtype == dns.TypeA {
|
||||||
name := strings.TrimSuffix(q.Name, ".")
|
name := strings.TrimSuffix(q.Name, ".")
|
||||||
name = strings.TrimSuffix(name, "."+s.domain)
|
name = strings.TrimSuffix(name, "."+s.domain)
|
||||||
|
|
||||||
if q.Name == "localhost." {
|
|
||||||
s.prom.upstream.WithLabelValues("local").Inc()
|
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, rr)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.Contains(name, ".") {
|
|
||||||
s.prom.upstream.WithLabelValues("local").Inc()
|
|
||||||
if host, ok := s.hostByName(name); ok {
|
if host, ok := s.hostByName(name); ok {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
|
return dns.NewRR(q.Name + " 3600 IN A " + host)
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, rr)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send an authoritative NXDOMAIN for local names:
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.SetRcode(r, dns.RcodeNameError)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
|
if q.Qtype == dns.TypePTR {
|
||||||
if isLocalInAddrArpa(q.Name) {
|
|
||||||
s.prom.upstream.WithLabelValues("local").Inc()
|
|
||||||
if host, ok := s.hostByIP(q.Name); ok {
|
if host, ok := s.hostByIP(q.Name); ok {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Answer = append(m.Answer, rr)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
|
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN PTR localhost.")
|
return dns.NewRR(q.Name + " 3600 IN PTR localhost.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
s.prom.queries.Inc()
|
||||||
|
s.prom.questions.Observe(float64(len(r.Question)))
|
||||||
|
s.prom.upstream.WithLabelValues("local").Inc()
|
||||||
|
if len(r.Question) != 1 { // TODO: answer all questions we can answer
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rr, err := s.resolve(r.Question[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if rr != nil {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Answer = append(m.Answer, rr)
|
m.Answer = append(m.Answer, rr)
|
||||||
@ -255,10 +225,20 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.SetRcode(r, dns.RcodeNameError)
|
m.SetRcode(r, dns.RcodeNameError)
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
||||||
|
q := r.Question[0]
|
||||||
|
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) {
|
||||||
|
s.handleInternal(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
s.prom.queries.Inc()
|
||||||
|
s.prom.questions.Observe(float64(len(r.Question)))
|
||||||
|
s.prom.upstream.WithLabelValues("DNS").Inc()
|
||||||
|
|
||||||
in, _, err := s.client.Exchange(r, s.upstream)
|
in, _, err := s.client.Exchange(r, s.upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -268,5 +248,4 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return // DNS has no reply for resolving errors
|
return // DNS has no reply for resolving errors
|
||||||
}
|
}
|
||||||
w.WriteMsg(in)
|
w.WriteMsg(in)
|
||||||
s.prom.upstream.WithLabelValues("DNS").Inc()
|
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ func TestNXDOMAIN(t *testing.T) {
|
|||||||
s := NewServer("localhost:0", "lan")
|
s := NewServer("localhost:0", "lan")
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("foo.invalid.", dns.TypeA)
|
m.SetQuestion("foo.invalid.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want {
|
if got, want := r.response.MsgHdr.Rcode, dns.RcodeNameError; got != want {
|
||||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
@ -45,7 +45,7 @@ func TestResolveError(t *testing.T) {
|
|||||||
s.upstream = "266.266.266.266:53"
|
s.upstream = "266.266.266.266:53"
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("foo.invalid.", dns.TypeA)
|
m.SetQuestion("foo.invalid.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if r.response != nil {
|
if r.response != nil {
|
||||||
t.Fatalf("r.response unexpectedly not nil: %v", r.response)
|
t.Fatalf("r.response unexpectedly not nil: %v", r.response)
|
||||||
}
|
}
|
||||||
@ -64,7 +64,7 @@ func TestDHCP(t *testing.T) {
|
|||||||
t.Run("xps.lan.", func(t *testing.T) {
|
t.Run("xps.lan.", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("xps.lan.", dns.TypeA)
|
m.SetQuestion("xps.lan.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
@ -80,7 +80,7 @@ func TestDHCP(t *testing.T) {
|
|||||||
t.Run("notfound.lan.", func(t *testing.T) {
|
t.Run("notfound.lan.", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("notfound.lan.", dns.TypeA)
|
m.SetQuestion("notfound.lan.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
||||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
@ -99,7 +99,7 @@ func TestHostname(t *testing.T) {
|
|||||||
t.Run("A", func(t *testing.T) {
|
t.Run("A", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(hostname+".lan.", dns.TypeA)
|
m.SetQuestion(hostname+".lan.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||||
}
|
}
|
||||||
@ -115,7 +115,7 @@ func TestHostname(t *testing.T) {
|
|||||||
t.Run("PTR", func(t *testing.T) {
|
t.Run("PTR", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR)
|
m.SetQuestion("2.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
@ -136,7 +136,7 @@ func TestLocalhost(t *testing.T) {
|
|||||||
t.Run("A", func(t *testing.T) {
|
t.Run("A", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("localhost.", dns.TypeA)
|
m.SetQuestion("localhost.", dns.TypeA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ func TestLocalhost(t *testing.T) {
|
|||||||
t.Run("AAAA", func(t *testing.T) {
|
t.Run("AAAA", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("localhost.", dns.TypeAAAA)
|
m.SetQuestion("localhost.", dns.TypeAAAA)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||||
}
|
}
|
||||||
@ -168,7 +168,7 @@ func TestLocalhost(t *testing.T) {
|
|||||||
t.Run("PTR", func(t *testing.T) {
|
t.Run("PTR", func(t *testing.T) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
m.SetQuestion("1.0.0.127.in-addr.arpa.", dns.TypePTR)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
@ -218,7 +218,7 @@ func TestDHCPReverse(t *testing.T) {
|
|||||||
})
|
})
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(test.question, dns.TypePTR)
|
m.SetQuestion(test.question, dns.TypePTR)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := len(r.response.Answer), 1; got != want {
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
@ -237,7 +237,7 @@ func TestDHCPReverse(t *testing.T) {
|
|||||||
s := NewServer("localhost:0", "lan")
|
s := NewServer("localhost:0", "lan")
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR)
|
m.SetQuestion("254.255.31.172.in-addr.arpa.", dns.TypePTR)
|
||||||
s.handleRequest(r, m)
|
s.Mux.ServeDNS(r, m)
|
||||||
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
if got, want := r.response.Rcode, dns.RcodeNameError; got != want {
|
||||||
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
t.Fatalf("unexpected rcode: got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user