diff --git a/hamilflow/models/harmonic_oscillator_chain.py b/hamilflow/models/harmonic_oscillator_chain.py index da3ece9..dccd551 100644 --- a/hamilflow/models/harmonic_oscillator_chain.py +++ b/hamilflow/models/harmonic_oscillator_chain.py @@ -36,22 +36,29 @@ def __init__( :param initial_conditions: a sequence of initial conditions on the Fourier modes. The first element in the sequence is that of the zero mode, taking a position and a velocity. Rest of the elements are that of the independent travelling waves, taking two amplitudes and two initial phases. - :param odd_dof: The system will have `2 * len(initial_conditions) - int(odd_dof)` degrees of freedom. + :param odd_dof: The system will have `2 * len(initial_conditions) + int(odd_dof) - 2` degrees of freedom. """ + self.n_dof = 2 * len(initial_conditions) + odd_dof - 2 + if not odd_dof: + prefix = "For even degrees of freedom, " + if self.n_dof == 0: + raise ValueError(prefix + "at least 1 travelling wave is needed") + amp = cast(tuple[float, float], initial_conditions[-1]["amp"]) + if amp[0] != amp[1]: + msg = "k == N // 2 must have equal positive and negative amplitudes." + raise ValueError(prefix + msg) self.omega = omega - self.n_independant_csho_dof = len(initial_conditions) - 1 self.odd_dof = odd_dof self.free_mode = FreeParticle(cast(Mapping[str, float], initial_conditions[0])) - r_wave_modes_ic = initial_conditions[1:] self.independent_csho_modes = [ self._sho_factory( k, cast(tuple[float, float], ic["amp"]), cast(tuple[float, float] | None, ic.get("phi")), ) - for k, ic in enumerate(r_wave_modes_ic, 1) + for k, ic in enumerate(initial_conditions[1:], 1) ] def _sho_factory( @@ -67,10 +74,6 @@ def _sho_factory( dict(x0=amp) | (dict(phi=phi) if phi else {}), ) - @cached_property - def n_dof(self) -> int: - return self.n_independant_csho_dof * 2 + self.odd_dof - @cached_property def definition( self, @@ -107,7 +110,7 @@ def _z( else ( independent_cshos[:-1], independent_cshos[[-1]], - independent_cshos[-1::-1].conj(), + independent_cshos[-2::-1].conj(), ) ) diff --git a/tests/test_models/test_harmonic_oscillator_chain.py b/tests/test_models/test_harmonic_oscillator_chain.py index 3f0c3d8..35a4f6c 100644 --- a/tests/test_models/test_harmonic_oscillator_chain.py +++ b/tests/test_models/test_harmonic_oscillator_chain.py @@ -30,13 +30,21 @@ def free_mode(self, request: pytest.FixtureRequest) -> dict[str, int]: product(_possible_wave_modes, repeat=r) for r in range(3) ) ) - def wave_modes(self, request: pytest.FixtureRequest) -> list[dict[str, int]]: + def wave_modes( + self, request: pytest.FixtureRequest + ) -> list[dict[str, tuple[int, int]]]: return request.param @pytest.fixture(params=(False, True)) def odd_dof(self, request: pytest.FixtureRequest) -> bool: return request.param + @pytest.fixture() + def legal_wave_modes_and_odd_def( + self, wave_modes: Iterable[Mapping[str, tuple[int, int]]], odd_dof: bool + ) -> tuple[Iterable[Mapping[str, tuple[int, int]]], bool]: + return wave_modes if odd_dof else chain(wave_modes, [dict(amp=(1, 1))]), odd_dof + @pytest.fixture(params=(0, 1, (0, 1))) def times(self, request: pytest.FixtureRequest) -> int | tuple[int]: return request.param @@ -45,19 +53,46 @@ def test_init( self, omega: int, free_mode: Mapping[str, int], - wave_modes: Iterable[Mapping[str, int]], - odd_dof: bool, + legal_wave_modes_and_odd_def: tuple[ + Iterable[Mapping[str, tuple[int, int]]], bool + ], ) -> None: + wave_modes, odd_dof = legal_wave_modes_and_odd_def assert HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof) - def test_real( + @pytest.fixture() + def hoc_and_zs( self, omega: int, free_mode: Mapping[str, int], - wave_modes: Iterable[Mapping[str, int]], - odd_dof: bool, + legal_wave_modes_and_odd_def: tuple[ + Iterable[Mapping[str, tuple[int, int]]], bool + ], times: int | Sequence[int], - ) -> None: + ) -> tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray]: + wave_modes, odd_dof = legal_wave_modes_and_odd_def hoc = HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof) - original_zs, _ = hoc._z(times) + return (hoc, *hoc._z(times)) + + def test_real( + self, hoc_and_zs: tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray] + ) -> None: + _, original_zs, _ = hoc_and_zs assert np.all(original_zs.imag == 0.0) + + def test_dof( + self, hoc_and_zs: tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray] + ) -> None: + hoc, original_zs, _ = hoc_and_zs + assert original_zs.shape[0] == hoc.n_dof + + @pytest.mark.parametrize("wave_mode", [None, *_possible_wave_modes[1:]]) + def test_raise( + self, + omega: int, + free_mode: Mapping[str, int] | None, + wave_mode: Mapping[str, int], + ) -> None: + ics = [free_mode, *([wave_mode] if wave_mode else [])] + with pytest.raises(ValueError): + HarmonicOscillatorsChain(omega, ics, False)