-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add smoothing spline baseline #299
base: main
Are you sure you want to change the base?
Changes from all commits
52b3da3
fdc862d
2af78a0
381fba8
14a4374
32c3e81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from lumicks.pylake.channel import Slice, Continuous | ||
from csaps import csaps | ||
from sklearn.model_selection import RepeatedKFold | ||
from sklearn.metrics import mean_squared_error | ||
|
||
|
||
def unique_sorted(trap_position, force): | ||
"""Sort and remove duplicates trap_position data to prepare for fit smoothing spline. | ||
Parameters | ||
---------- | ||
trap_position : lumicks.pylake.Slice | ||
Trap mirror position | ||
force : lumicks.pylake.Slice | ||
Force data | ||
""" | ||
|
||
x = trap_position.data | ||
u, c = np.unique(x, return_counts=True) | ||
m = np.isin(x, [u[c < 2]]) | ||
ind = np.argsort(x[m]) | ||
|
||
return x[m][ind], force.data[m][ind] | ||
|
||
|
||
def optimize_smoothing_factor( | ||
trap_position, | ||
force, | ||
smoothing_factors, | ||
n_repeats, | ||
plot_smoothingfactor_mse, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that the function has |
||
): | ||
"""Find optimal smoothing factor by choosing smoothing factor with lowest mse on test data | ||
Parameters | ||
---------- | ||
trap_position : lumicks.pylake.Slice | ||
Trap mirror position data | ||
force : lumicks.pylake.Slice | ||
Force data | ||
smoothing_factors : np.array float | ||
Array of smoothing factors used for optimization fit, 0 <= smoothing_factor <= 1 | ||
n_repeats: int | ||
number of times to repeat cross validation | ||
plot_smoothingfactor_mse: bool | ||
plot mse on test data vs smoothing factors used for optimization | ||
""" | ||
|
||
mse_test_vals = np.zeros(len(smoothing_factors)) | ||
x_sorted, y_sorted = unique_sorted(trap_position, force) | ||
Comment on lines
+49
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing I find a bit surprising is that you choose to pass trap position and force to this function, despite already having |
||
for i, smooth in enumerate(smoothing_factors): | ||
mse_test_array = np.zeros(n_repeats * 2) | ||
|
||
rkf = RepeatedKFold(n_splits=2, n_repeats=n_repeats) | ||
for k, (train_index, test_index) in enumerate(rkf.split(x_sorted)): | ||
x_train, x_test = x_sorted[train_index], x_sorted[test_index] | ||
y_train, y_test = y_sorted[train_index], y_sorted[test_index] | ||
|
||
smoothing_result_train = csaps(x_train, y_train, smooth=smooth) | ||
f_test = smoothing_result_train(x_test) | ||
mse_test_array[k] = mean_squared_error(y_test, f_test) | ||
|
||
mse_test_vals[i] = np.mean(mse_test_array) | ||
if plot_smoothingfactor_mse: | ||
plot_mse_smoothing_factors(smoothing_factors, mse_test_vals) | ||
|
||
return smoothing_factors[np.argmin(mse_test_vals)] | ||
|
||
|
||
def plot_mse_smoothing_factors(smoothing_factors, mse_test_vals): | ||
plt.figure() | ||
plt.plot( | ||
np.log(1 - smoothing_factors), | ||
mse_test_vals, | ||
label=f"optimal s= {smoothing_factors[np.argmin(mse_test_vals)]:0.6f}", | ||
) | ||
plt.ylabel("mse test") | ||
plt.xticks(np.log(1 - smoothing_factors), smoothing_factors) | ||
plt.xlabel("smoothing factor") | ||
plt.legend() | ||
plt.show() | ||
|
||
|
||
class ForceBaseLine: | ||
def __init__(self, model, trap_data, force): | ||
"""Force baseline | ||
|
||
Parameters | ||
---------- | ||
model : callable | ||
Model which returns the baseline at specified points. | ||
trap_data : lumicks.pylake.Slice | ||
Trap mirror position data | ||
force : lumicks.pylake.Slice | ||
Force data | ||
""" | ||
self._model = model | ||
self._trap_data = trap_data | ||
self._force = force | ||
|
||
def valid_range(self): | ||
return (np.min(self._trap_data.data), np.max(self._trap_data.data)) | ||
|
||
def correct_data(self, force, trap_position): | ||
if not np.array_equal(force.timestamps, trap_position.timestamps): | ||
raise RuntimeError("Provided force and trap position timestamps should match") | ||
|
||
return Slice( | ||
Continuous( | ||
force.data - self._model(trap_position.data), | ||
force._src.start, | ||
force._src.dt, | ||
), | ||
labels={ | ||
"title": force.labels.get("title", "Baseline Corrected Force"), | ||
"y": "Baseline Corrected Force (pN)", | ||
}, | ||
calibration=force._calibration, | ||
) | ||
|
||
def plot(self): | ||
plt.scatter(self._trap_data.data, self._force.data, s=2) | ||
plt.plot(self._trap_data.data, self._model(self._trap_data.data), "k") | ||
plt.xlabel("Mirror position") | ||
plt.ylabel(self._force.labels["y"]) | ||
plt.title("Force baseline") | ||
|
||
def plot_residual(self): | ||
plt.scatter(self._trap_data.data, self._force.data - self._model(self._trap_data.data), s=2) | ||
plt.xlabel("Mirror position") | ||
plt.ylabel(f"Residual {self._force.labels['y']}") | ||
plt.title("Fit residual") | ||
|
||
@classmethod | ||
def polynomial_baseline(cls, trap_position, force, degree=7, downsampling_factor=None): | ||
"""Generate a polynomial baseline from data | ||
|
||
Parameters | ||
---------- | ||
trap_position : lumicks.pylake.Slice | ||
Trap mirror position data | ||
force : lumicks.pylake.Slice | ||
Force data | ||
degree : int | ||
Polynomial degree | ||
downsampling_factor : int | ||
Factor by which to downsample before baseline determination | ||
""" | ||
if not np.array_equal(force.timestamps, trap_position.timestamps): | ||
raise RuntimeError("Provided force and trap position timestamps should match") | ||
|
||
if downsampling_factor: | ||
trap_position, force = ( | ||
ch.downsampled_by(downsampling_factor) for ch in (trap_position, force) | ||
) | ||
|
||
model = np.poly1d(np.polyfit(trap_position.data, force.data, deg=degree)) | ||
return cls(model, trap_position, force)\ | ||
|
||
|
||
@classmethod | ||
def smoothingspline_baseline( | ||
cls, | ||
trap_position, | ||
force, | ||
smoothing_factor=None, | ||
downsampling_factor=None, | ||
smoothing_factors=np.array([0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999]), | ||
n_repeats=10, | ||
plot_smoothingfactor_mse=False, | ||
): | ||
"""Generate a smoothing spline baseline from data. | ||
Items of xdata in smoothing spline must satisfy: x1 < x2 < ... < xN, | ||
therefore the trap_position data is sorted and duplicates are removed | ||
Parameters | ||
---------- | ||
trap_position : lumicks.pylake.Slice | ||
Trap mirror position data | ||
force : lumicks.pylake.Slice | ||
Force data | ||
smoothing_factor : float | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether it would make sense to merge If the input is just a value or has only one element, you use it as fixed value and otherwise you perform the optimization. You can use |
||
Smoothing factor for smoothing spline, 0 <= smoothing_factor <= 1 | ||
downsampling_factor : int | ||
Factor by which to downsample before baseline determination | ||
smoothing_factors : np.array float | ||
Array of smoothing factors used for optimization fit, 0 <= smoothing_factor <= 1 | ||
n_repeats: int | ||
number of times to repeat cross validation | ||
plot_smoothingfactor_mse: bool | ||
plot mse on test data vs smoothing factors used for optimization | ||
""" | ||
if not np.array_equal(force.timestamps, trap_position.timestamps): | ||
raise RuntimeError( | ||
"Provided force and trap position timestamps should match" | ||
) | ||
|
||
if downsampling_factor: | ||
trap_position, force = ( | ||
ch.downsampled_by(downsampling_factor) for ch in (trap_position, force) | ||
) | ||
|
||
x_sorted, y_sorted = unique_sorted(trap_position, force) | ||
|
||
if smoothing_factor: | ||
model = csaps(x_sorted, y_sorted, smooth=smoothing_factor) | ||
else: | ||
smoothing_factor = optimize_smoothing_factor( | ||
trap_position, | ||
force, | ||
smoothing_factors=smoothing_factors, | ||
n_repeats=n_repeats, | ||
plot_smoothingfactor_mse=plot_smoothingfactor_mse, | ||
) | ||
model = csaps(x_sorted, y_sorted, smooth=smoothing_factor) | ||
|
||
return cls(model, trap_position, force) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One issue with calling the constructor like this is that you don't actually put the data that you used for fitting in now (you fitted to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import warnings | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from lumicks.pylake.channel import Slice | ||
|
||
|
||
class DistanceCalibration: | ||
def __init__(self, trap_position, camera_distance, degree=1): | ||
"""Map the trap position to the camera tracking distance using a linear fit. | ||
|
||
Parameters | ||
---------- | ||
trap_position : lumicks.pylake.Slice | ||
Trap position. | ||
camera_distance : lumicks.pylake.Slice | ||
Camera distance as determined by Bluelake. | ||
NOTE: The distance data should already have the bead diameter subtracted by Bluelake. | ||
degree : int | ||
Polynomial degree. | ||
""" | ||
trap_position, camera_distance = trap_position.downsampled_like(camera_distance) | ||
mask = camera_distance.data != 0 | ||
missed_frames = np.sum(1 - mask) | ||
if missed_frames > 0: | ||
warnings.warn( | ||
RuntimeWarning( | ||
"There were frames with missing video tracking: " | ||
f"{missed_frames} data point(s) were omitted." | ||
) | ||
) | ||
self.position, self.distance = trap_position.data[mask], camera_distance.data[mask] | ||
coeffs = np.polyfit(self.position, self.distance, degree) | ||
self._model = np.poly1d(coeffs) | ||
|
||
def __call__(self, trap_position): | ||
return Slice( | ||
trap_position._src._with_data(self._model(trap_position.data)), | ||
labels={"title": "Piezo distance", "y": "Distance [um]"}, | ||
) | ||
|
||
def valid_range(self): | ||
return (np.min(self.position), np.max(self.position)) | ||
|
||
def __str__(self): | ||
powers = np.flip(np.arange(self._model.order + 1)) | ||
return "".join( | ||
f"{' + ' if coeff > 0 else ' - '}" | ||
f"{abs(coeff):.4f}" | ||
f"{'' if power == 0 else ' x' if power == 1 else f' x^{power}'}" | ||
for power, coeff in zip(powers, self._model.coeffs) | ||
).strip() | ||
|
||
def __repr__(self): | ||
return f"DistanceCalibration({str(self)})" | ||
|
||
def plot(self): | ||
"""Plot the calibration fit""" | ||
plt.scatter(self.position, self.distance, s=2, label="data") | ||
plt.plot(self.position, self._model(self.position), "k", label=f"${str(self)}$") | ||
plt.xlabel("Mirror position") | ||
plt.ylabel("Camera Distance [um]") | ||
plt.tight_layout() | ||
plt.legend() | ||
|
||
def plot_residual(self): | ||
"""Plot the residual of the calibration fit""" | ||
plt.scatter(self.position, self._model(self.position) - self.distance, s=2) | ||
plt.ylabel("Residual [um]") | ||
plt.xlabel("Mirror position") | ||
plt.tight_layout() | ||
plt.legend() | ||
|
||
@classmethod | ||
def from_file(cls, calibration_file, degree=1): | ||
"""Use a reference measurement to calibrate trap mirror position to bead-bead distance. | ||
|
||
Parameters | ||
---------- | ||
calibration_file : pylake.File | ||
degree : int | ||
Polynomial order. | ||
""" | ||
return cls(calibration_file["Trap position"]["1X"], calibration_file.distance1, degree) | ||
|
||
|
||
class PiezoTrackingCalibration: | ||
def __init__( | ||
self, | ||
trap_calibration, | ||
baseline_force1, | ||
baseline_force2, | ||
signs=(1, -1), | ||
): | ||
"""Set up piezo tracking calibration | ||
|
||
trap_calibration : pylake.DistanceCalibration | ||
Calibration from trap position to trap to trap distance. | ||
baseline_force1 : pylake.ForceBaseline | ||
Baseline for force1 | ||
baseline_force2 : pylake.ForceBaseline | ||
Baseline for force2 | ||
signs : tuple(float, float) | ||
Sign convention for forces (e.g. (1, -1) indicates that force2 is negative). | ||
""" | ||
if len(signs) != 2: | ||
raise ValueError( | ||
"Argument `signs` should be a tuple of two floats reflecting the sign for each " | ||
"channel." | ||
) | ||
for sign in signs: | ||
if abs(sign) != 1: | ||
raise ValueError("Each sign should be either -1 or 1.") | ||
|
||
self.trap_calibration = trap_calibration | ||
self.baseline_force1 = baseline_force1 | ||
self.baseline_force2 = baseline_force2 | ||
self._signs = signs | ||
|
||
def valid_range(self): | ||
"""Returns the mirror position range in which the piezo tracking is valid""" | ||
calibration_items = (self.trap_calibration, self.baseline_force1, self.baseline_force2) | ||
return np.min(np.stack([r.valid_range() for r in calibration_items]), axis=0) | ||
|
||
def piezo_track(self, trap_position, force1, force2, trim=True, downsampling_factor=None): | ||
"""Obtain piezo distance and baseline corrected forces | ||
|
||
Parameters | ||
---------- | ||
trap_position : pylake.channel.Slice | ||
Trap position. | ||
force1 : pylake.channel.Slice | ||
First force channel to use for piezo tracking. | ||
force2 : pylake.channel.Slice | ||
Second force channel to use for piezo tracking. | ||
trim : bool | ||
Trim regions outside the calibration range. | ||
downsampling_factor : Optional[int] | ||
Downsampling factor. | ||
""" | ||
if downsampling_factor: | ||
trap_position, force1, force2 = ( | ||
x.downsampled_by(downsampling_factor) for x in (trap_position, force1, force2) | ||
) | ||
|
||
trap_trap_dist = self.trap_calibration(trap_position) | ||
bead_displacements = 1e-3 * sum( | ||
sign * force / force.calibration[0]["kappa (pN/nm)"] | ||
for force, sign in zip((force1, force2), self._signs) | ||
) | ||
|
||
piezo_distance = trap_trap_dist - bead_displacements | ||
corrected_force1 = self.baseline_force1.correct_data(force1, trap_position) | ||
corrected_force2 = self.baseline_force2.correct_data(force2, trap_position) | ||
|
||
if trim: | ||
valid_range = self.valid_range() | ||
valid_mask = np.logical_and( | ||
valid_range[0] <= trap_position.data, trap_position.data <= valid_range[1] | ||
) | ||
piezo_distance, corrected_force1, corrected_force2 = ( | ||
piezo_distance[valid_mask], | ||
corrected_force1[valid_mask], | ||
corrected_force2[valid_mask], | ||
) | ||
|
||
return piezo_distance, corrected_force1, corrected_force2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import pytest | ||
import numpy as np | ||
from lumicks.pylake.channel import Continuous, TimeSeries, Slice | ||
from lumicks.pylake.fitting.models import inverted_odijk | ||
from lumicks.pylake.calibration import ForceCalibration | ||
|
||
|
||
def reference_baseline(): | ||
return np.poly1d([22.58134638112967, -601.6397764983628, 4007.686647006411]) | ||
|
||
|
||
@pytest.fixture() | ||
def poly_baseline_data(): | ||
baseline_model = reference_baseline() | ||
|
||
# Baseline correction should be able to deal with non-equidistant arbitrarily sorted points | ||
# with duplicates. | ||
trap_position_baseline = np.hstack( | ||
( | ||
np.arange(13.35, 12.95, -0.0000025), | ||
np.arange(13.35, 13.25, -0.0000005), | ||
np.ones(100000) * 12.95, | ||
) | ||
) | ||
|
||
trap = Slice( | ||
Continuous(trap_position_baseline, 1573123558595351600, int(1e9 / 78125)), | ||
labels={"title": "Trap position", "y": "y"}, | ||
) | ||
force = Slice( | ||
Continuous( | ||
baseline_model(trap_position_baseline), | ||
1573123558595351600, | ||
int(1e9 / 78125), | ||
), | ||
labels={"title": "force", "y": "Force (pN)"}, | ||
) | ||
|
||
return trap, force | ||
|
||
|
||
@pytest.fixture() | ||
def piezo_tracking_test_data(poly_baseline_data): | ||
baseline = reference_baseline() | ||
baseline_trap_position, baseline_force = poly_baseline_data | ||
trap2_ref = 9.15 | ||
|
||
# Positional calibration data | ||
# The "true" camera distance is given by trap position - reference point. | ||
ds_factor = 10 | ||
distance_ds = baseline_trap_position.downsampled_by(ds_factor) - trap2_ref | ||
old_dt = baseline_trap_position.timestamps[1] - baseline_trap_position.timestamps[0] | ||
camera_dist = Slice( | ||
TimeSeries(distance_ds.data, distance_ds.timestamps + (ds_factor // 2) * old_dt) | ||
) | ||
|
||
# Tether experiment data | ||
sample_rate = 78 | ||
dt = int(1e9 / sample_rate) | ||
tether_length_um = np.hstack( | ||
(np.arange(0.65, 0.7, 0.08 / sample_rate), np.arange(0.7, 0.785, 0.02 / sample_rate)) | ||
) | ||
wlc_force = inverted_odijk("tether")( | ||
tether_length_um, {"tether/Lp": 60, "tether/Lc": 0.75, "tether/St": 1400, "kT": 4.11} | ||
) | ||
|
||
stiffness = 0.15 | ||
stiffness_um = stiffness * 1e3 # pN/um (0.15 pN/nm) | ||
|
||
""" | ||
If we assume that the baseline force leads to a real displacement, then our function for the | ||
trap position becomes implicit, since the displacement depends on the baseline which in turn | ||
depends on the trap position: | ||
trap_trap_distance = tether_length + 2 * bead_radius + 2 * displacement_um | ||
And displacement_um is given by (wlc_force + baseline(trap_position)) / stiffness | ||
So we solve the following to obtain the trap position: | ||
displacement = 2 * (wlc_force + baseline(trap_position)) / stiffness | ||
0 = tether_length + 2 * bead_radius + displacement - (trap_position - trap2_ref) | ||
""" | ||
from scipy.optimize import minimize_scalar | ||
|
||
bead_radius = 1 # 1 micron beads | ||
trap_position = [] | ||
for tether_dist, force in zip(tether_length_um, wlc_force): | ||
|
||
def implicit_trap_position_equation(x): | ||
trap_trap_dist = x - trap2_ref | ||
displacement = 2 * (force + baseline(x)) / stiffness_um | ||
return (tether_dist + 2 * bead_radius + displacement - trap_trap_dist) ** 2 | ||
|
||
trap_position.append(minimize_scalar(implicit_trap_position_equation, [12.95, 13.35]).x) | ||
|
||
trap_position = Slice(Continuous(np.array(trap_position), 0, dt)) | ||
|
||
# Add our baseline force (assumption is that the baseline force leads to a real displacement) | ||
force_pn = wlc_force + baseline(trap_position.data) | ||
|
||
calibration = ForceCalibration( | ||
"Stop time (ns)", [{"Stop time (ns)": 1, "kappa (pN/nm)": stiffness}] | ||
) | ||
|
||
force_1x = Slice(Continuous(force_pn, 0, dt), calibration=calibration) | ||
force_2x = Slice(Continuous(-force_pn, 0, dt), calibration=calibration) | ||
|
||
return { | ||
"correct_distance": tether_length_um, | ||
"trap_position": trap_position, | ||
"force_without_baseline": wlc_force, | ||
"force_1x": force_1x, | ||
"force_2x": force_2x, | ||
"baseline_trap_position": baseline_trap_position, | ||
"baseline_force": baseline_force, | ||
"camera_dist": camera_dist - 2 * bead_radius, # Bluelake subtracts the radii already | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
from matplotlib.testing.decorators import cleanup | ||
from lumicks.pylake.piezo_tracking.baseline import ForceBaseLine | ||
|
||
|
||
def test_baseline(poly_baseline_data): | ||
trap, force = poly_baseline_data | ||
baseline = ForceBaseLine.polynomial_baseline(trap, force, degree=2) | ||
np.testing.assert_allclose(baseline.valid_range(), [12.95, 13.35]) | ||
np.testing.assert_allclose( | ||
baseline.correct_data(force, trap).data, np.zeros(force.data.shape), atol=1e-6 | ||
) | ||
|
||
|
||
def test_baseline_downsampled(poly_baseline_data): | ||
trap, force = poly_baseline_data | ||
|
||
baseline = ForceBaseLine.polynomial_baseline(trap, force, degree=2, downsampling_factor=500) | ||
np.testing.assert_allclose(baseline.valid_range(), [12.95, 13.349626]) | ||
np.testing.assert_allclose( | ||
baseline.correct_data(force, trap).data, np.zeros(force.data.shape), atol=1e-4 | ||
) | ||
np.testing.assert_allclose(baseline._trap_data, trap.downsampled_by(500)) | ||
np.testing.assert_allclose(baseline._force, force.downsampled_by(500)) | ||
|
||
|
||
@cleanup | ||
def test_baseline_plots(poly_baseline_data): | ||
trap, force = poly_baseline_data | ||
|
||
baseline = ForceBaseLine.polynomial_baseline(trap, force, degree=2) | ||
baseline.plot() | ||
baseline.plot_residual() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import pytest | ||
import numpy as np | ||
from matplotlib.testing.decorators import cleanup | ||
from lumicks.pylake.channel import Slice, Continuous, TimeSeries | ||
from lumicks.pylake.piezo_tracking.piezo_tracking import ( | ||
DistanceCalibration, | ||
PiezoTrackingCalibration, | ||
) | ||
from lumicks.pylake.piezo_tracking.baseline import ForceBaseLine | ||
|
||
|
||
def trap_pos_camera_distance(): | ||
dt = int(1e9 / 78125) | ||
trap_pos = Slice(Continuous(np.arange(2.0, 8.0, 0.001), dt=dt, start=1592916040906356300)) | ||
trap_pos_ds = trap_pos.downsampled_by(1000) | ||
camera_dist = Slice(TimeSeries(2 * trap_pos_ds.data + 1, trap_pos_ds.timestamps + 500 * dt)) | ||
|
||
return trap_pos, camera_dist | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"sampled_positions", | ||
[ | ||
Slice(Continuous(np.array([0, 0.5, 1.5, 2.5, 3.5]), dt=1, start=1)), | ||
Slice(TimeSeries(np.array([0, 0.5, 1.5, 2.5, 3.5]), np.array([0, 0.5, 1.5, 2.5, 3.5]))), | ||
], | ||
) | ||
def test_distance_calibration(sampled_positions): | ||
distance_calibration = DistanceCalibration(*trap_pos_camera_distance(), 1) | ||
np.testing.assert_allclose(distance_calibration.valid_range(), (3.4995, 7.4995)) | ||
assert str(distance_calibration) == "+ 2.0000 x + 1.0000" | ||
assert repr(distance_calibration) == "DistanceCalibration(+ 2.0000 x + 1.0000)" | ||
|
||
# Test evaluation of the calibration | ||
calibrated_slice = distance_calibration(sampled_positions) | ||
np.testing.assert_allclose(calibrated_slice.data, [1.0, 2.0, 4.0, 6.0, 8.0], atol=1e-12) | ||
np.testing.assert_allclose(calibrated_slice.timestamps, sampled_positions.timestamps) | ||
assert calibrated_slice.labels["title"] == "Piezo distance" | ||
assert calibrated_slice.labels["y"] == "Distance [um]" | ||
|
||
|
||
def test_lost_tracking(): | ||
trap_pos, camera_dist = trap_pos_camera_distance() | ||
trap_pos.data[4001] = 1e6 # Put a bad sample here, so we can detect that it gets discarded | ||
camera_dist.data[4] = 0 # Template lost => this should result in the bad sample being discarded | ||
with pytest.warns(RuntimeWarning, match="There were frames with missing video tracking"): | ||
distance_calibration = DistanceCalibration(trap_pos, camera_dist, 1) | ||
|
||
# Test evaluation of the calibration (this should be ok, since we discarded the sample) | ||
sampled_positions = Slice(Continuous(np.array([0, 0.5, 1.5, 2.5, 3.5]), dt=1, start=1)) | ||
calibrated_slice = distance_calibration(sampled_positions) | ||
np.testing.assert_allclose(calibrated_slice.data, [1.0, 2.0, 4.0, 6.0, 8.0], atol=1e-12) | ||
np.testing.assert_allclose(calibrated_slice.timestamps, sampled_positions.timestamps) | ||
assert calibrated_slice.labels["title"] == "Piezo distance" | ||
assert calibrated_slice.labels["y"] == "Distance [um]" | ||
|
||
|
||
def test_from_file(): | ||
trap_pos, camera_dist = trap_pos_camera_distance() | ||
|
||
class MockPiezo: | ||
def __init__(self): | ||
self.dict = {"Trap position": {"1X": trap_pos}} | ||
|
||
def __getitem__(self, item): | ||
return self.dict[item] | ||
|
||
@property | ||
def distance1(self): | ||
return camera_dist | ||
|
||
distance_calibration = DistanceCalibration.from_file(MockPiezo()) | ||
|
||
# Test evaluation of the calibration | ||
sampled_positions = Slice(Continuous(np.array([0, 0.5, 1.5, 2.5, 3.5]), dt=1, start=1)) | ||
calibrated_slice = distance_calibration(sampled_positions) | ||
np.testing.assert_allclose(calibrated_slice.data, [1.0, 2.0, 4.0, 6.0, 8.0], atol=1e-12) | ||
np.testing.assert_allclose(calibrated_slice.timestamps, sampled_positions.timestamps) | ||
assert calibrated_slice.labels["title"] == "Piezo distance" | ||
assert calibrated_slice.labels["y"] == "Distance [um]" | ||
|
||
|
||
@cleanup | ||
def test_plots(): | ||
distance_calibration = DistanceCalibration(*trap_pos_camera_distance(), 1) | ||
distance_calibration.plot() | ||
distance_calibration.plot_residual() | ||
|
||
|
||
def test_piezo_invalid_signs(): | ||
with pytest.raises( | ||
ValueError, | ||
match="Argument `signs` should be a tuple of two floats reflecting the sign for each " | ||
"channel.", | ||
): | ||
PiezoTrackingCalibration(None, None, None, (1, 1, 1)) | ||
|
||
with pytest.raises(ValueError, match="Each sign should be either -1 or 1."): | ||
PiezoTrackingCalibration(None, None, None, (1, 2)) | ||
|
||
|
||
def test_piezotracking(piezo_tracking_test_data): | ||
data = piezo_tracking_test_data | ||
|
||
# Calibrate using the trap position | ||
distance_calibration = DistanceCalibration(data["baseline_trap_position"], data["camera_dist"]) | ||
|
||
# Estimate the baselines | ||
baseline_1 = ForceBaseLine.polynomial_baseline( | ||
data["baseline_trap_position"], data["baseline_force"], degree=2 | ||
) | ||
baseline_2 = ForceBaseLine.polynomial_baseline( | ||
data["baseline_trap_position"], data["baseline_force"], degree=2 | ||
) | ||
|
||
# Perform the piezo tracking | ||
piezo_calibration = PiezoTrackingCalibration(distance_calibration, baseline_1, baseline_2) | ||
|
||
piezo_distance, corrected_force1, corrected_force2 = piezo_calibration.piezo_track( | ||
data["trap_position"], data["force_1x"], data["force_2x"], trim=False | ||
) | ||
|
||
np.testing.assert_allclose(corrected_force1.data, data["force_without_baseline"], rtol=1e-6) | ||
np.testing.assert_allclose(piezo_distance.data, data["correct_distance"], rtol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could consider simplifying this to:
Saves you a search and a sort (note that the result from
np.unique
is already sorted).Given that you return the raw numpy arrays rather than slices, I would also consider just taking raw numpy arrays as input and extracting the data where its called (rather than passing a slice).