Skip to content

Commit ee14463

Browse files
author
uyplayer
committed
add helper
1 parent fc19337 commit ee14463

File tree

6 files changed

+227
-21
lines changed

6 files changed

+227
-21
lines changed

Cargo.lock

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ path = "src/main.rs"
2020

2121
[dependencies]
2222
polars = { version = "^0.32.1", features = ["lazy","describe"] }
23-
23+
rand = "0.8.5"

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
2020

2121
mod lorentzian_classification;
22-
pub use lorentzian_classification::{rational_quadratic,rational_quadratic_tv,gaussian,gaussian_tv};
22+
pub use lorentzian_classification::{rational_quadratic,rational_quadratic_tv,gaussian,gaussian_tv,normalizer,rescale,rma_indicator,Settings,Filters,KernelFilter,Direction};
23+
2324

src/lorentzian_classification/helper.rs

+145-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,148 @@
77
* @Dir: tech_analysis / src/lorentzian_classification
88
* @Project_Name: tech_analysis
99
* @Description:
10-
*/
10+
*/
11+
12+
//! Helper functions for data manipulation
13+
14+
use polars::export::arrow::array::{Float64Array};
15+
use polars::prelude::*;
16+
17+
18+
/// Normalizes the values of the input series to a given range.
19+
///
20+
/// # Arguments
21+
/// * `src` - The input series
22+
/// * `min_val` - The minimum value of the range to normalize to
23+
/// * `max_val` - The maximum value of the range to normalize to
24+
///
25+
/// # Returns
26+
/// The normalized series.s
27+
pub fn normalizer<'a>(src: &'a Series, min_val: f64, max_val: f64) -> Result<Series, Box<dyn std::error::Error>> {
28+
let array = src.to_arrow(0);
29+
let vec_values = match array.as_any().downcast_ref::<Float64Array>() {
30+
Some(float_array) => {
31+
let values: &[f64] = float_array.values();
32+
let vec_values: Vec<f64> = values.to_vec();
33+
vec_values
34+
}
35+
None => return Err("Failed to downcast to Float64Array or Int64Array".into()),
36+
};
37+
let actual_min_val = vec_values
38+
.iter()
39+
.min_by(|x, y| x.partial_cmp(y).unwrap())
40+
.ok_or("Failed to find the minimum value in vec_values")?;
41+
42+
let actual_max_val = vec_values
43+
.iter()
44+
.max_by(|x, y| x.partial_cmp(y).unwrap())
45+
.ok_or("Failed to find the maximum value in vec_values")?;
46+
47+
let scaled_values: Vec<f64> = vec_values
48+
.iter()
49+
.map(|&x| (x - actual_min_val) / (actual_max_val - actual_min_val) * (max_val - min_val) + min_val)
50+
.collect();
51+
52+
Ok(Series::new("date", scaled_values))
53+
54+
55+
}
56+
57+
/// Rescales the values of the input series from one bounded range to another bounded range.
58+
///
59+
/// # Arguments
60+
/// * `src` - The input series
61+
/// * `old_min` - The minimum value of the range to rescale from
62+
/// * `old_max` - The maximum value of the range to rescale from
63+
/// * `new_min` - The minimum value of the range to rescale to
64+
/// * `new_max` - The maximum value of the range to rescale to
65+
///
66+
/// # Returns
67+
/// The rescaled series
68+
pub fn rescale<'a>(src: &'a Series, old_min: f64, old_max: f64, new_min: f64, new_max: f64) -> Result<Series, Box<dyn std::error::Error>> {
69+
let array = src.to_arrow(0);
70+
let vec_values = match array.as_any().downcast_ref::<Float64Array>() {
71+
Some(float_array) => {
72+
let values: &[f64] = float_array.values();
73+
let vec_values: Vec<f64> = values.to_vec();
74+
vec_values
75+
}
76+
None => return Err("Failed to downcast to Float64Array or Int64Array".into()),
77+
};
78+
let epsilon = 10e-10;
79+
let vec_values = vec_values.iter()
80+
.map(|x| new_min + (new_max - new_min) * (x - old_min) / f64::max(old_max - old_min, epsilon))
81+
.collect::<Vec<f64>>();
82+
Ok(Series::new("date", vec_values))
83+
}
84+
85+
86+
/// Computes the Rolling Moving Average (RMA) of the input series and then calculates the Exponential
87+
/// Weighted Moving Average (EWMA) using the RMA values.
88+
///
89+
/// # Arguments
90+
///
91+
/// * `src` - The input series.
92+
/// * `length` - The length of the rolling window.
93+
///
94+
/// # Returns
95+
///
96+
/// The series containing the EWMA values.
97+
pub fn rma_indicator(src: &Series, length: i32)->Result<Series, Box<dyn std::error::Error>> {
98+
99+
let rolling = src.rolling_mean( length as usize)?;
100+
let mut ewma: Vec<Option<f64>> = vec![];
101+
let alpha = 2.0 / (length + 1) as f64;
102+
103+
let mut prev_ema = None;
104+
for opt in rolling.into_iter() {
105+
if let Some(val) = opt {
106+
match prev_ema {
107+
Some(prev) => {
108+
let ema = alpha * val + (1.0 - alpha) * prev;
109+
ewma.push(Some(ema));
110+
prev_ema = Some(ema);
111+
}
112+
None => {
113+
ewma.push(Some(val));
114+
prev_ema = Some(val);
115+
}
116+
}
117+
} else {
118+
ewma.push(None);
119+
prev_ema = None;
120+
}
121+
}
122+
Ok(Series::new("data", &ewma))
123+
124+
}
125+
126+
#[cfg(test)]
127+
mod tests {
128+
use rand::Rng;
129+
use super::*;
130+
131+
#[test]
132+
fn test_normalizer()->Result<(), Box<dyn std::error::Error>>{
133+
let mut rng = rand::thread_rng();
134+
let random_data: Vec<f64> = (0..1000).map(|_| rng.gen_range(1.0..2000.0)).collect();
135+
let src = Series::new("data", random_data);
136+
let res = normalizer(&src, -2.0, 2.0)?;
137+
eprintln!("{:?}", res);
138+
Ok(())
139+
}
140+
#[test]
141+
fn test_rescale()->Result<(), Box<dyn std::error::Error>>{
142+
let mut rng = rand::thread_rng();
143+
let random_data: Vec<f64> = (0..1000).map(|_| rng.gen_range(1.0..2000.0)).collect();
144+
let old_min = 1.0;
145+
let old_max = 5.0;
146+
let new_min = 0.0;
147+
let new_max = 1.0;
148+
let src = Series::new("data", random_data);
149+
let res = rescale(&src, old_min, old_max, new_min, new_max)?;
150+
eprintln!("{:?}", res);
151+
Ok(())
152+
}
153+
154+
}

