diff --git a/pyproject.toml b/pyproject.toml index eb708baf..216789cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "numpy", "jsonschema>=2.6.0", "pydantic>=2.7.1", - "mat3ra-esse>=2025.4.15.post0", + "mat3ra-esse>=2025.7.1-0", "mat3ra-utils>=2024.5.15.post0", ] diff --git a/src/py/mat3ra/code/array_with_ids.py b/src/py/mat3ra/code/array_with_ids.py index 087052c0..7d305d0d 100644 --- a/src/py/mat3ra/code/array_with_ids.py +++ b/src/py/mat3ra/code/array_with_ids.py @@ -74,7 +74,7 @@ def filter_by_indices(self, indices: Union[List[int], int]): self.values = [self.values[i] for i in range(len(self.values)) if i in index_set] self.ids = [self.ids[i] for i in range(len(self.ids)) if i in index_set] - def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False): + def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False, reset_ids: bool = False): if isinstance(ids, int): ids = [ids] if not invert: @@ -84,6 +84,8 @@ def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False): keep_indices = [index for index, id_ in enumerate(self.ids) if id_ in ids_set] self.values = [self.values[index] for index in keep_indices] self.ids = [self.ids[index] for index in keep_indices] + if reset_ids: + self.ids = list(range(len(self.values))) def __eq__(self, other: object) -> bool: return isinstance(other, ArrayWithIds) and self.values == other.values and self.ids == other.ids diff --git a/src/py/mat3ra/code/vector.py b/src/py/mat3ra/code/vector.py index 7afd30c4..ac534760 100644 --- a/src/py/mat3ra/code/vector.py +++ b/src/py/mat3ra/code/vector.py @@ -1,7 +1,7 @@ from typing import List import numpy as np -from mat3ra.esse.models.core.abstract.point import PointSchema as Vector3DSchema +from mat3ra.esse.models.core.abstract.vector_3d import Vector3dSchema as Vector3DSchema from mat3ra.utils.mixins import RoundNumericValuesMixin from pydantic import model_serializer diff --git a/tests/py/unit/test_array_with_ids.py b/tests/py/unit/test_array_with_ids.py index 676a24e7..80fbf01d 100644 --- a/tests/py/unit/test_array_with_ids.py +++ b/tests/py/unit/test_array_with_ids.py @@ -208,6 +208,32 @@ def test_filter_by_ids(): ) +def test_filter_by_ids_reset_ids(): + """Test that reset_ids parameter resets IDs to consecutive integers starting from 0""" + instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE) + # Filter to keep only first and third elements (ids 2 and 6) + instance.filter_by_ids([2, 6], reset_ids=True) + + # Values should be filtered correctly + assert instance.values == [ + ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][0], + ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][2], + ] + # IDs should be reset to consecutive integers starting from 0 + assert instance.ids == [0, 1] + + # Test with invert=True + instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE) + # Filter to exclude first element (id 2), keep second and third (ids 4, 6) + instance.filter_by_ids([2], invert=True, reset_ids=True) + + assert instance.values == [ + ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][1], + ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][2], + ] + assert instance.ids == [0, 1] + + def test_filter_by_ids_invert(): instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG) instance.filter_by_ids([0, 2], invert=True)