dns: resolve own hostname, lock for concurrency

This commit is contained in:
Michael Stapelberg 2018-06-14 20:42:53 +02:00
parent 02c7fa7e0d
commit 93eaab99cb
2 changed files with 73 additions and 16 deletions

View File

@ -3,7 +3,9 @@ package dns
import (
"log"
"net"
"os"
"strings"
"sync"
"time"
"router7/internal/dhcp4d"
@ -15,29 +17,61 @@ import (
type Server struct {
*dns.Server
client *dns.Client
domain string
upstream string
sometimes *rate.Limiter
hostsByName map[string]string
hostsByIP map[string]string
client *dns.Client
domain string
upstream string
sometimes *rate.Limiter
mu sync.Mutex
hostname, ip string
hostsByName map[string]string
hostsByIP map[string]string
}
func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr)
server := &Server{
Server: &dns.Server{Addr: addr, Net: "udp"},
client: &dns.Client{},
domain: domain,
upstream: "8.8.8.8:53",
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostsByName: make(map[string]string),
hostsByIP: make(map[string]string),
Server: &dns.Server{Addr: addr, Net: "udp"},
client: &dns.Client{},
domain: domain,
upstream: "8.8.8.8:53",
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
hostname: hostname,
ip: ip,
}
server.initHostsLocked()
dns.HandleFunc(".", server.handleRequest)
return server
}
func (s *Server) initHostsLocked() {
s.hostsByName = make(map[string]string)
s.hostsByIP = make(map[string]string)
if s.hostname != "" && s.ip != "" {
s.hostsByName[s.hostname] = s.ip
s.hostsByIP[s.ip] = s.hostname
}
}
func (s *Server) hostByName(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByName[n]
return r, ok
}
func (s *Server) hostByIP(n string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.hostsByIP[n]
return r, ok
}
func (s *Server) SetLeases(leases []dhcp4d.Lease) {
s.mu.Lock()
defer s.mu.Unlock()
s.initHostsLocked()
for _, l := range leases {
s.hostsByName[l.Hostname] = l.Addr.String()
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
@ -90,7 +124,6 @@ func isLocalInAddrArpa(q string) bool {
return local
}
// TODO: is handleRequest called in more than one goroutine at a time?
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 1 { // TODO: answer all questions we can answer
@ -100,7 +133,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
name = strings.TrimSuffix(name, "."+s.domain)
if !strings.Contains(name, ".") {
if host, ok := s.hostsByName[name]; ok {
if host, ok := s.hostByName(name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
if err != nil {
log.Fatal(err)
@ -115,7 +148,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
}
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
if isLocalInAddrArpa(q.Name) {
if host, ok := s.hostsByIP[q.Name]; ok {
if host, ok := s.hostByIP(q.Name); ok {
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
if err != nil {
log.Fatal(err)

View File

@ -3,6 +3,7 @@ package dns
import (
"bytes"
"net"
"os"
"router7/internal/dhcp4d"
"testing"
@ -74,6 +75,29 @@ func TestDHCP(t *testing.T) {
}
}
func TestHostname(t *testing.T) {
hostname, err := os.Hostname()
if err != nil {
t.Skipf("os.Hostname: %v", err)
}
r := &recorder{}
s := NewServer("127.0.0.2:0", "lan")
m := new(dns.Msg)
m.SetQuestion(hostname+".", dns.TypeA)
s.handleRequest(r, m)
if got, want := len(r.response.Answer), 1; got != want {
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
}
a := r.response.Answer[0]
if _, ok := a.(*dns.A); !ok {
t.Fatalf("unexpected response type: got %T, want dns.A", a)
}
if got, want := a.(*dns.A).A.To4(), (net.IP{127, 0, 0, 2}); !bytes.Equal(got, want) {
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
}
}
func TestDHCPReverse(t *testing.T) {
for _, test := range []struct {
ip net.IP