diff --git a/auth/oidc.go b/auth/oidc.go index 9ba2558c..458c9388 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -22,7 +22,7 @@ type OIDCAuth struct { OIDC *OAuthClient `json:"oidc" yaml:"oidc"` } -// Auth is requested by the identity provider in the callback phase of an oauth code flow. +// Auth is executed when the identity provider enters the callback phase of an oauth code flow. func (a OIDCAuth) Auth(r *http.Request, usr users.Store, _ *settings.Settings, srv *settings.Server) (*users.User, error) { cookie, _ := r.Cookie("auth") if cookie != nil { @@ -69,7 +69,7 @@ func (o *OAuthClient) InitClient() { } } -// InitAuthFlow triggers oidc authentication flow +// InitAuthFlow triggers the oidc authentication flow. func (o *OAuthClient) InitAuthFlow(w http.ResponseWriter, r *http.Request) { o.InitClient() state := fmt.Sprintf("%x", rand.Uint32()) @@ -87,11 +87,12 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv * code := r.URL.Query().Get("code") stateQuery := r.URL.Query().Get("state") - stateCookie, _ := r.Cookie("state") + stateCookie, err := r.Cookie("state") - if code == "" || stateQuery == "" || stateQuery != stateCookie.Value { + // Validate state + if code == "" || stateQuery == "" || err != nil || stateQuery != stateCookie.Value { log.Fatal("Invalid request") - return nil, os.ErrInvalid + return nil, os.ErrPermission } // Exchange code for token @@ -100,15 +101,13 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv * log.Fatal(err) return nil, err } - log.Println("oidc got token") // Parse id token rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { log.Fatal("Invalid token") - return nil, nil + return nil, os.ErrPermission } - log.Println("oidc parsed token") // Verify id token idToken, err := o.Verifier.Verify(context.Background(), rawIDToken) @@ -116,7 +115,6 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv * log.Fatal("oidc verify failed") return nil, err } - log.Println("oidc verified token") // Extract claims var claims struct { @@ -137,7 +135,7 @@ func (o *OAuthClient) HandleAuthCallback(r *http.Request, usr users.Store, srv * return nil, os.ErrPermission } u.AuthSource = "oidc" - log.Println("oidc success (user, claims) ", u, claims) + log.Println("oidc success (user, claims) ", u.Username, claims) return u, nil }