diff --git a/.env.example b/.env.example index d9dcdbdb2..010e21b59 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,7 @@ COMPOSE_PROJECT_NAME=tracecat # --- Shared URL env vars --- PUBLIC_APP_URL=http://localhost PUBLIC_API_URL=http://localhost/api +SAML_SP_ACS_URL=${PUBLIC_API_URL}/auth/saml/acs INTERNAL_API_URL=http://api:8000 # -- Caddy env vars --- @@ -61,13 +62,24 @@ NEXT_PUBLIC_API_URL=${PUBLIC_API_URL} NEXT_SERVER_API_URL=${INTERNAL_API_URL} # --- Authentication --- -# One or more comma-separated values from `basic`, `google_oauth` + +# One or more comma-separated values from `basic`, `google_oauth`, `saml` TRACECAT__AUTH_TYPES=basic,google_oauth +# Initial admin user +TRACECAT__SETUP_ADMIN_EMAIL= +TRACECAT__SETUP_ADMIN_PASSWORD= + +# OAuth OAUTH_CLIENT_ID= OAUTH_CLIENT_SECRET= USER_AUTH_SECRET=your-auth-secret +# SAML SSO settings +SAML_IDP_ENTITY_ID= +SAML_IDP_REDIRECT_URL= +SAML_IDP_CERTIFICATE= +SAML_IDP_METADATA_URL= # --- Temporal --- TEMPORAL__CLUSTER_URL=temporal:7233 diff --git a/Dockerfile b/Dockerfile index a730d6a85..76c126b1c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ EXPOSE $PORT # Install necessary packages RUN apt-get update && \ - apt-get install -y acl git && \ + apt-get install -y acl git xmlsec1 && \ rm -rf /var/lib/apt/lists/* # Copy and run the script to install additional packages diff --git a/Dockerfile.dev b/Dockerfile.dev index c06a61758..a51e89a4b 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -8,6 +8,9 @@ ENV PORT=8000 # Expose the application port EXPOSE $PORT +# Install xmlsec1 +RUN apt-get update && apt-get install -y xmlsec1 + # Set the working directory inside the container WORKDIR /app diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a75c4b92b..6baac3b21 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -38,6 +38,15 @@ services: OAUTH_CLIENT_ID: ${OAUTH_CLIENT_ID} OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} + # Initial admin user + TRACECAT__SETUP_ADMIN_EMAIL: ${TRACECAT__SETUP_ADMIN_EMAIL} + TRACECAT__SETUP_ADMIN_PASSWORD: ${TRACECAT__SETUP_ADMIN_PASSWORD} + # SAML SSO + SAML_IDP_ENTITY_ID: ${SAML_IDP_ENTITY_ID} + SAML_IDP_REDIRECT_URL: ${SAML_IDP_REDIRECT_URL} + SAML_IDP_CERTIFICATE: ${SAML_IDP_CERTIFICATE} + SAML_IDP_METADATA_URL: ${SAML_IDP_METADATA_URL} + SAML_SP_ACS_URL: ${SAML_SP_ACS_URL} # Temporal TEMPORAL__CLUSTER_URL: ${TEMPORAL__CLUSTER_URL} TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} @@ -155,7 +164,7 @@ services: image: temporalio/ui:${TEMPORAL__UI_VERSION} container_name: temporal_ui ports: - - 8080:8080 + - 9090:8080 environment: - TEMPORAL_ADDRESS=temporal:7233 - TEMPORAL_CORS_ORIGINS=http://localhost:3000 diff --git a/docker-compose.yml b/docker-compose.yml index fe6e98312..d1bc33f54 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,6 +40,15 @@ services: OAUTH_CLIENT_SECRET: ${OAUTH_CLIENT_SECRET} USER_AUTH_SECRET: ${USER_AUTH_SECRET} RUN_MIGRATIONS: "true" + # Initial admin user + TRACECAT__SETUP_ADMIN_EMAIL: ${TRACECAT__SETUP_ADMIN_EMAIL} + TRACECAT__SETUP_ADMIN_PASSWORD: ${TRACECAT__SETUP_ADMIN_PASSWORD} + # SAML SSO + SAML_IDP_ENTITY_ID: ${SAML_IDP_ENTITY_ID} + SAML_IDP_REDIRECT_URL: ${SAML_IDP_REDIRECT_URL} + SAML_IDP_CERTIFICATE: ${SAML_IDP_CERTIFICATE} + SAML_IDP_METADATA_URL: ${SAML_IDP_METADATA_URL} + SAML_SP_ACS_URL: ${SAML_SP_ACS_URL} # Temporal TEMPORAL__CLUSTER_URL: ${TEMPORAL__CLUSTER_URL} TEMPORAL__CLUSTER_QUEUE: ${TEMPORAL__CLUSTER_QUEUE} diff --git a/frontend/src/app/auth/oauth/callback/route.ts b/frontend/src/app/auth/oauth/callback/route.ts index b37ce6883..f9d95123a 100644 --- a/frontend/src/app/auth/oauth/callback/route.ts +++ b/frontend/src/app/auth/oauth/callback/route.ts @@ -3,8 +3,6 @@ import { NextRequest, NextResponse } from "next/server" import { buildUrl, getDomain } from "@/lib/ss-utils" /** - * Wrapper around the FastAPI endpoint /auth/oauth/callback, - * which adds back a redirect to the main app. * @param request * @returns */ diff --git a/frontend/src/app/auth/saml/acs/route.ts b/frontend/src/app/auth/saml/acs/route.ts new file mode 100644 index 000000000..a6aaa5918 --- /dev/null +++ b/frontend/src/app/auth/saml/acs/route.ts @@ -0,0 +1,51 @@ +import { NextRequest, NextResponse } from "next/server" + +import { buildUrl, getDomain } from "@/lib/ss-utils" + +/** + * @param request + * @returns + */ +export async function POST(request: NextRequest) { + console.log("POST /auth/saml/acs", request.nextUrl.toString()) + + // Parse the form data from the request + const formData = await request.formData() + const samlResponse = formData.get('SAMLResponse') + + if (!samlResponse) { + console.error("No SAML response found in the request") + return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + } + + // Prepare the request to the FastAPI backend + const backendUrl = new URL(buildUrl("/auth/saml/acs")) + const backendFormData = new FormData() + backendFormData.append('SAMLResponse', samlResponse) + + // Forward the request to the FastAPI backend + const backendResponse = await fetch(backendUrl.toString(), { + method: 'POST', + body: backendFormData, + }) + + if (!backendResponse.ok) { + console.error("Error from backend:", await backendResponse.text()) + return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + } + + const setCookieHeader = backendResponse.headers.get("set-cookie") + + if (!setCookieHeader) { + console.error("No set-cookie header found in response") + return NextResponse.redirect(new URL("/auth/error", getDomain(request))) + } + + console.log("Redirecting to / with GET") + const redirectUrl = new URL("/", getDomain(request)) + const redirectResponse = NextResponse.redirect(redirectUrl, { + status: 303 // Force GET request + }) + redirectResponse.headers.set("set-cookie", setCookieHeader) + return redirectResponse +} diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index 5bd60068c..3c4a1a5d1 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -452,6 +452,18 @@ export const $Body_auth_reset_reset_password = { title: 'Body_auth-reset:reset_password' } as const; +export const $Body_auth_sso_acs = { + properties: { + SAMLResponse: { + type: 'string', + title: 'Samlresponse' + } + }, + type: 'object', + required: ['SAMLResponse'], + title: 'Body_auth-sso_acs' +} as const; + export const $Body_auth_verify_request_token = { properties: { email: { @@ -1980,6 +1992,18 @@ export const $RunContext = { description: 'This is the runtime context model for a workflow run. Passed into activities.' } as const; +export const $SAMLDatabaseLoginResponse = { + properties: { + redirect_url: { + type: 'string', + title: 'Redirect Url' + } + }, + type: 'object', + required: ['redirect_url'], + title: 'SAMLDatabaseLoginResponse' +} as const; + export const $Schedule = { properties: { owner_id: { diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 3e96a9f3d..9ccaca33a 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { PublicIncomingWebhookData, PublicIncomingWebhookResponse, PublicIncomingWebhookWaitData, PublicIncomingWebhookWaitResponse, WorkspacesListWorkspacesResponse, WorkspacesCreateWorkspaceData, WorkspacesCreateWorkspaceResponse, WorkspacesSearchWorkspacesData, WorkspacesSearchWorkspacesResponse, WorkspacesGetWorkspaceData, WorkspacesGetWorkspaceResponse, WorkspacesUpdateWorkspaceData, WorkspacesUpdateWorkspaceResponse, WorkspacesDeleteWorkspaceData, WorkspacesDeleteWorkspaceResponse, WorkspacesListWorkspaceMembershipsData, WorkspacesListWorkspaceMembershipsResponse, WorkspacesCreateWorkspaceMembershipData, WorkspacesCreateWorkspaceMembershipResponse, WorkspacesGetWorkspaceMembershipData, WorkspacesGetWorkspaceMembershipResponse, WorkspacesDeleteWorkspaceMembershipData, WorkspacesDeleteWorkspaceMembershipResponse, WorkflowsListWorkflowsData, WorkflowsListWorkflowsResponse, WorkflowsCreateWorkflowData, WorkflowsCreateWorkflowResponse, WorkflowsGetWorkflowData, WorkflowsGetWorkflowResponse, WorkflowsUpdateWorkflowData, WorkflowsUpdateWorkflowResponse, WorkflowsDeleteWorkflowData, WorkflowsDeleteWorkflowResponse, WorkflowsCommitWorkflowData, WorkflowsCommitWorkflowResponse, WorkflowsExportWorkflowData, WorkflowsExportWorkflowResponse, WorkflowsGetWorkflowDefinitionData, WorkflowsGetWorkflowDefinitionResponse, WorkflowsCreateWorkflowDefinitionData, WorkflowsCreateWorkflowDefinitionResponse, TriggersCreateWebhookData, TriggersCreateWebhookResponse, TriggersGetWebhookData, TriggersGetWebhookResponse, TriggersUpdateWebhookData, TriggersUpdateWebhookResponse, WorkflowExecutionsListWorkflowExecutionsData, WorkflowExecutionsListWorkflowExecutionsResponse, WorkflowExecutionsCreateWorkflowExecutionData, WorkflowExecutionsCreateWorkflowExecutionResponse, WorkflowExecutionsGetWorkflowExecutionData, WorkflowExecutionsGetWorkflowExecutionResponse, WorkflowExecutionsListWorkflowExecutionEventHistoryData, WorkflowExecutionsListWorkflowExecutionEventHistoryResponse, WorkflowExecutionsCancelWorkflowExecutionData, WorkflowExecutionsCancelWorkflowExecutionResponse, WorkflowExecutionsTerminateWorkflowExecutionData, WorkflowExecutionsTerminateWorkflowExecutionResponse, ActionsListActionsData, ActionsListActionsResponse, ActionsCreateActionData, ActionsCreateActionResponse, ActionsGetActionData, ActionsGetActionResponse, ActionsUpdateActionData, ActionsUpdateActionResponse, ActionsDeleteActionData, ActionsDeleteActionResponse, SecretsSearchSecretsData, SecretsSearchSecretsResponse, SecretsListSecretsData, SecretsListSecretsResponse, SecretsCreateSecretData, SecretsCreateSecretResponse, SecretsGetSecretByNameData, SecretsGetSecretByNameResponse, SecretsUpdateSecretByIdData, SecretsUpdateSecretByIdResponse, SecretsDeleteSecretByIdData, SecretsDeleteSecretByIdResponse, SchedulesListSchedulesData, SchedulesListSchedulesResponse, SchedulesCreateScheduleData, SchedulesCreateScheduleResponse, SchedulesGetScheduleData, SchedulesGetScheduleResponse, SchedulesUpdateScheduleData, SchedulesUpdateScheduleResponse, SchedulesDeleteScheduleData, SchedulesDeleteScheduleResponse, SchedulesSearchSchedulesData, SchedulesSearchSchedulesResponse, ValidationValidateWorkflowData, ValidationValidateWorkflowResponse, UsersSearchUserData, UsersSearchUserResponse, RegistryRepositoriesSyncRegistryRepositoriesData, RegistryRepositoriesSyncRegistryRepositoriesResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesCreateRegistryRepositoryData, RegistryRepositoriesCreateRegistryRepositoryResponse, RegistryRepositoriesGetRegistryRepositoryData, RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, RegistryRepositoriesUpdateRegistryRepositoryResponse, RegistryRepositoriesDeleteRegistryRepositoryData, RegistryRepositoriesDeleteRegistryRepositoryResponse, RegistryActionsListRegistryActionsResponse, RegistryActionsCreateRegistryActionData, RegistryActionsCreateRegistryActionResponse, RegistryActionsGetRegistryActionData, RegistryActionsGetRegistryActionResponse, RegistryActionsUpdateRegistryActionData, RegistryActionsUpdateRegistryActionResponse, RegistryActionsDeleteRegistryActionData, RegistryActionsDeleteRegistryActionResponse, RegistryActionsRunRegistryActionData, RegistryActionsRunRegistryActionResponse, RegistryActionsValidateRegistryActionData, RegistryActionsValidateRegistryActionResponse, UsersUsersCurrentUserResponse, UsersUsersPatchCurrentUserData, UsersUsersPatchCurrentUserResponse, UsersUsersUserData, UsersUsersUserResponse, UsersUsersPatchUserData, UsersUsersPatchUserResponse, UsersUsersDeleteUserData, UsersUsersDeleteUserResponse, AuthAuthDatabaseLoginData, AuthAuthDatabaseLoginResponse, AuthAuthDatabaseLogoutResponse, AuthRegisterRegisterData, AuthRegisterRegisterResponse, AuthResetForgotPasswordData, AuthResetForgotPasswordResponse, AuthResetResetPasswordData, AuthResetResetPasswordResponse, AuthVerifyRequestTokenData, AuthVerifyRequestTokenResponse, AuthVerifyVerifyData, AuthVerifyVerifyResponse, AuthOauthGoogleDatabaseAuthorizeData, AuthOauthGoogleDatabaseAuthorizeResponse, AuthOauthGoogleDatabaseCallbackData, AuthOauthGoogleDatabaseCallbackResponse, PublicCheckHealthResponse } from './types.gen'; +import type { PublicIncomingWebhookData, PublicIncomingWebhookResponse, PublicIncomingWebhookWaitData, PublicIncomingWebhookWaitResponse, WorkspacesListWorkspacesResponse, WorkspacesCreateWorkspaceData, WorkspacesCreateWorkspaceResponse, WorkspacesSearchWorkspacesData, WorkspacesSearchWorkspacesResponse, WorkspacesGetWorkspaceData, WorkspacesGetWorkspaceResponse, WorkspacesUpdateWorkspaceData, WorkspacesUpdateWorkspaceResponse, WorkspacesDeleteWorkspaceData, WorkspacesDeleteWorkspaceResponse, WorkspacesListWorkspaceMembershipsData, WorkspacesListWorkspaceMembershipsResponse, WorkspacesCreateWorkspaceMembershipData, WorkspacesCreateWorkspaceMembershipResponse, WorkspacesGetWorkspaceMembershipData, WorkspacesGetWorkspaceMembershipResponse, WorkspacesDeleteWorkspaceMembershipData, WorkspacesDeleteWorkspaceMembershipResponse, WorkflowsListWorkflowsData, WorkflowsListWorkflowsResponse, WorkflowsCreateWorkflowData, WorkflowsCreateWorkflowResponse, WorkflowsGetWorkflowData, WorkflowsGetWorkflowResponse, WorkflowsUpdateWorkflowData, WorkflowsUpdateWorkflowResponse, WorkflowsDeleteWorkflowData, WorkflowsDeleteWorkflowResponse, WorkflowsCommitWorkflowData, WorkflowsCommitWorkflowResponse, WorkflowsExportWorkflowData, WorkflowsExportWorkflowResponse, WorkflowsGetWorkflowDefinitionData, WorkflowsGetWorkflowDefinitionResponse, WorkflowsCreateWorkflowDefinitionData, WorkflowsCreateWorkflowDefinitionResponse, TriggersCreateWebhookData, TriggersCreateWebhookResponse, TriggersGetWebhookData, TriggersGetWebhookResponse, TriggersUpdateWebhookData, TriggersUpdateWebhookResponse, WorkflowExecutionsListWorkflowExecutionsData, WorkflowExecutionsListWorkflowExecutionsResponse, WorkflowExecutionsCreateWorkflowExecutionData, WorkflowExecutionsCreateWorkflowExecutionResponse, WorkflowExecutionsGetWorkflowExecutionData, WorkflowExecutionsGetWorkflowExecutionResponse, WorkflowExecutionsListWorkflowExecutionEventHistoryData, WorkflowExecutionsListWorkflowExecutionEventHistoryResponse, WorkflowExecutionsCancelWorkflowExecutionData, WorkflowExecutionsCancelWorkflowExecutionResponse, WorkflowExecutionsTerminateWorkflowExecutionData, WorkflowExecutionsTerminateWorkflowExecutionResponse, ActionsListActionsData, ActionsListActionsResponse, ActionsCreateActionData, ActionsCreateActionResponse, ActionsGetActionData, ActionsGetActionResponse, ActionsUpdateActionData, ActionsUpdateActionResponse, ActionsDeleteActionData, ActionsDeleteActionResponse, SecretsSearchSecretsData, SecretsSearchSecretsResponse, SecretsListSecretsData, SecretsListSecretsResponse, SecretsCreateSecretData, SecretsCreateSecretResponse, SecretsGetSecretByNameData, SecretsGetSecretByNameResponse, SecretsUpdateSecretByIdData, SecretsUpdateSecretByIdResponse, SecretsDeleteSecretByIdData, SecretsDeleteSecretByIdResponse, SchedulesListSchedulesData, SchedulesListSchedulesResponse, SchedulesCreateScheduleData, SchedulesCreateScheduleResponse, SchedulesGetScheduleData, SchedulesGetScheduleResponse, SchedulesUpdateScheduleData, SchedulesUpdateScheduleResponse, SchedulesDeleteScheduleData, SchedulesDeleteScheduleResponse, SchedulesSearchSchedulesData, SchedulesSearchSchedulesResponse, ValidationValidateWorkflowData, ValidationValidateWorkflowResponse, UsersSearchUserData, UsersSearchUserResponse, RegistryRepositoriesSyncRegistryRepositoriesData, RegistryRepositoriesSyncRegistryRepositoriesResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesCreateRegistryRepositoryData, RegistryRepositoriesCreateRegistryRepositoryResponse, RegistryRepositoriesGetRegistryRepositoryData, RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, RegistryRepositoriesUpdateRegistryRepositoryResponse, RegistryRepositoriesDeleteRegistryRepositoryData, RegistryRepositoriesDeleteRegistryRepositoryResponse, RegistryActionsListRegistryActionsResponse, RegistryActionsCreateRegistryActionData, RegistryActionsCreateRegistryActionResponse, RegistryActionsGetRegistryActionData, RegistryActionsGetRegistryActionResponse, RegistryActionsUpdateRegistryActionData, RegistryActionsUpdateRegistryActionResponse, RegistryActionsDeleteRegistryActionData, RegistryActionsDeleteRegistryActionResponse, RegistryActionsRunRegistryActionData, RegistryActionsRunRegistryActionResponse, RegistryActionsValidateRegistryActionData, RegistryActionsValidateRegistryActionResponse, UsersUsersCurrentUserResponse, UsersUsersPatchCurrentUserData, UsersUsersPatchCurrentUserResponse, UsersUsersUserData, UsersUsersUserResponse, UsersUsersPatchUserData, UsersUsersPatchUserResponse, UsersUsersDeleteUserData, UsersUsersDeleteUserResponse, AuthAuthDatabaseLoginData, AuthAuthDatabaseLoginResponse, AuthAuthDatabaseLogoutResponse, AuthRegisterRegisterData, AuthRegisterRegisterResponse, AuthResetForgotPasswordData, AuthResetForgotPasswordResponse, AuthResetResetPasswordData, AuthResetResetPasswordResponse, AuthVerifyRequestTokenData, AuthVerifyRequestTokenResponse, AuthVerifyVerifyData, AuthVerifyVerifyResponse, AuthOauthGoogleDatabaseAuthorizeData, AuthOauthGoogleDatabaseAuthorizeResponse, AuthOauthGoogleDatabaseCallbackData, AuthOauthGoogleDatabaseCallbackResponse, AuthSamlDatabaseLoginResponse, AuthSsoAcsData, AuthSsoAcsResponse, PublicCheckHealthResponse } from './types.gen'; /** * Incoming Webhook @@ -1639,6 +1639,34 @@ export const authOauthGoogleDatabaseCallback = (data: AuthOauthGoogleDatabaseCal } }); }; +/** + * Saml:Database.Login + * @returns SAMLDatabaseLoginResponse Successful Response + * @throws ApiError + */ +export const authSamlDatabaseLogin = (): CancelablePromise => { return __request(OpenAPI, { + method: 'GET', + url: '/auth/saml/login' +}); }; + +/** + * Sso Acs + * Handle the SAML SSO response from the IdP post-authentication. + * @param data The data for the request. + * @param data.formData + * @returns unknown Successful Response + * @throws ApiError + */ +export const authSsoAcs = (data: AuthSsoAcsData): CancelablePromise => { return __request(OpenAPI, { + method: 'POST', + url: '/auth/saml/acs', + formData: data.formData, + mediaType: 'application/x-www-form-urlencoded', + errors: { + 422: 'Validation Error' + } +}); }; + /** * Check Health * @returns string Successful Response diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index f6983cf6f..b7eb53303 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -185,6 +185,10 @@ export type Body_auth_reset_reset_password = { password: string; }; +export type Body_auth_sso_acs = { + SAMLResponse: string; +}; + export type Body_auth_verify_request_token = { email: string; }; @@ -741,6 +745,10 @@ export type RunContext = { environment: string; }; +export type SAMLDatabaseLoginResponse = { + redirect_url: string; +}; + export type Schedule = { owner_id: string; created_at: string; @@ -1765,6 +1773,14 @@ export type AuthOauthGoogleDatabaseCallbackData = { export type AuthOauthGoogleDatabaseCallbackResponse = unknown; +export type AuthSamlDatabaseLoginResponse = SAMLDatabaseLoginResponse; + +export type AuthSsoAcsData = { + formData: Body_auth_sso_acs; +}; + +export type AuthSsoAcsResponse = unknown; + export type PublicCheckHealthResponse = { [key: string]: (string); }; @@ -2913,6 +2929,31 @@ export type $OpenApiTs = { }; }; }; + '/auth/saml/login': { + get: { + res: { + /** + * Successful Response + */ + 200: SAMLDatabaseLoginResponse; + }; + }; + }; + '/auth/saml/acs': { + post: { + req: AuthSsoAcsData; + res: { + /** + * Successful Response + */ + 200: unknown; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; '/health': { get: { res: { diff --git a/frontend/src/components/auth/saml.tsx b/frontend/src/components/auth/saml.tsx new file mode 100644 index 000000000..af492ac89 --- /dev/null +++ b/frontend/src/components/auth/saml.tsx @@ -0,0 +1,39 @@ +"use client" + +import { ComponentPropsWithoutRef, useState } from "react" + +import { Button } from "@/components/ui/button" +import { Icons } from "@/components/icons" +import { authSamlDatabaseLogin } from "@/client" + +type SamlSSOButtonProps = ComponentPropsWithoutRef +export function SamlSSOButton(props: SamlSSOButtonProps) { + const [isLoading, setIsLoading] = useState(false) + const handleClick = async () => { + try { + setIsLoading(true) + // Call api/auth/saml/login + const { redirect_url } = await authSamlDatabaseLogin() + window.location.href = redirect_url + } catch (error) { + console.error("Error authorizing with SAML", error) + } finally { + setIsLoading(false) + } + } + return ( + + ) +} diff --git a/frontend/src/components/auth/sign-in.tsx b/frontend/src/components/auth/sign-in.tsx index 30eef23af..5be810d2f 100644 --- a/frontend/src/components/auth/sign-in.tsx +++ b/frontend/src/components/auth/sign-in.tsx @@ -31,6 +31,7 @@ import { import { Input } from "@/components/ui/input" import { toast } from "@/components/ui/use-toast" import { GoogleOAuthButton } from "@/components/auth/oauth-buttons" +import { SamlSSOButton } from "@/components/auth/saml" import { Icons } from "@/components/icons" export function SignIn({ className }: React.HTMLProps) { @@ -72,6 +73,9 @@ export function SignIn({ className }: React.HTMLProps) { {authConfig.authTypes.includes("google_oauth") && ( )} + {authConfig.authTypes.includes("saml") && ( + + )} {/* */} {authConfig.authTypes.includes("basic") && ( diff --git a/frontend/src/components/icons.tsx b/frontend/src/components/icons.tsx index da19d94cd..8c5834193 100644 --- a/frontend/src/components/icons.tsx +++ b/frontend/src/components/icons.tsx @@ -2,6 +2,7 @@ import { Blend, Bolt, BoxesIcon, + Building2Icon, Cpu, Globe, Mail, @@ -153,6 +154,7 @@ export const Icons = { ), + saml: (props: IconProps) => , } export function getFlairSize(size: "sm" | "md" | "lg"): string { diff --git a/pyproject.toml b/pyproject.toml index 390c6a942..618997566 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "pyarrow==16.1.0", "pydantic==2.6.1", "python-slugify==8.0.4", + "pysaml2==7.5.0", "sqlmodel==0.0.18", "tenacity==8.3.0", "temporalio==1.6.0", diff --git a/tracecat/api/app.py b/tracecat/api/app.py index 55aae0ea2..0b25d85b2 100644 --- a/tracecat/api/app.py +++ b/tracecat/api/app.py @@ -16,7 +16,7 @@ from tracecat.api.routers.users import router as users_router from tracecat.api.routers.validation import router as validation_router from tracecat.auth.constants import AuthType -from tracecat.auth.schemas import UserCreate, UserRead, UserUpdate +from tracecat.auth.models import UserCreate, UserRead, UserUpdate from tracecat.auth.users import ( auth_backend, fastapi_users, @@ -296,6 +296,11 @@ def create_app(**kwargs) -> FastAPI: prefix="/auth", tags=["auth"], ) + if AuthType.SAML in config.TRACECAT__AUTH_TYPES: + from tracecat.auth.saml import router as saml_router + + logger.info("SAML auth type enabled") + app.include_router(saml_router) # Development endpoints if config.TRACECAT__APP_ENV == "development": diff --git a/tracecat/api/routers/users.py b/tracecat/api/routers/users.py index 4c5085801..6a9d10cdd 100644 --- a/tracecat/api/routers/users.py +++ b/tracecat/api/routers/users.py @@ -6,7 +6,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from tracecat.auth.credentials import authenticate_user_access_level -from tracecat.auth.schemas import UserRead +from tracecat.auth.models import UserRead from tracecat.db.engine import get_async_session from tracecat.db.schemas import User from tracecat.logger import logger diff --git a/tracecat/auth/credentials.py b/tracecat/auth/credentials.py index ba4e44cfb..a96a44964 100644 --- a/tracecat/auth/credentials.py +++ b/tracecat/auth/credentials.py @@ -21,7 +21,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from tracecat import config -from tracecat.auth.schemas import UserRole +from tracecat.auth.models import UserRole from tracecat.auth.users import ( current_active_user, is_unprivileged, diff --git a/tracecat/auth/schemas.py b/tracecat/auth/models.py similarity index 100% rename from tracecat/auth/schemas.py rename to tracecat/auth/models.py diff --git a/tracecat/auth/saml.py b/tracecat/auth/saml.py new file mode 100644 index 000000000..d1481594e --- /dev/null +++ b/tracecat/auth/saml.py @@ -0,0 +1,284 @@ +import tempfile +import xml.etree.ElementTree as ET +from contextlib import contextmanager +from dataclasses import asdict, dataclass +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status +from fastapi_users.exceptions import UserAlreadyExists +from pydantic import BaseModel +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT +from saml2.client import Saml2Client +from saml2.config import Config as Saml2Config + +from tracecat.auth.users import AuthBackendStrategyDep, UserManagerDep, auth_backend +from tracecat.config import ( + SAML_IDP_CERTIFICATE, + SAML_IDP_ENTITY_ID, + SAML_IDP_METADATA_URL, + SAML_IDP_REDIRECT_URL, + SAML_SP_ACS_URL, + TRACECAT__PUBLIC_API_URL, + XMLSEC_BINARY_PATH, +) +from tracecat.logger import logger + +router = APIRouter(prefix="/auth/saml", tags=["auth"]) + + +class SAMLDatabaseLoginResponse(BaseModel): + redirect_url: str + + +@dataclass +class SAMLAttribute: + """Represents a SAML attribute with its name, format, and value""" + + name: str + value: str + name_format: str = "" + + +class SAMLParser: + """Parser for SAML AttributeStatement responses""" + + NAMESPACES = { + "saml2": "urn:oasis:names:tc:SAML:2.0:assertion", + "xs": "http://www.w3.org/2001/XMLSchema", + "xsi": "http://www.w3.org/2001/XMLSchema-instance", + } + + def __init__(self, xml_string: str): + """Initialize parser with SAML XML string""" + self.xml_string = xml_string.strip() + self.attributes = None # Store lazily parsed attributes + + def _register_namespaces(self): + """Register namespaces for proper XML handling""" + for prefix, uri in self.NAMESPACES.items(): + ET.register_namespace(prefix, uri) + + def _extract_attribute(self, attribute_elem: ET.Element) -> SAMLAttribute: + """Extract a single SAML attribute from an XML element""" + name = attribute_elem.get("Name", "") + name_format = attribute_elem.get("NameFormat", "") + + # Get the attribute value + value_elem = attribute_elem.find("saml2:AttributeValue", self.NAMESPACES) + value = value_elem.text if value_elem is not None else "" + + return SAMLAttribute(name=name, value=value, name_format=name_format) + + def get_attribute_value(self, attribute_name: str) -> str: + """Helper method to easily get an attribute value""" + if self.attributes is None: + self.attributes = self.parse_to_dict() + return self.attributes.get(attribute_name, {}).get("value", "") + + def parse_to_dict(self) -> dict[str, Any]: + """Parse SAML XML and return attributes as a dictionary""" + self._register_namespaces() + try: + root = ET.fromstring(self.xml_string) + except ET.ParseError as e: + logger.error(f"SAML response parsing failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to parse SAML response", + ) from None + + # Find AttributeStatement + attr_statement = root.find(".//saml2:AttributeStatement", self.NAMESPACES) + if attr_statement is None: + logger.error("SAML response failed: AttributeStatement not found") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid SAML response" + ) from None + + # Process all attributes + attributes = {} + for attr_elem in attr_statement.findall("saml2:Attribute", self.NAMESPACES): + saml_attr = self._extract_attribute(attr_elem) + attributes[saml_attr.name] = asdict(saml_attr) + + return attributes + + +@contextmanager +def generate_saml_metadata_file(): + """Generate a temporary SAML metadata file.""" + + if not SAML_IDP_ENTITY_ID: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="SAML SSO entity ID has not been configured.", + ) + + if not SAML_IDP_REDIRECT_URL: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="SAML SSO redirect URL has not been configured.", + ) + + if not SAML_IDP_CERTIFICATE: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="SAML SSO certificate has not been configured.", + ) + + # Create the root element + root = ET.Element( + "EntityDescriptor", + { + "xmlns": "urn:oasis:names:tc:SAML:2.0:metadata", + "xmlns:ds": "http://www.w3.org/2000/09/xmldsig#", + "entityID": SAML_IDP_ENTITY_ID, + }, + ) + + # Create IDPSSODescriptor element + idp_sso_descriptor = ET.SubElement( + root, + "IDPSSODescriptor", + {"protocolSupportEnumeration": "urn:oasis:names:tc:SAML:2.0:protocol"}, + ) + + # Add KeyDescriptor + key_descriptor = ET.SubElement(idp_sso_descriptor, "KeyDescriptor", use="signing") + key_info = ET.SubElement(key_descriptor, "ds:KeyInfo") + x509_data = ET.SubElement(key_info, "ds:X509Data") + x509_certificate = ET.SubElement(x509_data, "ds:X509Certificate") + x509_certificate.text = SAML_IDP_CERTIFICATE + + # Add NameIDFormat + name_id_format = ET.SubElement(idp_sso_descriptor, "NameIDFormat") + name_id_format.text = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" + + # Add SingleSignOnService + ET.SubElement( + idp_sso_descriptor, + "SingleSignOnService", + {"Binding": BINDING_HTTP_REDIRECT, "Location": SAML_IDP_REDIRECT_URL}, + ) + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w+", suffix=".xml") as tmp_file: + # Write the XML to the temporary file + tree = ET.ElementTree(root) + tree.write(tmp_file, encoding="unicode", xml_declaration=True) + tmp_file.flush() + tmp_file_path = tmp_file.name + yield tmp_file_path + + +def create_saml_client() -> Saml2Client: + saml_settings = { + "strict": True, + # The global unique identifier for this service provider + "entityid": TRACECAT__PUBLIC_API_URL, + "xmlsec_binary": XMLSEC_BINARY_PATH, + # Service provider settings + "service": { + "sp": { + "name": "tracecat_saml_sp", + "description": "Tracecat SAML SSO Service Provider", + "endpoints": { + "assertion_consumer_service": [ + (SAML_SP_ACS_URL, BINDING_HTTP_POST), + ], + }, + # Security settings + "allow_unsolicited": True, # If true, it allows the IdP to initiate the authentication + "authn_requests_signed": False, # Don't need to sign authn requests because we don't control the IdP + "want_assertions_signed": True, # We require the IdP to sign the assertions + "want_response_signed": False, + }, + }, + } + + if SAML_IDP_METADATA_URL is None: + with generate_saml_metadata_file() as tmp_metadata_path: + # Add the local metadata file to the settings + saml_settings["metadata"] = {"local": [tmp_metadata_path]} + config = Saml2Config() + config.load(saml_settings) + else: + # Save the cert to a temporary file + with tempfile.NamedTemporaryFile(mode="w+", suffix=".crt") as tmp_file: + tmp_file.write(SAML_IDP_CERTIFICATE) + tmp_file.flush() + saml_settings["metadata"] = { + "remote": [ + { + "url": SAML_IDP_METADATA_URL, + "cert": tmp_file.name, # Path to cert + } + ] + } + config = Saml2Config() + config.load(saml_settings) + + client = Saml2Client(config) + return client + + +SamlClientDep = Annotated[Saml2Client, Depends(create_saml_client)] + + +@router.get("/login", name=f"saml:{auth_backend.name}.login") +async def login(client: SamlClientDep) -> SAMLDatabaseLoginResponse: + _, info = client.prepare_for_authenticate() + try: + headers = info["headers"] + # Select the IdP URL to send the AuthN request to + redirect_url = next(v for k, v in headers if k == "Location") + except (KeyError, StopIteration): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Redirect URL not found in the SAML response.", + ) from None + # Return the redirect URL + return SAMLDatabaseLoginResponse(redirect_url=redirect_url) + + +@router.post("/acs") +async def sso_acs( + request: Request, + *, + saml_response: str = Form(..., alias="SAMLResponse"), + user_manager: UserManagerDep, + strategy: AuthBackendStrategyDep, + client: SamlClientDep, +) -> Response: + """Handle the SAML SSO response from the IdP post-authentication.""" + + # Get email in the SAML response from the IdP + authn_response = client.parse_authn_request_response( + saml_response, BINDING_HTTP_POST + ) + parser = SAMLParser(str(authn_response)) + email = parser.get_attribute_value("email") + + # Try to get the user from the database + try: + user = await user_manager.saml_callback( + email=email, + associate_by_email=True, # Assuming we want to associate by email + is_verified_by_default=True, # Assuming SAML-authenticated users are verified by default + ) + except UserAlreadyExists: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="User already exists", + ) from None + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Bad credentials", + ) from None + + # Authenticate + response = await auth_backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + return response diff --git a/tracecat/auth/users.py b/tracecat/auth/users.py index 5453de65e..34ee09260 100644 --- a/tracecat/auth/users.py +++ b/tracecat/auth/users.py @@ -1,6 +1,8 @@ import contextlib +import os import uuid from collections.abc import AsyncGenerator, Awaitable +from typing import Annotated from fastapi import APIRouter, Depends, Request, Response, status from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models @@ -14,14 +16,14 @@ DatabaseStrategy, ) from fastapi_users.db import SQLAlchemyUserDatabase -from fastapi_users.exceptions import UserAlreadyExists +from fastapi_users.exceptions import UserAlreadyExists, UserNotExists from fastapi_users.openapi import OpenAPIResponseType from sqlalchemy.ext.asyncio import AsyncSession as SQLAlchemyAsyncSession from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession as SQLModelAsyncSession from tracecat import config -from tracecat.auth.schemas import UserCreate, UserRole +from tracecat.auth.models import UserCreate, UserRole from tracecat.db.adapter import ( SQLModelAccessTokenDatabaseAsync, SQLModelUserDatabaseAsync, @@ -58,6 +60,39 @@ async def on_after_request_verify( f"Verification requested for user {user.id}. Verification token: {token}" ) + async def saml_callback( + self, + *, + email: str, + associate_by_email: bool = True, + is_verified_by_default: bool = True, + ) -> User: + """ + Handle the callback after a successful SAML authentication. + + :param email: Email of the user from SAML response. + :param associate_by_email: If True, associate existing user with the same email. Defaults to True. + :param is_verified_by_default: If True, set is_verified flag for new users. Defaults to True. + :return: A user. + """ + try: + user = await self.get_by_email(email) + if not associate_by_email: + raise UserAlreadyExists() + except UserNotExists: + # Create account + password = self.password_helper.generate() + user_dict = { + "email": email, + "hashed_password": self.password_helper.hash(password), + "is_verified": is_verified_by_default, + } + user = await self.user_db.create(user_dict) + await self.on_after_register(user) + + self.logger.info(f"User {user.id} authenticated via SAML.") + return user + async def get_user_db(session: SQLAlchemyAsyncSession = Depends(get_async_session)): yield SQLAlchemyUserDatabase(session, User, OAuthAccount) @@ -110,6 +145,11 @@ def get_database_strategy( get_strategy=get_database_strategy, ) +AuthBackendStrategyDep = Annotated[ + Strategy[models.UP, models.ID], Depends(auth_backend.get_strategy) +] +UserManagerDep = Annotated[UserManager, Depends(get_user_manager)] + class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): def get_logout_router( @@ -193,11 +233,16 @@ async def list_users(*, session: SQLModelAsyncSession) -> list[User]: def default_admin_user() -> UserCreate: + if not (email := os.getenv("TRACECAT__SETUP_ADMIN_EMAIL")): + raise ValueError("TRACECAT__SETUP_ADMIN_EMAIL is not set") + if not (password := os.getenv("TRACECAT__SETUP_ADMIN_PASSWORD")): + raise ValueError("TRACECAT__SETUP_ADMIN_PASSWORD is not set") + return UserCreate( - email="admin@domain.com", - first_name="Admin", + email=email, + first_name="Root", last_name="User", - password="password", + password=password, is_superuser=True, is_verified=True, role=UserRole.ADMIN, diff --git a/tracecat/config.py b/tracecat/config.py index fb9912108..6e31629a0 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -86,6 +86,16 @@ ) USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") +# SAML SSO +SAML_IDP_ENTITY_ID = os.environ.get("SAML_IDP_ENTITY_ID") +SAML_IDP_REDIRECT_URL = os.environ.get("SAML_IDP_REDIRECT_URL") +SAML_IDP_CERTIFICATE = os.environ.get("SAML_IDP_CERTIFICATE") +SAML_IDP_METADATA_URL = os.environ.get("SAML_IDP_METADATA_URL") +SAML_SP_ACS_URL = os.environ.get( + "SAML_SP_ACS_URL", "http://localhost/api/auth/saml/acs" +) +XMLSEC_BINARY_PATH = os.environ.get("XMLSEC_BINARY_PATH", "/usr/bin/xmlsec1") + # === CORS config === # # NOTE: If you are using Tracecat self-hosted, please replace with your # own domain by setting the comma separated TRACECAT__ALLOW_ORIGINS env var. diff --git a/tracecat/db/schemas.py b/tracecat/db/schemas.py index 4f40bf07f..3d9da7782 100644 --- a/tracecat/db/schemas.py +++ b/tracecat/db/schemas.py @@ -12,7 +12,7 @@ from sqlmodel import UUID, Field, Relationship, SQLModel, UniqueConstraint from tracecat import config -from tracecat.auth.schemas import UserRole +from tracecat.auth.models import UserRole from tracecat.db.adapter import ( SQLModelBaseAccessToken, SQLModelBaseOAuthAccount, diff --git a/tracecat/workspaces/models.py b/tracecat/workspaces/models.py index e6075e124..f17402c58 100644 --- a/tracecat/workspaces/models.py +++ b/tracecat/workspaces/models.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, EmailStr, Field from tracecat import config -from tracecat.auth.schemas import UserRole +from tracecat.auth.models import UserRole from tracecat.identifiers import OwnerID, UserID, WorkspaceID # === Workspace === #