nftables: use stateful object counters
This way, we can atomically get and reset them. fixes https://github.com/rtr7/router7/issues/3
This commit is contained in:
parent
ad779c3665
commit
b03596f1c5
@ -21,11 +21,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/gokrazy/gokrazy"
|
"github.com/gokrazy/gokrazy"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
@ -46,23 +46,41 @@ func init() {
|
|||||||
for _, metric := range []struct {
|
for _, metric := range []struct {
|
||||||
name string
|
name string
|
||||||
labels prometheus.Labels
|
labels prometheus.Labels
|
||||||
table *nftables.Table
|
obj *nftables.CounterObj
|
||||||
chain *nftables.Chain
|
packets, bytes uint64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "filter_forward",
|
name: "filter_forward",
|
||||||
labels: prometheus.Labels{"family": "ipv4"},
|
labels: prometheus.Labels{"family": "ipv4"},
|
||||||
table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"},
|
obj: &nftables.CounterObj{
|
||||||
chain: &nftables.Chain{Name: "forward"},
|
Table: &nftables.Table{Family: nftables.TableFamilyIPv4, Name: "filter"},
|
||||||
|
Name: "fwded",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "filter_forward",
|
name: "filter_forward",
|
||||||
labels: prometheus.Labels{"family": "ipv6"},
|
labels: prometheus.Labels{"family": "ipv6"},
|
||||||
table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"},
|
obj: &nftables.CounterObj{
|
||||||
chain: &nftables.Chain{Name: "forward"},
|
Table: &nftables.Table{Family: nftables.TableFamilyIPv6, Name: "filter"},
|
||||||
|
Name: "fwded",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
metric := metric // copy
|
metric := metric // copy
|
||||||
|
var mu sync.Mutex
|
||||||
|
updateCounter := func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
objs, err := c.GetObjReset(metric.obj)
|
||||||
|
if err != nil ||
|
||||||
|
len(objs) != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if co, ok := objs[0].(*nftables.CounterObj); ok {
|
||||||
|
metric.packets += co.Packets
|
||||||
|
metric.bytes += co.Bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
promauto.NewCounterFunc(
|
promauto.NewCounterFunc(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Subsystem: "nftables",
|
Subsystem: "nftables",
|
||||||
@ -71,16 +89,8 @@ func init() {
|
|||||||
ConstLabels: metric.labels,
|
ConstLabels: metric.labels,
|
||||||
},
|
},
|
||||||
func() float64 {
|
func() float64 {
|
||||||
rules, err := c.GetRule(metric.table, metric.chain)
|
updateCounter()
|
||||||
if err != nil ||
|
return float64(metric.packets)
|
||||||
len(rules) != 1 ||
|
|
||||||
len(rules[0].Exprs) != 1 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
|
|
||||||
return float64(ce.Packets)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
})
|
||||||
promauto.NewCounterFunc(
|
promauto.NewCounterFunc(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
@ -90,16 +100,8 @@ func init() {
|
|||||||
ConstLabels: metric.labels,
|
ConstLabels: metric.labels,
|
||||||
},
|
},
|
||||||
func() float64 {
|
func() float64 {
|
||||||
rules, err := c.GetRule(metric.table, metric.chain)
|
updateCounter()
|
||||||
if err != nil ||
|
return float64(metric.bytes)
|
||||||
len(rules) != 1 ||
|
|
||||||
len(rules[0].Exprs) != 1 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
|
|
||||||
return float64(ce.Bytes)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ import (
|
|||||||
"github.com/rtr7/router7/internal/netconfig"
|
"github.com/rtr7/router7/internal/netconfig"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const goldenInterfaces = `
|
const goldenInterfaces = `
|
||||||
@ -128,7 +128,7 @@ func TestNetconfig(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
netconfig.DefaultCounter = expr.Counter{Packets: 23, Bytes: 42}
|
netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 23, Bytes: 42}
|
||||||
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
|
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
|
||||||
t.Fatalf("netconfig.Apply: %v", err)
|
t.Fatalf("netconfig.Apply: %v", err)
|
||||||
}
|
}
|
||||||
@ -136,7 +136,7 @@ func TestNetconfig(t *testing.T) {
|
|||||||
// Apply twice to ensure the absence of errors when dealing with
|
// Apply twice to ensure the absence of errors when dealing with
|
||||||
// already-configured interfaces, addresses, routes, … (and ensure
|
// already-configured interfaces, addresses, routes, … (and ensure
|
||||||
// nftables rules are replaced, not appendend to).
|
// nftables rules are replaced, not appendend to).
|
||||||
netconfig.DefaultCounter = expr.Counter{Packets: 0, Bytes: 0}
|
netconfig.DefaultCounterObj = &nftables.CounterObj{Packets: 0, Bytes: 0}
|
||||||
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
|
if err := netconfig.Apply(tmp, filepath.Join(tmp, "root")); err != nil {
|
||||||
t.Fatalf("netconfig.Apply: %v", err)
|
t.Fatalf("netconfig.Apply: %v", err)
|
||||||
}
|
}
|
||||||
@ -248,15 +248,23 @@ func TestNetconfig(t *testing.T) {
|
|||||||
` }`,
|
` }`,
|
||||||
`}`,
|
`}`,
|
||||||
`table ip filter {`,
|
`table ip filter {`,
|
||||||
|
` counter fwded {`,
|
||||||
|
` packets 23 bytes 42`,
|
||||||
|
` }`,
|
||||||
|
``,
|
||||||
` chain forward {`,
|
` chain forward {`,
|
||||||
` type filter hook forward priority 0; policy accept;`,
|
` type filter hook forward priority 0; policy accept;`,
|
||||||
` counter packets 23 bytes 42`,
|
` counter name "fwded"`,
|
||||||
` }`,
|
` }`,
|
||||||
`}`,
|
`}`,
|
||||||
`table ip6 filter {`,
|
`table ip6 filter {`,
|
||||||
|
` counter fwded {`,
|
||||||
|
` packets 23 bytes 42`,
|
||||||
|
` }`,
|
||||||
|
``,
|
||||||
` chain forward {`,
|
` chain forward {`,
|
||||||
` type filter hook forward priority 0; policy accept;`,
|
` type filter hook forward priority 0; policy accept;`,
|
||||||
` counter packets 23 bytes 42`,
|
` counter name "fwded"`,
|
||||||
` }`,
|
` }`,
|
||||||
`}`,
|
`}`,
|
||||||
}
|
}
|
||||||
|
@ -470,39 +470,44 @@ func applyPortForwardings(dir string, c *nftables.Conn, nat *nftables.Table, pre
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultCounter is overridden while testing
|
// DefaultCounterObj is overridden while testing
|
||||||
var DefaultCounter expr.Counter
|
var DefaultCounterObj = &nftables.CounterObj{}
|
||||||
|
|
||||||
func getCounter(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain) expr.Counter {
|
func getCounterObj(c *nftables.Conn, o *nftables.CounterObj) *nftables.CounterObj {
|
||||||
rules, err := c.GetRule(table, chain)
|
objs, err := c.GetObj(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DefaultCounter
|
o.Bytes = DefaultCounterObj.Bytes
|
||||||
|
o.Packets = DefaultCounterObj.Packets
|
||||||
|
return o
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// TODO: remove this workaround once travis has workers with a newer kernel
|
// TODO: remove this workaround once travis has workers with a newer kernel
|
||||||
// than its current Ubuntu trusty kernel (Linux 4.4.0):
|
// than its current Ubuntu trusty kernel (Linux 4.4.0):
|
||||||
var filtered []*nftables.Rule
|
var filtered []nftables.Obj
|
||||||
for _, rule := range rules {
|
for _, obj := range objs {
|
||||||
if rule.Table.Name != table.Name ||
|
co, ok := obj.(*nftables.CounterObj)
|
||||||
rule.Chain.Name != chain.Name {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
filtered = append(filtered, rule)
|
if co.Table.Name != o.Table.Name {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
rules = filtered
|
filtered = append(filtered, obj)
|
||||||
}
|
}
|
||||||
if got, want := len(rules), 1; got != want {
|
objs = filtered
|
||||||
log.Printf("could not carry counter values: unexpected number of rules in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want)
|
|
||||||
return DefaultCounter
|
|
||||||
}
|
}
|
||||||
if got, want := len(rules[0].Exprs), 1; got != want {
|
if got, want := len(objs), 1; got != want {
|
||||||
log.Printf("could not carry counter values: unexpected number of exprs in rule 0 in table %v, chain %v: got %d, want %d", table.Name, chain.Name, got, want)
|
log.Printf("could not carry counter values: unexpected number of objects in table %v: got %d, want %d", o.Table.Name, got, want)
|
||||||
return DefaultCounter
|
o.Bytes = DefaultCounterObj.Bytes
|
||||||
|
o.Packets = DefaultCounterObj.Packets
|
||||||
|
return o
|
||||||
}
|
}
|
||||||
if ce, ok := rules[0].Exprs[0].(*expr.Counter); ok {
|
if co, ok := objs[0].(*nftables.CounterObj); ok {
|
||||||
return *ce
|
return co
|
||||||
}
|
}
|
||||||
return DefaultCounter
|
o.Bytes = DefaultCounterObj.Bytes
|
||||||
|
o.Packets = DefaultCounterObj.Packets
|
||||||
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyFirewall(dir string) error {
|
func applyFirewall(dir string) error {
|
||||||
@ -571,14 +576,22 @@ func applyFirewall(dir string) error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
counter := getCounter(c, filter, forward)
|
counterObj := getCounterObj(c, &nftables.CounterObj{
|
||||||
|
Table: filter,
|
||||||
|
Name: "fwded",
|
||||||
|
})
|
||||||
|
counter := c.AddObj(counterObj).(*nftables.CounterObj)
|
||||||
|
|
||||||
|
const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix
|
||||||
c.AddRule(&nftables.Rule{
|
c.AddRule(&nftables.Rule{
|
||||||
Table: filter,
|
Table: filter,
|
||||||
Chain: forward,
|
Chain: forward,
|
||||||
Exprs: []expr.Any{
|
Exprs: []expr.Any{
|
||||||
// [ counter pkts 23 bytes 42 ]
|
// [ counter name fwded ]
|
||||||
&counter,
|
&expr.Objref{
|
||||||
|
Type: NFT_OBJECT_COUNTER,
|
||||||
|
Name: counter.Name,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user