From 66942bd4f734e86f7c95ba3a00e45e35f1abf1cd Mon Sep 17 00:00:00 2001 From: Michael Stapelberg Date: Tue, 19 Feb 2019 09:19:32 +0100 Subject: [PATCH] dns: steer traffic toward the most response upstream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There is not much of a difference between IPv4 and IPv6 on my fiber7 link, but in other networks, there might well be. For my connection, this commit results in hitting different upstreams all the time because the list isn’t stable. But, this is the opposite of a problem: we are spreading the DNS query load over all configured IPs, as good netizen do. --- internal/dns/dns.go | 57 ++++++++++++++++++++++++++++++- internal/dns/dns_test.go | 74 ++++++++++++++++++++++++++-------------- 2 files changed, 104 insertions(+), 27 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 3a03379..f6921a4 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -18,15 +18,17 @@ package dns import ( "errors" "fmt" - "log" + "math" "net" "net/http" "os" + "sort" "strings" "sync" "time" "github.com/rtr7/router7/internal/dhcp4d" + "github.com/rtr7/router7/internal/teelogger" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -34,6 +36,8 @@ import ( "golang.org/x/time/rate" ) +var log = teelogger.NewConsole() + type Server struct { Mux *dns.ServeMux @@ -105,6 +109,11 @@ func NewServer(addr, domain string) *Server { server.Mux.HandleFunc(".", server.handleRequest) server.Mux.HandleFunc("lan.", server.handleInternal) server.Mux.HandleFunc("localhost.", server.handleInternal) + go func() { + for range time.Tick(10 * time.Second) { + server.probeUpstreamLatency() + } + }() return server } @@ -121,6 +130,52 @@ func (s *Server) initHostsLocked() { } } +type measurement struct { + upstream string + rtt time.Duration +} + +func (m measurement) String() string { + return fmt.Sprintf("{upstream: %s, rtt: %v}", m.upstream, m.rtt) +} + +func (s *Server) probeUpstreamLatency() { + upstreams := s.upstreams() + results := make([]measurement, len(upstreams)) + var wg sync.WaitGroup + for idx, u := range upstreams { + wg.Add(1) + go func(idx int, u string) { + defer wg.Done() + // resolve a most-definitely cached record + m := new(dns.Msg) + m.SetQuestion("google.ch.", dns.TypeA) + start := time.Now() + _, _, err := s.client.Exchange(m, u) + rtt := time.Since(start) + if err != nil { + // including unresponsive upstreams in results makes the update + // code simpler: + results[idx] = measurement{u, time.Duration(math.MaxInt64)} + return + } + results[idx] = measurement{u, rtt} + }(idx, u) + } + wg.Wait() + // Re-order by resolving latency: + sort.Slice(results, func(i, j int) bool { + return results[i].rtt < results[j].rtt + }) + log.Printf("probe results: %v", results) + for idx, result := range results { + upstreams[idx] = result.upstream + } + s.upstreamMu.Lock() + defer s.upstreamMu.Unlock() + s.upstream = upstreams +} + func (s *Server) hostByName(n string) (string, bool) { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index d7ee0c7..62ea4dd 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -79,11 +79,7 @@ func TestResolveFallback(t *testing.T) { s.upstream = []string{ "266.266.266.266:53", dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) + reply(w, r, " 3600 IN A 127.0.0.1") })), } if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { @@ -111,11 +107,7 @@ func TestResolveFallbackOnce(t *testing.T) { // trigger fallback by sending no reply })), dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) + reply(w, r, " 3600 IN A 127.0.0.1") })), "266.266.266.266:53", } @@ -130,6 +122,43 @@ func TestResolveFallbackOnce(t *testing.T) { } } +func reply(w dns.ResponseWriter, r *dns.Msg, response string) { + rr, _ := dns.NewRR(r.Question[0].Name + response) + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) +} + +func TestResolveLatencySteering(t *testing.T) { + s := NewServer("localhost:0", "lan") + var slowHits uint32 + s.upstream = []string{ + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + atomic.AddUint32(&slowHits, 1) + time.Sleep(10 * time.Millisecond) + reply(w, r, " 3600 IN A 127.0.0.1") + })), + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + reply(w, r, " 3600 IN A 127.0.0.1") + })), + "266.266.266.266:53", + } + + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } + s.probeUpstreamLatency() + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } + + want := uint32(2) // one for resolving, one for probing + if got := atomic.LoadUint32(&slowHits); got != want { + t.Errorf("slow upstream server hits = %d, wanted %d", got, want) + } +} + func TestDHCP(t *testing.T) { r := &recorder{} s := NewServer("localhost:0", "lan") @@ -140,22 +169,11 @@ func TestDHCP(t *testing.T) { }, }) - resolveTestTarget := func(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion("testtarget.lan.", dns.TypeA) - s.Mux.ServeDNS(r, m) - if got, want := len(r.response.Answer), 1; got != want { - t.Fatalf("unexpected number of answers: got %d, want %d", got, want) + t.Run("testtarget.lan.", func(t *testing.T) { + if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil { + t.Fatal(err) } - a := r.response.Answer[0] - if _, ok := a.(*dns.A); !ok { - t.Fatalf("unexpected response type: got %T, want dns.A", a) - } - if got, want := a.(*dns.A).A, net.ParseIP("192.168.42.23"); !got.Equal(want) { - t.Fatalf("unexpected response IP: got %v, want %v", got, want) - } - } - t.Run("testtarget.lan.", resolveTestTarget) + }) expired := time.Now().Add(-1 * time.Second) s.SetLeases([]dhcp4d.Lease{ @@ -171,7 +189,11 @@ func TestDHCP(t *testing.T) { }, }) - t.Run("testtarget.lan. (expired)", resolveTestTarget) + t.Run("testtarget.lan. (expired)", func(t *testing.T) { + if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil { + t.Fatal(err) + } + }) t.Run("notfound.lan.", func(t *testing.T) { m := new(dns.Msg)