126 lines
4.4 KiB
Go
126 lines
4.4 KiB
Go
package application
|
|
|
|
import (
|
|
"encoding/base32"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/gorilla/securecookie"
|
|
"github.com/mitchellh/mapstructure"
|
|
"goauthentik.io/api/v3"
|
|
)
|
|
|
|
type OAuthState struct {
|
|
Issuer string `json:"iss" mapstructure:"iss"`
|
|
SessionID string `json:"sid" mapstructure:"sid"`
|
|
State string `json:"state" mapstructure:"state"`
|
|
Redirect string `json:"redirect" mapstructure:"redirect"`
|
|
}
|
|
|
|
func (oas *OAuthState) GetExpirationTime() (*jwt.NumericDate, error) { return nil, nil }
|
|
func (oas *OAuthState) GetIssuedAt() (*jwt.NumericDate, error) { return nil, nil }
|
|
func (oas *OAuthState) GetNotBefore() (*jwt.NumericDate, error) { return nil, nil }
|
|
func (oas *OAuthState) GetIssuer() (string, error) { return oas.Issuer, nil }
|
|
func (oas *OAuthState) GetSubject() (string, error) { return oas.State, nil }
|
|
func (oas *OAuthState) GetAudience() (jwt.ClaimStrings, error) { return nil, nil }
|
|
|
|
var base32RawStdEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
|
|
|
// Validate that the given redirect parameter (?rd=...) is valid and can be used
|
|
// For proxy/forward_single this checks that if the `rd` param has a Hostname (and is a full URL)
|
|
// the hostname matches what's configured, or no hostname must be given
|
|
// For forward_domain this checks if the domain of the URL in `rd` ends with the configured domain
|
|
func (a *Application) checkRedirectParam(r *http.Request) (string, bool) {
|
|
rd := r.URL.Query().Get(redirectParam)
|
|
if rd == "" {
|
|
return "", false
|
|
}
|
|
u, err := url.Parse(rd)
|
|
if err != nil {
|
|
a.log.WithError(err).Warning("Failed to parse redirect URL")
|
|
return "", false
|
|
}
|
|
// Check to make sure we only redirect to allowed places
|
|
if a.Mode() == api.PROXYMODE_PROXY || a.Mode() == api.PROXYMODE_FORWARD_SINGLE {
|
|
ext, err := url.Parse(a.proxyConfig.ExternalHost)
|
|
if err != nil {
|
|
return "", false
|
|
}
|
|
// Either hostname needs to match the configured domain, or host name must be empty for just a path
|
|
if u.Host == "" {
|
|
u.Host = ext.Host
|
|
u.Scheme = ext.Scheme
|
|
}
|
|
if u.Host != ext.Host {
|
|
a.log.WithField("url", u.String()).WithField("ext", ext.String()).Warning("redirect URI did not contain external host")
|
|
return "", false
|
|
}
|
|
} else {
|
|
if !strings.HasSuffix(u.Host, *a.proxyConfig.CookieDomain) {
|
|
a.log.WithField("host", u.Host).WithField("dom", *a.proxyConfig.CookieDomain).Warning("redirect URI Host was not included in cookie domain")
|
|
return "", false
|
|
}
|
|
}
|
|
return u.String(), true
|
|
}
|
|
|
|
func (a *Application) createState(r *http.Request, fwd string) (string, error) {
|
|
s, _ := a.sessions.Get(r, a.SessionName())
|
|
if s.ID == "" {
|
|
// Ensure session has an ID
|
|
s.ID = base32RawStdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
|
|
}
|
|
st := &OAuthState{
|
|
Issuer: fmt.Sprintf("goauthentik.io/outpost/%s", a.proxyConfig.GetClientId()),
|
|
State: base64.RawURLEncoding.EncodeToString(securecookie.GenerateRandomKey(32)),
|
|
SessionID: s.ID,
|
|
Redirect: fwd,
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, st)
|
|
tokenString, err := token.SignedString([]byte(a.proxyConfig.GetCookieSecret()))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return tokenString, nil
|
|
}
|
|
|
|
func (a *Application) stateFromRequest(r *http.Request) *OAuthState {
|
|
stateJwt := r.URL.Query().Get("state")
|
|
token, err := jwt.Parse(stateJwt, func(token *jwt.Token) (interface{}, error) {
|
|
// Don't forget to validate the alg is what you expect:
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(a.proxyConfig.GetCookieSecret()), nil
|
|
})
|
|
if err != nil {
|
|
a.log.WithError(err).Warning("failed to parse state jwt")
|
|
return nil
|
|
}
|
|
iss, err := token.Claims.GetIssuer()
|
|
if err != nil {
|
|
a.log.WithError(err).Warning("state jwt without issuer")
|
|
return nil
|
|
}
|
|
if iss != fmt.Sprintf("goauthentik.io/outpost/%s", a.proxyConfig.GetClientId()) {
|
|
a.log.WithField("issuer", iss).Warning("invalid state jwt issuer")
|
|
return nil
|
|
}
|
|
claims := &OAuthState{}
|
|
err = mapstructure.Decode(token.Claims, &claims)
|
|
if err != nil {
|
|
a.log.WithError(err).Warning("failed to mapdecode")
|
|
return nil
|
|
}
|
|
s, _ := a.sessions.Get(r, a.SessionName())
|
|
if claims.SessionID != s.ID {
|
|
a.log.WithField("is", claims.SessionID).WithField("should", s.ID).Warning("mismatched session ID")
|
|
return nil
|
|
}
|
|
return claims
|
|
}
|