Skip to content

Commit

Permalink
matrix mul
Browse files Browse the repository at this point in the history
  • Loading branch information
NoNaeAbC committed Jan 25, 2021
1 parent 8b9d44e commit 91de9a6
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 53 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.s
build
cmake-build-debug
*.html
*.js
Expand Down
9 changes: 5 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ find_package(Lua REQUIRED)

set(CMAKE_CXX_STANDARD 20)

add_executable(mathlib main.cpp aml_lua_binding.cpp aml_lua_binding.h)
add_executable(mathlib main.cpp)
add_executable(mathliblua aml_lua_binding.cpp aml_lua_binding.h testlua.cpp)

target_link_libraries(mathlib lua)
target_link_libraries(mathliblua lua)


if (NOT CMAKE_BUILD_TYPE)
Expand All @@ -18,5 +19,5 @@ set(CMAKE_C_COMPILER "clang")
set(CMAKE_CXX_COMPILER "clang++")

set(CMAKE_CXX_FLAGS "-Wall -Wextra")
set(CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG -O3 -march=skylake -ffast-math -static-libstdc++ -DX86_64")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=skylake -ffast-math -static-libstdc++ -DX86_64")
set(CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG -O3 -march=native -ffast-math")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native -ffast-math")
238 changes: 217 additions & 21 deletions amathlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
#ifndef MATH_LIB_A_MATH_LIB_H
#define MATH_LIB_A_MATH_LIB_H

#include <stdint.h>
#define DEBUG_TO_INDEX(row, column) ((column - 1) * 4 + (row-1))

#include <cstdint>

#ifdef X86_64
//#define USE_AVX512
//#define USE_AVX
//#define USE_SSE
Expand Down Expand Up @@ -70,26 +71,51 @@

#endif //NDEBUG

#if defined(DEBUG)

#if defined(USE_AVX512F)
#define USE_AVX512
#endif
#if defined(USE_AVX512)
#define USE_AVX512F
#endif
#if defined(USE_FMA)
#define USE_AVX2
#endif
#if defined(USE_AVX2)
#define USE_AVX
#endif
#if defined(USE_AVX)
#define USE_SSE42
#endif
#if defined(USE_SSE42)
#define USE_SSE41
#endif
#if defined(USE_SSE41)
#define USE_SSE3
#endif
#if defined(USE_SSE3)
#define USE_SSE2
#endif
#if defined(USE_SSE2)
#define USE_SSE1
#endif
#if defined(USE_SSE1)
#define USE_SSE
#endif

#include <immintrin.h>
#endif // DEBUG

#if defined(USE_AVX512)
#include <immintrin.h>
#endif


#if defined(USE_AVX)

#include <immintrin.h>

#define USE_SSE

#endif
#ifdef USE_SSE

#include <emmintrin.h>

#endif
#endif

#if defined(__EMSCRIPTEN__)
Expand All @@ -102,10 +128,8 @@
#endif

#if defined(__ARM_NEON)

#include <arm_neon.h>
#define USE_NEON

#endif

#include <cmath>
Expand Down Expand Up @@ -296,8 +320,8 @@ class VectorU8_4D {
VectorU8_4D(uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
v.c[0] = a;
v.c[1] = b;
v.c[3] = c;
v.c[4] = d;
v.c[2] = c;
v.c[3] = d;
}
};

