diff --git a/cmd/web.go b/cmd/web.go index bc344db540..da6c987ff8 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -142,10 +142,8 @@ func runWeb(ctx *cli.Context) error { return err } } - installCtx, cancel := context.WithCancel(graceful.GetManager().HammerContext()) - c := install.Routes(installCtx) + c := install.Routes() err := listen(c, false) - cancel() if err != nil { log.Critical("Unable to open listener for installer. Is Gitea already running?") graceful.GetManager().DoGracefulShutdown() diff --git a/modules/context/context.go b/modules/context/context.go index 1e15081479..9e351432c4 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -68,12 +68,12 @@ func (ctx *Context) TrHTMLEscapeArgs(msg string, args ...string) string { return ctx.Locale.Tr(msg, trArgs...) } -type contextKeyType struct{} +type webContextKeyType struct{} -var contextKey interface{} = contextKeyType{} +var WebContextKey = webContextKeyType{} -func GetContext(req *http.Request) *Context { - ctx, _ := req.Context().Value(contextKey).(*Context) +func GetWebContext(req *http.Request) *Context { + ctx, _ := req.Context().Value(WebContextKey).(*Context) return ctx } @@ -86,7 +86,7 @@ type ValidateContext struct { func GetValidateContext(req *http.Request) (ctx *ValidateContext) { if ctxAPI, ok := req.Context().Value(apiContextKey).(*APIContext); ok { ctx = &ValidateContext{Base: ctxAPI.Base} - } else if ctxWeb, ok := req.Context().Value(contextKey).(*Context); ok { + } else if ctxWeb, ok := req.Context().Value(WebContextKey).(*Context); ok { ctx = &ValidateContext{Base: ctxWeb.Base} } else { panic("invalid context, expect either APIContext or Context") @@ -135,7 +135,7 @@ func Contexter() func(next http.Handler) http.Handler { ctx.PageData = map[string]any{} ctx.Data["PageData"] = ctx.PageData - ctx.Base.AppendContextValue(contextKey, ctx) + ctx.Base.AppendContextValue(WebContextKey, ctx) ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo }) ctx.Csrf = PrepareCSRFProtector(csrfOpts, ctx) diff --git a/modules/context/package.go b/modules/context/package.go index b1fd7088dd..8052032787 100644 --- a/modules/context/package.go +++ b/modules/context/package.go @@ -150,7 +150,7 @@ func PackageContexter() func(next http.Handler) http.Handler { } defer baseCleanUp() - ctx.Base.AppendContextValue(contextKey, ctx) + ctx.Base.AppendContextValue(WebContextKey, ctx) next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/modules/web/handler.go b/modules/web/handler.go index 5013bac93f..c8aebd9051 100644 --- a/modules/web/handler.go +++ b/modules/web/handler.go @@ -22,7 +22,7 @@ type ResponseStatusProvider interface { // TODO: decouple this from the context package, let the context package register these providers var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusProvider{ reflect.TypeOf(&context.APIContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetAPIContext(req) }, - reflect.TypeOf(&context.Context{}): func(req *http.Request) ResponseStatusProvider { return context.GetContext(req) }, + reflect.TypeOf(&context.Context{}): func(req *http.Request) ResponseStatusProvider { return context.GetWebContext(req) }, reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) }, } diff --git a/routers/install/install.go b/routers/install/install.go index 89b91a5a48..4635cd7cb6 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -59,7 +59,7 @@ func Contexter() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { base, baseCleanUp := context.NewBaseContext(resp, req) - ctx := context.Context{ + ctx := &context.Context{ Base: base, Flash: &middleware.Flash{}, Render: rnd, @@ -67,6 +67,7 @@ func Contexter() func(next http.Handler) http.Handler { } defer baseCleanUp() + ctx.AppendContextValue(context.WebContextKey, ctx) ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data.MergeFrom(middleware.ContextData{ "locale": ctx.Locale, diff --git a/routers/install/routes.go b/routers/install/routes.go index 52c07cfa26..f09a22b1e6 100644 --- a/routers/install/routes.go +++ b/routers/install/routes.go @@ -4,7 +4,6 @@ package install import ( - goctx "context" "fmt" "html" "net/http" @@ -18,7 +17,7 @@ import ( ) // Routes registers the installation routes -func Routes(ctx goctx.Context) *web.Route { +func Routes() *web.Route { base := web.NewRoute() base.Use(common.ProtocolMiddlewares()...) base.RouteMethods("/assets/*", "GET, HEAD", public.AssetsHandlerFunc("/assets/")) diff --git a/routers/install/routes_test.go b/routers/install/routes_test.go index e3d2a42467..fcbd052977 100644 --- a/routers/install/routes_test.go +++ b/routers/install/routes_test.go @@ -1,24 +1,41 @@ -// Copyright 2021 The Gitea Authors. All rights reserved. +// Copyright 2023 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT package install import ( - "context" + "net/http/httptest" + "path/filepath" "testing" + "code.gitea.io/gitea/models/unittest" + "github.com/stretchr/testify/assert" ) func TestRoutes(t *testing.T) { - // TODO: this test seems not really testing the handlers - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - base := Routes(ctx) - assert.NotNil(t, base) - r := base.R.Routes()[1] - routes := r.SubRoutes.Routes()[0] - assert.EqualValues(t, "/", routes.Pattern) - assert.Nil(t, routes.SubRoutes) - assert.Len(t, routes.Handlers, 2) + r := Routes() + assert.NotNil(t, r) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + r.ServeHTTP(w, req) + assert.EqualValues(t, 200, w.Code) + assert.Contains(t, w.Body.String(), `class="page-content install"`) + + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/no-such", nil) + r.ServeHTTP(w, req) + assert.EqualValues(t, 404, w.Code) + + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/assets/img/gitea.svg", nil) + r.ServeHTTP(w, req) + assert.EqualValues(t, 200, w.Code) +} + +func TestMain(m *testing.M) { + unittest.MainTest(m, &unittest.TestOptions{ + GiteaRootPath: filepath.Join("..", ".."), + }) } diff --git a/routers/web/web.go b/routers/web/web.go index c230d33398..395fc9425f 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -1405,7 +1405,7 @@ func registerRoutes(m *web.Route) { } m.NotFound(func(w http.ResponseWriter, req *http.Request) { - ctx := context.GetContext(req) + ctx := context.GetWebContext(req) ctx.NotFound("", nil) }) } diff --git a/services/auth/auth.go b/services/auth/auth.go index 905c776e58..c7fdc56cbe 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -92,7 +92,7 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore middleware.SetLocaleCookie(resp, user.Language, 0) // Clear whatever CSRF has right now, force to generate a new one - if ctx := gitea_context.GetContext(req); ctx != nil { + if ctx := gitea_context.GetWebContext(req); ctx != nil { ctx.Csrf.DeleteCookie(ctx) } }