Skip to content

Commit

Permalink
feat(test): #76 unit test for tau of u
Browse files Browse the repository at this point in the history
  • Loading branch information
cmp0xff committed Aug 15, 2024
1 parent 9db24a0 commit 54f3a2b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def tau_of_u_hyperbolic(ecc: float, u: "npt.ArrayLike") -> "npt.ArrayLike":
)


def tau_of_u_prime(e: float, u: "npt.ArrayLike") -> "npt.ArrayLike":
return 1 / (1 + u) ** 2 / np.sqrt(e**2 - u**2)
def tau_of_u_prime(ecc: float, u: "npt.ArrayLike") -> "npt.ArrayLike":
return -1 / (1 + u) ** 2 / np.sqrt(ecc**2 - u**2)


def solve_u_of_tau(
Expand Down
2 changes: 1 addition & 1 deletion hamilflow/models/kepler_problem/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from pydantic import BaseModel, Field, field_validator

from .math import (
from .dynamics import (
acos_with_shift,
solve_u_of_tau,
tau_of_u_elliptic,
Expand Down
44 changes: 44 additions & 0 deletions tests/models/kepler_problem/test_dynamics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from functools import partial
from typing import TYPE_CHECKING

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from scipy.integrate import quad

from hamilflow.models.kepler_problem.dynamics import (
tau_of_u_elliptic,
tau_of_u_hyperbolic,
tau_of_u_parabolic,
tau_of_u_prime,
)

if TYPE_CHECKING:
from numpy import typing as npt

_EPS = 0.05


@pytest.mark.parametrize("ecc", [1 / 3, 1 / 2, 5 / 7, 1.0, 12 / 11, 27 / 13])
def test_tau_of_u(ecc: float) -> None:
def integrand(u: float) -> float:
return tau_of_u_prime(ecc, u)

if 0 < ecc < 1:
tau_of_u = tau_of_u_elliptic
cosqr = 1 - ecc**2
const = -(ecc + np.arcsin(ecc) / np.sqrt(cosqr)) / cosqr
elif ecc == 1:
tau_of_u = tau_of_u_parabolic
const = 2 / 3
elif ecc > 1:
tau_of_u = tau_of_u_hyperbolic
cosqr = ecc**2 - 1
const = (ecc - np.arccosh(ecc) / np.sqrt(cosqr)) / cosqr
else:
raise ValueError(f"Expected ecc > 0, got {ecc}")

u_s = np.linspace(max(-1, -ecc) + _EPS, ecc - _EPS, 5)
rets = [quad(integrand, 0, u) for u in u_s]
integrals = np.array([ret[0] for ret in rets]) + const
assert_array_almost_equal(integrals, tau_of_u(ecc, u_s))

0 comments on commit 54f3a2b

Please sign in to comment.