diff --git a/examples/pydantic_models/example_006_boolean_fields.py b/examples/pydantic_models/example_006_boolean_fields.py new file mode 100644 index 0000000..5ac8048 --- /dev/null +++ b/examples/pydantic_models/example_006_boolean_fields.py @@ -0,0 +1,31 @@ +"""Example demonstrating boolean field handling in pydantic models.""" + +from typing import Annotated + +import pydantic +import typer + +import pydantic_typer + + +class Settings(pydantic.BaseModel): + """Settings with boolean fields.""" + + enable_feature: Annotated[ + bool, pydantic.Field(description="Enable the feature.") + ] = True + debug_mode: Annotated[bool, pydantic.Field(description="Enable debug mode.")] = ( + False + ) + verbose: bool = False + + +def main(settings: Settings): + """Main function that uses settings with boolean fields.""" + typer.echo(f"enable_feature={settings.enable_feature}") + typer.echo(f"debug_mode={settings.debug_mode}") + typer.echo(f"verbose={settings.verbose}") + + +if __name__ == "__main__": + pydantic_typer.run(main) diff --git a/examples/pydantic_types/example_011_annotated_validators.py b/examples/pydantic_types/example_011_annotated_validators.py new file mode 100644 index 0000000..3b3ed5c --- /dev/null +++ b/examples/pydantic_types/example_011_annotated_validators.py @@ -0,0 +1,151 @@ +"""Examples of Annotated validators (BeforeValidator, AfterValidator, etc.).""" + +from datetime import timedelta +from typing import Annotated, Any + +import typer +from pydantic import ( + AfterValidator, + BaseModel, + BeforeValidator, + EmailStr, + Field, + HttpUrl, +) + +from pydantic_typer import Typer + + +# BeforeValidator example with timedelta +def parse_timedelta(value: Any) -> timedelta: + """Parse timedelta from seconds (int/float/str) or ISO 8601 format.""" + if isinstance(value, timedelta): + return value + if isinstance(value, (int, float)): + return timedelta(seconds=value) + if isinstance(value, str): + # Try to parse as a number first + try: + return timedelta(seconds=float(value)) + except ValueError: + # If it fails, it might be ISO 8601 format + # Let pydantic's default parser handle it + pass + return value + + +FlexibleTimedelta = Annotated[timedelta, BeforeValidator(parse_timedelta)] + + +class Settings(BaseModel): + """Settings with a flexible timedelta field.""" + + min_duration: Annotated[ + FlexibleTimedelta, + Field(description="The minimum testing duration"), + ] = timedelta(seconds=5) + + +# AfterValidator example with temperature +def validate_temperature(value: float) -> float: + """Validate temperature is in reasonable range (AfterValidator).""" + if value < -273.15: + raise ValueError("Temperature cannot be below absolute zero (-273.15°C)") + if value > 1000: + raise ValueError("Temperature is unreasonably high (max 1000°C)") + return value + + +Temperature = Annotated[float, AfterValidator(validate_temperature)] + + +class TemperatureConfig(BaseModel): + """Config with temperature validation using AfterValidator.""" + + temp: Annotated[ + Temperature, + Field(description="Temperature in Celsius"), + ] = 20.0 + + +# Combined BeforeValidator and AfterValidator examples +def normalize_url(value: Any) -> str: + """Add https:// scheme if missing (BeforeValidator for HttpUrl).""" + if isinstance(value, str) and not value.startswith(("http://", "https://")): + return f"https://{value}" + return value + + +def validate_domain(value: HttpUrl) -> HttpUrl: + """Validate URL has allowed domain (AfterValidator).""" + allowed_domains = ["example.com", "test.com", "localhost"] + if value.host and not any( + value.host.endswith(domain) for domain in allowed_domains + ): + raise ValueError(f"Domain must be one of: {', '.join(allowed_domains)}") + return value + + +FlexibleHttpUrl = Annotated[ + HttpUrl, BeforeValidator(normalize_url), AfterValidator(validate_domain) +] + + +class WebConfig(BaseModel): + """Config with flexible URL handling.""" + + api_url: Annotated[ + FlexibleHttpUrl, + Field(description="API endpoint URL"), + ] + + +def normalize_email(value: Any) -> str: + """Normalize email to lowercase (BeforeValidator for EmailStr).""" + if isinstance(value, str): + return value.lower().strip() + return value + + +NormalizedEmail = Annotated[EmailStr, BeforeValidator(normalize_email)] + + +class UserConfig(BaseModel): + """Config with email normalization.""" + + email: Annotated[ + NormalizedEmail, + Field(description="User email address"), + ] = "user@example.com" + + +# CLI applications +app = Typer() + + +@app.command(name="settings") +def settings_command(settings: Settings): + """Process settings with flexible timedelta.""" + typer.echo(f"Duration: {settings.min_duration}") + + +@app.command(name="temperature") +def temperature_command(config: TemperatureConfig): + """Process temperature configuration.""" + typer.echo(f"Temperature: {config.temp}°C") + + +@app.command(name="web") +def web_command(config: WebConfig): + """Process web configuration with URL validation.""" + typer.echo(f"API URL: {config.api_url}") + + +@app.command(name="user") +def user_command(config: UserConfig): + """Process user configuration with email normalization.""" + typer.echo(f"Email: {config.email}") + + +if __name__ == "__main__": + app() diff --git a/examples/pydantic_types/example_012_field_validator.py b/examples/pydantic_types/example_012_field_validator.py new file mode 100644 index 0000000..d28a0c2 --- /dev/null +++ b/examples/pydantic_types/example_012_field_validator.py @@ -0,0 +1,138 @@ +"""Examples of field_validator decorators with pydantic-typer.""" + +from datetime import timedelta +from ipaddress import IPv4Address +from typing import Annotated, Any + +import typer +from pydantic import BaseModel, Field, HttpUrl, field_validator + +from pydantic_typer import Typer + + +class SettingsWithFieldValidator(BaseModel): + """Settings using field_validator decorator.""" + + min_duration: Annotated[ + timedelta, + Field(description="The minimum testing duration"), + ] = timedelta(seconds=5) + + @field_validator("min_duration", mode="before") + @classmethod + def parse_min_duration(cls, v: Any) -> timedelta: + """Parse timedelta from seconds or ISO 8601 format.""" + if isinstance(v, timedelta): + return v + if isinstance(v, (int, float)): + return timedelta(seconds=v) + if isinstance(v, str): + try: + return timedelta(seconds=float(v)) + except ValueError: + pass + return v + + +class ConfigWithAfterValidator(BaseModel): + """Config using field_validator with mode='after'.""" + + temperature: Annotated[ + float, + Field(description="Temperature in Celsius"), + ] = 20.0 + + @field_validator("temperature", mode="after") + @classmethod + def validate_temperature_range(cls, v: float) -> float: + """Ensure temperature is in valid range after conversion.""" + if v < -273.15: + raise ValueError("Temperature cannot be below absolute zero (-273.15°C)") + if v > 1000: + raise ValueError("Temperature is unreasonably high (max 1000°C)") + return v + + +class ServerConfig(BaseModel): + """Server configuration with HttpUrl and field validators.""" + + api_url: Annotated[ + HttpUrl, + Field(description="API server URL"), + ] + + @field_validator("api_url", mode="before") + @classmethod + def normalize_url(cls, v: Any) -> str: + """Add https:// if scheme is missing.""" + if isinstance(v, str) and not v.startswith(("http://", "https://")): + return f"https://{v}" + return v + + @field_validator("api_url", mode="after") + @classmethod + def validate_port(cls, v: HttpUrl) -> HttpUrl: + """Ensure port is >= 1024 (non-privileged) if specified.""" + if v.port and v.port < 1024: + raise ValueError(f"Port {v.port} is reserved (must be >= 1024)") + return v + + +class NetworkConfig(BaseModel): + """Network configuration with IPv4Address and validators.""" + + server_ip: Annotated[ + IPv4Address, + Field(description="Server IP address"), + ] = IPv4Address("127.0.0.1") + + @field_validator("server_ip", mode="before") + @classmethod + def parse_ip(cls, v: Any) -> str: + """Allow 'localhost' as alias for 127.0.0.1.""" + if isinstance(v, str): + if v.lower() == "localhost": + return "127.0.0.1" + # Strip whitespace + return v.strip() + return v + + @field_validator("server_ip", mode="after") + @classmethod + def validate_not_zero(cls, v: IPv4Address) -> IPv4Address: + """Ensure IP is not 0.0.0.0.""" + if str(v) == "0.0.0.0": + raise ValueError("IP address cannot be 0.0.0.0") + return v + + +# CLI applications +app = Typer() + + +@app.command(name="settings") +def settings_command(settings: SettingsWithFieldValidator): + """Process settings with field validator.""" + typer.echo(f"Duration: {settings.min_duration}") + + +@app.command(name="temperature") +def temperature_command(config: ConfigWithAfterValidator): + """Process temperature configuration.""" + typer.echo(f"Temperature: {config.temperature}°C") + + +@app.command(name="server") +def server_command(config: ServerConfig): + """Process server configuration.""" + typer.echo(f"API: {config.api_url}") + + +@app.command(name="network") +def network_command(config: NetworkConfig): + """Process network configuration.""" + typer.echo(f"Server IP: {config.server_ip}") + + +if __name__ == "__main__": + app() diff --git a/examples/pydantic_types/example_013_model_validator.py b/examples/pydantic_types/example_013_model_validator.py new file mode 100644 index 0000000..3468f6c --- /dev/null +++ b/examples/pydantic_types/example_013_model_validator.py @@ -0,0 +1,191 @@ +"""Examples of model_validator decorators with pydantic-typer.""" + +from typing import Annotated, Any + +import typer +from pydantic import BaseModel, EmailStr, Field, HttpUrl, SecretStr, model_validator + +from pydantic_typer import Typer + + +class RangeConfig(BaseModel): + """Config with range validation using model_validator.""" + + min_value: Annotated[ + int, + Field(description="Minimum value"), + ] = 0 + + max_value: Annotated[ + int, + Field(description="Maximum value"), + ] = 100 + + @model_validator(mode="before") + @classmethod + def validate_before(cls, data: Any) -> Any: + """Validate and normalize data before field validation.""" + if isinstance(data, dict): + # Ensure min_value is not negative + if "min_value" in data and int(data["min_value"]) < 0: + data["min_value"] = 0 + return data + + @model_validator(mode="after") + def validate_after(self) -> "RangeConfig": + """Validate that min_value < max_value after field validation.""" + if self.min_value >= self.max_value: + raise ValueError( + f"min_value ({self.min_value}) must be less than max_value ({self.max_value})" + ) + return self + + +class UserProfile(BaseModel): + """Profile with multiple model validators.""" + + username: Annotated[str, Field(description="Username")] + email: Annotated[str, Field(description="Email address")] + age: Annotated[int, Field(description="Age")] = 0 + + @model_validator(mode="before") + @classmethod + def normalize_before(cls, data: Any) -> Any: + """Normalize data before field validation.""" + if isinstance(data, dict): + # Lowercase username + if "username" in data: + data["username"] = data["username"].lower() + # Lowercase email + if "email" in data: + data["email"] = data["email"].lower() + return data + + @model_validator(mode="after") + def validate_after(self) -> "UserProfile": + """Validate profile after field validation.""" + # Ensure email matches username domain + if "@" in self.email: + email_user = self.email.split("@")[0] + if email_user != self.username: + raise ValueError("Email username must match profile username") + return self + + +class ConfigWithDefaults(BaseModel): + """Config that uses model_validator to set computed defaults.""" + + base_path: Annotated[str, Field(description="Base path")] = "/tmp" + cache_path: Annotated[str, Field(description="Cache path")] = "" + log_path: Annotated[str, Field(description="Log path")] = "" + + @model_validator(mode="after") + def set_defaults(self) -> "ConfigWithDefaults": + """Set default paths based on base_path if not provided.""" + if not self.cache_path: + self.cache_path = f"{self.base_path}/cache" + if not self.log_path: + self.log_path = f"{self.base_path}/logs" + return self + + +class ServiceConfig(BaseModel): + """Service configuration with HttpUrl, EmailStr and model validators.""" + + api_url: Annotated[HttpUrl, Field(description="API endpoint")] + webhook_url: Annotated[HttpUrl, Field(description="Webhook endpoint")] + admin_email: Annotated[EmailStr, Field(description="Admin email")] + + @model_validator(mode="before") + @classmethod + def normalize_urls(cls, data: Any) -> Any: + """Normalize URLs by adding https:// if scheme is missing.""" + if isinstance(data, dict): + for key in ["api_url", "webhook_url"]: + if key in data and isinstance(data[key], str): + if not data[key].startswith(("http://", "https://")): + data[key] = f"https://{data[key]}" + # Normalize email to lowercase + if "admin_email" in data and isinstance(data["admin_email"], str): + data["admin_email"] = data["admin_email"].lower().strip() + return data + + @model_validator(mode="after") + def validate_same_domain(self) -> "ServiceConfig": + """Ensure API and webhook URLs are on the same domain.""" + if self.api_url.host != self.webhook_url.host: + raise ValueError( + f"API URL ({self.api_url.host}) and webhook URL ({self.webhook_url.host}) " + "must be on the same domain" + ) + return self + + +class DatabaseConfig(BaseModel): + """Database configuration with SecretStr and validators.""" + + host: Annotated[str, Field(description="Database host")] = "localhost" + port: Annotated[int, Field(description="Database port")] = 5432 + username: Annotated[str, Field(description="Database username")] + password: Annotated[SecretStr, Field(description="Database password")] + + @model_validator(mode="before") + @classmethod + def set_defaults(cls, data: Any) -> Any: + """Set default host based on username if not provided.""" + if isinstance(data, dict): + # If username suggests local development, ensure localhost + if "username" in data and data["username"] in ["dev", "test"]: + if "host" not in data or not data["host"]: + data["host"] = "localhost" + return data + + @model_validator(mode="after") + def validate_security(self) -> "DatabaseConfig": + """Validate security requirements for production.""" + # If host is not localhost, password must be strong + if self.host != "localhost": + pwd = self.password.get_secret_value() + if len(pwd) < 8: + raise ValueError( + "Production database password must be at least 8 characters" + ) + return self + + +# CLI applications +app = Typer() + + +@app.command(name="range") +def range_command(config: RangeConfig): + """Process range configuration.""" + typer.echo(f"Range: {config.min_value} to {config.max_value}") + + +@app.command(name="profile") +def profile_command(profile: UserProfile): + """Process user profile.""" + typer.echo(f"Profile: {profile.username} ({profile.email})") + + +@app.command(name="paths") +def paths_command(config: ConfigWithDefaults): + """Process path configuration.""" + typer.echo(f"Paths: {config.cache_path}, {config.log_path}") + + +@app.command(name="service") +def service_command(config: ServiceConfig): + """Process service configuration.""" + typer.echo(f"API: {config.api_url}, Webhook: {config.webhook_url}") + + +@app.command(name="database") +def database_command(config: DatabaseConfig): + """Process database configuration.""" + typer.echo(f"DB: {config.username}@{config.host}:{config.port}") + + +if __name__ == "__main__": + app() diff --git a/examples/pydantic_types/example_014_optional_pydantic_types.py b/examples/pydantic_types/example_014_optional_pydantic_types.py new file mode 100644 index 0000000..020f545 --- /dev/null +++ b/examples/pydantic_types/example_014_optional_pydantic_types.py @@ -0,0 +1,38 @@ +"""Example demonstrating Optional fields with Pydantic types in models. + +This example tests the handling of Union types (Optional is Union[T, None]) +with Pydantic-specific types like DirectoryPath that need to be converted to str +for Click/Typer while preserving the proper option names. +""" + +import pydantic_typer +import typer +from pydantic import BaseModel, DirectoryPath, FilePath + + +class Config(BaseModel): + """Configuration with optional Pydantic type fields.""" + + input_dir: DirectoryPath | None = None + """Optional directory path for input files.""" + + output_file: FilePath | None = None + """Optional file path for output.""" + + name: str = "default" + """A regular string field for comparison.""" + + +def main(config: Config): + """Process configuration with optional Pydantic type fields.""" + typer.echo( + f"Input dir: {config.input_dir} (type: {type(config.input_dir).__name__})" + ) + typer.echo( + f"Output file: {config.output_file} (type: {type(config.output_file).__name__})" + ) + typer.echo(f"Name: {config.name} (type: {type(config.name).__name__})") + + +if __name__ == "__main__": + pydantic_typer.run(main) diff --git a/src/pydantic_typer/main.py b/src/pydantic_typer/main.py index 316d6eb..1bd6efd 100644 --- a/src/pydantic_typer/main.py +++ b/src/pydantic_typer/main.py @@ -4,13 +4,18 @@ import inspect import re from functools import wraps -from typing import Any, Callable, get_args, get_origin +from typing import Any, Callable, get_args, get_origin, get_type_hints import click import pydantic from typer import BadParameter, Option from typer import Typer as TyperBase -from typer.main import CommandFunctionType, get_click_param, get_params_from_function, lenient_issubclass +from typer.main import ( + CommandFunctionType, + get_click_param, + get_params_from_function, + lenient_issubclass, +) from typer.models import OptionInfo, ParameterInfo from typer.utils import ( AnnotatedParamWithDefaultValueError, @@ -29,6 +34,10 @@ def _flatten_pydantic_model( model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None ) -> dict[str, inspect.Parameter]: + # Get full type hints to preserve Annotated metadata like BeforeValidator, AfterValidator, etc. + # Using include_extras=True ensures validators attached via Annotated are preserved + type_hints = get_type_hints(model, include_extras=True) + pydantic_parameters = {} for field_name, field in model.model_fields.items(): qualifier = [*ancestors, field_name] @@ -39,30 +48,62 @@ def _flatten_pydantic_model( pydantic_parameters.update(params) else: default = ( - field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001 + field.default + if field.default is not pydantic.fields._Unset + else ... # noqa: SLF001 ) # Pydantic stores annotations in field.metadata. # If the field is already annotated with a typer.Option or typer.Argument, use that. - existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)] + existing_typer_params = [ + meta for meta in field.metadata if isinstance(meta, ParameterInfo) + ] typer_param: ParameterInfo if existing_typer_params: typer_param = existing_typer_params[0] if isinstance(typer_param, OptionInfo) and not typer_param.param_decls: # If the the option was not named manually, use the default naming scheme - typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",) + typer_param.param_decls = ( + f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}", + ) elif ancestor_typer_param: typer_param = ancestor_typer_param else: - typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") + typer_param = Option(..., f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") # Copy Field metadata to Option, fixes https://github.com/pypae/pydantic-typer/issues/2 if field.description and not typer_param.help: typer_param.help = field.description + # Get the full annotation from type hints to preserve validators (BeforeValidator, AfterValidator, etc.) + # Falls back to field.annotation if type hints are not available + full_field_annotation = type_hints.get(field_name, field.annotation) + + # If full_field_annotation is Annotated, we need to extract just the base type + # and the validators, stripping out any FieldInfo or ParameterInfo that pydantic adds + # (since we add our own FieldInfo metadata via typer_param) + if get_origin(full_field_annotation) is Annotated: + args = get_args(full_field_annotation) + # Filter out FieldInfo and ParameterInfo from metadata since we add our own metadata + filtered_metadata = tuple( + arg + for arg in args[1:] + if not isinstance(arg, (pydantic.fields.FieldInfo, ParameterInfo)) + ) + # Reconstruct Annotated with just the type and non-FieldInfo metadata (validators, etc.) + # Use Annotated.__class_getitem__ for Python 3.8 compatibility + if filtered_metadata: + field_annotation = Annotated.__class_getitem__( + (args[0],) + filtered_metadata + ) + else: + field_annotation = args[0] + else: + field_annotation = full_field_annotation + pydantic_parameters[sub_name] = inspect.Parameter( sub_name, inspect.Parameter.KEYWORD_ONLY, - annotation=Annotated[field.annotation, typer_param, qualifier], + annotation=Annotated[field_annotation, typer_param, qualifier], default=default, ) return pydantic_parameters @@ -90,14 +131,17 @@ def enable_pydantic(callback: CommandFunctionType) -> CommandFunctionType: pydantic_roots = {} other_parameters = {} for name, parameter in original_signature.parameters.items(): - base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation) + base_annotation, typer_annotations = _split_annotation_from_typer_annotations( + parameter.annotation + ) typer_param = typer_annotations[0] if typer_annotations else None if lenient_issubclass(base_annotation, pydantic.BaseModel): - params = _flatten_pydantic_model(parameter.annotation, [name], typer_param) + params = _flatten_pydantic_model(base_annotation, [name], typer_param) pydantic_parameters.update(params) pydantic_roots[name] = base_annotation elif get_origin(base_annotation) in (list, tuple) and any( - lenient_issubclass(arg, pydantic.BaseModel) for arg in get_args(base_annotation) + lenient_issubclass(arg, pydantic.BaseModel) + for arg in get_args(base_annotation) ): msg = f"Type not yet supported: {base_annotation}, see https://github.com/pypae/pydantic-typer/issues/6" raise RuntimeError(msg) @@ -119,7 +163,9 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] kwarg_value = kwargs[kwarg_name] converted_kwargs.pop(kwarg_name) annotation = pydantic_parameters[kwarg_name].annotation - _, qualifier = annotation.__metadata__ + # The last two items in metadata are always typer_param and qualifier + # There may be additional items before them (validators, etc.) + *_, typer_param, qualifier = annotation.__metadata__ for part in reversed(qualifier): kwarg_value = {part: kwarg_value} raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) @@ -129,11 +175,15 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] wrapper.__signature__ = extended_signature # type: ignore # Copy annotations to make forward references work in Python <= 3.9 - wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} + wrapper.__annotations__ = { + k: v.annotation for k, v in extended_signature.parameters.items() + } return wrapper def _recursive_replace_annotation(original_annotation, type_to_replace, replacement): + from typing import Union + if original_annotation == type_to_replace: return replacement @@ -142,7 +192,35 @@ def _recursive_replace_annotation(original_annotation, type_to_replace, replacem # This is a pydantic type with extra information, such as: # typing.Annotated[pydantic_core._pydantic_core.Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'], host_required=None, default_host=None, default_port=None, default_path=None)] return replacement - if origin in (Annotated, tuple, list): + if origin is Annotated: + args = get_args(original_annotation) + # Recursively process the base type (first arg) + base_type = args[0] + metadata = args[1:] + updated_base_type = _recursive_replace_annotation(base_type, type_to_replace, replacement) + + if updated_base_type != base_type: + # Type was replaced, only preserve ParameterInfo metadata (like OptionInfo), + # discard Pydantic-specific metadata (like PathType, UrlConstraints, etc.) + preserved_metadata = tuple(m for m in metadata if isinstance(m, (ParameterInfo, list))) + if preserved_metadata: + return Annotated.__class_getitem__((updated_base_type,) + preserved_metadata) + return updated_base_type + return original_annotation + if origin is Union: + # Handle Union types (including Optional which is Union[T, None]) + args = get_args(original_annotation) + updated_args = [] + changed = False + for arg in args: + updated_arg = _recursive_replace_annotation(arg, type_to_replace, replacement) + updated_args.append(updated_arg) + if updated_arg != arg: + changed = True + if changed: + return Union[tuple(updated_args)] + return original_annotation + if origin in (tuple, list): args = get_args(original_annotation) updated_args = [] for arg in args: @@ -152,7 +230,9 @@ def _recursive_replace_annotation(original_annotation, type_to_replace, replacem updated_args.append(arg) else: - updated_args.append(_recursive_replace_annotation(arg, type_to_replace, replacement)) + updated_args.append( + _recursive_replace_annotation(arg, type_to_replace, replacement) + ) return origin[tuple(updated_args)] return original_annotation @@ -174,7 +254,25 @@ def _parse_error_type(error_message: str) -> type | None: ParseStr = object() -def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFunctionType: +def _is_flattened_model_param(annotation: Any) -> bool: + """ + Check if a parameter is from a flattened Pydantic model. + + Flattened model parameters have a qualifier (list) in their metadata, + added by enable_pydantic(). For these parameters, we should skip TypeAdapter + validation and let the model construction handle it (including field_validator decorators). + """ + if get_origin(annotation) is not Annotated: + return False + + metadata = get_args(annotation)[1:] # Skip the base type + # Check if any metadata item is a list (the qualifier) + return any(isinstance(item, list) for item in metadata) + + +def enable_pydantic_type_validation( + callback: CommandFunctionType, +) -> CommandFunctionType: """ A decorator that ensures Pydantic validation is applied to parameters of Typer commands, including those with types not natively supported by Typer. @@ -211,6 +309,30 @@ def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFun if lenient_issubclass(param.annotation, click.Context): # click.Context should not be modified continue + + # Special handling for boolean fields from flattened pydantic models + # Typer treats bool with defaults as flags (--flag/--no-flag), but we want to support + # passing boolean values like: --settings.use_flag False + if _is_flattened_model_param(original_parameter.annotation): + # Check if the base type is bool + annotation = original_parameter.annotation + if get_origin(annotation) is Annotated: + base_type, *metadata = get_args(annotation) + if base_type is bool: + # Replace bool with str and preserve existing metadata + # This allows: --settings.use_flag False instead of requiring --no-settings.use_flag + # The boolean parsing will be handled by Pydantic when the model is constructed + updated_annotation = Annotated.__class_getitem__( + (str,) + tuple(metadata) + ) + updated_parameters[param_name] = inspect.Parameter( + param_name, + kind=original_parameter.kind, + default=original_parameter.default, + annotation=updated_annotation, + ) + continue + # We don't know wheter to use pydantic or typer to parse a param without checking if typer supports it. try: get_click_param(param) @@ -227,30 +349,62 @@ def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFun str, ) - updated_parameter = inspect.Parameter( - param_name, - kind=original_parameter.kind, - default=original_parameter.default, - annotation=Annotated[updated_annotation, ParsePython], - ) + # Preserve existing metadata (like OptionInfo) from the original annotation + # but replace the base type with str + if get_origin(updated_annotation) is Annotated: + # updated_annotation already has Annotated with metadata from _recursive_replace_annotation + # We just need to add ParsePython marker + args = get_args(updated_annotation) + updated_parameter = inspect.Parameter( + param_name, + kind=original_parameter.kind, + default=original_parameter.default, + annotation=Annotated.__class_getitem__( + (args[0],) + args[1:] + (ParsePython,) + ), + ) + else: + updated_parameter = inspect.Parameter( + param_name, + kind=original_parameter.kind, + default=original_parameter.default, + annotation=Annotated[updated_annotation, ParsePython], + ) updated_parameters[param_name] = updated_parameter except AssertionError as e: # Assertion error is raised for union and list types with complex sub-types, # which we support by using str and parsing that with pydantic. - if "List types with complex sub-types are not currently supported" in e.args: + if ( + "List types with complex sub-types are not currently supported" + in e.args + ): # TODO we don't support complex list types yet either. # Do not modify param, will be raised again by typer. continue if "Typer Currently doesn't support Union types" in e.args: - updated_parameters[param_name] = inspect.Parameter( - param_name, - kind=original_parameter.kind, - default=original_parameter.default, - annotation=Annotated[str, ParseStr], - ) + # Preserve existing metadata (like OptionInfo) from the original annotation + if get_origin(original_parameter.annotation) is Annotated: + args = get_args(original_parameter.annotation) + # Keep all original metadata and add ParseStr marker + updated_parameters[param_name] = inspect.Parameter( + param_name, + kind=original_parameter.kind, + default=original_parameter.default, + annotation=Annotated.__class_getitem__( + (str,) + args[1:] + (ParseStr,) + ), + ) + else: + updated_parameters[param_name] = inspect.Parameter( + param_name, + kind=original_parameter.kind, + default=original_parameter.default, + annotation=Annotated[str, ParseStr], + ) new_signature = inspect.Signature( - parameters=list(updated_parameters.values()), return_annotation=original_signature.return_annotation + parameters=list(updated_parameters.values()), + return_annotation=original_signature.return_annotation, ) @copy_type(callback) @@ -262,6 +416,13 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] updated_annotation = new_signature.parameters[name].annotation if annotation == updated_annotation: continue + + # Skip TypeAdapter validation for flattened model parameters. + # These will be validated when the full model is constructed in enable_pydantic(), + # which properly applies field_validator decorators. + if _is_flattened_model_param(annotation): + continue + # We only need to parse parameters where we changed the annotation try: type_adapter = pydantic.TypeAdapter(annotation) @@ -282,7 +443,9 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] wrapper.__signature__ = new_signature # type: ignore # Copy annotations to make forward references work in Python <= 3.9 - wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()} + wrapper.__annotations__ = { + k: v.annotation for k, v in new_signature.parameters.items() + } return wrapper diff --git a/unit_tests/pydantic_models/test_006_boolean_fields.py b/unit_tests/pydantic_models/test_006_boolean_fields.py new file mode 100644 index 0000000..4552ee1 --- /dev/null +++ b/unit_tests/pydantic_models/test_006_boolean_fields.py @@ -0,0 +1,116 @@ +"""Tests for boolean field handling in pydantic models.""" + +import importlib + +import pytest +from typer.testing import CliRunner + +import pydantic_typer + +runner = CliRunner() + + +@pytest.fixture +def mod(): + """Load the boolean fields example module.""" + return importlib.import_module( + "examples.pydantic_models.example_006_boolean_fields" + ) + + +@pytest.fixture +def app(mod): + """Create a Typer app with the example main function.""" + app = pydantic_typer.Typer() + app.command()(mod.main) + return app + + +def test_help(app): + """Test that help text is generated correctly for boolean fields.""" + result = runner.invoke(app, ["--help"]) + assert "Enable the feature." in result.output + assert "Enable debug mode." in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_false_value(app): + """Test passing False as a value to a boolean field.""" + result = runner.invoke(app, ["--settings.enable_feature", "False"]) + assert "enable_feature=False" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_true_value(app): + """Test passing True as a value to a boolean field.""" + result = runner.invoke(app, ["--settings.debug_mode", "True"]) + assert "debug_mode=True" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_lowercase_false(app): + """Test passing lowercase 'false' as a value to a boolean field.""" + result = runner.invoke(app, ["--settings.enable_feature", "false"]) + assert "enable_feature=False" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_lowercase_true(app): + """Test passing lowercase 'true' as a value to a boolean field.""" + result = runner.invoke(app, ["--settings.debug_mode", "true"]) + assert "debug_mode=True" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_default_true(app): + """Test that default value (True) is used when field is not provided.""" + result = runner.invoke(app, []) + assert "enable_feature=True" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_default_false(app): + """Test that default value (False) is used when field is not provided.""" + result = runner.invoke(app, []) + assert "debug_mode=False" in result.output + assert result.exit_code == 0 + + +def test_multiple_boolean_fields(app): + """Test setting multiple boolean fields at once.""" + result = runner.invoke( + app, + [ + "--settings.enable_feature", + "False", + "--settings.debug_mode", + "True", + "--settings.verbose", + "True", + ], + ) + assert "enable_feature=False" in result.output + assert "debug_mode=True" in result.output + assert "verbose=True" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_numeric_1(app): + """Test passing numeric 1 as a value (should be parsed as True by pydantic).""" + result = runner.invoke(app, ["--settings.debug_mode", "1"]) + assert "debug_mode=True" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_numeric_0(app): + """Test passing numeric 0 as a value (should be parsed as False by pydantic).""" + result = runner.invoke(app, ["--settings.enable_feature", "0"]) + assert "enable_feature=False" in result.output + assert result.exit_code == 0 + + +def test_boolean_field_with_invalid_value(app): + """Test passing an invalid value to a boolean field.""" + result = runner.invoke(app, ["--settings.enable_feature", "invalid"]) + assert result.exit_code != 0 + # Pydantic should raise a validation error diff --git a/unit_tests/pydantic_types/test_006_pydantic_types.py b/unit_tests/pydantic_types/test_006_pydantic_types.py index affd5da..bc7a7a8 100644 --- a/unit_tests/pydantic_types/test_006_pydantic_types.py +++ b/unit_tests/pydantic_types/test_006_pydantic_types.py @@ -20,12 +20,14 @@ def test_help(): def test_valid_input(): result = runner.invoke(app, ["2", "https://google.com"]) assert "2 " in result.output - assert "https://google.com/ " in result.output + assert "https://google.com/ " in result.output def test_invalid_url(): result = runner.invoke(app, ["2", "ftp://ftp.google.com"]) - assert "Invalid value for url: URL scheme should be 'http' or 'https'" in result.output + assert ( + "Invalid value for url: URL scheme should be 'http' or 'https'" in result.output + ) def test_script(): diff --git a/unit_tests/pydantic_types/test_011_annotated_validators.py b/unit_tests/pydantic_types/test_011_annotated_validators.py new file mode 100644 index 0000000..4532b9f --- /dev/null +++ b/unit_tests/pydantic_types/test_011_annotated_validators.py @@ -0,0 +1,183 @@ +"""Test that Annotated validators (BeforeValidator, AfterValidator, etc.) are preserved and work correctly.""" + +from datetime import timedelta + +from pydantic import HttpUrl +from typer.testing import CliRunner + +from examples.pydantic_types.example_011_annotated_validators import ( + Settings, + Temperature, + TemperatureConfig, + UserConfig, + WebConfig, + app, +) +from pydantic_typer import Typer + +runner = CliRunner() + + +class TestBeforeValidator: + """Test BeforeValidator functionality with Annotated types.""" + + def test_before_validator_with_int_input(self): + """Test that BeforeValidator works with integer string input.""" + app = Typer() + + @app.command() + def main(settings: Settings): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=10) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "10"]) + assert result.exit_code == 0 + assert "Duration: 0:00:10" in result.stdout + + def test_before_validator_with_float_input(self): + """Test that BeforeValidator works with float string input.""" + app = Typer() + + @app.command() + def main(settings: Settings): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=10.5) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "10.5"]) + assert result.exit_code == 0 + assert "Duration: 0:00:10.500000" in result.stdout + + def test_before_validator_with_iso8601_input(self): + """Test that BeforeValidator still allows ISO 8601 format.""" + app = Typer() + + @app.command() + def main(settings: Settings): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=30) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "PT30S"]) + assert result.exit_code == 0 + assert "Duration: 0:00:30" in result.stdout + + def test_before_validator_with_default(self): + """Test that default value works correctly.""" + app = Typer() + + @app.command() + def main(settings: Settings): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=5) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, []) + assert result.exit_code == 0 + assert "Duration: 0:00:05" in result.stdout + + +class TestAfterValidator: + """Test AfterValidator functionality with Annotated types.""" + + def test_after_validator_success(self): + """Test that AfterValidator works for valid values.""" + app = Typer() + + @app.command() + def main(config: TemperatureConfig): + assert isinstance(config.temp, float) + assert config.temp == 25.5 + print(f"Temperature: {config.temp}°C") + + result = runner.invoke(app, ["--config.temp", "25.5"]) + assert result.exit_code == 0 + assert "Temperature: 25.5°C" in result.stdout + + def test_after_validator_error_low(self): + """Test that AfterValidator catches values that are too low.""" + app = Typer() + + @app.command() + def main(config: TemperatureConfig): + print(f"Temperature: {config.temp}°C") + + result = runner.invoke(app, ["--config.temp", "-300"]) + assert result.exit_code != 0 + assert "Temperature cannot be below absolute zero (-273.15°C)" in str( + result.exception + ) + + def test_after_validator_error_high(self): + """Test that AfterValidator catches values that are too high.""" + app = Typer() + + @app.command() + def main(config: TemperatureConfig): + print(f"Temperature: {config.temp}°C") + + result = runner.invoke(app, ["--config.temp", "1500"]) + assert result.exit_code != 0 + assert "Temperature is unreasonably high (max 1000°C)" in str(result.exception) + + +class TestCombinedValidators: + """Test combined BeforeValidator and AfterValidator functionality.""" + + def test_httpurl_with_validators(self): + """Test that HttpUrl works with BeforeValidator and AfterValidator.""" + app = Typer() + + @app.command() + def main(config: WebConfig): + assert isinstance(config.api_url, HttpUrl) + assert str(config.api_url) == "https://api.example.com/" + print(f"API URL: {config.api_url}") + + # BeforeValidator adds https://, AfterValidator checks domain + result = runner.invoke(app, ["--config.api_url", "api.example.com"]) + assert result.exit_code == 0 + assert "https://api.example.com" in result.stdout + + def test_httpurl_validator_domain_error(self): + """Test that AfterValidator rejects invalid domains for HttpUrl.""" + app = Typer() + + @app.command() + def main(config: WebConfig): + print(f"API URL: {config.api_url}") + + result = runner.invoke(app, ["--config.api_url", "https://invalid.com"]) + assert result.exit_code != 0 + assert "Domain must be one of: example.com, test.com, localhost" in str( + result.exception + ) + + def test_emailstr_with_before_validator(self): + """Test that EmailStr works with BeforeValidator for normalization.""" + app = Typer() + + @app.command() + def main(config: UserConfig): + assert isinstance(config.email, str) + assert config.email == "john.doe@example.com" + print(f"Email: {config.email}") + + # BeforeValidator normalizes to lowercase + result = runner.invoke(app, ["--config.email", " JOHN.DOE@EXAMPLE.COM "]) + assert result.exit_code == 0 + assert "john.doe@example.com" in result.stdout + + def test_emailstr_validation_error(self): + """Test that EmailStr validation still works after BeforeValidator.""" + app = Typer() + + @app.command() + def main(config: UserConfig): + print(f"Email: {config.email}") + + result = runner.invoke(app, ["--config.email", "not-an-email"]) + assert result.exit_code != 0 + assert result.exception and ("email" in str(result.exception).lower()) + diff --git a/unit_tests/pydantic_types/test_012_field_validator.py b/unit_tests/pydantic_types/test_012_field_validator.py new file mode 100644 index 0000000..ea762cc --- /dev/null +++ b/unit_tests/pydantic_types/test_012_field_validator.py @@ -0,0 +1,203 @@ +"""Test that field_validator decorators work with pydantic-typer.""" + +from datetime import timedelta +from ipaddress import IPv4Address + +from pydantic import HttpUrl +from typer.testing import CliRunner + +from examples.pydantic_types.example_012_field_validator import ( + ConfigWithAfterValidator, + NetworkConfig, + ServerConfig, + SettingsWithFieldValidator, +) +from pydantic_typer import Typer + +runner = CliRunner() + + +class TestFieldValidatorBefore: + """Test field_validator with mode='before'.""" + + def test_field_validator_with_int_input(self): + """Test that field_validator works with integer string input.""" + app = Typer() + + @app.command() + def main(settings: SettingsWithFieldValidator): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=10) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "10"]) + assert result.exit_code == 0 + assert "Duration: 0:00:10" in result.stdout + + def test_field_validator_with_float_input(self): + """Test that field_validator works with float string input.""" + app = Typer() + + @app.command() + def main(settings: SettingsWithFieldValidator): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=10.5) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "10.5"]) + assert result.exit_code == 0 + assert "Duration: 0:00:10.500000" in result.stdout + + def test_field_validator_with_iso8601_input(self): + """Test that field_validator still allows ISO 8601 format.""" + app = Typer() + + @app.command() + def main(settings: SettingsWithFieldValidator): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=30) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, ["--settings.min_duration", "PT30S"]) + assert result.exit_code == 0 + assert "Duration: 0:00:30" in result.stdout + + def test_field_validator_with_default(self): + """Test that default value works correctly with field_validator.""" + app = Typer() + + @app.command() + def main(settings: SettingsWithFieldValidator): + assert isinstance(settings.min_duration, timedelta) + assert settings.min_duration == timedelta(seconds=5) + print(f"Duration: {settings.min_duration}") + + result = runner.invoke(app, []) + assert result.exit_code == 0 + assert "Duration: 0:00:05" in result.stdout + + +class TestFieldValidatorAfter: + """Test field_validator with mode='after'.""" + + def test_field_validator_after_mode_success(self): + """Test that field_validator with mode='after' works for valid values.""" + app = Typer() + + @app.command() + def main(config: ConfigWithAfterValidator): + assert config.temperature == 25.5 + print(f"Temperature: {config.temperature}°C") + + result = runner.invoke(app, ["--config.temperature", "25.5"]) + assert result.exit_code == 0 + assert "Temperature: 25.5°C" in result.stdout + + def test_field_validator_after_mode_error_low(self): + """Test that field_validator with mode='after' catches values that are too low.""" + app = Typer() + + @app.command() + def main(config: ConfigWithAfterValidator): + print(f"Temperature: {config.temperature}°C") + + result = runner.invoke(app, ["--config.temperature", "-300"]) + assert result.exit_code != 0 + assert "Temperature cannot be below absolute zero (-273.15°C)" in str( + result.exception + ) + + def test_field_validator_after_mode_error_high(self): + """Test that field_validator with mode='after' catches values that are too high.""" + app = Typer() + + @app.command() + def main(config: ConfigWithAfterValidator): + print(f"Temperature: {config.temperature}°C") + + result = runner.invoke(app, ["--config.temperature", "1500"]) + assert result.exit_code != 0 + assert "Temperature is unreasonably high (max 1000°C)" in str(result.exception) + + +class TestFieldValidatorWithPydanticTypes: + """Test field_validator with advanced Pydantic types.""" + + def test_httpurl_with_field_validator_before(self): + """Test that HttpUrl works with field_validator mode='before'.""" + app = Typer() + + @app.command() + def main(config: ServerConfig): + assert isinstance(config.api_url, HttpUrl) + assert str(config.api_url) == "https://api.example.com:8080/" + print(f"API: {config.api_url}") + + # Before validator adds https:// scheme + result = runner.invoke(app, ["--config.api_url", "api.example.com:8080"]) + assert result.exit_code == 0 + assert "https://api.example.com:8080" in result.stdout + + def test_httpurl_with_field_validator_after_error(self): + """Test that field_validator mode='after' catches privileged ports.""" + app = Typer() + + @app.command() + def main(config: ServerConfig): + print(f"API: {config.api_url}") + + result = runner.invoke(app, ["--config.api_url", "https://api.example.com:80"]) + assert result.exit_code != 0 + assert "Port 80 is reserved (must be >= 1024)" in str(result.exception) + + def test_ipv4address_with_field_validator(self): + """Test that IPv4Address works with field_validator.""" + app = Typer() + + @app.command() + def main(config: NetworkConfig): + assert isinstance(config.server_ip, IPv4Address) + assert str(config.server_ip) == "127.0.0.1" + print(f"Server IP: {config.server_ip}") + + # Before validator converts localhost to 127.0.0.1 + result = runner.invoke(app, ["--config.server_ip", "localhost"]) + assert result.exit_code == 0 + assert "127.0.0.1" in result.stdout + + def test_ipv4address_validation(self): + """Test that IPv4Address validation works.""" + app = Typer() + + @app.command() + def main(config: NetworkConfig): + print(f"Server IP: {config.server_ip}") + + # Valid IP + result = runner.invoke(app, ["--config.server_ip", "192.168.1.1"]) + assert result.exit_code == 0 + assert "192.168.1.1" in result.stdout + + def test_ipv4address_after_validator_error(self): + """Test that field_validator mode='after' rejects 0.0.0.0.""" + app = Typer() + + @app.command() + def main(config: NetworkConfig): + print(f"Server IP: {config.server_ip}") + + result = runner.invoke(app, ["--config.server_ip", "0.0.0.0"]) + assert result.exit_code != 0 + assert "IP address cannot be 0.0.0.0" in str(result.exception) + + def test_ipv4address_invalid(self): + """Test that invalid IP addresses are rejected.""" + app = Typer() + + @app.command() + def main(config: NetworkConfig): + print(f"Server IP: {config.server_ip}") + + result = runner.invoke(app, ["--config.server_ip", "999.999.999.999"]) + assert result.exit_code != 0 + diff --git a/unit_tests/pydantic_types/test_013_model_validator.py b/unit_tests/pydantic_types/test_013_model_validator.py new file mode 100644 index 0000000..82dd0cc --- /dev/null +++ b/unit_tests/pydantic_types/test_013_model_validator.py @@ -0,0 +1,284 @@ +"""Test that model_validator decorators work with pydantic-typer.""" + +from pydantic import SecretStr +from typer.testing import CliRunner + +from examples.pydantic_types.example_013_model_validator import ( + ConfigWithDefaults, + DatabaseConfig, + RangeConfig, + ServiceConfig, + UserProfile, +) +from pydantic_typer import Typer + +runner = CliRunner() + + +class TestModelValidatorBefore: + """Test model_validator with mode='before'.""" + + def test_model_validator_before(self): + """Test that model_validator with mode='before' works.""" + app = Typer() + + @app.command() + def main(config: RangeConfig): + # min_value should be normalized to 0 (not -5) + assert config.min_value == 0 + assert config.max_value == 50 + print(f"Range: {config.min_value} to {config.max_value}") + + result = runner.invoke( + app, ["--config.min_value", "-5", "--config.max_value", "50"] + ) + assert result.exit_code == 0 + assert "Range: 0 to 50" in result.stdout + + +class TestModelValidatorAfter: + """Test model_validator with mode='after'.""" + + def test_model_validator_after_success(self): + """Test that model_validator with mode='after' works for valid data.""" + app = Typer() + + @app.command() + def main(config: RangeConfig): + assert config.min_value == 10 + assert config.max_value == 50 + print(f"Range: {config.min_value} to {config.max_value}") + + result = runner.invoke( + app, ["--config.min_value", "10", "--config.max_value", "50"] + ) + assert result.exit_code == 0 + assert "Range: 10 to 50" in result.stdout + + def test_model_validator_after_error(self): + """Test that model_validator with mode='after' catches validation errors.""" + app = Typer() + + @app.command() + def main(config: RangeConfig): + print(f"Range: {config.min_value} to {config.max_value}") + + # min_value >= max_value should fail + result = runner.invoke( + app, ["--config.min_value", "100", "--config.max_value", "50"] + ) + assert result.exit_code != 0 + assert "min_value" in str(result.exception) and "max_value" in str( + result.exception + ) + + +class TestModelValidatorCombined: + """Test combined before and after model validators.""" + + def test_multiple_model_validators(self): + """Test that multiple model validators work together.""" + app = Typer() + + @app.command() + def main(profile: UserProfile): + # Username and email should be lowercased + assert profile.username == "john" + assert profile.email == "john@example.com" + assert profile.age == 25 + print(f"Profile: {profile.username} ({profile.email})") + + result = runner.invoke( + app, + [ + "--profile.username", + "JOHN", + "--profile.email", + "JOHN@EXAMPLE.COM", + "--profile.age", + "25", + ], + ) + assert result.exit_code == 0 + assert "Profile: john (john@example.com)" in result.stdout + + def test_model_validator_error(self): + """Test that model validator validation errors are properly reported.""" + app = Typer() + + @app.command() + def main(profile: UserProfile): + print(f"Profile: {profile.username}") + + # Email username doesn't match profile username + result = runner.invoke( + app, + ["--profile.username", "john", "--profile.email", "jane@example.com"], + ) + assert result.exit_code != 0 + assert "Email username must match profile username" in str(result.exception) + + def test_model_validator_computed_defaults(self): + """Test that model_validator can compute default values.""" + app = Typer() + + @app.command() + def main(config: ConfigWithDefaults): + assert config.base_path == "/var/data" + assert config.cache_path == "/var/data/cache" + assert config.log_path == "/var/data/logs" + print(f"Paths: {config.cache_path}, {config.log_path}") + + result = runner.invoke(app, ["--config.base_path", "/var/data"]) + assert result.exit_code == 0 + assert "/var/data/cache" in result.stdout + assert "/var/data/logs" in result.stdout + + def test_model_validator_with_explicit_paths(self): + """Test that explicit paths override computed defaults.""" + app = Typer() + + @app.command() + def main(config: ConfigWithDefaults): + assert config.base_path == "/var/data" + assert config.cache_path == "/custom/cache" + assert config.log_path == "/var/data/logs" + print(f"Paths: {config.cache_path}, {config.log_path}") + + result = runner.invoke( + app, + [ + "--config.base_path", + "/var/data", + "--config.cache_path", + "/custom/cache", + ], + ) + assert result.exit_code == 0 + assert "/custom/cache" in result.stdout + assert "/var/data/logs" in result.stdout + + +class TestModelValidatorWithPydanticTypes: + """Test model_validator with advanced Pydantic types.""" + + def test_model_validator_with_httpurl(self): + """Test that model_validator works with HttpUrl for normalization.""" + app = Typer() + + @app.command() + def main(config: ServiceConfig): + assert str(config.api_url) == "https://api.example.com/" + assert str(config.webhook_url) == "https://api.example.com/webhook" + assert config.admin_email == "admin@example.com" + print(f"API: {config.api_url}, Webhook: {config.webhook_url}") + + result = runner.invoke( + app, + [ + "--config.api_url", + "api.example.com", + "--config.webhook_url", + "api.example.com/webhook", + "--config.admin_email", + " ADMIN@EXAMPLE.COM ", + ], + ) + assert result.exit_code == 0 + assert "https://api.example.com" in result.stdout + + def test_model_validator_domain_mismatch_error(self): + """Test that model_validator catches domain mismatches.""" + app = Typer() + + @app.command() + def main(config: ServiceConfig): + print(f"API: {config.api_url}") + + result = runner.invoke( + app, + [ + "--config.api_url", + "https://api.example.com", + "--config.webhook_url", + "https://other.example.com/webhook", + "--config.admin_email", + "admin@example.com", + ], + ) + assert result.exit_code != 0 + assert "must be on the same domain" in str(result.exception) + + def test_model_validator_with_secretstr(self): + """Test that model_validator works with SecretStr.""" + app = Typer() + + @app.command() + def main(config: DatabaseConfig): + assert config.host == "localhost" + assert config.username == "dev" + assert isinstance(config.password, SecretStr) + assert config.password.get_secret_value() == "devpass" + print(f"DB: {config.username}@{config.host}:{config.port}") + + result = runner.invoke( + app, + [ + "--config.username", + "dev", + "--config.password", + "devpass", + ], + ) + assert result.exit_code == 0 + assert "dev@localhost:5432" in result.stdout + + def test_model_validator_production_password_error(self): + """Test that model_validator enforces strong passwords for production.""" + app = Typer() + + @app.command() + def main(config: DatabaseConfig): + print(f"DB: {config.username}@{config.host}") + + # Short password for remote host should fail + result = runner.invoke( + app, + [ + "--config.host", + "prod.example.com", + "--config.username", + "admin", + "--config.password", + "short", + ], + ) + assert result.exit_code != 0 + assert "Production database password must be at least 8 characters" in str( + result.exception + ) + + def test_model_validator_production_password_valid(self): + """Test that strong passwords pass validation for production.""" + app = Typer() + + @app.command() + def main(config: DatabaseConfig): + assert config.host == "prod.example.com" + assert config.password.get_secret_value() == "StrongP@ssw0rd!" + print(f"DB: {config.username}@{config.host}") + + result = runner.invoke( + app, + [ + "--config.host", + "prod.example.com", + "--config.username", + "admin", + "--config.password", + "StrongP@ssw0rd!", + ], + ) + assert result.exit_code == 0 + assert "admin@prod.example.com" in result.stdout + diff --git a/unit_tests/pydantic_types/test_014_optional_pydantic_types.py b/unit_tests/pydantic_types/test_014_optional_pydantic_types.py new file mode 100644 index 0000000..f64ba2e --- /dev/null +++ b/unit_tests/pydantic_types/test_014_optional_pydantic_types.py @@ -0,0 +1,123 @@ +"""Tests for Optional fields with Pydantic types in models.""" + +import subprocess +import sys +import tempfile +from pathlib import Path + +import pydantic_typer +from examples.pydantic_types import example_014_optional_pydantic_types as mod +from typer.testing import CliRunner + +runner = CliRunner() + +app = pydantic_typer.Typer() +app.command()(mod.main) + + +def test_help(): + """Test that help displays correctly with proper option names.""" + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + # Check that the options are named correctly (not ---pydantic-config-input-dir) + assert "--config.input_dir" in result.stdout + assert "--config.output_file" in result.stdout + assert "--config.name" in result.stdout + + +def test_with_none_values(): + """Test with default None values.""" + result = runner.invoke(app, ["--config.name", "test"]) + assert result.exit_code == 0 + assert "Input dir: None (type: NoneType)" in result.stdout + assert "Output file: None (type: NoneType)" in result.stdout + assert "Name: test (type: str)" in result.stdout + + +def test_with_directory_path(): + """Test with an actual directory path.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = runner.invoke( + app, ["--config.input_dir", tmpdir, "--config.name", "with_dir"] + ) + assert result.exit_code == 0 + assert f"Input dir: {tmpdir}" in result.stdout + assert "PosixPath" in result.stdout or "WindowsPath" in result.stdout + assert "Output file: None" in result.stdout + + +def test_with_file_path(): + """Test with an actual file path.""" + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + tmpfile_path = tmpfile.name + try: + result = runner.invoke( + app, + ["--config.output_file", tmpfile_path, "--config.name", "with_file"], + ) + assert result.exit_code == 0 + assert f"Output file: {tmpfile_path}" in result.stdout + assert "PosixPath" in result.stdout or "WindowsPath" in result.stdout + assert "Input dir: None" in result.stdout + finally: + Path(tmpfile_path).unlink() + + +def test_with_both_paths(): + """Test with both directory and file paths.""" + with tempfile.TemporaryDirectory() as tmpdir: + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + tmpfile_path = tmpfile.name + try: + result = runner.invoke( + app, + [ + "--config.input_dir", + tmpdir, + "--config.output_file", + tmpfile_path, + "--config.name", + "both_paths", + ], + ) + assert result.exit_code == 0 + assert f"Input dir: {tmpdir}" in result.stdout + assert f"Output file: {tmpfile_path}" in result.stdout + assert "Name: both_paths" in result.stdout + finally: + Path(tmpfile_path).unlink() + + +def test_invalid_directory(): + """Test that invalid directory path is rejected.""" + result = runner.invoke( + app, + [ + "--config.input_dir", + "/nonexistent/directory/path", + "--config.name", + "invalid", + ], + ) + assert result.exit_code != 0 + + +def test_invalid_file(): + """Test that invalid file path is rejected.""" + result = runner.invoke( + app, + ["--config.output_file", "/nonexistent/file.txt", "--config.name", "invalid"], + ) + assert result.exit_code != 0 + + +def test_script(): + """Test running the example as a script.""" + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + check=False, + ) + assert "Usage" in result.stdout + assert "--config.input_dir" in result.stdout