Expand All @@ -309,7 +333,7 @@ class VectorDouble4D {
doublevec4 v{};

inline double operator[](uint32_t position) {
return v.c[position];//TODO maybe error handling
return v.c[position];
}

inline void operator+=(VectorDouble4D vec2) {
Expand All @@ -333,7 +357,7 @@ class VectorDouble4D {
VectorDouble4D ret;
ret.v.avx = _mm256_add_pd(v.avx, vec2.v.avx);
return ret;
#elif defined(USE_SSE) // SSE2
#elif defined(USE_SSE2)
VectorDouble4D ret;
ret.v.sse[0] = _mm_add_pd(v.sse[0], vec2.v.sse[0]);
ret.v.sse[1] = _mm_add_pd(v.sse[1], vec2.v.sse[1]);
Expand All @@ -351,10 +375,11 @@ class VectorDouble4D {
VectorDouble4D ret(a);
ret.v.avx = _mm256_add_pd(v.avx, ret.v.avx);
return ret;
#elif defined(USE_SSE) // SSE2
#elif defined(USE_SSE2)
VectorDouble4D ret(a);
ret.v.sse[0] = _mm_add_pd(v.sse[0], ret.v.sse[0]);
ret.v.sse[1] = _mm_add_pd(v.sse[1], ret.v.sse[1]);
return ret;
#else
VectorDouble4D ret(v.c[0] + a, v.c[1] + a, v.c[2] + a, v.c[3] + a);
return ret;
Expand Down Expand Up @@ -441,7 +466,6 @@ class VectorDouble4D {
}

inline void normalize() {
//TODO check length==0
double vecLength = 1 / length();
v.c[0] *= vecLength;
v.c[1] *= vecLength;
Expand Down Expand Up @@ -638,6 +662,7 @@ class VectorDouble4D {
};

class MatrixDouble4X4 {
public:
doublemat4x4 m;

inline VectorDouble4D operator[](uint32_t column) {
Expand All @@ -650,11 +675,182 @@ class MatrixDouble4X4 {
ret.v.c[2] = m.c[column * 4 + 2];
ret.v.c[3] = m.c[column * 4 + 3];
#endif
return ret;//TODO maybe error handling
return ret;
}

inline MatrixDouble4X4 *identity() {
m = (doublemat4x4) {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0};
return this;
}

inline MatrixDouble4X4 operator*(MatrixDouble4X4 b) {
MatrixDouble4X4 ret;
#if defined(USE_FMA)
/*
* m0 * bcst 0
* m0 * bcst 4
* m0 * bcst 8
* m0 * bcst 12
*/
__m256d O0 = _mm256_broadcastsd_pd((__m128d) {b.m.c[0], 0.0f});
__m256d O1 = _mm256_broadcastsd_pd((__m128d) {b.m.c[1], 0.0f});
__m256d O2 = _mm256_broadcastsd_pd((__m128d) {b.m.c[2], 0.0f});
__m256d O3 = _mm256_broadcastsd_pd((__m128d) {b.m.c[3], 0.0f});

ret.m.avx[0] = _mm256_mul_pd(m.avx[0], O0);
ret.m.avx[0] = _mm256_fmadd_pd(m.avx[1], O1, ret.m.avx[0]);
ret.m.avx[0] = _mm256_fmadd_pd(m.avx[2], O2, ret.m.avx[0]);
ret.m.avx[0] = _mm256_fmadd_pd(m.avx[3], O3, ret.m.avx[0]);

__m256d O4 = _mm256_broadcastsd_pd((__m128d) {b.m.c[4], 0.0f});
__m256d O5 = _mm256_broadcastsd_pd((__m128d) {b.m.c[5], 0.0f});
__m256d O6 = _mm256_broadcastsd_pd((__m128d) {b.m.c[6], 0.0f});
__m256d O7 = _mm256_broadcastsd_pd((__m128d) {b.m.c[7], 0.0f});

ret.m.avx[1] = _mm256_mul_pd(m.avx[0], O4);
ret.m.avx[1] = _mm256_fmadd_pd(m.avx[1], O5, ret.m.avx[1]);
ret.m.avx[1] = _mm256_fmadd_pd(m.avx[2], O6, ret.m.avx[1]);
ret.m.avx[1] = _mm256_fmadd_pd(m.avx[3], O7, ret.m.avx[1]);

__m256d O8 = _mm256_broadcastsd_pd((__m128d) {b.m.c[8], 0.0f});
__m256d O9 = _mm256_broadcastsd_pd((__m128d) {b.m.c[9], 0.0f});
__m256d O10 = _mm256_broadcastsd_pd((__m128d) {b.m.c[10], 0.0f});
__m256d O11 = _mm256_broadcastsd_pd((__m128d) {b.m.c[11], 0.0f});

ret.m.avx[2] = _mm256_mul_pd(m.avx[0], O8);
ret.m.avx[2] = _mm256_fmadd_pd(m.avx[1], O9, ret.m.avx[2]);
ret.m.avx[2] = _mm256_fmadd_pd(m.avx[2], O10, ret.m.avx[2]);
ret.m.avx[2] = _mm256_fmadd_pd(m.avx[3], O11, ret.m.avx[2]);

__m256d O12 = _mm256_broadcastsd_pd((__m128d) {b.m.c[12], 0.0f});
__m256d O13 = _mm256_broadcastsd_pd((__m128d) {b.m.c[13], 0.0f});
__m256d O14 = _mm256_broadcastsd_pd((__m128d) {b.m.c[14], 0.0f});
__m256d O15 = _mm256_broadcastsd_pd((__m128d) {b.m.c[15], 0.0f});

ret.m.avx[3] = _mm256_mul_pd(m.avx[0], O12);
ret.m.avx[3] = _mm256_fmadd_pd(m.avx[1], O13, ret.m.avx[3]);
ret.m.avx[3] = _mm256_fmadd_pd(m.avx[2], O14, ret.m.avx[3]);
ret.m.avx[3] = _mm256_fmadd_pd(m.avx[3], O15, ret.m.avx[3]);

#elif defined(USE_SSE2)

ret.m.sse[0] = _mm_mul_pd(m.sse[0], (__m128d) {b.m.c[0], b.m.c[0]});
__m128d cache = _mm_mul_pd(m.sse[2], (__m128d) {b.m.c[1], b.m.c[1]});
ret.m.sse[0] = _mm_add_pd(cache, ret.m.sse[0]);
cache = _mm_mul_pd(m.sse[4], (__m128d) {b.m.c[2], b.m.c[2]});
ret.m.sse[0] = _mm_add_pd(cache, ret.m.sse[0]);
cache = _mm_mul_pd(m.sse[6], (__m128d) {b.m.c[3], b.m.c[3]});
ret.m.sse[0] = _mm_add_pd(cache, ret.m.sse[0]);
//
ret.m.sse[1] = _mm_mul_pd(m.sse[1], (__m128d) {b.m.c[0], b.m.c[0]});
cache = _mm_mul_pd(m.sse[3], (__m128d) {b.m.c[1], b.m.c[1]});
ret.m.sse[1] = _mm_add_pd(cache, ret.m.sse[1]);
cache = _mm_mul_pd(m.sse[5], (__m128d) {b.m.c[2], b.m.c[2]});
ret.m.sse[1] = _mm_add_pd(cache, ret.m.sse[1]);
cache = _mm_mul_pd(m.sse[7], (__m128d) {b.m.c[3], b.m.c[3]});
ret.m.sse[1] = _mm_add_pd(cache, ret.m.sse[1]);
//

ret.m.sse[2] = _mm_mul_pd(m.sse[0], (__m128d) {b.m.c[4], b.m.c[4]});
cache = _mm_mul_pd(m.sse[2], (__m128d) {b.m.c[5], b.m.c[5]});
ret.m.sse[2] = _mm_add_pd(cache, ret.m.sse[2]);
cache = _mm_mul_pd(m.sse[4], (__m128d) {b.m.c[6], b.m.c[6]});
ret.m.sse[2] = _mm_add_pd(cache, ret.m.sse[2]);
cache = _mm_mul_pd(m.sse[6], (__m128d) {b.m.c[7], b.m.c[7]});
ret.m.sse[2] = _mm_add_pd(cache, ret.m.sse[2]);
//
ret.m.sse[3] = _mm_mul_pd(m.sse[1], (__m128d) {b.m.c[4], b.m.c[4]});
cache = _mm_mul_pd(m.sse[3], (__m128d) {b.m.c[5], b.m.c[5]});
ret.m.sse[3] = _mm_add_pd(cache, ret.m.sse[3]);
cache = _mm_mul_pd(m.sse[5], (__m128d) {b.m.c[6], b.m.c[6]});
ret.m.sse[3] = _mm_add_pd(cache, ret.m.sse[3]);
cache = _mm_mul_pd(m.sse[7], (__m128d) {b.m.c[7], b.m.c[7]});
ret.m.sse[3] = _mm_add_pd(cache, ret.m.sse[3]);
//

ret.m.sse[4] = _mm_mul_pd(m.sse[0], (__m128d) {b.m.c[8], b.m.c[8]});
cache = _mm_mul_pd(m.sse[2], (__m128d) {b.m.c[9], b.m.c[9]});
ret.m.sse[4] = _mm_add_pd(cache, ret.m.sse[4]);
cache = _mm_mul_pd(m.sse[4], (__m128d) {b.m.c[10], b.m.c[10]});
ret.m.sse[4] = _mm_add_pd(cache, ret.m.sse[4]);
cache = _mm_mul_pd(m.sse[6], (__m128d) {b.m.c[11], b.m.c[11]});
ret.m.sse[4] = _mm_add_pd(cache, ret.m.sse[4]);
//
ret.m.sse[5] = _mm_mul_pd(m.sse[1], (__m128d) {b.m.c[8], b.m.c[8]});
cache = _mm_mul_pd(m.sse[3], (__m128d) {b.m.c[9], b.m.c[9]});
ret.m.sse[5] = _mm_add_pd(cache, ret.m.sse[5]);
cache = _mm_mul_pd(m.sse[5], (__m128d) {b.m.c[10], b.m.c[10]});
ret.m.sse[5] = _mm_add_pd(cache, ret.m.sse[5]);
cache = _mm_mul_pd(m.sse[7], (__m128d) {b.m.c[11], b.m.c[11]});
ret.m.sse[5] = _mm_add_pd(cache, ret.m.sse[5]);
//

ret.m.sse[6] = _mm_mul_pd(m.sse[0], (__m128d) {b.m.c[12], b.m.c[12]});
cache = _mm_mul_pd(m.sse[2], (__m128d) {b.m.c[13], b.m.c[13]});
ret.m.sse[6] = _mm_add_pd(cache, ret.m.sse[6]);
cache = _mm_mul_pd(m.sse[4], (__m128d) {b.m.c[14], b.m.c[14]});
ret.m.sse[6] = _mm_add_pd(cache, ret.m.sse[6]);
cache = _mm_mul_pd(m.sse[6], (__m128d) {b.m.c[15], b.m.c[15]});
ret.m.sse[6] = _mm_add_pd(cache, ret.m.sse[6]);
//
ret.m.sse[7] = _mm_mul_pd(m.sse[1], (__m128d) {b.m.c[12], b.m.c[12]});
cache = _mm_mul_pd(m.sse[3], (__m128d) {b.m.c[13], b.m.c[13]});
ret.m.sse[7] = _mm_add_pd(cache, ret.m.sse[7]);
cache = _mm_mul_pd(m.sse[5], (__m128d) {b.m.c[14], b.m.c[14]});
ret.m.sse[7] = _mm_add_pd(cache, ret.m.sse[7]);
cache = _mm_mul_pd(m.sse[7], (__m128d) {b.m.c[15], b.m.c[15]});
ret.m.sse[7] = _mm_add_pd(cache, ret.m.sse[7]);
#else
ret.m.c[0] = m.c[0] * b.m.c[0] + m.c[4] * b.m.c[1] + m.c[8] * b.m.c[2] + m.c[12] * b.m.c[3];// c11 = a11 * b11 + a12 * b21 + ...
ret.m.c[1] = m.c[1] * b.m.c[0] + m.c[5] * b.m.c[1] + m.c[9] * b.m.c[2] + m.c[13] * b.m.c[3];// c12 = a21 + b11 + a22 + b21
ret.m.c[2] = m.c[2] * b.m.c[0] + m.c[6] * b.m.c[1] + m.c[10] * b.m.c[2] + m.c[14] * b.m.c[3];
ret.m.c[3] = m.c[3] * b.m.c[0] + m.c[7] * b.m.c[1] + m.c[11] * b.m.c[2] + m.c[15] * b.m.c[3];
ret.m.c[4] = m.c[0] * b.m.c[4] + m.c[4] * b.m.c[5] + m.c[8] * b.m.c[6] + m.c[12] * b.m.c[7];// c21 = a11 * b12 + b12 * b22 + ...
ret.m.c[5] = m.c[1] * b.m.c[4] + m.c[5] * b.m.c[5] + m.c[9] * b.m.c[6] + m.c[13] * b.m.c[7];
ret.m.c[6] = m.c[2] * b.m.c[4] + m.c[6] * b.m.c[5] + m.c[10] * b.m.c[6] + m.c[14] * b.m.c[7];
ret.m.c[7] = m.c[3] * b.m.c[4] + m.c[7] * b.m.c[5] + m.c[11] * b.m.c[6] + m.c[15] * b.m.c[7];
ret.m.c[8] = m.c[0] * b.m.c[8] + m.c[4] * b.m.c[9] + m.c[8] * b.m.c[10] + m.c[12] * b.m.c[11];
ret.m.c[9] = m.c[1] * b.m.c[8] + m.c[5] * b.m.c[9] + m.c[9] * b.m.c[10] + m.c[13] * b.m.c[11];
ret.m.c[10] = m.c[2] * b.m.c[8] + m.c[6] * b.m.c[9] + m.c[10] * b.m.c[10] + m.c[14] * b.m.c[11];
ret.m.c[11] = m.c[3] * b.m.c[8] + m.c[7] * b.m.c[9] + m.c[11] * b.m.c[10] + m.c[15] * b.m.c[11];
ret.m.c[12] = m.c[0] * b.m.c[12] + m.c[4] * b.m.c[13] + m.c[8] * b.m.c[14] + m.c[12] * b.m.c[15];
ret.m.c[13] = m.c[1] * b.m.c[12] + m.c[5] * b.m.c[13] + m.c[9] * b.m.c[14] + m.c[13] * b.m.c[15];
ret.m.c[14] = m.c[2] * b.m.c[12] + m.c[6] * b.m.c[13] + m.c[10] * b.m.c[14] + m.c[14] * b.m.c[15];
ret.m.c[15] = m.c[3] * b.m.c[12] + m.c[7] * b.m.c[13] + m.c[11] * b.m.c[14] + m.c[15] * b.m.c[15];
#endif
return ret;
}

inline VectorDouble4D operator*(VectorDouble4D b) {
VectorDouble4D ret;
ret.v.c[0] = m.c[0] * b.v.c[0] + m.c[4] * b.v.c[1] + m.c[8] * b.v.c[2] + m.c[12] * b.v.c[3];
ret.v.c[1] = m.c[1] * b.v.c[0] + m.c[5] * b.v.c[1] + m.c[9] * b.v.c[2] + m.c[13] * b.v.c[3];
ret.v.c[2] = m.c[2] * b.v.c[0] + m.c[6] * b.v.c[1] + m.c[10] * b.v.c[2] + m.c[14] * b.v.c[3];
ret.v.c[3] = m.c[3] * b.v.c[0] + m.c[7] * b.v.c[1] + m.c[11] * b.v.c[2] + m.c[15] * b.v.c[3];
return ret;
}

inline MatrixDouble4X4() {
m = (doublemat4x4) {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
}

inline MatrixDouble4X4(VectorDouble4D a, VectorDouble4D b, VectorDouble4D c, VectorDouble4D d) {
m.c[0] = a.v.c[0];
m.c[1] = a.v.c[1];
m.c[2] = a.v.c[2];
m.c[3] = a.v.c[3];
m.c[4] = b.v.c[0];
m.c[5] = b.v.c[1];
m.c[6] = b.v.c[2];
m.c[7] = b.v.c[3];
m.c[8] = c.v.c[0];
m.c[9] = c.v.c[1];
m.c[10] = c.v.c[2];
m.c[11] = c.v.c[3];
m.c[12] = d.v.c[0];
m.c[13] = d.v.c[1];
m.c[14] = d.v.c[2];
m.c[15] = d.v.c[3];
}
};

Expand All @@ -663,11 +859,11 @@ class VectorDouble8D {
doublevec8 v;

inline double operator[](uint32_t position) {
return v.c[position];//TODO maybe error handling
return v.c[position];
}

inline void operator+=(VectorDouble8D vec2) {
#if defined(USE_AVX512) // AVX512F OR KNCNI
#if defined(USE_AVX512F) || defined(KNCNI)
v.avx512 = _mm512_add_pd(v.avx512, vec2.v.avx512);
#elif defined(USE_AVX)
v.avx[0] = _mm256_add_pd(v.avx[0], vec2.v.avx[0]);
Expand Down
Loading

0 comments on commit 91de9a6

Please sign in to comment.