src/lorentzian_classification/mod.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111

1212

1313
mod kernel;
14-
pub use kernel::{rational_quadratic,rational_quadratic_tv,gaussian,gaussian_tv};
14+
mod types;
1515
mod helper;
1616

17+
pub use kernel::{rational_quadratic,rational_quadratic_tv,gaussian,gaussian_tv};
18+
pub use types::{Settings,Filters,KernelFilter,Direction};
19+
pub use helper::{normalizer,rescale,rma_indicator};
20+
21+
22+
23+
24+
25+
26+
27+

src/lorentzian_classification/types.rs

+65-16
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,34 @@
99
* @Description:
1010
*/
1111

12+
1213
//! all type declared here using in lorentzian classification
1314
1415

1516

17+
18+
19+
/// A trait for checking the configuration.
20+
///
21+
/// This trait provides a method for checking the configuration. The associated type `Output` represents
22+
/// the type of the output of the configuration check.
23+
pub trait ConfigCheck {
24+
/// The type of the output of the configuration check.
25+
type Output;
26+
27+
/// Checks the configuration.
28+
///
29+
/// # Returns
30+
///
31+
/// The result of the configuration check.
32+
fn configuration_check(&self) -> Self::Output;
33+
}
34+
35+
36+
1637
// settings
1738
/// Settings struct representing settings for a certain functionality for classification.
18-
# [derive(Debug)]
39+
# [derive(Debug,Clone,Copy,Eq, PartialEq)]
1940
pub struct Settings<'a>{
2041
/// The data source for the functionality.
2142
pub source:&'a str,
@@ -39,28 +60,27 @@ pub struct Settings<'a>{
3960
}
4061
// check setting params
4162
/// checking Settings params validation
42-
impl<'a> Settings<'a> {
43-
fn check_settings(&self) {
63+
impl ConfigCheck for Settings<'_>{
64+
type Output = ();
65+
fn configuration_check(&self) -> Self::Output {
4466
// Access each member to trigger compile-time checks
4567
let source = ["close","open","high","low","volume","vol"];
4668
assert!(source.contains(&self.source));
4769
if self.neighbors_count <= 0 {
48-
panic!(" neighbors_count must be bigger than zero ")
70+
panic!(" neighbors_count must be bigger than zero ");
4971
}
5072
if self.max_bars_back <= 0 {
51-
panic!(" max_bars_back must be bigger than zero ")
73+
panic!(" max_bars_back must be bigger than zero ");
5274
}
5375
if self.ema_period <= 1 {
54-
panic!(" ema_period must be bigger than one ")
76+
panic!(" ema_period must be bigger than one ");
5577
}
5678
if self.sma_period <= 1 {
57-
panic!(" sma_period must be bigger than one ")
79+
panic!(" sma_period must be bigger than one ");
5880
}
5981
}
60-
}
61-
62-
// test for settings
6382

83+
}
6484

