diff --git a/models/auth/two_factor.go b/models/auth/two_factor.go new file mode 100644 index 0000000000..e8f19c33cc --- /dev/null +++ b/models/auth/two_factor.go @@ -0,0 +1,21 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later +package auth + +import ( + "context" +) + +// HasTwoFactorByUID returns true if the user has TOTP or WebAuthn enabled for +// their account. +func HasTwoFactorByUID(ctx context.Context, userID int64) (bool, error) { + hasTOTP, err := HasTOTPByUID(ctx, userID) + if err != nil { + return false, err + } + if hasTOTP { + return true, nil + } + + return HasWebAuthnRegistrationsByUID(ctx, userID) +} diff --git a/models/auth/two_factor_test.go b/models/auth/two_factor_test.go new file mode 100644 index 0000000000..36e0404ae2 --- /dev/null +++ b/models/auth/two_factor_test.go @@ -0,0 +1,34 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later +package auth + +import ( + "testing" + + "forgejo.org/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHasTwoFactorByUID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + t.Run("No twofactor", func(t *testing.T) { + ok, err := HasTwoFactorByUID(t.Context(), 2) + require.NoError(t, err) + assert.False(t, ok) + }) + + t.Run("WebAuthn credential", func(t *testing.T) { + ok, err := HasTwoFactorByUID(t.Context(), 32) + require.NoError(t, err) + assert.True(t, ok) + }) + + t.Run("TOTP", func(t *testing.T) { + ok, err := HasTwoFactorByUID(t.Context(), 24) + require.NoError(t, err) + assert.True(t, ok) + }) +} diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go index f4c1e36b3f..edff471836 100644 --- a/models/auth/twofactor.go +++ b/models/auth/twofactor.go @@ -139,9 +139,9 @@ func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) { return twofa, nil } -// HasTwoFactorByUID returns the two-factor authentication token associated with -// the user, if any. -func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) { +// HasTOTPByUID returns the TOTP authentication token associated with +// the user, if the user has TOTP enabled for their account. +func HasTOTPByUID(ctx context.Context, uid int64) (bool, error) { return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{}) } diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index f53a0197cb..f6d214d24c 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -249,7 +249,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User { } ctx.Data["Sources"] = sources - hasTOTP, err := auth.HasTwoFactorByUID(ctx, u.ID) + hasTOTP, err := auth.HasTOTPByUID(ctx, u.ID) if err != nil { ctx.ServerError("auth.HasTwoFactorByUID", err) return nil diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 755126b8e0..64006eeae8 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -242,7 +242,7 @@ func SignInPost(ctx *context.Context) { // If this user is enrolled in 2FA TOTP, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. - hasTOTPtwofa, err := auth.HasTwoFactorByUID(ctx, u.ID) + hasTOTPtwofa, err := auth.HasTOTPByUID(ctx, u.ID) if err != nil { ctx.ServerError("UserSignIn", err) return diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 9566652751..fbf03ca475 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -155,15 +155,14 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r // If this user is enrolled in 2FA, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. // We deliberately ignore the skip local 2fa setting here because we are linking to a previous user here - _, err := auth.GetTwoFactorByUID(ctx, u.ID) + hasTwoFactor, err := auth.HasTwoFactorByUID(ctx, u.ID) if err != nil { - if !auth.IsErrTwoFactorNotEnrolled(err) { - ctx.ServerError("UserLinkAccount", err) - return - } + ctx.ServerError("UserLinkAccount", err) + return + } - err = externalaccount.LinkAccountToUser(ctx, u, gothUser) - if err != nil { + if !hasTwoFactor { + if err := externalaccount.LinkAccountToUser(ctx, u, gothUser); err != nil { ctx.ServerError("UserLinkAccount", err) return } diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index c55397da88..aa599bd252 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -1243,12 +1243,11 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model needs2FA := false if !source.Cfg.(*oauth2.Source).SkipLocalTwoFA { - _, err := auth.GetTwoFactorByUID(ctx, u.ID) - if err != nil && !auth.IsErrTwoFactorNotEnrolled(err) { + needs2FA, err = auth.HasTwoFactorByUID(ctx, u.ID) + if err != nil { ctx.ServerError("UserSignIn", err) return } - needs2FA = err == nil } oauth2Source := source.Cfg.(*oauth2.Source) diff --git a/routers/web/auth/webauthn.go b/routers/web/auth/webauthn.go index ac69e03389..3da6199b6e 100644 --- a/routers/web/auth/webauthn.go +++ b/routers/web/auth/webauthn.go @@ -36,7 +36,7 @@ func WebAuthn(ctx *context.Context) { return } - hasTwoFactor, err := auth.HasTwoFactorByUID(ctx, ctx.Session.Get("twofaUid").(int64)) + hasTwoFactor, err := auth.HasTOTPByUID(ctx, ctx.Session.Get("twofaUid").(int64)) if err != nil { ctx.ServerError("HasTwoFactorByUID", err) return diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 9acc6ab6f0..8b801cfebd 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -55,7 +55,7 @@ func DeleteAccountLink(ctx *context.Context) { } func loadSecurityData(ctx *context.Context) { - enrolled, err := auth_model.HasTwoFactorByUID(ctx, ctx.Doer.ID) + enrolled, err := auth_model.HasTOTPByUID(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("SettingsTwoFactor", err) return diff --git a/services/mailer/mail.go b/services/mailer/mail.go index c0a37d9fb2..4269686f2d 100644 --- a/services/mailer/mail.go +++ b/services/mailer/mail.go @@ -689,7 +689,7 @@ func SendRemovedSecurityKey(ctx context.Context, u *user_model.User, securityKey if err != nil { return err } - hasTOTP, err := auth_model.HasTwoFactorByUID(ctx, u.ID) + hasTOTP, err := auth_model.HasTOTPByUID(ctx, u.ID) if err != nil { return err } diff --git a/tests/integration/oauth_test.go b/tests/integration/oauth_test.go index 1f7aca90ad..2097afcfae 100644 --- a/tests/integration/oauth_test.go +++ b/tests/integration/oauth_test.go @@ -1456,3 +1456,69 @@ func TestSignUpViaOAuthDefaultRestricted(t *testing.T) { unittest.AssertExistsIf(t, true, &user_model.User{Name: "gitlab-user"}, "is_restricted = true") } + +func TestSignUpViaOAuthLinking2FA(t *testing.T) { + defer tests.PrepareTestEnv(t)() + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + defer test.MockVariableValue(&setting.OAuth2Client.AccountLinking, setting.OAuth2AccountLinkingAuto)() + + // Fake that user 2 is enrolled into WebAuthn. + t.Cleanup(func() { + unittest.AssertSuccessfulDelete(t, &auth_model.WebAuthnCredential{UserID: 2}) + }) + unittest.AssertSuccessfulInsert(t, &auth_model.WebAuthnCredential{UserID: 2}) + + gitlabName := "gitlab" + addAuthSource(t, authSourcePayloadGitLabCustom(gitlabName)) + userGitLabUserID := "BB(4)=107" + + defer mockCompleteUserAuth(func(res http.ResponseWriter, req *http.Request) (goth.User, error) { + return goth.User{ + Provider: gitlabName, + UserID: userGitLabUserID, + NickName: "user2", + Email: "user2@example.com", + }, nil + })() + req := NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp := MakeRequest(t, req, http.StatusSeeOther) + + // Make sure the user has to go through 2FA after linking. + assert.Equal(t, "/user/webauthn", test.RedirectURL(resp)) +} + +func TestSignUpViaOAuth2FA(t *testing.T) { + defer tests.PrepareTestEnv(t)() + defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + defer test.MockVariableValue(&setting.OAuth2Client.AccountLinking, setting.OAuth2AccountLinkingAuto)() + + gitlabName := "gitlab" + addAuthSource(t, authSourcePayloadGitLabCustom(gitlabName)) + userGitLabUserID := "BB(3)=21" + + defer mockCompleteUserAuth(func(res http.ResponseWriter, req *http.Request) (goth.User, error) { + return goth.User{ + Provider: gitlabName, + UserID: userGitLabUserID, + NickName: "user2", + Email: "user2@example.com", + }, nil + })() + req := NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp := MakeRequest(t, req, http.StatusSeeOther) + + // Make sure the user can login normally and is linked. + assert.Equal(t, "/", test.RedirectURL(resp)) + + // Fake that user 2 is enrolled into WebAuthn. + t.Cleanup(func() { + unittest.AssertSuccessfulDelete(t, &auth_model.WebAuthnCredential{UserID: 2}) + }) + unittest.AssertSuccessfulInsert(t, &auth_model.WebAuthnCredential{UserID: 2}) + + req = NewRequest(t, "GET", fmt.Sprintf("/user/oauth2/%s/callback?code=XYZ&state=XYZ", gitlabName)) + resp = MakeRequest(t, req, http.StatusSeeOther) + + // Make sure user has to go through 2FA. + assert.Equal(t, "/user/webauthn", test.RedirectURL(resp)) +}