diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 9f07b7c..0fefd29 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -227,10 +227,10 @@ func (s *Server) hostByIP(n string) (string, bool) { return r, ok } -func (s *Server) subname(hostname, host string) (IP, bool) { +func (s *Server) subname(domain, host string) (IP, bool) { s.mu.Lock() defer s.mu.Unlock() - r, ok := s.subnames[lcHostname(strings.ToLower(hostname))][lcHostname(strings.ToLower(host))] + r, ok := s.subnames[lcHostname(strings.ToLower(domain))][lcHostname(strings.ToLower(host))] return r, ok } @@ -544,6 +544,18 @@ func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) { } continue // fall back to next-slower upstream } + if len(in.Answer) > 1 { + if in.Answer[0].Header().Rrtype == dns.TypeCNAME { + for _, rr := range in.Answer { + if rr.Header().Rrtype == dns.TypeA { + if newRR, err := s.resolveSubname("", dns.Question{strings.ToLower(rr.Header().Name), dns.TypeA, dns.ClassINET}); err != nil { + in.Answer[len(in.Answer)-1] = newRR + } + } + } + } + + } w.WriteMsg(in) if idx > 0 { // re-order this upstream to the front of s.upstream. @@ -587,13 +599,13 @@ func (s *Server) promInc(label string, r *dns.Msg) { s.prom.upstream.WithLabelValues(label).Inc() } -func (s *Server) subnameHandler(hostname lcHostname) func(w dns.ResponseWriter, r *dns.Msg) { +func (s *Server) subnameHandler(domain lcHostname) func(w dns.ResponseWriter, r *dns.Msg) { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) != 1 { // TODO: answer all questions we can answer s.promInc("local", r) return } - rr, err := s.resolveSubname(string(hostname), r.Question[0]) + rr, err := s.resolveSubname(string(domain), r.Question[0]) if err != nil { s.promInc("local", r)