diff --git a/jsonpath_ng/jsonpath.py b/jsonpath_ng/jsonpath.py index e6e287c..334b291 100644 --- a/jsonpath_ng/jsonpath.py +++ b/jsonpath_ng/jsonpath.py @@ -643,8 +643,8 @@ class Index(JSONPath): NOTE: For the concrete syntax of `[*]`, the abstract syntax is a Slice() with no parameters (equiv to `[:]` """ - def __init__(self, index): - self.index = index + def __init__(self, *indices): + self.indices = indices def find(self, datum): return self._find_base(datum, create=False) @@ -658,10 +658,12 @@ def _find_base(self, datum, create): if datum.value == {}: datum.value = _create_list_key(datum.value) self._pad_value(datum.value) - if datum.value and len(datum.value) > self.index: - return [DatumInContext(datum.value[self.index], path=self, context=datum)] - else: - return [] + rv = [] + for index in self.indices: + # invalid indices do not crash, return [] instead + if datum.value and len(datum.value) > index: + rv += [DatumInContext(datum.value[index], path=self, context=datum)] + return rv def update(self, data, val): return self._update_base(data, val, create=False) @@ -675,28 +677,40 @@ def _update_base(self, data, val, create): data = _create_list_key(data) self._pad_value(data) if hasattr(val, '__call__'): - data[self.index] = val.__call__(data[self.index], data, self.index) - elif len(data) > self.index: - data[self.index] = val + for index in self.indices: + val.__call__(data[index], data, index) + else: + for index in self.indices: + if len(data) > index: + try: + if isinstance(val, list): + # allows somelist[5,1,2] = [some_value, another_value, third_value] + data[index] = val.pop(0) + else: + data[index] = val + except Exception as e: + raise e return data def filter(self, fn, data): - if fn(data[self.index]): - data.pop(self.index) # relies on mutation :( + for index in self.indices: + if fn(data[index]): + data.pop(index) # relies on mutation :( return data def __eq__(self, other): - return isinstance(other, Index) and self.index == other.index + return isinstance(other, Index) and sorted(self.indices) == sorted(other.indices) def __str__(self): - return '[%i]' % self.index + return '[%i]' % self.indices def __repr__(self): - return '%s(index=%r)' % (self.__class__.__name__, self.index) + return '%s(indices=%r)' % (self.__class__.__name__, self.indices) def _pad_value(self, value): - if len(value) <= self.index: - pad = self.index - len(value) + 1 + _max = max(self.indices) + if len(value) <= _max: + pad = _max - len(value) + 1 value += [{} for __ in range(pad)] def __hash__(self): diff --git a/jsonpath_ng/parser.py b/jsonpath_ng/parser.py index 3c6f37b..8619ba1 100644 --- a/jsonpath_ng/parser.py +++ b/jsonpath_ng/parser.py @@ -116,7 +116,7 @@ def p_jsonpath_root(self, p): def p_jsonpath_idx(self, p): "jsonpath : '[' idx ']'" - p[0] = p[2] + p[0] = Index(*p[2]) def p_jsonpath_slice(self, p): "jsonpath : '[' slice ']'" @@ -132,7 +132,7 @@ def p_jsonpath_child_fieldbrackets(self, p): def p_jsonpath_child_idxbrackets(self, p): "jsonpath : jsonpath '[' idx ']'" - p[0] = Child(p[1], p[3]) + p[0] = Child(p[1], Index(*p[3])) def p_jsonpath_child_slicebrackets(self, p): "jsonpath : jsonpath '[' slice ']'" @@ -161,7 +161,11 @@ def p_fields_comma(self, p): def p_idx(self, p): "idx : NUMBER" - p[0] = Index(p[1]) + p[0] = [p[1]] + + def p_idx_comma(self, p): + "idx : idx ',' idx " + p[0] = p[1] + p[3] def p_slice_any(self, p): "slice : '*'" diff --git a/tests/test_jsonpath.py b/tests/test_jsonpath.py index 20d4f11..f97d2ba 100644 --- a/tests/test_jsonpath.py +++ b/tests/test_jsonpath.py @@ -72,6 +72,8 @@ def test_datumincontext_in_context_nested(): # ------- # ("[0]", ["foo", "bar", "baz"], "test", ["test", "bar", "baz"]), + ("[0, 1]", ["foo", "bar", "baz"], "test", ["test", "test", "baz"]), + ("[0, 1]", ["foo", "bar", "baz"], ["test", "test 1"], ["test", "test 1", "baz"]), # # Slices # ------ @@ -156,7 +158,8 @@ def test_datumincontext_in_context_nested(): @parsers def test_update(parse, expression, data, update_value, expected_value): data_copy = copy.deepcopy(data) - result = parse(expression).update(data_copy, update_value) + update_value_copy = copy.deepcopy(update_value) + result = parse(expression).update(data_copy, update_value_copy) assert result == expected_value