Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docs/examples/litestar_extension_migrations_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Example demonstrating how to use Litestar extension migrations with SQLSpec.

This example shows how to configure SQLSpec to include Litestar's session table
migrations, which will create dialect-specific tables when you run migrations.
"""

from pathlib import Path

from litestar import Litestar

from sqlspec.adapters.sqlite.config import SqliteConfig
from sqlspec.extensions.litestar.plugin import SQLSpec
from sqlspec.extensions.litestar.store import SQLSpecSessionStore
from sqlspec.migrations.commands import MigrationCommands

# Configure database with extension migrations enabled
db_config = SqliteConfig(
pool_config={"database": "app.db"},
migration_config={
"script_location": "migrations",
"version_table_name": "ddl_migrations",
# Enable Litestar extension migrations
"include_extensions": ["litestar"],
},
)

# Create SQLSpec plugin with session store
sqlspec_plugin = SQLSpec(db_config)

# Configure session store to use the database
session_store = SQLSpecSessionStore(
config=db_config,
table_name="litestar_sessions", # Matches migration table name
)

# Create Litestar app with SQLSpec and sessions
app = Litestar(plugins=[sqlspec_plugin], stores={"sessions": session_store})


def run_migrations() -> None:
"""Run database migrations including extension migrations.

This will:
1. Create your project's migrations (from migrations/ directory)
2. Create Litestar extension migrations (session table with dialect-specific types)
"""
commands = MigrationCommands(db_config)

# Initialize migrations directory if it doesn't exist
migrations_dir = Path("migrations")
if not migrations_dir.exists():
commands.init("migrations")

# Run all migrations including extension migrations
# The session table will be created with:
# - JSONB for PostgreSQL
# - JSON for MySQL/MariaDB
# - TEXT for SQLite
commands.upgrade()

# Check current version
current = commands.current(verbose=True)
print(f"Current migration version: {current}")


if __name__ == "__main__":
# Run migrations before starting the app
run_migrations()

# Start the application
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
166 changes: 166 additions & 0 deletions docs/examples/litestar_session_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Example showing how to use SQLSpec session backend with Litestar."""

from typing import Any

from litestar import Litestar, get, post
from litestar.config.session import SessionConfig
from litestar.connection import Request
from litestar.datastructures import State

from sqlspec.adapters.sqlite.config import SqliteConfig
from sqlspec.extensions.litestar import SQLSpec, SQLSpecSessionBackend, SQLSpecSessionConfig

# Configure SQLSpec with SQLite database
# Include Litestar extension migrations to automatically create session tables
sqlite_config = SqliteConfig(
pool_config={"database": "sessions.db"},
migration_config={
"script_location": "migrations",
"version_table_name": "sqlspec_migrations",
"include_extensions": ["litestar"], # Include Litestar session table migrations
},
)

# Create SQLSpec plugin
sqlspec_plugin = SQLSpec(sqlite_config)

# Create session backend using SQLSpec
# Note: The session table will be created automatically when you run migrations
# Example: sqlspec migrations upgrade --head
session_backend = SQLSpecSessionBackend(
config=SQLSpecSessionConfig(
table_name="litestar_sessions",
session_id_column="session_id",
data_column="data",
expires_at_column="expires_at",
created_at_column="created_at",
)
)

# Configure session middleware
session_config = SessionConfig(
backend=session_backend,
cookie_https_only=False, # Set to True in production
cookie_secure=False, # Set to True in production with HTTPS
cookie_domain="localhost",
cookie_path="/",
cookie_max_age=3600,
cookie_same_site="lax",
cookie_http_only=True,
session_cookie_name="sqlspec_session",
)


@get("/")
async def index() -> dict[str, str]:
"""Homepage route."""
return {"message": "SQLSpec Session Example"}


@get("/login")
async def login_form() -> str:
"""Simple login form."""
return """
<html>
<body>
<h2>Login</h2>
<form method="post" action="/login">
<input type="text" name="username" placeholder="Username" required>
<input type="password" name="password" placeholder="Password" required>
<button type="submit">Login</button>
</form>
</body>
</html>
"""


@post("/login")
async def login(data: dict[str, str], request: "Request[Any, Any, Any]") -> dict[str, str]:
"""Handle login and create session."""
username = data.get("username")
password = data.get("password")

# Simple authentication (use proper auth in production)
if username == "admin" and password == "secret":
# Store user data in session
request.set_session(
{"user_id": 1, "username": username, "login_time": "2024-01-01T12:00:00Z", "roles": ["admin", "user"]}
)
return {"message": f"Welcome, {username}!"}

return {"error": "Invalid credentials"}


@get("/profile")
async def profile(request: "Request[Any, Any, Any]") -> dict[str, str]:
"""User profile route - requires session."""
session_data = request.session

if not session_data or "user_id" not in session_data:
return {"error": "Not logged in"}

return {
"user_id": session_data["user_id"],
"username": session_data["username"],
"login_time": session_data["login_time"],
"roles": session_data["roles"],
}


@post("/logout")
async def logout(request: "Request[Any, Any, Any]") -> dict[str, str]:
"""Logout and clear session."""
request.clear_session()
return {"message": "Logged out successfully"}


@get("/admin/sessions")
async def admin_sessions(request: "Request[Any, Any, Any]", state: State) -> dict[str, any]:
"""Admin route to view all active sessions."""
session_data = request.session

