Skip to content

Commit c45cf12

Browse files
committed
refactor: drop jinja2-ansible-filters and reimplement Jinja2 filters
1 parent a4d38b1 commit c45cf12

7 files changed

+1165
-40
lines changed

copier/jinja_ext.py

+351
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
"""Jinja2 extensions."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
import uuid
7+
from base64 import b64decode, b64encode
8+
from collections.abc import Iterator
9+
from datetime import datetime
10+
from functools import reduce
11+
from hashlib import new as new_hash
12+
from json import dumps as to_json, loads as from_json
13+
from ntpath import (
14+
basename as win_basename,
15+
dirname as win_dirname,
16+
splitdrive as win_splitdrive,
17+
)
18+
from os.path import expanduser, expandvars, realpath, relpath, splitext
19+
from pathlib import Path
20+
from posixpath import basename, dirname
21+
from random import Random
22+
from shlex import quote
23+
from time import gmtime, localtime, strftime
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Callable,
28+
Final,
29+
Literal,
30+
Sequence,
31+
TypeVar,
32+
overload,
33+
)
34+
from warnings import warn
35+
36+
import yaml
37+
from jinja2 import Environment, Undefined, UndefinedError, pass_environment
38+
from jinja2.ext import Extension
39+
from jinja2.filters import do_groupby
40+
41+
from .tools import cast_to_bool
42+
43+
if TYPE_CHECKING:
44+
from typing_extensions import TypeGuard
45+
46+
_T = TypeVar("_T")
47+
48+
_UUID_NAMESPACE = uuid.uuid5(uuid.NAMESPACE_DNS, "https://github.com/copier-org/copier")
49+
50+
51+
def _is_sequence(obj: object) -> TypeGuard[Sequence[Any]]:
52+
return hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes))
53+
54+
55+
def _do_b64decode(value: str) -> str:
56+
return b64decode(value).decode()
57+
58+
59+
def _do_b64encode(value: str) -> str:
60+
return b64encode(value.encode()).decode()
61+
62+
63+
def _do_bool(value: Any) -> bool | None:
64+
return None if value is None else cast_to_bool(value)
65+
66+
67+
def _do_hash(value: str, algorithm: str) -> str:
68+
hasher = new_hash(algorithm)
69+
hasher.update(value.encode())
70+
return hasher.hexdigest()
71+
72+
73+
def _do_sha1(value: str) -> str:
74+
return _do_hash(value, "sha1")
75+
76+
77+
def _do_md5(value: str) -> str:
78+
return _do_hash(value, "md5")
79+
80+
81+
def _do_mandatory(value: _T, msg: str | None = None) -> _T:
82+
if isinstance(value, Undefined):
83+
# See https://jinja.palletsprojects.com/en/3.1.x/api/#jinja2.Undefined._undefined_name
84+
raise UndefinedError(
85+
msg
86+
or f'Mandatory variable `{value._undefined_name or "<unknown>"}` is undefined'
87+
)
88+
return value
89+
90+
91+
def _do_to_uuid(name: str, namespace: str | uuid.UUID = _UUID_NAMESPACE) -> str:
92+
if not isinstance(namespace, uuid.UUID):
93+
namespace = uuid.UUID(namespace)
94+
return str(uuid.uuid5(namespace, name))
95+
96+
97+
def _do_to_yaml(value: Any, *args: Any, **kwargs: Any) -> str:
98+
kwargs.setdefault("allow_unicode", True)
99+
return yaml.dump(value, *args, **kwargs) # type: ignore[no-any-return]
100+
101+
102+
def _do_from_yaml(value: str) -> Any:
103+
return yaml.load(value, Loader=yaml.SafeLoader)
104+
105+
106+
def _do_from_yaml_all(value: str) -> Iterator[Any]:
107+
return yaml.load_all(value, Loader=yaml.SafeLoader)
108+
109+
110+
def _do_strftime(format: str, seconds: float | None = None, utc: bool = False) -> str:
111+
return strftime(format, gmtime(seconds) if utc else localtime(seconds))
112+
113+
114+
def _do_to_datetime(date_string: str, format: str = "%Y-%m-%d %H:%M:%S") -> datetime:
115+
return datetime.strptime(date_string, format)
116+
117+
118+
def _do_ternary(condition: bool | None, true: Any, false: Any, none: Any = None) -> Any:
119+
if condition is None:
120+
return none
121+
return true if condition else false
122+
123+
124+
def _do_to_nice_json(value: Any, /, **kwargs: Any) -> str:
125+
kwargs.setdefault("skipkeys", False)
126+
kwargs.setdefault("ensure_ascii", True)
127+
kwargs.setdefault("check_circular", True)
128+
kwargs.setdefault("allow_nan", True)
129+
kwargs.setdefault("indent", 4)
130+
kwargs.setdefault("sort_keys", True)
131+
return to_json(value, **kwargs)
132+
133+
134+
def _do_to_nice_yaml(value: Any, *args: Any, **kwargs: Any) -> str:
135+
kwargs.setdefault("allow_unicode", True)
136+
kwargs.setdefault("indent", 4)
137+
return yaml.dump(value, *args, **kwargs) # type: ignore[no-any-return]
138+
139+
140+
def _do_shuffle(seq: Sequence[_T], seed: str | None = None) -> list[_T]:
141+
seq = list(seq)
142+
Random(seed).shuffle(seq)
143+
return seq
144+
145+
146+
@overload
147+
def _do_random(stop: int, start: int, step: int, seed: str | None) -> int: ...
148+
149+
150+
@overload
151+
def _do_random(stop: Sequence[_T], start: None, step: None, seed: str | None) -> _T: ...
152+
153+
154+
def _do_random(
155+
stop: int | Sequence[_T],
156+
start: int | None = None,
157+
step: int | None = None,
158+
seed: str | None = None,
159+
) -> int | _T:
160+
rng = Random(seed)
161+
162+
if isinstance(stop, int):
163+
if start is None:
164+
start = 0
165+
if step is None:
166+
step = 1
167+
return rng.randrange(start, stop, step)
168+
169+
for arg_name, arg_value in [("start", start), ("stop", stop)]:
170+
if arg_value is None:
171+
raise TypeError(f'"{arg_name}" can only be used when "stop" is an integer')
172+
return rng.choice(stop)
173+
174+
175+
def _do_flatten(
176+
seq: Sequence[Any], levels: int | None = None, skip_nulls: bool = True
177+
) -> Sequence[Any]:
178+
if levels is not None:
179+
if levels < 1:
180+
return seq
181+
levels -= 1
182+
result: list[Any] = []
183+
for item in seq:
184+
if _is_sequence(item):
185+
result.extend(_do_flatten(item, levels, skip_nulls))
186+
elif not skip_nulls or item is not None:
187+
result.append(item)
188+
return result
189+
190+
191+
def _do_fileglob(pattern: str) -> Sequence[str]:
192+
return [str(path) for path in Path(".").glob(pattern) if path.is_file()]
193+
194+
195+
def _do_random_mac(prefix: str, seed: str | None = None) -> str:
196+
parts = prefix.lower().strip(":").split(":")
197+
if len(parts) > 5:
198+
raise ValueError(f"Invalid MAC address prefix {prefix}: too many parts")
199+
for part in parts:
200+
if not re.match(r"[a-f0-9]{2}", part):
201+
raise ValueError(
202+
f"Invalid MAC address prefix {prefix}: {part} is not a hexadecimal byte"
203+
)
204+
rng = Random(seed)
205+
return ":".join(
206+
parts + [f"{rng.randint(0, 255):02x}" for _ in range(6 - len(parts))]
207+
)
208+
209+
210+
def _do_regex_escape(
211+
pattern: str, re_type: Literal["python", "posix_basic"] = "python"
212+
) -> str:
213+
if re_type == "python":
214+
return re.escape(pattern)
215+
raise NotImplementedError(f"Regex type {re_type} not implemented")
216+
217+
218+
def _do_regex_search(
219+
value: str,
220+
pattern: str,
221+
*args: str,
222+
ignorecase: bool = False,
223+
multiline: bool = False,
224+
) -> str | list[str] | None:
225+
groups: list[str | int] = []
226+
for arg in args:
227+
if match := re.match(r"^\\g<(\S+)>$", arg):
228+
groups.append(match.group(1))
229+
elif match := re.match(r"^\\(\d+)$", arg):
230+
groups.append(int(match.group(1)))
231+
else:
232+
raise ValueError("Invalid backref format")
233+
234+
flags = 0
235+
if ignorecase:
236+
flags |= re.IGNORECASE
237+
if multiline:
238+
flags |= re.MULTILINE
239+
240+
return (match := re.search(pattern, value, flags)) and (
241+
list(result) if isinstance((result := match.group(*groups)), tuple) else result
242+
)
243+
244+
245+
def _do_regex_replace(
246+
value: str,
247+
pattern: str,
248+
replacement: str,
249+
*,
250+
ignorecase: bool = False,
251+
) -> str:
252+
return re.sub(pattern, replacement, value, flags=re.I if ignorecase else 0)
253+
254+
255+
def _do_regex_findall(
256+
value: str,
257+
pattern: str,
258+
*,
259+
ignorecase: bool = False,
260+
multiline: bool = False,
261+
) -> list[str]:
262+
flags = 0
263+
if ignorecase:
264+
flags |= re.IGNORECASE
265+
if multiline:
266+
flags |= re.MULTILINE
267+
return re.findall(pattern, value, flags)
268+
269+
270+
def _do_type_debug(value: object) -> str:
271+
return value.__class__.__name__
272+
273+
274+
@pass_environment
275+
def _do_extract(
276+
environment: Environment,
277+
key: Any,
278+
container: Any,
279+
*,
280+
morekeys: Any | Sequence[Any] | None = None,
281+
) -> Any | Undefined:
282+
keys: list[Any]
283+
if morekeys is None:
284+
keys = [key]
285+
elif _is_sequence(morekeys):
286+
keys = [key, *morekeys]
287+
else:
288+
keys = [key, morekeys]
289+
return reduce(environment.getitem, keys, container)
290+
291+
292+
class CopierExtension(Extension):
293+
"""Jinja2 extension for Copier."""
294+
295+
# NOTE: mypy disallows `Callable[[Any, ...], Any]`
296+
_filters: Final[dict[str, Callable[..., Any]]] = {
297+
"ans_groupby": do_groupby,
298+
"ans_random": _do_random,
299+
"b64decode": _do_b64decode,
300+
"b64encode": _do_b64encode,
301+
"basename": basename,
302+
"bool": _do_bool,
303+
"checksum": _do_sha1,
304+
"dirname": dirname,
305+
"expanduser": expanduser,
306+
"expandvars": expandvars,
307+
"extract": _do_extract,
308+
"fileglob": _do_fileglob,
309+
"flatten": _do_flatten,
310+
"from_json": from_json,
311+
"from_yaml": _do_from_yaml,
312+
"from_yaml_all": _do_from_yaml_all,
313+
"hash": _do_hash,
314+
"mandatory": _do_mandatory,
315+
"md5": _do_md5,
316+
"quote": quote,
317+
"random_mac": _do_random_mac,
318+
"realpath": realpath,
319+
"regex_escape": _do_regex_escape,
320+
"regex_findall": _do_regex_findall,
321+
"regex_replace": _do_regex_replace,
322+
"regex_search": _do_regex_search,
323+
"relpath": relpath,
324+
"sha1": _do_sha1,
325+
"shuffle": _do_shuffle,
326+
"splitext": splitext,
327+
"strftime": _do_strftime,
328+
"ternary": _do_ternary,
329+
"to_datetime": _do_to_datetime,
330+
"to_json": to_json,
331+
"to_nice_json": _do_to_nice_json,
332+
"to_nice_yaml": _do_to_nice_yaml,
333+
"to_uuid": _do_to_uuid,
334+
"to_yaml": _do_to_yaml,
335+
"type_debug": _do_type_debug,
336+
"win_basename": win_basename,
337+
"win_dirname": win_dirname,
338+
"win_splitdrive": win_splitdrive,
339+
}
340+
341+
def __init__(self, environment: Environment) -> None:
342+
super().__init__(environment)
343+
for k, v in self._filters.items():
344+
if k in environment.filters:
345+
warn(
346+
f'A filter named "{k}" already exists in the Jinja2 environment',
347+
category=RuntimeWarning,
348+
stacklevel=2,
349+
)
350+
else:
351+
environment.filters[k] = v

copier/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
UnsafeTemplateError,
4545
UserMessageError,
4646
)
47+
from .jinja_ext import CopierExtension
4748
from .subproject import Subproject
4849
from .template import Task, Template
4950
from .tools import (
@@ -547,9 +548,7 @@ def jinja_env(self) -> SandboxedEnvironment:
547548
"""
548549
paths = [str(self.template.local_abspath)]
549550
loader = FileSystemLoader(paths)
550-
default_extensions = [
551-
"jinja2_ansible_filters.AnsibleCoreFiltersExtension",
552-
]
551+
default_extensions = [CopierExtension]
553552
extensions = default_extensions + list(self.template.jinja_extensions)
554553
# We want to minimize the risk of hidden malware in the templates
555554
# so we use the SandboxedEnvironment instead of the regular one.

0 commit comments

Comments
 (0)