nftables: use stateful object counters

This way, we can atomically get and reset them.

fixes https://github.com/rtr7/router7/issues/3
This commit is contained in:
Michael Stapelberg 2018-08-08 23:15:21 +02:00
parent ad779c3665
commit b03596f1c5
3 changed files with 81 additions and 58 deletions

View File

@ -21,11 +21,11 @@ import (
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"github.com/gokrazy/gokrazy"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
@ -44,25 +44,43 @@ var (
func init() {
var c nftables.Conn
for _, metric := range []struct {
name string
labels prometheus.Labels
table *nftables.Table
chain *nftables.Chain
name string
labels prometheus.Labels
obj *nftables.CounterObj
packets, bytes uint64
}{
{
name: "filter_forward",
labels: prometheus.Labels{"family": "ipv4"},
table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"},
chain: &nftables.Chain{Name: "forward"},
obj: &nftables.CounterObj{
Table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"},
Name: "fwded",
},
},
{
name: "filter_forward",
labels: prometheus.Labels{"family": "ipv6"},
table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"},
chain: &nftables.Chain{Name: "forward"},
obj: &nftables.CounterObj{
Table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"},
Name: "fwded",
},
},
} {
metric := metric // copy
var mu sync.Mutex
updateCounter := func() {
mu.Lock()
defer mu.Unlock()
objs, err := c.GetObjReset(metric.obj)
if err != nil ||
len(objs) != 1 {
return
}
if co, ok := objs[0].(*nftables.CounterObj); ok {
metric.packets += co.Packets
metric.bytes += co.Bytes
}
}
promauto.NewCounterFunc(
prometheus.CounterOpts{
Subsystem: "nftables",
@ -71,16 +89,8 @@ func init() {
ConstLabels: metric.labels,
},
func() float64 {
rules, err := c.GetRule(metric.table, metric.chain)
if err != nil ||
len(rules) != 1 ||
len(rules[0].Exprs) != 1 {
return 0
}
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
return float64(ce.Packets)
}
return 0
updateCounter()
return float64(metric.packets)
})
promauto.NewCounterFunc(
prometheus.CounterOpts{
@ -90,16 +100,8 @@ func init() {
ConstLabels: metric.labels,
},
func() float64 {
rules, err := c.GetRule(metric.table, metric.chain)
if err != nil ||
len(rules) != 1 ||
len(rules[0].Exprs) != 1 {
return 0
}
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
return float64(ce.Bytes)
}
return 0
updateCounter()
return float64(metric.bytes)
})
}
}

View File

@ -27,7 +27,7 @@ import (
"github.com/rtr7/router7/internal/netconfig"
"github.com/google/go-cmp/cmp"
"github.com/google/nftables/expr"
"github.com/google/nftables"
)
const goldenInterfaces = `
@ -128,7 +128,7 @@ func TestNetconfig(t *testing.T) {
t.Fatal(err)
}
netconfig.DefaultCounter = expr.Counter{Packets: 23, Bytes: 42}
netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 23, Bytes: 42}
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
t.Fatalf("netconfig.Apply: %v", err)
}
@ -136,7 +136,7 @@ func TestNetconfig(t *testing.T) {
// Apply twice to ensure the absence of errors when dealing with
// already-configured interfaces, addresses, routes, … (and ensure
// nftables rules are replaced, not appendend to).
netconfig.DefaultCounter = expr.Counter{Packets: 0, Bytes: 0}
netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 0, Bytes: 0}
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
t.Fatalf("netconfig.Apply: %v", err)
}
@ -248,15 +248,23 @@ func TestNetconfig(t *testing.T) {
` }`,
`}`,
`table ip filter {`,
` counter fwded {`,
` packets 23 bytes 42`,
` }`,
``,
` chain forward {`,
` type filter hook forward priority 0; policy accept;`,
` counter packets 23 bytes 42`,
` counter name "fwded"`,
` }`,
`}`,
`table ip6 filter {`,
` counter fwded {`,
` packets 23 bytes 42`,
` }`,
``,
` chain forward {`,
` type filter hook forward priority 0; policy accept;`,
` counter packets 23 bytes 42`,
` counter name "fwded"`,
` }`,
`}`,
}

View File

@ -470,39 +470,44 @@ func applyPortForwardings(dir string, c *nftables.Conn, nat *nftables.Table, pre
return nil
}
// DefaultCounter is overridden while testing
var DefaultCounter expr.Counter
// DefaultCounterObj is overridden while testing
var DefaultCounterObj = &nftables.CounterObj{}
func getCounter(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain) expr.Counter {
rules, err := c.GetRule(table, chain)
func getCounterObj(c *nftables.Conn, o *nftables.CounterObj) *nftables.CounterObj {
objs, err := c.GetObj(o)
if err != nil {
return DefaultCounter
o.Bytes = DefaultCounterObj.Bytes
o.Packets = DefaultCounterObj.Packets
return o
}
{
// TODO: remove this workaround once travis has workers with a newer kernel
// than its current Ubuntu trusty kernel (Linux 4.4.0):
var filtered []*nftables.Rule
for _, rule := range rules {
if rule.Table.Name != table.Name ||
rule.Chain.Name != chain.Name {
var filtered []nftables.Obj
for _, obj := range objs {
co, ok := obj.(*nftables.CounterObj)
if !ok {
continue
}
filtered = append(filtered, rule)
if co.Table.Name != o.Table.Name {
continue
}
filtered = append(filtered, obj)
}
rules = filtered
objs = filtered
}
if got, want := len(rules), 1; got != want {
log.Printf("could not carry counter values: unexpected number of rules in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want)
return DefaultCounter
if got, want := len(objs), 1; got != want {
log.Printf("could not carry counter values: unexpected number of objects in table %v: got %d, want %d", o.Table.Name, got, want)
o.Bytes = DefaultCounterObj.Bytes
o.Packets = DefaultCounterObj.Packets
return o
}
if got, want := len(rules[0].Exprs), 1; got != want {
log.Printf("could not carry counter values: unexpected number of exprs in rule 0 in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want)
return DefaultCounter
if co, ok := objs[0].(*nftables.CounterObj); ok {
return co
}
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
return *ce
}
return DefaultCounter
o.Bytes = DefaultCounterObj.Bytes
o.Packets = DefaultCounterObj.Packets
return o
}
func applyFirewall(dir string) error {
@ -571,14 +576,22 @@ func applyFirewall(dir string) error {
Type: nftables.ChainTypeFilter,
})
counter := getCounter(c, filter, forward)
counterObj := getCounterObj(c, &nftables.CounterObj{
Table: filter,
Name: "fwded",
})
counter := c.AddObj(counterObj).(*nftables.CounterObj)
const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
// [ counter pkts 23 bytes 42 ]
&counter,
// [ counter name fwded ]
&expr.Objref{
Type: NFT_OBJECT_COUNTER,
Name: counter.Name,
},
},
})
}