Skip to content

Commit

Permalink
refactor: add HermiteInterpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
SoulMelody committed Jan 25, 2025
1 parent 26aaa1c commit cb8aaff
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 96 deletions.
47 changes: 0 additions & 47 deletions libresvip/plugins/acep/ace_curve_utils.py

This file was deleted.

3 changes: 0 additions & 3 deletions libresvip/plugins/acep/ace_studio_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def parse_track(self, ace_track: AcepTrack) -> Optional[Track]:
ace_note_list.extend(ace_notes)

def merge_curves(src: AcepParamCurveList, dst: AcepParamCurveList) -> None:
for curve in src.root:
if curve.curve_type == "anchor":
curve.points2values()
ace_curves = [
curve
for curve in src.root
Expand Down
40 changes: 18 additions & 22 deletions libresvip/plugins/acep/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NamedTuple,
Optional,
Union,
cast,
)

from more_itertools import batched, minmax
Expand All @@ -30,9 +31,8 @@
from libresvip.model.base import BaseModel
from libresvip.model.point import PointList
from libresvip.utils.audio import audio_path_validator
from libresvip.utils.music_math import linear_interpolation
from libresvip.utils.music_math import HermiteInterpolator

from .ace_curve_utils import interpolate_hermite
from .enums import AcepLyricsLanguage
from .singers import DEFAULT_SEED, DEFAULT_SINGER, DEFAULT_SINGER_ID

Expand Down Expand Up @@ -69,26 +69,22 @@ def serialize_points(
) -> list[float]:
return list(chain.from_iterable(points.root))

def points2values(self) -> None:
if self.curve_type == "anchor" and self.points is not None:
if len(self.points.root) > 2:
self.offset = math.floor(self.points.root[0].pos)
self.values = interpolate_hermite(
[point.pos for point in self.points.root],
[point.value for point in self.points.root],
list(
range(
self.offset,
math.ceil(self.points.root[-1].pos) + 1,
)
),
)
elif len(self.points.root) == 2:
self.offset = math.floor(self.points.root[0].pos)
self.values = [
linear_interpolation(pos, self.points.root[0], self.points.root[-1])
for pos in range(self.offset, math.ceil(self.points.root[-1].pos) + 1)
]
@model_validator(mode="after")
def points2values(self) -> Self:
if self.curve_type == "anchor" and self.points is not None and len(self.points.root):
interpolator = HermiteInterpolator(
cast(list[tuple[float, float]], self.points.root),
)
self.offset = math.floor(self.points.root[0].pos)
self.values = interpolator.interpolate(
list(
range(
self.offset,
math.ceil(self.points.root[-1].pos) + 1,
)
),
)
return self

def transform(self, value_transform: Callable[[float], float]) -> AcepParamCurve:
return self.model_copy(
Expand Down
59 changes: 35 additions & 24 deletions libresvip/plugins/tlp/tunelab_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import math
import operator
from typing import cast

import more_itertools
import portion
Expand All @@ -20,7 +21,7 @@
Track,
)
from libresvip.model.point import Point
from libresvip.utils.music_math import db_to_float, ratio_to_db
from libresvip.utils.music_math import HermiteInterpolator, db_to_float, ratio_to_db
from libresvip.utils.translation import gettext_lazy as _

