Skip to content

Commit bd54f91

Browse files
committed
Implemented solving.
1 parent c21de94 commit bd54f91

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

chspy/_chspy.py

+99
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,53 @@ def extrema_from_anchors(anchors,beginning=None,end=None,target=None):
269269

270270
return extrema
271271

272+
def solve_from_anchors(anchors,i,value,beginning=None,end=None):
273+
"""
274+
Finds the times at which a component of the Hermite interpolant for the anchors assumes a given value and the derivatives at those points (allowing to distinguish upwards and downwards threshold crossings).
275+
276+
Parameters
277+
----------
278+
i : integer
279+
The index of the component.
280+
value : float
281+
The value that shall be assumed
282+
beginning : float or `None`
283+
Beginning of the time interval for which positions are returned. If `None`, the time of the first anchor is used.
284+
end : float or `None`
285+
End of the time interval for which positions are returned. If `None`, the time of the last anchor is used.
286+
287+
Returns
288+
-------
289+
positions : list of pairs of floats
290+
Each pair consists of a time where `value` is assumed and the derivative (of `component`) at that time.
291+
"""
292+
293+
q = (anchors[1].time-anchors[0].time)
294+
retransform = lambda x: q*x+anchors[0].time
295+
a = anchors[0].state[i]
296+
b = anchors[0].diff[i] * q
297+
c = anchors[1].state[i]
298+
d = anchors[1].diff[i] * q
299+
300+
left_x = 0 if beginning is None else (beginning-anchors[0].time)/q
301+
right_x = 1 if end is None else (end -anchors[0].time)/q
302+
303+
candidates = np.roots([
304+
2*a + b - 2*c + d,
305+
-3*a - 2*b + 3*c - d,
306+
b,
307+
a - value,
308+
])
309+
310+
solutions = sorted(
311+
retransform(candidate.real)
312+
for candidate in candidates
313+
if np.isreal(candidate) and left_x<=candidate<=right_x
314+
)
315+
316+
return [ (t,interpolate_diff(t,i,anchors)) for t in solutions ]
317+
318+
272319
class CubicHermiteSpline(list):
273320
"""
274321
Class for a cubic Hermite Spline of one variable (time) with `n` values. This behaves like a list with additional functionalities and checks. Note that the times of the anchors must always be in ascending order.
@@ -590,6 +637,9 @@ def extrema(self,beginning=None,end=None):
590637
beginning = self[ 0].time if beginning is None else beginning
591638
end = self[-1].time if end is None else end
592639

640+
if not self[0].time <= beginning < end <= self[-1].time:
641+
raise ValueError("Beginning and end must in the time interval spanned by the anchors.")
642+
593643
extrema = Extrema(self.n)
594644

595645
for i in range(self.last_index_before(beginning),len(self)-1):
@@ -604,6 +654,55 @@ def extrema(self,beginning=None,end=None):
604654
)
605655

606656
return extrema
657+
658+
def solve(self,i,value,beginning=None,end=None):
659+
"""
660+
Finds the times at which a component of the spline assumes a given value and the derivatives at those points (allowing to distinguish upwards and downwards threshold crossings). This will not work well if the spline is constantly at the given value for some interval.
661+
662+
Parameters
663+
----------
664+
i : integer
665+
The index of the component.
666+
value : float
667+
The value that shall be assumed
668+
beginning : float or `None`
669+
Beginning of the time interval for which solutions are returned. If `None`, the time of the first anchor is used.
670+
end : float or `None`
671+
End of the time interval for which solutions are returned. If `None`, the time of the last anchor is used.
672+
673+
Returns
674+
-------
675+
positions : list of pairs of floats
676+
Each pair consists of a time where `value` is assumed and the derivative (of `component`) at that time.
677+
"""
678+
679+
beginning = self[ 0].time if beginning is None else beginning
680+
end = self[-1].time if end is None else end
681+
682+
if not self[0].time <= beginning < end <= self[-1].time:
683+
raise ValueError("Beginning and end must in the time interval spanned by the anchors.")
684+
685+
extrema = Extrema(self.n)
686+
687+
sols = []
688+
689+
for j in range(self.last_index_before(beginning),len(self)-1):
690+
if self[j].time>end:
691+
break
692+
693+
new_sols = solve_from_anchors(
694+
anchors = ( self[j], self[j+1] ),
695+
i = i,
696+
value = value,
697+
beginning = max( beginning, self[j ].time ),
698+
end = min( end , self[j+1].time ),
699+
)
700+
701+
if sols and new_sols and sols[-1][0]==new_sols[0][0]:
702+
del new_sols[0]
703+
sols.extend(new_sols)
704+
705+
return sols
607706

608707
def norm(self, delay, indices):
609708
"""

tests/test_hermite_spline.py

+24
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,30 @@ def test_multiple_anchors(self):
247247
assert_allclose( result.arg_min, times[np.argmin(values,axis=0)], atol=1e-3 )
248248
assert_allclose( result.arg_max, times[np.argmax(values,axis=0)], atol=1e-3 )
249249

250+
class TestSolving(unittest.TestCase):
251+
def test_random_function(self):
252+
roots = np.sort(np.random.normal(size=5))
253+
value = np.random.normal()
254+
t = symengine.Symbol("t")
255+
function = np.prod([t-root for root in roots]) + value
256+
257+
i = 1
258+
spline = CubicHermiteSpline(n=3)
259+
spline.from_function(
260+
[10,function,10],
261+
times_of_interest = ( min(roots)-0.01, max(roots)+0.01 ),
262+
max_anchors = 1000,
263+
tol = 7,
264+
)
265+
266+
solutions = spline.solve(i=i,value=value)
267+
sol_times = [ sol[0] for sol in solutions ]
268+
assert_allclose( spline.get_state(sol_times)[:,i], value )
269+
assert_allclose( [sol[0] for sol in solutions], roots, atol=1e-3 )
270+
for time,diff in solutions:
271+
true_diff = float(function.diff(t).subs({t:time}))
272+
self.assertAlmostEqual( true_diff, diff, places=5 )
273+
250274
class TimeSeriesTest(unittest.TestCase):
251275
def test_comparison(self):
252276
interval = (-3,10)

0 commit comments

Comments
 (0)