From 304f6f959de9dcdc6a6c0b594e54ceed35522ecf Mon Sep 17 00:00:00 2001 From: Pierre Paleo Date: Mon, 15 May 2017 12:00:13 +0200 Subject: [PATCH] Add group L1/L2 threshold. Also fix the max number of decomp levels --- Makefile | 2 +- TODO.txt | 4 +-- src/common.cu | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/common.h | 5 ++- src/wt.cu | 2 +- 5 files changed, 100 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index d013d83..10b383f 100755 --- a/Makefile +++ b/Makefile @@ -41,4 +41,4 @@ libpdwtd.so: $(PDWTCORE) clean: - rm -rf build + rm -rf build demo libpdwt*.so diff --git a/TODO.txt b/TODO.txt index 62717ca..e75105d 100644 --- a/TODO.txt +++ b/TODO.txt @@ -6,9 +6,9 @@ TODO LIST - Modify inverse haar (and related) to avoid extra memcpy - Nonseparable SWT re-computes the inverse filters at each inversion (for memory requirements). See if there is a workaround. - Update filters coefficients in filters.cpp, to be compatible with PyWavelets. -- ISWT: +- SWT: - Either return all app coeffs (can be interesing for multi-resolution) - - Or only store appcoeffs at last scale (for eg. Matlab only uses the last appcoeff for iswt) + - Or only store appcoeffs at last scale (for eg. Matlab only uses the last appcoeff for iswt). This is what is done by default. - Exceptions to fail gracefully (especially from the Python side) - More threshold types - Compute norm related to threshold (eg. soft/L1, group-lasso, etc) diff --git a/src/common.cu b/src/common.cu index 0c880af..8bcd67e 100644 --- a/src/common.cu +++ b/src/common.cu @@ -137,7 +137,65 @@ __global__ void w_kern_proj_linf_1d(DTYPE* c_d, DTYPE beta, int Nr, int Nc) { } +/// group soft thresholding the detail coefficients (2D) +/// If do_thresh_appcoeffs, the appcoeff (A) is only used at the last scale: +/// - At any scale, c_a == NULL +/// - At the last scale, c_a != NULL (i.e its size is the same as c_h, c_v, c_d) +/// Must be lanched with block size (Nc, Nr) : the size of the current coefficient vector +__global__ void w_kern_group_soft_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs) { + int gidx = threadIdx.x + blockIdx.x*blockDim.x; + int gidy = threadIdx.y + blockIdx.y*blockDim.y; + if (gidx < Nc && gidy < Nr) { + int tid = gidy*Nc + gidx; + DTYPE val_h = 0.0f, val_v = 0.0f, val_d = 0.0f, val_a = 0.0f; + DTYPE norm = 0, res = 0; + + val_h = c_h[tid]; + val_v = c_v[tid]; + val_d = c_d[tid]; + norm = val_h*val_h + val_v*val_v + val_d*val_d; + + if (c_a != NULL) { // SWT + val_a = c_a[tid]; + norm += val_a*val_a; + } + norm = sqrtf(norm); + if (norm == 0) res = 0; + else res = max(1 - beta/norm, 0.0); + c_h[tid] *= res; + c_v[tid] *= res; + c_d[tid] *= res; + if (c_a != NULL) c_a[tid] *= res; + } +} +/// group soft thresholding of the coefficients (1D) +/// If do_thresh_appcoeffs, the appcoeff (A) is only used at the last scale: +/// - At any scale, c_a == NULL +/// - At the last scale, c_a != NULL (i.e its size is the same as c_d) +/// Must be lanched with block size (Nc, Nr) : the size of the current coefficient vector +__global__ void w_kern_group_soft_thresh_1d(DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs) { + int gidx = threadIdx.x + blockIdx.x*blockDim.x; + int gidy = threadIdx.y + blockIdx.y*blockDim.y; + if (gidx < Nc && gidy < Nr) { + int tid = gidy*Nc + gidx; + DTYPE val_d = 0.0f, val_a = 0.0f; + DTYPE norm = 0, res = 0; + + val_d = c_d[tid]; + norm = val_d*val_d; // does not make much sense to use DWT_1D + group_soft_thresh (use soft_tresh) + + if (c_a != NULL) { // SWT + val_a = c_a[tid]; + norm += val_a*val_a; + } + norm = sqrtf(norm); + if (norm == 0) res = 0; + else res = max(1 - beta/norm, 0.0); + c_d[tid] *= res; + if (c_a != NULL) c_a[tid] *= res; + } +} /// Circular shift of the image (2D and 1D) @@ -250,6 +308,40 @@ void w_call_proj_linf(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh } +void w_call_group_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize) { + int tpb = 16; // Threads per block + dim3 n_threads_per_block = dim3(tpb, tpb, 1); + dim3 n_blocks; + int Nr = winfos.Nr, Nc = winfos.Nc, do_swt = winfos.do_swt, nlevels = winfos.nlevels, ndims = winfos.ndims; + int Nr2 = Nr, Nc2 = Nc; + if (!do_swt) { + if (ndims > 1) w_div2(&Nr2); + w_div2(&Nc2); + } + //~ if (do_thresh_appcoeffs) { + //~ DTYPE beta2 = beta; + //~ if (normalize > 0) { // beta2 = beta/sqrt(2)^nlevels + //~ int nlevels2 = nlevels/2; + //~ beta2 /= (1 << nlevels2); + //~ if (nlevels2 *2 != nlevels) beta2 /= SQRT_2; + //~ } + //~ n_blocks = dim3(w_iDivUp(Nc2, tpb), w_iDivUp(Nr2, tpb), 1); + //~ w_kern_soft_thresh_appcoeffs<<>>(d_coeffs[0], beta2, Nr2, Nc2); + //~ } + for (int i = 0; i < nlevels; i++) { + if (!do_swt) { + if (ndims > 1) w_div2(&Nr); + w_div2(&Nc); + } + if (normalize > 0) beta /= SQRT_2; + n_blocks = dim3(w_iDivUp(Nc, tpb), w_iDivUp(Nr, tpb), 1); + if (ndims > 1) w_kern_group_soft_thresh<<>>(d_coeffs[3*i+1], d_coeffs[3*i+2], d_coeffs[3*i+3], ((do_thresh_appcoeffs && i == nlevels-1) ? d_coeffs[0]: NULL), beta, Nr, Nc, do_thresh_appcoeffs); + else w_kern_group_soft_thresh_1d<<>>(d_coeffs[i+1], ((do_thresh_appcoeffs && i == nlevels-1) ? d_coeffs[0]: NULL), beta, Nr, Nc, do_thresh_appcoeffs); + } +} + + + void w_shrink(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs) { diff --git a/src/common.h b/src/common.h index 1cf4a76..ead55c8 100644 --- a/src/common.h +++ b/src/common.h @@ -44,15 +44,18 @@ __global__ void w_kern_proj_linf(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE beta, __global__ void w_kern_proj_linf_1d(DTYPE* c_d, DTYPE beta, int Nr, int Nc); __global__ void w_kern_proj_linf_appcoeffs(DTYPE* c_a, DTYPE beta, int Nr, int Nc); - __global__ void w_kern_hard_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE beta, int Nr, int Nc); __global__ void w_kern_hard_thresh_appcoeffs(DTYPE* c_a, DTYPE beta, int Nr, int Nc); +__global__ void w_kern_group_soft_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs); +__global__ void w_kern_group_soft_thresh_1d(DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs); + __global__ void w_kern_circshift(DTYPE* d_image, DTYPE* d_out, int Nr, int Nc, int sr, int sc); void w_call_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize); void w_call_hard_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize); void w_call_proj_linf(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs); +void w_call_group_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize); void w_shrink(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs); void w_call_circshift(DTYPE* d_image, DTYPE* d_image2, w_info winfos, int sr, int sc, int inplace = 1); diff --git a/src/wt.cu b/src/wt.cu index e14282d..a10c404 100755 --- a/src/wt.cu +++ b/src/wt.cu @@ -156,7 +156,7 @@ Wavelets::Wavelets( int N; if (ndim == 2) N = min(Nr, Nc); else N = Nc; - int wmaxlev = w_ilog2(N/hlen); + int wmaxlev = w_ilog2(N/(hlen-1)); // TODO: remove this limitation if (levels > wmaxlev) { printf("Warning: required level (%d) is greater than the maximum possible level for %s (%d) on a %dx%d image.\n", winfos.nlevels, wname, wmaxlev, winfos.Nc, winfos.Nr);