Skip to content

Commit

Permalink
Add group L1/L2 threshold. Also fix the max number of decomp levels
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrepaleo committed May 15, 2017
1 parent a1fd1c9 commit 304f6f9
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ libpdwtd.so: $(PDWTCORE)


clean:
rm -rf build
rm -rf build demo libpdwt*.so
4 changes: 2 additions & 2 deletions TODO.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ TODO LIST
- Modify inverse haar (and related) to avoid extra memcpy
- Nonseparable SWT re-computes the inverse filters at each inversion (for memory requirements). See if there is a workaround.
- Update filters coefficients in filters.cpp, to be compatible with PyWavelets.
- ISWT:
- SWT:
- Either return all app coeffs (can be interesing for multi-resolution)
- Or only store appcoeffs at last scale (for eg. Matlab only uses the last appcoeff for iswt)
- Or only store appcoeffs at last scale (for eg. Matlab only uses the last appcoeff for iswt). This is what is done by default.
- Exceptions to fail gracefully (especially from the Python side)
- More threshold types
- Compute norm related to threshold (eg. soft/L1, group-lasso, etc)
Expand Down
92 changes: 92 additions & 0 deletions src/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,65 @@ __global__ void w_kern_proj_linf_1d(DTYPE* c_d, DTYPE beta, int Nr, int Nc) {
}


/// group soft thresholding the detail coefficients (2D)
/// If do_thresh_appcoeffs, the appcoeff (A) is only used at the last scale:
/// - At any scale, c_a == NULL
/// - At the last scale, c_a != NULL (i.e its size is the same as c_h, c_v, c_d)
/// Must be lanched with block size (Nc, Nr) : the size of the current coefficient vector
__global__ void w_kern_group_soft_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs) {
int gidx = threadIdx.x + blockIdx.x*blockDim.x;
int gidy = threadIdx.y + blockIdx.y*blockDim.y;
if (gidx < Nc && gidy < Nr) {
int tid = gidy*Nc + gidx;
DTYPE val_h = 0.0f, val_v = 0.0f, val_d = 0.0f, val_a = 0.0f;
DTYPE norm = 0, res = 0;

val_h = c_h[tid];
val_v = c_v[tid];
val_d = c_d[tid];
norm = val_h*val_h + val_v*val_v + val_d*val_d;

if (c_a != NULL) { // SWT
val_a = c_a[tid];
norm += val_a*val_a;
}
norm = sqrtf(norm);
if (norm == 0) res = 0;
else res = max(1 - beta/norm, 0.0);
c_h[tid] *= res;
c_v[tid] *= res;
c_d[tid] *= res;
if (c_a != NULL) c_a[tid] *= res;
}
}

/// group soft thresholding of the coefficients (1D)
/// If do_thresh_appcoeffs, the appcoeff (A) is only used at the last scale:
/// - At any scale, c_a == NULL
/// - At the last scale, c_a != NULL (i.e its size is the same as c_d)
/// Must be lanched with block size (Nc, Nr) : the size of the current coefficient vector
__global__ void w_kern_group_soft_thresh_1d(DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs) {
int gidx = threadIdx.x + blockIdx.x*blockDim.x;
int gidy = threadIdx.y + blockIdx.y*blockDim.y;
if (gidx < Nc && gidy < Nr) {
int tid = gidy*Nc + gidx;
DTYPE val_d = 0.0f, val_a = 0.0f;
DTYPE norm = 0, res = 0;

val_d = c_d[tid];
norm = val_d*val_d; // does not make much sense to use DWT_1D + group_soft_thresh (use soft_tresh)

if (c_a != NULL) { // SWT
val_a = c_a[tid];
norm += val_a*val_a;
}
norm = sqrtf(norm);
if (norm == 0) res = 0;
else res = max(1 - beta/norm, 0.0);
c_d[tid] *= res;
if (c_a != NULL) c_a[tid] *= res;
}
}


/// Circular shift of the image (2D and 1D)
Expand Down Expand Up @@ -250,6 +308,40 @@ void w_call_proj_linf(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh
}


