237 lines
5.7 KiB
Go
Raw Normal View History

2018-06-25 20:12:51 +02:00
// Package dns implements a DNS forwarder.
2018-05-27 17:30:42 +02:00
package dns
import (
"log"
"net"
"net/http"
"os"
2018-05-27 17:30:42 +02:00
"strings"
"sync"
2018-05-27 17:30:42 +02:00
"time"
"router7/internal/dhcp4d"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/time/rate"
2018-05-27 17:30:42 +02:00
)
type Server struct {
*dns.Server
client *dns.Client
domain string
upstream string
sometimes *rate.Limiter
prom struct {
registry *prometheus.Registry
queries prometheus.Counter
upstream *prometheus.CounterVec
questions prometheus.Histogram
}
mu sync.Mutex
hostname, ip string
hostsByName map[string]string
hostsByIP map[string]string
2018-05-27 17:30:42 +02:00
}
func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr)
2018-05-27 17:30:42 +02:00
server := &Server{
Server: &dns.Server{Addr: addr, Net: "udp"},
client: &dns.Client{},
domain: domain,
upstream: "8.8.8.8:53",
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname,
ip: ip,
2018-05-27 17:30:42 +02:00
}
server.prom.registry = prometheus.NewRegistry()
server.prom.queries = prometheus.NewCounter(prometheus.CounterOpts{
Name: "dns_queries",
Help: "Number of DNS queries received",
})
server.prom.registry.MustRegister(server.prom.queries)
server.prom.upstream = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "dns_upstream",
Help: "Which upstream answered which DNS query",
},
[]string{"upstream"},
)
server.prom.registry.MustRegister(server.prom.upstream)
server.prom.questions = prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "dns_questions",
Help: "Number of questions in each DNS request",
Buckets: prometheus.LinearBuckets(0, 1, 10),
})
server.prom.registry.MustRegister(server.prom.questions)
server.prom.registry.MustRegister(prometheus.NewGoCollector())
server.initHostsLocked()
2018-05-27 17:30:42 +02:00
dns.HandleFunc(".", server.handleRequest)
return server
}
func (s *Server) initHostsLocked() {
s.hostsByName = make(map[string]string)
s.hostsByIP = make(map[string]string)
if s.hostname != "" && s.ip != "" {
s.hostsByName[s.hostname] = s.ip
if rev, err := dns.ReverseAddr(s.ip); err == nil {
s.hostsByIP[rev] = s.hostname
}
}
}
func (s *Server) hostByName(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByName[n]
return r, ok
}
func (s *Server) hostByIP(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByIP[n]
return r, ok
}
func (s *Server) PrometheusHandler() http.Handler {
return promhttp.HandlerFor(s.prom.registry, promhttp.HandlerOpts{})
}
2018-05-27 17:30:42 +02:00
func (s *Server) SetLeases(leases []dhcp4d.Lease) {
s.mu.Lock()
defer s.mu.Unlock()
s.initHostsLocked()
2018-05-27 17:30:42 +02:00
for _, l := range leases {
s.hostsByName[l.Hostname] = l.Addr.String()
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
s.hostsByIP[rev] = l.Hostname
}
}
}
func mustParseCIDR(s string) *net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return ipnet
}
var (
localNets = []*net.IPNet{
// loopback: https://tools.ietf.org/html/rfc3330#section-2
mustParseCIDR("127.0.0.0/8"),
// loopback: https://tools.ietf.org/html/rfc3513#section-2.4
mustParseCIDR("::1/128"),
// reversed: https://tools.ietf.org/html/rfc1918#section-3
mustParseCIDR("10.0.0.0/8"),
mustParseCIDR("172.16.0.0/12"),
mustParseCIDR("192.168.0.0/16"),
}
)
func reverse(ss []string) {
last := len(ss) - 1
for i := 0; i < len(ss)/2; i++ {
ss[i], ss[last-i] = ss[last-i], ss[i]
}
}
func isLocalInAddrArpa(q string) bool {
if !strings.HasSuffix(q, ".in-addr.arpa.") {
return false
}
parts := strings.Split(strings.TrimSuffix(q, ".in-addr.arpa."), ".")
reverse(parts)
ip := net.ParseIP(strings.Join(parts, "."))
if ip == nil {
return false
}
var local bool
for _, l := range localNets {
if l.Contains(ip) {
local = true
break
}
}
return local
}
2018-05-27 17:30:42 +02:00
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))
2018-05-27 17:30:42 +02:00
if len(r.Question) == 1 { // TODO: answer all questions we can answer
q := r.Question[0]
if q.Qtype == dns.TypeA && q.Qclass == dns.ClassINET {
name := strings.TrimSuffix(q.Name, ".")
name = strings.TrimSuffix(name, "."+s.domain)
if !strings.Contains(name, ".") {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByName(name); ok {
2018-05-27 17:30:42 +02:00
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
2018-05-27 17:30:42 +02:00
}
}
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
if isLocalInAddrArpa(q.Name) {
s.prom.upstream.WithLabelValues("local").Inc()
if host, ok := s.hostByIP(q.Name); ok {
2018-05-27 17:30:42 +02:00
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
if err != nil {
log.Fatal(err)
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
return
}
// Send an authoritative NXDOMAIN for local names:
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(m)
return
2018-05-27 17:30:42 +02:00
}
}
}
in, _, err := s.client.Exchange(r, s.upstream)
if err != nil {
if s.sometimes.Allow() {
log.Printf("resolving %v failed: %v", r.Question, err)
}
return // DNS has no reply for resolving errors
}
w.WriteMsg(in)
s.prom.upstream.WithLabelValues("DNS").Inc()
2018-05-27 17:30:42 +02:00
}