122 lines
2.5 KiB
Go
122 lines
2.5 KiB
Go
package radius
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sort"
|
|
"sync"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"goauthentik.io/internal/config"
|
|
"goauthentik.io/internal/outpost/ak"
|
|
"goauthentik.io/internal/outpost/radius/metrics"
|
|
|
|
"layeh.com/radius"
|
|
)
|
|
|
|
type ProviderInstance struct {
|
|
ClientNetworks []*net.IPNet
|
|
SharedSecret []byte
|
|
MFASupport bool
|
|
|
|
appSlug string
|
|
flowSlug string
|
|
providerId int32
|
|
s *RadiusServer
|
|
log *log.Entry
|
|
}
|
|
|
|
type RadiusServer struct {
|
|
s radius.PacketServer
|
|
log *log.Entry
|
|
ac *ak.APIController
|
|
|
|
providers []*ProviderInstance
|
|
}
|
|
|
|
func NewServer(ac *ak.APIController) *RadiusServer {
|
|
rs := &RadiusServer{
|
|
log: log.WithField("logger", "authentik.outpost.radius"),
|
|
ac: ac,
|
|
providers: []*ProviderInstance{},
|
|
}
|
|
rs.s = radius.PacketServer{
|
|
Handler: rs,
|
|
SecretSource: rs,
|
|
Addr: config.Get().Listen.Radius,
|
|
}
|
|
return rs
|
|
}
|
|
|
|
type cpp struct {
|
|
c *net.IPNet
|
|
p *ProviderInstance
|
|
}
|
|
|
|
func (rs *RadiusServer) RADIUSSecret(ctx context.Context, remoteAddr net.Addr) ([]byte, error) {
|
|
matchedPrefixes := []cpp{}
|
|
|
|
host, _, err := net.SplitHostPort(remoteAddr.String())
|
|
if err != nil {
|
|
rs.log.WithError(err).Warning("Failed to get remote IP")
|
|
return nil, err
|
|
}
|
|
ip := net.ParseIP(host)
|
|
|
|
// Check all networks we have and remember all matching prefixes
|
|
for _, p := range rs.providers {
|
|
for _, cidr := range p.ClientNetworks {
|
|
if cidr.Contains(ip) {
|
|
matchedPrefixes = append(matchedPrefixes, cpp{
|
|
c: cidr,
|
|
p: p,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
if len(matchedPrefixes) < 1 {
|
|
rs.log.WithField("ip", ip.String()).Warning("No matching prefixes found")
|
|
return []byte{}, nil
|
|
}
|
|
// Sort matched cidrs by prefix length
|
|
sort.Slice(matchedPrefixes, func(i, j int) bool {
|
|
_, bi := matchedPrefixes[i].c.Mask.Size()
|
|
_, bj := matchedPrefixes[j].c.Mask.Size()
|
|
return bi < bj
|
|
})
|
|
candidate := matchedPrefixes[0]
|
|
rs.log.WithField("ip", ip.String()).WithField("cidr", candidate.c.String()).Debug("Matched CIDR")
|
|
return candidate.p.SharedSecret, nil
|
|
}
|
|
|
|
func (rs *RadiusServer) Start() error {
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
metrics.RunServer()
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
err := rs.StartRadiusServer()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func (rs *RadiusServer) Stop() error {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
err := rs.s.Shutdown(ctx)
|
|
cancel()
|
|
return err
|
|
}
|
|
|
|
func (rs *RadiusServer) TimerFlowCacheExpiry(context.Context) {}
|
|
|
|
func (rs *RadiusServer) Type() string {
|
|
return "radius"
|
|
}
|