diff --git a/cmd/netconfigd/netconfigd.go b/cmd/netconfigd/netconfigd.go index f336231..74fb2ad 100644 --- a/cmd/netconfigd/netconfigd.go +++ b/cmd/netconfigd/netconfigd.go @@ -21,11 +21,11 @@ import ( "net/http" "os" "os/signal" + "sync" "syscall" "github.com/gokrazy/gokrazy" "github.com/google/nftables" - "github.com/google/nftables/expr" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -44,25 +44,43 @@ var ( func init() { var c nftables.Conn for _, metric := range []struct { - name string - labels prometheus.Labels - table *nftables.Table - chain *nftables.Chain + name string + labels prometheus.Labels + obj *nftables.CounterObj + packets, bytes uint64 }{ { name: "filter_forward", labels: prometheus.Labels{"family": "ipv4"}, - table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"}, - chain: &nftables.Chain{Name: "forward"}, + obj: &nftables.CounterObj{ + Table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"}, + Name: "fwded", + }, }, { name: "filter_forward", labels: prometheus.Labels{"family": "ipv6"}, - table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"}, - chain: &nftables.Chain{Name: "forward"}, + obj: &nftables.CounterObj{ + Table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"}, + Name: "fwded", + }, }, } { metric := metric // copy + var mu sync.Mutex + updateCounter := func() { + mu.Lock() + defer mu.Unlock() + objs, err := c.GetObjReset(metric.obj) + if err != nil || + len(objs) != 1 { + return + } + if co, ok := objs[0].(*nftables.CounterObj); ok { + metric.packets += co.Packets + metric.bytes += co.Bytes + } + } promauto.NewCounterFunc( prometheus.CounterOpts{ Subsystem: "nftables", @@ -71,16 +89,8 @@ func init() { ConstLabels: metric.labels, }, func() float64 { - rules, err := c.GetRule(metric.table, metric.chain) - if err != nil || - len(rules) != 1 || - len(rules[0].Exprs) != 1 { - return 0 - } - if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok { - return float64(ce.Packets) - } - return 0 + updateCounter() + return float64(metric.packets) }) promauto.NewCounterFunc( prometheus.CounterOpts{ @@ -90,16 +100,8 @@ func init() { ConstLabels: metric.labels, }, func() float64 { - rules, err := c.GetRule(metric.table, metric.chain) - if err != nil || - len(rules) != 1 || - len(rules[0].Exprs) != 1 { - return 0 - } - if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok { - return float64(ce.Bytes) - } - return 0 + updateCounter() + return float64(metric.bytes) }) } } diff --git a/integration/netconfig/netconfig_test.go b/integration/netconfig/netconfig_test.go index 973886d..24ffe91 100644 --- a/integration/netconfig/netconfig_test.go +++ b/integration/netconfig/netconfig_test.go @@ -27,7 +27,7 @@ import ( "github.com/rtr7/router7/internal/netconfig" "github.com/google/go-cmp/cmp" - "github.com/google/nftables/expr" + "github.com/google/nftables" ) const goldenInterfaces = ` @@ -128,7 +128,7 @@ func TestNetconfig(t *testing.T) { t.Fatal(err) } - netconfig.DefaultCounter = expr.Counter{Packets: 23, Bytes: 42} + netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 23, Bytes: 42} if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil { t.Fatalf("netconfig.Apply: %v", err) } @@ -136,7 +136,7 @@ func TestNetconfig(t *testing.T) { // Apply twice to ensure the absence of errors when dealing with // already-configured interfaces, addresses, routes, … (and ensure // nftables rules are replaced, not appendend to). - netconfig.DefaultCounter = expr.Counter{Packets: 0, Bytes: 0} + netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 0, Bytes: 0} if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil { t.Fatalf("netconfig.Apply: %v", err) } @@ -248,15 +248,23 @@ func TestNetconfig(t *testing.T) { ` }`, `}`, `table ip filter {`, + ` counter fwded {`, + ` packets 23 bytes 42`, + ` }`, + ``, ` chain forward {`, ` type filter hook forward priority 0; policy accept;`, - ` counter packets 23 bytes 42`, + ` counter name "fwded"`, ` }`, `}`, `table ip6 filter {`, + ` counter fwded {`, + ` packets 23 bytes 42`, + ` }`, + ``, ` chain forward {`, ` type filter hook forward priority 0; policy accept;`, - ` counter packets 23 bytes 42`, + ` counter name "fwded"`, ` }`, `}`, } diff --git a/internal/netconfig/netconfig.go b/internal/netconfig/netconfig.go index 01b653d..1e28553 100644 --- a/internal/netconfig/netconfig.go +++ b/internal/netconfig/netconfig.go @@ -470,39 +470,44 @@ func applyPortForwardings(dir string, c *nftables.Conn, nat *nftables.Table, pre return nil } -// DefaultCounter is overridden while testing -var DefaultCounter expr.Counter +// DefaultCounterObj is overridden while testing +var DefaultCounterObj = &nftables.CounterObj{} -func getCounter(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain) expr.Counter { - rules, err := c.GetRule(table, chain) +func getCounterObj(c *nftables.Conn, o *nftables.CounterObj) *nftables.CounterObj { + objs, err := c.GetObj(o) if err != nil { - return DefaultCounter + o.Bytes = DefaultCounterObj.Bytes + o.Packets = DefaultCounterObj.Packets + return o } { // TODO: remove this workaround once travis has workers with a newer kernel // than its current Ubuntu trusty kernel (Linux 4.4.0): - var filtered []*nftables.Rule - for _, rule := range rules { - if rule.Table.Name != table.Name || - rule.Chain.Name != chain.Name { + var filtered []nftables.Obj + for _, obj := range objs { + co, ok := obj.(*nftables.CounterObj) + if !ok { continue } - filtered = append(filtered, rule) + if co.Table.Name != o.Table.Name { + continue + } + filtered = append(filtered, obj) } - rules = filtered + objs = filtered } - if got, want := len(rules), 1; got != want { - log.Printf("could not carry counter values: unexpected number of rules in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want) - return DefaultCounter + if got, want := len(objs), 1; got != want { + log.Printf("could not carry counter values: unexpected number of objects in table %v: got %d, want %d", o.Table.Name, got, want) + o.Bytes = DefaultCounterObj.Bytes + o.Packets = DefaultCounterObj.Packets + return o } - if got, want := len(rules[0].Exprs), 1; got != want { - log.Printf("could not carry counter values: unexpected number of exprs in rule 0 in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want) - return DefaultCounter + if co, ok := objs[0].(*nftables.CounterObj); ok { + return co } - if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok { - return *ce - } - return DefaultCounter + o.Bytes = DefaultCounterObj.Bytes + o.Packets = DefaultCounterObj.Packets + return o } func applyFirewall(dir string) error { @@ -571,14 +576,22 @@ func applyFirewall(dir string) error { Type: nftables.ChainTypeFilter, }) - counter := getCounter(c, filter, forward) + counterObj := getCounterObj(c, &nftables.CounterObj{ + Table: filter, + Name: "fwded", + }) + counter := c.AddObj(counterObj).(*nftables.CounterObj) + const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix c.AddRule(&nftables.Rule{ Table: filter, Chain: forward, Exprs: []expr.Any{ - // [ counter pkts 23 bytes 42 ] - &counter, + // [ counter name fwded ] + &expr.Objref{ + Type: NFT_OBJECT_COUNTER, + Name: counter.Name, + }, }, }) }