Skip to content

Commit

Permalink
Merge branch 'rust-ml:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-montes authored Nov 27, 2023
2 parents 5eefd38 + 00e59f6 commit aa56bf4
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 5 deletions.
22 changes: 18 additions & 4 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, ArrayView1, ArrayView2, ArrayViewMut1};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A point in a D dimensional euclidean space that memorizes its
/// status: 'core' or 'non core'
pub struct StatusPoint {
Expand All @@ -16,10 +23,7 @@ pub struct StatusPoint {

impl StatusPoint {
pub fn new(point_index: usize) -> StatusPoint {
StatusPoint {
point_index,
is_core: false,
}
StatusPoint { point_index, is_core: false }
}

pub fn is_core(&self) -> bool {
Expand All @@ -32,6 +36,11 @@ impl StatusPoint {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Informations regarding the cell used in various stages of the approximate DBSCAN
/// algorithm if it is a core cell
pub struct CoreCellInfo<F: Float> {
Expand All @@ -42,6 +51,11 @@ pub struct CoreCellInfo<F: Float> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A cell from a grid that partitions the D dimensional euclidean space.
pub struct Cell<F: Float> {
/// The index of the intervals of the D dimensional axes where this cell lies
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use linfa::Float;
use linfa_nn::{distance::L2Dist, NearestNeighbour};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use cell::{Cell, StatusPoint};

Expand All @@ -16,6 +18,11 @@ pub type CellVector<F> = PartitionVec<Cell<F>>;
pub type CellTable = HashMap<Array1<i64>, usize>;

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct CellsGrid<F: Float> {
table: CellTable,
cells: CellVector<F>,
Expand Down
12 changes: 12 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/counting_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@ use crate::appx_dbscan::AppxDbscanValidParams;
use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, Array2, ArrayView1, Axis};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum IntersectionType {
FullyCovered,
Disjoint,
Intersecting,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Tree structure that divides the space in nested cells to perform approximate range counting
/// Each member of this structure is a node in the tree
pub struct TreeStructure<F: Float> {
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ use linfa_nn::{
CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex,
};
use ndarray::{Array1, ArrayBase, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::VecDeque;

use linfa::Float;
use linfa::{traits::Transformer, DatasetBase};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
/// clusters together points which are close together with enough neighbors
/// labelled points which are sparsely neighbored as noise. As points may be
Expand Down
33 changes: 32 additions & 1 deletion algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ impl<F: Float> GaussianMixtureModel<F> {
)?;
self.means = means;
self.weights = weights / F::cast(n_samples);
self.covariances = covariances;
// GmmCovarType = Full()
self.precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
Ok(())
}

Expand Down Expand Up @@ -488,7 +489,9 @@ mod tests {
use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
use ndarray_rand::rand::prelude::ThreadRng;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Normal;
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
use ndarray_rand::RandomExt;

#[test]
fn autotraits() {
Expand Down Expand Up @@ -570,6 +573,34 @@ mod tests {
);
}

#[test]
fn test_gmm_covariances() {
let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);

let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];

let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
let dataset = linfa::DatasetBase::from(data_2d);

let gmm = GaussianMixtureModel::params(3)
.n_runs(1)
.tolerance(1e-4)
.with_rng(rng)
.max_n_iterations(500)
.fit(&dataset)
.expect("GMM fit");

// expected results from scikit-learn 1.3.1
let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
let expected = Array::from_iter(expected.iter().cloned());
let actual = gmm.covariances();
let actual = Array::from_iter(actual.iter().cloned());
assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
}

fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
let mut y = Array2::zeros(x.dim());
Zip::from(&mut y).and(x).for_each(|yi, &xi| {
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/k_means/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// An helper struct used to construct a set of [valid hyperparameters](KMeansParams) for
/// the [K-means algorithm](crate::KMeans) (using the builder pattern).
pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
Expand Down
10 changes: 10 additions & 0 deletions algorithms/linfa-clustering/src/optics/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub struct Optics;
/// This struct represents a data point in the dataset with it's associated distances obtained from
/// the OPTICS analysis
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Sample<F> {
/// Index of the observation in the dataset
index: usize,
Expand Down Expand Up @@ -103,6 +108,11 @@ impl<F: Float> Ord for Sample<F> {
/// that of the dataset instead ordering based on the clustering structure worked out during
/// analysis.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsAnalysis<F: Float> {
/// A list of the samples in the dataset sorted and with their reachability and core distances
/// computed.
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/optics/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl<F: Float, D, N> OpticsValidParams<F, D, N> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsParams<F, D, N>(OpticsValidParams<F, D, N>);

impl<F: Float, D, N> OpticsParams<F, D, N> {
Expand Down

0 comments on commit aa56bf4

Please sign in to comment.