Skip to content

Commit 9407000

Browse files
committed
add custom yaml tags initial support
1 parent 6bdb908 commit 9407000

8 files changed

+192
-2
lines changed

nested_diff/cli.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sys
2020

2121
import nested_diff
22+
import nested_diff.handlers
2223

2324

2425
class App:
@@ -220,6 +221,7 @@ def load(self, file_):
220221
fmt = self.args.ifmt
221222

222223
fmt_opts = self._decode_fmt_opts(self.args.ifmt_opts)
224+
223225
return self.get_loader(fmt, **fmt_opts).load(file_)
224226

225227
@staticmethod
@@ -550,15 +552,105 @@ def __init__(self, **kwargs):
550552

551553
import yaml
552554

555+
from yaml.nodes import (
556+
ScalarNode as YamlScalarNode,
557+
SequenceNode as YamlSequenceNode,
558+
MappingNode as YamlMappingNode,
559+
)
560+
553561
try:
554562
from yaml import CSafeLoader as YamlLoader
555563
except ImportError:
556564
from yaml import SafeLoader as YamlLoader
557565

566+
self.opts = self.get_opts(kwargs)
558567
self.yaml = yaml
559568
self.yaml_loader = YamlLoader
560-
self.opts = self.get_opts(kwargs)
569+
570+
def _default_constructor(loader, tag_suffix, node): # noqa: ARG001
571+
tag = node.tag
572+
573+
if isinstance(node, YamlScalarNode):
574+
value = node.value
575+
elif isinstance(node, YamlSequenceNode):
576+
node.tag = 'tag:yaml.org,2002:seq'
577+
value = loader.construct_sequence(node, deep=True)
578+
elif isinstance(node, YamlMappingNode):
579+
node.tag = 'tag:yaml.org,2002:map'
580+
value = loader.construct_mapping(node, deep=True)
581+
582+
return YamlNode(tag, value)
583+
584+
self.yaml_loader.add_multi_constructor(None, _default_constructor)
561585

562586
def decode(self, data):
563587
"""Parse YAML string."""
564588
return self.yaml.load(data, Loader=self.yaml_loader, **self.opts)
589+
590+
591+
class YamlNode:
592+
"""Wrapper to represent YAML node."""
593+
594+
def __init__(self, tag, value):
595+
"""Initialize wrapper.
596+
597+
Args:
598+
tag: YAML node tag.
599+
value: YAML node value.
600+
601+
"""
602+
self.tag = tag
603+
self.value = value
604+
605+
def __repr__(self):
606+
"""Repr for YAML node wrapper."""
607+
return f"YamlNode(tag='{self.tag}', value={self.value!r})"
608+
609+
610+
class YamlNodeHandler(nested_diff.handlers.TypeHandler):
611+
"""YamlNode handler."""
612+
613+
extension_id = 'nested_diff.YamlNode'
614+
handled_type = YamlNode
615+
616+
def diff(self, differ, a, b):
617+
"""Calculate diff for two YamlNode objects.
618+
619+
Args:
620+
differ: nested_diff.Differ object.
621+
a: First node to diff.
622+
b: Second node to diff.
623+
624+
Returns:
625+
Tuple: equality flag and nested diff.
626+
627+
"""
628+
equal, _ = differ.diff(a.tag, b.tag)
629+
630+
if not equal:
631+
return equal, {'N': b, 'O': a, 'E': self.extension_id}
632+
633+
equal, diff = differ.diff(a.value, b.value)
634+
635+
if diff:
636+
diff = {
637+
'D': {'value': diff},
638+
'E': self.extension_id,
639+
'tag': a.tag,
640+
}
641+
642+
return equal, diff
643+
644+
def generate_formatted_diff(self, formatter, diff, depth):
645+
"""Generate formatted YamlNode diff."""
646+
if 'D' in diff:
647+
yield from formatter.generate_string(diff['tag'], 'D', depth)
648+
yield from formatter.generate_diff(diff['D']['value'], depth + 1)
649+
650+
return
651+
652+
yield from formatter.generate_string(diff['O'].tag, 'O', depth)
653+
yield from formatter.generate_value(diff['O'].value, 'O', depth + 1)
654+
655+
yield from formatter.generate_string(diff['N'].tag, 'N', depth)
656+
yield from formatter.generate_value(diff['N'].value, 'N', depth + 1)

