Make SSPI auth mockable (#27036)

Before, the SSPI auth is only complied for Windows, it's difficult to
test and it breaks a lot.

Now, make the SSPI auth mockable and testable.
This commit is contained in:
wxiaoguang 2023-09-18 07:32:56 +08:00 committed by GitHub
parent 47b878858a
commit 8531ca0837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 72 additions and 76 deletions

View File

@ -705,7 +705,10 @@ func buildAuthGroup() *auth.Group {
if setting.Service.EnableReverseProxyAuthAPI {
group.Add(&auth.ReverseProxy{})
}
specialAdd(group)
if setting.IsWindows && auth_model.IsSSPIEnabled() {
group.Add(&auth.SSPI{}) // it MUST be the last, see the comment of SSPI
}
return group
}

View File

@ -1,10 +0,0 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build !windows
package v1
import auth_service "code.gitea.io/gitea/services/auth"
func specialAdd(group *auth_service.Group) {}

View File

@ -1,19 +0,0 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package v1
import (
"code.gitea.io/gitea/models/auth"
auth_service "code.gitea.io/gitea/services/auth"
)
// specialAdd registers the SSPI auth method as the last method in the list.
// The SSPI plugin is expected to be executed last, as it returns 401 status code if negotiation
// fails (or if negotiation should continue), which would prevent other authentication methods
// to execute at all.
func specialAdd(group *auth_service.Group) {
if auth.IsSSPIEnabled() {
group.Add(&auth_service.SSPI{})
}
}

View File

@ -1,10 +0,0 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build !windows
package web
import auth_service "code.gitea.io/gitea/services/auth"
func specialAdd(group *auth_service.Group) {}

View File

@ -1,19 +0,0 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package web
import (
"code.gitea.io/gitea/models/auth"
auth_service "code.gitea.io/gitea/services/auth"
)
// specialAdd registers the SSPI auth method as the last method in the list.
// The SSPI plugin is expected to be executed last, as it returns 401 status code if negotiation
// fails (or if negotiation should continue), which would prevent other authentication methods
// to execute at all.
func specialAdd(group *auth_service.Group) {
if auth.IsSSPIEnabled() {
group.Add(&auth_service.SSPI{})
}
}

View File

@ -8,6 +8,7 @@ import (
"net/http"
"strings"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/context"
@ -92,7 +93,10 @@ func buildAuthGroup() *auth_service.Group {
if setting.Service.EnableReverseProxyAuth {
group.Add(&auth_service.ReverseProxy{})
}
specialAdd(group)
if setting.IsWindows && auth_model.IsSSPIEnabled() {
group.Add(&auth_service.SSPI{}) // it MUST be the last, see the comment of SSPI
}
return group
}

View File

@ -22,19 +22,21 @@ import (
"code.gitea.io/gitea/services/auth/source/sspi"
gouuid "github.com/google/uuid"
"github.com/quasoft/websspi"
)
const (
tplSignIn base.TplName = "user/auth/signin"
)
type SSPIAuth interface {
AppendAuthenticateHeader(w http.ResponseWriter, data string)
Authenticate(r *http.Request, w http.ResponseWriter) (userInfo *SSPIUserInfo, outToken string, err error)
}
var (
// sspiAuth is a global instance of the websspi authentication package,
// which is used to avoid acquiring the server credential handle on
// every request
sspiAuth *websspi.Authenticator
sspiAuthOnce sync.Once
sspiAuth SSPIAuth // a global instance of the websspi authenticator to avoid acquiring the server credential handle on every request
sspiAuthOnce sync.Once
sspiAuthErrInit error
// Ensure the struct implements the interface.
_ Method = &SSPI{}
@ -42,8 +44,9 @@ var (
// SSPI implements the SingleSignOn interface and authenticates requests
// via the built-in SSPI module in Windows for SPNEGO authentication.
// On successful authentication returns a valid user object.
// Returns nil if authentication fails.
// The SSPI plugin is expected to be executed last, as it returns 401 status code if negotiation
// fails (or if negotiation should continue), which would prevent other authentication methods
// to execute at all.
type SSPI struct{}
// Name represents the name of auth method
@ -56,15 +59,10 @@ func (s *SSPI) Name() string {
// If negotiation should continue or authentication fails, immediately returns a 401 HTTP
// response code, as required by the SPNEGO protocol.
func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, sess SessionStore) (*user_model.User, error) {
var errInit error
sspiAuthOnce.Do(func() {
config := websspi.NewConfig()
sspiAuth, errInit = websspi.New(config)
})
if errInit != nil {
return nil, errInit
sspiAuthOnce.Do(func() { sspiAuthErrInit = sspiAuthInit() })
if sspiAuthErrInit != nil {
return nil, sspiAuthErrInit
}
if !s.shouldAuthenticate(req) {
return nil, nil
}

View File

@ -0,0 +1,30 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build !windows
package auth
import (
"errors"
"net/http"
)
type SSPIUserInfo struct {
Username string // Name of user, usually in the form DOMAIN\User
Groups []string // The global groups the user is a member of
}
type sspiAuthMock struct{}
func (s sspiAuthMock) AppendAuthenticateHeader(w http.ResponseWriter, data string) {
}
func (s sspiAuthMock) Authenticate(r *http.Request, w http.ResponseWriter) (userInfo *SSPIUserInfo, outToken string, err error) {
return nil, "", errors.New("not implemented")
}
func sspiAuthInit() error {
sspiAuth = &sspiAuthMock{} // TODO: we can mock the SSPI auth in tests
return nil
}

View File

@ -0,0 +1,19 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build windows
package auth
import (
"github.com/quasoft/websspi"
)
type SSPIUserInfo = websspi.UserInfo
func sspiAuthInit() error {
var err error
config := websspi.NewConfig()
sspiAuth, err = websspi.New(config)
return err
}