From a05f027765478080b22b3fd4eb764337c8a39707 Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Tue, 19 Feb 2019 08:43:56 +0100 Subject: [PATCH] dns: fallback only once, i.e. prefer the working server next time --- internal/dns/dns.go | 22 ++++++++++++++--- internal/dns/dns_test.go | 52 ++++++++++++++++++++++++++++++++-------- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 2d63a03..3a03379 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -39,7 +39,6 @@ type Server struct { client *dns.Client domain string - upstream []string sometimes *rate.Limiter prom struct { registry *prometheus.Registry @@ -53,6 +52,9 @@ type Server struct { hostsByName map[string]string hostsByIP map[string]string subnames map[string]map[string]net.IP // hostname → subname → ip + + upstreamMu sync.RWMutex + upstream []string } func NewServer(addr, domain string) *Server { @@ -320,6 +322,14 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } +func (s *Server) upstreams() []string { + s.upstreamMu.RLock() + defer s.upstreamMu.RUnlock() + result := make([]string, len(s.upstream)) + copy(result, s.upstream) + return result +} + func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 1 { // TODO: answer all questions we can answer q := r.Question[0] @@ -333,7 +343,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { s.prom.questions.Observe(float64(len(r.Question))) s.prom.upstream.WithLabelValues("DNS").Inc() - for _, u := range s.upstream { + for idx, u := range s.upstreams() { in, _, err := s.client.Exchange(r, u) if err != nil { if s.sometimes.Allow() { @@ -342,7 +352,13 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { continue // fall back to next-slower upstream } w.WriteMsg(in) - break + if idx > 0 { + // re-order this upstream to the front of s.upstream. + s.upstreamMu.Lock() + s.upstream = append(append([]string{u}, s.upstream[:idx]...), s.upstream[idx+1:]...) + s.upstreamMu.Unlock() + } + return } // DNS has no reply for resolving errors } diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index fc9031c..d7ee0c7 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -24,6 +24,7 @@ import ( "net/url" "os" "strings" + "sync/atomic" "testing" "time" @@ -77,27 +78,58 @@ func TestResolveFallback(t *testing.T) { s := NewServer("localhost:0", "lan") s.upstream = []string{ "266.266.266.266:53", - } - { - pc, err := net.ListenPacket("udp", "localhost:0") - if err != nil { - t.Fatal(err) - } - go dns.ActivateAndServe(nil, pc, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") m := new(dns.Msg) m.SetReply(r) m.Answer = append(m.Answer, rr) w.WriteMsg(m) - })) - s.upstream = append(s.upstream, pc.LocalAddr().String()) + })), } - if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { t.Fatal(err) } } +func dnsServerAddr(t *testing.T, h dns.Handler) string { + t.Helper() + + pc, err := net.ListenPacket("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go dns.ActivateAndServe(nil, pc, h) + return pc.LocalAddr().String() +} + +func TestResolveFallbackOnce(t *testing.T) { + s := NewServer("localhost:0", "lan") + var slowHits uint32 + s.upstream = []string{ + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + atomic.AddUint32(&slowHits, 1) + // trigger fallback by sending no reply + })), + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) + })), + "266.266.266.266:53", + } + + for i := 0; i < 2; i++ { + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } + } + if got, want := atomic.LoadUint32(&slowHits), uint32(1); got != want { + t.Errorf("slow upstream server hits = %d, wanted %d", got, want) + } +} + func TestDHCP(t *testing.T) { r := &recorder{} s := NewServer("localhost:0", "lan")