Skip to content

Commit 0089bef

Browse files
committed
unify dot-product, style fixes
1 parent 3e3b603 commit 0089bef

File tree

13 files changed

+102
-61
lines changed

13 files changed

+102
-61
lines changed

src/cmdline.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
#include <format>
2+
#include <cstring>
13
#include "cmdline.h"
24
#include "common/utils.h"
35
#include "common/timer.h"
46
#include "file/sac.h"
5-
#include <cstring>
7+
68

79
CmdLine::CmdLine()
810
:mode(ENCODE)

src/common/histbuf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class RollBuffer2 {
7676
return buf[pos + index];
7777
}
7878

79-
const std::span<T> get_span() const {
79+
const std::span<const T> get_span() const {
8080
return std::span<const T>{buf.data() + pos,n};
8181
}
8282
const T* data() const {

src/common/math.h

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "../global.h"
55
#include <cassert>
66
#include <cmath>
7+
#include <immintrin.h>
78

89
namespace slmath
910
{
@@ -58,21 +59,60 @@ namespace slmath
5859
vec2D G;
5960
};
6061

61-
inline double dot_scalar(const vec1D &v1,const vec1D &v2)
62+
63+
inline double dot(span_cf64 x,span_cf64 y)
6264
{
63-
assert(v1.size()==v2.size());
64-
double sum=0.0;
65-
for (std::size_t i=0;i<v1.size();++i)
66-
sum+=v1[i]*v2[i];
67-
return sum;
65+
assert(x.size()==y.size());
66+
const std::size_t n=x.size();
67+
double total=0.0;
68+
std::size_t i=0;
69+
70+
if constexpr(SACGlobalCfg::USE_AVX2) {
71+
if constexpr(SACGlobalCfg::UNROLL_AVX2) {
72+
if (n>=8)
73+
{
74+
__m256d sum1 = _mm256_setzero_pd();
75+
__m256d sum2 = _mm256_setzero_pd();
76+
for (;i + 8 <= n;i += 8)
77+
{
78+
__m256d vx1 = _mm256_loadu_pd(&x[i]);
79+
__m256d vy1 = _mm256_loadu_pd(&y[i]);
80+
sum1 = _mm256_fmadd_pd(vx1, vy1, sum1);
81+
__m256d vx2 = _mm256_loadu_pd(&x[i + 4]);
82+
__m256d vy2 = _mm256_loadu_pd(&y[i + 4]);
83+
sum2 = _mm256_fmadd_pd(vx2, vy2, sum2);
84+
}
85+
sum1 = _mm256_add_pd(sum1, sum2);
86+
alignas(32) double buffer[4];
87+
_mm256_store_pd(buffer, sum1);
88+
total = buffer[0] + buffer[1] + buffer[2] + buffer[3];
89+
}
90+
} else if (n>=4)
91+
{
92+
__m256d sum = _mm256_setzero_pd();
93+
for (;i + 4 <= n;i += 4)
94+
{
95+
__m256d vx = _mm256_loadu_pd(&x[i]);
96+
__m256d vy = _mm256_loadu_pd(&y[i]);
97+
sum = _mm256_fmadd_pd(vx, vy, sum);
98+
}
99+
alignas(32) double buffer[4];
100+
_mm256_store_pd(buffer, sum);
101+
total = buffer[0] + buffer[1] + buffer[2] + buffer[3];
102+
}
103+
}
104+
105+
for (;i<n;i++)
106+
total+=x[i]*y[i];
107+
return total;
68108
}
69109

70110
// vector = matrix * vector
71111
inline vec1D mul(const vec2D &m,const vec1D &v)
72112
{
73113
vec1D v_out(m.size());
74114
for (std::size_t i=0;i<m.size();i++)
75-
v_out[i]=slmath::dot_scalar(m[i],v);
115+
v_out[i]=slmath::dot(m[i],v);
76116
return v_out;
77117
}
78118

src/common/utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ namespace StrUtils {
150150

151151
namespace MathUtils {
152152

153-
#if defined(USE_AVX512)
153+
154+
/*#if defined(USE_AVX512)
154155
inline double dot(const double* x,const double* y, std::size_t n)
155156
{
156157
__m512d sum = _mm512_setzero_pd();
@@ -222,7 +223,7 @@ inline double dot(const double* x,const double* y, std::size_t n)
222223
sum+=x[i]*y[i];
223224
return sum;
224225
}
225-
#endif
226+
#endif*/
226227

227228
inline double calc_loglik_L1(double abs_e,double b)
228229
{

src/file/wav.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include "wav.h"
2-
#include "../common/utils.h"
31
#include <iostream>
2+
#include <format>
3+
#include "../common/utils.h"
4+
#include "wav.h"
45

56
int word_align(int numbytes)
67
{

src/global.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <fstream>
88
#include <sstream>
99
#include <iomanip>
10-
#include <chrono>
1110
#include <vector>
1211
#include <span>
1312

@@ -23,8 +22,14 @@ using span_i32=std::span<int32_t>;
2322
using span_ci32=std::span<const int32_t>;
2423
using span_cf64=std::span<const double>;
2524

26-
#define USE_AVX256
27-
//#define UNROLL_AVX256
28-
//#define USE_AVX512
25+
struct SACGlobalCfg {
26+
static constexpr bool USE_AVX2=true;
27+
static constexpr bool UNROLL_AVX2=true;
28+
static constexpr double NLMS_POW_EPS=1.0;
29+
static constexpr double LMS_ADA_EPS=1E-5;
30+
static constexpr bool LMS_MIX_INIT=true;// increase stability
31+
static constexpr bool LMS_MIX_CLAMPW=true;
32+
static constexpr bool RLS_ALC=true; //adaptive lambda control
33+
};
2934

3035
#endif

src/opt/cma.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
#include <format>
2+
#include "../common/math.h"
13
#include "cma.h"
24
#include "ssc.h"
3-
#include "../common/math.h"
45

56
OptCMA::OptCMA(const CMACfg &cfg,const box_const &parambox,bool verbose)
67
:Opt(parambox),cfg(cfg),p(ndim),

src/opt/dds.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
#include <format>
12
#include "dds.h"
23
#include "ssc.h"
34

5+
46
OptDDS::OptDDS(const DDSCfg &cfg,const box_const &parambox,bool verbose)
57
:Opt(parambox),cfg(cfg),
68
verbose(verbose)

src/opt/de.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <cassert>
2+
#include <format>
23
#include "de.h"
34
#include "../common/utils.h"
45

src/pred/lms.h

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class LS_Stream {
1515
}
1616
double Predict()
1717
{
18-
pred=MathUtils::dot(x.data(),w.data(),n);
18+
pred=slmath::dot(x.get_span(),w);
1919
return pred;
2020
}
2121
virtual void Update(double val)=0;
@@ -58,7 +58,6 @@ void update_w_avx(double* w, const double* mutab, const double* x, double wgrad,
5858

5959
class NLMS_Stream : public LS_Stream
6060
{
61-
const double eps_pow=1.0;
6261
public:
6362
NLMS_Stream(int n,double mu,double mu_decay=1.0,double pow_decay=0.8)
6463
:LS_Stream(n),mutab(n),powtab(n),mu(mu)
@@ -71,44 +70,37 @@ class NLMS_Stream : public LS_Stream
7170
}
7271
}
7372

74-
#if defined(USE_AVX256)
7573
double calc_spow(const double *x,const double *powtab,std::size_t n)
7674
{
7775
double spow=0.0;
76+
7877
std::size_t i=0;
79-
if (n>=4) {
80-
__m256d sum_vec = _mm256_setzero_pd();
81-
for (; i + 4 <= n; i += 4) {
82-
__m256d x_vec = _mm256_loadu_pd(&x[i]);
83-
__m256d pow_vec = _mm256_load_pd(&powtab[i]);
84-
__m256d x_squared = _mm256_mul_pd(x_vec, x_vec);
85-
sum_vec = _mm256_fmadd_pd(pow_vec, x_squared, sum_vec);
86-
}
8778

88-
alignas(32) double buffer[4];
89-
_mm256_store_pd(buffer, sum_vec);
90-
spow = buffer[0] + buffer[1] + buffer[2] + buffer[3];
79+
if constexpr(SACGlobalCfg::USE_AVX2) {
80+
if (n>=4) {
81+
__m256d sum_vec = _mm256_setzero_pd();
82+
for (; i + 4 <= n; i += 4) {
83+
__m256d x_vec = _mm256_loadu_pd(&x[i]);
84+
__m256d pow_vec = _mm256_load_pd(&powtab[i]);
85+
__m256d x_squared = _mm256_mul_pd(x_vec, x_vec);
86+
sum_vec = _mm256_fmadd_pd(pow_vec, x_squared, sum_vec);
87+
}
88+
89+
alignas(32) double buffer[4];
90+
_mm256_store_pd(buffer, sum_vec);
91+
spow = buffer[0] + buffer[1] + buffer[2] + buffer[3];
92+
}
9193
}
9294

9395
for (;i<n;i++)
9496
spow += powtab[i] * (x[i] * x[i]);
9597
return spow;
9698
}
97-
#else
98-
double calc_spow(const double *x,const double *powtab,std::size_t n)
99-
{
100-
double spow=0.0;
101-
for (std::size_t i=0;i<n;i++) {
102-
spow+=powtab[i]*(x[i]*x[i]);
103-
}
104-
return spow;
105-
}
106-
#endif
10799

108100
void Update(double val) override
109101
{
110102
const double spow=calc_spow(x.data(),powtab.data(),n);
111-
const double wgrad=mu*(val-pred)*sum_powtab/(eps_pow+spow);
103+
const double wgrad=mu*(val-pred)*sum_powtab/(spow+SACGlobalCfg::NLMS_POW_EPS);
112104
for (int i=0;i<n;i++) {
113105
w[i]+=mutab[i]*(wgrad*x[i]);
114106
}
@@ -135,7 +127,7 @@ class LADADA_Stream : public LS_Stream
135127
for (int i=0;i<n;i++) {
136128
double const grad=serr*x[i];
137129
eg[i]=beta*eg[i]+(1.0-beta)*grad*grad; //accumulate gradients
138-
double g=grad*1.0/(sqrt(eg[i])+1E-5);// update weights
130+
double g=grad*1.0/(sqrt(eg[i])+SACGlobalCfg::LMS_ADA_EPS);// update weights
139131
w[i]+=mu*g;
140132
}
141133
x.push(val);
@@ -159,7 +151,7 @@ class LMSADA_Stream : public LS_Stream
159151
for (int i=0;i<n;i++) {
160152
double const grad=err*x[i]-nu*MathUtils::sgn(w[i]);
161153
eg[i]=beta*eg[i]+(1.0-beta)*grad*grad; //accumulate gradients
162-
double g=grad*1.0/(sqrt(eg[i])+1E-5);// update weights
154+
double g=grad*1.0/(sqrt(eg[i])+SACGlobalCfg::LMS_ADA_EPS);// update weights
163155
w[i]+=mu*g;
164156
}
165157
x.push(val);
@@ -180,7 +172,7 @@ class LMS {
180172
double Predict(const vec1D &inp)
181173
{
182174
x=inp;
183-
pred=slmath::dot_scalar(x,w);
175+
pred=slmath::dot(x,w);
184176
return pred;
185177
}
186178
virtual void Update(double)=0;
@@ -204,7 +196,7 @@ class LMS_ADA : public LMS
204196
double const grad=err*x[i] - nu*MathUtils::sgn(w[i]); // gradient + l1-regularization
205197

206198
eg[i]=beta*eg[i]+(1.0-beta)*grad*grad; //accumulate gradients
207-
double g=grad*1.0/(sqrt(eg[i])+1E-5);// update weights
199+
double g=grad*1.0/(sqrt(eg[i])+SACGlobalCfg::LMS_ADA_EPS);// update weights
208200
w[i]+=mu*g;
209201
}
210202
}
@@ -226,7 +218,7 @@ class LAD_ADA : public LMS
226218
for (int i=0;i<n;i++) {
227219
double const grad=serr*x[i];
228220
eg[i]=beta*eg[i]+(1.0-beta)*grad*grad; //accumulate gradients
229-
double scaled_grad=grad*1.0/(sqrt(eg[i])+1E-5);// update weights
221+
double scaled_grad=grad*1.0/(sqrt(eg[i])+SACGlobalCfg::LMS_ADA_EPS);// update weights
230222
w[i]+=mu*scaled_grad;
231223
}
232224
}
@@ -264,7 +256,7 @@ class HBR_ADA : public LMS
264256
for (int i=0;i<n;i++) {
265257
double const grad=grad_loss*x[i];
266258
eg[i]=beta*eg[i]+(1.0-beta)*grad*grad; //accumulate gradients
267-
const double g=grad*1.0/(sqrt(eg[i])+1E-5);// update weights
259+
const double g=grad*1.0/(sqrt(eg[i])+SACGlobalCfg::LMS_ADA_EPS);// update weights
268260
w[i]+=mu*g;
269261
}
270262

@@ -304,7 +296,7 @@ class LMS_ADAM : public LMS
304296
double n_hat=beta2*S[i]/(1.0-power_beta2);*/
305297
double m_hat=M[i]/(1.0-power_beta1);
306298
double n_hat=S[i]/(1.0-power_beta2);
307-
w[i]+=mu*m_hat/(sqrt(n_hat)+1E-5);
299+
w[i]+=mu*m_hat/(sqrt(n_hat)+SACGlobalCfg::LMS_ADA_EPS);
308300
}
309301
}
310302
private:

0 commit comments

Comments
 (0)