-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcaid_contract.py
More file actions
292 lines (239 loc) · 12.1 KB
/
Copy pathcaid_contract.py
File metadata and controls
292 lines (239 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
from __future__ import annotations
import json
from copy import deepcopy
from pathlib import Path
from typing import Any
SCHEMA_VERSION = 1
ARTIFACT_REQUIRED_KEYS = frozenset(
{"schema_version", "artifact_id", "producer", "created_at", "feature_tree", "parameters", "simulation_tags"}
)
PATCH_REQUIRED_KEYS = frozenset({"schema_version", "artifact_id", "source", "parameter_patches"})
SIMULATION_TAG_KINDS = frozenset({"body", "joint", "geom", "site", "parameter"})
PARAMETER_VALUE_TYPES = (bool, int, float, str)
__all__ = [
"SCHEMA_VERSION",
"ARTIFACT_REQUIRED_KEYS",
"PATCH_REQUIRED_KEYS",
"ContractError",
"apply_parameter_patch",
"apply_patch_to_simulation_params",
"get_parameter",
"load_artifact",
"make_parameter_patch",
"make_patch_from_identification",
"resolve_parameter_name",
"simulation_target_for_parameter",
"write_json",
]
class ContractError(ValueError):
"""Raised when a CAID artifact or patch violates the integration contract."""
def load_artifact(source: str | Path | dict[str, Any]) -> dict[str, Any]:
artifact = _load_json(source)
_require_artifact(artifact)
return artifact
def get_parameter(artifact: dict[str, Any], name: str) -> dict[str, Any]:
_require_artifact(artifact)
try:
parameter = artifact["parameters"][name]
except KeyError as exc:
raise ContractError(f"Unknown design parameter '{name}'.") from exc
if not isinstance(parameter, dict) or "value" not in parameter:
raise ContractError(f"Design parameter '{name}' is malformed.")
return parameter
def resolve_parameter_name(artifact: dict[str, Any], name_or_target: str) -> str:
_require_artifact(artifact)
if name_or_target in artifact["parameters"]:
return name_or_target
for tag in artifact.get("simulation_tags", []):
if tag.get("kind") == "parameter" and tag.get("target") == name_or_target:
name = tag.get("name")
if name in artifact["parameters"]:
return name
raise ContractError(f"Unknown design parameter or simulation target '{name_or_target}'.")
def simulation_target_for_parameter(artifact: dict[str, Any], name: str) -> str:
_require_artifact(artifact)
get_parameter(artifact, name)
for tag in artifact.get("simulation_tags", []):
if tag.get("kind") == "parameter" and tag.get("name") == name:
target = tag.get("target")
if isinstance(target, str) and target:
return target
raise ContractError(f"Simulation tag for '{name}' is missing a target.")
return name
def make_parameter_patch(
artifact: dict[str, Any],
name: str,
value: bool | int | float | str,
*,
reason: str | None = None,
source: str = "simcorrect",
) -> dict[str, Any]:
parameter = get_parameter(artifact, name)
return {
"schema_version": SCHEMA_VERSION,
"artifact_id": artifact["artifact_id"],
"source": source,
"parameter_patches": [
{
"name": name,
"old_value": parameter["value"],
"value": value,
"reason": reason,
}
],
}
def make_patch_from_identification(
artifact: dict[str, Any],
identification: dict[str, Any],
*,
source: str = "simcorrect",
) -> dict[str, Any]:
missing = [key for key in ("identified_parameter", "proposed_value") if key not in identification]
if missing:
raise ContractError(f"Identification result missing required key(s): {', '.join(missing)}.")
name = resolve_parameter_name(artifact, identification["identified_parameter"])
return make_parameter_patch(
artifact,
name,
identification["proposed_value"],
reason=f"{identification.get('method', 'parameter_identification')} identified {identification['identified_parameter']}.",
source=source,
)
def apply_parameter_patch(artifact: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
_require_artifact(artifact)
_require_patch(patch)
_require_patch_targets_artifact(artifact, patch)
updated = deepcopy(artifact)
for item in patch["parameter_patches"]:
parameter = get_parameter(updated, item["name"])
_require_current_value(parameter, item)
parameter["value"] = item["value"]
return updated
def apply_patch_to_simulation_params(
artifact: dict[str, Any],
patch: dict[str, Any],
params: dict[str, Any],
) -> dict[str, Any]:
_require_artifact(artifact)
_require_patch(patch)
_require_patch_targets_artifact(artifact, patch)
updated = dict(params)
for item in patch["parameter_patches"]:
parameter = get_parameter(artifact, item["name"])
_require_current_value(parameter, item)
target = simulation_target_for_parameter(artifact, item["name"])
if target not in updated:
raise ContractError(f"Simulation parameter '{target}' is not present in current params.")
updated[target] = item["value"]
return updated
def write_json(payload: dict[str, Any], path: str | Path) -> None:
Path(path).write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def _load_json(source: str | Path | dict[str, Any]) -> dict[str, Any]:
if isinstance(source, dict):
return deepcopy(source)
if not isinstance(source, (str, Path)):
raise ContractError("CAID JSON source must be a path or object.")
payload = json.loads(Path(source).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ContractError("CAID JSON payload must be an object.")
return payload
def _require_version(payload: dict[str, Any]) -> None:
if payload.get("schema_version") != SCHEMA_VERSION:
raise ContractError(f"Unsupported CAID schema_version: {payload.get('schema_version')!r}.")
def _require_artifact(payload: dict[str, Any]) -> None:
_require_object(payload, "CAID artifact")
_require_keys(payload, ARTIFACT_REQUIRED_KEYS, "CAID artifact")
_require_version(payload)
if not isinstance(payload.get("artifact_id"), str) or not payload["artifact_id"]:
raise ContractError("CAID artifact must contain a non-empty artifact_id.")
_require_producer(payload["producer"])
_require_feature_tree(payload["feature_tree"])
if not isinstance(payload.get("parameters"), dict):
raise ContractError("CAID artifact must contain a parameters object.")
for name, parameter in payload["parameters"].items():
_require_parameter(name, parameter)
if not isinstance(payload["simulation_tags"], list):
raise ContractError("CAID artifact simulation_tags must be a list.")
for tag in payload["simulation_tags"]:
_require_simulation_tag(tag)
def _require_patch(payload: dict[str, Any]) -> None:
_require_object(payload, "CAID patch")
_require_keys(payload, PATCH_REQUIRED_KEYS, "CAID patch")
_require_version(payload)
if not isinstance(payload.get("artifact_id"), str) or not payload["artifact_id"]:
raise ContractError("CAID patch must contain a non-empty artifact_id.")
if not isinstance(payload.get("source"), str) or not payload["source"]:
raise ContractError("CAID patch must contain a non-empty source.")
patches = payload.get("parameter_patches")
if not isinstance(patches, list) or not patches:
raise ContractError("CAID patch must contain at least one parameter patch.")
for item in patches:
_require_patch_item(item)
def _require_patch_item(item: Any) -> None:
if not isinstance(item, dict):
raise ContractError("Each parameter patch must be an object.")
if not isinstance(item.get("name"), str) or not item["name"]:
raise ContractError("Each parameter patch must contain a non-empty name.")
if "value" not in item:
raise ContractError(f"Parameter patch '{item.get('name')}' is missing value.")
if not isinstance(item["value"], PARAMETER_VALUE_TYPES):
raise ContractError(f"Parameter patch '{item['name']}' has unsupported value type.")
if "old_value" in item and item["old_value"] is not None and not isinstance(item["old_value"], PARAMETER_VALUE_TYPES):
raise ContractError(f"Parameter patch '{item['name']}' has unsupported old_value type.")
if "reason" in item and item["reason"] is not None and not isinstance(item["reason"], str):
raise ContractError(f"Parameter patch '{item['name']}' reason must be a string when present.")
def _require_patch_targets_artifact(artifact: dict[str, Any], patch: dict[str, Any]) -> None:
if patch.get("artifact_id") != artifact.get("artifact_id"):
raise ContractError("Patch artifact_id does not match artifact.")
def _require_current_value(parameter: dict[str, Any], item: dict[str, Any]) -> None:
if item.get("old_value") is not None and item["old_value"] != parameter["value"]:
raise ContractError(
f"Patch for '{item['name']}' expected old value {item['old_value']!r}, "
f"but artifact has {parameter['value']!r}."
)
def _require_object(payload: Any, label: str) -> None:
if not isinstance(payload, dict):
raise ContractError(f"{label} must be an object.")
def _require_keys(payload: dict[str, Any], required: frozenset[str], label: str) -> None:
missing = sorted(required - payload.keys())
if missing:
raise ContractError(f"{label} missing required key(s): {', '.join(missing)}.")
def _require_producer(producer: Any) -> None:
if not isinstance(producer, dict):
raise ContractError("CAID artifact producer must be an object.")
if not isinstance(producer.get("name"), str) or not producer["name"]:
raise ContractError("CAID artifact producer must contain a non-empty name.")
if not isinstance(producer.get("version"), str) or not producer["version"]:
raise ContractError("CAID artifact producer must contain a non-empty version.")
def _require_feature_tree(feature_tree: Any) -> None:
if not isinstance(feature_tree, dict):
raise ContractError("CAID artifact feature_tree must be an object.")
if not isinstance(feature_tree.get("root_id"), str) or not feature_tree["root_id"]:
raise ContractError("CAID artifact feature_tree must contain a non-empty root_id.")
if not isinstance(feature_tree.get("nodes"), dict):
raise ContractError("CAID artifact feature_tree must contain a nodes object.")
def _require_parameter(name: Any, parameter: Any) -> None:
if not isinstance(name, str) or not name:
raise ContractError("CAID artifact parameter keys must be non-empty strings.")
if not isinstance(parameter, dict):
raise ContractError(f"Design parameter '{name}' must be an object.")
if parameter.get("name") != name:
raise ContractError(f"Design parameter key '{name}' does not match parameter name '{parameter.get('name')}'.")
if "value" not in parameter:
raise ContractError(f"Design parameter '{name}' is missing value.")
if not isinstance(parameter["value"], PARAMETER_VALUE_TYPES):
raise ContractError(f"Design parameter '{name}' has unsupported value type.")
for optional in ("unit", "role", "feature_id"):
if optional in parameter and parameter[optional] is not None and not isinstance(parameter[optional], str):
raise ContractError(f"Design parameter '{name}' field '{optional}' must be a string when present.")
def _require_simulation_tag(tag: Any) -> None:
if not isinstance(tag, dict):
raise ContractError("CAID artifact simulation tags must be objects.")
if not isinstance(tag.get("name"), str) or not tag["name"]:
raise ContractError("CAID artifact simulation tag must contain a non-empty name.")
if tag.get("kind") not in SIMULATION_TAG_KINDS:
raise ContractError(f"CAID artifact simulation tag has unsupported kind: {tag.get('kind')!r}.")
if not isinstance(tag.get("target"), str) or not tag["target"]:
raise ContractError("CAID artifact simulation tag must contain a non-empty target.")
if "metadata" in tag and not isinstance(tag["metadata"], dict):
raise ContractError("CAID artifact simulation tag metadata must be an object when present.")