Skip to content

Commit 270fcbf

Browse files
committed
Add tests and fixes for update
1 parent 1235def commit 270fcbf

File tree

3 files changed

+95
-72
lines changed

3 files changed

+95
-72
lines changed

jsonpath_rw/jsonpath.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def find(self, data):
2626
raise NotImplementedError()
2727

2828
def update(self, data, val):
29-
"Returns `data` with the specified path replaced by `val`"
29+
"""
30+
Returns `data` with the specified path replaced by `val`. Only updates
31+
if the specified path exists.
32+
"""
33+
3034
raise NotImplementedError()
3135

3236
def child(self, child):
@@ -340,8 +344,30 @@ def is_singular():
340344
return False
341345

342346
def update(self, data, val):
343-
for datum in self.left.find(data):
344-
self.right.update(datum.value, val)
347+
# Get all left matches into a list
348+
left_matches = self.left.find(data)
349+
if not isinstance(left_matches, list):
350+
left_matches = [left_matches]
351+
352+
def update_recursively(data):
353+
# Update only mutable values corresponding to JSON types
354+
if not (isinstance(data, list) or isinstance(data, dict)):
355+
return
356+
357+
self.right.update(data, val)
358+
359+
# Manually do the * or [*] to avoid coercion and recurse just the right-hand pattern
360+
if isinstance(data, list):
361+
for i in range(0, len(data)):
362+
update_recursively(data[i])
363+
364+
elif isinstance(data, dict):
365+
for field in data.keys():
366+
update_recursively(data[field])
367+
368+
for submatch in left_matches:
369+
update_recursively(submatch.value)
370+
345371
return data
346372

347373
def __str__(self):
@@ -432,7 +458,8 @@ def find(self, datum):
432458

433459
def update(self, data, val):
434460
for field in self.reified_fields(DatumInContext.wrap(data)):
435-
data[field] = val
461+
if field in data:
462+
data[field] = val
436463
return data
437464

438465
def __str__(self):
@@ -466,7 +493,8 @@ def find(self, datum):
466493
return []
467494

468495
def update(self, data, val):
469-
data[self.index] = val
496+
if len(data) > self.index:
497+
data[self.index] = val
470498
return data
471499

472500
def __eq__(self, other):

tests/test_jsonpath.py

+62
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,65 @@ def test_descendants_auto_id(self):
290290
} },
291291
['foo.baz',
292292
'foo.bing.baz'] )])
293+
294+
def check_update_cases(self, test_cases):
295+
for original, expr_str, value, expected in test_cases:
296+
print('parse(%r).update(%r, %r) =?= %r'
297+
% (expr_str, original, value, expected))
298+
expr = parse(expr_str)
299+
actual = expr.update(original, value)
300+
assert actual == expected
301+
302+
def test_update_root(self):
303+
self.check_update_cases([
304+
('foo', '$', 'bar', 'bar')
305+
])
306+
307+
def test_update_this(self):
308+
self.check_update_cases([
309+
('foo', '`this`', 'bar', 'bar')
310+
])
311+
312+
def test_update_fields(self):
313+
self.check_update_cases([
314+
({'foo': 1}, 'foo', 5, {'foo': 5}),
315+
({'foo': 1, 'bar': 2}, '$.*', 3, {'foo': 3, 'bar': 3})
316+
])
317+
318+
def test_update_child(self):
319+
self.check_update_cases([
320+
({'foo': 'bar'}, '$.foo', 'baz', {'foo': 'baz'}),
321+
({'foo': {'bar': 1}}, 'foo.bar', 'baz', {'foo': {'bar': 'baz'}})
322+
])
323+
324+
def test_update_where(self):
325+
self.check_update_cases([
326+
({'foo': {'bar': {'baz': 1}}, 'bar': {'baz': 2}},
327+
'*.bar where baz', 5, {'foo': {'bar': 5}, 'bar': {'baz': 2}})
328+
])
329+
330+
def test_update_descendants_where(self):
331+
self.check_update_cases([
332+
({'foo': {'bar': 1, 'flag': 1}, 'baz': {'bar': 2}},
333+
'(* where flag) .. bar', 3,
334+
{'foo': {'bar': 3, 'flag': 1}, 'baz': {'bar': 2}})
335+
])
336+
337+
def test_update_descendants(self):
338+
self.check_update_cases([
339+
({'somefield': 1}, '$..somefield', 42, {'somefield': 42}),
340+
({'outer': {'nestedfield': 1}}, '$..nestedfield', 42, {'outer': {'nestedfield': 42}}),
341+
({'outs': {'bar': 1, 'ins': {'bar': 9}}, 'outs2': {'bar': 2}},
342+
'$..bar', 42,
343+
{'outs': {'bar': 42, 'ins': {'bar': 42}}, 'outs2': {'bar': 42}})
344+
])
345+
346+
def test_update_index(self):
347+
self.check_update_cases([
348+
(['foo', 'bar', 'baz'], '[0]', 'test', ['test', 'bar', 'baz'])
349+
])
350+
351+
def test_update_slice(self):
352+
self.check_update_cases([
353+
(['foo', 'bar', 'baz'], '[0:2]', 'test', ['test', 'test', 'baz'])
354+
])

tests/test_update.py

-67
This file was deleted.

0 commit comments

Comments
 (0)