Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/lib/openjp2/dwt.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
#if (defined(__AVX2__) || defined(__AVX512F__))
#include <immintrin.h>
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif

#if defined(__GNUC__)
#pragma GCC poison malloc calloc realloc free
Expand All @@ -73,7 +76,7 @@
/** Number of int32 values in a AVX2 register */
#define VREG_INT_COUNT 8
#else
/** Number of int32 values in a SSE2 register */
/** Number of int32 values in a SSE2 or NEON register */
#define VREG_INT_COUNT 4
#endif

Expand Down Expand Up @@ -699,7 +702,7 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#endif
}

#if (defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION)
#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION)

/* Conveniency macros to improve the readability of the formulas */
#if defined(__AVX512F__)
Expand All @@ -722,6 +725,16 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#define ADD(x,y) _mm256_add_epi32((x),(y))
#define SUB(x,y) _mm256_sub_epi32((x),(y))
#define SAR(x,y) _mm256_srai_epi32((x),(y))
#elif defined(__ARM_NEON)
#define VREG int32x4_t
#define LOAD_CST(x) vdupq_n_s32(x)
#define LOAD(x) vld1q_s32((const int32_t*)(x))
#define LOADU(x) vld1q_s32((const int32_t*)(x))
#define STORE(x,y) vst1q_s32((int32_t*)(x),(y))
#define STOREU(x,y) vst1q_s32((int32_t*)(x),(y))
#define ADD(x,y) vaddq_s32((x),(y))
#define SUB(x,y) vsubq_s32((x),(y))
#define SAR(x,y) vshrq_n_s32((x),(y))
#else
#define VREG __m128i
#define LOAD_CST(x) _mm_set1_epi32(x)
Expand Down Expand Up @@ -755,9 +768,9 @@ void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col,
}
}

/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on even coordinate */
static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2 and NEON,
* or 16 in AVX2, when top-most pixel is on even coordinate */
static void opj_idwt53_v_cas0_mcols_SIMD(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
Expand Down Expand Up @@ -862,9 +875,9 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
}


/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on odd coordinate */
static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2 and NEON,
* or 16 in AVX2, when top-most pixel is on odd coordinate */
static void opj_idwt53_v_cas1_mcols_SIMD(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
Expand Down Expand Up @@ -1104,11 +1117,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
if (dwt->cas == 0) {
/* If len == 1, unmodified value */

#if (defined(__SSE2__) || defined(__AVX2__))
#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__))
if (len > 1 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* Same as below general case, except that thanks to SIMD */
/* we can efficiently process 8/16 columns in parallel */
opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
opj_idwt53_v_cas0_mcols_SIMD(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
Expand Down Expand Up @@ -1147,11 +1160,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
return;
}

#if (defined(__SSE2__) || defined(__AVX2__))
#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__))
if (len > 2 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* Same as below general case, except that thanks to SIMD */
/* we can efficiently process 8/16 columns in parallel */
opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
opj_idwt53_v_cas1_mcols_SIMD(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
Expand Down
Loading