Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"numpy",
"jsonschema>=2.6.0",
"pydantic>=2.7.1",
"mat3ra-esse>=2025.4.15.post0",
"mat3ra-esse>=2025.7.1-0",
"mat3ra-utils>=2024.5.15.post0",
]

Expand Down
4 changes: 3 additions & 1 deletion src/py/mat3ra/code/array_with_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def filter_by_indices(self, indices: Union[List[int], int]):
self.values = [self.values[i] for i in range(len(self.values)) if i in index_set]
self.ids = [self.ids[i] for i in range(len(self.ids)) if i in index_set]

def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False):
def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False, reset_ids: bool = False):
if isinstance(ids, int):
ids = [ids]
if not invert:
Expand All @@ -84,6 +84,8 @@ def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False):
keep_indices = [index for index, id_ in enumerate(self.ids) if id_ in ids_set]
self.values = [self.values[index] for index in keep_indices]
self.ids = [self.ids[index] for index in keep_indices]
if reset_ids:
self.ids = list(range(len(self.values)))

def __eq__(self, other: object) -> bool:
return isinstance(other, ArrayWithIds) and self.values == other.values and self.ids == other.ids
Expand Down
2 changes: 1 addition & 1 deletion src/py/mat3ra/code/vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

import numpy as np
from mat3ra.esse.models.core.abstract.point import PointSchema as Vector3DSchema
from mat3ra.esse.models.core.abstract.vector_3d import Vector3dSchema as Vector3DSchema
from mat3ra.utils.mixins import RoundNumericValuesMixin
from pydantic import model_serializer

Expand Down
26 changes: 26 additions & 0 deletions tests/py/unit/test_array_with_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,32 @@ def test_filter_by_ids():
)


def test_filter_by_ids_reset_ids():
"""Test that reset_ids parameter resets IDs to consecutive integers starting from 0"""
instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE)
# Filter to keep only first and third elements (ids 2 and 6)
instance.filter_by_ids([2, 6], reset_ids=True)

# Values should be filtered correctly
assert instance.values == [
ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][0],
ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][2],
]
# IDs should be reset to consecutive integers starting from 0
assert instance.ids == [0, 1]

# Test with invert=True
instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE)
# Filter to exclude first element (id 2), keep second and third (ids 4, 6)
instance.filter_by_ids([2], invert=True, reset_ids=True)

assert instance.values == [
ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][1],
ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG_NON_CONSECUTIVE["values"][2],
]
assert instance.ids == [0, 1]


def test_filter_by_ids_invert():
instance = ArrayWithIds(**ARRAY_WITH_IDS_ARRAYS_OF_FLOAT_VALUES_CONFIG)
instance.filter_by_ids([0, 2], invert=True)
Expand Down
Loading