if not session_data or "admin" not in session_data.get("roles", []):
return {"error": "Admin access required"}

# Get session backend from state
backend = session_backend
session_ids = await backend.get_all_session_ids()

return {
"active_sessions": len(session_ids),
"session_ids": session_ids[:10], # Limit to first 10 for display
}


@post("/admin/cleanup")
async def cleanup_sessions(request: "Request[Any, Any, Any]", state: State) -> dict[str, str]:
"""Admin route to clean up expired sessions."""
session_data = request.session

if not session_data or "admin" not in session_data.get("roles", []):
return {"error": "Admin access required"}

# Clean up expired sessions
backend = session_backend
await backend.delete_expired_sessions()

return {"message": "Expired sessions cleaned up"}


# Create Litestar application
app = Litestar(
route_handlers=[index, login_form, login, profile, logout, admin_sessions, cleanup_sessions],
plugins=[sqlspec_plugin],
session_config=session_config,
debug=True,
)


if __name__ == "__main__":
import uvicorn

print("Starting SQLSpec Session Example...")
print("Visit http://localhost:8000 to view the application")
print("Login with username 'admin' and password 'secret'")

uvicorn.run(app, host="0.0.0.0", port=8000)
5 changes: 4 additions & 1 deletion sqlspec/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
FSSPEC_INSTALLED = bool(find_spec("fsspec"))
OBSTORE_INSTALLED = bool(find_spec("obstore"))
PGVECTOR_INSTALLED = bool(find_spec("pgvector"))

UUID_UTILS_INSTALLED = bool(find_spec("uuid_utils"))
NANOID_INSTALLED = bool(find_spec("fastnanoid"))

__all__ = (
"AIOSQL_INSTALLED",
Expand All @@ -617,6 +618,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
"FSSPEC_INSTALLED",
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"NANOID_INSTALLED",
"OBSTORE_INSTALLED",
"OPENTELEMETRY_INSTALLED",
"PGVECTOR_INSTALLED",
Expand All @@ -625,6 +627,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
"PYDANTIC_INSTALLED",
"UNSET",
"UNSET_STUB",
"UUID_UTILS_INSTALLED",
"AiosqlAsyncProtocol",
"AiosqlParamType",
"AiosqlProtocol",
Expand Down
14 changes: 12 additions & 2 deletions sqlspec/adapters/asyncmy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class AsyncmyExceptionHandler:
async def __aenter__(self) -> None:
return None

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "Optional[bool]":
if exc_type is None:
return
return None

if issubclass(exc_type, asyncmy.errors.IntegrityError):
e = exc_val
Expand All @@ -102,6 +102,15 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
raise SQLSpecError(msg) from e
if issubclass(exc_type, asyncmy.errors.OperationalError):
e = exc_val
# Handle specific MySQL errors that are expected in migrations
if hasattr(e, "args") and len(e.args) >= 1 and isinstance(e.args[0], int):
error_code = e.args[0]
# Error 1061: Duplicate key name (index already exists)
# Error 1091: Can't DROP index that doesn't exist
if error_code in {1061, 1091}:
# These are acceptable during migrations - log and continue
logger.warning("AsyncMy MySQL expected migration error (ignoring): %s", e)
return True # Suppress the exception by returning True
msg = f"AsyncMy MySQL operational error: {e}"
raise SQLSpecError(msg) from e
if issubclass(exc_type, asyncmy.errors.DatabaseError):
Expand All @@ -120,6 +129,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
raise SQLParsingError(msg) from e
msg = f"Unexpected async database operation error: {e}"
raise SQLSpecError(msg) from e
return None


class AsyncmyDriver(AsyncDriverAdapterBase):
Expand Down
3 changes: 2 additions & 1 deletion sqlspec/adapters/asyncpg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
PostgreSQL COPY operation support, and transaction management.
"""

import datetime
import re
from typing import TYPE_CHECKING, Any, Final, Optional

Expand Down Expand Up @@ -36,7 +37,7 @@
supported_parameter_styles={ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT},
default_execution_parameter_style=ParameterStyle.NUMERIC,
supported_execution_parameter_styles={ParameterStyle.NUMERIC},
type_coercion_map={},
type_coercion_map={datetime.datetime: lambda x: x, datetime.date: lambda x: x, datetime.time: lambda x: x},
has_native_list_expansion=True,
needs_static_script_compilation=False,
preserve_parameter_format=True,
Expand Down
3 changes: 2 additions & 1 deletion sqlspec/adapters/oracledb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlspec.core.statement import StatementConfig
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
from sqlspec.exceptions import SQLParsingError, SQLSpecError
from sqlspec.utils.serializers import to_json

if TYPE_CHECKING:
from contextlib import AbstractAsyncContextManager, AbstractContextManager
Expand All @@ -38,7 +39,7 @@
supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON, ParameterStyle.QMARK},
default_execution_parameter_style=ParameterStyle.POSITIONAL_COLON,
supported_execution_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.POSITIONAL_COLON},
type_coercion_map={},
type_coercion_map={dict: to_json, list: to_json},
has_native_list_expansion=False,
needs_static_script_compilation=True,
preserve_parameter_format=True,
Expand Down
Loading
Loading