Skip to content

Commit

Permalink
fix: #36 even degrees of freedom
Browse files Browse the repository at this point in the history
  • Loading branch information
cmp0xff committed Jul 28, 2024
1 parent 7ceefc8 commit d04c771
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
21 changes: 12 additions & 9 deletions hamilflow/models/harmonic_oscillator_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,29 @@ def __init__(
:param initial_conditions: a sequence of initial conditions on the Fourier modes.
The first element in the sequence is that of the zero mode, taking a position and a velocity.
Rest of the elements are that of the independent travelling waves, taking two amplitudes and two initial phases.
:param odd_dof: The system will have `2 * len(initial_conditions) - int(odd_dof)` degrees of freedom.
:param odd_dof: The system will have `2 * len(initial_conditions) + int(odd_dof) - 2` degrees of freedom.
"""
self.n_dof = 2 * len(initial_conditions) + odd_dof - 2
if not odd_dof:
prefix = "For even degrees of freedom, "
if self.n_dof == 0:
raise ValueError(prefix + "at least 1 travelling wave is needed")
amp = cast(tuple[float, float], initial_conditions[-1]["amp"])
if amp[0] != amp[1]:
msg = "k == N // 2 must have equal positive and negative amplitudes."
raise ValueError(prefix + msg)
self.omega = omega
self.n_independant_csho_dof = len(initial_conditions) - 1
self.odd_dof = odd_dof

self.free_mode = FreeParticle(cast(Mapping[str, float], initial_conditions[0]))

r_wave_modes_ic = initial_conditions[1:]
self.independent_csho_modes = [
self._sho_factory(
k,
cast(tuple[float, float], ic["amp"]),
cast(tuple[float, float] | None, ic.get("phi")),
)
for k, ic in enumerate(r_wave_modes_ic, 1)
for k, ic in enumerate(initial_conditions[1:], 1)
]

def _sho_factory(
Expand All @@ -67,10 +74,6 @@ def _sho_factory(
dict(x0=amp) | (dict(phi=phi) if phi else {}),
)

@cached_property
def n_dof(self) -> int:
return self.n_independant_csho_dof * 2 + self.odd_dof

@cached_property
def definition(
self,
Expand Down Expand Up @@ -107,7 +110,7 @@ def _z(
else (
independent_cshos[:-1],
independent_cshos[[-1]],
independent_cshos[-1::-1].conj(),
independent_cshos[-2::-1].conj(),
)
)

Expand Down
51 changes: 43 additions & 8 deletions tests/test_models/test_harmonic_oscillator_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,21 @@ def free_mode(self, request: pytest.FixtureRequest) -> dict[str, int]:
product(_possible_wave_modes, repeat=r) for r in range(3)
)
)
def wave_modes(self, request: pytest.FixtureRequest) -> list[dict[str, int]]:
def wave_modes(
self, request: pytest.FixtureRequest
) -> list[dict[str, tuple[int, int]]]:
return request.param

@pytest.fixture(params=(False, True))
def odd_dof(self, request: pytest.FixtureRequest) -> bool:
return request.param

@pytest.fixture()
def legal_wave_modes_and_odd_def(
self, wave_modes: Iterable[Mapping[str, tuple[int, int]]], odd_dof: bool
) -> tuple[Iterable[Mapping[str, tuple[int, int]]], bool]:
return wave_modes if odd_dof else chain(wave_modes, [dict(amp=(1, 1))]), odd_dof

@pytest.fixture(params=(0, 1, (0, 1)))
def times(self, request: pytest.FixtureRequest) -> int | tuple[int]:
return request.param
Expand All @@ -45,19 +53,46 @@ def test_init(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[Mapping[str, int]],
odd_dof: bool,
legal_wave_modes_and_odd_def: tuple[
Iterable[Mapping[str, tuple[int, int]]], bool
],
) -> None:
wave_modes, odd_dof = legal_wave_modes_and_odd_def
assert HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof)

def test_real(
@pytest.fixture()
def hoc_and_zs(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[Mapping[str, int]],
odd_dof: bool,
legal_wave_modes_and_odd_def: tuple[
Iterable[Mapping[str, tuple[int, int]]], bool
],
times: int | Sequence[int],
) -> None:
) -> tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray]:
wave_modes, odd_dof = legal_wave_modes_and_odd_def
hoc = HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof)
original_zs, _ = hoc._z(times)
return (hoc, *hoc._z(times))

def test_real(
self, hoc_and_zs: tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray]
) -> None:
_, original_zs, _ = hoc_and_zs
assert np.all(original_zs.imag == 0.0)

def test_dof(
self, hoc_and_zs: tuple[HarmonicOscillatorsChain, np.ndarray, np.ndarray]
) -> None:
hoc, original_zs, _ = hoc_and_zs
assert original_zs.shape[0] == hoc.n_dof

@pytest.mark.parametrize("wave_mode", [None, *_possible_wave_modes[1:]])
def test_raise(
self,
omega: int,
free_mode: Mapping[str, int] | None,
wave_mode: Mapping[str, int],
) -> None:
ics = [free_mode, *([wave_mode] if wave_mode else [])]
with pytest.raises(ValueError):
HarmonicOscillatorsChain(omega, ics, False)

0 comments on commit d04c771

Please sign in to comment.