|
7 | 7 | import shutil |
8 | 8 | import sys |
9 | 9 | from importlib.util import module_from_spec, spec_from_file_location |
| 10 | +from pathlib import Path |
10 | 11 | from tempfile import mkdtemp |
11 | 12 | from types import ModuleType |
12 | 13 | from typing import Any, Dict, List, Tuple, Type |
13 | 14 | from uuid import uuid4 |
14 | 15 |
|
15 | | -from pydantic import BaseModel, Extra, create_model |
| 16 | +from pydantic import VERSION, BaseModel, Extra, create_model |
16 | 17 |
|
17 | | -try: |
18 | | - from pydantic.generics import GenericModel |
19 | | -except ImportError: |
20 | | - GenericModel = None |
| 18 | +V2 = True if VERSION.startswith("2") else False |
| 19 | + |
| 20 | +if not V2: |
| 21 | + try: |
| 22 | + from pydantic.generics import GenericModel |
| 23 | + except ImportError: |
| 24 | + GenericModel = None |
21 | 25 |
|
22 | 26 | logger = logging.getLogger("pydantic2ts") |
23 | 27 |
|
24 | 28 |
|
| 29 | +DEBUG = os.environ.get("DEBUG", False) |
| 30 | + |
| 31 | + |
25 | 32 | def import_module(path: str) -> ModuleType: |
26 | 33 | """ |
27 | 34 | Helper which allows modules to be specified by either dotted path notation or by filepath. |
@@ -61,12 +68,15 @@ def is_concrete_pydantic_model(obj) -> bool: |
61 | 68 | Return true if an object is a concrete subclass of pydantic's BaseModel. |
62 | 69 | 'concrete' meaning that it's not a GenericModel. |
63 | 70 | """ |
| 71 | + generic_metadata = getattr(obj, "__pydantic_generic_metadata__", None) |
64 | 72 | if not inspect.isclass(obj): |
65 | 73 | return False |
66 | 74 | elif obj is BaseModel: |
67 | 75 | return False |
68 | | - elif GenericModel and issubclass(obj, GenericModel): |
| 76 | + elif not V2 and GenericModel and issubclass(obj, GenericModel): |
69 | 77 | return bool(obj.__concrete__) |
| 78 | + elif V2 and generic_metadata: |
| 79 | + return not bool(generic_metadata["parameters"]) |
70 | 80 | else: |
71 | 81 | return issubclass(obj, BaseModel) |
72 | 82 |
|
@@ -141,7 +151,7 @@ def clean_schema(schema: Dict[str, Any]) -> None: |
141 | 151 | del schema["description"] |
142 | 152 |
|
143 | 153 |
|
144 | | -def generate_json_schema(models: List[Type[BaseModel]]) -> str: |
| 154 | +def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str: |
145 | 155 | """ |
146 | 156 | Create a top-level '_Master_' model with references to each of the actual models. |
147 | 157 | Generate the schema for this model, which will include the schemas for all the |
@@ -178,6 +188,43 @@ def generate_json_schema(models: List[Type[BaseModel]]) -> str: |
178 | 188 | m.Config.extra = x |
179 | 189 |
|
180 | 190 |
|
| 191 | +def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str: |
| 192 | + """ |
| 193 | + Create a top-level '_Master_' model with references to each of the actual models. |
| 194 | + Generate the schema for this model, which will include the schemas for all the |
| 195 | + nested models. Then clean up the schema. |
| 196 | +
|
| 197 | + One weird thing we do is we temporarily override the 'extra' setting in models, |
| 198 | + changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents |
| 199 | + '[k: string]: any' from being added to every interface. This change is reverted |
| 200 | + once the schema has been generated. |
| 201 | + """ |
| 202 | + model_extras = [m.model_config.get("extra") for m in models] |
| 203 | + |
| 204 | + try: |
| 205 | + for m in models: |
| 206 | + if m.model_config.get("extra") != "allow": |
| 207 | + m.model_config["extra"] = "forbid" |
| 208 | + |
| 209 | + master_model: BaseModel = create_model( |
| 210 | + "_Master_", **{m.__name__: (m, ...) for m in models} |
| 211 | + ) |
| 212 | + master_model.model_config["extra"] = "forbid" |
| 213 | + master_model.model_config["json_schema_extra"] = staticmethod(clean_schema) |
| 214 | + |
| 215 | + schema: dict = master_model.model_json_schema(mode="serialization") |
| 216 | + |
| 217 | + for d in schema.get("$defs", {}).values(): |
| 218 | + clean_schema(d) |
| 219 | + |
| 220 | + return json.dumps(schema, indent=2) |
| 221 | + |
| 222 | + finally: |
| 223 | + for m, x in zip(models, model_extras): |
| 224 | + if x is not None: |
| 225 | + m.model_config["extra"] = x |
| 226 | + |
| 227 | + |
181 | 228 | def generate_typescript_defs( |
182 | 229 | module: str, output: str, exclude: Tuple[str] = (), json2ts_cmd: str = "json2ts" |
183 | 230 | ) -> None: |
@@ -205,13 +252,20 @@ def generate_typescript_defs( |
205 | 252 |
|
206 | 253 | logger.info("Generating JSON schema from pydantic models...") |
207 | 254 |
|
208 | | - schema = generate_json_schema(models) |
| 255 | + schema = generate_json_schema_v2(models) if V2 else generate_json_schema_v1(models) |
| 256 | + |
209 | 257 | schema_dir = mkdtemp() |
210 | 258 | schema_file_path = os.path.join(schema_dir, "schema.json") |
211 | 259 |
|
212 | 260 | with open(schema_file_path, "w") as f: |
213 | 261 | f.write(schema) |
214 | 262 |
|
| 263 | + if DEBUG: |
| 264 | + debug_schema_file_path = Path(module).parent / "schema_debug.json" |
| 265 | + # raise ValueError(module) |
| 266 | + with open(debug_schema_file_path, "w") as f: |
| 267 | + f.write(schema) |
| 268 | + |
215 | 269 | logger.info("Converting JSON schema to typescript definitions...") |
216 | 270 |
|
217 | 271 | json2ts_exit_code = os.system( |
|
0 commit comments