Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 15 additions & 95 deletions nanshe_ipython.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@
"source": [
"from nanshe_workflow.par import halo_block_parallel\n",
"\n",
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data\n",
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, normalize_data\n",
"\n",
"from nanshe_workflow.par import halo_block_generate_dictionary_parallel\n",
"from nanshe_workflow.imp import block_postprocess_data_parallel\n",
Expand Down Expand Up @@ -993,11 +993,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Normalize Data\n",
"### Project\n",
"\n",
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
"* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).\n",
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
"* `proj_type` (`str`): type of projection to take."
]
},
{
Expand All @@ -1007,12 +1006,11 @@
"outputs": [],
"source": [
"block_frames = 40\n",
"block_space = 300\n",
"norm_frames = 100\n",
"proj_type = \"max\"\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" dask_io_remove(data_basename + postfix_norm + zarr_ext, executor)\n",
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
" with open_zarr(data_basename + postfix_wt + zarr_ext, \"r\") as f:\n",
Expand All @@ -1027,106 +1025,28 @@
" da_imgs_flt.dtype.itemsize >= 4):\n",
" da_imgs_flt = da_imgs_flt.astype(np.float32)\n",
"\n",
" da_imgs_flt_mins = da_imgs_flt.min(\n",
" axis=tuple(irange(1, da_imgs_flt.ndim)),\n",
" keepdims=True\n",
" )\n",
"\n",
" da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins\n",
"\n",
" da_result = renormalized_images(da_imgs_flt_shift)\n",
" da_result = da_imgs\n",
" if proj_type == \"max\":\n",
" da_result = da_result.max(axis=0, keepdims=True)\n",
" elif proj_type == \"std\":\n",
" da_result = da_result.std(axis=0, keepdims=True)\n",
"\n",
" # Store denoised data\n",
" dask_store_zarr(data_basename + postfix_norm + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_norm + zarr_ext, executor)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
" result_image_stack,\n",
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dictionary Learning\n",
"\n",
"* `n_components` (`int`): number of basis images in the dictionary.\n",
"* `batchsize` (`int`): minibatch size to use.\n",
"* `iters` (`int`): number of iterations to run before getting dictionary.\n",
"* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.\n",
"* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.\n",
"\n",
"<br>\n",
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_components = 50\n",
"batchsize = 256\n",
"iters = 100\n",
"lambda1 = 0.2\n",
"lambda2 = 0.0\n",
"\n",
"block_frames = 51\n",
"norm_frames = 100\n",
" dask_store_zarr(data_basename + postfix_dict + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
"result = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
"block_shape = (block_frames,) + result.shape[1:]\n",
"with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n",
" new_result = f2.create_dataset(\"images\", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)\n",
"\n",
" result = par_generate_dictionary(block_shape)(\n",
" result,\n",
" n_components=n_components,\n",
" out=new_result,\n",
" **{\"sklearn.decomposition.dict_learning_online\" : {\n",
" \"n_jobs\" : 1,\n",
" \"n_iter\" : iters,\n",
" \"batch_size\" : batchsize,\n",
" \"alpha\" : lambda1\n",
" }\n",
" }\n",
" )\n",
"\n",
" result_j = f2.create_dataset(\"images_j\", shape=new_result.shape, dtype=numpy.uint16, chunks=True)\n",
" par_norm_layer(num_frames=norm_frames)(result, out=result_j)\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" zip_zarr(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")[...][...]\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
" result_image_stack,\n",
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
" )\n",
" mplsv.time_nav.stime.label.set_text(\"Basis Image\")"
" vmin=result_image_stack.min(),\n",
" vmax=result_image_stack.max()\n",
" )"
]
},
{
Expand Down