dns: resolve own hostname, lock for concurrency
This commit is contained in:
parent
02c7fa7e0d
commit
93eaab99cb
@ -3,7 +3,9 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"router7/internal/dhcp4d"
|
"router7/internal/dhcp4d"
|
||||||
@ -15,29 +17,61 @@ import (
|
|||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
*dns.Server
|
*dns.Server
|
||||||
client *dns.Client
|
client *dns.Client
|
||||||
domain string
|
domain string
|
||||||
upstream string
|
upstream string
|
||||||
sometimes *rate.Limiter
|
sometimes *rate.Limiter
|
||||||
hostsByName map[string]string
|
|
||||||
hostsByIP map[string]string
|
mu sync.Mutex
|
||||||
|
hostname, ip string
|
||||||
|
hostsByName map[string]string
|
||||||
|
hostsByIP map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(addr, domain string) *Server {
|
func NewServer(addr, domain string) *Server {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
ip, _, _ := net.SplitHostPort(addr)
|
||||||
server := &Server{
|
server := &Server{
|
||||||
Server: &dns.Server{Addr: addr, Net: "udp"},
|
Server: &dns.Server{Addr: addr, Net: "udp"},
|
||||||
client: &dns.Client{},
|
client: &dns.Client{},
|
||||||
domain: domain,
|
domain: domain,
|
||||||
upstream: "8.8.8.8:53",
|
upstream: "8.8.8.8:53",
|
||||||
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
|
sometimes: rate.NewLimiter(rate.Every(1*time.Second), 1), // at most once per second
|
||||||
hostsByName: make(map[string]string),
|
hostname: hostname,
|
||||||
hostsByIP: make(map[string]string),
|
ip: ip,
|
||||||
}
|
}
|
||||||
|
server.initHostsLocked()
|
||||||
dns.HandleFunc(".", server.handleRequest)
|
dns.HandleFunc(".", server.handleRequest)
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) initHostsLocked() {
|
||||||
|
s.hostsByName = make(map[string]string)
|
||||||
|
s.hostsByIP = make(map[string]string)
|
||||||
|
if s.hostname != "" && s.ip != "" {
|
||||||
|
s.hostsByName[s.hostname] = s.ip
|
||||||
|
s.hostsByIP[s.ip] = s.hostname
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) hostByName(n string) (string, bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
r, ok := s.hostsByName[n]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) hostByIP(n string) (string, bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
r, ok := s.hostsByIP[n]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) SetLeases(leases []dhcp4d.Lease) {
|
func (s *Server) SetLeases(leases []dhcp4d.Lease) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.initHostsLocked()
|
||||||
for _, l := range leases {
|
for _, l := range leases {
|
||||||
s.hostsByName[l.Hostname] = l.Addr.String()
|
s.hostsByName[l.Hostname] = l.Addr.String()
|
||||||
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
|
if rev, err := dns.ReverseAddr(l.Addr.String()); err == nil {
|
||||||
@ -90,7 +124,6 @@ func isLocalInAddrArpa(q string) bool {
|
|||||||
return local
|
return local
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: is handleRequest called in more than one goroutine at a time?
|
|
||||||
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
|
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
|
||||||
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
if len(r.Question) == 1 { // TODO: answer all questions we can answer
|
||||||
@ -100,7 +133,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
name = strings.TrimSuffix(name, "."+s.domain)
|
name = strings.TrimSuffix(name, "."+s.domain)
|
||||||
|
|
||||||
if !strings.Contains(name, ".") {
|
if !strings.Contains(name, ".") {
|
||||||
if host, ok := s.hostsByName[name]; ok {
|
if host, ok := s.hostByName(name); ok {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
|
rr, err := dns.NewRR(q.Name + " 3600 IN A " + host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@ -115,7 +148,7 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
|
if q.Qtype == dns.TypePTR && q.Qclass == dns.ClassINET {
|
||||||
if isLocalInAddrArpa(q.Name) {
|
if isLocalInAddrArpa(q.Name) {
|
||||||
if host, ok := s.hostsByIP[q.Name]; ok {
|
if host, ok := s.hostByIP(q.Name); ok {
|
||||||
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
rr, err := dns.NewRR(q.Name + " 3600 IN PTR " + host + "." + s.domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"router7/internal/dhcp4d"
|
"router7/internal/dhcp4d"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -74,6 +75,29 @@ func TestDHCP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHostname(t *testing.T) {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("os.Hostname: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &recorder{}
|
||||||
|
s := NewServer("127.0.0.2:0", "lan")
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion(hostname+".", dns.TypeA)
|
||||||
|
s.handleRequest(r, m)
|
||||||
|
if got, want := len(r.response.Answer), 1; got != want {
|
||||||
|
t.Fatalf("unexpected number of answers for %v: got %d, want %d", m.Question, got, want)
|
||||||
|
}
|
||||||
|
a := r.response.Answer[0]
|
||||||
|
if _, ok := a.(*dns.A); !ok {
|
||||||
|
t.Fatalf("unexpected response type: got %T, want dns.A", a)
|
||||||
|
}
|
||||||
|
if got, want := a.(*dns.A).A.To4(), (net.IP{127, 0, 0, 2}); !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("unexpected response IP: got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDHCPReverse(t *testing.T) {
|
func TestDHCPReverse(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
ip net.IP
|
ip net.IP
|
||||||
|
Loading…
x
Reference in New Issue
Block a user