Skip to content

Commit

Permalink
chore(typing): #69
Browse files Browse the repository at this point in the history
  • Loading branch information
cmp0xff committed Jul 28, 2024
1 parent 4c1897e commit 075efb1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
29 changes: 13 additions & 16 deletions hamilflow/models/harmonic_oscillator_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from scipy.fft import ifft

from .free_particle import FreeParticle
Expand Down Expand Up @@ -41,35 +42,31 @@ class HarmonicOscillatorsChain:

def __init__(
self,
omega: float | int,
initial_conditions: Sequence[
Mapping[str, float | int | tuple[float | int, float | int]]
],
omega: float,
initial_conditions: Sequence[Mapping[str, float | tuple[float, float]]],
odd_dof: bool,
) -> None:
self.omega = omega
self.n_independant_csho_dof = len(initial_conditions) - 1
self.odd_dof = odd_dof

self.free_mode = FreeParticle(
cast(Mapping[str, float | int], initial_conditions[0])
)
self.free_mode = FreeParticle(cast(Mapping[str, float], initial_conditions[0]))

r_wave_modes_ic = initial_conditions[1:]
self.independent_csho_modes = [
self._sho_factory(
k,
cast(tuple[float | int, float | int], ic["amp"]),
cast(tuple[float | int, float | int] | None, ic.get("phi")),
cast(tuple[float, float], ic["amp"]),
cast(tuple[float, float] | None, ic.get("phi")),
)
for k, ic in enumerate(r_wave_modes_ic, 1)
]

def _sho_factory(
self,
k: int,
amp: tuple[float | int, float | int],
phi: tuple[float | int, float | int] | None = None,
amp: tuple[float, float],
phi: tuple[float, float] | None = None,
) -> ComplexSimpleHarmonicOscillator:
return ComplexSimpleHarmonicOscillator(
dict(
Expand All @@ -89,8 +86,8 @@ def definition(
str,
float
| int
| dict[str, dict[str, int | float | list[int | float]]]
| list[dict[str, dict[str, float | int | tuple[float | int, float | int]]]],
| dict[str, dict[str, float | list[float]]]
| list[dict[str, dict[str, float | tuple[float, float]]]],
]:
"""model params and initial conditions defined as a dictionary."""
return dict(
Expand All @@ -103,7 +100,7 @@ def definition(
)

def _z(
self, t: float | int | Sequence[float | int]
self, t: "Sequence[float] | ArrayLike[float]"
) -> tuple[np.ndarray, np.ndarray]:
t = np.array(t, copy=False).reshape(-1)
all_travelling_waves = [self.free_mode._x(t).reshape(1, -1)]
Expand All @@ -127,13 +124,13 @@ def _z(
return original_zs, travelling_waves

def _x(
self, t: float | int | Sequence[float | int]
self, t: "Sequence[float] | ArrayLike[float]"
) -> tuple[np.ndarray, np.ndarray]:
original_xs, travelling_waves = self._z(t)

return np.real(original_xs), travelling_waves

def __call__(self, t: float | int | Sequence[float | int]) -> pd.DataFrame:
def __call__(self, t: "Sequence[float] | ArrayLike[float]") -> pd.DataFrame:
"""Generate time series data for the harmonic oscillator chain.
Returns float(s) representing the displacement at the given time(s).
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_harmonic_oscillator_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def free_mode(self, request: pytest.FixtureRequest) -> dict[str, int]:

@pytest.fixture(
params=chain.from_iterable(
product(_possible_wave_modes, repeat=r) for r in range(2)
product(_possible_wave_modes, repeat=r) for r in range(3)
)
)
def wave_modes(self, request: pytest.FixtureRequest) -> list[dict[str, int]]:
Expand All @@ -45,7 +45,7 @@ def test_init(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[dict[str, int]],
wave_modes: Iterable[Mapping[str, int]],
odd_dof: bool,
) -> None:
assert HarmonicOscillatorsChain(omega, [free_mode, *wave_modes], odd_dof)
Expand All @@ -54,7 +54,7 @@ def test_real(
self,
omega: int,
free_mode: Mapping[str, int],
wave_modes: Iterable[dict[str, int]],
wave_modes: Iterable[Mapping[str, int]],
odd_dof: bool,
times: int | Sequence[int],
) -> None:
Expand Down

0 comments on commit 075efb1

Please sign in to comment.