Skip to content

Commit

Permalink
feat(model): #63 truely complex simple harmonic oscillators (#64)
Browse files Browse the repository at this point in the history
* feat: #63 make real oscillators real again

* feat: #63 make complex oscillators finally complex

* chore: #63 typing, import

* feat(refactor): #63 private communication with @emptymalei

* chore: #63 typing, import

* chore(typing): #63

* fix: #63 edge cases

* fix(comment): #63 #64 (comment)
  • Loading branch information
cmp0xff authored Jul 27, 2024
1 parent 1347914 commit a783c4c
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 57 deletions.
9 changes: 5 additions & 4 deletions hamilflow/models/brownian_motion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from typing import Mapping

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -33,8 +34,8 @@ class BrownianMotionSystem(BaseModel):
:cvar delta_t: time granunality of the motion
"""

sigma: float = Field(ge=0)
delta_t: float = Field(ge=0, default=1.0)
sigma: float = Field(ge=0.0)
delta_t: float = Field(ge=0.0, default=1.0)

@computed_field # type: ignore[misc]
@cached_property
Expand Down Expand Up @@ -141,8 +142,8 @@ class BrownianMotion:

def __init__(
self,
system: dict[str, float],
initial_condition: dict[str, float] | None = None,
system: Mapping[str, float],
initial_condition: Mapping[str, float] | None = None,
):
initial_condition = initial_condition or {}
self.system = BrownianMotionSystem.model_validate(system)
Expand Down
File renamed without changes.
124 changes: 83 additions & 41 deletions hamilflow/models/harmonic_oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,18 @@
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator

try:
from typing import Self
except ImportError:
from typing_extensions import Self
from pydantic import BaseModel, Field, computed_field, field_validator


class HarmonicOscillatorSystem(BaseModel):
"""The params for the harmonic oscillator
:cvar omega: angular frequency of the harmonic oscillator
:cvar zeta: damping ratio
:cvar real: use real solution (only supported for the undamped case)
"""

omega: float
zeta: float = 0.0

real: bool = Field(default=True)
omega: float = Field()
zeta: float = Field(default=0.0)

@computed_field # type: ignore[misc]
@cached_property
Expand Down Expand Up @@ -61,13 +53,6 @@ def check_zeta_non_negative(cls, v: float) -> float:

return v

@model_validator(mode="after")
def check_real_zeta(self) -> Self:
if not self.real and self.zeta != 0.0:
raise NotImplementedError("real = False only implemented for zeta = 0.0")

return self


class HarmonicOscillatorIC(BaseModel):
"""The initial condition for a harmonic oscillator
Expand All @@ -77,9 +62,9 @@ class HarmonicOscillatorIC(BaseModel):
:cvar phi: initial phase
"""

x0: float = 1.0
v0: float = 0.0
phi: float = 0.0
x0: float = Field(default=1.0)
v0: float = Field(default=0.0)
phi: float = Field(default=0.0)


class HarmonicOscillatorBase(ABC):
Expand All @@ -92,15 +77,15 @@ class HarmonicOscillatorBase(ABC):

def __init__(
self,
system: Mapping[str, float | int | bool],
system: Mapping[str, float | int],
initial_condition: Mapping[str, float | int] | None = None,
) -> None:
initial_condition = initial_condition or {}
self.system = HarmonicOscillatorSystem.model_validate(system)
self.initial_condition = HarmonicOscillatorIC.model_validate(initial_condition)

@cached_property
def definition(self) -> dict[str, dict[str, float | int | bool]]:
def definition(self) -> dict[str, dict[str, float | int]]:
"""model params and initial conditions defined as a dictionary."""
return {
"system": self.system.model_dump(),
Expand Down Expand Up @@ -144,17 +129,13 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):
The mass behaves like a simple harmonic oscillator.
In general, the solution to a real simple harmonic oscillator is
In general, the solution to a simple harmonic oscillator is
$$
x(t) = A \cos(\omega t + \phi),
$$
where $\omega$ is the angular frequency, $\phi$ is the initial phase, and $A$ is the amplitude.
The complex solution is
$$
x(t) = A \exp(-\mathbb{i} (\omega t + \phi)).
$$
To use this generator,
Expand All @@ -172,7 +153,7 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):

def __init__(
self,
system: Mapping[str, float | int | bool],
system: Mapping[str, float | int],
initial_condition: Mapping[str, float | int] | None = None,
) -> None:
super().__init__(system, initial_condition)
Expand All @@ -181,24 +162,15 @@ def __init__(
f"System is not a Simple Harmonic Oscillator: {self.system}"
)

def _f(self, phase: float | int | Sequence[float | int]) -> np.ndarray:
np_phase = np.array(phase, copy=False)
return np.cos(np_phase) if self.system.real else np.exp(-1j * np_phase)

def _x(self, t: float | int | Sequence[float | int]) -> np.ndarray:
r"""Solution to simple harmonic oscillators:
$$
x(t) = x_0 \cos(\omega t + \phi)
x(t) = x_0 \cos(\omega t + \phi).
$$
if real, or
$$
x(t) = x_0 \exp(-\mathbb{i} (\omega t + \phi))
$$
if not real.
"""
return self.initial_condition.x0 * self._f(
self.system.omega * t + self.initial_condition.phi
return self.initial_condition.x0 * np.cos(
self.system.omega * np.array(t, copy=False) + self.initial_condition.phi
)


Expand Down Expand Up @@ -334,6 +306,7 @@ def _x_over_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:

def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to damped harmonic oscillators."""
t = np.array(t, copy=False)
if self.system.type == "under_damped":
x = self._x_under_damped(t)
elif self.system.type == "over_damped":
Expand All @@ -346,3 +319,72 @@ def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
)

