diff --git a/hamilflow/models/discrete/d0/free_particle.py b/hamilflow/models/discrete/d0/free_particle.py index 8512bd0..2e7e3e8 100644 --- a/hamilflow/models/discrete/d0/free_particle.py +++ b/hamilflow/models/discrete/d0/free_particle.py @@ -23,11 +23,11 @@ class FreeParticleIC(BaseModel): @model_validator(mode="after") def check_dimensions_match(self) -> Self: - assert ( - len(self.x0) == len(cast(Sequence, self.v0)) - if isinstance(self.x0, Sequence) - else not isinstance(self.v0, Sequence) - ) + if (x0_seq := isinstance(self.x0, Sequence)) != isinstance(self.v0, Sequence): + raise TypeError("x0 and v0 needs both to be scalars or Sequences") + elif x0_seq and len(cast(Sequence, self.x0)) != len(cast(Sequence, self.v0)): + raise ValueError("Sequences x0 and v0 needs to have the same length") + return self diff --git a/tests/test_models/discrete/d0/test_free_particle.py b/tests/test_models/discrete/d0/test_free_particle.py index b2a6731..574bee0 100644 --- a/tests/test_models/discrete/d0/test_free_particle.py +++ b/tests/test_models/discrete/d0/test_free_particle.py @@ -15,9 +15,13 @@ def test_constructor( ) -> None: assert FreeParticleIC(x0=x0, v0=v0) - @pytest.mark.parametrize(("x0", "v0"), [(1, (2,)), ((1,), (2, 3))]) - def test_raise(self, x0: int | Sequence[int], v0: int | Sequence[int]) -> None: - with pytest.raises(ValidationError): + @pytest.mark.parametrize( + ("x0", "v0", "expected"), [(1, (2,), TypeError), ((1,), (2, 3), ValueError)] + ) + def test_raise( + self, x0: int | Sequence[int], v0: Sequence[int], expected: type[Exception] + ) -> None: + with pytest.raises(expected): FreeParticleIC(x0=x0, v0=v0)