6585
// filter setting
6686
/// a set of filters struct used for classification.
@@ -77,8 +97,9 @@ pub struct Filters{
7797
pub adx_threshold: i32,
7898
}
7999

80-
impl Filters {
81-
fn check_filters(&self) {
100+
impl ConfigCheck for Filters {
101+
type Output =();
102+
fn configuration_check(&self)->Self::Output {
82103
// Access each member to trigger compile-time checks
83104
if !(self.regime_threshold >= -10.0 && self.regime_threshold <= 10.0) {
84105
panic!("regime_threshold must be between -10.0 and 10.0");
@@ -107,6 +128,27 @@ pub struct KernelFilter {
107128
pub crossover_lag: i32,
108129
}
109130

131+
impl ConfigCheck for KernelFilter{
132+
133+
type Output = ();
134+
fn configuration_check(&self) -> Self::Output {
135+
136+
if self.look_back_window < 0 {
137+
panic!("look_back_window must be greater tha 0");
138+
}
139+
if self.relative_weight < 0.0 {
140+
panic!("relative_weight must be greater tha 0");
141+
}
142+
if self.regression_level < 0.0 {
143+
panic!("regression_level must be greater tha 0");
144+
}
145+
if self.crossover_lag < 0 {
146+
panic!("regression_level must be greater tha 0");
147+
}
148+
149+
}
150+
151+
}
110152

111153
// market trend direction
112154
pub enum Direction{
@@ -115,8 +157,9 @@ pub enum Direction{
115157
NEUTRAL = 0,
116158
}
117159

160+
// unit tes
118161
#[cfg(test)]
119-
mod test{
162+
mod tests {
120163
use super::*;
121164

122165
#[test]
@@ -132,7 +175,7 @@ mod test{
132175
use_sma_filter: true,
133176
sma_period: 20,
134177
};
135-
settings.check_settings();
178+
settings.configuration_check();
136179
}
137180
#[test]
138181
fn test_filters() {
@@ -143,23 +186,29 @@ mod test{
143186
regime_threshold: 0.0,
144187
adx_threshold: 10,
145188
};
146-
filters.check_filters();
189+
filters.configuration_check();
147190
}
148191
#[test]
149192
fn test_kernel_filter(){
150-
let _ = KernelFilter{
193+
let kernel = KernelFilter{
151194
show_kernel_estimate: false,
152195
use_kernel_smoothing: false,
153196
look_back_window: 0,
154197
relative_weight: 0.0,
155198
regression_level: 0.0,
156199
crossover_lag: 0,
157200
};
201+
kernel.configuration_check();
158202
}
159203
#[test]
160204
fn test_direction(){
161205
let _ = Direction::LONG;
162206
let _ = Direction::SHORT;
163207
let _ = Direction::NEUTRAL;
164208
}
209+
210+
#[test]
211+
fn test_finish(){
212+
eprintln!("finished");
213+
}
165214
}

0 commit comments

Comments
 (0)