dns: fallback only once, i.e. prefer the working server next time

This commit is contained in:
Michael Stapelberg 2019-02-19 08:43:56 +01:00
parent ccaf6ad452
commit a05f027765
2 changed files with 61 additions and 13 deletions

View File

@ -39,7 +39,6 @@ type Server struct {
client *dns.Client
domain string
upstream []string
sometimes *rate.Limiter
prom struct {
registry *prometheus.Registry
@ -53,6 +52,9 @@ type Server struct {
hostsByName map[string]string
hostsByIP map[string]string
subnames map[string]map[string]net.IP // hostname → subname → ip
upstreamMu sync.RWMutex
upstream []string
}
func NewServer(addr, domain string) *Server {
@ -320,6 +322,14 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func (s *Server) upstreams() []string {
s.upstreamMu.RLock()
defer s.upstreamMu.RUnlock()
result := make([]string, len(s.upstream))
copy(result, s.upstream)
return result
}
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
@ -333,7 +343,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
for _, u := range s.upstream {
for idx, u := range s.upstreams() {
in, _, err := s.client.Exchange(r, u)
if err != nil {
if s.sometimes.Allow() {
@ -342,7 +352,13 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
continue // fall back to next-slower upstream
}
w.WriteMsg(in)
break
if idx > 0 {
// re-order this upstream to the front of s.upstream.
s.upstreamMu.Lock()
s.upstream = append(append([]string{u}, s.upstream[:idx]...), s.upstream[idx+1:]...)
s.upstreamMu.Unlock()
}
return
}
// DNS has no reply for resolving errors
}

View File

@ -24,6 +24,7 @@ import (
"net/url"
"os"
"strings"
"sync/atomic"
"testing"
"time"
@ -77,27 +78,58 @@ func TestResolveFallback(t *testing.T) {
s := NewServer("localhost:0", "lan")
s.upstream = []string{
"266.266.266.266:53",
}
{
pc, err := net.ListenPacket("udp", "localhost:0")
if err != nil {
t.Fatal(err)
}
go dns.ActivateAndServe(nil, pc, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1")
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
}))
s.upstream = append(s.upstream, pc.LocalAddr().String())
})),
}
if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil {
t.Fatal(err)
}
}
func dnsServerAddr(t *testing.T, h dns.Handler) string {
t.Helper()
pc, err := net.ListenPacket("udp", "localhost:0")
if err != nil {
t.Fatal(err)
}
go dns.ActivateAndServe(nil, pc, h)
return pc.LocalAddr().String()
}
func TestResolveFallbackOnce(t *testing.T) {
s := NewServer("localhost:0", "lan")
var slowHits uint32
s.upstream = []string{
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
atomic.AddUint32(&slowHits, 1)
// trigger fallback by sending no reply
})),
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1")
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
})),
"266.266.266.266:53",
}
for i := 0; i < 2; i++ {
if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil {
t.Fatal(err)
}
}
if got, want := atomic.LoadUint32(&slowHits), uint32(1); got != want {
t.Errorf("slow upstream server hits = %d, wanted %d", got, want)
}
}
func TestDHCP(t *testing.T) {
r := &recorder{}
s := NewServer("localhost:0", "lan")