dns: fallback only once, i.e. prefer the working server next time
This commit is contained in:
parent
ccaf6ad452
commit
a05f027765
@ -39,7 +39,6 @@ type Server struct {
|
||||
|
||||
client *dns.Client
|
||||
domain string
|
||||
upstream []string
|
||||
sometimes *rate.Limiter
|
||||
prom struct {
|
||||
registry *prometheus.Registry
|
||||
@ -53,6 +52,9 @@ type Server struct {
|
||||
hostsByName map[string]string
|
||||
hostsByIP map[string]string
|
||||
subnames map[string]map[string]net.IP // hostname → subname → ip
|
||||
|
||||
upstreamMu sync.RWMutex
|
||||
upstream []string
|
||||
}
|
||||
|
||||
func NewServer(addr, domain string) *Server {
|
||||
@ -320,6 +322,14 @@ func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func (s *Server) upstreams() []string {
|
||||
s.upstreamMu.RLock()
|
||||
defer s.upstreamMu.RUnlock()
|
||||
result := make([]string, len(s.upstream))
|
||||
copy(result, s.upstream)
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
||||
q := r.Question[0]
|
||||
@ -333,7 +343,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
s.prom.questions.Observe(float64(len(r.Question)))
|
||||
s.prom.upstream.WithLabelValues("DNS").Inc()
|
||||
|
||||
for _, u := range s.upstream {
|
||||
for idx, u := range s.upstreams() {
|
||||
in, _, err := s.client.Exchange(r, u)
|
||||
if err != nil {
|
||||
if s.sometimes.Allow() {
|
||||
@ -342,7 +352,13 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
continue // fall back to next-slower upstream
|
||||
}
|
||||
w.WriteMsg(in)
|
||||
break
|
||||
if idx > 0 {
|
||||
// re-order this upstream to the front of s.upstream.
|
||||
s.upstreamMu.Lock()
|
||||
s.upstream = append(append([]string{u}, s.upstream[:idx]...), s.upstream[idx+1:]...)
|
||||
s.upstreamMu.Unlock()
|
||||
}
|
||||
return
|
||||
}
|
||||
// DNS has no reply for resolving errors
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -77,27 +78,58 @@ func TestResolveFallback(t *testing.T) {
|
||||
s := NewServer("localhost:0", "lan")
|
||||
s.upstream = []string{
|
||||
"266.266.266.266:53",
|
||||
}
|
||||
{
|
||||
pc, err := net.ListenPacket("udp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go dns.ActivateAndServe(nil, pc, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1")
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
}))
|
||||
s.upstream = append(s.upstream, pc.LocalAddr().String())
|
||||
})),
|
||||
}
|
||||
|
||||
if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dnsServerAddr(t *testing.T, h dns.Handler) string {
|
||||
t.Helper()
|
||||
|
||||
pc, err := net.ListenPacket("udp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go dns.ActivateAndServe(nil, pc, h)
|
||||
return pc.LocalAddr().String()
|
||||
}
|
||||
|
||||
func TestResolveFallbackOnce(t *testing.T) {
|
||||
s := NewServer("localhost:0", "lan")
|
||||
var slowHits uint32
|
||||
s.upstream = []string{
|
||||
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
atomic.AddUint32(&slowHits, 1)
|
||||
// trigger fallback by sending no reply
|
||||
})),
|
||||
dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1")
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, rr)
|
||||
w.WriteMsg(m)
|
||||
})),
|
||||
"266.266.266.266:53",
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if got, want := atomic.LoadUint32(&slowHits), uint32(1); got != want {
|
||||
t.Errorf("slow upstream server hits = %d, wanted %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDHCP(t *testing.T) {
|
||||
r := &recorder{}
|
||||
s := NewServer("localhost:0", "lan")
|
||||
|
Loading…
x
Reference in New Issue
Block a user