@@ -15,7 +15,7 @@ class LS_Stream {
1515 }
1616 double Predict ()
1717 {
18- pred=MathUtils ::dot (x.data (),w. data (),n );
18+ pred=slmath ::dot (x.get_span (),w);
1919 return pred;
2020 }
2121 virtual void Update (double val)=0;
@@ -58,7 +58,6 @@ void update_w_avx(double* w, const double* mutab, const double* x, double wgrad,
5858
5959class NLMS_Stream : public LS_Stream
6060{
61- const double eps_pow=1.0 ;
6261 public:
6362 NLMS_Stream (int n,double mu,double mu_decay=1.0 ,double pow_decay=0.8 )
6463 :LS_Stream(n),mutab(n),powtab(n),mu(mu)
@@ -71,44 +70,37 @@ class NLMS_Stream : public LS_Stream
7170 }
7271 }
7372
74- #if defined(USE_AVX256)
7573 double calc_spow (const double *x,const double *powtab,std::size_t n)
7674 {
7775 double spow=0.0 ;
76+
7877 std::size_t i=0 ;
79- if (n>=4 ) {
80- __m256d sum_vec = _mm256_setzero_pd ();
81- for (; i + 4 <= n; i += 4 ) {
82- __m256d x_vec = _mm256_loadu_pd (&x[i]);
83- __m256d pow_vec = _mm256_load_pd (&powtab[i]);
84- __m256d x_squared = _mm256_mul_pd (x_vec, x_vec);
85- sum_vec = _mm256_fmadd_pd (pow_vec, x_squared, sum_vec);
86- }
8778
88- alignas (32 ) double buffer[4 ];
89- _mm256_store_pd (buffer, sum_vec);
90- spow = buffer[0 ] + buffer[1 ] + buffer[2 ] + buffer[3 ];
79+ if constexpr (SACGlobalCfg::USE_AVX2) {
80+ if (n>=4 ) {
81+ __m256d sum_vec = _mm256_setzero_pd ();
82+ for (; i + 4 <= n; i += 4 ) {
83+ __m256d x_vec = _mm256_loadu_pd (&x[i]);
84+ __m256d pow_vec = _mm256_load_pd (&powtab[i]);
85+ __m256d x_squared = _mm256_mul_pd (x_vec, x_vec);
86+ sum_vec = _mm256_fmadd_pd (pow_vec, x_squared, sum_vec);
87+ }
88+
89+ alignas (32 ) double buffer[4 ];
90+ _mm256_store_pd (buffer, sum_vec);
91+ spow = buffer[0 ] + buffer[1 ] + buffer[2 ] + buffer[3 ];
92+ }
9193 }
9294
9395 for (;i<n;i++)
9496 spow += powtab[i] * (x[i] * x[i]);
9597 return spow;
9698 }
97- #else
98- double calc_spow (const double *x,const double *powtab,std::size_t n)
99- {
100- double spow=0.0 ;
101- for (std::size_t i=0 ;i<n;i++) {
102- spow+=powtab[i]*(x[i]*x[i]);
103- }
104- return spow;
105- }
106- #endif
10799
108100 void Update (double val) override
109101 {
110102 const double spow=calc_spow (x.data (),powtab.data (),n);
111- const double wgrad=mu*(val-pred)*sum_powtab/(eps_pow+ spow);
103+ const double wgrad=mu*(val-pred)*sum_powtab/(spow+SACGlobalCfg::NLMS_POW_EPS );
112104 for (int i=0 ;i<n;i++) {
113105 w[i]+=mutab[i]*(wgrad*x[i]);
114106 }
@@ -135,7 +127,7 @@ class LADADA_Stream : public LS_Stream
135127 for (int i=0 ;i<n;i++) {
136128 double const grad=serr*x[i];
137129 eg[i]=beta*eg[i]+(1.0 -beta)*grad*grad; // accumulate gradients
138- double g=grad*1.0 /(sqrt (eg[i])+1E-5 );// update weights
130+ double g=grad*1.0 /(sqrt (eg[i])+SACGlobalCfg::LMS_ADA_EPS );// update weights
139131 w[i]+=mu*g;
140132 }
141133 x.push (val);
@@ -159,7 +151,7 @@ class LMSADA_Stream : public LS_Stream
159151 for (int i=0 ;i<n;i++) {
160152 double const grad=err*x[i]-nu*MathUtils::sgn (w[i]);
161153 eg[i]=beta*eg[i]+(1.0 -beta)*grad*grad; // accumulate gradients
162- double g=grad*1.0 /(sqrt (eg[i])+1E-5 );// update weights
154+ double g=grad*1.0 /(sqrt (eg[i])+SACGlobalCfg::LMS_ADA_EPS );// update weights
163155 w[i]+=mu*g;
164156 }
165157 x.push (val);
@@ -180,7 +172,7 @@ class LMS {
180172 double Predict (const vec1D &inp)
181173 {
182174 x=inp;
183- pred=slmath::dot_scalar (x,w);
175+ pred=slmath::dot (x,w);
184176 return pred;
185177 }
186178 virtual void Update (double )=0;
@@ -204,7 +196,7 @@ class LMS_ADA : public LMS
204196 double const grad=err*x[i] - nu*MathUtils::sgn (w[i]); // gradient + l1-regularization
205197
206198 eg[i]=beta*eg[i]+(1.0 -beta)*grad*grad; // accumulate gradients
207- double g=grad*1.0 /(sqrt (eg[i])+1E-5 );// update weights
199+ double g=grad*1.0 /(sqrt (eg[i])+SACGlobalCfg::LMS_ADA_EPS );// update weights
208200 w[i]+=mu*g;
209201 }
210202 }
@@ -226,7 +218,7 @@ class LAD_ADA : public LMS
226218 for (int i=0 ;i<n;i++) {
227219 double const grad=serr*x[i];
228220 eg[i]=beta*eg[i]+(1.0 -beta)*grad*grad; // accumulate gradients
229- double scaled_grad=grad*1.0 /(sqrt (eg[i])+1E-5 );// update weights
221+ double scaled_grad=grad*1.0 /(sqrt (eg[i])+SACGlobalCfg::LMS_ADA_EPS );// update weights
230222 w[i]+=mu*scaled_grad;
231223 }
232224 }
@@ -264,7 +256,7 @@ class HBR_ADA : public LMS
264256 for (int i=0 ;i<n;i++) {
265257 double const grad=grad_loss*x[i];
266258 eg[i]=beta*eg[i]+(1.0 -beta)*grad*grad; // accumulate gradients
267- const double g=grad*1.0 /(sqrt (eg[i])+1E-5 );// update weights
259+ const double g=grad*1.0 /(sqrt (eg[i])+SACGlobalCfg::LMS_ADA_EPS );// update weights
268260 w[i]+=mu*g;
269261 }
270262
@@ -304,7 +296,7 @@ class LMS_ADAM : public LMS
304296 double n_hat=beta2*S[i]/(1.0-power_beta2);*/
305297 double m_hat=M[i]/(1.0 -power_beta1);
306298 double n_hat=S[i]/(1.0 -power_beta2);
307- w[i]+=mu*m_hat/(sqrt (n_hat)+1E-5 );
299+ w[i]+=mu*m_hat/(sqrt (n_hat)+SACGlobalCfg::LMS_ADA_EPS );
308300 }
309301 }
310302 private:
0 commit comments