from .model import (
Expand Down Expand Up @@ -208,34 +209,44 @@ def parse_pitch(
) -> list[Point]:
points: list[Point] = [Point.start_point()]
for pitch_part in pitch:
for is_first, is_last, tlp_point in more_itertools.mark_ends(pitch_part.root):
pitch_pos = int(tlp_point.pos) + offset
if is_first:
points.append(
Point(
x=pitch_pos + self.first_bar_length,
y=-100,
)
)
pitch_secs = self.synchronizer.get_actual_secs_from_ticks(pitch_pos)
pitch_value = tlp_point.value
if math.isnan(pitch_value):
for anchor_group in more_itertools.split_at(
pitch_part.root, lambda x: math.isnan(x.value)
):
if len(anchor_group) < 2:
continue
if (vibrato_value := vibrato_base_interval_dict.get(pitch_secs)) is not None:
vibrato_value *= vibrato_envelope_interval_dict.get(pitch_secs, 1)
pitch_value += vibrato_value
points.append(
Point(
x=pitch_pos + self.first_bar_length,
y=round(pitch_value * 100),
)
interpolator = HermiteInterpolator(
points=cast(list[tuple[float, float]], anchor_group)
)
xs = list(
more_itertools.numeric_range(anchor_group[0].pos, anchor_group[-1].pos + 1, 5)
)
if is_last:
ys = interpolator.interpolate(xs)
for is_first, is_last, i in more_itertools.mark_ends(range(len(xs))):
pitch_pos = int(xs[i]) + offset
if is_first:
points.append(
Point(
x=pitch_pos + self.first_bar_length,
y=-100,
)
)
pitch_secs = self.synchronizer.get_actual_secs_from_ticks(pitch_pos)
pitch_value = ys[i]
if (vibrato_value := vibrato_base_interval_dict.get(pitch_secs)) is not None:
vibrato_value *= vibrato_envelope_interval_dict.get(pitch_secs, 1)
pitch_value += vibrato_value
points.append(
Point(
x=points[-1].x,
y=-100,
x=pitch_pos + self.first_bar_length,
y=round(pitch_value * 100),
)
)
if is_last:
points.append(
Point(
x=points[-1].x,
y=-100,
)
)
points.append(Point.end_point())
return points
59 changes: 59 additions & 0 deletions libresvip/utils/music_math.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import functools
import math
import re
Expand Down Expand Up @@ -187,6 +188,64 @@ def _inner_interpolate(
)


@dataclasses.dataclass
class HermiteInterpolator:
# from https://github.com/LiuYunPlayer/TuneLab/blob/master/TuneLab.Base/Science/HermiteInterpolation.cs

points: list[tuple[float, float]]

@staticmethod
def f1(t1: float, t2: float) -> float:
return (1 + 2 * t1) * t2**2

@staticmethod
def f3(t: float, d: float) -> float:
return t**2 * d

@staticmethod
def slope(p1: tuple[float, float], p2: tuple[float, float]) -> float:
return (p2[1] - p1[1]) / (p2[0] - p1[0])

def slope_at(self, point_index: int) -> float:
if point_index in [0, len(self.points) - 1]:
return 0
point = self.points[point_index]
last_k = self.slope(point, self.points[point_index - 1])
next_k = self.slope(point, self.points[point_index + 1])
kk = last_k * next_k
return 0 if kk <= 0 else 2 / (1 / last_k + 1 / next_k)

def interpolate(self, xs: list[float]) -> list[float]:
if len(self.points) < 2:
return [self.points[0][1] if len(self.points) == 1 else 0] * len(xs)
elif len(self.points) == 2:
return [linear_interpolation(x, self.points[0], self.points[1]) for x in xs]
point_index = 0
ys = []
for x in xs:
while point_index < len(self.points) and self.points[point_index][0] < x:
point_index += 1
if point_index == 0:
ys.append(self.points[0][1])
elif point_index == len(self.points):
ys.append(self.points[-1][1])
last_point = self.points[point_index - 1]
last_delta = self.slope_at(point_index - 1)
next_point = self.points[point_index]
next_delta = self.slope_at(point_index)
delta_1 = x - last_point[0]
delta_2 = x - next_point[0]
t1 = delta_1 / (next_point[0] - last_point[0])
t2 = delta_2 / (last_point[0] - next_point[0])
ys.append(
self.f1(t1, t2) * last_point[1]
+ self.f1(t2, t1) * next_point[1]
+ self.f3(t2, delta_1) * last_delta
+ self.f3(t1, delta_2) * next_delta
)
return ys


def db_to_float(db: float, using_amplitude: bool = True) -> float:
"""
Converts the input db to a float, which represents the equivalent
Expand Down

0 comments on commit cb8aaff

Please sign in to comment.