Skip to content

Commit

Permalink
feat(pytest): #36
Browse files Browse the repository at this point in the history
  • Loading branch information
cmp0xff committed Jul 27, 2024
1 parent 9b4bacb commit 6d28912
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tests/test_models/test_harmonic_oscillator_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from itertools import chain, product
from typing import Iterable, Mapping, Sequence

import numpy as np
import pytest

from hamilflow.models.harmonic_oscillator_chain import HarmonicOscillatorsChain

_wave_params: list[tuple[tuple[int, int], ...]] = [
((0, 0),),
((1, 0),),
((0, 1), (0, 1)),
]
_possible_wave_modes: list[dict[str, tuple[int, int]]] = [
dict(zip(("amp", "phi"), param, strict=False)) for param in _wave_params
]


class TestHarmonicOscillatorChain:
@pytest.fixture(params=(1, 2))
def omega(self, request: pytest.FixtureRequest) -> int:
return request.param

@pytest.fixture(params=((0, 0), (0, 1), (1, 0), (1, 1)))
def free_mode(self, request: pytest.FixtureRequest) -> dict[str, int]:
return dict(zip(("x0", "v0"), request.param))

@pytest.fixture(
params=chain.from_iterable(
product(_possible_wave_modes, repeat=r) for r in range(2)
)
)
def wave_modes(self, request: pytest.FixtureRequest) -> list[dict[str, int]]:
return request.param

@pytest.fixture(params=(False, True))
def odd_dof(self, request: pytest.FixtureRequest) -> bool:
return request.param

@pytest.fixture(params=(0, 1, (0, 1)))
def times(self, request: pytest.FixtureRequest) -> int | tuple[int]:
return request.param

def test_init(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[dict[str, int]],
odd_dof: bool,
) -> None:
assert HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof)

def test_real(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[dict[str, int]],
odd_dof: bool,
times: int | Sequence[int],
) -> None:
hoc = HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof)
original_zs, _ = hoc._z(times)
assert np.all(original_zs.imag == 0.0)

0 comments on commit 6d28912

Please sign in to comment.