diff --git a/internal/dns/dns.go b/internal/dns/dns.go index c404884..2d63a03 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -39,7 +39,7 @@ type Server struct { client *dns.Client domain string - upstream string + upstream []string sometimes *rate.Limiter prom struct { registry *prometheus.Registry @@ -59,10 +59,16 @@ func NewServer(addr, domain string) *Server { hostname, _ := os.Hostname() ip, _, _ := net.SplitHostPort(addr) server := &Server{ - Mux: dns.NewServeMux(), - client: &dns.Client{}, - domain: domain, - upstream: "8.8.8.8:53", + Mux: dns.NewServeMux(), + client: &dns.Client{}, + domain: domain, + upstream: []string{ + // https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses + "8.8.8.8:53", + "8.8.4.4:53", + "[2001:4860:4860::8888]:53", + "[2001:4860:4860::8844]:53", + }, sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second hostname: hostname, ip: ip, @@ -327,14 +333,18 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { s.prom.questions.Observe(float64(len(r.Question))) s.prom.upstream.WithLabelValues("DNS").Inc() - in, _, err := s.client.Exchange(r, s.upstream) - if err != nil { - if s.sometimes.Allow() { - log.Printf("resolving %v failed: %v", r.Question, err) + for _, u := range s.upstream { + in, _, err := s.client.Exchange(r, u) + if err != nil { + if s.sometimes.Allow() { + log.Printf("resolving %v failed: %v", r.Question, err) + } + continue // fall back to next-slower upstream } - return // DNS has no reply for resolving errors + w.WriteMsg(in) + break } - w.WriteMsg(in) + // DNS has no reply for resolving errors } func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) { diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index ed51f71..fc9031c 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -16,6 +16,7 @@ package dns import ( "bytes" + "fmt" "io/ioutil" "net" "net/http" @@ -63,7 +64,7 @@ func TestNXDOMAIN(t *testing.T) { func TestResolveError(t *testing.T) { r := &recorder{} s := NewServer("localhost:0", "lan") - s.upstream = "266.266.266.266:53" + s.upstream = []string{"266.266.266.266:53"} m := new(dns.Msg) m.SetQuestion("foo.invalid.", dns.TypeA) s.Mux.ServeDNS(r, m) @@ -72,6 +73,31 @@ func TestResolveError(t *testing.T) { } } +func TestResolveFallback(t *testing.T) { + s := NewServer("localhost:0", "lan") + s.upstream = []string{ + "266.266.266.266:53", + } + { + pc, err := net.ListenPacket("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go dns.ActivateAndServe(nil, pc, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + })) + s.upstream = append(s.upstream, pc.LocalAddr().String()) + } + + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } +} + func TestDHCP(t *testing.T) { r := &recorder{} s := NewServer("localhost:0", "lan") @@ -343,6 +369,40 @@ func TestDHCPReverse(t *testing.T) { } +func resolveTestTarget(s *Server, name string, want net.IP) error { + m := new(dns.Msg) + typ := dns.TypeA + if want.To4() == nil { + typ = dns.TypeAAAA + } + m.SetQuestion(name, typ) + r := &recorder{} + s.Mux.ServeDNS(r, m) + if r.response == nil { + return fmt.Errorf("nil response") + } + if got, want := len(r.response.Answer), 1; got != want { + return fmt.Errorf("unexpected number of answers: got %d, want %d", got, want) + } + a := r.response.Answer[0] + if typ == dns.TypeA { + if _, ok := a.(*dns.A); !ok { + return fmt.Errorf("unexpected response type: got %T, want dns.A", a) + } + if got := a.(*dns.A).A; !got.Equal(want) { + return fmt.Errorf("unexpected response IP: got %v, want %v", got, want) + } + } else { + if _, ok := a.(*dns.AAAA); !ok { + return fmt.Errorf("unexpected response type: got %T, want dns.A", a) + } + if got := a.(*dns.AAAA).AAAA; !got.Equal(want) { + return fmt.Errorf("unexpected response IP: got %v, want %v", got, want) + } + } + return nil +} + // TODO: multiple questions func TestSubname(t *testing.T) { @@ -355,40 +415,10 @@ func TestSubname(t *testing.T) { }, }) - resolveTestTarget := func(t *testing.T, name string, want net.IP) { - m := new(dns.Msg) - typ := dns.TypeA - if want.To4() == nil { - typ = dns.TypeAAAA - } - m.SetQuestion(name, typ) - s.Mux.ServeDNS(r, m) - if r.response == nil { - t.Fatalf("nil response") - } - if got, want := len(r.response.Answer), 1; got != want { - t.Fatalf("unexpected number of answers: got %d, want %d", got, want) - } - a := r.response.Answer[0] - if typ == dns.TypeA { - if _, ok := a.(*dns.A); !ok { - t.Fatalf("unexpected response type: got %T, want dns.A", a) - } - if got := a.(*dns.A).A; !got.Equal(want) { - t.Fatalf("unexpected response IP: got %v, want %v", got, want) - } - } else { - if _, ok := a.(*dns.AAAA); !ok { - t.Fatalf("unexpected response type: got %T, want dns.A", a) - } - if got := a.(*dns.AAAA).AAAA; !got.Equal(want) { - t.Fatalf("unexpected response IP: got %v, want %v", got, want) - } - } - } - t.Run("testtarget.lan.", func(t *testing.T) { - resolveTestTarget(t, "testtarget.lan.", net.ParseIP("192.168.42.23")) + if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil { + t.Fatal(err) + } }) t.Run("sub.testtarget.lan.", func(t *testing.T) { @@ -424,7 +454,9 @@ func TestSubname(t *testing.T) { "sub.testtarget.", } { t.Run(name+" (after dyndns)", func(t *testing.T) { - resolveTestTarget(t, name, net.ParseIP(ip)) + if err := resolveTestTarget(s, name, net.ParseIP(ip)); err != nil { + t.Fatal(err) + } }) } @@ -433,8 +465,12 @@ func TestSubname(t *testing.T) { if err != nil { t.Skipf("os.Hostname: %v", err) } - resolveTestTarget(t, hostname+".lan.", net.ParseIP("127.0.0.2")) + if err := resolveTestTarget(s, hostname+".lan.", net.ParseIP("127.0.0.2")); err != nil { + t.Fatal(err) + } setSubname(ip, "127.0.0.2:1234") - resolveTestTarget(t, "sub."+hostname+".lan.", net.ParseIP(ip)) + if err := resolveTestTarget(s, "sub."+hostname+".lan.", net.ParseIP(ip)); err != nil { + t.Fatal(err) + } }) }