dns: listen on all private IP addresses

This commit is contained in:
Michael Stapelberg 2018-06-26 08:52:04 +02:00
parent 08249aec6a
commit 10df129c1f
5 changed files with 123 additions and 18 deletions

View File

@ -6,12 +6,14 @@ import (
"flag"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/gokrazy/gokrazy"
miekgdns "github.com/miekg/dns"
"router7/internal/dhcp4d"
"router7/internal/dns"
@ -21,18 +23,33 @@ import (
_ "net/http/pprof"
)
var (
httpListeners = multilisten.NewPool()
dnsListeners = multilisten.NewPool()
)
func updateListeners() error {
hosts, err := gokrazy.PrivateInterfaceAddrs()
if err != nil {
return err
}
if net1, err := multilisten.IPv6Net1("/perm"); err == nil {
hosts = append(hosts, net1)
}
return multilisten.ListenAndServe(hosts, "8053", http.DefaultServeMux)
httpListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
return &http.Server{Addr: net.JoinHostPort(host, "8053")}
})
dnsListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
return &listenerAdapter{&miekgdns.Server{Addr: net.JoinHostPort(host, "53"), Net: "udp"}}
})
return nil
}
type listenerAdapter struct {
*miekgdns.Server
}
func (a *listenerAdapter) Close() error { return a.Shutdown() }
func logic() error {
// TODO: set correct upstream DNS resolver(s)
ip, err := netconfig.LinkAddress("/perm", "lan0")
@ -59,17 +76,15 @@ func logic() error {
updateListeners()
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGUSR1)
go func() {
for range ch {
if err := updateListeners(); err != nil {
log.Printf("updateListeners: %v", err)
}
if err := readLeases(); err != nil {
log.Printf("readLeases: %v", err)
}
for range ch {
if err := updateListeners(); err != nil {
log.Printf("updateListeners: %v", err)
}
}()
return srv.ListenAndServe()
if err := readLeases(); err != nil {
log.Printf("readLeases: %v", err)
}
}
return nil
}
func main() {

View File

@ -3,6 +3,7 @@ package main
import (
"flag"
"net"
"net/http"
"os"
"os/signal"
@ -89,6 +90,8 @@ func init() {
}
}
var httpListeners = multilisten.NewPool()
func updateListeners() error {
hosts, err := gokrazy.PrivateInterfaceAddrs()
if err != nil {
@ -98,7 +101,10 @@ func updateListeners() error {
hosts = append(hosts, net1)
}
return multilisten.ListenAndServe(hosts, "8066", http.DefaultServeMux)
httpListeners.ListenAndServe(hosts, func(host string) multilisten.Listener {
return &http.Server{Addr: net.JoinHostPort(host, "8066")}
})
return nil
}
func logic() error {

View File

@ -8,10 +8,14 @@ import (
"testing"
"router7/internal/dns"
miekgdns "github.com/miekg/dns"
)
func TestDNS(t *testing.T) {
go dns.NewServer("localhost:4453", "lan").ListenAndServe()
dns.NewServer("localhost:4453", "lan")
s := &miekgdns.Server{Addr: "localhost:4453", Net: "udp"}
go s.ListenAndServe()
const port = 4453
dig := exec.Command("dig", "-p", strconv.Itoa(port), "+timeout=1", "+short", "-x", "8.8.8.8", "@127.0.0.1")
dig.Stderr = os.Stderr

View File

@ -19,7 +19,6 @@ import (
)
type Server struct {
*dns.Server
client *dns.Client
domain string
upstream string
@ -41,7 +40,6 @@ func NewServer(addr, domain string) *Server {
hostname, _ := os.Hostname()
ip, _, _ := net.SplitHostPort(addr)
server := &Server{
Server: &dns.Server{Addr: addr, Net: "udp"},
client: &dns.Client{},
domain: domain,
upstream: "8.8.8.8:53",
@ -170,6 +168,7 @@ func isLocalInAddrArpa(q string) bool {
}
// TODO: require search domains to be present, then use HandleFunc("lan.", internalName)
// TODO: add test for non-A records on internal names, they should not go upstream
func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
s.prom.queries.Inc()
s.prom.questions.Observe(float64(len(r.Question)))

View File

@ -0,0 +1,81 @@
// Package multilisten implements listening on multiple addresses at once.
package multilisten
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"path/filepath"
"router7/internal/dhcp6"
"sync"
)
type Listener interface {
ListenAndServe() error
Close() error
}
type Pool struct {
mu sync.Mutex
listeners map[string]Listener
}
func NewPool() *Pool {
return &Pool{
listeners: make(map[string]Listener),
}
}
func (p *Pool) ListenAndServe(hosts []string, listenerFor func(host string) Listener) {
p.mu.Lock()
defer p.mu.Unlock()
vanished := make(map[string]bool)
for host := range p.listeners {
vanished[host] = false
}
for _, host := range hosts {
if _, ok := p.listeners[host]; ok {
// confirm found
delete(vanished, host)
} else {
log.Printf("now listening on %s", host)
// add a new listener
ln := listenerFor(host)
p.listeners[host] = ln
go func(host string, ln Listener) {
err := ln.ListenAndServe()
log.Printf("listener for %q died: %v", host, err)
p.mu.Lock()
defer p.mu.Unlock()
delete(p.listeners, host)
}(host, ln)
}
}
for host := range vanished {
log.Printf("no longer listening on %s", host)
p.listeners[host].Close()
delete(p.listeners, host)
}
}
// IPv6Net1 returns the IP address which router7 picks from the IPv6 prefix for
// itself, e.g. address 2a02:168:4a00::1 for prefix 2a02:168:4a00::/48.
func IPv6Net1(dir string) (string, error) {
b, err := ioutil.ReadFile(filepath.Join(dir, "dhcp6/wire/lease.json"))
if err != nil {
return "", err
}
var got dhcp6.Config
if err := json.Unmarshal(b, &got); err != nil {
return "", err
}
for _, prefix := range got.Prefixes {
// pick the first address of the prefix, e.g. address 2a02:168:4a00::1
// for prefix 2a02:168:4a00::/48
prefix.IP[len(prefix.IP)-1] = 1
return prefix.IP.String(), nil
}
return "", fmt.Errorf("no DHCPv6 prefix obtained")
}