Skip to content

Commit

Permalink
feat: #36 migrate to the same place (private communication with @empt…
Browse files Browse the repository at this point in the history
  • Loading branch information
cmp0xff committed Jul 27, 2024
1 parent e634cf9 commit 9b4bacb
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 32 deletions.
1 change: 0 additions & 1 deletion hamilflow/models/discrete/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion hamilflow/models/discrete/d0/__init__.py

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from functools import cached_property
from typing import Any, Collection, Mapping, Sequence
from typing import Mapping, Sequence, cast

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pandas.api.types import is_scalar
from scipy.fft import ifft

from ...harmonic_oscillator import SimpleHarmonicOscillator
from ..d0.free_particle import FreeParticle
from .free_particle import FreeParticle
from .harmonic_oscillator import ComplexSimpleHarmonicOscillator


class HarmonicOscillatorsChain:
Expand Down Expand Up @@ -43,36 +41,46 @@ class HarmonicOscillatorsChain:

def __init__(
self,
omega: float,
independent_initial_conditions: Sequence[Mapping[str, float | int]],
omega: float | int,
initial_conditions: Sequence[
Mapping[str, float | int | tuple[float | int, float | int]]
],
odd_dof: bool,
) -> None:
self.omega = omega
self.n_independant_wave_dof = len(independent_initial_conditions) - 1
self.n_independant_csho_dof = len(initial_conditions) - 1
self.odd_dof = odd_dof

self.free_mode = FreeParticle(independent_initial_conditions[0])
self.free_mode = FreeParticle(
cast(Mapping[str, float | int], initial_conditions[0])
)

r_wave_modes_ic = independent_initial_conditions[1:]
self.r_wave_modes = [
self._sho_factory(k, ic["amp"], ic["phi"])
r_wave_modes_ic = initial_conditions[1:]
self.independent_csho_modes = [
self._sho_factory(
k,
cast(tuple[float | int, float | int], ic["amp"]),
cast(tuple[float | int, float | int] | None, ic.get("phi")),
)
for k, ic in enumerate(r_wave_modes_ic, 1)
]

def _sho_factory(
self, k: int, amp: float | int, phi: float | int
) -> SimpleHarmonicOscillator:
return SimpleHarmonicOscillator(
self,
k: int,
amp: tuple[float | int, float | int],
phi: tuple[float | int, float | int] | None = None,
) -> ComplexSimpleHarmonicOscillator:
return ComplexSimpleHarmonicOscillator(
dict(
omega=2 * self.omega * np.sin(np.pi * k / self.n_dof),
real=not self.odd_dof and k == self.n_independant_wave_dof,
),
dict(x0=amp, phi=phi),
dict(x0=amp) | (dict(phi=phi) if phi else {}),
)

@cached_property
def n_dof(self) -> int:
return self.n_independant_wave_dof * 2 + self.odd_dof
return self.n_independant_csho_dof * 2 + self.odd_dof

@cached_property
def definition(
Expand All @@ -82,29 +90,39 @@ def definition(
float
| int
| dict[str, dict[str, int | float | list[int | float]]]
| list[dict[str, dict[str, int | float | bool]]],
| list[dict[str, dict[str, float | int | tuple[float | int, float | int]]]],
]:
"""model params and initial conditions defined as a dictionary."""
return dict(
omega=self.omega,
n_dof=self.n_dof,
free_mode=self.free_mode.definition,
r_wave_modes=[rwm.definition for rwm in self.r_wave_modes],
independent_csho_modes=[
rwm.definition for rwm in self.independent_csho_modes
],
)

def _z(
self, t: float | int | Sequence[float | int]
) -> tuple[np.ndarray, np.ndarray]:
to_concat = [self.free_mode._x(t).reshape(1, -1)]

r_waves = np.array([rwm._x(t) for rwm in self.r_wave_modes], copy=False)
to_concat.extend(
(r_waves, r_waves[:, ::-1].conj())
if self.odd_dof
else (r_waves[:, :-1], r_waves[:, -1], r_waves[:, -1::-1].conj())
)

travelling_waves = np.concatenate(to_concat)
t = np.array(t, copy=False).reshape(-1)
all_travelling_waves = [self.free_mode._x(t).reshape(1, -1)]

if self.independent_csho_modes:
independent_cshos = np.array(
[o._z(t) for o in self.independent_csho_modes], copy=False
)
all_travelling_waves.extend(
(independent_cshos, independent_cshos[::-1].conj())
if self.odd_dof
else (
independent_cshos[:-1],
independent_cshos[[-1]],
independent_cshos[-1::-1].conj(),
)
)

travelling_waves = np.concatenate(all_travelling_waves)
original_zs = ifft(travelling_waves, axis=0, norm="ortho")
return original_zs, travelling_waves

Expand Down

0 comments on commit 9b4bacb

Please sign in to comment.