nested_diff/diff_tool.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def diff(self, a, b, **kwargs):
5757
diff_opts['R'] = int(diff_opts['R'])
5858

5959
differ = nested_diff.Differ(**diff_opts)
60+
differ.set_handler(nested_diff.cli.YamlNodeHandler())
6061

6162
if self.args.text_ctx >= 0:
6263
differ.set_handler(
@@ -286,6 +287,7 @@ def __init__(
286287

287288
fmt_class = self.get_formatter_class(base_class, values=values)
288289
self.encoder = fmt_class(**self.get_opts(kwargs))
290+
self.encoder.set_handler(nested_diff.cli.YamlNodeHandler())
289291

290292
if header is None:
291293
if hasattr(self.encoder, 'get_page_header'):

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ requires-python = '>=3.7'
4444

4545
[project.optional-dependencies]
4646
cli = [
47-
'pyyaml',
47+
'pyyaml >= 5.3',
4848
'tomli >= 1.1.0 ; python_version < "3.11"',
4949
'tomli-w >= 1.0.0'
5050
]

tests/cli/shared.custom_tags.a.yaml

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
changed_value: !changed_value old_value
3+
changed_tag: !changed_tag changed_tag_value
4+
unchanged_value: !unchanged_value unchanged_value
5+
removed: !removed
6+
7+
recursive_mapping: !mapping
8+
child: !mapping
9+
var: bar
10+
var: foo
11+
12+
sequence: !sequence
13+
- foo
14+
- bar
15+
- baz

tests/cli/shared.custom_tags.b.yaml

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
added: !added
3+
changed_value: !changed_value new_value
4+
changed_tag: !Changed_tag changed_tag_value
5+
unchanged_value: !unchanged_value unchanged_value
6+
7+
recursive_mapping: !mapping
8+
child: !mapping
9+
var: baz
10+
var: foo
11+
12+
sequence: !sequence
13+
- foo
14+
- baz
15+
- bar

tests/cli/test_cli.py

+8
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ def test_get_loader_unsupported_fmt():
110110
cli.App(args=()).get_loader('garbage')
111111

112112

113+
def test_loader_yaml_custom_tags():
114+
loaded = cli.YamlLoader().decode('!custom_tag')
115+
116+
assert isinstance(loaded, cli.YamlNode)
117+
assert loaded.tag == '!custom_tag'
118+
assert loaded.value == ''
119+
120+
113121
def test_run():
114122
with pytest.raises(NotImplementedError):
115123
cli.App(args=()).run()

tests/cli/test_diff_tool.py

+17
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,20 @@ def test_values_yaml_multiline_strings(capsys, expected, rpath):
748748
assert exit_code == 0
749749

750750
assert captured.out == expected
751+
752+
753+
def test_yaml_custom_tags(capsys, expected, rpath):
754+
exit_code = nested_diff.diff_tool.App(
755+
args=(
756+
rpath('shared.custom_tags.a.yaml'),
757+
rpath('shared.custom_tags.b.yaml'),
758+
'-U',
759+
'1',
760+
),
761+
).run()
762+
763+
captured = capsys.readouterr()
764+
assert captured.err == ''
765+
assert exit_code == 1
766+
767+
assert captured.out == expected
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
+ {'added'}
2+
+ YamlNode(tag='!added', value='')
3+
{'changed_tag'}
4+
# <YamlNode>
5+
- !changed_tag
6+
- 'changed_tag_value'
7+
+ !Changed_tag
8+
+ 'changed_tag_value'
9+
{'changed_value'}
10+
# <YamlNode>
11+
!changed_value
12+
- 'old_value'
13+
+ 'new_value'
14+
{'recursive_mapping'}
15+
# <YamlNode>
16+
!mapping
17+
{'child'}
18+
# <YamlNode>
19+
!mapping
20+
{'var'}
21+
- 'bar'
22+
+ 'baz'
23+
{'var'}
24+
'foo'
25+
- {'removed'}
26+
- YamlNode(tag='!removed', value='')
27+
{'sequence'}
28+
# <YamlNode>
29+
!sequence
30+
[0]
31+
'foo'
32+
+ [1]
33+
+ 'baz'
34+
[2]
35+
'bar'
36+
- [3]
37+
- 'baz'
38+
{'unchanged_value'}
39+
# <YamlNode>
40+
!unchanged_value
41+
'unchanged_value'

0 commit comments

Comments
 (0)