dns: fallback to next upstream upon failure

This commit is contained in:
Michael Stapelberg 2019-02-19 08:32:00 +01:00
parent abeddabbb7
commit ccaf6ad452
2 changed files with 94 additions and 48 deletions

View File

@ -39,7 +39,7 @@ type Server struct {
client *dns.Client client *dns.Client
domain string domain string
upstream string upstream []string
sometimes *rate.Limiter sometimes *rate.Limiter
prom struct { prom struct {
registry *prometheus.Registry registry *prometheus.Registry
@ -62,7 +62,13 @@ func NewServer(addr, domain string) *Server {
Mux: dns.NewServeMux(), Mux: dns.NewServeMux(),
client: &dns.Client{}, client: &dns.Client{},
domain: domain, domain: domain,
upstream: "8.8.8.8:53", upstream: []string{
// https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses
"8.8.8.8:53",
"8.8.4.4:53",
"[2001:4860:4860::8888]:53",
"[2001:4860:4860::8844]:53",
},
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname, hostname: hostname,
ip: ip, ip: ip,
@ -327,14 +333,18 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.questions.Observe(float64(len(r.Question))) s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc() s.prom.upstream.WithLabelValues("DNS").Inc()
in, _, err := s.client.Exchange(r, s.upstream) for _, u := range s.upstream {
in, _, err := s.client.Exchange(r, u)
if err != nil { if err != nil {
if s.sometimes.Allow() { if s.sometimes.Allow() {
log.Printf("resolving %v failed: %v", r.Question, err) log.Printf("resolving %v failed: %v", r.Question, err)
} }
return // DNS has no reply for resolving errors continue // fall back to next-slower upstream
} }
w.WriteMsg(in) w.WriteMsg(in)
break
}
// DNS has no reply for resolving errors
} }
func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) { func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) {

View File

@ -16,6 +16,7 @@ package dns
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
@ -63,7 +64,7 @@ func TestNXDOMAIN(t *testing.T) {
func TestResolveError(t *testing.T) { func TestResolveError(t *testing.T) {
r := &recorder{} r := &recorder{}
s := NewServer("localhost:0", "lan") s := NewServer("localhost:0", "lan")
s.upstream = "266.266.266.266:53" s.upstream = []string{"266.266.266.266:53"}
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("foo.invalid.", dns.TypeA) m.SetQuestion("foo.invalid.", dns.TypeA)
s.Mux.ServeDNS(r, m) s.Mux.ServeDNS(r, m)
@ -72,6 +73,31 @@ func TestResolveError(t *testing.T) {
} }
} }
func TestResolveFallback(t *testing.T) {
s := NewServer("localhost:0", "lan")
s.upstream = []string{
"266.266.266.266:53",
}
{
pc, err := net.ListenPacket("udp", "localhost:0")
if err != nil {
t.Fatal(err)
}
go dns.ActivateAndServe(nil, pc, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1")
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
}))
s.upstream = append(s.upstream, pc.LocalAddr().String())
}
if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil {
t.Fatal(err)
}
}
func TestDHCP(t *testing.T) { func TestDHCP(t *testing.T) {
r := &recorder{} r := &recorder{}
s := NewServer("localhost:0", "lan") s := NewServer("localhost:0", "lan")
@ -343,6 +369,40 @@ func TestDHCPReverse(t *testing.T) {
} }
func resolveTestTarget(s *Server, name string, want net.IP) error {
m := new(dns.Msg)
typ := dns.TypeA
if want.To4() == nil {
typ = dns.TypeAAAA
}
m.SetQuestion(name, typ)
r := &recorder{}
s.Mux.ServeDNS(r, m)
if r.response == nil {
return fmt.Errorf("nil response")
}
if got, want := len(r.response.Answer), 1; got != want {
return fmt.Errorf("unexpected number of answers: got %d, want %d", got, want)
}
a := r.response.Answer[0]
if typ == dns.TypeA {
if _, ok := a.(*dns.A); !ok {
return fmt.Errorf("unexpected response type: got %T, want dns.A", a)
}
if got := a.(*dns.A).A; !got.Equal(want) {
return fmt.Errorf("unexpected response IP: got %v, want %v", got, want)
}
} else {
if _, ok := a.(*dns.AAAA); !ok {
return fmt.Errorf("unexpected response type: got %T, want dns.A", a)
}
if got := a.(*dns.AAAA).AAAA; !got.Equal(want) {
return fmt.Errorf("unexpected response IP: got %v, want %v", got, want)
}
}
return nil
}
// TODO: multiple questions // TODO: multiple questions
func TestSubname(t *testing.T) { func TestSubname(t *testing.T) {
@ -355,40 +415,10 @@ func TestSubname(t *testing.T) {
}, },
}) })
resolveTestTarget := func(t *testing.T, name string, want net.IP) {
m := new(dns.Msg)
typ := dns.TypeA
if want.To4() == nil {
typ = dns.TypeAAAA
}
m.SetQuestion(name, typ)
s.Mux.ServeDNS(r, m)
if r.response == nil {
t.Fatalf("nil response")
}
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers: got %d, want %d", got, want)
}
a := r.response.Answer[0]
if typ == dns.TypeA {
if _, ok := a.(*dns.A); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a)
}
if got := a.(*dns.A).A; !got.Equal(want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
}
} else {
if _, ok := a.(*dns.AAAA); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a)
}
if got := a.(*dns.AAAA).AAAA; !got.Equal(want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
}
}
}
t.Run("testtarget.lan.", func(t *testing.T) { t.Run("testtarget.lan.", func(t *testing.T) {
resolveTestTarget(t, "testtarget.lan.", net.ParseIP("192.168.42.23")) if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil {
t.Fatal(err)
}
}) })
t.Run("sub.testtarget.lan.", func(t *testing.T) { t.Run("sub.testtarget.lan.", func(t *testing.T) {
@ -424,7 +454,9 @@ func TestSubname(t *testing.T) {
"sub.testtarget.", "sub.testtarget.",
} { } {
t.Run(name+" (after dyndns)", func(t *testing.T) { t.Run(name+" (after dyndns)", func(t *testing.T) {
resolveTestTarget(t, name, net.ParseIP(ip)) if err := resolveTestTarget(s, name, net.ParseIP(ip)); err != nil {
t.Fatal(err)
}
}) })
} }
@ -433,8 +465,12 @@ func TestSubname(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("os.Hostname: %v", err) t.Skipf("os.Hostname: %v", err)
} }
resolveTestTarget(t, hostname+".lan.", net.ParseIP("127.0.0.2")) if err := resolveTestTarget(s, hostname+".lan.", net.ParseIP("127.0.0.2")); err != nil {
t.Fatal(err)
}
setSubname(ip, "127.0.0.2:1234") setSubname(ip, "127.0.0.2:1234")
resolveTestTarget(t, "sub."+hostname+".lan.", net.ParseIP(ip)) if err := resolveTestTarget(s, "sub."+hostname+".lan.", net.ParseIP(ip)); err != nil {
t.Fatal(err)
}
}) })
} }