506 lines
13 KiB
Go
Raw Normal View History

2018-06-28 13:39:48 +02:00
// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2018-06-25 20:12:51 +02:00
// Package dns implements a DNS forwarder.
2018-05-27 17:30:42 +02:00
package dns
import (
"errors"
"fmt"
"math"
"net"
"net/http"
"os"
"sort"
2018-05-27 17:30:42 +02:00
"strings"
"sync"
2018-05-27 17:30:42 +02:00
"time"
2018-07-09 08:54:04 +02:00
"github.com/rtr7/router7/internal/dhcp4d"
"github.com/rtr7/router7/internal/teelogger"
2018-05-27 17:30:42 +02:00
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/time/rate"
2018-05-27 17:30:42 +02:00
)
var log = teelogger.NewConsole()
// lcHostname is a string type used for lower-cased hostnames so that the
// DHCP-based local name resolution can be made case-insensitive.
type lcHostname string
2018-05-27 17:30:42 +02:00
type Server struct {
2018-06-26 09:32:34 +02:00
Mux *dns.ServeMux
client *dns.Client
domain string
sometimes *rate.Limiter
prom struct {
registry *prometheus.Registry
queries prometheus.Counter
upstream *prometheus.CounterVec
questions prometheus.Histogram
}
mu sync.Mutex
hostname, ip string
hostsByName map[lcHostname]string
hostsByIP map[string]string
subnames map[lcHostname]map[string]net.IP // hostname → subname → ip
upstreamMu sync.RWMutex
upstream []string
2018-05-27 17:30:42 +02:00
}
func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr)
2018-05-27 17:30:42 +02:00
server := &Server{
Mux: dns.NewServeMux(),
client: &dns.Client{},
domain: domain,
upstream: []string{
// https://developers.google.com/speed/public-dns/docs/using#google_public_dns_ip_addresses
"8.8.8.8:53",
"8.8.4.4:53",
"[2001:4860:4860::8888]:53",
"[2001:4860:4860::8844]:53",
},
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname,
ip: ip,
subnames: make(map[lcHostname]map[string]net.IP),
2018-05-27 17:30:42 +02:00
}
server.prom.registry = prometheus.NewRegistry()
server.prom.queries = prometheus.NewCounter(prometheus.CounterOpts{
Name: "dns_queries",
Help: "Number of DNS queries received",
})
server.prom.registry.MustRegister(server.prom.queries)
server.prom.upstream = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "dns_upstream",
Help: "Which upstream answered which DNS query",
},
[]string{"upstream"},
)
server.prom.registry.MustRegister(server.prom.upstream)
server.prom.questions = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "dns_questions",
Help: "Number of questions in each DNS request",
Buckets: prometheus.LinearBuckets(0, 1, 10),
})
server.prom.registry.MustRegister(server.prom.questions)
server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked()
2018-06-26 09:32:34 +02:00
server.Mux.HandleFunc(".", server.handleRequest)
server.Mux.HandleFunc("lan.", server.handleInternal)
server.Mux.HandleFunc("localhost.", server.handleInternal)
go func() {
for range time.Tick(10 * time.Second) {
server.probeUpstreamLatency()
}
}()
2018-05-27 17:30:42 +02:00
return server
}
func (s *Server) initHostsLocked() {
s.hostsByName = make(map[lcHostname]string)
s.hostsByIP = make(map[string]string)
if s.hostname != "" && s.ip != "" {
lower := strings.ToLower(s.hostname)
s.hostsByName[lcHostname(lower)] = s.ip
if rev, err := dns.ReverseAddr(s.ip); err == nil {
s.hostsByIP[rev] = s.hostname
}
s.Mux.HandleFunc(lower+".", s.subnameHandler(s.hostname))
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(s.hostname))
}
}
type measurement struct {
upstream string
rtt time.Duration
}
func (m measurement) String() string {
return fmt.Sprintf("{upstream: %s, rtt: %v}", m.upstream, m.rtt)
}
func (s *Server) probeUpstreamLatency() {
upstreams := s.upstreams()
results := make([]measurement, len(upstreams))
var wg sync.WaitGroup
for idx, u := range upstreams {
wg.Add(1)
go func(idx int, u string) {
defer wg.Done()
// resolve a most-definitely cached record
m := new(dns.Msg)
m.SetQuestion("google.ch.", dns.TypeA)
start := time.Now()
_, _, err := s.client.Exchange(m, u)
rtt := time.Since(start)
if err != nil {
// including unresponsive upstreams in results makes the update
// code simpler:
results[idx] = measurement{u, time.Duration(math.MaxInt64)}
return
}
results[idx] = measurement{u, rtt}
}(idx, u)
}
wg.Wait()
// Re-order by resolving latency:
sort.Slice(results, func(i, j int) bool {
return results[i].rtt < results[j].rtt
})
log.Printf("probe results: %v", results)
for idx, result := range results {
upstreams[idx] = result.upstream
}
s.upstreamMu.Lock()
defer s.upstreamMu.Unlock()
s.upstream = upstreams
}
func (s *Server) hostByName(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByName[lcHostname(strings.ToLower(n))]
return r, ok
}
func (s *Server) hostByIP(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByIP[n]
return r, ok
}
func (s *Server) subname(hostname, host string) (net.IP, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][host]
return r, ok
}
func (s *Server) PrometheusHandler() http.Handler {
return promhttp.HandlerFor(s.prom.registry, promhttp.HandlerOpts{})
}
func (s *Server) DyndnsHandler(w http.ResponseWriter, r *http.Request) {
host := r.FormValue("host")
ip := net.ParseIP(r.FormValue("ip"))
if ip == nil {
http.Error(w, "invalid ip", http.StatusBadRequest)
return
}
s.mu.Lock()
defer s.mu.Unlock()
remote, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, fmt.Sprintf("net.SplitHostPort(%q): %v", r.RemoteAddr, err), http.StatusBadRequest)
return
}
rev, err := dns.ReverseAddr(remote)
if err != nil {
http.Error(w, fmt.Sprintf("dns.ReverseAddr(%v): %v", remote, err), http.StatusBadRequest)
return
}
hostname, ok := s.hostsByIP[rev]
if !ok {
err := fmt.Sprintf("connection without corresponding DHCP lease: %v", rev)
http.Error(w, err, http.StatusForbidden)
return
}
lower := strings.ToLower(hostname)
subnames, ok := s.subnames[lcHostname(lower)]
if !ok {
subnames = make(map[string]net.IP)
s.subnames[lcHostname(lower)] = subnames
}
subnames[host] = ip
w.Write([]byte("ok\n"))
}
2018-05-27 17:30:42 +02:00
func (s *Server) SetLeases(leases []dhcp4d.Lease) {
s.mu.Lock()
defer s.mu.Unlock()
s.initHostsLocked()
now := time.Now()
{
// defensive copy
slice := make([]dhcp4d.Lease, len(leases))
copy(slice, leases)
leases = slice
}
// First entry wins, so we order by expiration descendingly to put the
// newest entry for any given name into s.hostsByName.
sort.Slice(leases, func(i, j int) bool {
return !leases[i].Expiry.Before(leases[j].Expiry)
})
2018-05-27 17:30:42 +02:00
for _, l := range leases {
if l.Expired(now) {
continue
}
if l.Hostname == "" {
continue
}
lower := strings.ToLower(l.Hostname)
if _, ok := s.hostsByName[lcHostname(lower)]; ok {
continue // dont overwrite e.g. the hostname entry
}
s.hostsByName[lcHostname(lower)] = l.Addr.String()
2018-05-27 17:30:42 +02:00
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
s.hostsByIP[rev] = l.Hostname
}
s.Mux.HandleFunc(lower+".", s.subnameHandler(lower))
s.Mux.HandleFunc(lower+"."+s.domain+".", s.subnameHandler(lower))
2018-05-27 17:30:42 +02:00
}
}
func mustParseCIDR(s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return ipnet
}
var (
localNets = []*net.IPNet{
// loopback: https://tools.ietf.org/html/rfc3330#section-2
mustParseCIDR("127.0.0.0/8"),
// loopback: https://tools.ietf.org/html/rfc3513#section-2.4
mustParseCIDR("::1/128"),
// reversed: https://tools.ietf.org/html/rfc1918#section-3
mustParseCIDR("10.0.0.0/8"),
mustParseCIDR("172.16.0.0/12"),
mustParseCIDR("192.168.0.0/16"),
}
)
func reverse(ss []string) {
last := len(ss) - 1
for i := 0; i < len(ss)/2; i++ {
ss[i], ss[last-i] = ss[last-i], ss[i]
}
}
func isLocalInAddrArpa(q string) bool {
if !strings.HasSuffix(q, ".in-addr.arpa.") {
return false
}
parts := strings.Split(strings.TrimSuffix(q, ".in-addr.arpa."), ".")
reverse(parts)
ip := net.ParseIP(strings.Join(parts, "."))
if ip == nil {
return false
}
var local bool
for _, l := range localNets {
if l.Contains(ip) {
local = true
break
}
}
return local
}
var errEmpty = errors.New("no answers")
func (s *Server) resolve(q dns.Question) (rr dns.RR, err error) {
2018-06-26 09:32:34 +02:00
if q.Qclass != dns.ClassINET {
return nil, nil
}
if strings.ToLower(q.Name) == "localhost." {
2018-06-26 09:32:34 +02:00
if q.Qtype == dns.TypeAAAA {
return dns.NewRR(q.Name + " 3600 IN AAAA ::1")
}
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A 127.0.0.1")
}
}
if q.Qtype == dns.TypeA ||
q.Qtype == dns.TypeAAAA ||
q.Qtype == dns.TypeMX {
2018-06-26 09:32:34 +02:00
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if host, ok := s.hostByName(name); ok {
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
2018-06-26 09:32:34 +02:00
}
}
if q.Qtype == dns.TypePTR {
if host, ok := s.hostByIP(q.Name); ok {
return dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
}
if strings.HasSuffix(q.Name, "127.in-addr.arpa.") {
return dns.NewRR(q.Name + " 3600 IN PTR localhost.")
}
}
return nil, nil
}
func (s *Server) handleInternal(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
2018-06-26 09:32:34 +02:00
s.prom.upstream.WithLabelValues("local").Inc()
if len(r.Question) != 1 { // TODO: answer all questions we can answer
return
}
rr, err := s.resolve(r.Question[0])
if err != nil {
if err == errEmpty {
m := new(dns.Msg)
m.SetReply(r)
w.WriteMsg(m)
return
}
2018-06-26 09:32:34 +02:00
log.Fatal(err)
}
if rr != nil {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
}
func (s *Server) upstreams() []string {
s.upstreamMu.RLock()
defer s.upstreamMu.RUnlock()
result := make([]string, len(s.upstream))
copy(result, s.upstream)
return result
}
2018-06-26 09:32:34 +02:00
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
2018-05-27 17:30:42 +02:00
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
2018-06-26 09:32:34 +02:00
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET && isLocalInAddrArpa(q.Name) {
s.handleInternal(w, r)
return
2018-05-27 17:30:42 +02:00
}
}
2018-06-26 09:32:34 +02:00
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
s.prom.upstream.WithLabelValues("DNS").Inc()
for idx, u := range s.upstreams() {
in, _, err := s.client.Exchange(r, u)
if err != nil {
if s.sometimes.Allow() {
log.Printf("resolving %v failed: %v", r.Question, err)
}
continue // fall back to next-slower upstream
2018-05-27 17:30:42 +02:00
}
w.WriteMsg(in)
if idx > 0 {
// re-order this upstream to the front of s.upstream.
s.upstreamMu.Lock()
s.upstream = append(append([]string{u}, s.upstream[:idx]...), s.upstream[idx+1:]...)
s.upstreamMu.Unlock()
}
return
2018-05-27 17:30:42 +02:00
}
// DNS has no reply for resolving errors
2018-05-27 17:30:42 +02:00
}
func (s *Server) resolveSubname(hostname string, q dns.Question) (dns.RR, error) {
if q.Qclass != dns.ClassINET {
return nil, nil
}
if q.Qtype == dns.TypeA ||
q.Qtype == dns.TypeAAAA ||
q.Qtype == dns.TypeMX {
name := strings.TrimSuffix(q.Name, "."+hostname+".")
name = strings.TrimSuffix(name, "."+hostname+"."+s.domain+".")
if lower := strings.ToLower(q.Name); lower == hostname+"." ||
lower == hostname+"."+s.domain+"." {
host, ok := s.hostByName(hostname)
if !ok {
// The corresponding DHCP lease might have expired, but this
// handler is still installed on the mux.
return nil, nil // NXDOMAIN
}
if q.Qtype == dns.TypeA {
return dns.NewRR(q.Name + " 3600 IN A " + host)
}
return nil, errEmpty
}
if ip, ok := s.subname(hostname, name); ok {
if q.Qtype == dns.TypeA && ip.To4() != nil {
return dns.NewRR(q.Name + " 3600 IN A " + ip.String())
}
if q.Qtype == dns.TypeAAAA && ip.To4() == nil {
return dns.NewRR(q.Name + " 3600 IN AAAA " + ip.String())
}
return nil, errEmpty
}
}
return nil, nil
}
func (s *Server) subnameHandler(hostname string) func(w dns.ResponseWriter, r *dns.Msg) {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) != 1 { // TODO: answer all questions we can answer
return
}
rr, err := s.resolveSubname(hostname, r.Question[0])
if err != nil {
if err == errEmpty {
m := new(dns.Msg)
m.SetReply(r)
w.WriteMsg(m)
return
}
log.Fatalf("question %#v: %v", r.Question[0], err)
}
if rr != nil {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
}
}