return x


class ComplexSimpleHarmonicOscillatorIC(BaseModel):
"""The initial condition for a complex harmonic oscillator
:cvar x0: the initial displacements
:cvar phi: initial phases
"""

x0: tuple[float | int, float | int] = Field()
phi: tuple[float | int, float | int] = Field(default=(0, 0))


class ComplexSimpleHarmonicOscillator:
r"""Generate time series data for a complex simple harmonic oscillator.
:param system: all the params that defines the complex harmonic oscillator.
:param initial_condition: the initial condition of the complex harmonic oscillator.
"""

def __init__(
self,
system: Mapping[str, float | int],
initial_condition: Mapping[str, tuple[float | int, float | int]],
) -> None:
self.system = HarmonicOscillatorSystem.model_validate(system)
self.initial_condition = ComplexSimpleHarmonicOscillatorIC.model_validate(
initial_condition
)
if self.system.type != "simple":
raise ValueError(
f"System is not a Simple Harmonic Oscillator: {self.system}"
)

@cached_property
def definition(
self,
) -> dict[str, dict[str, float | int | tuple[float | int, float | int]]]:
"""model params and initial conditions defined as a dictionary."""

return dict(
system=self.system.model_dump(),
initial_condition=self.initial_condition.model_dump(),
)

def _z(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to complex simple harmonic oscillators:
$$
x(t) = x_+ \exp(-\mathbb{i} (\omega t + \phi_+)) + x_- \exp(+\mathbb{i} (\omega t + \phi_-)).
$$
"""
t = np.array(t, copy=False)
omega = self.system.omega
x0, phi = self.initial_condition.x0, self.initial_condition.phi
phases = -omega * t - phi[0], omega * t + phi[1]
return x0[0] * np.exp(1j * phases[0]) + x0[1] * np.exp(1j * phases[1])

def __call__(self, t: float | int | Sequence[float | int]) -> pd.DataFrame:
"""Generate time series data for the harmonic oscillator.
Returns a list of floats representing the displacement at each time.
:param t: time(s).
"""
t = [t] if not isinstance(t, Sequence) else t
data = self._z(t)

return pd.DataFrame({"t": t, "z": data})
2 changes: 1 addition & 1 deletion hamilflow/models/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class PendulumSystem(BaseModel):
parameter
"""

omega0: float = Field(gt=0, frozen=True)
omega0: float = Field(gt=0.0, frozen=True)


class PendulumIC(BaseModel):
Expand Down
File renamed without changes.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from pydantic import ValidationError

from hamilflow.models.discrete.d0.free_particle import FreeParticle, FreeParticleIC
from hamilflow.models.free_particle import FreeParticle, FreeParticleIC


class TestFreeParticleIC:
Expand Down
51 changes: 42 additions & 9 deletions tests/test_models/test_harmonic_oscillator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Mapping, Sequence

import numpy as np
import pandas as pd
import pytest
from pydantic import ValidationError
from numpy.testing import assert_array_equal

from hamilflow.models.harmonic_oscillator import (
ComplexSimpleHarmonicOscillator,
ComplexSimpleHarmonicOscillatorIC,
DampedHarmonicOscillator,
HarmonicOscillatorSystem,
SimpleHarmonicOscillator,
Expand Down Expand Up @@ -139,13 +144,41 @@ def test_criticaldamped_harmonic_oscillator(omega, zeta, expected):
pd.testing.assert_frame_equal(df, pd.DataFrame(expected))


class TestHarmonicOscillatorSystem:
@pytest.mark.parametrize("omega", [-1, 1])
def test_complex(self, omega: int) -> None:
HarmonicOscillatorSystem(omega=omega, real=False)
class TestComplexHarmonicOscillatorIC:
@pytest.mark.parametrize("kwargs", [dict(x0=(1, 2), phi=(2, 3)), dict(x0=(1, 2))])
def test_ic(self, kwargs: Mapping[str, tuple[int, int]]) -> None:
assert ComplexSimpleHarmonicOscillatorIC(**kwargs)


class TestComplexHarmonicOscillator:
def test_complex(self) -> None:
assert ComplexSimpleHarmonicOscillator(
dict(omega=3), dict(x0=(1, 2), phi=(2, 3))
)

@pytest.mark.parametrize("omega", [-1, 1])
@pytest.mark.parametrize("zeta", [0.5, 1.0, 1.5])
def test_raise_complex(self, omega: int, zeta: float) -> None:
with pytest.raises(NotImplementedError):
HarmonicOscillatorSystem(omega=omega, zeta=zeta, real=False)
def test_raise(self, zeta: float) -> None:
with pytest.raises(ValueError):
ComplexSimpleHarmonicOscillator(
dict(omega=3, zeta=zeta), dict(x0=(2, 3), phi=(3, 4))
)

@pytest.fixture(params=(1, (1,), [1, 2], np.array([2, 3, 5, 7, 11])))
def times(self, request: pytest.FixtureRequest) -> int | Sequence[int]:
return request.param

@pytest.mark.parametrize("omega", [3, 5])
@pytest.mark.parametrize("x0", [2, 4])
@pytest.mark.parametrize("phi", [1, 6])
def test_degenerate_real(
self, omega: int, x0: int, phi: int, times: int | Sequence[int]
) -> None:
csho = ComplexSimpleHarmonicOscillator(
dict(omega=omega), dict(x0=(x0, x0), phi=(phi, phi))
)
sho = SimpleHarmonicOscillator(dict(omega=omega), dict(x0=2 * x0, phi=phi))
z = csho._z(times)
x = sho._x(times)

assert np.all(z.imag == 0.0)
assert_array_equal(z.real, x)

0 comments on commit a783c4c

Please sign in to comment.