diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 1656981b..1c9eab0a 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -84,7 +84,7 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth env: CGO_ENABLED: 0 @@ -130,7 +130,7 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth env: CGO_ENABLED: 0 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 32e5de19..ea69097d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -60,7 +60,7 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth env: CGO_ENABLED: 0 @@ -103,7 +103,7 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth env: CGO_ENABLED: 0 diff --git a/Dockerfile b/Dockerfile index 6b6cee1a..4724f6d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,9 +38,9 @@ COPY ./internal ./internal COPY --from=frontend-builder /frontend/dist ./internal/assets/dist RUN CGO_ENABLED=0 go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth # Runner FROM alpine:3.23 AS runner diff --git a/Dockerfile.distroless b/Dockerfile.distroless index 8626028c..00d04107 100644 --- a/Dockerfile.distroless +++ b/Dockerfile.distroless @@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist RUN mkdir -p data RUN CGO_ENABLED=0 go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth # Runner FROM gcr.io/distroless/static-debian12:latest AS runner diff --git a/Makefile b/Makefile index 7f4e393e..616fd994 100644 --- a/Makefile +++ b/Makefile @@ -37,9 +37,9 @@ webui: clean-webui # Build the binary binary: webui CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \ + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \ -o ${BIN_NAME} ./cmd/tinyauth # Build for amd64 diff --git a/cmd/tinyauth/generate_totp.go b/cmd/tinyauth/generate_totp.go index 22102c15..8819922e 100644 --- a/cmd/tinyauth/generate_totp.go +++ b/cmd/tinyauth/generate_totp.go @@ -73,7 +73,7 @@ func generateTotpCmd() *cli.Command { docker = true } - if user.TotpSecret != "" { + if user.TOTPSecret != "" { return fmt.Errorf("user already has a TOTP secret") } @@ -102,14 +102,14 @@ func generateTotpCmd() *cli.Command { qrterminal.GenerateWithConfig(key.URL(), config) - user.TotpSecret = secret + user.TOTPSecret = secret // If using docker escape re-escape it if docker { user.Password = strings.ReplaceAll(user.Password, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") + tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") return nil }, diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index cc7c7261..f5bbb19f 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -5,7 +5,7 @@ import ( "charm.land/huh/v2" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -14,7 +14,7 @@ import ( ) func main() { - tConfig := config.NewDefaultConfiguration() + tConfig := model.NewDefaultConfiguration() loaders := []cli.ResourceLoader{ &loaders.FileLoader{}, @@ -108,11 +108,11 @@ func main() { } } -func runCmd(cfg config.Config) error { +func runCmd(cfg model.Config) error { logger := tlog.NewLogger(cfg.Log) logger.Init() - tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth") + tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth") app := bootstrap.NewBootstrapApp(cfg) diff --git a/cmd/tinyauth/verify_user.go b/cmd/tinyauth/verify_user.go index 5ab7aeee..5058b606 100644 --- a/cmd/tinyauth/verify_user.go +++ b/cmd/tinyauth/verify_user.go @@ -95,7 +95,7 @@ func verifyUserCmd() *cli.Command { return fmt.Errorf("password is incorrect: %w", err) } - if user.TotpSecret == "" { + if user.TOTPSecret == "" { if tCfg.Totp != "" { tlog.App.Warn().Msg("User does not have TOTP secret") } @@ -103,7 +103,7 @@ func verifyUserCmd() *cli.Command { return nil } - ok := totp.Validate(tCfg.Totp, user.TotpSecret) + ok := totp.Validate(tCfg.Totp, user.TOTPSecret) if !ok { return fmt.Errorf("TOTP code incorrect") diff --git a/cmd/tinyauth/version.go b/cmd/tinyauth/version.go index 5bd2d9ac..4bd49924 100644 --- a/cmd/tinyauth/version.go +++ b/cmd/tinyauth/version.go @@ -3,9 +3,8 @@ package main import ( "fmt" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/model" ) func versionCmd() *cli.Command { @@ -15,9 +14,9 @@ func versionCmd() *cli.Command { Configuration: nil, Resources: nil, Run: func(_ []string) error { - fmt.Printf("Version: %s\n", config.Version) - fmt.Printf("Commit Hash: %s\n", config.CommitHash) - fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp) + fmt.Printf("Version: %s\n", model.Version) + fmt.Printf("Commit Hash: %s\n", model.CommitHash) + fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp) return nil }, } diff --git a/gen/gen_env.go b/gen/gen_env.go index 881888a9..36354fff 100644 --- a/gen/gen_env.go +++ b/gen/gen_env.go @@ -10,7 +10,7 @@ import ( "reflect" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type EnvEntry struct { @@ -20,7 +20,7 @@ type EnvEntry struct { } func generateExampleEnv() { - cfg := config.NewDefaultConfiguration() + cfg := model.NewDefaultConfiguration() entries := make([]EnvEntry, 0) root := reflect.TypeOf(cfg).Elem() diff --git a/gen/gen_md.go b/gen/gen_md.go index ae8f0f19..0dcf3822 100644 --- a/gen/gen_md.go +++ b/gen/gen_md.go @@ -10,7 +10,7 @@ import ( "reflect" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type MarkdownEntry struct { @@ -21,7 +21,7 @@ type MarkdownEntry struct { } func generateMarkdown() { - cfg := config.NewDefaultConfiguration() + cfg := model.NewDefaultConfiguration() entries := make([]MarkdownEntry, 0) root := reflect.TypeOf(cfg).Elem() diff --git a/go.mod b/go.mod index d0c5a515..6b8906ed 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,6 @@ require ( github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 - gotest.tools/v3 v3.5.2 k8s.io/apimachinery v0.32.2 k8s.io/client-go v0.32.2 modernc.org/sqlite v1.49.1 @@ -133,6 +132,7 @@ require ( google.golang.org/protobuf v1.36.11 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/v3 v3.5.2 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect modernc.org/libc v1.72.0 // indirect diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 3879c05e..fc86a7ab 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -12,15 +12,15 @@ import ( "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) type BootstrapApp struct { - config config.Config + config model.Config context struct { appUrl string uuid string @@ -29,15 +29,15 @@ type BootstrapApp struct { csrfCookieName string redirectCookieName string oauthSessionCookieName string - users []config.User - oauthProviders map[string]config.OAuthServiceConfig + localUsers *[]model.LocalUser + oauthProviders map[string]model.OAuthServiceConfig configuredProviders []controller.Provider - oidcClients []config.OIDCClientConfig + oidcClients []model.OIDCClientConfig } services Services } -func NewBootstrapApp(config config.Config) *BootstrapApp { +func NewBootstrapApp(config model.Config) *BootstrapApp { return &BootstrapApp{ config: config, } @@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error { return err } - app.context.users = users + app.context.localUsers = users // Setup OAuth providers app.context.oauthProviders = app.config.OAuth.Providers @@ -88,7 +88,7 @@ func (app *BootstrapApp) Setup() error { for id, provider := range app.context.oauthProviders { if provider.Name == "" { - if name, ok := config.OverrideProviders[id]; ok { + if name, ok := model.OverrideProviders[id]; ok { provider.Name = name } else { provider.Name = utils.Capitalize(id) @@ -115,14 +115,14 @@ func (app *BootstrapApp) Setup() error { // Cookie names app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) cookieId := strings.Split(app.context.uuid, "-")[0] - app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) - app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) - app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) - app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId) + app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) // Dumps tlog.App.Trace().Interface("config", app.config).Msg("Config dump") - tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump") + tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump") tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") @@ -171,7 +171,7 @@ func (app *BootstrapApp) Setup() error { }) } - if services.authService.LdapAuthConfigured() { + if services.authService.LDAPAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "LDAP", ID: "ldap", @@ -244,7 +244,7 @@ func (app *BootstrapApp) heartbeatRoutine() { var body heartbeat body.UUID = app.context.uuid - body.Version = config.Version + body.Version = model.Version bodyJson, err := json.Marshal(body) @@ -257,7 +257,7 @@ func (app *BootstrapApp) heartbeatRoutine() { Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond } - heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" + heartbeatURL := model.APIServer + "/v1/instances/heartbeat" for range ticker.C { tlog.App.Debug().Msg("Sending heartbeat") diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 91d36ac2..53cb8504 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -4,9 +4,9 @@ import ( "fmt" "slices" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/gin-gonic/gin" ) @@ -14,7 +14,7 @@ import ( var DEV_MODES = []string{"main", "test", "development"} func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { - if !slices.Contains(DEV_MODES, config.Version) { + if !slices.Contains(DEV_MODES, model.Version) { gin.SetMode(gin.ReleaseMode) } @@ -30,7 +30,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { } contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.context.cookieDomain, + CookieDomain: app.context.cookieDomain, + SessionCookieName: app.context.sessionCookieName, }, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() @@ -98,7 +99,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { proxyController.SetupRoutes() userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.context.cookieDomain, + CookieDomain: app.context.cookieDomain, + SessionCookieName: app.context.sessionCookieName, }, apiRouter, app.services.authService) userController.SetupRoutes() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 91e2b50b..fc2357bc 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -22,14 +22,14 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services := Services{} ldapService := service.NewLdapService(service.LdapServiceConfig{ - Address: app.config.Ldap.Address, - BindDN: app.config.Ldap.BindDN, - BindPassword: app.config.Ldap.BindPassword, - BaseDN: app.config.Ldap.BaseDN, - Insecure: app.config.Ldap.Insecure, - SearchFilter: app.config.Ldap.SearchFilter, - AuthCert: app.config.Ldap.AuthCert, - AuthKey: app.config.Ldap.AuthKey, + Address: app.config.LDAP.Address, + BindDN: app.config.LDAP.BindDN, + BindPassword: app.config.LDAP.BindPassword, + BaseDN: app.config.LDAP.BaseDN, + Insecure: app.config.LDAP.Insecure, + SearchFilter: app.config.LDAP.SearchFilter, + AuthCert: app.config.LDAP.AuthCert, + AuthKey: app.config.LDAP.AuthKey, }) err := ldapService.Init() @@ -89,7 +89,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(service.AuthServiceConfig{ - Users: app.context.users, + LocalUsers: app.context.localUsers, OauthWhitelist: app.config.OAuth.Whitelist, SessionExpiry: app.config.Auth.SessionExpiry, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, @@ -99,7 +99,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er LoginMaxRetries: app.config.Auth.LoginMaxRetries, SessionCookieName: app.context.sessionCookieName, IP: app.config.Auth.IP, - LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, + LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, }, services.ldapService, queries, services.oauthBrokerService) err = authService.Init() diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index da53303b..3362d0de 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -4,7 +4,7 @@ import ( "fmt" "net/url" - "github.com/tinyauthapp/tinyauth/internal/utils" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" @@ -19,7 +19,7 @@ type UserContextResponse struct { Email string `json:"email"` Provider string `json:"provider"` OAuth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` + TOTPPending bool `json:"totpPending"` OAuthName string `json:"oauthName"` } @@ -76,30 +76,31 @@ func (controller *ContextController) SetupRoutes() { } func (controller *ContextController) userContextHandler(c *gin.Context) { - context, err := utils.GetContext(c) - - userContext := UserContextResponse{ - Status: 200, - Message: "Success", - IsLoggedIn: context.IsLoggedIn, - Username: context.Username, - Name: context.Name, - Email: context.Email, - Provider: context.Provider, - OAuth: context.OAuth, - TotpPending: context.TotpPending, - OAuthName: context.OAuthName, - } + context, err := new(model.UserContext).NewFromGin(c) if err != nil { tlog.App.Debug().Err(err).Msg("No user context found in request") - userContext.Status = 401 - userContext.Message = "Unauthorized" - userContext.IsLoggedIn = false - c.JSON(200, userContext) + c.JSON(200, UserContextResponse{ + Status: 401, + Message: "Unauthorized", + IsLoggedIn: false, + }) return } + userContext := UserContextResponse{ + Status: 200, + Message: "Success", + IsLoggedIn: context.Authenticated, + Username: context.GetUsername(), + Name: context.GetName(), + Email: context.GetEmail(), + Provider: context.ProviderName(), + OAuth: context.IsOAuth(), + TOTPPending: context.TOTPPending(), + OAuthName: context.OAuthName(), + } + c.JSON(200, userContext) } diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 2329425b..12a8e22b 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -7,11 +7,11 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" ) func TestContextController(t *testing.T) { @@ -79,12 +79,16 @@ func TestContextController(t *testing.T) { description: "Ensure user context returns when authorized", middlewares: []gin.HandlerFunc{ func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "johndoe", - Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), - Provider: "local", - IsLoggedIn: true, + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "johndoe", + Name: "John Doe", + Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + }, + }, }) }, }, diff --git a/internal/controller/controller.go b/internal/controller/controller.go new file mode 100644 index 00000000..a1ca59ba --- /dev/null +++ b/internal/controller/controller.go @@ -0,0 +1,12 @@ +package controller + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` +} diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 4133b849..439c57dc 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" @@ -176,7 +175,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Username: user.Email, }) @@ -236,7 +235,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create session cookie") @@ -244,6 +243,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) if controller.isOidcRequest(oauthPendingSession.CallbackParams) { @@ -259,7 +260,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } if oauthPendingSession.CallbackParams.RedirectURI != "" { - queries, err := query.Values(config.RedirectQuery{ + queries, err := query.Values(RedirectQuery{ RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, }) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 8a08fd69..5e3f75f5 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -111,14 +112,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - userContext, err := utils.GetContext(c) + userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "") return } - if !userContext.IsLoggedIn { + if !userContext.Authenticated { controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "") return } @@ -151,7 +152,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. - sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID)) + sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) code := utils.GenerateString(32) // Before storing the code, delete old session @@ -170,7 +171,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { // We also need a snapshot of the user that authorized this (skip if no openid scope) if slices.Contains(strings.Fields(req.Scope), "openid") { - err = controller.oidc.StoreUserinfo(c, sub, userContext, req) + err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) if err != nil { tlog.App.Error().Err(err).Msg("Failed to insert user info into database") diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index a09697bf..150540fc 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -12,14 +12,14 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestOIDCController(t *testing.T) { @@ -27,7 +27,7 @@ func TestOIDCController(t *testing.T) { tempDir := t.TempDir() oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ + Clients: map[string]model.OIDCClientConfig{ "test": { ClientID: "some-client-id", ClientSecret: "some-client-secret", @@ -44,12 +44,16 @@ func TestOIDCController(t *testing.T) { controllerCfg := controller.OIDCControllerConfig{} simpleCtx := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "test", - Name: "Test User", - Email: "test@example.com", - IsLoggedIn: true, - Provider: "local", + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "test", + Name: "Test User", + Email: "test@example.com", + }, + }, }) c.Next() } @@ -848,7 +852,7 @@ func TestOIDCController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(config.Config{}) + app := bootstrap.NewBootstrapApp(model.Config{}) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 724c6f6f..7cd01969 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -8,7 +8,7 @@ import ( "regexp" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -103,7 +103,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { clientIP := c.ClientIP() - if controller.auth.IsBypassedIP(acls.IP, clientIP) { + if controller.auth.IsBypassedIP(clientIP, acls) { controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -112,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path) + authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) if err != nil { tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") @@ -130,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.auth.CheckIP(acls.IP, clientIP) { - queries, err := query.Values(config.UnauthorizedQuery{ + if !controller.auth.CheckIP(clientIP, acls) { + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], IP: clientIP, }) @@ -157,28 +157,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - var userContext config.UserContext - - context, err := utils.GetContext(c) + userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Msg("No user context found in request, treating as not logged in") - userContext = config.UserContext{ - IsLoggedIn: false, + tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") + userContext = &model.UserContext{ + Authenticated: false, } - } else { - userContext = context } tlog.App.Trace().Interface("context", userContext).Msg("User context from request") - if userContext.IsLoggedIn { - userAllowed := controller.auth.IsUserAllowed(c, userContext, acls) + if userContext.Authenticated { + userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) if !userAllowed { - tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") + tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], }) @@ -188,10 +184,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if userContext.OAuth { - queries.Set("username", userContext.Email) + if userContext.IsOAuth() { + queries.Set("username", userContext.GetEmail()) } else { - queries.Set("username", userContext.Username) + queries.Set("username", userContext.GetUsername()) } redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) @@ -209,19 +205,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if userContext.OAuth || userContext.Provider == "ldap" { + if userContext.IsOAuth() || userContext.IsLDAP() { var groupOK bool - if userContext.OAuth { - groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups) + if userContext.IsOAuth() { + groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls) } else { - groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups) + groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls) } if !groupOK { - tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") + tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], GroupErr: true, }) @@ -232,10 +228,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if userContext.OAuth { - queries.Set("username", userContext.Email) + if userContext.IsOAuth() { + queries.Set("username", userContext.GetEmail()) } else { - queries.Set("username", userContext.Username) + queries.Set("username", userContext.GetUsername()) } redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) @@ -254,17 +250,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } } - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername())) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName())) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail())) - if userContext.Provider == "ldap" { - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups)) - } else if userContext.Provider != "local" { - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + if userContext.IsLDAP() { + c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ","))) } - c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub)) + if userContext.IsOAuth() { + c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ","))) + c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub)) + } controller.setHeaders(c, acls) @@ -275,7 +272,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(config.RedirectQuery{ + queries, err := query.Values(RedirectQuery{ RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), }) @@ -299,9 +296,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, redirectURL) } -func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { +func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { c.Header("Authorization", c.Request.Header.Get("Authorization")) + if acls == nil { + return + } + headers := utils.ParseHeaders(acls.Response.Headers) for key, value := range headers { @@ -313,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { if acls.Response.BasicAuth.Username != "" && basicPassword != "" { tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) } } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 8efbd31c..66c24a5e 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -6,14 +6,14 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestProxyController(t *testing.T) { @@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) { tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ - Users: []config.User{ + LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password @@ -29,7 +29,7 @@ func TestProxyController(t *testing.T) { { Username: "totpuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", + TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, }, SessionExpiry: 10, // 10 seconds, useful for testing @@ -43,28 +43,28 @@ func TestProxyController(t *testing.T) { AppURL: "https://tinyauth.example.com", } - acls := map[string]config.App{ + acls := map[string]model.App{ "app_path_allow": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "path-allow.example.com", }, - Path: config.AppPath{ + Path: model.AppPath{ Allow: "/allowed", }, }, "app_user_allow": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "user-allow.example.com", }, - Users: config.AppUsers{ + Users: model.AppUsers{ Allow: "testuser", }, }, "ip_bypass": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "ip-bypass.example.com", }, - IP: config.AppIP{ + IP: model.AppIP{ Bypass: []string{"10.10.10.10"}, }, }, @@ -74,24 +74,32 @@ func TestProxyController(t *testing.T) { Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` simpleCtx := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "testuser", - Name: "Testuser", - Email: "testuser@example.com", - IsLoggedIn: true, - Provider: "local", + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + }, }) c.Next() } simpleCtxTotp := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - IsLoggedIn: true, - Provider: "local", - TotpEnabled: true, + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + }, + TOTPEnabled: true, + }, }) c.Next() } @@ -391,9 +399,9 @@ func TestProxyController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(config.Config{}) + app := bootstrap.NewBootstrapApp(model.Config{}) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 187b33b9..3d5b2215 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -1,10 +1,12 @@ package controller import ( + "errors" "fmt" + "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" @@ -24,7 +26,8 @@ type TotpRequest struct { } type UserControllerConfig struct { - CookieDomain string + CookieDomain string + SessionCookieName string } type UserController struct { @@ -77,20 +80,28 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - userSearch := controller.auth.SearchUser(req.Username) + search, err := controller.auth.SearchUser(req.Username) - if userSearch.Type == "unknown" { - tlog.App.Warn().Str("username", req.Username).Msg("User not found") - controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "user not found") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + tlog.App.Warn().Str("username", req.Username).Msg("User not found") + controller.auth.RecordLoginAttempt(req.Username, false) + tlog.AuditLoginFailure(c, req.Username, "username", "user not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", }) return } - if !controller.auth.VerifyUser(userSearch, req.Password) { + if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { tlog.App.Warn().Str("username", req.Username).Msg("Invalid password") controller.auth.RecordLoginAttempt(req.Username, false) tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") @@ -106,30 +117,26 @@ func (controller *UserController) loginHandler(c *gin.Context) { controller.auth.RecordLoginAttempt(req.Username, true) - var localUser *config.User - if userSearch.Type == "local" { - user := controller.auth.GetLocalUser(userSearch.Username) - localUser = &user - } + var localUser *model.LocalUser - if userSearch.Type == "local" && localUser != nil { - user := *localUser + if search.Type == model.UserLocal { + localUser = controller.auth.GetLocalUser(req.Username) - if user.TotpSecret != "" { + if localUser.TOTPSecret != "" { tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - name := user.Attributes.Name + name := localUser.Attributes.Name if name == "" { - name = utils.Capitalize(user.Username) + name = utils.Capitalize(localUser.Username) } - email := user.Attributes.Email + email := localUser.Attributes.Email if email == "" { - email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain) + email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) } - err := controller.auth.CreateSessionCookie(c, &repository.Session{ - Username: user.Username, + cookie, err := controller.auth.CreateSession(c, repository.Session{ + Username: localUser.Username, Name: name, Email: email, Provider: "local", @@ -145,6 +152,8 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "TOTP required", @@ -161,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { Provider: "local", } - if userSearch.Type == "local" && localUser != nil { + if search.Type == model.UserLocal { if localUser.Attributes.Name != "" { sessionCookie.Name = localUser.Attributes.Name } @@ -170,13 +179,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - if userSearch.Type == "ldap" { + if search.Type == model.UserLDAP { sessionCookie.Provider = "ldap" } tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create session cookie") @@ -187,6 +196,8 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", @@ -196,13 +207,47 @@ func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) { tlog.App.Debug().Msg("Logout request received") - controller.auth.DeleteSessionCookie(c) + uuid, err := c.Cookie(controller.config.SessionCookieName) - context, err := utils.GetContext(c) - if err == nil && context.IsLoggedIn { - tlog.AuditLogout(c, context.Username, context.Provider) + if err != nil { + if errors.Is(err, http.ErrNoCookie) { + tlog.App.Warn().Msg("No session cookie found on logout request") + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) + return + } + tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + cookie, err := controller.auth.DeleteSession(c, uuid) + + if err != nil { + tlog.App.Error().Err(err).Msg("Error deleting session on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + context, err := new(model.UserContext).NewFromGin(c) + + if err == nil { + tlog.AuditLogout(c, context.GetUsername(), context.ProviderName()) + } else { + tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") + tlog.AuditLogout(c, "unknown", "unknown") } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", @@ -222,7 +267,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - context, err := utils.GetContext(c) + context, err := new(model.UserContext).NewFromGin(c) if err != nil { tlog.App.Error().Err(err).Msg("Failed to get user context") @@ -233,7 +278,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - if !context.TotpPending { + if !context.TOTPPending() { tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") c.JSON(401, gin.H{ "status": 401, @@ -242,12 +287,12 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt") + tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") - isLocked, remaining := controller.auth.IsAccountLocked(context.Username) + isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) if isLocked { - tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts") + tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -257,14 +302,23 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - user := controller.auth.GetLocalUser(context.Username) + user := controller.auth.GetLocalUser(context.GetUsername()) - ok := totp.Validate(req.Code, user.TotpSecret) + if user == nil { + tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + ok := totp.Validate(req.Code, user.TOTPSecret) if !ok { - tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code") - controller.auth.RecordLoginAttempt(context.Username, false) - tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code") + tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") + controller.auth.RecordLoginAttempt(context.GetUsername(), false) + tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -272,10 +326,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful") - tlog.AuditLoginSuccess(c, context.Username, "totp") + tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") + tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") - controller.auth.RecordLoginAttempt(context.Username, true) + controller.auth.RecordLoginAttempt(context.GetUsername(), true) sessionCookie := repository.Session{ Username: user.Username, @@ -293,7 +347,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create session cookie") @@ -304,6 +358,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 65ef15ef..18544c43 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -10,14 +10,14 @@ import ( "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestUserController(t *testing.T) { @@ -25,7 +25,7 @@ func TestUserController(t *testing.T) { tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ - Users: []config.User{ + LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password @@ -33,12 +33,12 @@ func TestUserController(t *testing.T) { { Username: "totpuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", + TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, { Username: "attruser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - Attributes: config.UserAttributes{ + Attributes: model.UserAttributes{ Name: "Alice Smith", Email: "alice@example.com", }, @@ -46,8 +46,8 @@ func TestUserController(t *testing.T) { { Username: "attrtotpuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - Attributes: config.UserAttributes{ + TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", + Attributes: model.UserAttributes{ Name: "Bob Jones", Email: "bob@example.com", }, @@ -61,7 +61,54 @@ func TestUserController(t *testing.T) { } userControllerCfg := controller.UserControllerConfig{ - CookieDomain: "example.com", + CookieDomain: "example.com", + SessionCookieName: "tinyauth-session", + } + + totpCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + }, + TOTPPending: true, + TOTPEnabled: true, + }, + }) + } + + totpAttrCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "attrtotpuser", + Name: "Bob Jones", + Email: "bob@example.com", + }, + TOTPPending: true, + TOTPEnabled: true, + }, + }) + } + + simpleCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Test User", + Email: "testuser@example.com", + }, + }, + }) } type testCase struct { @@ -94,7 +141,7 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 10, cookie.MaxAge) + assert.Equal(t, 9, cookie.MaxAge) }, }, { @@ -183,12 +230,14 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions + assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions }, }, { description: "Should be able to logout", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { // First login to get a session cookie loginReq := controller.LoginRequest{ @@ -204,9 +253,10 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + cookies := recorder.Result().Cookies() + assert.Len(t, cookies, 1) - cookie := recorder.Result().Cookies()[0] + cookie := cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) // Now logout using the session cookie @@ -217,17 +267,20 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + cookies = recorder.Result().Cookies() + assert.Len(t, cookies, 1) - logoutCookie := recorder.Result().Cookies()[0] - assert.Equal(t, "tinyauth-session", logoutCookie.Name) - assert.Equal(t, "", logoutCookie.Value) - assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie + cookie = cookies[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.Equal(t, "", cookie.Value) + assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie }, }, { description: "Should be able to login with totp", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) assert.NoError(t, err) @@ -253,12 +306,14 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", totpCookie.Name) assert.True(t, totpCookie.HttpOnly) assert.Equal(t, "example.com", totpCookie.Domain) - assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time + assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time }, }, { description: "Totp should rate limit on multiple invalid attempts", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { for range 3 { totpReq := controller.TotpRequest{ @@ -328,7 +383,9 @@ func TestUserController(t *testing.T) { }, { description: "TOTP completion uses name and email from user attributes", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpAttrCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) require.NoError(t, err) @@ -349,9 +406,9 @@ func TestUserController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(config.Config{}) + app := bootstrap.NewBootstrapApp(model.Config{}) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) @@ -379,33 +436,6 @@ func TestUserController(t *testing.T) { authService.ClearRateLimitsTestingOnly() } - setTotpMiddlewareOverrides := map[string]config.UserContext{ - "Should be able to login with totp": { - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - "Totp should rate limit on multiple invalid attempts": { - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - "TOTP completion uses name and email from user attributes": { - Username: "attrtotpuser", - Name: "Bob Jones", - Email: "bob@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - } - for _, test := range tests { beforeEach() t.Run(test.description, func(t *testing.T) { @@ -415,15 +445,6 @@ func TestUserController(t *testing.T) { router.Use(middleware) } - // Gin is stupid and doesn't allow setting a middleware after the groups - // so we need to do some stupid overrides here - if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok { - ctx := ctx - router.Use(func(c *gin.Context) { - c.Set("context", &ctx) - }) - } - group := router.Group("/api") gin.SetMode(gin.TestMode) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7d8d05f5..7dcf2bdc 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -8,14 +8,14 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestWellKnownController(t *testing.T) { @@ -23,7 +23,7 @@ func TestWellKnownController(t *testing.T) { tempDir := t.TempDir() oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ + Clients: map[string]model.OIDCClientConfig{ "test": { ClientID: "some-client-id", ClientSecret: "some-client-secret", @@ -101,7 +101,7 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(config.Config{}) + app := bootstrap.NewBootstrapApp(model.Config{}) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 651d9d85..a5773dbd 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -1,10 +1,13 @@ package middleware import ( + "context" + "fmt" + "net/http" "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -33,7 +36,8 @@ var ( ) type ContextMiddlewareConfig struct { - CookieDomain string + CookieDomain string + SessionCookieName string } type ContextMiddleware struct { @@ -61,200 +65,191 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - cookie, err := m.auth.GetSessionCookie(c) + uuid, err := c.Cookie(m.config.SessionCookieName) - if err != nil { - tlog.App.Debug().Err(err).Msg("No valid session cookie found") - goto basic - } - - if cookie.TotpPending { - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }) - c.Next() - return - } - - switch cookie.Provider { - case "local", "ldap": - userSearch := m.auth.SearchUser(cookie.Username) + if err == nil { + userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) - if userSearch.Type == "unknown" { - tlog.App.Debug().Msg("User from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } + if err == nil { + if cookie != nil { + http.SetCookie(c.Writer, cookie) + } - if userSearch.Type != cookie.Provider { - tlog.App.Warn().Msg("User type from session cookie does not match user search type") - m.auth.DeleteSessionCookie(c) + tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername()) + c.Set("context", userContext) c.Next() return + } else { + tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) } + } - var ldapGroups []string - var localAttributes config.UserAttributes - - if cookie.Provider == "ldap" { - ldapUser, err := m.auth.GetLdapUser(userSearch.Username) + username, password, ok := c.Request.BasicAuth() - if err != nil { - tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details") - c.Next() - return - } + if ok { + userContext, headers, err := m.basicAuth(username, password) - ldapGroups = ldapUser.Groups + if err != nil { + tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) + c.Next() + return } - if cookie.Provider == "local" { - localUser := m.auth.GetLocalUser(cookie.Username) - localAttributes = localUser.Attributes + for k, v := range headers { + c.Header(k, v) } - m.auth.RefreshSessionCookie(c) - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - IsLoggedIn: true, - LdapGroups: strings.Join(ldapGroups, ","), - Attributes: localAttributes, - }) + c.Set("context", userContext) c.Next() return - default: - _, exists := m.broker.GetService(cookie.Provider) + } - if !exists { - tlog.App.Debug().Msg("OAuth provider from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } + c.Next() + } +} - if !m.auth.IsEmailWhitelisted(cookie.Email) { - tlog.App.Debug().Msg("Email from session cookie not whitelisted") - m.auth.DeleteSessionCookie(c) - goto basic - } +func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) { + session, err := m.auth.GetSession(ctx, uuid) - m.auth.RefreshSessionCookie(c) - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - OAuthName: cookie.OAuthName, - OAuthSub: cookie.OAuthSub, - IsLoggedIn: true, - OAuth: true, - }) - c.Next() - return + if err != nil { + return nil, nil, fmt.Errorf("error retrieving session: %w", err) + } + + userContext, err := new(model.UserContext).NewFromSession(session) + + if err != nil { + return nil, nil, fmt.Errorf("error creating user context from session: %w", err) + } + + if userContext.Provider == model.ProviderLocal && + userContext.Local.TOTPPending { + userContext.Local.TOTPEnabled = true + return userContext, nil, nil + } + + switch userContext.Provider { + case model.ProviderLocal: + user := m.auth.GetLocalUser(userContext.Local.Username) + + if user == nil { + return nil, nil, fmt.Errorf("local user not found") } - basic: - basic := m.auth.GetBasicAuth(c) + userContext.Local.Attributes = user.Attributes - if basic == nil { - tlog.App.Debug().Msg("No basic auth provided") - c.Next() - return + if userContext.Local.Attributes.Name == "" { + userContext.Local.Attributes.Name = utils.Capitalize(user.Username) } - locked, remaining := m.auth.IsAccountLocked(basic.Username) + if userContext.Local.Attributes.Email == "" { + userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) + } + case model.ProviderLDAP: + search, err := m.auth.SearchUser(userContext.LDAP.Username) - if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) - c.Writer.Header().Add("x-tinyauth-lock-locked", "true") - c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) - c.Next() - return + if err != nil { + return nil, nil, fmt.Errorf("error searching for ldap user: %w", err) + } + + if search.Type != model.UserLDAP { + return nil, nil, fmt.Errorf("user from session cookie is not ldap") } - userSearch := m.auth.SearchUser(basic.Username) + user, err := m.auth.GetLDAPUser(search.Username) - if userSearch.Type == "unknown" || userSearch.Type == "error" { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("User from basic auth not found") - c.Next() - return + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) } - if !m.auth.VerifyUser(userSearch, basic.Password) { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("Invalid password for basic auth user") - c.Next() - return + userContext.LDAP.Groups = user.Groups + userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) + userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) + case model.ProviderOAuth: + _, exists := m.broker.GetService(userContext.OAuth.ID) + + if !exists { + return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) } - m.auth.RecordLoginAttempt(basic.Username, true) + if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { + m.auth.DeleteSession(ctx, uuid) + return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) + } + } - switch userSearch.Type { - case "local": - tlog.App.Debug().Msg("Basic auth user is local") + cookie, err := m.auth.RefreshSession(ctx, uuid) - user := m.auth.GetLocalUser(basic.Username) + if err != nil { + return nil, nil, fmt.Errorf("error refreshing session: %w", err) + } - if user.TotpSecret != "" { - tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth") - return - } + return userContext, cookie, nil +} - name := utils.Capitalize(user.Username) - if user.Attributes.Name != "" { - name = user.Attributes.Name - } - email := utils.CompileUserEmail(user.Username, m.config.CookieDomain) - if user.Attributes.Email != "" { - email = user.Attributes.Email - } +func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) { + headers := make(map[string]string) + userContext := new(model.UserContext) + locked, remaining := m.auth.IsAccountLocked(username) - c.Set("context", &config.UserContext{ - Username: user.Username, - Name: name, - Email: email, - Provider: "local", - IsLoggedIn: true, - IsBasicAuth: true, - Attributes: user.Attributes, - }) - c.Next() - return - case "ldap": - tlog.App.Debug().Msg("Basic auth user is LDAP") + if locked { + tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) + headers["x-tinyauth-lock-locked"] = "true" + headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) + return nil, headers, nil + } - ldapUser, err := m.auth.GetLdapUser(basic.Username) + search, err := m.auth.SearchUser(username) - if err != nil { - tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details") - c.Next() - return - } + if err != nil { + return nil, nil, fmt.Errorf("error searching for user: %w", err) + } - c.Set("context", &config.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), - Provider: "ldap", - IsLoggedIn: true, - LdapGroups: strings.Join(ldapUser.Groups, ","), - IsBasicAuth: true, - }) - c.Next() - return + err = m.auth.CheckUserPassword(*search, password) + + if err != nil { + m.auth.RecordLoginAttempt(username, false) + return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) + } + + m.auth.RecordLoginAttempt(username, true) + + switch search.Type { + case model.UserLocal: + user := m.auth.GetLocalUser(username) + + if user.TOTPSecret != "" { + return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username) } - c.Next() + userContext.Local = &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), + }, + Attributes: user.Attributes, + } + userContext.Provider = model.ProviderLocal + case model.UserLDAP: + user, err := m.auth.GetLDAPUser(username) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) + } + + userContext.LDAP = &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: username, + Name: utils.Capitalize(username), + Email: utils.CompileUserEmail(username, m.config.CookieDomain), + }, + Groups: user.Groups, + } + userContext.Provider = model.ProviderLDAP } + + userContext.Authenticated = true + return userContext, nil, nil } func (m *ContextMiddleware) isIgnorePath(path string) bool { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go new file mode 100644 index 00000000..6e91a585 --- /dev/null +++ b/internal/middleware/context_middleware_test.go @@ -0,0 +1,330 @@ +package middleware_test + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + "path" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/bootstrap" + "github.com/tinyauthapp/tinyauth/internal/middleware" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/utils/tlog" +) + +func TestContextMiddleware(t *testing.T) { + tlog.NewTestLogger().Init() + tempDir := t.TempDir() + + authServiceCfg := service.AuthServiceConfig{ + LocalUsers: &[]model.LocalUser{ + { + Username: "testuser", + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password + }, + { + Username: "totpuser", + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password + TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", + }, + }, + SessionExpiry: 10, // 10 seconds, useful for testing + CookieDomain: "example.com", + LoginTimeout: 10, // 10 seconds, useful for testing + LoginMaxRetries: 3, + SessionCookieName: "tinyauth-session", + } + + middlewareCfg := middleware.ContextMiddlewareConfig{ + CookieDomain: "example.com", + SessionCookieName: "tinyauth-session", + } + + basicAuthHeader := func(username, password string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) + } + + seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { + t.Helper() + _, err := queries.CreateSession(context.Background(), params) + require.NoError(t, err) + } + + type runArgs struct { + do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) + queries *repository.Queries + } + + type testCase struct { + description string + run func(t *testing.T, args runArgs) + } + + tests := []testCase{ + { + description: "Skip path bypasses auth processing", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/healthz", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "No credentials yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Valid session cookie sets authenticated local context", + run: func(t *testing.T, args runArgs) { + uuid := "session-valid-local" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, model.ProviderLocal, userCtx.Provider) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + require.NotNil(t, userCtx.Local) + assert.False(t, userCtx.Local.TOTPEnabled) + }, + }, + { + description: "Session cookie with totp pending sets unauthenticated context with totp enabled", + run: func(t *testing.T, args runArgs) { + uuid := "session-totp-pending" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "totpuser", + Provider: "local", + TotpPending: true, + Expiry: time.Now().Add(60 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "totpuser", userCtx.GetUsername()) + assert.False(t, userCtx.Authenticated) + require.NotNil(t, userCtx.Local) + assert.True(t, userCtx.Local.TOTPPending) + assert.True(t, userCtx.Local.TOTPEnabled) + }, + }, + { + description: "Unknown session cookie yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Session for missing local user yields no context", + run: func(t *testing.T, args runArgs) { + uuid := "session-deleted-user" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "ghostuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Expired session cookie yields no context", + run: func(t *testing.T, args runArgs) { + uuid := "session-expired" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(-1 * time.Second).Unix(), + CreatedAt: time.Now().Add(-10 * time.Second).Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Valid basic auth sets authenticated local context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, model.ProviderLocal, userCtx.Provider) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + { + description: "Invalid basic auth password yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Basic auth is rejected for users with totp", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Locked account on basic auth sets lock headers", + run: func(t *testing.T, args runArgs) { + for range 3 { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) + args.do(req) + } + + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, recorder := args.do(req) + + assert.Nil(t, userCtx) + assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked")) + assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset")) + }, + }, + { + description: "Cookie auth takes precedence over basic auth", + run: func(t *testing.T, args runArgs) { + uuid := "session-precedence" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + { + description: "Ensure fallback to basic auth when cookie is missing", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + } + + oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + + app := bootstrap.NewBootstrapApp(model.Config{}) + + db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + queries := repository.New(db) + + ldap := service.NewLdapService(service.LdapServiceConfig{}) + err = ldap.Init() + require.NoError(t, err) + + broker := service.NewOAuthBrokerService(oauthBrokerCfgs) + err = broker.Init() + require.NoError(t, err) + + authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + err = authService.Init() + require.NoError(t, err) + + contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker) + err = contextMiddleware.Init() + require.NoError(t, err) + + for _, test := range tests { + authService.ClearRateLimitsTestingOnly() + t.Run(test.description, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) { + var captured *model.UserContext + router := gin.New() + router.Use(contextMiddleware.Middleware()) + handler := func(c *gin.Context) { + if val, exists := c.Get("context"); exists { + captured, _ = val.(*model.UserContext) + } + } + router.GET("/api/test", handler) + router.GET("/api/healthz", handler) + + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + return captured, recorder + } + + test.run(t, runArgs{do: do, queries: queries}) + }) + } + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) +} diff --git a/internal/config/config.go b/internal/model/config.go similarity index 87% rename from internal/config/config.go rename to internal/model/config.go index e364b458..d0bb9d15 100644 --- a/internal/config/config.go +++ b/internal/model/config.go @@ -1,4 +1,4 @@ -package config +package model // Default configuration func NewDefaultConfiguration() *Config { @@ -29,7 +29,7 @@ func NewDefaultConfiguration() *Config { BackgroundImage: "/background.jpg", WarningsEnabled: true, }, - Ldap: LdapConfig{ + LDAP: LDAPConfig{ Insecure: false, SearchFilter: "(uid=%s)", GroupCacheTTL: 900, // 15 minutes @@ -63,20 +63,6 @@ func NewDefaultConfiguration() *Config { } } -// Version information, set at build time - -var Version = "development" -var CommitHash = "development" -var BuildTimestamp = "0000-00-00T00:00:00Z" - -// Cookie name templates - -var SessionCookieName = "tinyauth-session" -var CSRFCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" -var OAuthSessionCookieName = "tinyauth-oauth" - -// Main app config type Config struct { AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` Database DatabaseConfig `description:"Database configuration." yaml:"database"` @@ -88,7 +74,7 @@ type Config struct { OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` UI UIConfig `description:"UI customization." yaml:"ui"` - Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"` + LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"` Log LogConfig `description:"Logging configuration." yaml:"log"` @@ -177,7 +163,7 @@ type UIConfig struct { WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` } -type LdapConfig struct { +type LDAPConfig struct { Address string `description:"LDAP server address." yaml:"address"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` @@ -210,20 +196,6 @@ type ExperimentalConfig struct { ConfigFile string `description:"Path to config file." yaml:"-"` } -// Config loader options - -const DefaultNamePrefix = "TINYAUTH_" - -// OAuth/OIDC config - -type Claims struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} - type OAuthServiceConfig struct { ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` @@ -246,60 +218,6 @@ type OIDCClientConfig struct { Name string `description:"Client name in UI." yaml:"name"` } -var OverrideProviders = map[string]string{ - "google": "Google", - "github": "GitHub", -} - -// User/session related stuff - -type User struct { - Username string - Password string - TotpSecret string - Attributes UserAttributes -} - -type LdapUser struct { - DN string - Groups []string -} - -type UserSearch struct { - Username string - Type string // local, ldap or unknown -} - -type UserContext struct { - Username string - Name string - Email string - IsLoggedIn bool - IsBasicAuth bool - OAuth bool - Provider string - TotpPending bool - OAuthGroups string - TotpEnabled bool - OAuthName string - OAuthSub string - LdapGroups string - Attributes UserAttributes -} - -// API responses and queries - -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -type RedirectQuery struct { - RedirectURI string `url:"redirect_uri"` -} - // ACLs type Apps struct { @@ -355,7 +273,3 @@ type AppPath struct { Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Block string `description:"Comma-separated list of blocked paths." yaml:"block"` } - -// API server - -var ApiServer = "https://api.tinyauth.app" diff --git a/internal/model/constants.go b/internal/model/constants.go new file mode 100644 index 00000000..d9e85e57 --- /dev/null +++ b/internal/model/constants.go @@ -0,0 +1,23 @@ +package model + +const DefaultNamePrefix = "TINYAUTH_" + +const APIServer = "https://api.tinyauth.app" + +type Claims struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var OverrideProviders = map[string]string{ + "google": "Google", + "github": "GitHub", +} + +const SessionCookieName = "tinyauth-session" +const CSRFCookieName = "tinyauth-csrf" +const RedirectCookieName = "tinyauth-redirect" +const OAuthSessionCookieName = "tinyauth-oauth" diff --git a/internal/model/context.go b/internal/model/context.go new file mode 100644 index 00000000..7df204de --- /dev/null +++ b/internal/model/context.go @@ -0,0 +1,251 @@ +package model + +import ( + "errors" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +type ProviderType int + +const ( + ProviderLocal ProviderType = iota + ProviderBasicAuth + ProviderOAuth + ProviderLDAP +) + +type UserContext struct { + Authenticated bool + Provider ProviderType + Local *LocalContext + OAuth *OAuthContext + LDAP *LDAPContext +} + +type BaseContext struct { + Username string + Name string + Email string +} + +type LocalContext struct { + BaseContext + TOTPPending bool + TOTPEnabled bool + Attributes UserAttributes +} + +type OAuthContext struct { + BaseContext + Groups []string + Sub string + DisplayName string + ID string +} + +type LDAPContext struct { + BaseContext + Groups []string +} + +func (c *UserContext) IsAuthenticated() bool { + return c.Authenticated +} + +func (c *UserContext) IsLocal() bool { + return c.Provider == ProviderLocal && c.Local != nil +} + +func (c *UserContext) IsOAuth() bool { + return c.Provider == ProviderOAuth && c.OAuth != nil +} + +func (c *UserContext) IsLDAP() bool { + return c.Provider == ProviderLDAP && c.LDAP != nil +} + +func (c *UserContext) IsBasicAuth() bool { + return c.Provider == ProviderBasicAuth && c.Local != nil +} + +func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { + userContextValue, exists := ginctx.Get("context") + + if !exists { + return nil, errors.New("failed to get user context") + } + + userContext, ok := userContextValue.(*UserContext) + + if !ok || userContext == nil { + return nil, errors.New("invalid user context type") + } + + if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil { + return nil, errors.New("incomplete user context") + } + + *c = *userContext + return c, nil +} + +// Compatability layer until we get an excuse to drop in database migrations +func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { + *c = UserContext{ + Authenticated: !session.TotpPending, + } + + switch session.Provider { + case "local": + c.Provider = ProviderLocal + c.Local = &LocalContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + TOTPPending: session.TotpPending, + } + case "ldap": + c.Provider = ProviderLDAP + c.LDAP = &LDAPContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + } + // By default we assume an unkown name which is oauth + default: + c.Provider = ProviderOAuth + c.OAuth = &OAuthContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + Groups: func() []string { + if session.OAuthGroups == "" { + return nil + } + return strings.Split(session.OAuthGroups, ",") + }(), + Sub: session.OAuthSub, + DisplayName: session.OAuthName, + ID: session.Provider, + } + } + + return c, nil +} + +func (c *UserContext) GetUsername() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Username + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Username + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Username + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Username + default: + return "" + } +} + +func (c *UserContext) GetEmail() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Email + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Email + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Email + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Email + default: + return "" + } +} + +func (c *UserContext) GetName() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Name + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Name + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Name + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Name + default: + return "" + } +} + +func (c *UserContext) ProviderName() string { + switch c.Provider { + case ProviderBasicAuth, ProviderLocal: + return "local" + case ProviderLDAP: + return "ldap" + case ProviderOAuth: + return c.OAuth.DisplayName // compatability + default: + return "unknown" + } +} + +func (c *UserContext) TOTPPending() bool { + if c.Provider == ProviderLocal && c.Local != nil { + return c.Local.TOTPPending + } + return false +} + +func (c *UserContext) OAuthName() string { + if c.Provider == ProviderOAuth && c.OAuth != nil { + return c.OAuth.DisplayName + } + return "" +} diff --git a/internal/model/context_test.go b/internal/model/context_test.go new file mode 100644 index 00000000..b45b9210 --- /dev/null +++ b/internal/model/context_test.go @@ -0,0 +1,276 @@ +package model_test + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func TestContext(t *testing.T) { + newGinCtx := func(value any, set bool) *gin.Context { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + if set { + c.Set("context", value) + } + return c + } + + tests := []struct { + description string + context *model.UserContext + run func(*testing.T, *model.UserContext) any + expected any + }{ + { + description: "IsAuthenticated reflects Authenticated field", + context: &model.UserContext{Authenticated: true}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() }, + expected: true, + }, + { + description: "IsLocal returns true for ProviderLocal", + context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, + expected: true, + }, + { + description: "IsOAuth returns true for ProviderOAuth", + context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, + expected: true, + }, + { + description: "IsLDAP returns true for ProviderLDAP", + context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, + expected: true, + }, + { + description: "IsBasicAuth returns true for ProviderBasicAuth", + context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, + expected: true, + }, + { + description: "NewFromSession local session is authenticated and ProviderLocal", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "alice", Email: "alice@example.com", Name: "Alice", + Provider: "local", + }) + require.NoError(t, err) + return [2]any{got.Provider, got.Authenticated} + }, + expected: [2]any{model.ProviderLocal, true}, + }, + { + description: "NewFromSession local session with TotpPending is not authenticated", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "bob", Provider: "local", TotpPending: true, + }) + require.NoError(t, err) + return got.Authenticated + }, + expected: false, + }, + { + description: "NewFromSession ldap session is ProviderLDAP", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "carol", Provider: "ldap", + }) + require.NoError(t, err) + return got.Provider + }, + expected: model.ProviderLDAP, + }, + { + description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "dave", Provider: "github", + OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", + }) + require.NoError(t, err) + return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} + }, + expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, + }, + { + description: "Local getters return BaseContext fields", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"alice", "alice@example.com", "Alice"}, + }, + { + description: "BasicAuth getters fall back to local fields", + context: &model.UserContext{ + Provider: model.ProviderBasicAuth, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"bob", "bob@example.com", "Bob"}, + }, + { + description: "LDAP getters return LDAP fields", + context: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"carol", "carol@example.com", "Carol"}, + }, + { + description: "OAuth getters return OAuth fields", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"dave", "dave@example.com", "Dave"}, + }, + { + description: "ProviderName returns 'local' for ProviderLocal", + context: &model.UserContext{Provider: model.ProviderLocal}, + run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + expected: "local", + }, + { + description: "ProviderName returns 'local' for ProviderBasicAuth", + context: &model.UserContext{Provider: model.ProviderBasicAuth}, + run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + expected: "local", + }, + { + description: "ProviderName returns 'ldap' for ProviderLDAP", + context: &model.UserContext{Provider: model.ProviderLDAP}, + run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + expected: "ldap", + }, + { + description: "ProviderName returns OAuth DisplayName for ProviderOAuth", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{DisplayName: "GitHub"}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + expected: "GitHub", + }, + { + description: "TOTPPending returns true when local context is pending", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{TOTPPending: true}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: true, + }, + { + description: "TOTPPending returns false when local context is not pending", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{TOTPPending: false}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: false, + }, + { + description: "TOTPPending returns false for non-local providers", + context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: false, + }, + { + description: "OAuthName returns DisplayName for ProviderOAuth", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{DisplayName: "Google"}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, + expected: "Google", + }, + { + description: "OAuthName returns empty string for non-oauth providers", + context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, + expected: "", + }, + { + description: "NewFromGin populates context from gin value", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + stored := &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}}, + } + got, err := c.NewFromGin(newGinCtx(stored, true)) + require.NoError(t, err) + return [2]any{got.Authenticated, got.GetUsername()} + }, + expected: [2]any{true, "alice"}, + }, + { + description: "NewFromGin returns error when context value is missing", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx(nil, false)) + return err.Error() + }, + expected: "failed to get user context", + }, + { + description: "NewFromGin returns error when context value has wrong type", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx("not a user context", true)) + return err.Error() + }, + expected: "invalid user context type", + }, + { + description: "NewFromGin returns an error when context doesn't include user information", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true)) + return err.Error() + }, + expected: "incomplete user context", + }, + { + description: "Getters should not panic if provider context is empty", + context: &model.UserContext{Provider: model.ProviderLocal}, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"", "", ""}, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expected, test.run(t, test.context)) + }) + } +} diff --git a/internal/model/users.go b/internal/model/users.go new file mode 100644 index 00000000..48826fda --- /dev/null +++ b/internal/model/users.go @@ -0,0 +1,25 @@ +package model + +type UserSearchType int + +const ( + UserLocal UserSearchType = iota + UserLDAP +) + +type LDAPUser struct { + DN string + Groups []string +} + +type LocalUser struct { + Username string + Password string + TOTPSecret string + Attributes UserAttributes +} + +type UserSearch struct { + Username string + Type UserSearchType +} diff --git a/internal/model/version.go b/internal/model/version.go new file mode 100644 index 00000000..cd8bc138 --- /dev/null +++ b/internal/model/version.go @@ -0,0 +1,5 @@ +package model + +var Version = "development" +var CommitHash = "development" +var BuildTimestamp = "0000-00-00T00:00:00Z" diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index d054b5f1..fd57bf39 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -1,23 +1,22 @@ package service import ( - "errors" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) type LabelProvider interface { - GetLabels(appDomain string) (config.App, error) + GetLabels(appDomain string) (*model.App, error) } type AccessControlsService struct { labelProvider LabelProvider - static map[string]config.App + static map[string]model.App } -func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService { +func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { return &AccessControlsService{ labelProvider: labelProvider, static: static, @@ -28,26 +27,29 @@ func (acls *AccessControlsService) Init() error { return nil // No initialization needed } -func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) { +func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { + var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") - return config, nil + appAcls = &config + break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") - return config, nil + appAcls = &config + break // If we find a match by app name, we can stop searching } } - return config.App{}, errors.New("no results") + return appAcls } -func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) { +func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { // First check in the static config - app, err := acls.lookupStaticACLs(domain) + app := acls.lookupStaticACLs(domain) - if err == nil { + if app != nil { tlog.App.Debug().Msg("Using ACls from static configuration") return app, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 0311229d..cad25608 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -5,12 +5,13 @@ import ( "database/sql" "errors" "fmt" + "net/http" "regexp" "strings" "sync" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -29,6 +30,10 @@ const MaxOAuthPendingSessions = 256 const OAuthCleanupCount = 16 const MaxLoginAttemptRecords = 256 +var ( + ErrUserNotFound = errors.New("user not found") +) + // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // parameters and pass them to the authorize page if needed type OAuthURLParams struct { @@ -68,7 +73,7 @@ type Lockdown struct { } type AuthServiceConfig struct { - Users []config.User + LocalUsers *[]model.LocalUser OauthWhitelist []string SessionExpiry int SessionMaxLifetime int @@ -77,7 +82,7 @@ type AuthServiceConfig struct { LoginTimeout int LoginMaxRetries int SessionCookieName string - IP config.IPConfig + IP model.IPConfig LDAPGroupsCacheTTL int } @@ -106,7 +111,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi ldap: ldap, queries: queries, oauthBroker: oauthBroker, -} + } } func (auth *AuthService) Init() error { @@ -114,79 +119,73 @@ func (auth *AuthService) Init() error { return nil } -func (auth *AuthService) SearchUser(username string) config.UserSearch { - if auth.GetLocalUser(username).Username != "" { - return config.UserSearch{ +func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { + if auth.GetLocalUser(username) != nil { + return &model.UserSearch{ Username: username, - Type: "local", - } + Type: model.UserLocal, + }, nil } if auth.ldap.IsConfigured() { userDN, err := auth.ldap.GetUserDN(username) if err != nil { - tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") - return config.UserSearch{ - Type: "unknown", - } + return nil, fmt.Errorf("failed to get ldap user: %w", err) } - return config.UserSearch{ + return &model.UserSearch{ Username: userDN, - Type: "ldap", - } + Type: model.UserLDAP, + }, nil } - return config.UserSearch{ - Type: "unknown", - } + return nil, ErrUserNotFound } -func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { +func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { switch search.Type { - case "local": + case model.UserLocal: user := auth.GetLocalUser(search.Username) - return auth.CheckPassword(user, password) - case "ldap": + if user == nil { + return ErrUserNotFound + } + return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) + case model.UserLDAP: if auth.ldap.IsConfigured() { err := auth.ldap.Bind(search.Username, password) if err != nil { - tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") - return false + return fmt.Errorf("failed to bind to ldap user: %w", err) } err = auth.ldap.BindService(true) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication") - return false + return fmt.Errorf("failed to bind to ldap service account: %w", err) } - return true + return nil } default: - tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication") - return false + return errors.New("unknown user search type") } - - tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed") - return false + return errors.New("user authentication failed") } -func (auth *AuthService) GetLocalUser(username string) config.User { - for _, user := range auth.config.Users { +func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { + if auth.config.LocalUsers == nil { + return nil + } + for _, user := range *auth.config.LocalUsers { if user.Username == username { - return user + return &user } } - - tlog.App.Warn().Str("username", username).Msg("Local user not found") - return config.User{} + return nil } -func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { +func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { if !auth.ldap.IsConfigured() { - return config.LdapUser{}, errors.New("LDAP service not initialized") + return nil, errors.New("ldap service not configured") } auth.ldapGroupsMutex.RLock() @@ -194,7 +193,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { auth.ldapGroupsMutex.RUnlock() if exists && time.Now().Before(entry.Expires) { - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: entry.Groups, }, nil @@ -203,7 +202,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { groups, err := auth.ldap.GetUserGroups(userDN) if err != nil { - return config.LdapUser{}, err + return nil, fmt.Errorf("failed to get ldap groups: %w", err) } auth.ldapGroupsMutex.Lock() @@ -213,16 +212,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { } auth.ldapGroupsMutex.Unlock() - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: groups, }, nil } -func (auth *AuthService) CheckPassword(user config.User, password string) bool { - return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil -} - func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.loginMutex.RLock() defer auth.loginMutex.RUnlock() @@ -291,11 +286,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool { return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) } -func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error { +func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { uuid, err := uuid.NewRandom() if err != nil { - return err + return nil, fmt.Errorf("failed to generate session uuid: %w", err) } var expiry int @@ -306,6 +301,8 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se expiry = auth.config.SessionExpiry } + expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) + session := repository.CreateSessionParams{ UUID: uuid.String(), Username: data.Username, @@ -314,34 +311,36 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se Provider: data.Provider, TotpPending: data.TotpPending, OAuthGroups: data.OAuthGroups, - Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), + Expiry: expiresAt.Unix(), CreatedAt: time.Now().Unix(), OAuthName: data.OAuthName, OAuthSub: data.OAuthSub, } - _, err = auth.queries.CreateSession(c, session) + _, err = auth.queries.CreateSession(ctx, session) if err != nil { - return err - } - - c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + return nil, fmt.Errorf("failed to create session entry: %w", err) + } + + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: expiresAt, + MaxAge: int(time.Until(expiresAt).Seconds()), + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) +func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - return err - } - - session, err := auth.queries.GetSession(c, cookie) - - if err != nil { - return err + return nil, fmt.Errorf("failed to retrieve session: %w", err) } currentTime := time.Now().Unix() @@ -355,12 +354,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { } if session.Expiry-currentTime > refreshThreshold { - return nil + return nil, nil } newExpiry := session.Expiry + refreshThreshold - _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{ + _, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{ Username: session.Username, Email: session.Email, Name: session.Name, @@ -374,122 +373,123 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { }) if err != nil { - return err - } - - c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed") + return nil, fmt.Errorf("failed to update session expiry: %w", err) + } + + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), + MaxAge: int(newExpiry - currentTime), + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil - return nil } -func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) - - if err != nil { - return err - } - - err = auth.queries.DeleteSession(c, cookie) +func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { + err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - return err - } - - c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") + } + + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: "", + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: time.Now(), + MaxAge: -1, + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) { - cookie, err := c.Cookie(auth.config.SessionCookieName) - - if err != nil { - return repository.Session{}, err - } - - session, err := auth.queries.GetSession(c, cookie) +func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return repository.Session{}, fmt.Errorf("session not found") + return nil, errors.New("session not found") } - return repository.Session{}, err + return nil, err } currentTime := time.Now().Unix() if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { - err = auth.queries.DeleteSession(c, cookie) + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded") + return nil, fmt.Errorf("session max lifetime exceeded") } } if currentTime > session.Expiry { - err = auth.queries.DeleteSession(c, cookie) + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete expired session") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired") + return nil, fmt.Errorf("session expired") } - return repository.Session{ - UUID: session.UUID, - Username: session.Username, - Email: session.Email, - Name: session.Name, - Provider: session.Provider, - TotpPending: session.TotpPending, - OAuthGroups: session.OAuthGroups, - OAuthName: session.OAuthName, - OAuthSub: session.OAuthSub, - }, nil + return &session, nil } func (auth *AuthService) LocalAuthConfigured() bool { - return len(auth.config.Users) > 0 + return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 } -func (auth *AuthService) LdapAuthConfigured() bool { +func (auth *AuthService) LDAPAuthConfigured() bool { return auth.ldap.IsConfigured() } -func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool { - if context.OAuth { +func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { + return true + } + + if context.Provider == model.ProviderOAuth { tlog.App.Debug().Msg("Checking OAuth whitelist") - return utils.CheckFilter(acls.OAuth.Whitelist, context.Email) + return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { tlog.App.Debug().Msg("Checking blocked users") - if utils.CheckFilter(acls.Users.Block, context.Username) { + if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } tlog.App.Debug().Msg("Checking users") - return utils.CheckFilter(acls.Users.Allow, context.Username) + return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { - if requiredGroups == "" { +func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { return true } - for id := range config.OverrideProviders { - if context.Provider == id { - tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider") - return true - } + if !context.IsOAuth() { + tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + return false + } + + if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { + tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") + return true } - for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") + for _, userGroup := range context.OAuth.Groups { + if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { + tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } @@ -498,14 +498,19 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte return false } -func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { - if requiredGroups == "" { +func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { return true } - for userGroup := range strings.SplitSeq(context.LdapGroups, ",") { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") + if !context.IsLDAP() { + tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + return false + } + + for _, userGroup := range context.LDAP.Groups { + if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { + tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") return true } } @@ -514,10 +519,14 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex return false } -func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) { +func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) { + if acls == nil { + return true, nil + } + // Check for block list - if path.Block != "" { - regex, err := regexp.Compile(path.Block) + if acls.Path.Block != "" { + regex, err := regexp.Compile(acls.Path.Block) if err != nil { return true, err @@ -529,8 +538,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e } // Check for allow list - if path.Allow != "" { - regex, err := regexp.Compile(path.Allow) + if acls.Path.Allow != "" { + regex, err := regexp.Compile(acls.Path.Allow) if err != nil { return true, err @@ -544,22 +553,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e return true, nil } -func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { - username, password, ok := c.Request.BasicAuth() - if !ok { - tlog.App.Debug().Msg("No basic auth provided") - return nil - } - return &config.User{ - Username: username, - Password: password, +func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { + if acls == nil { + return true } -} -func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.Allow...) + blockedIps := append(auth.config.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) @@ -594,8 +595,12 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { return true } -func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { - for _, bypassed := range acls.Bypass { +func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { + if acls == nil { + return false + } + + for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") @@ -674,21 +679,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return token, nil } -func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { +func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) { session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { - return config.Claims{}, err + return nil, err } if session.Token == nil { - return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId) + return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) } userinfo, err := (*session.Service).GetUserinfo(session.Token) if err != nil { - return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) + return nil, fmt.Errorf("failed to get userinfo: %w", err) } return userinfo, nil diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 97179242..c5f95dd4 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -4,7 +4,7 @@ import ( "context" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -51,56 +51,48 @@ func (docker *DockerService) Init() error { } func (docker *DockerService) getContainers() ([]container.Summary, error) { - containers, err := docker.client.ContainerList(docker.context, container.ListOptions{}) - if err != nil { - return nil, err - } - return containers, nil + return docker.client.ContainerList(docker.context, container.ListOptions{}) } func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) { - inspect, err := docker.client.ContainerInspect(docker.context, containerId) - if err != nil { - return container.InspectResponse{}, err - } - return inspect, nil + return docker.client.ContainerInspect(docker.context, containerId) } -func (docker *DockerService) GetLabels(appDomain string) (config.App, error) { +func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { tlog.App.Debug().Msg("Docker not connected, returning empty labels") - return config.App{}, nil + return nil, nil } containers, err := docker.getContainers() if err != nil { - return config.App{}, err + return nil, err } for _, ctr := range containers { inspect, err := docker.inspectContainer(ctr.ID) if err != nil { - return config.App{}, err + return nil, err } - labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps") + labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps") if err != nil { - return config.App{}, err + return nil, err } for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") - return appLabels, nil + return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") - return appLabels, nil + return &appLabels, nil } } } tlog.App.Debug().Msg("No matching container found, returning empty labels") - return config.App{}, nil + return nil, nil } diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 6e11eac1..9c5ad427 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -32,7 +32,7 @@ type ingressAppKey struct { type ingressApp struct { domain string appName string - app config.App + app model.App } type KubernetesService struct { @@ -89,36 +89,38 @@ func (k *KubernetesService) removeIngress(namespace, name string) { } } -func (k *KubernetesService) getByDomain(domain string) (config.App, bool) { +func (k *KubernetesService) getByDomain(domain string) *model.App { k.mu.RLock() defer k.mu.RUnlock() if appKey, ok := k.domainIndex[domain]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { - for _, app := range apps { + for i := range apps { + app := &apps[i] if app.domain == domain && app.appName == appKey.appName { - return app.app, true + return &app.app } } } } - return config.App{}, false + return nil } -func (k *KubernetesService) getByAppName(appName string) (config.App, bool) { +func (k *KubernetesService) getByAppName(appName string) *model.App { k.mu.RLock() defer k.mu.RUnlock() if appKey, ok := k.appNameIndex[appName]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { - for _, app := range apps { + for i := range apps { + app := &apps[i] if app.appName == appName { - return app.app, true + return &app.app } } } } - return config.App{}, false + return nil } func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { @@ -129,7 +131,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { k.removeIngress(namespace, name) return } - labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps") + labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") if err != nil { tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") k.removeIngress(namespace, name) @@ -280,24 +282,25 @@ func (k *KubernetesService) Init() error { return nil } -func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) { +func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") - return config.App{}, nil + return nil, nil } // First check cache - if app, found := k.getByDomain(appDomain); found { + app := k.getByDomain(appDomain) + if app != nil { tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") return app, nil } appName := strings.SplitN(appDomain, ".", 2)[0] - if app, found := k.getByAppName(appName); found { + app = k.getByAppName(appName) + if app != nil { tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") return app, nil } tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") - return config.App{}, nil + return nil, nil } - diff --git a/internal/service/kubernetes_service_test.go b/internal/service/kubernetes_service_test.go index 1cd75b6a..c7b39ead 100644 --- a/internal/service/kubernetes_service_test.go +++ b/internal/service/kubernetes_service_test.go @@ -3,11 +3,11 @@ package service import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" ) func TestKubernetesService(t *testing.T) { @@ -20,69 +20,69 @@ func TestKubernetesService(t *testing.T) { { description: "Cache by domain returns app and misses unknown domain", run: func(t *testing.T, svc *KubernetesService) { - app := config.App{Config: config.AppConfig{Domain: "foo.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}} svc.addIngressApps("default", "my-ingress", []ingressApp{ {domain: "foo.example.com", appName: "foo", app: app}, }) - got, ok := svc.getByDomain("foo.example.com") - require.True(t, ok) + got := svc.getByDomain("foo.example.com") + require.NotNil(t, got) assert.Equal(t, "foo.example.com", got.Config.Domain) - _, ok = svc.getByDomain("notfound.example.com") - assert.False(t, ok) + got = svc.getByDomain("notfound.example.com") + assert.Nil(t, got) }, }, { description: "Cache by app name returns app and misses unknown name", run: func(t *testing.T, svc *KubernetesService) { - app := config.App{Config: config.AppConfig{Domain: "bar.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}} svc.addIngressApps("default", "my-ingress", []ingressApp{ {domain: "bar.example.com", appName: "bar", app: app}, }) - got, ok := svc.getByAppName("bar") - require.True(t, ok) + got := svc.getByAppName("bar") + require.NotNil(t, got) assert.Equal(t, "bar.example.com", got.Config.Domain) - _, ok = svc.getByAppName("notfound") - assert.False(t, ok) + got = svc.getByAppName("notfound") + assert.Nil(t, got) }, }, { description: "RemoveIngress clears domain and app name entries", run: func(t *testing.T, svc *KubernetesService) { - app := config.App{Config: config.AppConfig{Domain: "baz.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}} svc.addIngressApps("default", "my-ingress", []ingressApp{ {domain: "baz.example.com", appName: "baz", app: app}, }) svc.removeIngress("default", "my-ingress") - _, ok := svc.getByDomain("baz.example.com") - assert.False(t, ok) - _, ok = svc.getByAppName("baz") - assert.False(t, ok) + got := svc.getByDomain("baz.example.com") + assert.Nil(t, got) + got = svc.getByAppName("baz") + assert.Nil(t, got) }, }, { description: "AddIngressApps replaces stale entries for the same ingress", run: func(t *testing.T, svc *KubernetesService) { - old := config.App{Config: config.AppConfig{Domain: "old.example.com"}} + old := model.App{Config: model.AppConfig{Domain: "old.example.com"}} svc.addIngressApps("default", "my-ingress", []ingressApp{ {domain: "old.example.com", appName: "old", app: old}, }) - updated := config.App{Config: config.AppConfig{Domain: "new.example.com"}} + updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}} svc.addIngressApps("default", "my-ingress", []ingressApp{ {domain: "new.example.com", appName: "new", app: updated}, }) - _, ok := svc.getByDomain("old.example.com") - assert.False(t, ok) + got := svc.getByDomain("old.example.com") + assert.Nil(t, got) - got, ok := svc.getByDomain("new.example.com") - require.True(t, ok) + got = svc.getByDomain("new.example.com") + require.NotNil(t, got) assert.Equal(t, "new.example.com", got.Config.Domain) }, }, @@ -91,7 +91,7 @@ func TestKubernetesService(t *testing.T) { run: func(t *testing.T, svc *KubernetesService) { svc.started = true - app := config.App{Config: config.AppConfig{Domain: "hit.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}} svc.addIngressApps("default", "ing", []ingressApp{ {domain: "hit.example.com", appName: "hit", app: app}, }) @@ -108,7 +108,7 @@ func TestKubernetesService(t *testing.T) { got, err := svc.GetLabels("notfound.example.com") require.NoError(t, err) - assert.Equal(t, config.App{}, got) + assert.Nil(t, got) }, }, { @@ -116,7 +116,7 @@ func TestKubernetesService(t *testing.T) { run: func(t *testing.T, svc *KubernetesService) { svc.started = true - app := config.App{Config: config.AppConfig{Domain: "myapp.internal.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}} svc.addIngressApps("default", "ing", []ingressApp{ {domain: "myapp.internal.example.com", appName: "myapp", app: app}, }) @@ -131,7 +131,7 @@ func TestKubernetesService(t *testing.T) { run: func(t *testing.T, svc *KubernetesService) { got, err := svc.GetLabels("anything.example.com") require.NoError(t, err) - assert.Equal(t, config.App{}, got) + assert.Nil(t, got) }, }, { @@ -147,8 +147,8 @@ func TestKubernetesService(t *testing.T) { svc.updateFromItem(&item) - got, ok := svc.getByDomain("myapp.example.com") - require.True(t, ok) + got := svc.getByDomain("myapp.example.com") + require.NotNil(t, got) assert.Equal(t, "myapp.example.com", got.Config.Domain) assert.Equal(t, "alice", got.Users.Allow) }, @@ -156,7 +156,7 @@ func TestKubernetesService(t *testing.T) { { description: "UpdateFromItem with no annotations removes existing cache entries", run: func(t *testing.T, svc *KubernetesService) { - app := config.App{Config: config.AppConfig{Domain: "todelete.example.com"}} + app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}} svc.addIngressApps("default", "test-ingress", []ingressApp{ {domain: "todelete.example.com", appName: "todelete", app: app}, }) @@ -167,8 +167,8 @@ func TestKubernetesService(t *testing.T) { svc.updateFromItem(&item) - _, ok := svc.getByDomain("todelete.example.com") - assert.False(t, ok) + got := svc.getByDomain("todelete.example.com") + assert.Nil(t, got) }, }, } diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 610a8821..15823c47 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,7 +1,7 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "slices" @@ -15,20 +15,20 @@ type OAuthServiceImpl interface { NewRandom() string GetAuthURL(state string, verifier string) string GetToken(code string, verifier string) (*oauth2.Token, error) - GetUserinfo(token *oauth2.Token) (config.Claims, error) + GetUserinfo(token *oauth2.Token) (*model.Claims, error) } type OAuthBrokerService struct { services map[string]OAuthServiceImpl - configs map[string]config.OAuthServiceConfig + configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { +func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { return &OAuthBrokerService{ services: make(map[string]OAuthServiceImpl), configs: configs, diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go index 45d03f74..96e2a034 100644 --- a/internal/service/oauth_extractors.go +++ b/internal/service/oauth_extractors.go @@ -8,12 +8,13 @@ import ( "net/http" "strconv" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` } type GithubUserInfoResponse struct { @@ -22,32 +23,32 @@ type GithubUserInfoResponse struct { ID int `json:"id"` } -func defaultExtractor(client *http.Client, url string) (config.Claims, error) { - return simpleReq[config.Claims](client, url, nil) +func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { + return simpleReq[model.Claims](client, url, nil) } -func githubExtractor(client *http.Client, url string) (config.Claims, error) { - var user config.Claims +func githubExtractor(client *http.Client, url string) (*model.Claims, error) { + var user model.Claims userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } - if len(userEmails) == 0 { - return user, errors.New("no emails found") + if len(*userEmails) == 0 { + return nil, errors.New("no emails found") } - for _, email := range userEmails { + for _, email := range *userEmails { if email.Primary { user.Email = email.Email break @@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) { // Use first available email if no primary email was found if user.Email == "" { - user.Email = userEmails[0].Email + for _, email := range *userEmails { + if email.Verified { + user.Email = email.Email + break + } + } + } + + if user.Email == "" { + return nil, errors.New("no verified email found") } user.PreferredUsername = userInfo.Login user.Name = userInfo.Name user.Sub = strconv.Itoa(userInfo.ID) - return user, nil + return &user, nil } -func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) { +func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) { var decodedRes T req, err := http.NewRequest("GET", url, nil) if err != nil { - return decodedRes, err + return nil, err } for key, value := range headers { @@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string res, err := client.Do(req) if err != nil { - return decodedRes, err + return nil, err } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode >= 300 { - return decodedRes, fmt.Errorf("request failed with status: %s", res.Status) + return nil, fmt.Errorf("request failed with status: %s", res.Status) } body, err := io.ReadAll(res.Body) if err != nil { - return decodedRes, err + return nil, err } err = json.Unmarshal(body, &decodedRes) if err != nil { - return decodedRes, err + return nil, err } - return decodedRes, nil + return &decodedRes, nil } diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index df23be5e..ef21fa60 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,11 +1,11 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL @@ -14,7 +14,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { return NewOAuthService(config, "google") } -func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 4ef118ea..11b0be9c 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -6,21 +6,21 @@ import ( "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2" ) -type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error) +type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type OAuthService struct { - serviceCfg config.OAuthServiceConfig + serviceCfg model.OAuthServiceConfig config *oauth2.Config ctx context.Context userinfoExtractor UserinfoExtractor id string } -func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) } -func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) { +func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1ac138ae..1e1c1986 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -22,7 +22,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -68,27 +68,27 @@ type ClaimSet struct { } type UserinfoResponse struct { - Sub string `json:"sub"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - Email string `json:"email,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Groups []string `json:"groups,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` - Address *config.AddressClaim `json:"address,omitempty"` - UpdatedAt int64 `json:"updated_at"` + Sub string `json:"sub"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale string `json:"locale,omitempty"` + Email string `json:"email,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Groups []string `json:"groups,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` + Address *model.AddressClaim `json:"address,omitempty"` + UpdatedAt int64 `json:"updated_at"` } type TokenResponse struct { @@ -112,7 +112,7 @@ type AuthorizeRequest struct { } type OIDCServiceConfig struct { - Clients map[string]config.OIDCClientConfig + Clients map[string]model.OIDCClientConfig PrivateKeyPath string PublicKeyPath string Issuer string @@ -122,7 +122,7 @@ type OIDCServiceConfig struct { type OIDCService struct { config OIDCServiceConfig queries *repository.Queries - clients map[string]config.OIDCClientConfig + clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey issuer string @@ -255,7 +255,7 @@ func (service *OIDCService) Init() error { } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]config.OIDCClientConfig) + service.clients = make(map[string]model.OIDCClientConfig) for id, client := range service.config.Clients { client.ID = id @@ -283,7 +283,7 @@ func (service *OIDCService) GetIssuer() string { return service.issuer } -func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { +func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) { client, ok := service.clients[id] return client, ok } @@ -367,43 +367,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r return err } -func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { - addressJSON, err := json.Marshal(userContext.Attributes.Address) - if err != nil { - return err - } - +func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { userInfoParams := repository.CreateOidcUserInfoParams{ Sub: sub, - Name: userContext.Name, - Email: userContext.Email, - PreferredUsername: userContext.Username, + Name: userContext.GetName(), + Email: userContext.GetEmail(), + PreferredUsername: userContext.GetUsername(), UpdatedAt: time.Now().Unix(), - GivenName: userContext.Attributes.GivenName, - FamilyName: userContext.Attributes.FamilyName, - MiddleName: userContext.Attributes.MiddleName, - Nickname: userContext.Attributes.Nickname, - Profile: userContext.Attributes.Profile, - Picture: userContext.Attributes.Picture, - Website: userContext.Attributes.Website, - Gender: userContext.Attributes.Gender, - Birthdate: userContext.Attributes.Birthdate, - Zoneinfo: userContext.Attributes.Zoneinfo, - Locale: userContext.Attributes.Locale, - PhoneNumber: userContext.Attributes.PhoneNumber, - Address: string(addressJSON), + } + + if userContext.IsLocal() { + addressJSON, err := json.Marshal(userContext.Local.Attributes.Address) + if err != nil { + return err + } + userInfoParams.GivenName = userContext.Local.Attributes.GivenName + userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName + userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName + userInfoParams.Nickname = userContext.Local.Attributes.Nickname + userInfoParams.Profile = userContext.Local.Attributes.Profile + userInfoParams.Picture = userContext.Local.Attributes.Picture + userInfoParams.Website = userContext.Local.Attributes.Website + userInfoParams.Gender = userContext.Local.Attributes.Gender + userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate + userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo + userInfoParams.Locale = userContext.Local.Attributes.Locale + userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber + userInfoParams.Address = string(addressJSON) } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server - if userContext.Provider == "ldap" { - userInfoParams.Groups = userContext.LdapGroups + if userContext.IsLDAP() { + userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") } - if userContext.OAuth && len(userContext.OAuthGroups) > 0 { - userInfoParams.Groups = userContext.OAuthGroups + if userContext.IsOAuth() { + userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") } - _, err = service.queries.CreateOidcUserInfo(c, userInfoParams) + _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) return err } @@ -445,7 +447,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client return oidcCode, nil } -func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { +func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() @@ -511,7 +513,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { +func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { user, err := service.GetUserinfo(c, codeEntry.Sub) if err != nil { @@ -585,7 +587,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri return TokenResponse{}, err } - idToken, err := service.generateIDToken(config.OIDCClientConfig{ + idToken, err := service.generateIDToken(model.OIDCClientConfig{ ClientID: entry.ClientID, }, user, entry.Scope, entry.Nonce) @@ -714,7 +716,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "address") { - var addr config.AddressClaim + var addr model.AddressClaim if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { userInfo.Address = &addr } diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index 222ad626..394df4be 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -7,13 +7,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" ) func newTestUser() repository.OidcUserinfo { - addr := config.AddressClaim{ + addr := model.AddressClaim{ Formatted: "123 Main St", StreetAddress: "123 Main St", Locality: "Springfield", diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 55665ee0..e7206bd8 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,10 +7,8 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/gin-gonic/gin" "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -73,22 +71,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { return res } -func GetContext(c *gin.Context) (config.UserContext, error) { - userContextValue, exists := c.Get("context") - - if !exists { - return config.UserContext{}, errors.New("no user context in request") - } - - userContext, ok := userContextValue.(*config.UserContext) - - if !ok { - return config.UserContext{}, errors.New("invalid user context in request") - } - - return *userContext, nil -} - func IsRedirectSafe(redirectURL string, domain string) bool { if redirectURL == "" { return false diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index a44c08d3..46dacafc 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -3,11 +3,8 @@ package utils_test import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "gotest.tools/v3/assert" ) func TestGetRootDomain(t *testing.T) { @@ -15,14 +12,14 @@ func TestGetRootDomain(t *testing.T) { domain := "http://sub.tinyauth.app" expected := "tinyauth.app" result, err := utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Domain with multiple subdomains domain = "http://b.c.tinyauth.app" expected = "c.tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Invalid domain (only TLD) @@ -44,14 +41,14 @@ func TestGetRootDomain(t *testing.T) { domain = "https://sub.tinyauth.app/path" expected = "tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // URL with port domain = "http://sub.tinyauth.app:8080" expected = "tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Domain managed by ICANN @@ -98,57 +95,35 @@ func TestFilter(t *testing.T) { testFunc := func(n int) bool { return n%2 == 0 } expected := []int{2, 4} result := utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with no matches slice = []int{1, 3, 5} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with all matches slice = []int{2, 4, 6} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{2, 4, 6} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with empty slice slice = []int{} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with different type (string) sliceStr := []string{"apple", "banana", "cherry"} testFuncStr := func(s string) bool { return len(s) > 5 } expectedStr := []string{"banana", "cherry"} resultStr := utils.Filter(sliceStr, testFuncStr) - assert.DeepEqual(t, expectedStr, resultStr) -} - -func TestGetContext(t *testing.T) { - // Setup - gin.SetMode(gin.TestMode) - c, _ := gin.CreateTestContext(nil) - - // Normal case - c.Set("context", &config.UserContext{Username: "testuser"}) - result, err := utils.GetContext(c) - assert.NilError(t, err) - assert.Equal(t, "testuser", result.Username) - - // Case with no context - c.Set("context", nil) - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") - - // Case with invalid context type - c.Set("context", "invalid type") - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") + assert.Equal(t, expectedStr, resultStr) } func TestIsRedirectSafe(t *testing.T) { @@ -158,50 +133,50 @@ func TestIsRedirectSafe(t *testing.T) { // Case with no subdomain redirectURL := "http://example.com/welcome" result := utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with different domain redirectURL = "http://malicious.com/phishing" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with subdomain redirectURL = "http://sub.example.com/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with sub-subdomain redirectURL = "http://a.b.example.com/home" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with empty redirect URL redirectURL = "" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with invalid URL redirectURL = "http://[::1]:namedport" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with URL having port redirectURL = "http://sub.example.com:8080/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with URL having different subdomain redirectURL = "http://another.example.com/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with URL having different TLD redirectURL = "http://example.org/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with malicious domain redirectURL = "https://malicious-example.com/yoyo" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) } diff --git a/internal/utils/decoders/label_decoder_test.go b/internal/utils/decoders/label_decoder_test.go index bf5d49fd..9048e7bc 100644 --- a/internal/utils/decoders/label_decoder_test.go +++ b/internal/utils/decoders/label_decoder_test.go @@ -3,42 +3,41 @@ package decoders_test import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - - "gotest.tools/v3/assert" ) func TestDecodeLabels(t *testing.T) { // Variables - expected := config.Apps{ - Apps: map[string]config.App{ + expected := model.Apps{ + Apps: map[string]model.App{ "foo": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "example.com", }, - Users: config.AppUsers{ + Users: model.AppUsers{ Allow: "user1,user2", Block: "user3", }, - OAuth: config.AppOAuth{ + OAuth: model.AppOAuth{ Whitelist: "somebody@example.com", Groups: "group3", }, - IP: config.AppIP{ + IP: model.AppIP{ Allow: []string{"10.71.0.1/24", "10.71.0.2"}, Block: []string{"10.10.10.10", "10.0.0.0/24"}, Bypass: []string{"192.168.1.1"}, }, - Response: config.AppResponse{ + Response: model.AppResponse{ Headers: []string{"X-Foo=Bar", "X-Baz=Qux"}, - BasicAuth: config.AppBasicAuth{ + BasicAuth: model.AppBasicAuth{ Username: "admin", Password: "password", PasswordFile: "/path/to/passwordfile", }, }, - Path: config.AppPath{ + Path: model.AppPath{ Allow: "/public", Block: "/private", }, @@ -63,7 +62,7 @@ func TestDecodeLabels(t *testing.T) { } // Test - result, err := decoders.DecodeLabels[config.Apps](test, "apps") - assert.NilError(t, err) - assert.DeepEqual(t, expected, result) + result, err := decoders.DecodeLabels[model.Apps](test, "apps") + assert.NoError(t, err) + assert.Equal(t, expected, result) } diff --git a/internal/utils/fs_utils_test.go b/internal/utils/fs_utils_test.go index 54033ba5..68154419 100644 --- a/internal/utils/fs_utils_test.go +++ b/internal/utils/fs_utils_test.go @@ -4,24 +4,25 @@ import ( "os" "testing" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReadFile(t *testing.T) { // Setup file, err := os.Create("/tmp/tinyauth_test_file") - assert.NilError(t, err) + require.NoError(t, err) _, err = file.WriteString("file content\n") - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) + require.NoError(t, err) defer os.Remove("/tmp/tinyauth_test_file") // Normal case content, err := ReadFile("/tmp/tinyauth_test_file") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "file content\n", content) // Non-existing file diff --git a/internal/utils/label_utils_test.go b/internal/utils/label_utils_test.go index 1d1554bb..7da1947d 100644 --- a/internal/utils/label_utils_test.go +++ b/internal/utils/label_utils_test.go @@ -3,9 +3,8 @@ package utils_test import ( "testing" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestParseHeaders(t *testing.T) { @@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) { "X-Custom-Header": "Value", "Another-Header": "AnotherValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Case insensitivity and trimming headers = []string{ @@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) { "X-Custom-Header": "Value", "Another-Header": "AnotherValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Invalid headers (missing '=', empty key/value) headers = []string{ @@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) { " = ", } expected = map[string]string{} - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Headers with unsafe characters headers = []string{ @@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) { "Another-Header": "AnotherValue", "Good-Header": "GoodValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Header with spaces in key (should be ignored) headers = []string{ @@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) { expected = map[string]string{ "Valid-Header": "ValidValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) } func TestSanitizeHeader(t *testing.T) { diff --git a/internal/utils/loaders/loader_env.go b/internal/utils/loaders/loader_env.go index f441ddda..c09ad828 100644 --- a/internal/utils/loaders/loader_env.go +++ b/internal/utils/loaders/loader_env.go @@ -4,21 +4,20 @@ import ( "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/env" + "github.com/tinyauthapp/tinyauth/internal/model" ) type EnvLoader struct{} func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) { - vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration) + vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration) if len(vars) == 0 { return false, nil } - if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil { + if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil { return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err) } diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 1b8d8e9f..abfdbfe8 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string { return "" } -func GetBasicAuth(username string, password string) string { +func EncodeBasicAuth(username string, password string) string { auth := username + ":" + password return base64.StdEncoding.EncodeToString([]byte(auth)) } diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go index 48c37335..6feac4ca 100644 --- a/internal/utils/security_utils_test.go +++ b/internal/utils/security_utils_test.go @@ -4,21 +4,21 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestGetSecret(t *testing.T) { // Setup file, err := os.Create("/tmp/tinyauth_test_secret") - assert.NilError(t, err) + require.NoError(t, err) _, err = file.WriteString(" secret \n") - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) + require.NoError(t, err) defer os.Remove("/tmp/tinyauth_test_secret") // Get from config @@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) { assert.Equal(t, "", utils.ParseSecretFile(content)) } -func TestGetBasicAuth(t *testing.T) { +func TestEncodeBasicAuth(t *testing.T) { // Normal case username := "user" password := "pass" expected := "dXNlcjpwYXNz" // base64 of "user:pass" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) // Empty username username = "" password = "pass" expected = "OnBhc3M=" // base64 of ":pass" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) // Empty password username = "user" password = "" expected = "dXNlcjo=" // base64 of "user:" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) } func TestFilterIP(t *testing.T) { // Exact match IPv4 ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // Non-match IPv4 ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, false, ok) // CIDR match IPv4 ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR match IPv4 with '-' instead of '/' ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR non-match IPv4 ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, false, ok) // Invalid CIDR @@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) { // Different output for different input id3 := utils.GenerateUUID("differentstring") - assert.Assert(t, id1 != id3) + assert.NotEqual(t, id2, id3) } diff --git a/internal/utils/string_utils_test.go b/internal/utils/string_utils_test.go index 1db3bf17..30c192c9 100644 --- a/internal/utils/string_utils_test.go +++ b/internal/utils/string_utils_test.go @@ -3,9 +3,8 @@ package utils_test import ( "testing" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestCapitalize(t *testing.T) { diff --git a/internal/utils/tlog/log_wrapper.go b/internal/utils/tlog/log_wrapper.go index e3220e40..ffdfcf91 100644 --- a/internal/utils/tlog/log_wrapper.go +++ b/internal/utils/tlog/log_wrapper.go @@ -7,7 +7,7 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type Logger struct { @@ -22,7 +22,7 @@ var ( App zerolog.Logger ) -func NewLogger(cfg config.LogConfig) *Logger { +func NewLogger(cfg model.LogConfig) *Logger { baseLogger := log.With(). Timestamp(). Caller(). @@ -44,24 +44,24 @@ func NewLogger(cfg config.LogConfig) *Logger { } func NewSimpleLogger() *Logger { - return NewLogger(config.LogConfig{ + return NewLogger(model.LogConfig{ Level: "info", Json: false, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: false}, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, }, }) } func NewTestLogger() *Logger { - return NewLogger(config.LogConfig{ + return NewLogger(model.LogConfig{ Level: "trace", - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: true}, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, }, }) } @@ -72,7 +72,7 @@ func (l *Logger) Init() { App = l.App } -func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { +func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { if !streamCfg.Enabled { return zerolog.Nop() } diff --git a/internal/utils/tlog/log_wrapper_test.go b/internal/utils/tlog/log_wrapper_test.go index 2db9e2a6..41609f53 100644 --- a/internal/utils/tlog/log_wrapper_test.go +++ b/internal/utils/tlog/log_wrapper_test.go @@ -5,75 +5,75 @@ import ( "encoding/json" "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/rs/zerolog" - "gotest.tools/v3/assert" ) func TestNewLogger(t *testing.T) { - cfg := config.LogConfig{ + cfg := model.LogConfig{ Level: "debug", Json: true, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true, Level: "info"}, - App: config.LogStreamConfig{Enabled: true, Level: ""}, - Audit: config.LogStreamConfig{Enabled: false, Level: ""}, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true, Level: "info"}, + App: model.LogStreamConfig{Enabled: true, Level: ""}, + Audit: model.LogStreamConfig{Enabled: false, Level: ""}, }, } logger := tlog.NewLogger(cfg) - assert.Assert(t, logger != nil) - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) + assert.NotNil(t, logger) + assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) + assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel()) + assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) } func TestNewSimpleLogger(t *testing.T) { logger := tlog.NewSimpleLogger() - assert.Assert(t, logger != nil) - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) + assert.NotNil(t, logger) + assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) + assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel()) + assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) } func TestLoggerInit(t *testing.T) { logger := tlog.NewSimpleLogger() logger.Init() - assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled) + assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel()) } func TestLoggerWithDisabledStreams(t *testing.T) { - cfg := config.LogConfig{ + cfg := model.LogConfig{ Level: "info", Json: false, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: false}, - App: config.LogStreamConfig{Enabled: false}, - Audit: config.LogStreamConfig{Enabled: false}, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: false}, + Audit: model.LogStreamConfig{Enabled: false}, }, } logger := tlog.NewLogger(cfg) - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled) - assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) + assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel()) + assert.Equal(t, zerolog.Disabled, logger.App.GetLevel()) + assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) } func TestLogStreamField(t *testing.T) { var buf bytes.Buffer - cfg := config.LogConfig{ + cfg := model.LogConfig{ Level: "info", Json: true, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: true}, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, }, } @@ -86,7 +86,7 @@ func TestLogStreamField(t *testing.T) { var logEntry map[string]interface{} err := json.Unmarshal(buf.Bytes(), &logEntry) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "http", logEntry["log_stream"]) assert.Equal(t, "test message", logEntry["message"]) diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go index d80c655d..420c7e31 100644 --- a/internal/utils/user_utils.go +++ b/internal/utils/user_utils.go @@ -6,14 +6,14 @@ import ( "net/mail" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) -func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { - var users []config.User +func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { + var users []model.LocalUser if len(usersStr) == 0 { - return []config.User{}, nil + return &users, nil } for _, user := range usersStr { @@ -22,22 +22,22 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut } parsed, err := ParseUser(strings.TrimSpace(user)) if err != nil { - return []config.User{}, err + return nil, err } if attrs, ok := userAttributes[parsed.Username]; ok { parsed.Attributes = attrs } - users = append(users, parsed) + users = append(users, *parsed) } - return users, nil + return &users, nil } -func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { +func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { var usersStr []string if len(usersCfg) == 0 && usersPath == "" { - return []config.User{}, nil + return nil, nil } if len(usersCfg) > 0 { @@ -48,7 +48,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con contents, err := ReadFile(usersPath) if err != nil { - return []config.User{}, err + return nil, err } lines := strings.SplitSeq(contents, "\n") @@ -65,7 +65,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con return ParseUsers(usersStr, userAttributes) } -func ParseUser(userStr string) (config.User, error) { +func ParseUser(userStr string) (*model.LocalUser, error) { if strings.Contains(userStr, "$$") { userStr = strings.ReplaceAll(userStr, "$$", "$") } @@ -73,27 +73,27 @@ func ParseUser(userStr string) (config.User, error) { parts := strings.SplitN(userStr, ":", 4) if len(parts) < 2 || len(parts) > 3 { - return config.User{}, errors.New("invalid user format") + return nil, errors.New("invalid user format") } for i, part := range parts { trimmed := strings.TrimSpace(part) if trimmed == "" { - return config.User{}, errors.New("invalid user format") + return nil, errors.New("invalid user format") } parts[i] = trimmed } - user := config.User{ + user := model.LocalUser{ Username: parts[0], Password: parts[1], } if len(parts) == 3 { - user.TotpSecret = parts[2] + user.TOTPSecret = parts[2] } - return user, nil + return &user, nil } func CompileUserEmail(username string, domain string) string { diff --git a/internal/utils/user_utils_test.go b/internal/utils/user_utils_test.go index dcbb75cf..973be918 100644 --- a/internal/utils/user_utils_test.go +++ b/internal/utils/user_utils_test.go @@ -4,74 +4,76 @@ import ( "os" "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestGetUsers(t *testing.T) { + tmpDir := t.TempDir() + hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G" // Setup - file, err := os.Create("/tmp/tinyauth_users_test.txt") - assert.NilError(t, err) + file, err := os.Create(tmpDir + "/tinyauth_users_test.txt") + require.NoError(t, err) _, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) - defer os.Remove("/tmp/tinyauth_users_test.txt") + require.NoError(t, err) + defer os.Remove(tmpDir + "/tinyauth_users_test.txt") - noAttrs := map[string]config.UserAttributes{} + noAttrs := map[string]model.UserAttributes{} // Test file only - users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs) - - assert.NilError(t, err) + users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs) - assert.Equal(t, 2, len(users)) + assert.NoError(t, err) + assert.NotNil(t, users) + assert.Len(t, *users, 2) - assert.Equal(t, "user1", users[0].Username) - assert.Equal(t, hash, users[0].Password) - assert.Equal(t, "user2", users[1].Username) - assert.Equal(t, hash, users[1].Password) + assert.Equal(t, "user1", (*users)[0].Username) + assert.Equal(t, hash, (*users)[0].Password) + assert.Equal(t, "user2", (*users)[1].Username) + assert.Equal(t, hash, (*users)[1].Password) // Test inline config only users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs) - assert.NilError(t, err) + assert.NoError(t, err) - assert.Equal(t, 2, len(users)) - assert.Equal(t, "user3", users[0].Username) - assert.Equal(t, "user4", users[1].Username) + assert.Len(t, *users, 2) + assert.Equal(t, "user3", (*users)[0].Username) + assert.Equal(t, "user4", (*users)[1].Username) // Test both - users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs) + users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs) - assert.NilError(t, err) + assert.NoError(t, err) - assert.Equal(t, 3, len(users)) + assert.Len(t, *users, 3) usernames := map[string]bool{} - for _, u := range users { + for _, u := range *users { usernames[u.Username] = true } - assert.Assert(t, usernames["user1"]) - assert.Assert(t, usernames["user2"]) - assert.Assert(t, usernames["user5"]) + assert.True(t, usernames["user1"]) + assert.True(t, usernames["user2"]) + assert.True(t, usernames["user5"]) // Test attributes applied from userAttributes map - attrs := map[string]config.UserAttributes{ + attrs := map[string]model.UserAttributes{ "user1": {Name: "User One", Email: "user1@example.com"}, } - users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs) + users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs) - assert.NilError(t, err) - assert.Equal(t, 2, len(users)) + assert.NoError(t, err) + assert.Len(t, *users, 2) - for _, u := range users { + for _, u := range *users { if u.Username == "user1" { assert.Equal(t, "User One", u.Attributes.Name) assert.Equal(t, "user1@example.com", u.Attributes.Email) @@ -84,16 +86,14 @@ func TestGetUsers(t *testing.T) { // Test empty users, err = utils.GetUsers([]string{}, "", noAttrs) - assert.NilError(t, err) - - assert.Equal(t, 0, len(users)) + assert.NoError(t, err) + assert.Nil(t, users) // Test non-existent file - users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs) + users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs) assert.ErrorContains(t, err, "no such file or directory") - - assert.Equal(t, 0, len(users)) + assert.Nil(t, users) } func TestParseUser(t *testing.T) { @@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) { // Valid user without TOTP user, err := utils.ParseUser("user1:" + hash) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user1", user.Username) assert.Equal(t, hash, user.Password) - assert.Equal(t, "", user.TotpSecret) + assert.Equal(t, "", user.TOTPSecret) // Valid user with TOTP user, err = utils.ParseUser("user2:" + hash + ":ABCDEF") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user2", user.Username) assert.Equal(t, hash, user.Password) - assert.Equal(t, "ABCDEF", user.TotpSecret) + assert.Equal(t, "ABCDEF", user.TOTPSecret) // Valid user with $$ in password user, err = utils.ParseUser("user3:pa$$word123") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user3", user.Username) assert.Equal(t, "pa$word123", user.Password) - assert.Equal(t, "", user.TotpSecret) + assert.Equal(t, "", user.TOTPSecret) // User with spaces user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user4", user.Username) assert.Equal(t, "password123", user.Password) - assert.Equal(t, "TOTPSECRET", user.TotpSecret) + assert.Equal(t, "TOTPSECRET", user.TOTPSecret) // Invalid users _, err = utils.ParseUser("user1") // Missing password