Skip to content

Commit 25c5c36

Browse files
committed
use constant for axis 0
1 parent a8e0861 commit 25c5c36

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

src/interp1d/strategies/cubic_spline.rs

+29-27
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use crate::{
1616

1717
use super::{Strategy, StrategyBuilder};
1818

19+
const AX0: Axis = Axis(0);
20+
1921
#[derive(Debug)]
2022
pub struct CubicSpline;
2123
impl<Sd, Sx, D> StrategyBuilder<Sd, Sx, D> for CubicSpline
@@ -130,14 +132,14 @@ impl CubicSpline {
130132
// RHS vector
131133
let mut rhs: Array<Sd::Elem, D> = Array::zeros(dim.clone());
132134

133-
let mut inner_rhs = rhs.slice_axis_mut(Axis(0), Slice::from(1..-1));
134-
Zip::from(inner_rhs.axis_iter_mut(Axis(0)))
135+
let mut inner_rhs = rhs.slice_axis_mut(AX0, Slice::from(1..-1));
136+
Zip::from(inner_rhs.axis_iter_mut(AX0))
135137
.and(x.windows(3))
136-
.and(data.axis_windows(Axis(0), 3))
138+
.and(data.axis_windows(AX0, 3))
137139
.for_each(|rhs, x, data| {
138-
let y_left = data.index_axis(Axis(0), 0);
139-
let y_mid = data.index_axis(Axis(0), 1);
140-
let y_right = data.index_axis(Axis(0), 2);
140+
let y_left = data.index_axis(AX0, 0);
141+
let y_mid = data.index_axis(AX0, 1);
142+
let y_right = data.index_axis(AX0, 2);
141143
let x_left = x[0];
142144
let x_mid = x[1];
143145
let x_right = x[2];
@@ -152,19 +154,19 @@ impl CubicSpline {
152154
);
153155
});
154156

155-
let rhs_0 = rhs.index_axis_mut(Axis(0), 0);
156-
let data_0 = data.index_axis(Axis(0), 0);
157-
let data_1 = data.index_axis(Axis(0), 1);
157+
let rhs_0 = rhs.index_axis_mut(AX0, 0);
158+
let data_0 = data.index_axis(AX0, 0);
159+
let data_1 = data.index_axis(AX0, 1);
158160
Zip::from(rhs_0)
159161
.and(data_0)
160162
.and(data_1)
161163
.for_each(|rhs_0, &y_0, &y_1| {
162164
*rhs_0 = three * (y_1 - y_0) / (x_1 - x_0).pow(two);
163165
});
164166

165-
let rhs_n = rhs.index_axis_mut(Axis(0), len - 1);
166-
let data_n = data.index_axis(Axis(0), len - 1);
167-
let data_n1 = data.index_axis(Axis(0), len - 2);
167+
let rhs_n = rhs.index_axis_mut(AX0, len - 1);
168+
let data_n = data.index_axis(AX0, len - 1);
169+
let data_n1 = data.index_axis(AX0, len - 2);
168170
Zip::from(rhs_n)
169171
.and(data_n)
170172
.and(data_n1)
@@ -174,12 +176,12 @@ impl CubicSpline {
174176

175177
// now solving With Thomas algorithm
176178

177-
let mut rhs_left = rhs.index_axis(Axis(0), 0).into_owned();
179+
let mut rhs_left = rhs.index_axis(AX0, 0).into_owned();
178180
for i in 1..len {
179181
let w = a_low[i] / a_mid[i - 1];
180182
a_mid[i] -= w * a_up[i - 1];
181183

182-
let rhs = rhs.index_axis_mut(Axis(0), i);
184+
let rhs = rhs.index_axis_mut(AX0, i);
183185
Zip::from(rhs)
184186
.and(rhs_left.view_mut())
185187
.for_each(|rhs, rhs_left| {
@@ -190,17 +192,17 @@ impl CubicSpline {
190192
}
191193

192194
let mut k = Array::zeros(dim);
193-
Zip::from(k.index_axis_mut(Axis(0), len - 1))
194-
.and(rhs.index_axis(Axis(0), len - 1))
195+
Zip::from(k.index_axis_mut(AX0, len - 1))
196+
.and(rhs.index_axis(AX0, len - 1))
195197
.for_each(|k, &rhs| {
196198
*k = rhs / a_mid[len - 1];
197199
});
198200

199-
let mut k_right = k.index_axis(Axis(0), len - 1).into_owned();
201+
let mut k_right = k.index_axis(AX0, len - 1).into_owned();
200202
for i in (0..len - 1).rev() {
201-
Zip::from(k.index_axis_mut(Axis(0), i))
203+
Zip::from(k.index_axis_mut(AX0, i))
202204
.and(k_right.view_mut())
203-
.and(rhs.index_axis(Axis(0), i))
205+
.and(rhs.index_axis(AX0, i))
204206
.for_each(|k, k_right, &rhs| {
205207
let new_k = (rhs - a_up[i] * *k_right) / a_mid[i];
206208
*k = new_k;
@@ -211,12 +213,12 @@ impl CubicSpline {
211213
let mut c_a = Array::zeros(a_b_dim.clone());
212214
let mut c_b = Array::zeros(a_b_dim);
213215
for index in 0..len - 1 {
214-
Zip::from(c_a.index_axis_mut(Axis(0), index))
215-
.and(c_b.index_axis_mut(Axis(0), index))
216-
.and(k.index_axis(Axis(0), index))
217-
.and(k.index_axis(Axis(0), index + 1))
218-
.and(data.index_axis(Axis(0), index))
219-
.and(data.index_axis(Axis(0), index + 1))
216+
Zip::from(c_a.index_axis_mut(AX0, index))
217+
.and(c_b.index_axis_mut(AX0, index))
218+
.and(k.index_axis(AX0, index))
219+
.and(k.index_axis(AX0, index + 1))
220+
.and(data.index_axis(AX0, index))
221+
.and(data.index_axis(AX0, index + 1))
220222
.for_each(|c_a, c_b, &k, &k_right, &y, &y_right| {
221223
*c_a = k * (x[index + 1] - x[index]) - (y_right - y);
222224
*c_b = (y_right - y) - k_right * (x[index + 1] - x[index]);
@@ -260,8 +262,8 @@ where
260262
let idx = interp.get_left_index(x);
261263
let (x_left, data_left) = interp.get_point(idx);
262264
let (x_right, data_right) = interp.get_point(idx + 1);
263-
let a_left = self.a.index_axis(Axis(0), idx);
264-
let b_left = self.b.index_axis(Axis(0), idx);
265+
let a_left = self.a.index_axis(AX0, idx);
266+
let b_left = self.b.index_axis(AX0, idx);
265267
let one: Sd::Elem = cast(1.0).unwrap_or_else(|| unimplemented!());
266268

267269
let t = (x - x_left) / (x_right - x_left);

0 commit comments

Comments
 (0)