Skip to content

Commit

Permalink
make lists part more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 19, 2024
1 parent 383f05b commit 33d60f2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 42 deletions.
50 changes: 27 additions & 23 deletions docs/lists.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ On the Rust side, this is where things get a bit scary. I don't know of any exis
convenience function to do this, so I'll provide you with one here:

```rust
fn binary_amortized_elementwise_float<'a, F>(
fn binary_amortized_elementwise<'a, T, K, F>(
ca: &'a ListChunked,
weights: &'a ListChunked,
mut f: F,
) -> Float64Chunked
) -> ChunkedArray<T>
where
F: FnMut(&Series, &Series) -> Option<f64>,
T: PolarsDataType,
T::Array: ArrayFromIter<Option<K>>,
F: FnMut(&Series, &Series) -> Option<K> + Copy,
{
unsafe {
ca.amortized_iter()
Expand All @@ -73,36 +75,38 @@ where
}
}
```
Don't worry about understanding it.
Some of its details (such as `.as_ref()` to get a `Series` out of an `UnstableSeries`) are arguably
implementation details. Hopefully a more generic version of this utility like this can be added to
Polars itself, so that plugin authors won't need to worry about it. But for now, let's just use it to
implement a weighted mean.
Polars itself, so that plugin authors won't need to worry about it.

We're just going to accept two inputs (values and weights), multiply them together, and divide by
the sum of the weights:
Let's concern ourselves with just using it to accomplish our task!
We just need to write a function which accepts two `Series`, computes their dot product, and then
divides by the sum of the weights:

```rust
#[polars_expr(output_type=Float64)]
fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].list()?;
let values = inputs[0].list()?;
let weights = &inputs[1].list()?;

let out = binary_amortized_elementwise_float(ca, weights, |values, weights| {
let values = values.i64().unwrap();
let weights = weights.f64().unwrap();
let out_inner: Float64Chunked = binary_elementwise(
values,
weights,
|lhs: Option<i64>, rhs: Option<f64>| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(lhs as f64 * rhs),
let out: Float64Chunked =
binary_amortized_elementwise(values, weights, |values_inner, weights_inner| {
let values_inner = values_inner.i64().unwrap();
let weights_inner = weights_inner.f64().unwrap();
let out_inner: Float64Chunked = binary_elementwise(
values_inner,
weights_inner,
|opt_value: Option<i64>, opt_weight: Option<f64>| match (opt_value, opt_weight) {
(Some(value), Some(weight)) => Some(value as f64 * weight),
_ => None,
},
);
match (out_inner.sum(), weights_inner.sum()) {
(Some(weighted_sum), Some(weights_sum)) => Some(weighted_sum / weights_sum),
_ => None,
},
);
match (out_inner.sum(), weights.sum()) {
(Some(sum), Some(weights_sum)) => Some(sum / weights_sum),
_ => None,
}
});
}
});
Ok(out.into_series())
}
```
Expand Down
41 changes: 22 additions & 19 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,15 @@ fn snowball_stem(inputs: &[Series]) -> PolarsResult<Series> {
Ok(out.into_series())
}

fn binary_amortized_elementwise_float<'a, F>(
fn binary_amortized_elementwise<'a, T, K, F>(
ca: &'a ListChunked,
weights: &'a ListChunked,
mut f: F,
) -> Float64Chunked
) -> ChunkedArray<T>
where
F: FnMut(&Series, &Series) -> Option<f64>,
T: PolarsDataType,
T::Array: ArrayFromIter<Option<K>>,
F: FnMut(&Series, &Series) -> Option<K> + Copy,
{
unsafe {
ca.amortized_iter()
Expand All @@ -222,24 +224,25 @@ where

#[polars_expr(output_type=Float64)]
fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].list()?;
let values = inputs[0].list()?;
let weights = &inputs[1].list()?;

let out = binary_amortized_elementwise_float(ca, weights, |values, weights| {
let values = values.i64().unwrap();
let weights = weights.f64().unwrap();
let out_inner: Float64Chunked = binary_elementwise(
values,
weights,
|lhs: Option<i64>, rhs: Option<f64>| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(lhs as f64 * rhs),
let out: Float64Chunked =
binary_amortized_elementwise(values, weights, |values_inner, weights_inner| {
let values_inner = values_inner.i64().unwrap();
let weights_inner = weights_inner.f64().unwrap();
let out_inner: Float64Chunked = binary_elementwise(
values_inner,
weights_inner,
|opt_value: Option<i64>, opt_weight: Option<f64>| match (opt_value, opt_weight) {
(Some(value), Some(weight)) => Some(value as f64 * weight),
_ => None,
},
);
match (out_inner.sum(), weights_inner.sum()) {
(Some(weighted_sum), Some(weights_sum)) => Some(weighted_sum / weights_sum),
_ => None,
},
);
match (out_inner.sum(), weights.sum()) {
(Some(sum), Some(weights_sum)) => Some(sum / weights_sum),
_ => None,
}
});
}
});
Ok(out.into_series())
}

0 comments on commit 33d60f2

Please sign in to comment.