241 lines
6.9 KiB
Go
241 lines
6.9 KiB
Go
// Copyright 2025 Joshua Rich <joshua.rich@gmail.com>.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
package net
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jsimonetti/rtnetlink"
|
|
|
|
"github.com/joshuar/go-hass-agent/internal/components/logging"
|
|
"github.com/joshuar/go-hass-agent/internal/components/preferences"
|
|
"github.com/joshuar/go-hass-agent/internal/hass/sensor"
|
|
"github.com/joshuar/go-hass-agent/internal/linux"
|
|
)
|
|
|
|
// netStatsWorker is the object used for tracking network stats sensors. It
|
|
// holds a netlink connection and a map of links with their stats sensors.
|
|
type netStatsWorker struct {
|
|
statsSensors map[string]map[netStatsType]*netStatsSensor
|
|
nlconn *rtnetlink.Conn
|
|
prefs *WorkerPrefs
|
|
delta time.Duration
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// updateTotals takes the total Rx/Tx bytes and updates the total sensors.
|
|
func (w *netStatsWorker) updateTotals(totalBytesRx, totalBytesTx uint64) {
|
|
stats := &rtnetlink.LinkStats64{
|
|
RXBytes: totalBytesRx,
|
|
TXBytes: totalBytesTx,
|
|
}
|
|
for _, sensorType := range sensorList {
|
|
w.statsSensors[totalsName][sensorType].update(totalsName, sensorType, stats, w.delta)
|
|
}
|
|
}
|
|
|
|
func (w *netStatsWorker) UpdateDelta(delta time.Duration) {
|
|
w.delta = delta
|
|
}
|
|
|
|
func (w *netStatsWorker) Sensors(_ context.Context) ([]sensor.Entity, error) {
|
|
// Get all links.
|
|
links, err := w.getLinks()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error accessing netlink: %w", err)
|
|
}
|
|
|
|
// Get link stats, filtering "uninteresting" links.
|
|
stats := w.getLinkStats(links)
|
|
sensors := make([]sensor.Entity, 0, len(stats)*4+4)
|
|
|
|
w.mu.Lock()
|
|
|
|
// Counters for totals.
|
|
var totalBytesRx, totalBytesTx uint64
|
|
|
|
// For each link, update the link sensors with the new stats.
|
|
for _, link := range stats {
|
|
name := link.name
|
|
stats := link.stats
|
|
totalBytesRx += stats.RXBytes
|
|
totalBytesTx += stats.TXBytes
|
|
|
|
// Skip ignored devices.
|
|
if slices.ContainsFunc(w.prefs.IgnoredDevices, func(e string) bool {
|
|
return strings.HasPrefix(name, e)
|
|
}) {
|
|
continue
|
|
}
|
|
|
|
if _, ok := w.statsSensors[name]; ok { // Existing link/sensors, update.
|
|
for sensorType := range w.statsSensors[name] {
|
|
w.statsSensors[name][sensorType].update(name, sensorType, stats, w.delta)
|
|
}
|
|
} else { // New link, add to tracking map.
|
|
w.statsSensors[name] = generateSensors(name, stats)
|
|
}
|
|
// Create a list of sensors.
|
|
for _, s := range w.statsSensors[name] {
|
|
sensors = append(sensors, s.Entity)
|
|
}
|
|
}
|
|
|
|
if len(stats) > 0 {
|
|
// Update the totals sensors based on the counters.
|
|
w.updateTotals(totalBytesRx, totalBytesTx)
|
|
// Append the total sensors to the list of sensors.
|
|
for _, s := range w.statsSensors[totalsName] {
|
|
sensors = append(sensors, s.Entity)
|
|
}
|
|
}
|
|
|
|
w.mu.Unlock()
|
|
|
|
return sensors, nil
|
|
}
|
|
|
|
func (w *netStatsWorker) PreferencesID() string {
|
|
return preferencesID
|
|
}
|
|
|
|
func (w *netStatsWorker) DefaultPreferences() WorkerPrefs {
|
|
prefs := WorkerPrefs{
|
|
UpdateInterval: rateInterval.String(),
|
|
IgnoredDevices: defaultIgnoredDevices,
|
|
}
|
|
|
|
return prefs
|
|
}
|
|
|
|
// getLinks returns all available links on this device. If a problem occurred, a
|
|
// non-nil error is returned.
|
|
func (w *netStatsWorker) getLinks() ([]rtnetlink.LinkMessage, error) {
|
|
links, err := w.nlconn.Link.List()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not retrieve list of devices from netlink: %w", err)
|
|
}
|
|
|
|
return links, nil
|
|
}
|
|
|
|
// getLinkStats collates the network stats for all links. It filters
|
|
// links to those with stats, not in an exclusion list and currently active.
|
|
func (w *netStatsWorker) getLinkStats(links []rtnetlink.LinkMessage) []linkStats {
|
|
allLinkStats := make([]linkStats, 0, len(links))
|
|
|
|
for _, msg := range links {
|
|
if msg.Attributes == nil {
|
|
continue
|
|
}
|
|
|
|
// Ignore loopback.
|
|
if msg.Attributes.Name == loopbackDeviceName {
|
|
continue
|
|
}
|
|
|
|
// Ignore devices that are not currently active.
|
|
if *msg.Attributes.Carrier == 0 {
|
|
continue
|
|
}
|
|
|
|
name := msg.Attributes.Name
|
|
stats := msg.Attributes.Stats64
|
|
|
|
if stats32 := msg.Attributes.Stats; stats == nil && stats32 != nil {
|
|
stats = &rtnetlink.LinkStats64{
|
|
RXPackets: uint64(stats32.RXPackets),
|
|
TXPackets: uint64(stats32.TXPackets),
|
|
RXBytes: uint64(stats32.RXBytes),
|
|
TXBytes: uint64(stats32.TXBytes),
|
|
RXErrors: uint64(stats32.RXErrors),
|
|
TXErrors: uint64(stats32.TXErrors),
|
|
RXDropped: uint64(stats32.RXDropped),
|
|
TXDropped: uint64(stats32.TXDropped),
|
|
Multicast: uint64(stats32.Multicast),
|
|
Collisions: uint64(stats32.Collisions),
|
|
RXLengthErrors: uint64(stats32.RXLengthErrors),
|
|
RXOverErrors: uint64(stats32.RXOverErrors),
|
|
RXCRCErrors: uint64(stats32.RXCRCErrors),
|
|
RXFrameErrors: uint64(stats32.RXFrameErrors),
|
|
RXFIFOErrors: uint64(stats32.RXFIFOErrors),
|
|
RXMissedErrors: uint64(stats32.RXMissedErrors),
|
|
TXAbortedErrors: uint64(stats32.TXAbortedErrors),
|
|
TXCarrierErrors: uint64(stats32.TXCarrierErrors),
|
|
TXFIFOErrors: uint64(stats32.TXFIFOErrors),
|
|
TXHeartbeatErrors: uint64(stats32.TXHeartbeatErrors),
|
|
TXWindowErrors: uint64(stats32.TXWindowErrors),
|
|
RXCompressed: uint64(stats32.RXCompressed),
|
|
TXCompressed: uint64(stats32.TXCompressed),
|
|
RXNoHandler: uint64(stats32.RXNoHandler),
|
|
RXOtherhostDropped: 0,
|
|
}
|
|
}
|
|
|
|
if stats != nil {
|
|
allLinkStats = append(allLinkStats, linkStats{
|
|
name: name,
|
|
stats: stats,
|
|
})
|
|
}
|
|
}
|
|
|
|
return allLinkStats
|
|
}
|
|
|
|
// NewNetStatsWorker sets up a sensor worker that tracks network stats.
|
|
func NewNetStatsWorker(ctx context.Context) (*linux.PollingSensorWorker, error) {
|
|
conn, err := rtnetlink.Dial(nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not connect to netlink: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
|
|
if err = conn.Close(); err != nil {
|
|
logging.FromContext(ctx).Debug("Could not close netlink connection.",
|
|
slog.String("worker", netRatesWorkerID),
|
|
slog.Any("error", err))
|
|
}
|
|
}()
|
|
|
|
ratesWorker := &netStatsWorker{
|
|
statsSensors: make(map[string]map[netStatsType]*netStatsSensor),
|
|
nlconn: conn,
|
|
}
|
|
ratesWorker.statsSensors[totalsName] = generateSensors(totalsName, nil)
|
|
|
|
ratesWorker.prefs, err = preferences.LoadWorker(ctx, ratesWorker)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not load preferences: %w", err)
|
|
}
|
|
|
|
//nolint:nilnil
|
|
if ratesWorker.prefs.IsDisabled() {
|
|
return nil, nil
|
|
}
|
|
|
|
pollInterval, err := time.ParseDuration(ratesWorker.prefs.UpdateInterval)
|
|
if err != nil {
|
|
logging.FromContext(ctx).Warn("Invalid polling interval, using default",
|
|
slog.String("worker", netRatesWorkerID),
|
|
slog.String("given_interval", ratesWorker.prefs.UpdateInterval),
|
|
slog.String("default_interval", rateInterval.String()))
|
|
|
|
pollInterval = rateInterval
|
|
}
|
|
|
|
worker := linux.NewPollingSensorWorker(netRatesWorkerID, pollInterval, rateJitter)
|
|
worker.PollingSensorType = ratesWorker
|
|
|
|
return worker, nil
|
|
}
|