diff --git a/models/auth/two_factor.go b/models/auth/two_factor.go new file mode 100644 index 0000000000..b7fe1ba9c1 --- /dev/null +++ b/models/auth/two_factor.go @@ -0,0 +1,21 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT +package auth + +import ( + "context" +) + +// HasTwoFactorByUID returns true if the user has TOTP or WebAuthn enabled for +// their account. +func HasTwoFactorByUID(ctx context.Context, userID int64) (bool, error) { + hasTOTP, err := HasTOTPByUID(ctx, userID) + if err != nil { + return false, err + } + if hasTOTP { + return true, nil + } + + return HasWebAuthnRegistrationsByUID(ctx, userID) +} diff --git a/models/auth/two_factor_test.go b/models/auth/two_factor_test.go new file mode 100644 index 0000000000..cdfc1365d1 --- /dev/null +++ b/models/auth/two_factor_test.go @@ -0,0 +1,35 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later +package auth + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHasTwoFactorByUID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + t.Run("No twofactor", func(t *testing.T) { + ok, err := HasTwoFactorByUID(db.DefaultContext, 2) + require.NoError(t, err) + assert.False(t, ok) + }) + + t.Run("WebAuthn credential", func(t *testing.T) { + ok, err := HasTwoFactorByUID(db.DefaultContext, 32) + require.NoError(t, err) + assert.True(t, ok) + }) + + t.Run("TOTP", func(t *testing.T) { + ok, err := HasTwoFactorByUID(db.DefaultContext, 24) + require.NoError(t, err) + assert.True(t, ok) + }) +} diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go index d0c341a192..be7f5a858a 100644 --- a/models/auth/twofactor.go +++ b/models/auth/twofactor.go @@ -146,9 +146,9 @@ func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) { return twofa, nil } -// HasTwoFactorByUID returns the two-factor authentication token associated with -// the user, if any. -func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) { +// HasTOTPByUID returns the TOTP authentication token associated with +// the user, if the user has TOTP enabled for their account. +func HasTOTPByUID(ctx context.Context, uid int64) (bool, error) { return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{}) } diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 15bd667a4f..a94d5ec8f4 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -247,7 +247,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User { } ctx.Data["Sources"] = sources - hasTOTP, err := auth.HasTwoFactorByUID(ctx, u.ID) + hasTOTP, err := auth.HasTOTPByUID(ctx, u.ID) if err != nil { ctx.ServerError("auth.HasTwoFactorByUID", err) return nil diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index def5f5b8f7..27d4b27bc1 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -240,7 +240,7 @@ func SignInPost(ctx *context.Context) { // If this user is enrolled in 2FA TOTP, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. - hasTOTPtwofa, err := auth.HasTwoFactorByUID(ctx, u.ID) + hasTOTPtwofa, err := auth.HasTOTPByUID(ctx, u.ID) if err != nil { ctx.ServerError("UserSignIn", err) return diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 9b0141c14e..86170d78fc 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -163,15 +163,14 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r // If this user is enrolled in 2FA, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. // We deliberately ignore the skip local 2fa setting here because we are linking to a previous user here - _, err := auth.GetTwoFactorByUID(ctx, u.ID) + hasTwoFactor, err := auth.HasTwoFactorByUID(ctx, u.ID) if err != nil { - if !auth.IsErrTwoFactorNotEnrolled(err) { - ctx.ServerError("UserLinkAccount", err) - return - } + ctx.ServerError("UserLinkAccount", err) + return + } - err = externalaccount.LinkAccountToUser(ctx, u, gothUser) - if err != nil { + if !hasTwoFactor { + if err := externalaccount.LinkAccountToUser(ctx, u, gothUser); err != nil { ctx.ServerError("UserLinkAccount", err) return } diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 79751450b5..6914e48f47 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -1124,14 +1124,14 @@ func updateAvatarIfNeed(ctx *context.Context, url string, u *user_model.User) { func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model.User, gothUser goth.User) { updateAvatarIfNeed(ctx, gothUser.AvatarURL, u) + var err error needs2FA := false if !source.Cfg.(*oauth2.Source).SkipLocalTwoFA { - _, err := auth.GetTwoFactorByUID(ctx, u.ID) - if err != nil && !auth.IsErrTwoFactorNotEnrolled(err) { + needs2FA, err = auth.HasTwoFactorByUID(ctx, u.ID) + if err != nil { ctx.ServerError("UserSignIn", err) return } - needs2FA = err == nil } oauth2Source := source.Cfg.(*oauth2.Source) diff --git a/routers/web/auth/webauthn.go b/routers/web/auth/webauthn.go index 1079f44a08..f417258c0c 100644 --- a/routers/web/auth/webauthn.go +++ b/routers/web/auth/webauthn.go @@ -36,7 +36,7 @@ func WebAuthn(ctx *context.Context) { return } - hasTwoFactor, err := auth.HasTwoFactorByUID(ctx, ctx.Session.Get("twofaUid").(int64)) + hasTwoFactor, err := auth.HasTOTPByUID(ctx, ctx.Session.Get("twofaUid").(int64)) if err != nil { ctx.ServerError("HasTwoFactorByUID", err) return diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 8d6859ab87..80509c8dbd 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -55,7 +55,7 @@ func DeleteAccountLink(ctx *context.Context) { } func loadSecurityData(ctx *context.Context) { - enrolled, err := auth_model.HasTwoFactorByUID(ctx, ctx.Doer.ID) + enrolled, err := auth_model.HasTOTPByUID(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("SettingsTwoFactor", err) return diff --git a/tests/integration/oauth_test.go b/tests/integration/oauth_test.go index 14ea9c119e..b761babe0c 100644 --- a/tests/integration/oauth_test.go +++ b/tests/integration/oauth_test.go @@ -12,6 +12,7 @@ import ( "testing" auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/json" @@ -635,3 +636,69 @@ func TestOAuth_GrantApplicationOAuth(t *testing.T) { resp = ctx.MakeRequest(t, req, http.StatusSeeOther) assert.Contains(t, test.RedirectURL(resp), "error=access_denied&error_description=the+request+is+denied") } + +func TestSignUpViaOAuthLinking2FA(t *testing.T) { + defer tests.PrepareTestEnv(t)() + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + defer test.MockVariableValue(&setting.OAuth2Client.AccountLinking, setting.OAuth2AccountLinkingAuto)() + + // Fake that user 2 is enrolled into WebAuthn. + t.Cleanup(func() { + require.NoError(t, db.DeleteBeans(db.DefaultContext, &auth_model.WebAuthnCredential{UserID: 2})) + }) + unittest.AssertSuccessfulInsert(t, &auth_model.WebAuthnCredential{UserID: 2}) + + gitlabName := "gitlab" + addAuthSource(t, authSourcePayloadGitLabCustom(gitlabName)) + userGitLabUserID := "107" + + defer mockCompleteUserAuth(func(res http.ResponseWriter, req *http.Request) (goth.User, error) { + return goth.User{ + Provider: gitlabName, + UserID: userGitLabUserID, + NickName: "user2", + Email: "user2@example.com", + }, nil + })() + req := NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp := MakeRequest(t, req, http.StatusSeeOther) + + // Make sure the user has to go through 2FA after linking. + assert.Equal(t, "/user/webauthn", test.RedirectURL(resp)) +} + +func TestSignUpViaOAuth2FA(t *testing.T) { + defer tests.PrepareTestEnv(t)() + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + defer test.MockVariableValue(&setting.OAuth2Client.AccountLinking, setting.OAuth2AccountLinkingAuto)() + + gitlabName := "gitlab" + addAuthSource(t, authSourcePayloadGitLabCustom(gitlabName)) + userGitLabUserID := "21" + + defer mockCompleteUserAuth(func(res http.ResponseWriter, req *http.Request) (goth.User, error) { + return goth.User{ + Provider: gitlabName, + UserID: userGitLabUserID, + NickName: "user2", + Email: "user2@example.com", + }, nil + })() + req := NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp := MakeRequest(t, req, http.StatusSeeOther) + + // Make sure the user can login normally and is linked. + assert.Equal(t, "/", test.RedirectURL(resp)) + + // Fake that user 2 is enrolled into WebAuthn. + t.Cleanup(func() { + require.NoError(t, db.DeleteBeans(db.DefaultContext, &auth_model.WebAuthnCredential{UserID: 2})) + }) + unittest.AssertSuccessfulInsert(t, &auth_model.WebAuthnCredential{UserID: 2}) + + req = NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp = MakeRequest(t, req, http.StatusSeeOther) + + // Make sure user has to go through 2FA. + assert.Equal(t, "/user/webauthn", test.RedirectURL(resp)) +}