chore: improve state validation + cleanup

This commit is contained in:
Marcell FÜLÖP 2023-02-19 11:59:43 +00:00
parent 1621683fdb
commit e89d343e23

View File

@ -22,7 +22,7 @@ type OIDCAuth struct {
OIDC *OAuthClient `json:"oidc" yaml:"oidc"` OIDC *OAuthClient `json:"oidc" yaml:"oidc"`
} }
// Auth is requested by the identity provider in the callback phase of an oauth code flow. // Auth is executed when the identity provider enters the callback phase of an oauth code flow.
func (a OIDCAuth) Auth(r *http.Request, usr users.Store, _ *settings.Settings, srv *settings.Server) (*users.User, error) { func (a OIDCAuth) Auth(r *http.Request, usr users.Store, _ *settings.Settings, srv *settings.Server) (*users.User, error) {
cookie, _ := r.Cookie("auth") cookie, _ := r.Cookie("auth")
if cookie != nil { if cookie != nil {
@ -69,7 +69,7 @@ func (o *OAuthClient) InitClient() {
} }
} }
// InitAuthFlow triggers oidc authentication flow // InitAuthFlow triggers the oidc authentication flow.
func (o *OAuthClient) InitAuthFlow(w http.ResponseWriter, r *http.Request) { func (o *OAuthClient) InitAuthFlow(w http.ResponseWriter, r *http.Request) {
o.InitClient() o.InitClient()
state := fmt.Sprintf("%x", rand.Uint32()) state := fmt.Sprintf("%x", rand.Uint32())
@ -87,11 +87,12 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv *
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
stateQuery := r.URL.Query().Get("state") stateQuery := r.URL.Query().Get("state")
stateCookie, _ := r.Cookie("state") stateCookie, err := r.Cookie("state")
if code == "" || stateQuery == "" || stateQuery != stateCookie.Value { // Validate state
if code == "" || stateQuery == "" || err != nil || stateQuery != stateCookie.Value {
log.Fatal("Invalid request") log.Fatal("Invalid request")
return nil, os.ErrInvalid return nil, os.ErrPermission
} }
// Exchange code for token // Exchange code for token
@ -100,15 +101,13 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv *
log.Fatal(err) log.Fatal(err)
return nil, err return nil, err
} }
log.Println("oidc got token")
// Parse id token // Parse id token
rawIDToken, ok := oauth2Token.Extra("id_token").(string) rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok { if !ok {
log.Fatal("Invalid token") log.Fatal("Invalid token")
return nil, nil return nil, os.ErrPermission
} }
log.Println("oidc parsed token")
// Verify id token // Verify id token
idToken, err := o.Verifier.Verify(context.Background(), rawIDToken) idToken, err := o.Verifier.Verify(context.Background(), rawIDToken)
@ -116,7 +115,6 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv *
log.Fatal("oidc verify failed") log.Fatal("oidc verify failed")
return nil, err return nil, err
} }
log.Println("oidc verified token")
// Extract claims // Extract claims
var claims struct { var claims struct {
@ -137,7 +135,7 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv *
return nil, os.ErrPermission return nil, os.ErrPermission
} }
u.AuthSource = "oidc" u.AuthSource = "oidc"
log.Println("oidc success (user, claims) ", u, claims) log.Println("oidc success (user, claims) ", u.Username, claims)
return u, nil return u, nil
} }