void w_call_group_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize) {
int tpb = 16; // Threads per block
dim3 n_threads_per_block = dim3(tpb, tpb, 1);
dim3 n_blocks;
int Nr = winfos.Nr, Nc = winfos.Nc, do_swt = winfos.do_swt, nlevels = winfos.nlevels, ndims = winfos.ndims;
int Nr2 = Nr, Nc2 = Nc;
if (!do_swt) {
if (ndims > 1) w_div2(&Nr2);
w_div2(&Nc2);
}
//~ if (do_thresh_appcoeffs) {
//~ DTYPE beta2 = beta;
//~ if (normalize > 0) { // beta2 = beta/sqrt(2)^nlevels
//~ int nlevels2 = nlevels/2;
//~ beta2 /= (1 << nlevels2);
//~ if (nlevels2 *2 != nlevels) beta2 /= SQRT_2;
//~ }
//~ n_blocks = dim3(w_iDivUp(Nc2, tpb), w_iDivUp(Nr2, tpb), 1);
//~ w_kern_soft_thresh_appcoeffs<<<n_blocks, n_threads_per_block>>>(d_coeffs[0], beta2, Nr2, Nc2);
//~ }
for (int i = 0; i < nlevels; i++) {
if (!do_swt) {
if (ndims > 1) w_div2(&Nr);
w_div2(&Nc);
}
if (normalize > 0) beta /= SQRT_2;
n_blocks = dim3(w_iDivUp(Nc, tpb), w_iDivUp(Nr, tpb), 1);
if (ndims > 1) w_kern_group_soft_thresh<<<n_blocks, n_threads_per_block>>>(d_coeffs[3*i+1], d_coeffs[3*i+2], d_coeffs[3*i+3], ((do_thresh_appcoeffs && i == nlevels-1) ? d_coeffs[0]: NULL), beta, Nr, Nc, do_thresh_appcoeffs);
else w_kern_group_soft_thresh_1d<<<n_blocks, n_threads_per_block>>>(d_coeffs[i+1], ((do_thresh_appcoeffs && i == nlevels-1) ? d_coeffs[0]: NULL), beta, Nr, Nc, do_thresh_appcoeffs);
}
}





void w_shrink(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs) {
Expand Down
5 changes: 4 additions & 1 deletion src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ __global__ void w_kern_proj_linf(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE beta,
__global__ void w_kern_proj_linf_1d(DTYPE* c_d, DTYPE beta, int Nr, int Nc);
__global__ void w_kern_proj_linf_appcoeffs(DTYPE* c_a, DTYPE beta, int Nr, int Nc);


__global__ void w_kern_hard_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE beta, int Nr, int Nc);
__global__ void w_kern_hard_thresh_appcoeffs(DTYPE* c_a, DTYPE beta, int Nr, int Nc);

__global__ void w_kern_group_soft_thresh(DTYPE* c_h, DTYPE* c_v, DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs);
__global__ void w_kern_group_soft_thresh_1d(DTYPE* c_d, DTYPE* c_a, DTYPE beta, int Nr, int Nc, int do_thresh_appcoeffs);

__global__ void w_kern_circshift(DTYPE* d_image, DTYPE* d_out, int Nr, int Nc, int sr, int sc);

void w_call_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize);
void w_call_hard_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize);
void w_call_proj_linf(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs);
void w_call_group_soft_thresh(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs, int normalize);
void w_shrink(DTYPE** d_coeffs, DTYPE beta, w_info winfos, int do_thresh_appcoeffs);
void w_call_circshift(DTYPE* d_image, DTYPE* d_image2, w_info winfos, int sr, int sc, int inplace = 1);

Expand Down
2 changes: 1 addition & 1 deletion src/wt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Wavelets::Wavelets(
int N;
if (ndim == 2) N = min(Nr, Nc);
else N = Nc;
int wmaxlev = w_ilog2(N/hlen);
int wmaxlev = w_ilog2(N/(hlen-1));
// TODO: remove this limitation
if (levels > wmaxlev) {
printf("Warning: required level (%d) is greater than the maximum possible level for %s (%d) on a %dx%d image.\n", winfos.nlevels, wname, wmaxlev, winfos.Nc, winfos.Nr);
Expand Down

0 comments on commit 304f6f9

Please sign in to comment.