|
| 1 | +"""Jinja2 extensions.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import re |
| 6 | +import uuid |
| 7 | +from base64 import b64decode, b64encode |
| 8 | +from collections.abc import Iterator |
| 9 | +from datetime import datetime |
| 10 | +from functools import reduce |
| 11 | +from hashlib import new as new_hash |
| 12 | +from json import dumps as to_json, loads as from_json |
| 13 | +from ntpath import ( |
| 14 | + basename as win_basename, |
| 15 | + dirname as win_dirname, |
| 16 | + splitdrive as win_splitdrive, |
| 17 | +) |
| 18 | +from os.path import expanduser, expandvars, realpath, relpath, splitext |
| 19 | +from pathlib import Path |
| 20 | +from posixpath import basename, dirname |
| 21 | +from random import Random |
| 22 | +from shlex import quote |
| 23 | +from time import gmtime, localtime, strftime |
| 24 | +from typing import ( |
| 25 | + TYPE_CHECKING, |
| 26 | + Any, |
| 27 | + Callable, |
| 28 | + Final, |
| 29 | + Literal, |
| 30 | + Sequence, |
| 31 | + TypeVar, |
| 32 | + overload, |
| 33 | +) |
| 34 | +from warnings import warn |
| 35 | + |
| 36 | +import yaml |
| 37 | +from jinja2 import Environment, Undefined, UndefinedError, pass_environment |
| 38 | +from jinja2.ext import Extension |
| 39 | +from jinja2.filters import do_groupby |
| 40 | + |
| 41 | +from .tools import cast_to_bool |
| 42 | + |
| 43 | +if TYPE_CHECKING: |
| 44 | + from typing_extensions import TypeGuard |
| 45 | + |
| 46 | +_T = TypeVar("_T") |
| 47 | + |
| 48 | +_UUID_NAMESPACE = uuid.uuid5(uuid.NAMESPACE_DNS, "https://github.com/copier-org/copier") |
| 49 | + |
| 50 | + |
| 51 | +def _is_sequence(obj: object) -> TypeGuard[Sequence[Any]]: |
| 52 | + return hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)) |
| 53 | + |
| 54 | + |
| 55 | +def _do_b64decode(value: str) -> str: |
| 56 | + return b64decode(value).decode() |
| 57 | + |
| 58 | + |
| 59 | +def _do_b64encode(value: str) -> str: |
| 60 | + return b64encode(value.encode()).decode() |
| 61 | + |
| 62 | + |
| 63 | +def _do_bool(value: Any) -> bool | None: |
| 64 | + return None if value is None else cast_to_bool(value) |
| 65 | + |
| 66 | + |
| 67 | +def _do_hash(value: str, algorithm: str) -> str: |
| 68 | + hasher = new_hash(algorithm) |
| 69 | + hasher.update(value.encode()) |
| 70 | + return hasher.hexdigest() |
| 71 | + |
| 72 | + |
| 73 | +def _do_sha1(value: str) -> str: |
| 74 | + return _do_hash(value, "sha1") |
| 75 | + |
| 76 | + |
| 77 | +def _do_md5(value: str) -> str: |
| 78 | + return _do_hash(value, "md5") |
| 79 | + |
| 80 | + |
| 81 | +def _do_mandatory(value: _T, msg: str | None = None) -> _T: |
| 82 | + if isinstance(value, Undefined): |
| 83 | + # See https://jinja.palletsprojects.com/en/3.1.x/api/#jinja2.Undefined._undefined_name |
| 84 | + raise UndefinedError( |
| 85 | + msg |
| 86 | + or f'Mandatory variable `{value._undefined_name or "<unknown>"}` is undefined' |
| 87 | + ) |
| 88 | + return value |
| 89 | + |
| 90 | + |
| 91 | +def _do_to_uuid(name: str, namespace: str | uuid.UUID = _UUID_NAMESPACE) -> str: |
| 92 | + if not isinstance(namespace, uuid.UUID): |
| 93 | + namespace = uuid.UUID(namespace) |
| 94 | + return str(uuid.uuid5(namespace, name)) |
| 95 | + |
| 96 | + |
| 97 | +def _do_to_yaml(value: Any, *args: Any, **kwargs: Any) -> str: |
| 98 | + kwargs.setdefault("allow_unicode", True) |
| 99 | + return yaml.dump(value, *args, **kwargs) # type: ignore[no-any-return] |
| 100 | + |
| 101 | + |
| 102 | +def _do_from_yaml(value: str) -> Any: |
| 103 | + return yaml.load(value, Loader=yaml.SafeLoader) |
| 104 | + |
| 105 | + |
| 106 | +def _do_from_yaml_all(value: str) -> Iterator[Any]: |
| 107 | + return yaml.load_all(value, Loader=yaml.SafeLoader) |
| 108 | + |
| 109 | + |
| 110 | +def _do_strftime(format: str, seconds: float | None = None, utc: bool = False) -> str: |
| 111 | + return strftime(format, gmtime(seconds) if utc else localtime(seconds)) |
| 112 | + |
| 113 | + |
| 114 | +def _do_to_datetime(date_string: str, format: str = "%Y-%m-%d %H:%M:%S") -> datetime: |
| 115 | + return datetime.strptime(date_string, format) |
| 116 | + |
| 117 | + |
| 118 | +def _do_ternary(condition: bool | None, true: Any, false: Any, none: Any = None) -> Any: |
| 119 | + if condition is None: |
| 120 | + return none |
| 121 | + return true if condition else false |
| 122 | + |
| 123 | + |
| 124 | +def _do_to_nice_json(value: Any, /, **kwargs: Any) -> str: |
| 125 | + kwargs.setdefault("skipkeys", False) |
| 126 | + kwargs.setdefault("ensure_ascii", True) |
| 127 | + kwargs.setdefault("check_circular", True) |
| 128 | + kwargs.setdefault("allow_nan", True) |
| 129 | + kwargs.setdefault("indent", 4) |
| 130 | + kwargs.setdefault("sort_keys", True) |
| 131 | + return to_json(value, **kwargs) |
| 132 | + |
| 133 | + |
| 134 | +def _do_to_nice_yaml(value: Any, *args: Any, **kwargs: Any) -> str: |
| 135 | + kwargs.setdefault("allow_unicode", True) |
| 136 | + kwargs.setdefault("indent", 4) |
| 137 | + return yaml.dump(value, *args, **kwargs) # type: ignore[no-any-return] |
| 138 | + |
| 139 | + |
| 140 | +def _do_shuffle(seq: Sequence[_T], seed: str | None = None) -> list[_T]: |
| 141 | + seq = list(seq) |
| 142 | + Random(seed).shuffle(seq) |
| 143 | + return seq |
| 144 | + |
| 145 | + |
| 146 | +@overload |
| 147 | +def _do_random(stop: int, start: int, step: int, seed: str | None) -> int: ... |
| 148 | + |
| 149 | + |
| 150 | +@overload |
| 151 | +def _do_random(stop: Sequence[_T], start: None, step: None, seed: str | None) -> _T: ... |
| 152 | + |
| 153 | + |
| 154 | +def _do_random( |
| 155 | + stop: int | Sequence[_T], |
| 156 | + start: int | None = None, |
| 157 | + step: int | None = None, |
| 158 | + seed: str | None = None, |
| 159 | +) -> int | _T: |
| 160 | + rng = Random(seed) |
| 161 | + |
| 162 | + if isinstance(stop, int): |
| 163 | + if start is None: |
| 164 | + start = 0 |
| 165 | + if step is None: |
| 166 | + step = 1 |
| 167 | + return rng.randrange(start, stop, step) |
| 168 | + |
| 169 | + for arg_name, arg_value in [("start", start), ("stop", stop)]: |
| 170 | + if arg_value is None: |
| 171 | + raise TypeError(f'"{arg_name}" can only be used when "stop" is an integer') |
| 172 | + return rng.choice(stop) |
| 173 | + |
| 174 | + |
| 175 | +def _do_flatten( |
| 176 | + seq: Sequence[Any], levels: int | None = None, skip_nulls: bool = True |
| 177 | +) -> Sequence[Any]: |
| 178 | + if levels is not None: |
| 179 | + if levels < 1: |
| 180 | + return seq |
| 181 | + levels -= 1 |
| 182 | + result: list[Any] = [] |
| 183 | + for item in seq: |
| 184 | + if _is_sequence(item): |
| 185 | + result.extend(_do_flatten(item, levels, skip_nulls)) |
| 186 | + elif not skip_nulls or item is not None: |
| 187 | + result.append(item) |
| 188 | + return result |
| 189 | + |
| 190 | + |
| 191 | +def _do_fileglob(pattern: str) -> Sequence[str]: |
| 192 | + return [str(path) for path in Path(".").glob(pattern) if path.is_file()] |
| 193 | + |
| 194 | + |
| 195 | +def _do_random_mac(prefix: str, seed: str | None = None) -> str: |
| 196 | + parts = prefix.lower().strip(":").split(":") |
| 197 | + if len(parts) > 5: |
| 198 | + raise ValueError(f"Invalid MAC address prefix {prefix}: too many parts") |
| 199 | + for part in parts: |
| 200 | + if not re.match(r"[a-f0-9]{2}", part): |
| 201 | + raise ValueError( |
| 202 | + f"Invalid MAC address prefix {prefix}: {part} is not a hexadecimal byte" |
| 203 | + ) |
| 204 | + rng = Random(seed) |
| 205 | + return ":".join( |
| 206 | + parts + [f"{rng.randint(0, 255):02x}" for _ in range(6 - len(parts))] |
| 207 | + ) |
| 208 | + |
| 209 | + |
| 210 | +def _do_regex_escape( |
| 211 | + pattern: str, re_type: Literal["python", "posix_basic"] = "python" |
| 212 | +) -> str: |
| 213 | + if re_type == "python": |
| 214 | + return re.escape(pattern) |
| 215 | + raise NotImplementedError(f"Regex type {re_type} not implemented") |
| 216 | + |
| 217 | + |
| 218 | +def _do_regex_search( |
| 219 | + value: str, |
| 220 | + pattern: str, |
| 221 | + *args: str, |
| 222 | + ignorecase: bool = False, |
| 223 | + multiline: bool = False, |
| 224 | +) -> str | list[str] | None: |
| 225 | + groups: list[str | int] = [] |
| 226 | + for arg in args: |
| 227 | + if match := re.match(r"^\\g<(\S+)>$", arg): |
| 228 | + groups.append(match.group(1)) |
| 229 | + elif match := re.match(r"^\\(\d+)$", arg): |
| 230 | + groups.append(int(match.group(1))) |
| 231 | + else: |
| 232 | + raise ValueError("Invalid backref format") |
| 233 | + |
| 234 | + flags = 0 |
| 235 | + if ignorecase: |
| 236 | + flags |= re.IGNORECASE |
| 237 | + if multiline: |
| 238 | + flags |= re.MULTILINE |
| 239 | + |
| 240 | + return (match := re.search(pattern, value, flags)) and ( |
| 241 | + list(result) if isinstance((result := match.group(*groups)), tuple) else result |
| 242 | + ) |
| 243 | + |
| 244 | + |
| 245 | +def _do_regex_replace( |
| 246 | + value: str, |
| 247 | + pattern: str, |
| 248 | + replacement: str, |
| 249 | + *, |
| 250 | + ignorecase: bool = False, |
| 251 | +) -> str: |
| 252 | + return re.sub(pattern, replacement, value, flags=re.I if ignorecase else 0) |
| 253 | + |
| 254 | + |
| 255 | +def _do_regex_findall( |
| 256 | + value: str, |
| 257 | + pattern: str, |
| 258 | + *, |
| 259 | + ignorecase: bool = False, |
| 260 | + multiline: bool = False, |
| 261 | +) -> list[str]: |
| 262 | + flags = 0 |
| 263 | + if ignorecase: |
| 264 | + flags |= re.IGNORECASE |
| 265 | + if multiline: |
| 266 | + flags |= re.MULTILINE |
| 267 | + return re.findall(pattern, value, flags) |
| 268 | + |
| 269 | + |
| 270 | +def _do_type_debug(value: object) -> str: |
| 271 | + return value.__class__.__name__ |
| 272 | + |
| 273 | + |
| 274 | +@pass_environment |
| 275 | +def _do_extract( |
| 276 | + environment: Environment, |
| 277 | + key: Any, |
| 278 | + container: Any, |
| 279 | + *, |
| 280 | + morekeys: Any | Sequence[Any] | None = None, |
| 281 | +) -> Any | Undefined: |
| 282 | + keys: list[Any] |
| 283 | + if morekeys is None: |
| 284 | + keys = [key] |
| 285 | + elif _is_sequence(morekeys): |
| 286 | + keys = [key, *morekeys] |
| 287 | + else: |
| 288 | + keys = [key, morekeys] |
| 289 | + return reduce(environment.getitem, keys, container) |
| 290 | + |
| 291 | + |
| 292 | +class CopierExtension(Extension): |
| 293 | + """Jinja2 extension for Copier.""" |
| 294 | + |
| 295 | + # NOTE: mypy disallows `Callable[[Any, ...], Any]` |
| 296 | + _filters: Final[dict[str, Callable[..., Any]]] = { |
| 297 | + "ans_groupby": do_groupby, |
| 298 | + "ans_random": _do_random, |
| 299 | + "b64decode": _do_b64decode, |
| 300 | + "b64encode": _do_b64encode, |
| 301 | + "basename": basename, |
| 302 | + "bool": _do_bool, |
| 303 | + "checksum": _do_sha1, |
| 304 | + "dirname": dirname, |
| 305 | + "expanduser": expanduser, |
| 306 | + "expandvars": expandvars, |
| 307 | + "extract": _do_extract, |
| 308 | + "fileglob": _do_fileglob, |
| 309 | + "flatten": _do_flatten, |
| 310 | + "from_json": from_json, |
| 311 | + "from_yaml": _do_from_yaml, |
| 312 | + "from_yaml_all": _do_from_yaml_all, |
| 313 | + "hash": _do_hash, |
| 314 | + "mandatory": _do_mandatory, |
| 315 | + "md5": _do_md5, |
| 316 | + "quote": quote, |
| 317 | + "random_mac": _do_random_mac, |
| 318 | + "realpath": realpath, |
| 319 | + "regex_escape": _do_regex_escape, |
| 320 | + "regex_findall": _do_regex_findall, |
| 321 | + "regex_replace": _do_regex_replace, |
| 322 | + "regex_search": _do_regex_search, |
| 323 | + "relpath": relpath, |
| 324 | + "sha1": _do_sha1, |
| 325 | + "shuffle": _do_shuffle, |
| 326 | + "splitext": splitext, |
| 327 | + "strftime": _do_strftime, |
| 328 | + "ternary": _do_ternary, |
| 329 | + "to_datetime": _do_to_datetime, |
| 330 | + "to_json": to_json, |
| 331 | + "to_nice_json": _do_to_nice_json, |
| 332 | + "to_nice_yaml": _do_to_nice_yaml, |
| 333 | + "to_uuid": _do_to_uuid, |
| 334 | + "to_yaml": _do_to_yaml, |
| 335 | + "type_debug": _do_type_debug, |
| 336 | + "win_basename": win_basename, |
| 337 | + "win_dirname": win_dirname, |
| 338 | + "win_splitdrive": win_splitdrive, |
| 339 | + } |
| 340 | + |
| 341 | + def __init__(self, environment: Environment) -> None: |
| 342 | + super().__init__(environment) |
| 343 | + for k, v in self._filters.items(): |
| 344 | + if k in environment.filters: |
| 345 | + warn( |
| 346 | + f'A filter named "{k}" already exists in the Jinja2 environment', |
| 347 | + category=RuntimeWarning, |
| 348 | + stacklevel=2, |
| 349 | + ) |
| 350 | + else: |
| 351 | + environment.filters[k] = v |
0 commit comments