diff --git a/gokrazy.go b/gokrazy.go index c06f19f..6170203 100644 --- a/gokrazy.go +++ b/gokrazy.go @@ -17,6 +17,7 @@ import ( "os/exec" "os/signal" "strings" + "syscall" "time" "github.com/gokrazy/gokrazy/internal/iface" @@ -145,6 +146,30 @@ func Supervise(commands []*exec.Cmd) error { return fmt.Errorf("updating listeners: %v", err) } + if nl, err := listenNetlink(); err == nil { + go func() { + for { + msgs, err := nl.ReadMsgs() + if err != nil { + log.Printf("netlink.ReadMsgs: %v", err) + return + } + + for _, m := range msgs { + if m.Header.Type != syscall.RTM_NEWADDR && + m.Header.Type != syscall.RTM_DELADDR { + continue + } + if err := updateListeners("80"); err != nil { + log.Printf("updating listeners: %v", err) + } + } + } + }() + } else { + log.Printf("cannot listen for new IP addresses: %v", err) + } + go func() { c := make(chan os.Signal, 1) signal.Notify(c, unix.SIGHUP) diff --git a/netlink.go b/netlink.go new file mode 100644 index 0000000..f9da69f --- /dev/null +++ b/netlink.go @@ -0,0 +1,49 @@ +package gokrazy + +import ( + "fmt" + "os" + "syscall" +) + +type netlinkListener struct { + fd int + buf []byte +} + +func listenNetlink() (*netlinkListener, error) { + fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_ROUTE) + if err != nil { + return nil, fmt.Errorf("socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE): %v", err) + } + + saddr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Groups: (1 << (syscall.RTNLGRP_IPV4_IFADDR - 1)) | + (1 << (syscall.RTNLGRP_IPV6_IFADDR - 1)), + } + + if err := syscall.Bind(fd, saddr); err != nil { + return nil, fmt.Errorf("bind: %v", err) + } + + return &netlinkListener{ + fd: fd, + // use the page size as buffer size, like libnl + buf: make([]byte, os.Getpagesize()), + }, nil +} + +func (l *netlinkListener) ReadMsgs() ([]syscall.NetlinkMessage, error) { + n, err := syscall.Read(l.fd, l.buf) + if err != nil { + return nil, fmt.Errorf("Read: %v", err) + } + + msgs, err := syscall.ParseNetlinkMessage(l.buf[:n]) + if err != nil { + return nil, fmt.Errorf("ParseNetlinkMessage: %v", err) + } + + return msgs, nil +}