diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 3a03379..f6921a4 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -18,15 +18,17 @@ package dns import ( "errors" "fmt" - "log" + "math" "net" "net/http" "os" + "sort" "strings" "sync" "time" "github.com/rtr7/router7/internal/dhcp4d" + "github.com/rtr7/router7/internal/teelogger" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -34,6 +36,8 @@ import ( "golang.org/x/time/rate" ) +var log = teelogger.NewConsole() + type Server struct { Mux *dns.ServeMux @@ -105,6 +109,11 @@ func NewServer(addr, domain string) *Server { server.Mux.HandleFunc(".", server.handleRequest) server.Mux.HandleFunc("lan.", server.handleInternal) server.Mux.HandleFunc("localhost.", server.handleInternal) + go func() { + for range time.Tick(10 * time.Second) { + server.probeUpstreamLatency() + } + }() return server } @@ -121,6 +130,52 @@ func (s *Server) initHostsLocked() { } } +type measurement struct { + upstream string + rtt time.Duration +} + +func (m measurement) String() string { + return fmt.Sprintf("{upstream: %s, rtt: %v}", m.upstream, m.rtt) +} + +func (s *Server) probeUpstreamLatency() { + upstreams := s.upstreams() + results := make([]measurement, len(upstreams)) + var wg sync.WaitGroup + for idx, u := range upstreams { + wg.Add(1) + go func(idx int, u string) { + defer wg.Done() + // resolve a most-definitely cached record + m := new(dns.Msg) + m.SetQuestion("google.ch.", dns.TypeA) + start := time.Now() + _, _, err := s.client.Exchange(m, u) + rtt := time.Since(start) + if err != nil { + // including unresponsive upstreams in results makes the update + // code simpler: + results[idx] = measurement{u, time.Duration(math.MaxInt64)} + return + } + results[idx] = measurement{u, rtt} + }(idx, u) + } + wg.Wait() + // Re-order by resolving latency: + sort.Slice(results, func(i, j int) bool { + return results[i].rtt < results[j].rtt + }) + log.Printf("probe results: %v", results) + for idx, result := range results { + upstreams[idx] = result.upstream + } + s.upstreamMu.Lock() + defer s.upstreamMu.Unlock() + s.upstream = upstreams +} + func (s *Server) hostByName(n string) (string, bool) { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/dns/dns_test.go b/internal/dns/dns_test.go index d7ee0c7..62ea4dd 100644 --- a/internal/dns/dns_test.go +++ b/internal/dns/dns_test.go @@ -79,11 +79,7 @@ func TestResolveFallback(t *testing.T) { s.upstream = []string{ "266.266.266.266:53", dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) + reply(w, r, " 3600 IN A 127.0.0.1") })), } if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { @@ -111,11 +107,7 @@ func TestResolveFallbackOnce(t *testing.T) { // trigger fallback by sending no reply })), dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { - rr, _ := dns.NewRR(r.Question[0].Name + " 3600 IN A 127.0.0.1") - m := new(dns.Msg) - m.SetReply(r) - m.Answer = append(m.Answer, rr) - w.WriteMsg(m) + reply(w, r, " 3600 IN A 127.0.0.1") })), "266.266.266.266:53", } @@ -130,6 +122,43 @@ func TestResolveFallbackOnce(t *testing.T) { } } +func reply(w dns.ResponseWriter, r *dns.Msg, response string) { + rr, _ := dns.NewRR(r.Question[0].Name + response) + m := new(dns.Msg) + m.SetReply(r) + m.Answer = append(m.Answer, rr) + w.WriteMsg(m) +} + +func TestResolveLatencySteering(t *testing.T) { + s := NewServer("localhost:0", "lan") + var slowHits uint32 + s.upstream = []string{ + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + atomic.AddUint32(&slowHits, 1) + time.Sleep(10 * time.Millisecond) + reply(w, r, " 3600 IN A 127.0.0.1") + })), + dnsServerAddr(t, dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + reply(w, r, " 3600 IN A 127.0.0.1") + })), + "266.266.266.266:53", + } + + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } + s.probeUpstreamLatency() + if err := resolveTestTarget(s, "google.ch.", net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } + + want := uint32(2) // one for resolving, one for probing + if got := atomic.LoadUint32(&slowHits); got != want { + t.Errorf("slow upstream server hits = %d, wanted %d", got, want) + } +} + func TestDHCP(t *testing.T) { r := &recorder{} s := NewServer("localhost:0", "lan") @@ -140,22 +169,11 @@ func TestDHCP(t *testing.T) { }, }) - resolveTestTarget := func(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion("testtarget.lan.", dns.TypeA) - s.Mux.ServeDNS(r, m) - if got, want := len(r.response.Answer), 1; got != want { - t.Fatalf("unexpected number of answers: got %d, want %d", got, want) + t.Run("testtarget.lan.", func(t *testing.T) { + if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil { + t.Fatal(err) } - a := r.response.Answer[0] - if _, ok := a.(*dns.A); !ok { - t.Fatalf("unexpected response type: got %T, want dns.A", a) - } - if got, want := a.(*dns.A).A, net.ParseIP("192.168.42.23"); !got.Equal(want) { - t.Fatalf("unexpected response IP: got %v, want %v", got, want) - } - } - t.Run("testtarget.lan.", resolveTestTarget) + }) expired := time.Now().Add(-1 * time.Second) s.SetLeases([]dhcp4d.Lease{ @@ -171,7 +189,11 @@ func TestDHCP(t *testing.T) { }, }) - t.Run("testtarget.lan. (expired)", resolveTestTarget) + t.Run("testtarget.lan. (expired)", func(t *testing.T) { + if err := resolveTestTarget(s, "testtarget.lan.", net.ParseIP("192.168.42.23")); err != nil { + t.Fatal(err) + } + }) t.Run("notfound.lan.", func(t *testing.T) { m := new(dns.Msg)