diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index fed55c9..0000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,16 +0,0 @@ -[bumpversion] -current_version = 0.0.8.dev0 -parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.((?P[a-z]*)(?P\d*)))? -serialize = - {major}.{minor}.{patch}.{release}{devbuild} - {major}.{minor}.{patch} - -[bumpversion:part:release] -optional_value = rel -values = - dev - rel - -[bumpversion:file:aicsmlsegment/version.py] -search = {current_version} -replace = {new_version} diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..8f54c3b Binary files /dev/null and b/.github/ISSUE_TEMPLATE/bug_report.md differ diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..3997765 Binary files /dev/null and b/.github/ISSUE_TEMPLATE/feature_request.md differ diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..664b90e Binary files /dev/null and b/.github/PULL_REQUEST_TEMPLATE.md differ diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml new file mode 100644 index 0000000..6d867a2 --- /dev/null +++ b/.github/workflows/build-docs.yml @@ -0,0 +1,34 @@ +name: Documentation + +on: + push: + branches: + - main + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2.3.1 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install Dependencies + run: | + pip install --upgrade pip + pip install .[dev] + - name: Generate Docs + run: | + make gen-docs + touch docs/_build/html/.nojekyll + - name: Publish Docs + uses: JamesIves/github-pages-deploy-action@3.7.1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BASE_BRANCH: main # The branch the action should deploy from. + BRANCH: gh-pages # The branch the action should deploy to. + FOLDER: docs/_build/html/ # The folder the action should deploy. + diff --git a/.github/workflows/build-main.yml b/.github/workflows/build-main.yml new file mode 100644 index 0000000..8a05c1a --- /dev/null +++ b/.github/workflows/build-main.yml @@ -0,0 +1,79 @@ +name: Build Main + +on: + push: + branches: + - main + schedule: + # + # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/crontab.html#tag_20_25_07 + # Run every Monday at 18:00:00 UTC (Monday at 10:00:00 PST) + - cron: '0 18 * * 1' + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [3.7, 3.8, 3.9] + os: [ubuntu-latest, windows-latest] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Test with pytest + run: | + pytest --cov-report xml --cov=test aicsmlsegment/tests/ + - name: Upload codecov + uses: codecov/codecov-action@v1 + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Lint with flake8 + run: | + flake8 test --count --verbose --show-source --statistics + - name: Check with black + run: | + black --check test + + publish: + if: "contains(github.event.head_commit.message, 'Bump version')" + needs: [test, lint] + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel + - name: Build Package + run: | + python setup.py sdist bdist_wheel + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml new file mode 100644 index 0000000..d7774e1 --- /dev/null +++ b/.github/workflows/test-and-lint.yml @@ -0,0 +1,47 @@ +name: Test and Lint + +on: pull_request + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [3.7, 3.8, 3.9] + os: [ubuntu-latest, windows-latest] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Test with pytest + run: | + pytest aicsmlsegment/tests/ + - name: Upload codecov + uses: codecov/codecov-action@v1 + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Lint with flake8 + run: | + flake8 aicsmlsegment --count --verbose --show-source --statistics + - name: Check with black + run: | + black --check aicsmlsegment diff --git a/.gitignore b/.gitignore index f61c7b9..099766b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,121 @@ -.vscode/ -.gradle/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# OS generated files +.DS_Store + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/test.*rst + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# Dask +dask-worker-space + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# VSCode + .idea/ *.iml .*.swp .*.swo *~ -*.ipynb_checkpoints + # Generated by build -build/ -dist/ -venv/ -.eggs/ -*.egg-info -**/__pycache__/ -.pytest_cache/ activate -.coverage diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..c2b3dea --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,73 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting any of the maintainers of this project and +we will attempt to resolve the issues with respect and dignity. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/Jenkinsfile b/Jenkinsfile deleted file mode 100644 index 2ef60b4..0000000 --- a/Jenkinsfile +++ /dev/null @@ -1,139 +0,0 @@ -pipeline { - parameters { booleanParam(name: 'create_release', defaultValue: false, - description: 'If true, create a release artifact and publish to ' + - 'the artifactory release PyPi or public PyPi.') } - options { - timeout(time: 1, unit: 'HOURS') - } - agent { - node { - label "python-gradle" - } - } - environment { - PATH = "/home/jenkins/.local/bin:$PATH" - REQUESTS_CA_BUNDLE = "/etc/ssl/certs" - } - stages { - stage ("create virtualenv") { - steps { - this.notifyBB("INPROGRESS") - sh "./gradlew -i cleanAll installCIDependencies" - } - } - - stage ("bump version pre-build") { - when { - expression { return params.create_release } - } - steps { - // This will drop the dev suffix if we are releasing - // X.Y.Z.devN -> X.Y.Z - sh "./gradlew -i bumpVersionRelease" - } - } - - stage ("test/build distribution") { - steps { - sh "./gradlew -i build" - } - } - - stage ("report on tests") { - steps { - junit "build/test_report.xml" - - cobertura autoUpdateHealth: false, - autoUpdateStability: false, - coberturaReportFile: 'build/coverage.xml', - failUnhealthy: false, - failUnstable: false, - maxNumberOfBuilds: 0, - onlyStable: false, - sourceEncoding: 'ASCII', - zoomCoverageChart: false - - - } - } - - stage ("publish release") { - when { - branch 'master' - expression { return params.create_release } - } - steps { - sh "./gradlew -i publishRelease" - sh "./gradlew -i gitTagCommitPush" - sh "./gradlew -i bumpVersionPostRelease gitCommitPush" - } - } - - stage ("publish snapshot") { - when { - branch 'master' - not { expression { return params.create_release } } - } - steps { - sh "./gradlew -i publishSnapshot" - script { - def ignoreAuthors = ["jenkins", "Jenkins User", "Jenkins Builder"] - if (!ignoreAuthors.contains(gitAuthor())) { - sh "./gradlew -i bumpVersionDev gitCommitPush" - } - } - } - } - - } - post { - always { - notifyBuildOnSlack(currentBuild.result, currentBuild.previousBuild?.result) - this.notifyBB(currentBuild.result) - } - cleanup { - deleteDir() - } - } -} - -def notifyBB(String state) { - // on success, result is null - state = state ?: "SUCCESS" - - if (state == "SUCCESS" || state == "FAILURE") { - currentBuild.result = state - } - - notifyBitbucket commitSha1: "${GIT_COMMIT}", - credentialsId: 'aea50792-dda8-40e4-a683-79e8c83e72a6', - disableInprogressNotification: false, - considerUnstableAsSuccess: true, - ignoreUnverifiedSSLPeer: false, - includeBuildNumberInKey: false, - prependParentProjectKey: false, - projectKey: 'SW', - stashServerBaseUrl: 'https://aicsbitbucket.corp.alleninstitute.org' -} - -def notifyBuildOnSlack(String buildStatus = 'STARTED', String priorStatus) { - // build status of null means successful - buildStatus = buildStatus ?: 'SUCCESS' - - // Override default values based on build status - if (buildStatus != 'SUCCESS') { - slackSend ( - color: '#FF0000', - message: "${buildStatus}: '${env.JOB_NAME} [${env.BUILD_NUMBER}]' (${env.BUILD_URL})" - ) - } else if (priorStatus != 'SUCCESS') { - slackSend ( - color: '#00FF00', - message: "BACK_TO_NORMAL: '${env.JOB_NAME} [${env.BUILD_NUMBER}]' (${env.BUILD_URL})" - ) - } -} - -def gitAuthor() { - sh(returnStdout: true, script: 'git log -1 --format=%an').trim() -} diff --git a/LICENSE.txt b/LICENSE similarity index 100% rename from LICENSE.txt rename to LICENSE diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..c9c8459 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,10 @@ +include CONTRIBUTING.md +include LICENSE +include README.md + +recursive-include tests * +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] + +recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif +graft aicsmlsegment/data diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..996a769 --- /dev/null +++ b/Makefile @@ -0,0 +1,59 @@ +.PHONY: clean build docs help +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +try: + from urllib import pathname2url +except: + from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: ## clean all build, python, and testing files + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + rm -fr .tox/ + rm -fr .coverage + rm -fr coverage.xml + rm -fr htmlcov/ + rm -fr .pytest_cache + +build: ## run tox / run tests and lint + tox + +gen-docs: ## generate Sphinx HTML documentation, including API docs + rm -f docs/aicsmlsegment*.rst + rm -f docs/modules.rst + sphinx-apidoc -o docs/ aicsmlsegment **/tests/ + $(MAKE) -C docs html + +docs: ## generate Sphinx HTML documentation, including API docs, and serve to browser + make gen-docs + $(BROWSER) docs/_build/html/index.html diff --git a/README.md b/README.md index 6b0498e..1790fd4 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,28 @@ The Allen Cell Structure Segmenter is a Python-based open source toolkit develop **Note: This repository has only the code for the "Iterative Deep Learning Workflow". The classic part can be found at [https://github.com/AllenCell/aics-segmentation](https://github.com/AllenCell/aics-segmentation)** +*********************************************************************** +**June 2021 Update**: We have refactored and modernizaed the deep learning code base +to be more powerful. Examples of new features are: +- utilize [pytorch-lightning](https://www.pytorchlightning.ai/) to perform more sophisticated training + (e.g., 16-bit training, stachastic weight averaging, various learning + rate schedulers, multi-GPU training) +- support more baseline models from [MONAI](https://monai.io/) +- utilize [TorchIO](https://github.com/fepegar/torchio) for more efficient 3D data augmentation +- new inference function utilizing weighted blending to avoid stiching + effect when applying the model on a large image +- support tensorboard for visualizing and tracking experiments + +More details on how the new code is organized can be found here: [code overview](./docs/code_overview.md) +*********************************************************************** + +# Link to [Documentations and Tutorials](./docs/overview.md) + ## Installation: 0. prerequisite: -To perform training/prediction of the deep learning models in this package, we assume an [NVIDIA GPU](https://www.nvidia.com/en-us/deep-learning-ai/developer/) has been set up properly on a Linux operating system, either on a local machine or on a remote computation cluster. Make sure to check if your GPU supports at least CUDA 8.0 (CUDA 9.0 and up is preferred): [NVIDIA Driver check](https://www.nvidia.com/Download/index.aspx?lang=en-us). +To perform training/prediction of the deep learning models in this package, we assume an [NVIDIA GPU](https://www.nvidia.com/en-us/deep-learning-ai/developer/) has been set up properly on a Linux operating system, either on a local machine or on a remote computation cluster. Make sure to check if your GPU supports at least CUDA 10.0: [NVIDIA Driver check](https://www.nvidia.com/Download/index.aspx?lang=en-us). The GPUs we used to develop and test our package are two types: (1) GeForce GTX 1080 Ti GPU (about 11GB GPU memory), (2) Titan Xp GPU (about 12GB GPU memory), (3) Tesla V100 for PCIe (with about 33GB memory). These cover common chips for personal workstations and data centers. @@ -20,7 +37,7 @@ The GPUs we used to develop and test our package are two types: (1) GeForce GTX 1. create a conda environment: ```bash -conda create --name mlsegmenter python=3.7 +conda create --name mlsegmenter python=3.8 ``` (For how to install conda, see [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html#installing-conda-on-a-system-that-has-other-python-installations-or-packages)) @@ -37,8 +54,8 @@ conda activate mlsegmenter Go to [PyTorch website](https://pytorch.org/get-started/locally/), and find the right installation command for you. -* we use version 1.0 (which is the stable version at the time of our development) -* we use Linux (OS), Conda (package), python 3.6 (Language), CUDA=10.0 (Question about CUDA? see [setup CUDA](./docs/check_cuda.md)). +* we use version 1.8.1 (which is the stable version at the time of our release) +* we use Linux (OS), Conda (package), python 3.8 (Language), CUDA=10.2 (Question about CUDA? see [setup CUDA](./docs/check_cuda.md)). ***Make sure you use either the automatically generated command on PyTorch website, or the command recommended on PyTorch website for installing [older version](https://pytorch.org/get-started/previous-versions/)*** @@ -56,6 +73,3 @@ The `-e` flag when doing `pip install` will allow users to modify any the source ## Level of Support We are offering it to the community AS IS; we have used the toolkit within our organization. We are not able to provide guarantees of support. However, we welcome feedback and submission of issues. Users are encouraged to sign up on our [Allen Cell Discussion Forum](https://forum.allencell.org/) for community quesitons and comments. - - -# Link to [Documentations and Tutorials](./docs/overview.md) \ No newline at end of file diff --git a/aicsmlsegment/DataLoader3D/Universal_Loader.py b/aicsmlsegment/DataLoader3D/Universal_Loader.py deleted file mode 100644 index 788086a..0000000 --- a/aicsmlsegment/DataLoader3D/Universal_Loader.py +++ /dev/null @@ -1,364 +0,0 @@ -import numpy as np -import os -from PIL import Image -import random - -from torch import from_numpy -from aicsimageio import imread -from random import shuffle -import time -from torchvision.transforms import ToTensor -from torch.utils.data import Dataset - - -# CODE for generic loader -# No augmentation = NOAUG,simply load data and convert to tensor -# Augmentation code: -# RR = Rotate by a random degree from 1 to 180 -# R4 = Rotate by 0, 90, 180, 270 -# FH = Flip Horizantally -# FV = Flip Vertically -# FD = Flip Depth (i.e., along z dim) -# SS = Size Scaling by a ratio between -0.1 to 0.1 (TODO) -# IJ = Intensity Jittering (TODO) -# DD = Dense Deformation (TODO) - - -class RR_FH_M0(Dataset): - - def __init__(self, filenames, num_patch, size_in, size_out): - - self.img = [] - self.gt = [] - self.cmap = [] - - padding = [(x-y)//2 for x,y in zip(size_in, size_out)] - total_in_count = size_in[0] * size_in[1] * size_in[2] - total_out_count = size_out[0] * size_out[1] * size_out[2] - - num_data = len(filenames) - shuffle(filenames) - num_patch_per_img = np.zeros((num_data,), dtype=int) - if num_data >= num_patch: - # all one - num_patch_per_img[:num_patch]=1 - else: - basic_num = num_patch // num_data - # assign each image the same number of patches to extract - num_patch_per_img[:] = basic_num - - # assign one more patch to the first few images to achieve the total patch number - num_patch_per_img[:(num_patch-basic_num*num_data)] = num_patch_per_img[:(num_patch-basic_num*num_data)] + 1 - - for img_idx, fn in enumerate(filenames): - - if len(self.img)==num_patch: - break - - label = np.squeeze(imread(fn+'_GT.ome.tif')) - label = np.expand_dims(label, axis=0) - - input_img = np.squeeze(imread(fn+'.ome.tif')) - if len(input_img.shape) == 3: - # add channel dimension - input_img = np.expand_dims(input_img, axis=0) - elif len(input_img.shape) == 4: - # assume number of channel < number of Z, make sure channel dim comes first - if input_img.shape[0] > input_img.shape[1]: - input_img = np.transpose(input_img, (1, 0, 2, 3)) - - costmap = np.squeeze(imread(fn+'_CM.ome.tif')) - - img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'constant') - raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') - - cost_scale = costmap.max() - if cost_scale<1: ## this should not happen, but just in case - cost_scale = 1 - - deg = random.randrange(1,180) - flip_flag = random.random() - - for zz in range(label.shape[1]): - - for ci in range(label.shape[0]): - labi = label[ci,zz,:,:] - labi_pil = Image.fromarray(np.uint8(labi)) - new_labi_pil = labi_pil.rotate(deg,resample=Image.NEAREST) - if flip_flag<0.5: - new_labi_pil = new_labi_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_labi = np.array(new_labi_pil.convert('L')) - label[ci,zz,:,:] = new_labi.astype(int) - - cmap = costmap[zz,:,:] - cmap_pil = Image.fromarray(np.uint8(255*(cmap/cost_scale))) - new_cmap_pil = cmap_pil.rotate(deg,resample=Image.NEAREST) - if flip_flag<0.5: - new_cmap_pil = new_cmap_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_cmap = np.array(new_cmap_pil.convert('L')) - costmap[zz,:,:] = cost_scale*(new_cmap/255.0) - - for zz in range(raw.shape[1]): - for ci in range(raw.shape[0]): - str_im = raw[ci,zz,:,:] - str_im_pil = Image.fromarray(np.uint8(str_im*255)) - new_str_im_pil = str_im_pil.rotate(deg,resample=Image.BICUBIC) - if flip_flag<0.5: - new_str_im_pil = new_str_im_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_str_image = np.array(new_str_im_pil.convert('L')) - raw[ci,zz,:,:] = (new_str_image.astype(float))/255.0 - new_patch_num = 0 - - while new_patch_num < num_patch_per_img[img_idx]: - - pz = random.randint(0, label.shape[1] - size_out[0]) - py = random.randint(0, label.shape[2] - size_out[1]) - px = random.randint(0, label.shape[3] - size_out[2]) - - - # check if this is a good crop - ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] - - # confirmed good crop - (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) - (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) - (self.cmap).append(ref_patch_cmap) - - new_patch_num += 1 - - def __getitem__(self, index): - - image_tensor = from_numpy(self.img[index].astype(float)) - cmap_tensor = from_numpy(self.cmap[index].astype(float)) - - label_tensor = [] - if self.gt[index].shape[0]>0: - for zz in range(self.gt[index].shape[0]): - label_tensor.append(from_numpy(self.gt[index][zz,:,:,:].astype(float)).float()) - else: - label_tensor.append(from_numpy(self.gt[index].astype(float)).float()) - - return image_tensor.float(), label_tensor, cmap_tensor.float() - - def __len__(self): - return len(self.img) - -class RR_FH_M0C(Dataset): - - def __init__(self, filenames, num_patch, size_in, size_out): - - self.img = [] - self.gt = [] - self.cmap = [] - - padding = [(x-y)//2 for x,y in zip(size_in, size_out)] - - num_data = len(filenames) - shuffle(filenames) - - num_trial_round = 0 - while len(self.img) < num_patch: - - # to avoid dead loop - num_trial_round = num_trial_round + 1 - if num_trial_round > 2: - break - - num_patch_to_obtain = num_patch - len(self.img) - num_patch_per_img = np.zeros((num_data,), dtype=int) - if num_data >= num_patch_to_obtain: - # all one - num_patch_per_img[:num_patch_to_obtain]=1 - else: - basic_num = num_patch_to_obtain // num_data - # assign each image the same number of patches to extract - num_patch_per_img[:] = basic_num - - # assign one more patch to the first few images to achieve the total patch number - num_patch_per_img[:(num_patch_to_obtain-basic_num*num_data)] = num_patch_per_img[:(num_patch_to_obtain-basic_num*num_data)] + 1 - - - for img_idx, fn in enumerate(filenames): - - if len(self.img)==num_patch: - break - - label = np.squeeze(imread(fn+'_GT.ome.tif')) - label = np.expand_dims(label, axis=0) - - input_img = np.squeeze(imread(fn+'.ome.tif')) - if len(input_img.shape) == 3: - # add channel dimension - input_img = np.expand_dims(input_img, axis=0) - elif len(input_img.shape) == 4: - # assume number of channel < number of Z, make sure channel dim comes first - if input_img.shape[0] > input_img.shape[1]: - input_img = np.transpose(input_img, (1, 0, 2, 3)) - - costmap = np.squeeze(imread(fn+'_CM.ome.tif')) - - img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'constant') - raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') - - cost_scale = costmap.max() - if cost_scale<1: ## this should not happen, but just in case - cost_scale = 1 - - deg = random.randrange(1,180) - flip_flag = random.random() - - for zz in range(label.shape[1]): - - for ci in range(label.shape[0]): - labi = label[ci,zz,:,:] - labi_pil = Image.fromarray(np.uint8(labi)) - new_labi_pil = labi_pil.rotate(deg,resample=Image.NEAREST) - if flip_flag<0.5: - new_labi_pil = new_labi_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_labi = np.array(new_labi_pil.convert('L')) - label[ci,zz,:,:] = new_labi.astype(int) - - cmap = costmap[zz,:,:] - cmap_pil = Image.fromarray(np.uint8(255*(cmap/cost_scale))) - new_cmap_pil = cmap_pil.rotate(deg,resample=Image.NEAREST) - if flip_flag<0.5: - new_cmap_pil = new_cmap_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_cmap = np.array(new_cmap_pil.convert('L')) - costmap[zz,:,:] = cost_scale*(new_cmap/255.0) - - for zz in range(raw.shape[1]): - for ci in range(raw.shape[0]): - str_im = raw[ci,zz,:,:] - str_im_pil = Image.fromarray(np.uint8(str_im*255)) - new_str_im_pil = str_im_pil.rotate(deg,resample=Image.BICUBIC) - if flip_flag<0.5: - new_str_im_pil = new_str_im_pil.transpose(Image.FLIP_LEFT_RIGHT) - new_str_image = np.array(new_str_im_pil.convert('L')) - raw[ci,zz,:,:] = (new_str_image.astype(float))/255.0 - - new_patch_num = 0 - num_fail = 0 - while new_patch_num < num_patch_per_img[img_idx]: - - pz = random.randint(0, label.shape[1] - size_out[0]) - py = random.randint(0, label.shape[2] - size_out[1]) - px = random.randint(0, label.shape[3] - size_out[2]) - - - # check if this is a good crop - ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] - if np.count_nonzero(ref_patch_cmap>1e-5) < 1000: #enough valida samples - num_fail = num_fail + 1 - if num_fail > 50: - break - continue - - - # confirmed good crop - (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) - (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) - (self.cmap).append(ref_patch_cmap) - - new_patch_num += 1 - - def __getitem__(self, index): - - image_tensor = from_numpy(self.img[index].astype(float)) - cmap_tensor = from_numpy(self.cmap[index].astype(float)) - - label_tensor = [] - if self.gt[index].shape[0]>0: - for zz in range(self.gt[index].shape[0]): - label_tensor.append(from_numpy(self.gt[index][zz,:,:,:].astype(float)).float()) - else: - label_tensor.append(from_numpy(self.gt[index].astype(float)).float()) - - return image_tensor.float(), label_tensor, cmap_tensor.float() - - def __len__(self): - return len(self.img) - -class NOAUG_M(Dataset): - - def __init__(self, filenames, num_patch, size_in, size_out): - - self.img = [] - self.gt = [] - self.cmap = [] - - padding = [(x-y)//2 for x,y in zip(size_in, size_out)] - total_in_count = size_in[0] * size_in[1] * size_in[2] - total_out_count = size_out[0] * size_out[1] * size_out[2] - - num_data = len(filenames) - shuffle(filenames) - num_patch_per_img = np.zeros((num_data,), dtype=int) - if num_data >= num_patch: - # all one - num_patch_per_img[:num_patch]=1 - else: - basic_num = num_patch // num_data - # assign each image the same number of patches to extract - num_patch_per_img[:] = basic_num - - # assign one more patch to the first few images to achieve the total patch number - num_patch_per_img[:(num_patch-basic_num*num_data)] = num_patch_per_img[:(num_patch-basic_num*num_data)] + 1 - - - for img_idx, fn in enumerate(filenames): - - label = np.squeeze(imread(fn+'_GT.ome.tif')) - label = np.expand_dims(label, axis=0) - - input_img = np.squeeze(imread(fn+'.ome.tif')) - if len(input_img.shape) == 3: - # add channel dimension - input_img = np.expand_dims(input_img, axis=0) - elif len(input_img.shape) == 4: - # assume number of channel < number of Z, make sure channel dim comes first - if input_img.shape[0] > input_img.shape[1]: - input_img = np.transpose(input_img, (1, 0, 2, 3)) - - costmap = np.squeeze(imread(fn+'_CM.ome.tif')) - - img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'symmetric') - raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') - - new_patch_num = 0 - - while new_patch_num < num_patch_per_img[img_idx]: - - pz = random.randint(0, label.shape[1] - size_out[0]) - py = random.randint(0, label.shape[2] - size_out[1]) - px = random.randint(0, label.shape[3] - size_out[2]) - - - ## check if this is a good crop - ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] - - - # confirmed good crop - (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) - (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) - (self.cmap).append(ref_patch_cmap) - - new_patch_num += 1 - - def __getitem__(self, index): - - image_tensor = from_numpy(self.img[index].astype(float)) - cmap_tensor = from_numpy(self.cmap[index].astype(float)) - - #if self.gt[index].shape[0]>1: - label_tensor = [] - for zz in range(self.gt[index].shape[0]): - tmp_tensor = from_numpy(self.gt[index][zz,:,:,:].astype(float)) - label_tensor.append(tmp_tensor.float()) - #else: - # label_tensor = from_numpy(self.gt[index].astype(float)) - # label_tensor = label_tensor.float() - - return image_tensor.float(), label_tensor, cmap_tensor.float() - - def __len__(self): - return len(self.img) \ No newline at end of file diff --git a/aicsmlsegment/DataUtils/DataMod.py b/aicsmlsegment/DataUtils/DataMod.py new file mode 100644 index 0000000..fbc7973 --- /dev/null +++ b/aicsmlsegment/DataUtils/DataMod.py @@ -0,0 +1,261 @@ +from aicsmlsegment.DataUtils.Universal_Loader import ( + UniversalDataset, + TestDataset, + load_img, +) +import random +from glob import glob +from torch.utils.data import DataLoader +import pytorch_lightning +from aicsmlsegment.Model import get_loss_criterion +import numpy as np +import torch +from math import ceil +from typing import Dict + + +def init_worker(worker_id: int): + """ + Divides the testing images equally among workers + + Parameters + ---------- + worker_id: int + id of worker, used to assign start and end images + + Return: None + """ + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + # divide images among all workers + per_worker = int(ceil(len(dataset.filenames) / float(worker_info.num_workers))) + dataset.start = worker_info.id * per_worker + dataset.end = min(len(dataset.filenames), (worker_info.id + 1) * per_worker) - 1 + + +class DataModule(pytorch_lightning.LightningDataModule): + def __init__(self, config: Dict, train: bool = True): + """ + Initialize Datamodule variable based on config + + Parameters + ---------- + config: Dict + a top level configuration object describing which images to load, + how to load them, and what transforms to apply + + Return: None + """ + super().__init__() + self.config = config + + try: # monai + self.nchannel = self.config["model"]["nchannel"] + except KeyError: # custom model + self.nchannel = self.config["model"]["in_channels"] + + if train: + self.loader_config = config["loader"] + + name = self.loader_config["name"] + if name not in ["default", "focus"]: + print("other loaders are under construction") + quit() + if name == "focus": + self.check_crop = True + else: + self.check_crop = False + self.transforms = [] + if "Transforms" in self.loader_config: + self.transforms = self.loader_config["Transforms"] + + _, self.accepts_costmap = get_loss_criterion(config) + + self.init_only = False + if self.loader_config["epoch_shuffle"] is not None: + self.init_only = True + + model_config = config["model"] + if "unet_xy" in config["model"]["name"]: + self.size_in = model_config["size_in"] + self.size_out = model_config["size_out"] + self.nchannel = model_config["nchannel"] + + else: + self.size_in = model_config["patch_size"] + self.size_out = self.size_in + self.nchannel = model_config["in_channels"] + + def prepare_data(self): + pass + + def setup(self, stage: str): + """ + Set up identical train/val splits across gpus. Since all image in batches must + be the same size, if random splits are selected in the config, the loader will + try 10 random splits until all of the validation images are the same size. + + Parameters + ---------- + stage: str + either "fit" or not + + Return: None + """ + if stage == "fit": # no setup is required for testing + # load settings # + config = self.config + + # get validation and training filenames from input dir from config + validation_config = config["validation"] + loader_config = config["loader"] + if validation_config["metric"] is not None: + + if type(loader_config["datafolder"]) == str: + loader_config["datafolder"] = [loader_config["datafolder"]] + + filenames = [] + for folder in loader_config["datafolder"]: + fns = glob(folder + "/*_GT.ome.tif") + fns.sort() + filenames += fns + + total_num = len(filenames) + LeaveOut = validation_config["leaveout"] + + all_same_size = False + rand = False + it = 0 + max_it = 10 + while not all_same_size and it < max_it: + if rand and it > 0: + print("Validation images not all same size. Reshuffling...") + elif not rand and it > 0: + print( + "Validation images must be the same size. Please choose" + " different validation img indices" + ) + quit() + + if len(LeaveOut) == 1: + if LeaveOut[0] > 0 and LeaveOut[0] < 1: + num_train = int(np.floor((1 - LeaveOut[0]) * total_num)) + shuffled_idx = np.arange(total_num) + # make sure validation sets are same across gpus + random.seed(0) + random.shuffle(shuffled_idx) + train_idx = shuffled_idx[:num_train] + valid_idx = shuffled_idx[num_train:] + rand = True + else: + valid_idx = [int(LeaveOut[0])] + train_idx = list( + set(range(total_num)) - set(map(int, LeaveOut)) + ) + elif LeaveOut: + valid_idx = list(map(int, LeaveOut)) + train_idx = list(set(range(total_num)) - set(valid_idx)) + + img_shapes = [ + load_img( + filenames[fn][:-11], + img_type="label", + n_channel=self.nchannel, + shape_only=True, + ) + for fn in valid_idx + ] + all_same_size = img_shapes.count(img_shapes[0]) == len(img_shapes) + it += 1 + if loader_config["batch_size"] == 1: + all_same_size = True + if it == max_it: + assert ( + all_same_size + ), "Could not find val images with all same size, please try again." + valid_filenames = [] + train_filenames = [] + # remove file extensions from filenames + for fi, fn in enumerate(valid_idx): + valid_filenames.append(filenames[fn][:-11]) + for fi, fn in enumerate(train_idx): + train_filenames.append(filenames[fn][:-11]) + + self.valid_filenames = valid_filenames + self.train_filenames = train_filenames + + else: + print("need validation in config file") + quit() + + def train_dataloader(self): + """ + Returns the train dataloader from the train filenames with the specified + transforms. + + Parameters:None + Return: DataLoader train_set_loader + """ + loader_config = self.loader_config + train_set_loader = DataLoader( + UniversalDataset( + self.train_filenames, + loader_config["PatchPerBuffer"], + self.size_in, + self.size_out, + self.nchannel, + use_costmap=self.accepts_costmap, + transforms=self.transforms, + patchize=True, + check_crop=self.check_crop, + init_only=self.init_only, # first call of train_dataloader is just to get dataset params if init_only is true # noqa E501 + ), + batch_size=loader_config["batch_size"], + shuffle=True, + num_workers=loader_config["NumWorkers"], + pin_memory=True, + ) + return train_set_loader + + def val_dataloader(self): + """ + Returns the validation dataloader from the validation filenames with + no transforms + + Parameters: None + Return: DataLoader val_set_loader + """ + loader_config = self.loader_config + val_set_loader = DataLoader( + UniversalDataset( + self.valid_filenames, + loader_config["PatchPerBuffer"], + self.size_in, + self.size_out, + self.nchannel, + transforms=[], # no transforms for validation data + use_costmap=self.accepts_costmap, + patchize=False, # validate on entire image + ), + batch_size=loader_config["batch_size"], + shuffle=False, + num_workers=loader_config["NumWorkers"], + pin_memory=True, + ) + return val_set_loader + + def test_dataloader(self): + """ + Returns the test dataloader + Parameters: None + Return: DataLoader test_set_loader + """ + test_set_loader = DataLoader( + TestDataset(self.config), + batch_size=1, + shuffle=False, + num_workers=self.config["NumWorkers"], + pin_memory=True, + worker_init_fn=init_worker, + ) + return test_set_loader diff --git a/aicsmlsegment/DataUtils/Universal_Loader.py b/aicsmlsegment/DataUtils/Universal_Loader.py new file mode 100644 index 0000000..e7b7381 --- /dev/null +++ b/aicsmlsegment/DataUtils/Universal_Loader.py @@ -0,0 +1,966 @@ +import numpy as np +import random +from torch import from_numpy +import torch +from aicsimageio import AICSImage +from aicsmlsegment.utils import ( + image_normalization, +) +from random import shuffle +from torch.utils.data import Dataset, IterableDataset +from scipy.ndimage import zoom +from torchio.transforms import RandomAffine, RandomBiasField, RandomNoise +from monai.transforms import RandShiftIntensity +from typing import Dict, List, Sequence, Tuple + + +# CODE for generic loader +# No augmentation = NOAUG,simply load data and convert to tensor +# Augmentation code: +# RR = Rotate by a random degree from 1 to 180 +# R4 = Rotate by 0, 90, 180, 270 +# RF = Random flip +# FH = Flip Horizontally +# FV = Flip Vertically +# FD = Flip Depth (i.e., along z dim) +# SS = Size Scaling by a ratio between -0.1 to 0.1 (TODO) +# IJ = Intensity Jittering (TODO) +# DD = Dense Deformation (TODO) + + +def minmax(img: np.ndarray) -> np.ndarray: + """ + Performs minmax normalization on an image + + Parameters + ---------- + img: numpy array + + Return: minmaxed numpy array + """ + return (img - img.min()) / (img.max() - img.min()) + + +def resize(img: np.ndarray, config: Dict, min_max: bool = False) -> np.ndarray: + """ + Resize an image based on the provided config. + + Parameters + ---------- + img: 4d CZYX order numpy array + config: user-provided configuration file with "ResizeRatio" provided + min_max: whether to conduct minmax normalization on each channel independently + + Return: resized + minmaxed img if specified + """ + if len(config["ResizeRatio"]) > 0 and config["ResizeRatio"] != [ + 1.0, + 1.0, + 1.0, + ]: + # don't resize if resize ratio is all 1s + # note that struct_img is only a view of img, so changes made on + # struct_img also affects img + assert len(img.shape) == 4, f"Expected 4D image, got {len(img.shape)}-D array" + img = zoom( + img, + ( + 1, + config["ResizeRatio"][0], + config["ResizeRatio"][1], + config["ResizeRatio"][2], + ), + order=2, + mode="reflect", + ) + if min_max: + for ch_idx in range(img.shape[0]): + struct_img = img[ch_idx, :, :, :] + img[ch_idx, :, :, :] = minmax(struct_img) + return img + + +def undo_resize(img: np.ndarray, config: Dict): + """ + Undo Resizing an image based on the provided config. + + Parameters + ---------- + img: 5d NCZYX order numpy array + config: user-provided configuration file with "ResizeRatio" provided + + Return: float 32 numpy array img resized to its original dimensions + """ + if len(config["ResizeRatio"]) > 0 and config["ResizeRatio"] != [1.0, 1.0, 1.0]: + img = zoom( + img, + ( + 1.0, + 1.0, + 1 / config["ResizeRatio"][0], + 1 / config["ResizeRatio"][1], + 1 / config["ResizeRatio"][2], + ), + order=2, + mode="reflect", + ) + return img.astype(np.float32) + + +def swap(ll: List, index1: int, index2: int) -> List: + """ + Swap index1 and index2 of list L + + Parameters + ---------- + l: list + index1, index2: integer indices to swap + + Return: List l with index1 and index2 swapped + """ + temp = ll[index1] + ll[index1] = ll[index2] + ll[index2] = temp + return ll + + +def validate_shape( + img_shape: Tuple[int], + n_channel: int = 1, + timelapse: bool = False, +) -> Tuple[Dict, Tuple[Sequence[int]]]: + """ + General function to load and rearrange the dimensions of 3D images + input: + img_shape: STCZYX shape of image + n_channel: number of channels expted in image + timelapse: whether image is a timelapse + + output: + load_dict: dictionary to be passed to AICSImage.get_image_data + containing out_orientation and specific channel indices + correct_shape: tuple rearranged img_shape + """ + img_shape = list(img_shape) + load_order = ["S", "T", "C", "Z", "Y", "X"] + expected_channel_idx = 2 + # all dimensions that could be channel dimension + real_channel_idx = [i for i, elem in enumerate(img_shape) if elem == n_channel] + keep_channels = ["C"] + if expected_channel_idx not in real_channel_idx: + assert ( + len(real_channel_idx) > 0 + ), f"The specified channel dim is wrong, no other dims have size {n_channel}" + # if nchannels is 1, doesn't matter which other size-1 dim we swap it with + assert n_channel == 1 or len(real_channel_idx) == 1, ( + "Index of channel dimension is incorrect and there are multiple candidate " + f"channel dimensions. Please check your image metadata. {img_shape}" + ) + + # change load order and image shape to reflect new index of channel dimension + real_channel_idx = real_channel_idx[-1] + keep_channels.append(load_order[real_channel_idx]) + swap(load_order, real_channel_idx, expected_channel_idx) + swap(img_shape, real_channel_idx, expected_channel_idx) + + load_dict = {"out_orientation": ""} + correct_shape = [] + for s, load in zip(img_shape, load_order): + if s == 1 and load not in keep_channels: + # specify e.g. S=0 for aicsimagio + load_dict[load] = 0 + else: + load_dict["out_orientation"] += load + correct_shape.append(s) + if timelapse: + assert ( + correct_shape[1] > 1 + ), "Image is not a timelapse, please check your image metadata" + + return load_dict, tuple(correct_shape) + + +def load_img( + filename: str, + img_type: str, + n_channel: int = 1, + input_ch: int = None, + shape_only: bool = False, +) -> List[np.ndarray]: + """ + General function to load and rearrange the dimensions of 3D images + input: + filename: name of image to be loaded + img_type: one of "label", "input", or "costmap" determining the file extension + n_channel: number of channels expected by model + input_ch: channel to extract from image during loading for testing + shape_only: whether to only return validated shape of an image + output: + img: list of np.ndarray(s) containing image data. + """ + extension_dict = { + "label": "_GT.ome.tif", + "input": ".ome.tif", + "costmap": "_CM.ome.tif", + "test": "", + "timelapse": "", + } + reader = AICSImage(filename + extension_dict[img_type]) + args_dict, correct_shape = validate_shape( + reader.shape, n_channel, img_type == "timelapse" + ) + if shape_only: + return correct_shape + if img_type != "timelapse": + img = reader.get_image_data(**args_dict) + if img_type == "costmap": + img = np.squeeze(img, 0) # remove channel dimension + if img_type == "test": + # return as list so we can iterate through it in test dataloader + img = img.astype(float) + img = [img[input_ch, :, :, :]] + else: + img = [ + reader.get_image_data(**args_dict, T=tt) for tt in range(correct_shape[1]) + ] + return img + + +class UniversalDataset(Dataset): + """ + Multipurpose dataset for training and validation. Randomly crops images, labels, + and costmaps into user-specified number of patches. Users can specify which + augmentations to apply. + """ + + def __init__( + self, + filenames: Sequence[str], + num_patch: int, + size_in: Sequence[int], + size_out: Sequence[int], + n_channel: int, + use_costmap: bool = True, + transforms: Sequence[str] = [], + patchize: bool = True, + check_crop: bool = False, + init_only: bool = False, + ): + """ + input: + filenames: path to images + num_patch: number of random crops to be produced + size_in: size of input to model + size_out: size of output from model + n_channel: number of iput channels expected by model + transforms: list of strings specifying transforms + patchize: whether to divide image into patches + check_crop: whether to check + """ + self.patchize = patchize + self.img = [] + self.gt = [] + self.cmap = [] + self.transforms = transforms + self.parameters = { + "filenames": filenames, + "num_patch": num_patch, + "size_in": size_in, + "size_out": size_out, + "n_channel": n_channel, + "use_costmap": use_costmap, + "transforms": transforms, + "patchize": patchize, + "check_crop": check_crop, + } + self.num_patch = num_patch + self.init_only = init_only + if init_only: + num_patch = 1 + num_data = len(filenames) + shuffle(filenames) + self.filenames = None + if not patchize: + print("Validating on", filenames) + self.filenames = filenames + num_patch_per_img = np.zeros((num_data,), dtype=int) + if num_data >= num_patch: + # take one patch from each image + num_patch_per_img[:num_patch] = 1 + else: # assign how many patches to take form each img + basic_num = num_patch // num_data + # assign each image the same number of patches to extract + num_patch_per_img[:] = basic_num + + # assign 1 more patch to the first few images to get the total patch number + num_patch_per_img[: (num_patch - basic_num * num_data)] = ( + num_patch_per_img[: (num_patch - basic_num * num_data)] + 1 + ) + + padding = [(x - y) // 2 for x, y in zip(size_in, size_out)] + + # extract patches from images until num_patch reached + for img_idx, fn in enumerate(filenames): + # if we're not dividing into patches, don't break before transforming imgs + if patchize and len(self.img) == num_patch: + break + + label = load_img(fn, img_type="label", n_channel=n_channel) + input_img = load_img(fn, img_type="input", n_channel=n_channel) + if use_costmap: + costmap = load_img(fn, img_type="costmap", n_channel=n_channel) + else: + costmap = np.zeros((1)) + + img_pad0 = np.pad( + input_img, + ((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])), + "symmetric", + ) + raw = np.pad( + img_pad0, ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), "constant" + ) + + if "RF" in transforms: + # random flip + flip_flag = random.random() + if flip_flag < 0.5: + raw = np.flip( + raw, axis=-1 + ).copy() # avoid negative stride error when converting to tensor + if use_costmap: + costmap = np.flip(costmap, axis=-1).copy() + label = np.flip(label, axis=-1).copy() + + if "RR" in transforms: + # random rotation + deg = random.randrange(1, 180) + trans = RandomAffine( + scales=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + degrees=(0, 0, 0, 0, deg, deg), + default_pad_value=0, + image_interpolation="bspline", + center="image", + ) + + # rotate the raw image + out_img = trans(np.transpose(raw, (0, 3, 2, 1))) + raw = np.transpose(out_img, (0, 3, 2, 1)) + + trans_label = RandomAffine( + scales=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + degrees=(0, 0, 0, 0, deg, deg), + default_pad_value=0, + image_interpolation="nearest", + center="image", + ) + # rotate label and costmap + out_label = trans_label(np.transpose(label, (0, 3, 2, 1))) + label = np.transpose(out_label, (0, 3, 2, 1)) + if use_costmap: + out_map = trans_label( + np.transpose(np.expand_dims(costmap, axis=0), (0, 3, 2, 1)) + ) + costmap = np.transpose(out_map[0, :, :, :], (2, 1, 0)) + if "RBF" in transforms: + random_bias_field = RandomBiasField() + raw = random_bias_field(raw) + if "RN" in transforms: + random_noise = RandomNoise() + raw = random_noise(raw) + + if "RI" in transforms: + random_intensity = RandShiftIntensity(offsets=0.15, prob=0.2) + raw = random_intensity(raw) + + if patchize: + # take specified number of patches from current image + new_patch_num = 0 + num_fail = 0 + while new_patch_num < num_patch_per_img[img_idx]: + pz = random.randint(0, label.shape[1] - size_out[0]) + py = random.randint(0, label.shape[2] - size_out[1]) + px = random.randint(0, label.shape[3] - size_out[2]) + + if use_costmap: + # check if this is a good crop + ref_patch_cmap = costmap[ + pz : pz + size_out[0], + py : py + size_out[1], + px : px + size_out[2], + ] + if check_crop: + if np.count_nonzero(ref_patch_cmap > 1e-5) < 1000: + num_fail += 1 + if num_fail > 50: + print("Failed to generate valid crops") + break + continue + + # confirmed good crop + (self.img).append( + raw[ + :, + pz : pz + size_in[0], + py : py + size_in[1], + px : px + size_in[2], + ] + ) + (self.gt).append( + label[ + :, + pz : pz + size_out[0], + py : py + size_out[1], + px : px + size_out[2], + ] + ) + if use_costmap: + (self.cmap).append(ref_patch_cmap) + else: + self.cmap.append(costmap) + + new_patch_num += 1 + else: + (self.img).append(raw) + (self.gt).append(label) + (self.cmap).append(costmap) + + def __getitem__(self, index): + if self.init_only: + return torch.zeros(0) + if self.filenames is not None: + fn = self.filenames[index] + else: + fn = "" + img_tensor = from_numpy(self.img[index].astype(float)).float() + gt_tensor = from_numpy(self.gt[index].astype(float)).float() + cmap_tensor = from_numpy(self.cmap[index].astype(float)).float() + + return (img_tensor, gt_tensor, cmap_tensor, fn) + + def __len__(self): + if self.init_only: + return self.num_patch + return len(self.img) + + def get_params(self): + return self.parameters + + +from monai.transforms import ( + MapTransform, + RandFlipd, + RandBiasFieldd, + RandGaussianNoised, + RandShiftIntensityd, + RandSpatialCropSamplesd, + RandSpatialCropd, + ToTensord, + Compose, +) + + +class LoadImageD(MapTransform): + def __init__(self, use_costmap, n_channel): + super(LoadImageD).__init__() + self.use_costmap = use_costmap + self.n_channel = n_channel + + def __call__(self, data): + img_data = {} + img_data["label"] = load_img(data, img_type="label", n_channel=self.n_channel) + img_data["img"] = load_img(data, img_type="input", n_channel=self.n_channel) + if self.use_costmap: + img_data["costmap"] = load_img( + data, img_type="costmap", n_channel=self.n_channel + ) + return img_data + + +class PadImageD(MapTransform): + def __init__(self, padding, keys): + super(PadImageD).__init__() + self.padding = padding + self.keys = keys + + def __call__(self, data): + for key in self.keys: + data[key] = np.pad( + data[key], + ( + (0, 0), + (0, 0), + (self.padding[1], self.padding[1]), + (self.padding[2], self.padding[2]), + ), + "symmetric", + ) + data[key] = np.pad( + data[key], + ((0, 0), (self.padding[0], self.padding[0]), (0, 0), (0, 0)), + "constant", + ) + return data + + +class RandomPatchesD(MapTransform): + def __init__(self, check_crop, size_in, size_out, keys, num_patch): + super(RandomPatchesD).__init__() + self.check_crop = check_crop + self.size_in = size_in + self.size_out = size_out + self.keys = keys + self.num_patch = num_patch + + def __call__(self, data): + if self.check_crop: + cropper = RandSpatialCropd(self.keys, self.size_in, random_size=False) + num_fail = 0 + n_patches = 0 + while n_patches < self.num_patch: + additional_data = cropper(data) + if np.count_nonzero(additional_data["costmap"] > 1e-5) < 1000: + num_fail += 1 + assert num_fail < 50, "Failed to generate valid crops." + else: + for key in additional_data: + data[key] += additional_data[key] + n_patches += 1 + else: + cropper = RandSpatialCropSamplesd( + self.keys, self.size_in, self.num_patch, random_size=False + ) + data = cropper(data) + # crop costmap (if available) and label to match model output size + if self.size_in != self.size_out: + for key in self.keys: + if key == "img": + continue + for i in range(len(data[key])): + data[key][i] = data[key][i][ + 0 : self.size_in[0], 0 : self.size_in[1], 0 : self.size_in[2] + ] + + return data + + +class RandomRotationD(MapTransform): + def __init__(self, use_costmap): + super(RandomRotationD).__init__() + self.use_costmap = use_costmap + + def __call__(self, data): + deg = random.randrange(1, 180) + trans = RandomAffine( + scales=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + degrees=(0, 0, 0, 0, deg, deg), + default_pad_value=0, + image_interpolation="bspline", + center="image", + ) + out_img = trans(np.transpose(data["img"], (0, 3, 2, 1))) + data["img"] = np.transpose(out_img, (0, 3, 2, 1)) + + trans_label = RandomAffine( + scales=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + degrees=(0, 0, 0, 0, deg, deg), + default_pad_value=0, + image_interpolation="nearest", + center="image", + ) + # rotate label and costmap + out_label = trans_label(np.transpose(data["label"], (0, 3, 2, 1))) + data["label"] = np.transpose(out_label, (0, 3, 2, 1)) + if self.use_costmap: + out_map = trans_label(np.transpose(data["costmap"], (0, 3, 2, 1))) + data["costmap"] = np.transpose(out_map, (0, 3, 2, 1)) + + return data + + +class UniversalDataset_redo_transforms(Dataset): + """ + Multipurpose dataset for training and validation. Randomly crops images, labels, + and costmaps into user-specified number of patches. Users can specify which + augmentations to apply. + """ + + def __init__( + self, + filenames: Sequence[str], + num_patch: int, + size_in: Sequence[int], + size_out: Sequence[int], + n_channel: int, + use_costmap: bool = True, + transforms: Sequence[str] = [], + patchize: bool = True, + check_crop: bool = False, + init_only: bool = False, + ): + """ + input: + filenames: path to images + num_patch: number of random crops to be produced + size_in: size of input to model + size_out: size of output from model + n_channel: number of iput channels expected by model + transforms: list of strings specifying transforms + patchize: whether to divide image into patches + check_crop: whether to check + """ + self.data = {"img": [], "label": []} + if use_costmap: + self.data["costmap"] = [] + self.parameters = { + "filenames": filenames, + "num_patch": num_patch, + "size_in": size_in, + "size_out": size_out, + "n_channel": n_channel, + "use_costmap": use_costmap, + "transforms": transforms, + "patchize": patchize, + "check_crop": check_crop, + } + self.init_only = init_only + if init_only: + num_patch = 1 + num_data = len(filenames) + shuffle(filenames) + self.filenames = None + if not patchize: + print("Validating on", filenames) + self.filenames = filenames + num_patch_per_img = np.zeros((num_data,), dtype=int) + print(filenames) + print(num_data, num_patch) + if num_data >= num_patch: + # take one patch from each image + num_patch_per_img[:num_patch] = 1 + else: # assign how many patches to take form each img + basic_num = num_patch // num_data + # assign each image the same number of patches to extract + num_patch_per_img[:] = basic_num + + # assign 1 more patch to the first few images to get the total patch number + num_patch_per_img[: (num_patch - basic_num * num_data)] = ( + num_patch_per_img[: (num_patch - basic_num * num_data)] + 1 + ) + + self.padding = [(x - y) // 2 for x, y in zip(size_in, size_out)] + + # basepath = "//allen/aics/assay-dev/users/Benji/transform_test/" + # extract patches from images until num_patch reached + import time + + t1 = time.time() + tsfrm = self.select_transforms(160) + for count, (fn, n_patch) in enumerate(zip(filenames, num_patch_per_img)): + if n_patch == 0 or count > num_patch: + break + aug_data = tsfrm(fn) + for key in self.data: # costmap, img, label + for i in aug_data: # each patch + self.data[key] += i[key] + print(time.time() - t1) + + def select_transforms(self, num_patch): + all_keys = ["img", "label"] + params = self.parameters + if params["use_costmap"]: + all_keys.append("costmap") + transform_fns = [ + LoadImageD( + use_costmap=params["use_costmap"], n_channel=params["n_channel"] + ), + PadImageD(self.padding, keys=["img"]), + ] + if "RF" in params["transforms"]: + flipper = RandFlipd(keys=all_keys, prob=1, spatial_axis=-1) + transform_fns.append(flipper) + if "RR" in params["transforms"]: + transform_fns.append(RandomRotationD(use_costmap=params["use_costmap"])) + if "RBF" in params["transforms"]: + transform_fns.append(RandBiasFieldd(keys=["img"], coeff_range=(0.0, 0.01))) + if "RN" in params["transforms"]: + transform_fns.append(RandGaussianNoised(keys=["img"], std=0.001)) + if "RI" in params["transforms"]: + transform_fns.append(RandShiftIntensityd(keys=["img"], offsets=0.08)) + if params["patchize"]: + transform_fns.append( + RandomPatchesD( + check_crop=params["check_crop"], + size_in=params["size_in"], + size_out=params["size_out"], + keys=all_keys, + num_patch=num_patch, + ) + ) + transform_fns.append(ToTensord(keys=all_keys)) + return Compose(transform_fns) + + def __getitem__(self, index): + if self.init_only: + return torch.zeros(0) + if self.filenames is not None: + fn = self.filenames[index] + else: + fn = "" + + if self.parameters["use_costmap"]: + costmap = self.data["costmap"][index] + else: + costmap = [] + + return (self.data["img"][index], self.data["label"][index], costmap, fn) + + def __len__(self): + if self.init_only: + return self.num_patch + return len(self.data["img"]) + + def get_params(self): + return self.parameters + + +def patchize( + img: np.ndarray, pr: Sequence[int], patch_size: Sequence[int] +) -> Tuple[List[List[int]], List[np.ndarray]]: + """ + Break an image into z * y * x patches specified by pr + + Parameters + ---------- + img: 4d CZYX order numpy array + pr: length 3 list specifying number of patches to divide in z,y,x dimensions + patch_size: inference patch size to make sure that the patches are large + enough for inference + + Return: list of [i,j,k] start points of a patch and corresponding list of + np.array imgs + """ + ijk = [] + imgs = [] + + x_max = img.shape[-1] + y_max = img.shape[-2] + z_max = img.shape[-3] + + x_patch_sz = x_max // pr[2] + y_patch_sz = y_max // pr[1] + z_patch_sz = z_max // pr[0] + + assert ( + x_patch_sz >= patch_size[2] + and y_patch_sz >= patch_size[1] + and z_patch_sz >= patch_size[0] + ), "Large image resize patches must be larger than model patch size" + assert len(img.shape) == 4, f"Expected 4D image, got {len(img.shape)}-D array" + assert len(pr) == 3, f"Expected pr to have length 3, got length {len(pr)}" + + maxs = [z_max, y_max, x_max] + patch_szs = [z_patch_sz, y_patch_sz, x_patch_sz] + + all_coords = [] + for i in range(3): + # remainder is the number of pixels per axis not evenly divided into patches + remainder = maxs[i] % pr[i] + coords = [ + # for the first *remainder* patches, we want to expand the + # patch_size by one pixel so that after *remainder* iterations, + # all pixels are included in exactly one patch. + j * patch_szs[i] + j if j < remainder + # once *remainder* pixels have been added we don't have to + # add an extra pixel to each patch's size, but we do + # have to shift the starts of the remaining patches + # by the *remainder* pixels we've already added + else j * patch_szs[i] + remainder + for j in range(pr[i] + 1) + ] + all_coords.append(coords) + + for i in range(pr[0]): # z + for j in range(pr[1]): # y + for k in range(pr[2]): # x + i_start = max(0, all_coords[0][i] - 5) + i_end = min(z_max, all_coords[0][i + 1] + 5) + + j_start = max(0, all_coords[1][j] - 30) + j_end = min(y_max, all_coords[1][j + 1] + 30) + + k_start = max(0, all_coords[2][k] - 30) + k_end = min(x_max, all_coords[2][k + 1] + 30) + temp = np.array( + img[ + :, + i_start:i_end, + j_start:j_end, + k_start:k_end, + ] + ) + ijk.append([i_start, j_start, k_start]) + imgs.append(temp) + return ijk, imgs + + +# TODO deal with timelapse images +class TestDataset(IterableDataset): + def __init__(self, config: Dict): + """ + Dataset to load, resize, normalize, and return testing images when needed for + inference + + Parameters + ---------- + config: user-provided preferences to specify how to shape and normalize images + in preparation for prediction + Return: None + """ + self.config = config + self.inf_config = config["mode"] + self.model_config = config["model"] + self.patchize_ratio = config["large_image_resize"] + self.patches_per_image = np.prod(self.patchize_ratio) + self.load_type = "test" + self.timelapse = False + + try: # monai + self.size_in = self.model_config["patch_size"] + self.size_out = self.model_config["patch_size"] + self.nchannel = self.model_config["in_channels"] + except KeyError: # unet_xy_zoom + self.size_in = self.model_config["size_in"] + self.size_out = self.model_config["size_out"] + self.nchannel = self.model_config["nchannel"] + + if self.inf_config["name"] == "file": + filenames = [self.inf_config["InputFile"]] + if "timelapse" in self.inf_config and self.inf_config["timelapse"]: + self.load_type = "timelapse" + self.timelapse = True + else: + from glob import glob + + if type(self.inf_config["InputDir"]) == str: + self.inf_config["InputDir"] = [self.inf_config["InputDir"]] + + filenames = [] + for folder in self.inf_config["InputDir"]: + fns = glob(folder + "/*" + self.inf_config["DataType"]) + fns.sort() + filenames += fns + print("Predicting on", len(filenames), "files") + + self.filenames = filenames + self.start = None + self.end = None + self.all_img_info = [] + + def pad_image(self, image: np.ndarray) -> torch.Tensor: + """ + Pad image so model output is the same size as the input. Padding is symmetric + in xy and constant in z + + Parameters + ---------- + image: 4d CZYX order numpy array + + Return: padded image + """ + if len(image.shape) == 5: + image = np.squeeze(image, axis=0) + padding = [(x - y) // 2 for x, y in zip(self.size_in, self.size_out)] + image = np.pad( + image, + ((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])), + "symmetric", + ) + image = np.pad( + image, + ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), + "constant", + ) + image = from_numpy(image.astype(float)).float() + return image + + def patchize_wrapper( + self, + pr: Sequence[int], + fn: str, + img: np.ndarray, + patch_size: Tuple, + tt: int, + timelapse: bool, + ) -> Dict: + """ + Create dictionary with information necessary for inference + + Parameters + ---------- + fn: filename + img: 4d CZYX order numpy array + pr: length 3 list specifying number of patches to divide in z,y,x dimensions + patch_size: inference patch size to make sure that the patches are large enough + for inference + tt: timepoint + timelapse: whether image is a timelapse + + Return: Dictionary containing image filename, tensor image, shape of input + image, ijk index of patch from original image, how many patches + original image was split into, and timepoint + """ + if pr == [1, 1, 1]: + return_dicts = [ + { + "fn": fn, + "img": self.pad_image(img), + "im_shape": img.shape, + "ijk": -1, + "save_n_batches": 1, + "tt": tt if timelapse else -1, + } + ] + else: + save_n_batches = np.prod( + pr + ) # how many patches until aggregated image saved + ijk, imgs = patchize(img, pr, patch_size) + return_dicts = [] + for index, patch in zip(ijk, imgs): + return_dict = { + "fn": fn, + "img": self.pad_image(patch), + "im_shape": img.shape, + "ijk": index, + "save_n_batches": save_n_batches, + "tt": tt if timelapse else -1, + } + return_dicts.append(return_dict) + return return_dicts + + def __iter__(self): + self.current_index = self.start + return self + + def __next__(self): + if self.current_index > self.end and len(self.all_img_info) == 0: + raise StopIteration + if len(self.all_img_info) == 0: # load new image if no images have been loaded + fn = self.filenames[self.current_index] + imgs = load_img(fn, self.load_type, self.nchannel, self.config["InputCh"]) + # only one image unless timelapse + for tt, img in enumerate(imgs): + img = resize(img, self.config) + img = image_normalization(img, self.config["Normalization"]) + # generate patch info + self.all_img_info += self.patchize_wrapper( + self.patchize_ratio, + fn, + img, + self.size_in, + tt, + self.timelapse, + ) + self.current_index += 1 # next iteration load the next file + return self.all_img_info.pop() # pop patch/tp diff --git a/aicsmlsegment/Model.py b/aicsmlsegment/Model.py new file mode 100644 index 0000000..b044f40 --- /dev/null +++ b/aicsmlsegment/Model.py @@ -0,0 +1,418 @@ +import pytorch_lightning +from torch.optim import Adam +import torch +from aicsmlsegment.custom_metrics import get_metric +from aicsmlsegment.custom_loss import get_loss_criterion +from aicsmlsegment.model_utils import ( + model_inference, + apply_on_image, +) +from aicsmlsegment.DataUtils.Universal_Loader import ( + minmax, + undo_resize, + UniversalDataset, +) +from aicsmlsegment.utils import compute_iou + +import numpy as np +from skimage.io import imsave +from skimage.morphology import remove_small_objects +import os +import pathlib +from torch.utils.data import DataLoader + + +class Model(pytorch_lightning.LightningModule): + # the base class for all the models + def __init__(self, config, model_config, train): + super().__init__() + + self.args_inference = {} + + self.model_name = config["model"]["name"] + self.model_config = model_config + + if "unet_xy" in self.model_name: # custom model + import importlib + from aicsmlsegment.model_utils import weights_init as weights_init + + module = importlib.import_module( + "aicsmlsegment.NetworkArchitecture." + self.model_name + ) + init_args = { + "in_channel": model_config["nchannel"], + "n_classes": model_config["nclass"], + "test_mode": not train, + } + if self.model_name == "sdunet_xy": + init_args["loss"] = config["loss"]["name"] + + if "zoom" in self.model_name: + init_args["down_ratio"] = model_config.get("zoom_ratio", 3) + + model = getattr(module, "UNet3D") + self.model = model(**init_args).apply(weights_init) + + self.args_inference["size_in"] = model_config["size_in"] + self.args_inference["size_out"] = model_config["size_out"] + self.args_inference["nclass"] = model_config["nclass"] + + else: # monai model + if self.model_name == "segresnetvae": + from monai.networks.nets.segresnet import SegResNetVAE as model + + model_config["input_image_size"] = model_config["patch_size"] + elif self.model_name == "extended_vnet": + from aicsmlsegment.NetworkArchitecture.vnet import VNet as model + elif self.model_name == "extended_dynunet": + from aicsmlsegment.NetworkArchitecture.dynunet import DynUNet as model + + else: + import importlib + + module = importlib.import_module( + "monai.networks.nets." + self.model_name + ) + # deal with monai name scheme - module name != class name for networks + net_name = [attr for attr in dir(module) if "Net" in attr][0] + model = getattr(module, net_name) + # monai model assumes same size for input and output + self.args_inference["size_in"] = model_config["patch_size"] + self.args_inference["size_out"] = model_config["patch_size"] + del model_config["patch_size"] + self.model = model(**model_config) + + self.config = config + self.aggregate_img = None + if train: + loader_config = config["loader"] + self.datapath = loader_config["datafolder"] + self.nworkers = loader_config["NumWorkers"] + self.batchsize = loader_config["batch_size"] + self.epoch_shuffle = loader_config["epoch_shuffle"] + + validation_config = config["validation"] + self.leaveout = validation_config["leaveout"] + self.validation_period = validation_config["validate_every_n_epoch"] + + self.lr = config["learning_rate"] + self.weight_decay = config["weight_decay"] + + self.args_inference["inference_batch_size"] = loader_config["batch_size"] + self.args_inference["OutputCh"] = validation_config["OutputCh"] + + ( + self.loss_function, + self.accepts_costmap, + ) = get_loss_criterion(config) + self.metric = get_metric(config) + self.scheduler_params = config["scheduler"] + self.dataset_params = None + + else: + if config["RuntimeAug"] <= 0: + self.args_inference["RuntimeAug"] = False + else: + self.args_inference["RuntimeAug"] = True + self.args_inference["OutputCh"] = config["OutputCh"] + self.args_inference["inference_batch_size"] = config["batch_size"] + self.args_inference["mode"] = config["mode"]["name"] + self.args_inference["Threshold"] = config["Threshold"] + if config["large_image_resize"] != [1, 1, 1]: + self.aggregate_img = {} + self.count_map = {} + self.batch_count = {} + self.save_hyperparameters() + + def forward(self, x): + """ + returns raw predictions + """ + return self.model(x) + + def configure_optimizers(self): + optims = [] + scheds = [] + + scheduler_params = self.scheduler_params + + # basic optimizer + optimizer = Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + optims.append(optimizer) + + if scheduler_params["name"] is not None: + if scheduler_params["name"] == "ExponentialLR": + from torch.optim.lr_scheduler import ExponentialLR + + assert scheduler_params["gamma"] > 0 + scheduler = ExponentialLR( + optims[0], + gamma=scheduler_params["gamma"], + verbose=scheduler_params["verbose"], + ) + + elif scheduler_params["name"] == "CosineAnnealingLR": + from torch.optim.lr_scheduler import CosineAnnealingLR as CALR + + assert scheduler_params["T_max"] > 0 + scheduler = CALR( + optims[0], + T_max=scheduler_params["T_max"], + verbose=scheduler_params["verbose"], + ) + + elif scheduler_params["name"] == "StepLR": + from torch.optim.lr_scheduler import StepLR + + assert scheduler_params["step_size"] > 0 + assert scheduler_params["gamma"] > 0 + scheduler = StepLR( + optims[0], + step_size=scheduler_params["step_size"], + gamma=scheduler_params["gamma"], + verbose=scheduler_params["verbose"], + ) + elif scheduler_params["name"] == "ReduceLROnPlateau": + from torch.optim.lr_scheduler import ReduceLROnPlateau + + assert 0 < scheduler_params["factor"] < 1 + assert scheduler_params["patience"] > 0 + # if patience is too short, validation metrics won't be available + # if "val" in scheduler_params["monitor"]: + # assert ( + # scheduler_params["patience"] > self.validation_period + # ), "Patience must be larger than validation frequency" + scheduler = ReduceLROnPlateau( + optims[0], + mode=scheduler_params["mode"], + factor=scheduler_params["factor"], + patience=scheduler_params["patience"], + verbose=scheduler_params["verbose"], + min_lr=0.0000001, + ) + # monitoring metric must be specified + return { + "optimizer": optims[0], + "lr_scheduler": scheduler, + "monitor": scheduler_params["monitor"], + } + elif scheduler_params["name"] == "1cycle": + from torch.optim.lr_scheduler import OneCycleLR + + scheduler = OneCycleLR( + optims[0], + max_lr=scheduler_params["max_lr"], + total_steps=scheduler_params["total_steps"], + pct_start=scheduler_params["pct_start"], + verbose=scheduler_params["verbose"], + ) + else: + print( + "The selected scheduler is not yet supported. No scheduler is used." + ) + return optims + scheds.append(scheduler) + return optims, scheds + else: + print("no scheduler is used") + return optims + + # HACK until pytorch lightning includes reload_dataloaders_every_n_epochs + def on_train_epoch_start(self): + if self.epoch_shuffle is not None: + if self.current_epoch == 0 and self.dataset_params is None: + self.dataset_params = self.train_dataloader().dataset.get_params() + + if self.current_epoch % self.epoch_shuffle == 0: + if self.global_rank == 0 and self.current_epoch > 0: + print("Reloading dataloader...") + self.DATALOADER = DataLoader( + UniversalDataset(**self.dataset_params), + batch_size=self.config["loader"]["batch_size"], + shuffle=True, + num_workers=self.config["loader"]["NumWorkers"], + pin_memory=True, + ) + self.iter_dataloader = iter(self.DATALOADER) + + def get_upsample_grid(self, desired_shape, n_targets): + x = torch.linspace(-1, 1, desired_shape[-1], device=self.device) + y = torch.linspace(-1, 1, desired_shape[-2], device=self.device) + z = torch.linspace(-1, 1, desired_shape[-3], device=self.device) + meshz, meshy, meshx = torch.meshgrid((z, y, x)) + grid = torch.stack((meshx, meshy, meshz), 3) + grid = torch.stack([grid] * n_targets) # one grid for each target in batch + return grid + + def log_and_return(self, name, value): + # sync_dist on_epoch=True ensures that results will be averaged across gpus + self.log( + name, + value, + sync_dist=True, + prog_bar=True, + on_epoch=True, + on_step=False, + ) + return {"loss": value} # return val only used in train step + + def training_step(self, batch, batch_idx): + if self.epoch_shuffle is not None: + # ignore dataloader provided by pytorch lightning + batch = next(self.iter_dataloader) + inputs = batch[0].half().to(self.device) + targets = batch[1].to(self.device) + cmap = batch[2].to(self.device) + else: + inputs = batch[0] + targets = batch[1] + cmap = batch[2] + outputs = self(inputs) + + vae_loss = 0 + if self.model_name == "segresnetvae": + # segresnetvae forward returns an additional vae loss term + outputs, vae_loss = outputs + if ( + "dynunet" in self.model_name and self.model_config["deep_supervision"] + ): # output is a stacked tensor all of same shape instead of a list + outputs = torch.unbind(outputs, dim=1) + + # from https://arxiv.org/pdf/1810.11654.pdf, + # vae_loss > 0 if model = segresnetvae + loss = self.loss_function(outputs, targets, cmap) + 0.1 * vae_loss + return self.log_and_return("epoch_train_loss", loss) + + def validation_step(self, batch, batch_idx): + input_img = batch[0] + label = batch[1] + costmap = batch[2] + # fn = batch[3] + + outputs, vae_loss = model_inference( + self.model, + input_img, + self.args_inference, + squeeze=True, + extract_output_ch=False, + model_name=self.model_name, + softmax=False, + ) + # from https://arxiv.org/pdf/1810.11654.pdf, # mode = 'validation' + val_loss = self.loss_function(outputs, label, costmap) + 0.1 * vae_loss + self.log_and_return("val_loss", val_loss) + outputs = torch.nn.Softmax(dim=1)(outputs) + val_metric = compute_iou(outputs > 0.5, label, torch.unsqueeze(costmap, dim=1)) + self.log_and_return("val_iou", val_metric) + # save first validation image result + if batch_idx == 0: + imsave( + self.config["checkpoint_dir"] + + os.sep + + "validation_results" + + os.sep + + "epoch=" + + str(self.current_epoch) + + "_loss=" + + str(round(val_loss.item(), 3)) + + "_iou=" + + str(round(val_metric, 3)) + + ".tiff", + outputs[0, 1, :, :, :].detach().cpu().numpy(), + ) + + def test_step(self, batch, batch_idx): + img = batch["img"] + fn = batch["fn"][0] + tt = batch["tt"][0] + save_n_batches = batch["save_n_batches"].detach().cpu().numpy()[0] + args_inference = self.args_inference + to_numpy = True + if self.aggregate_img is not None: + to_numpy = False # prevent excess gpu->cpu data transfer + + output_img, _ = apply_on_image( + self.model, + img, + args_inference, + squeeze=False, + to_numpy=to_numpy, + softmax=True, + model_name=self.model_name, + extract_output_ch=True, + ) + + if self.aggregate_img is not None: + # initialize the aggregate img + i, j, k = batch["ijk"][0], batch["ijk"][1], batch["ijk"][2] + if fn not in self.aggregate_img: + self.aggregate_img[fn] = torch.zeros( + batch["im_shape"], dtype=torch.float32, device=self.device + ) + self.count_map[fn] = torch.zeros( + batch["im_shape"], dtype=torch.uint8, device=self.device + ) + self.batch_count[fn] = 0 + self.aggregate_img[fn][ + :, # preserve all channels + i : i + output_img.shape[2], + j : j + output_img.shape[3], + k : k + output_img.shape[4], + ] += torch.squeeze(output_img, dim=0) + + self.count_map[fn][ + :, # preserve all channels + i : i + output_img.shape[2], + j : j + output_img.shape[3], + k : k + output_img.shape[4], + ] += 1 + self.batch_count[fn] += 1 + else: + # not aggregating an image, save every batch + self.batch_count = {fn: 1} + + # only want to perform post-processing and saving once the aggregated image + # is complete or we're not aggregating an image + if self.batch_count[fn] % save_n_batches == 0: + from aicsimageio.writers.ome_tiff_writer import OmeTiffWriter + + if self.aggregate_img is not None: + # normalize for overlapping patches + output_img = self.aggregate_img[fn] / self.count_map[fn] + output_img = output_img.detach().cpu().numpy() + + if args_inference["mode"] != "folder": + out = minmax(output_img) + out = undo_resize(out, self.config) + if args_inference["Threshold"] > 0: + out = out > args_inference["Threshold"] + out = out.astype(np.uint8) + out[out > 0] = 255 + else: + if args_inference["Threshold"] < 0: + out = minmax(output_img) + out = undo_resize(out, self.config) + out = minmax(out) + else: + out = remove_small_objects( + output_img > args_inference["Threshold"], + min_size=2, + connectivity=1, + ) + out = out.astype(np.uint8) + out[out > 0] = 255 + out = np.squeeze(out, 0) # remove N dimension + path = self.config["OutputDir"] + os.sep + pathlib.PurePosixPath(fn).stem + if tt != -1: + path = path + "_T_" + f"{tt:03}" + path += "_struct_segmentation.tiff" + with OmeTiffWriter(path, overwrite_file=True) as writer: + writer.save( + data=out, + channel_names=[self.config["segmentation_name"]], + dimension_order="CZYX", + ) diff --git a/aicsmlsegment/Net3D/uNet_original.py b/aicsmlsegment/Net3D/uNet_original.py deleted file mode 100644 index 395352c..0000000 --- a/aicsmlsegment/Net3D/uNet_original.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class UNet3D(nn.Module): - def __init__(self, in_channel, n_classes, batchnorm_flag=True): - self.in_channel = in_channel - self.n_classes = n_classes - super(UNet3D, self).__init__() - - self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) - self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) - self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) - self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) - - self.pool1 = nn.MaxPool3d(2) - self.pool2 = nn.MaxPool3d(2) - self.pool3 = nn.MaxPool3d(2) - - self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) - self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) - self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) - self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) - self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) - self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) - - self.dc0 = nn.Conv3d(64, n_classes, 1) - self.softmax = F.log_softmax - - self.numClass = n_classes - - def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm2d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm2d(2*out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - - def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm2d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm2d(out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - def forward(self, x): - - down1 = self.ec1(x) - x1 = self.pool1(down1) - down2 = self.ec2(x1) - x2 = self.pool2(down2) - down3 = self.ec3(x2) - x3 = self.pool3(down3) - - u3 = self.ec4(x3) - - d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-4,-4))), 1) - u2 = self.dc3(d3) - d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-16,-16))), 1) - u1 = self.dc2(d2) - d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-40,-40))), 1) - u0 = self.dc1(d1) - out = self.dc0(u0) - - out = out.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - out = out.view(out.numel() // self.numClass, self.numClass) - out = self.softmax(out, dim=1) - - return out diff --git a/aicsmlsegment/Net3D/unet_xy.py b/aicsmlsegment/Net3D/unet_xy.py deleted file mode 100644 index 0d9ccf7..0000000 --- a/aicsmlsegment/Net3D/unet_xy.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class UNet3D(nn.Module): - def __init__(self, in_channel, n_classes, batchnorm_flag=True): - self.in_channel = in_channel - self.n_classes = n_classes - super(UNet3D, self).__init__() - - self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) # in --> 64 - self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 - self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 - self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 - - self.pool1 = nn.MaxPool3d((1,2,2)) - self.pool2 = nn.MaxPool3d((1,2,2)) - self.pool3 = nn.MaxPool3d((1,2,2)) - - self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - - self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) - self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) - self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) - - self.dc0 = nn.Conv3d(64, n_classes[0], 1) - - self.up2a = nn.ConvTranspose3d(256, n_classes[2], kernel_size=(1,8,8), stride=(1,4,4), padding=0, output_padding=0, bias=True) - self.up1a = nn.ConvTranspose3d(128, n_classes[1], kernel_size=(1,4,4), stride=(1,2,2), padding=0, output_padding=0, bias=True) - - self.conv2a = nn.Conv3d(n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True) - self.conv1a = nn.Conv3d(n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True) - - self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) - self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) - - #self.conv_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[0]+n_classes[1]+n_classes[2], 3, stride=1, padding=1, bias=True) - #self.predict_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[3], 1) - - self.softmax = F.log_softmax # nn.LogSoftmax(1) - - self.final_activation = nn.Softmax(dim=1) - - self.numClass = n_classes[0] - self.numClass1 = n_classes[1] - self.numClass2 = n_classes[2] - #self.numClass_combine = n_classes[3] - - def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(2*out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - - def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - def forward(self, x): - - down1 = self.ec1(x) - x1 = self.pool1(down1) - down2 = self.ec2(x1) - x2 = self.pool2(down2) - down3 = self.ec3(x2) - x3 = self.pool3(down3) - - u3 = self.ec4(x3) - - d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-2,-2))), 1) - u2 = self.dc3(d3) - - d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-6,-6))), 1) - u1 = self.dc2(d2) - - d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-10,-10))), 1) - u0 = self.dc1(d1) - - p0 = self.dc0(u0) - - p1a = F.pad(self.predict1a(self.conv1a(self.up1a(u1))),(-2,-2,-2,-2, -1, -1)) - p2a = F.pad(self.predict2a(self.conv2a(self.up2a(u2))),(-7,-7,-7,-7,-3,-3)) - - p0_final = p0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p0_final = p0_final.view(p0_final.numel() // self.numClass, self.numClass) - p0_final = self.softmax(p0_final, dim=1) - - p1_final = p1a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p1_final = p1_final.view(p1_final.numel() // self.numClass1, self.numClass1) - p1_final = self.softmax(p1_final, dim=1) - - p2_final = p2a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p2_final = p2_final.view(p2_final.numel() // self.numClass2, self.numClass2) - p2_final = self.softmax(p2_final, dim=1) - - ''' - p_combine0 = self.predict_final(self.conv_final(torch.cat((p0, p1a, p2a), 1))) # BCZYX - p_combine = p_combine0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p_combine = p_combine.view(p_combine.numel() // self.numClass_combine, self.numClass_combine) - p_combine = self.softmax(p_combine) - ''' - - return [p0_final, p1_final, p2_final] diff --git a/aicsmlsegment/Net3D/unet_xy_enlarge.py b/aicsmlsegment/Net3D/unet_xy_enlarge.py deleted file mode 100644 index 15eefc7..0000000 --- a/aicsmlsegment/Net3D/unet_xy_enlarge.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class UNet3D(nn.Module): - def __init__(self, in_channel, n_classes, down_ratio, batchnorm_flag=True): - self.in_channel = in_channel - self.n_classes = n_classes - super(UNet3D, self).__init__() - - k = down_ratio - - self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) # in --> 64 - self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 - self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 - self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 - - self.pool0 = nn.MaxPool3d((1,k,k)) - self.pool1 = nn.MaxPool3d((1,2,2)) - self.pool2 = nn.MaxPool3d((1,2,2)) - self.pool3 = nn.MaxPool3d((1,2,2)) - - self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) - self.up0 = nn.ConvTranspose3d(64, 64, kernel_size=(1,k,k), stride=(1,k,k), padding=0, output_padding=0, bias=True) - - self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) - self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) - self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) - self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag) - - self.predict0 = nn.Conv3d(64, n_classes[0], 1) - - self.up1a = nn.ConvTranspose3d(128, n_classes[1], kernel_size=(1,2*k,2*k), stride=(1,2*k,2*k), padding=0, output_padding=0, bias=True) - self.up2a = nn.ConvTranspose3d(256, n_classes[2], kernel_size=(1,4*k,4*k), stride=(1,4*k,4*k), padding=0, output_padding=0, bias=True) - - self.conv2a = nn.Conv3d(n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True) - self.conv1a = nn.Conv3d(n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True) - - self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) - self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) - - #self.conv_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[0]+n_classes[1]+n_classes[2], 3, stride=1, padding=1, bias=True) - #self.predict_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[3], 1) - - self.softmax = F.log_softmax # nn.LogSoftmax(1) - - self.final_activation = nn.Softmax(dim=1) - - self.numClass = n_classes[0] - self.numClass1 = n_classes[1] - self.numClass2 = n_classes[2] - - self.k = k - #self.numClass_combine = n_classes[3] - - def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(2*out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - - def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, - bias=True, batchnorm=False): - if batchnorm: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.BatchNorm3d(out_channels, affine=False), - nn.ReLU()) - else: - layer = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU(), - nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), - nn.ReLU()) - return layer - - def forward(self, x): - - k = self.k - - x0 = self.pool0(x) - - down1 = self.ec1(x0) - x1 = self.pool1(down1) - down2 = self.ec2(x1) - x2 = self.pool2(down2) - down3 = self.ec3(x2) - x3 = self.pool3(down3) - - u3 = self.ec4(x3) - - d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-2,-2))), 1) - u2 = self.dc3(d3) - - d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-6,-6))), 1) - u1 = self.dc2(d2) - - d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-10,-10))), 1) - u0 = self.dc1(d1) - - d0 = self.up0(u0) - - predict00 = self.predict0(self.dc0(d0)) - p0_final = predict00.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p0_final = p0_final.view(p0_final.numel() // self.numClass, self.numClass) - p0_final = self.softmax(p0_final, dim=1) - - p1a = F.pad(self.predict1a(self.conv1a(self.up1a(u1))),(-2*k-1,-2*k-1,-2*k-1,-2*k-1, -3, -3)) - p1_final = p1a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p1_final = p1_final.view(p1_final.numel() // self.numClass1, self.numClass1) - p1_final = self.softmax(p1_final, dim=1) - - p2a = F.pad(self.predict2a(self.conv2a(self.up2a(u2))),(-6*k-1,-6*k-1,-6*k-1,-6*k-1,-5,-5)) ## fix +5 - p2_final = p2a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p2_final = p2_final.view(p2_final.numel() // self.numClass2, self.numClass2) - p2_final = self.softmax(p2_final, dim=1) - - ''' - p_combine0 = self.predict_final(self.conv_final(torch.cat((p0, p1a, p2a), 1))) # BCZYX - p_combine = p_combine0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension - p_combine = p_combine.view(p_combine.numel() // self.numClass_combine, self.numClass_combine) - p_combine = self.softmax(p_combine) - ''' - - return [p0_final, p1_final, p2_final] diff --git a/aicsmlsegment/NetworkArchitecture/dynunet.py b/aicsmlsegment/NetworkArchitecture/dynunet.py new file mode 100644 index 0000000..1ee6615 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/dynunet.py @@ -0,0 +1,247 @@ +from typing import List, Optional, Sequence, Union + +import torch +import torch.nn as nn +from torch.nn.functional import interpolate + +from monai.networks.blocks.dynunet_block import ( + UnetBasicBlock, + UnetOutBlock, + UnetResBlock, + UnetUpBlock, +) +from monai.networks.nets.dynunet import DynUNetSkipLayer + + +class DynUNet(nn.Module): + """ + modified from https://docs.monai.io/en/latest/_modules/monai/networks/nets/dynunet.html#DynUNet # noqa E501 + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Sequence[Union[Sequence[int], int]], + strides: Sequence[Union[Sequence[int], int]], + upsample_kernel_size: Sequence[Union[Sequence[int], int]], + norm_name: str = "instance", + deep_supervision: bool = False, + deep_supr_num: int = 1, + res_block: bool = False, + ): + super(DynUNet, self).__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.strides = strides + self.upsample_kernel_size = upsample_kernel_size + self.norm_name = norm_name + self.conv_block = UnetResBlock if res_block else UnetBasicBlock + self.filters = [ + 2 ** (5 + i) for i in range(len(strides)) + ] # REMOVE FILTER LIMIT + self.input_block = self.get_input_block() + self.downsamples = self.get_downsamples() + self.bottleneck = self.get_bottleneck() + self.upsamples = self.get_upsamples() + self.output_block = self.get_output_block(0) + self.deep_supervision = deep_supervision + self.deep_supervision_heads = self.get_deep_supervision_heads() + self.deep_supr_num = deep_supr_num + self.apply(self.initialize_weights) + self.check_kernel_stride() + self.check_deep_supr_num() + + # initialize the typed list of supervision head outputs so that Torchscript + # can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * ( + len(self.deep_supervision_heads) + 1 + ) + + def create_skips(index, downsamples, upsamples, superheads, bottleneck): + + if len(downsamples) != len(upsamples): + raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") + if (len(downsamples) - len(superheads)) not in (1, 0): + raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") + + if ( + len(downsamples) == 0 + ): # bottom of the network, pass the bottleneck block + return bottleneck + if index == 0: # don't associate a supervision head with self.input_block + current_head, rest_heads = nn.Identity(), superheads + elif ( + not self.deep_supervision + ): # bypass supervision heads by passing nn.Identity in place of a real one + current_head, rest_heads = nn.Identity(), superheads[1:] + else: + current_head, rest_heads = superheads[0], superheads[1:] + + # create the next layer down, this will stop at the bottleneck layer + next_layer = create_skips( + 1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck + ) + + return DynUNetSkipLayer( + index, + self.heads, + downsamples[0], + upsamples[0], + current_head, + next_layer, + ) + + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.deep_supervision_heads, + self.bottleneck, + ) + + def check_kernel_stride(self): + kernels, strides = self.kernel_size, self.strides + error_msg = ( + "length of kernel_size and strides should be the same, and no less than 3." + ) + if not (len(kernels) == len(strides) and len(kernels) >= 3): + raise AssertionError(error_msg) + + for idx in range(len(kernels)): + kernel, stride = kernels[idx], strides[idx] + if not isinstance(kernel, int): + error_msg = ( + f"length of kernel_size in block {idx} should " + "be the same as spatial_dims." + ) + if len(kernel) != self.spatial_dims: + raise AssertionError(error_msg) + if not isinstance(stride, int): + error_msg = ( + f"length of stride in block {idx} should be " + "the same as spatial_dims." + ) + if len(stride) != self.spatial_dims: + raise AssertionError(error_msg) + + def check_deep_supr_num(self): + deep_supr_num, strides = self.deep_supr_num, self.strides + num_up_layers = len(strides) - 1 + if deep_supr_num >= num_up_layers: + raise AssertionError( + "deep_supr_num should be less than the number of up sample layers." + ) + if deep_supr_num < 1: + raise AssertionError("deep_supr_num should be larger than 0.") + + def forward(self, x): + out = self.skip_layers(x) + out = self.output_block(out) + if self.training and self.deep_supervision: + out_all = [out] + feature_maps = self.heads[1 : self.deep_supr_num + 1] + for feature_map in feature_maps: + out_all.append(interpolate(feature_map, out.shape[2:])) + return torch.stack(out_all, dim=1) + return out + + def get_input_block(self): + return self.conv_block( + self.spatial_dims, + self.in_channels, + self.filters[0], + self.kernel_size[0], + self.strides[0], + self.norm_name, + ) + + def get_bottleneck(self): + return self.conv_block( + self.spatial_dims, + self.filters[-2], + self.filters[-1], + self.kernel_size[-1], + self.strides[-1], + self.norm_name, + ) + + def get_output_block(self, idx: int): + return UnetOutBlock( + self.spatial_dims, + self.filters[idx], + self.out_channels, + ) + + def get_downsamples(self): + inp, out = self.filters[:-2], self.filters[1:-1] + strides, kernel_size = self.strides[1:-1], self.kernel_size[1:-1] + return self.get_module_list(inp, out, kernel_size, strides, self.conv_block) + + def get_upsamples(self): + inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] + strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] + upsample_kernel_size = self.upsample_kernel_size[::-1] + return self.get_module_list( + inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size + ) + + def get_module_list( + self, + in_channels: List[int], + out_channels: List[int], + kernel_size: Sequence[Union[Sequence[int], int]], + strides: Sequence[Union[Sequence[int], int]], + conv_block: nn.Module, + upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, + ): + layers = [] + if upsample_kernel_size is not None: + for in_c, out_c, kernel, stride, up_kernel in zip( + in_channels, out_channels, kernel_size, strides, upsample_kernel_size + ): + params = { + "spatial_dims": self.spatial_dims, + "in_channels": in_c, + "out_channels": out_c, + "kernel_size": kernel, + "stride": stride, + "norm_name": self.norm_name, + "upsample_kernel_size": up_kernel, + } + layer = conv_block(**params) + layers.append(layer) + else: + for in_c, out_c, kernel, stride in zip( + in_channels, out_channels, kernel_size, strides + ): + params = { + "spatial_dims": self.spatial_dims, + "in_channels": in_c, + "out_channels": out_c, + "kernel_size": kernel, + "stride": stride, + "norm_name": self.norm_name, + } + layer = conv_block(**params) + layers.append(layer) + return nn.ModuleList(layers) + + def get_deep_supervision_heads(self): + return nn.ModuleList( + [self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)] + ) + + @staticmethod + def initialize_weights(module): + name = module.__class__.__name__.lower() + if "conv3d" in name or "conv2d" in name: + nn.init.kaiming_normal_(module.weight, a=0.01) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif "norm" in name: + nn.init.normal_(module.weight, 1.0, 0.02) + nn.init.zeros_(module.bias) diff --git a/aicsmlsegment/NetworkArchitecture/sdunet_xy.py b/aicsmlsegment/NetworkArchitecture/sdunet_xy.py new file mode 100644 index 0000000..e5a06c8 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/sdunet_xy.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn + + +class UNet3D(nn.Module): + + """ + 3D adaptation of https://arxiv.org/ftp/arxiv/papers/2004/2004.03466.pdf + This network is similar to a standard UNet, but it exchanges the 2 stacked + convolutions followed by pooling for encoding with concatenation of a series of + dilated convolutions. The Decoder unit is similar, except pooling is replaced by + upsampling and concatenation with the encoder map. + """ + + def __init__(self, in_channel, n_classes, loss, test_mode): + self.in_channel = in_channel + self.n_classes = n_classes[0] + self.loss = loss + super(UNet3D, self).__init__() + + self.input = self.conv_relu(self.in_channel, 32, 1, 1) + self.one_by_oneconv = self.oneconv(32, self.n_classes) + self.pool = nn.MaxPool3d((1, 2, 2)) + + # self.conv1 = self.conv_relu( + # 32, 32, stride=(1, 2, 2), kernel=(1, 2, 2), norm=False + # ) + # self.conv2 = self.conv_relu( + # 64, 64, stride=(1, 2, 2), kernel=(1, 2, 2), norm=False + # ) + # self.conv3 = self.conv_relu( + # 128, 128, stride=(1, 2, 2), kernel=(1, 2, 2), norm=False + # ) + # self.conv4 = self.conv_relu( + # 256, 256, stride=(1, 2, 2), kernel=(1, 2, 2), norm=False + # ) + + self.upsample_64 = self.upsample(64) + self.upsample_128 = self.upsample(128) + self.upsample_256 = self.upsample(256) + self.upsample_512 = self.upsample(512) + + self.down_32_1 = self.conv_relu(32, 32, padding=1, dilation=1) + self.down_32_3 = self.conv_relu(32, 16, padding=3, dilation=3) + self.down_32_6 = self.conv_relu(16, 8, padding=6, dilation=6) + self.down_32_9 = self.conv_relu(8, 4, padding=9, dilation=9) + self.down_32_12 = self.conv_relu(4, 4, padding=12, dilation=12) + self.down_block1_fns = [ + # self.conv1, + self.pool, + self.down_32_1, + self.down_32_3, + self.down_32_6, + self.down_32_9, + self.down_32_12, + ] + + self.down_64_1 = self.conv_relu(64, 64, padding=1, dilation=1) + self.down_64_3 = self.conv_relu(64, 32, padding=3, dilation=3) + self.down_64_6 = self.conv_relu(32, 16, padding=6, dilation=6) + self.down_64_9 = self.conv_relu(16, 8, padding=9, dilation=9) + self.down_64_12 = self.conv_relu(8, 8, padding=12, dilation=12) + self.down_block2_fns = [ + # self.conv2, + self.pool, + self.down_64_1, + self.down_64_3, + self.down_64_6, + self.down_64_9, + self.down_64_12, + ] + + self.down_128_1 = self.conv_relu(128, 128, padding=1, dilation=1) + self.down_128_3 = self.conv_relu(128, 64, padding=3, dilation=3) + self.down_128_6 = self.conv_relu(64, 32, padding=6, dilation=6) + self.down_128_9 = self.conv_relu(32, 16, padding=9, dilation=9) + self.down_128_12 = self.conv_relu(16, 16, padding=12, dilation=12) + self.down_block3_fns = [ + # self.conv3, + self.pool, + self.down_128_1, + self.down_128_3, + self.down_128_6, + self.down_128_9, + self.down_128_12, + ] + + self.down_256_1 = self.conv_relu(256, 256, padding=1, dilation=1) + self.down_256_3 = self.conv_relu(256, 128, padding=3, dilation=3) + self.down_256_6 = self.conv_relu(128, 64, padding=6, dilation=6) + self.down_256_9 = self.conv_relu(64, 32, padding=9, dilation=9) + self.down_256_12 = self.conv_relu(32, 32, padding=12, dilation=12) + self.down_block4_fns = [ + # self.conv4, + self.pool, + self.down_256_1, + self.down_256_3, + self.down_256_6, + self.down_256_9, + self.down_256_12, + ] + + self.up_768_1 = self.conv_relu(256 + 512, 128, padding=1, dilation=1) + self.up_512_3 = self.conv_relu(128, 64, padding=3, dilation=3) + self.up_512_6 = self.conv_relu(64, 32, padding=6, dilation=6) + self.up_512_9 = self.conv_relu(32, 16, padding=9, dilation=9) + self.up_512_12 = self.conv_relu(16, 16, padding=12, dilation=12) + self.up_block1_fns = [ + self.upsample_512, + self.up_768_1, + self.up_512_3, + self.up_512_6, + self.up_512_9, + self.up_512_12, + ] + + self.up_384_1 = self.conv_relu(128 + 256, 64, padding=1, dilation=1) + self.up_256_3 = self.conv_relu(64, 32, padding=3, dilation=3) + self.up_256_6 = self.conv_relu(32, 16, padding=6, dilation=6) + self.up_256_9 = self.conv_relu(16, 8, padding=9, dilation=9) + self.up_256_12 = self.conv_relu(8, 8, padding=12, dilation=12) + self.up_block2_fns = [ + self.upsample_256, + self.up_384_1, + self.up_256_3, + self.up_256_6, + self.up_256_9, + self.up_256_12, + ] + + self.up_192_1 = self.conv_relu(64 + 128, 32, padding=1, dilation=1) + self.up_128_3 = self.conv_relu(32, 16, padding=3, dilation=3) + self.up_128_6 = self.conv_relu(16, 8, padding=6, dilation=6) + self.up_128_9 = self.conv_relu(8, 4, padding=9, dilation=9) + self.up_128_12 = self.conv_relu(4, 4, padding=12, dilation=12) + self.up_block3_fns = [ + self.upsample_128, + self.up_192_1, + self.up_128_3, + self.up_128_6, + self.up_128_9, + self.up_128_12, + ] + + self.up_64_1 = self.conv_relu(64, 16, padding=1, dilation=1) + self.up_64_3 = self.conv_relu(16, 8, padding=3, dilation=3) + self.up_64_6 = self.conv_relu(8, 4, padding=6, dilation=6) + self.up_64_9 = self.conv_relu(4, 2, padding=9, dilation=9) + self.up_64_12 = self.conv_relu(2, 2, padding=12, dilation=12) + self.up_block4_fns = [ + self.upsample_64, + self.up_64_1, + self.up_64_3, + self.up_64_6, + self.up_64_9, + self.up_64_12, + ] + + def upsample(self, in_channels): + return nn.ConvTranspose3d( + in_channels, + in_channels, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + + def oneconv(self, in_channels, out_channels): + return nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + 1, + stride=1, + padding=0, + bias=True, + ), + ) + + def conv_relu( + self, + in_channels, + out_channels, + padding=0, + dilation=1, + stride=1, + kernel=3, + norm=True, + ): + if norm: + return nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel, + stride=stride, + padding=padding, + bias=True, + dilation=dilation, + ), + nn.InstanceNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + return nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel, + stride=stride, + padding=padding, + bias=True, + dilation=dilation, + ), + # nn.ReLU(), + ) + + def down_block(self, x, fns): + """ + x: input tensor + fns: list of functions to be sequentially applied to x + """ + outputs = [x] + # output of previous step as input to next step + for fn in fns: + outputs.append(fn(outputs[-1])) + # concat all output except original input + return torch.cat(outputs[-5:], dim=1) + + def up_block(self, x, cat, fns): + upsample = fns[0](x) # don't apply upsample later in fns loop + if cat is not None: + upsample = torch.cat((upsample, cat), dim=1) + outputs = [upsample] + for fn in fns[1:]: + outputs.append(fn(outputs[-1])) + return torch.cat(outputs[-5:], dim=1) + + def forward(self, x): + x = self.input(x) # 1 channel -> 32 channels + down1 = self.down_block(x, self.down_block1_fns) # 32 ch -> 64 ch, pool xy + down2 = self.down_block(down1, self.down_block2_fns) # 64->128ch, pool xy + down3 = self.down_block(down2, self.down_block3_fns) # 128->56ch, pool xy + down4 = self.down_block(down3, self.down_block4_fns) # 256->512ch, pool xy + + up1 = self.up_block(down4, down3, self.up_block1_fns) # 512->256, upsample xy + up2 = self.up_block(up1, down2, self.up_block2_fns) # 256->128, upsample xy + up3 = self.up_block(up2, down1, self.up_block3_fns) # 128->64, upsample xy + up4 = self.up_block(up3, None, self.up_block4_fns) # 64->32 upsample xy + out = self.one_by_oneconv(up4) # 32->2ch + + if self.loss == "Aux": + return [out] + return out diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy.py b/aicsmlsegment/NetworkArchitecture/unet_xy.py new file mode 100644 index 0000000..793a958 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__(self, in_channel, n_classes, test_mode, batchnorm_flag=True): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag + ) # in --> 64 + self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 + self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 + self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 + + self.pool1 = nn.MaxPool3d((1, 2, 2)) + self.pool2 = nn.MaxPool3d((1, 2, 2)) + self.pool3 = nn.MaxPool3d((1, 2, 2)) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) + self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) + self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) + + self.dc0 = nn.Conv3d(64, n_classes[0], 1) + + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 8, 8), + stride=(1, 4, 4), + padding=0, + output_padding=0, + bias=True, + ) + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 4, 4), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + # self.numClass_combine = n_classes[3] + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + + down1 = self.ec1(x) + x1 = self.pool1(down1) + down2 = self.ec2(x1) + x2 = self.pool2(down2) + down3 = self.ec3(x2) + x3 = self.pool3(down3) + + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), F.pad(down3, (-4, -4, -4, -4, -2, -2))), 1) + u2 = self.dc3(d3) + + d2 = torch.cat((self.up2(u2), F.pad(down2, (-16, -16, -16, -16, -6, -6))), 1) + u1 = self.dc2(d2) + + d1 = torch.cat((self.up1(u1), F.pad(down1, (-40, -40, -40, -40, -10, -10))), 1) + u0 = self.dc1(d1) + + p0 = self.dc0(u0) + if self.test_mode: + return [p0] + + p1a = F.pad( + self.predict1a(self.conv1a(self.up1a(u1))), (-2, -2, -2, -2, -1, -1) + ) + p2a = F.pad( + self.predict2a(self.conv2a(self.up2a(u2))), (-7, -7, -7, -7, -3, -3) + ) + + return [p0, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_original.py b/aicsmlsegment/NetworkArchitecture/unet_xy_original.py new file mode 100644 index 0000000..f860c69 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_original.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__(self, in_channel, n_classes, test_mode, batchnorm_flag=True): + self.in_channel = in_channel + self.n_classes = n_classes + super(UNet3D, self).__init__() + + self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) + self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) + self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) + self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) + + self.pool1 = nn.MaxPool3d(2) + self.pool2 = nn.MaxPool3d(2) + self.pool3 = nn.MaxPool3d(2) + + self.up3 = nn.ConvTranspose3d( + 512, 512, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True + ) + self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) + self.up2 = nn.ConvTranspose3d( + 256, 256, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True + ) + self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) + self.up1 = nn.ConvTranspose3d( + 128, 128, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True + ) + self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) + + self.dc0 = nn.Conv3d(64, n_classes, 1) + + self.numClass = n_classes + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm2d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm2d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm2d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm2d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + + down1 = self.ec1(x) + x1 = self.pool1(down1) + down2 = self.ec2(x1) + x2 = self.pool2(down2) + down3 = self.ec3(x2) + x3 = self.pool3(down3) + + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), F.pad(down3, (-4, -4, -4, -4, -4, -4))), 1) + u2 = self.dc3(d3) + d2 = torch.cat((self.up2(u2), F.pad(down2, (-16, -16, -16, -16, -16, -16))), 1) + u1 = self.dc2(d2) + d1 = torch.cat((self.up1(u1), F.pad(down1, (-40, -40, -40, -40, -40, -40))), 1) + u0 = self.dc1(d1) + out = self.dc0(u0) + + return [out] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom.py new file mode 100644 index 0000000..0e7c583 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + """ + unet_xy_zoom, see Figure 20 in https://www.biorxiv.org/content/10.1101/491035v2 + """ + + def __init__( + self, in_channel, n_classes, down_ratio, batchnorm_flag=True, test_mode=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag + ) # in --> 64 + self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 + self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 + self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 + + self.pool0 = nn.MaxPool3d((1, k, k)) + self.pool1 = nn.MaxPool3d((1, 2, 2)) + self.pool2 = nn.MaxPool3d((1, 2, 2)) + self.pool3 = nn.MaxPool3d((1, 2, 2)) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) + self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) + self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag) + + self.predict0 = nn.Conv3d(64, n_classes[0], 1) + + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 2 * k, 2 * k), + stride=(1, 2 * k, 2 * k), + padding=0, + output_padding=0, + bias=True, + ) + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 4 * k, 4 * k), + stride=(1, 4 * k, 4 * k), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + self.k = k + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + + k = self.k + + x0 = self.pool0(x) + + down1 = self.ec1(x0) + x1 = self.pool1(down1) + down2 = self.ec2(x1) + x2 = self.pool2(down2) + down3 = self.ec3(x2) + x3 = self.pool3(down3) + + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), F.pad(down3, (-4, -4, -4, -4, -2, -2))), 1) + u2 = self.dc3(d3) + + d2 = torch.cat((self.up2(u2), F.pad(down2, (-16, -16, -16, -16, -6, -6))), 1) + u1 = self.dc2(d2) + + d1 = torch.cat((self.up1(u1), F.pad(down1, (-40, -40, -40, -40, -10, -10))), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + + if self.test_mode: + return [predict00] + + p1a = F.pad( + self.predict1a(self.conv1a(self.up1a(u1))), + (-2 * k - 1, -2 * k - 1, -2 * k - 1, -2 * k - 1, -3, -3), + ) + + p2a = F.pad( + self.predict2a(self.conv2a(self.up2a(u2))), + (-6 * k - 1, -6 * k - 1, -6 * k - 1, -6 * k - 1, -5, -5), + ) # fix +5 + + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py new file mode 100644 index 0000000..2adb776 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn + + +class UNet3D(nn.Module): + def __init__( + self, in_channel, n_classes, down_ratio, test_mode=True, batchnorm_flag=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # in --> 64 + self.ec2 = self.encoder( + 64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 64 --> 128 + self.ec3 = self.encoder( + 128, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 128 --> 256 + self.ec4 = self.encoder( + 256, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 256 -->512 + + self.pool0 = nn.MaxPool3d((1, k, k)) + self.pool1 = nn.MaxPool3d((1, 2, 2)) + self.pool2 = nn.MaxPool3d((1, 2, 2)) + self.pool3 = nn.MaxPool3d((1, 2, 2)) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder( + 256 + 512, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc2 = self.decoder( + 128 + 256, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc1 = self.decoder( + 64 + 128, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1)) + + self.predict0 = nn.Conv3d(64, n_classes[0], 1) + + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 2 * k, 2 * k), + stride=(1, 2 * k, 2 * k), + padding=0, + output_padding=0, + bias=True, + ) + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 4 * k, 4 * k), + stride=(1, 4 * k, 4 * k), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=(1, 1, 1), bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=(1, 1, 1), bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + # a property will be used when calling this model in model zoo + self.final_activation = nn.Softmax(dim=1) + + self.k = k + # self.numClass_combine = n_classes[3] + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + x0 = self.pool0(x) + down1 = self.ec1(x0) + x1 = self.pool1(down1) + down2 = self.ec2(x1) + + x2 = self.pool2(down2) + down3 = self.ec3(x2) + x3 = self.pool3(down3) + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), down3), 1) + u2 = self.dc3(d3) + d2 = torch.cat((self.up2(u2), down2), 1) + u1 = self.dc2(d2) + d1 = torch.cat((self.up1(u1), down1), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + # predict00 = F.pad(predict00, (-30, -30, -30, -30, -5, -5)) + if self.test_mode: + return [predict00] + + p1a = self.predict1a(self.conv1a(self.up1a(u1))) + p2a = self.predict2a(self.conv2a(self.up2a(u2))) # fix +5 + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_nopadz_stridedconv.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_nopadz_stridedconv.py new file mode 100644 index 0000000..41c33ab --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_nopadz_stridedconv.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__( + self, in_channel, n_classes, down_ratio, test_mode, batchnorm_flag=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) # in --> 64 + self.ec2 = self.encoder( + 64, 64, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) # 64 --> 128 + self.ec3 = self.encoder( + 128, 128, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) # 128 --> 256 + self.ec4 = self.encoder( + 256, 256, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) # 256 -->512 + + self.conv0 = nn.Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + ) + self.conv1 = nn.Conv3d( + in_channels=64, + out_channels=64, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv2 = nn.Conv3d( + in_channels=128, + out_channels=128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv3 = nn.Conv3d( + in_channels=256, + out_channels=256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder( + 256 + 512, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc2 = self.decoder( + 128 + 256, 128, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) + self.dc1 = self.decoder( + 64 + 128, 64, batchnorm=batchnorm_flag, padding=(0, 1, 1) + ) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag, padding=(0, 1, 1)) + + self.predict0 = nn.Conv3d(64, n_classes[0], 1) + + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 2 * k, 2 * k), + stride=(1, 2 * k, 2 * k), + padding=0, + output_padding=0, + bias=True, + ) + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 4 * k, 4 * k), + stride=(1, 4 * k, 4 * k), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=(0, 1, 1), bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=(0, 1, 1), bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + self.k = k + # self.numClass_combine = n_classes[3] + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + x0 = self.conv0(x) + down1 = self.ec1(x0) + x1 = self.conv1(down1) + down2 = self.ec2(x1) + x2 = self.conv2(down2) + down3 = self.ec3(x2) + x3 = self.conv3(down3) + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), F.pad(down3, (0, 0, 0, 0, -2, -2))), 1) + u2 = self.dc3(d3) + d2 = torch.cat((self.up2(u2), F.pad(down2, (0, 0, 0, 0, -4, -4))), 1) + u1 = self.dc2(d2) + d1 = torch.cat((self.up1(u1), F.pad(down1, (0, 0, 0, 0, -8, -8))), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + if self.test_mode: + return [predict00] + + p1a = F.pad(self.predict1a(self.conv1a(self.up1a(u1))), (0, 0, 0, 0, -3, -3)) + + p2a = F.pad( + self.predict2a(self.conv2a(self.up2a(u2))), (0, 0, 0, 0, -5, -5) + ) # fix +5 + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_single.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_single.py new file mode 100644 index 0000000..140e2bc --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_single.py @@ -0,0 +1,225 @@ +import torch +import torch.nn as nn + + +class UNet3D(nn.Module): + def __init__( + self, in_channel, n_classes, down_ratio, test_mode=True, batchnorm_flag=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # in --> 64 + self.ec2 = self.encoder( + 64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 64 --> 128 + self.ec3 = self.encoder( + 128, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 128 --> 256 + self.ec4 = self.encoder( + 256, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 256 -->512 + + self.pool0 = nn.MaxPool3d((1, k, k)) + self.pool1 = nn.MaxPool3d((1, 2, 2)) + self.pool2 = nn.MaxPool3d((1, 2, 2)) + self.pool3 = nn.MaxPool3d((1, 2, 2)) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder( + 256 + 512, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc2 = self.decoder( + 128 + 256, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc1 = self.decoder( + 64 + 128, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1)) + + self.predict0 = nn.Conv3d(64, n_classes, 1) + + self.numClass = n_classes + + # a property will be used when calling this model in model zoo + self.final_activation = nn.Softmax(dim=1) + + self.k = k + # self.numClass_combine = n_classes[3] + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + x0 = self.pool0(x) + down1 = self.ec1(x0) + x1 = self.pool1(down1) + down2 = self.ec2(x1) + + x2 = self.pool2(down2) + down3 = self.ec3(x2) + x3 = self.pool3(down3) + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), down3), 1) + u2 = self.dc3(d3) + d2 = torch.cat((self.up2(u2), down2), 1) + u1 = self.dc2(d2) + d1 = torch.cat((self.up1(u1), down1), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + return predict00 diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_stridedconv.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_stridedconv.py new file mode 100644 index 0000000..2543bc3 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad_stridedconv.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__( + self, in_channel, n_classes, down_ratio, test_mode, batchnorm_flag=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # in --> 64 + self.ec2 = self.encoder( + 64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 64 --> 128 + self.ec3 = self.encoder( + 128, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 128 --> 256 + self.ec4 = self.encoder( + 256, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) # 256 -->512 + + self.conv0 = nn.Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + ) + self.conv1 = nn.Conv3d( + in_channels=64, + out_channels=64, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv2 = nn.Conv3d( + in_channels=128, + out_channels=128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv3 = nn.Conv3d( + in_channels=256, + out_channels=256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder( + 256 + 512, 256, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc2 = self.decoder( + 128 + 256, 128, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc1 = self.decoder( + 64 + 128, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1) + ) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag, padding=(1, 1, 1)) + + self.predict0 = nn.Conv3d(64, n_classes[0], 1) + + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 2 * k, 2 * k), + stride=(1, 2 * k, 2 * k), + padding=0, + output_padding=0, + bias=True, + ) + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 4 * k, 4 * k), + stride=(1, 4 * k, 4 * k), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=(1, 1, 1), bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=(1, 1, 1), bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + self.k = k + # self.numClass_combine = n_classes[3] + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + x0 = self.conv0(x) + down1 = self.ec1(x0) + x1 = self.conv1(down1) + down2 = self.ec2(x1) + + x2 = self.conv2(down2) + down3 = self.ec3(x2) + x3 = self.conv3(down3) + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), down3), 1) + u2 = self.dc3(d3) + d2 = torch.cat((self.up2(u2), down2), 1) + u1 = self.dc2(d2) + d1 = torch.cat((self.up1(u1), down1), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + + if self.test_mode: + return [predict00] + + p1a = self.predict1a(self.conv1a(self.up1a(u1))) + + p2a = self.predict2a(self.conv2a(self.up2a(u2))) # fix +5 + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_dilated.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_dilated.py new file mode 100644 index 0000000..f134ee8 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_dilated.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + def __init__( + self, in_channel, n_classes, down_ratio, batchnorm_flag=True, test_mode=False + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, + 8, + batchnorm=batchnorm_flag, + padding=(1, k, k), + dilation=(1, k, k), + ) # in --> 16 + self.ec2 = self.encoder( + 16, 16, batchnorm=batchnorm_flag, padding=(1, 2, 2), dilation=(1, 2, 2) + ) # 16 --> 32 + self.ec3 = self.encoder( + 32, 32, batchnorm=batchnorm_flag, padding=(1, 2, 2), dilation=(1, 2, 2) + ) # 32 --> 64 + self.ec4 = self.encoder( + 64, 64, batchnorm=batchnorm_flag, padding=(1, 2, 2), dilation=(1, 2, 2) + ) # 64 -->128 + + self.dc3 = self.decoder( + 64 + 128, + 64, + batchnorm=batchnorm_flag, + padding=(1, 2, 2), + dilation=(1, 2, 2), + ) + self.dc2 = self.decoder( + 32 + 64, 32, batchnorm=batchnorm_flag, padding=(1, 2, 2), dilation=(1, 2, 2) + ) + self.dc1 = self.decoder( + 16 + 32, 16, batchnorm=batchnorm_flag, padding=(1, 2, 2), dilation=(1, 2, 2) + ) + self.dc0 = self.decoder( + 16, 16, batchnorm=batchnorm_flag, padding=(1, k, k), dilation=(1, k, k) + ) + + self.predict0 = nn.Conv3d(16, n_classes[0], 1) + + self.conv2a = nn.Conv3d( + 64, n_classes[2], 3, stride=1, padding=(1, 1, 1), bias=True + ) + self.conv1a = nn.Conv3d( + 32, n_classes[1], 3, stride=1, padding=(1, 1, 1), bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + self.k = k + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + dilation=1, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + dilation=1, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + dilation=dilation, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + down1 = self.ec1(x) + down2 = self.ec2(down1) + down3 = self.ec3(down2) + u3 = self.ec4(down3) + + d3 = torch.cat((u3, down3), 1) + u2 = self.dc3(d3) + d2 = torch.cat((u2, down2), 1) + u1 = self.dc2(d2) + d1 = torch.cat((u1, down1), 1) + u0 = self.dc1(d1) + + predict00 = self.predict0(self.dc0(u0)) + if self.test_mode: + return [predict00] + + p1a = self.predict1a(self.conv1a(u1)) + p2a = self.predict2a(self.conv2a(u2)) + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_stridedconv.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_stridedconv.py new file mode 100644 index 0000000..241f18d --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_stridedconv.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet3D(nn.Module): + """ + unet_xy_zoom, see Figure 20 in https://www.biorxiv.org/content/10.1101/491035v2 + """ + + def __init__( + self, in_channel, n_classes, down_ratio, test_mode, batchnorm_flag=True + ): + self.in_channel = in_channel + self.n_classes = n_classes + self.test_mode = test_mode + super(UNet3D, self).__init__() + + k = down_ratio + + self.ec1 = self.encoder( + self.in_channel, 32, batchnorm=batchnorm_flag + ) # in --> 64 + self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 + self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 + self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 + + self.conv0 = nn.Conv3d( + in_channels=1, + out_channels=1, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + ) + self.conv1 = nn.Conv3d( + in_channels=64, + out_channels=64, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv2 = nn.Conv3d( + in_channels=128, + out_channels=128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + self.conv3 = nn.Conv3d( + in_channels=256, + out_channels=256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + ) + + self.up3 = nn.ConvTranspose3d( + 512, + 512, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up2 = nn.ConvTranspose3d( + 256, + 256, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up1 = nn.ConvTranspose3d( + 128, + 128, + kernel_size=(1, 2, 2), + stride=(1, 2, 2), + padding=0, + output_padding=0, + bias=True, + ) + self.up0 = nn.ConvTranspose3d( + 64, + 64, + kernel_size=(1, k, k), + stride=(1, k, k), + padding=0, + output_padding=0, + bias=True, + ) + + self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) + self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) + self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) + self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag) + + self.predict0 = nn.Conv3d(64, n_classes[0], 1) + + self.up1a = nn.ConvTranspose3d( + 128, + n_classes[1], + kernel_size=(1, 2 * k, 2 * k), + stride=(1, 2 * k, 2 * k), + padding=0, + output_padding=0, + bias=True, + ) + self.up2a = nn.ConvTranspose3d( + 256, + n_classes[2], + kernel_size=(1, 4 * k, 4 * k), + stride=(1, 4 * k, 4 * k), + padding=0, + output_padding=0, + bias=True, + ) + + self.conv2a = nn.Conv3d( + n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True + ) + self.conv1a = nn.Conv3d( + n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True + ) + + self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) + self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) + + self.softmax = F.log_softmax # nn.LogSoftmax(1) + + self.final_activation = nn.Softmax(dim=1) + + self.numClass = n_classes[0] + self.numClass1 = n_classes[1] + self.numClass2 = n_classes[2] + + self.k = k + + def encoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(2 * out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + 2 * out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def decoder( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias=True, + batchnorm=False, + ): + if batchnorm: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.BatchNorm3d(out_channels, affine=False), + nn.ReLU(), + ) + else: + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + nn.Conv3d( + out_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + nn.ReLU(), + ) + return layer + + def forward(self, x): + + k = self.k + + x0 = self.conv0(x) + + down1 = self.ec1(x0) + x1 = self.conv1(down1) + down2 = self.ec2(x1) + x2 = self.conv2(down2) + down3 = self.ec3(x2) + x3 = self.conv3(down3) + + u3 = self.ec4(x3) + + d3 = torch.cat((self.up3(u3), F.pad(down3, (-4, -4, -4, -4, -2, -2))), 1) + u2 = self.dc3(d3) + + d2 = torch.cat((self.up2(u2), F.pad(down2, (-16, -16, -16, -16, -6, -6))), 1) + u1 = self.dc2(d2) + + d1 = torch.cat((self.up1(u1), F.pad(down1, (-40, -40, -40, -40, -10, -10))), 1) + u0 = self.dc1(d1) + + d0 = self.up0(u0) + + predict00 = self.predict0(self.dc0(d0)) + if self.test_mode: + return [predict00] + + p1a = F.pad( + self.predict1a(self.conv1a(self.up1a(u1))), + (-2 * k - 1, -2 * k - 1, -2 * k - 1, -2 * k - 1, -3, -3), + ) + p2a = F.pad( + self.predict2a(self.conv2a(self.up2a(u2))), + (-6 * k - 1, -6 * k - 1, -6 * k - 1, -6 * k - 1, -5, -5), + ) # fix +5 + + return [predict00, p1a, p2a] diff --git a/aicsmlsegment/NetworkArchitecture/vnet.py b/aicsmlsegment/NetworkArchitecture/vnet.py new file mode 100644 index 0000000..813ba99 --- /dev/null +++ b/aicsmlsegment/NetworkArchitecture/vnet.py @@ -0,0 +1,101 @@ +# This implementation was adapted from MONAI: +# https://docs.monai.io/en/latest/_modules/monai/networks/nets/vnet.html +# The adapted version allows to use more feature maps + +from monai.networks.nets.vnet import ( + DownTransition, + UpTransition, + OutputTransition, + InputTransition, +) +from typing import Dict, Tuple, Union + +import torch.nn as nn + + +class VNet(nn.Module): + """ + Adapted from https://docs.monai.io/en/latest/_modules/monai/networks/nets/vnet.html + to allow more feature maps than original implementation + + V-Net: `Fully Convolutional Neural Networks for Volumetric Medical Image + Segmentation `_. Original implementation in + MONAI was adapted from `the official Caffe implementation + `_. and `another pytorch implementation + `_. + The model supports 2D or 3D inputs. + + Parameters + ------------- + spatial_dims: + spatial dimension of the input data. Defaults to 3. + in_channels: + number of input channels for the network. Defaults to 1. The value should meet + the condition that ``16 % in_channels == 0``. + out_channels: + number of output channels for the network. Defaults to 1. + act: + activation type in the network. Defaults to ``("elu", {"inplace": True})``. + dropout_prob: + dropout ratio. Defaults to 0.5. Defaults to 3. + dropout_dim: + determine the dimensions of dropout. Defaults to 3. + - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel. + - ``dropout_dim = 2``, randomly zeroes out entire 2D feature maps. + - ``dropout_dim = 3``, Randomly zeroes out entire 3D feature maps. + """ + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 1, + act: Union[Tuple[str, Dict], str] = ("elu", {"inplace": True}), + dropout_prob: float = 0.5, + dropout_dim: int = 3, + ): + super().__init__() + + if spatial_dims not in (2, 3): + raise AssertionError("spatial_dims can only be 2 or 3.") + + self.in_tr = InputTransition(spatial_dims, in_channels, 16, act) + self.down_tr32 = DownTransition(spatial_dims, 16, 1, act) + self.down_tr64 = DownTransition(spatial_dims, 32, 2, act) + self.down_tr128 = DownTransition( + spatial_dims, 64, 3, act, dropout_prob=dropout_prob + ) + self.down_tr256 = DownTransition( + spatial_dims, 128, 2, act, dropout_prob=dropout_prob + ) + self.down_tr512 = DownTransition( + spatial_dims, 256, 2, act, dropout_prob=dropout_prob + ) + self.up_tr512 = UpTransition( + spatial_dims, 512, 512, 2, act, dropout_prob=dropout_prob + ) + self.up_tr256 = UpTransition( + spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob + ) + self.up_tr128 = UpTransition( + spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob + ) + self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act) + self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act) + self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act) + + def forward(self, x): + x = nn.MaxPool3d((1, 3, 3)) # TODO: make as variable + out16 = self.in_tr(x) + out32 = self.down_tr32(out16) + out64 = self.down_tr64(out32) + out128 = self.down_tr128(out64) + out256 = self.down_tr256(out128) + out512 = self.down_tr512(out256) + x = self.up_tr512(out512, out256) + x = self.up_tr256(out256, out128) + x = self.up_tr128(x, out64) + x = self.up_tr64(x, out32) + x = self.up_tr32(x, out16) + x = self.out_tr(x) + return x diff --git a/aicsmlsegment/__init__.py b/aicsmlsegment/__init__.py deleted file mode 100644 index 297e5d4..0000000 --- a/aicsmlsegment/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .version import MODULE_VERSION - - -def get_module_version(): - return MODULE_VERSION - - diff --git a/aicsmlsegment/Net3D/__init__.py b/aicsmlsegment/bin/__init__.py similarity index 100% rename from aicsmlsegment/Net3D/__init__.py rename to aicsmlsegment/bin/__init__.py diff --git a/aicsmlsegment/bin/curator/curator_merging.py b/aicsmlsegment/bin/curator/curator_merging.py index f2b307d..ae72898 100644 --- a/aicsmlsegment/bin/curator/curator_merging.py +++ b/aicsmlsegment/bin/curator/curator_merging.py @@ -5,18 +5,13 @@ import logging import argparse import traceback -import importlib -import pathlib import csv import pandas as pd import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import matplotlib from glob import glob -from random import shuffle -from scipy import stats -from skimage.io import imsave from skimage.draw import line, polygon from aicssegmentation.core.utils import histogram_otsu @@ -24,9 +19,9 @@ from aicsimageio.writers import OmeTiffWriter from aicsmlsegment.utils import input_normalization -matplotlib.use('TkAgg') +matplotlib.use("TkAgg") -#################################################################################################### +####################################################################################### # global settings ignore_img = False flag_done = False @@ -37,108 +32,145 @@ log = logging.getLogger() -logging.basicConfig(level=logging.INFO, - format='[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s') +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s", +) # # Set the default log level for other modules used by this script # logging.getLogger("labkey").setLevel(logging.ERROR) # logging.getLogger("requests").setLevel(logging.WARNING) # logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.INFO) -#################################################################################################### +###################################################################################### + def draw_polygons(event): global pts, draw_img, draw_ax, draw_mask if event.button == 1: - if not (event.ydata == None or event.xdata == None): - pts.append([event.xdata,event.ydata]) - if len(pts)>1: - rr, cc = line(int(round(pts[-1][0])), int(round(pts[-1][1])), int(round(pts[-2][0])), int(round(pts[-2][1])) ) - draw_img[cc,rr,:1]=255 + if not (event.ydata is None or event.xdata is None): + pts.append([event.xdata, event.ydata]) + if len(pts) > 1: + rr, cc = line( + int(round(pts[-1][0])), + int(round(pts[-1][1])), + int(round(pts[-2][0])), + int(round(pts[-2][1])), + ) + draw_img[cc, rr, :1] = 255 draw_ax.set_data(draw_img) plt.draw() elif event.button == 3: - if len(pts)>2: + if len(pts) > 2: # draw polygon pts_array = np.asarray(pts) - rr, cc = polygon(pts_array[:,0], pts_array[:,1]) - draw_img[cc,rr,:1]=255 + rr, cc = polygon(pts_array[:, 0], pts_array[:, 1]) + draw_img[cc, rr, :1] = 255 draw_ax.set_data(draw_img) - draw_mask[cc,rr]=1 + draw_mask[cc, rr] = 1 pts.clear() plt.draw() else: - print('need at least three clicks before finishing annotation') + print("need at least three clicks before finishing annotation") + def quit_mask_drawing(event): global ignore_img - if event.key == 'd': + if event.key == "d": plt.close() - elif event.key == 'b': + elif event.key == "b": ignore_img = True plt.close() - elif event.key == 'q': + elif event.key == "q": exit() def create_merge_mask(raw_img, seg1, seg2, drawing_aim): global pts, draw_img, draw_mask, draw_ax - + offset = 20 - seg1_label = seg1 + offset # make it brighter - seg1_label[seg1_label==offset]=0 - seg1_label = seg1_label.astype(float) * (255/seg1_label.max()) + seg1_label = seg1 + offset # make it brighter + seg1_label[seg1_label == offset] = 0 + seg1_label = seg1_label.astype(float) * (255 / seg1_label.max()) seg1_label = np.round(seg1_label) seg1_label = seg1_label.astype(np.uint8) offset = 25 - seg2_label = seg2 + offset # make it brighter - seg2_label[seg2_label==offset]=0 - seg2_label = seg2_label.astype(float) * (255/seg2_label.max()) + seg2_label = seg2 + offset # make it brighter + seg2_label[seg2_label == offset] = 0 + seg2_label = seg2_label.astype(float) * (255 / seg2_label.max()) seg2_label = np.round(seg2_label) seg2_label = seg2_label.astype(np.uint8) - - bw = seg1>0 - z_profile = np.zeros((bw.shape[0],),dtype=int) + bw = seg1 > 0 + z_profile = np.zeros((bw.shape[0],), dtype=int) for zz in range(bw.shape[0]): - z_profile[zz] = np.count_nonzero(bw[zz,:,:]) - mid_frame = int(round(histogram_otsu(z_profile)*bw.shape[0])) + z_profile[zz] = np.count_nonzero(bw[zz, :, :]) + mid_frame = int(round(histogram_otsu(z_profile) * bw.shape[0])) - img = np.zeros((2*raw_img.shape[1], 3*raw_img.shape[2], 3),dtype=np.uint8) + img = np.zeros((2 * raw_img.shape[1], 3 * raw_img.shape[2], 3), dtype=np.uint8) row_index = 0 for cc in range(3): - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], :raw_img.shape[2], cc]=np.amax(raw_img, axis=0) - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], raw_img.shape[2]:2*raw_img.shape[2], cc]=np.amax(seg1_label, axis=0) - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:, cc]=np.amax(seg2_label, axis=0) - + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + : raw_img.shape[2], + cc, + ] = np.amax(raw_img, axis=0) + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + raw_img.shape[2] : 2 * raw_img.shape[2], + cc, + ] = np.amax(seg1_label, axis=0) + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 2 * raw_img.shape[2] :, + cc, + ] = np.amax(seg2_label, axis=0) + row_index = 1 for cc in range(3): - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], :raw_img.shape[2], cc]=raw_img[mid_frame,:,:] - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], raw_img.shape[2]:2*raw_img.shape[2], cc]=seg1_label[mid_frame,:,:] - img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:, cc]=seg2_label[mid_frame,:,:] - - draw_mask = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + : raw_img.shape[2], + cc, + ] = raw_img[mid_frame, :, :] + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + raw_img.shape[2] : 2 * raw_img.shape[2], + cc, + ] = seg1_label[mid_frame, :, :] + img[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 2 * raw_img.shape[2] :, + cc, + ] = seg2_label[mid_frame, :, :] + + draw_mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) draw_img = img.copy() # display the image for good/bad inspection fig = plt.figure() - figManager = plt.get_current_fig_manager() - figManager.full_screen_toggle() + figManager = plt.get_current_fig_manager() + figManager.full_screen_toggle() ax = fig.add_subplot(111) - ax.set_title('Interface for annotating '+drawing_aim+'. Left: raw, Middle: segmentation v1, Right: segmentation v2. \n' \ - +'Top row: max z projection, Bottom row: middle z slice. \n'\ - +'Please draw in the upper left panel \n'\ - +'Left click to add a vertex; Right click to close the current polygon \n' \ - +'Press D to finish annotating mask, Press Q to quit curation (can resume later)') + ax.set_title( + "Interface for annotating " + + drawing_aim + + ". Left: raw, Middle: segmentation v1, Right: segmentation v2. \n" + + "Top row: max z projection, Bottom row: middle z slice. \n" + + "Please draw in the upper left panel \n" + + "Left click to add a vertex; Right click to close the current polygon \n" + + "Press D to finish annotation, Press Q to quit curation (can resume later)" + ) draw_ax = ax.imshow(img) - cid = fig.canvas.mpl_connect('button_press_event', draw_polygons) - cid2 = fig.canvas.mpl_connect('key_press_event', quit_mask_drawing) + cid = fig.canvas.mpl_connect("button_press_event", draw_polygons) + cid2 = fig.canvas.mpl_connect("key_press_event", quit_mask_drawing) plt.show() fig.canvas.mpl_disconnect(cid) fig.canvas.mpl_disconnect(cid2) + class Args(object): """ Use this to define command line arguments and use them later. @@ -153,7 +185,7 @@ class Args(object): def __init__(self, log_cmdline=True): self.debug = False - self.output_dir = '.'+os.sep + self.output_dir = "." + os.sep self.struct_ch = 0 self.xy = 0.108 @@ -172,7 +204,7 @@ def __no_args_print_help(parser): This is used to print out the help if no arguments are provided. Note: - You need to remove it's usage if your script truly doesn't want arguments. - - It exits with 1 because it's an error if this is used in a script with no args. + - It exits with 1 because it's an error if this is used in a script with no args That's a non-interactive use scenario - typically you don't want help there. """ if len(sys.argv) == 1: @@ -182,18 +214,40 @@ def __no_args_print_help(parser): def __parse(self): p = argparse.ArgumentParser() # Add arguments - p.add_argument('--d', '--debug', action='store_true', dest='debug', - help='If set debug log output is enabled') - p.add_argument('--raw_path', required=True, help='path to raw images') - p.add_argument('--data_type', required=True, help='the type of raw images') - p.add_argument('--input_channel', default=0, type=int) - p.add_argument('--seg1_path', required=True, help='path to segmentation results v1') - p.add_argument('--seg2_path', required=True, help='path to segmentation results v2') - p.add_argument('--train_path', required=True, help='path to output training data') - p.add_argument('--mask_path', help='[optional] the output directory for merging masks') - p.add_argument('--ex_mask_path', help='[optional] the output directory for excluding masks') - p.add_argument('--csv_name', required=True, help='the csv file to save the sorting results') - p.add_argument('--Normalization', required=True, type=int, help='the normalization recipe to use') + p.add_argument( + "--d", + "--debug", + action="store_true", + dest="debug", + help="If set debug log output is enabled", + ) + p.add_argument("--raw_path", required=True, help="path to raw images") + p.add_argument("--data_type", required=True, help="the type of raw images") + p.add_argument("--input_channel", default=0, type=int) + p.add_argument( + "--seg1_path", required=True, help="path to segmentation results v1" + ) + p.add_argument( + "--seg2_path", required=True, help="path to segmentation results v2" + ) + p.add_argument( + "--train_path", required=True, help="path to output training data" + ) + p.add_argument( + "--mask_path", help="[optional] the output directory for merging masks" + ) + p.add_argument( + "--ex_mask_path", help="[optional] the output directory for excluding masks" + ) + p.add_argument( + "--csv_name", required=True, help="the csv file to save the sorting results" + ) + p.add_argument( + "--Normalization", + required=True, + type=int, + help="the normalization recipe to use", + ) self.__no_args_print_help(p) p.parse_args(namespace=self) @@ -210,28 +264,42 @@ def show_info(self): ############################################################################### -class Executor(object): +class Executor(object): def __init__(self, args): if os.path.exists(args.csv_name): - print('the csv file for saving sorting results exists, sorting will be resumed') + print("the csv file for saving sorting results exists, sorting resuming") else: - print('no existing csv found, start a new sorting ') - if not args.data_type.startswith('.'): - args.data_type = '.' + args.data_type + print("no existing csv found, start a new sorting ") + if not args.data_type.startswith("."): + args.data_type = "." + args.data_type - filenames = glob(args.raw_path + os.sep +'*' + args.data_type) + filenames = glob(args.raw_path + os.sep + "*" + args.data_type) filenames.sort() - with open(args.csv_name, 'w') as csvfile: - filewriter = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) - filewriter.writerow(['raw','seg1','seg2','score','merging_mask','excluding_mask']) + with open(args.csv_name, "w") as csvfile: + filewriter = csv.writer( + csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL + ) + filewriter.writerow( + ["raw", "seg1", "seg2", "score", "merging_mask", "excluding_mask"] + ) for _, fn in enumerate(filenames): - seg1_fn = args.seg1_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' - seg2_fn = args.seg2_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' + seg1_fn = ( + args.seg1_path + + os.sep + + os.path.basename(fn)[: -1 * len(args.data_type)] + + "_struct_segmentation.tiff" + ) + seg2_fn = ( + args.seg2_path + + os.sep + + os.path.basename(fn)[: -1 * len(args.data_type)] + + "_struct_segmentation.tiff" + ) assert os.path.exists(seg1_fn) assert os.path.exists(seg2_fn) - filewriter.writerow([fn, seg1_fn , seg2_fn , None, None, None]) + filewriter.writerow([fn, seg1_fn, seg2_fn, None, None, None]) def execute(self, args): @@ -241,102 +309,145 @@ def execute(self, args): for index, row in df.iterrows(): - if not np.isnan(row['score']) and (row['score']==1 or row['score']==0): + if not np.isnan(row["score"]) and (row["score"] == 1 or row["score"] == 0): continue - reader = AICSImage(row['raw']) + reader = AICSImage(row["raw"]) struct_img = reader.get_image_data("ZYX", S=0, T=0, C=args.input_channel) - raw_img = (struct_img- struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-8) + raw_img = (struct_img - struct_img.min() + 1e-8) / ( + struct_img.max() - struct_img.min() + 1e-8 + ) raw_img = 255 * raw_img raw_img = raw_img.astype(np.uint8) - seg1 = np.squeeze(imread(row['seg1'])) > 0.01 - seg2 = np.squeeze(imread(row['seg2'])) > 0.01 - - create_merge_mask(raw_img, seg1.astype(np.uint8), seg2.astype(np.uint8), 'merging_mask') + seg1 = np.squeeze(imread(row["seg1"])) > 0.01 + seg2 = np.squeeze(imread(row["seg2"])) > 0.01 + + create_merge_mask( + raw_img, seg1.astype(np.uint8), seg2.astype(np.uint8), "merging_mask" + ) if ignore_img: - df['score'].iloc[index]=0 + df["score"].iloc[index] = 0 else: - df['score'].iloc[index]=1 - - mask_fn = args.mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff' + df["score"].iloc[index] = 1 + + mask_fn = ( + args.mask_path + + os.sep + + os.path.basename(row["raw"])[:-5] + + "_mask.tiff" + ) crop_mask = np.zeros(seg1.shape, dtype=np.uint8) for zz in range(crop_mask.shape[0]): - crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]] + crop_mask[zz, :, :] = draw_mask[ + : crop_mask.shape[1], : crop_mask.shape[2] + ] crop_mask = crop_mask.astype(np.uint8) - crop_mask[crop_mask>0]=255 + crop_mask[crop_mask > 0] = 255 with OmeTiffWriter(mask_fn) as writer: writer.save(crop_mask) - df['merging_mask'].iloc[index]=mask_fn - - need_mask = input('Do you need to add an excluding mask for this image, enter y or n: ') - if need_mask == 'y': - create_merge_mask(raw_img, seg1.astype(np.uint8), seg2.astype(np.uint8), 'excluding mask') - - mask_fn = args.ex_mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff' + df["merging_mask"].iloc[index] = mask_fn + + need_mask = input( + "Do you need to add an excluding mask for this image, y/n: " + ) + if need_mask == "y": + create_merge_mask( + raw_img, + seg1.astype(np.uint8), + seg2.astype(np.uint8), + "excluding mask", + ) + + mask_fn = ( + args.ex_mask_path + + os.sep + + os.path.basename(row["raw"])[:-5] + + "_mask.tiff" + ) crop_mask = np.zeros(seg1.shape, dtype=np.uint8) for zz in range(crop_mask.shape[0]): - crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]] + crop_mask[zz, :, :] = draw_mask[ + : crop_mask.shape[1], : crop_mask.shape[2] + ] crop_mask = crop_mask.astype(np.uint8) - crop_mask[crop_mask>0]=255 + crop_mask[crop_mask > 0] = 255 with OmeTiffWriter(mask_fn) as writer: writer.save(crop_mask) - df['excluding_mask'].iloc[index]=mask_fn - + df["excluding_mask"].iloc[index] = mask_fn df.to_csv(args.csv_name, index=False) - ######################################### # generate training data: # (we want to do this step after "sorting" - # (is mainly because we want to get the sorting + # (is mainly because we want to get the sorting # step as smooth as possible, even though # this may waster i/o time on reloading images) # ####################################### - print('finish merging, start building the training data ...') - existing_files = glob(args.train_path+os.sep+'img_*.ome.tif') + print("finish merging, start building the training data ...") + existing_files = glob(args.train_path + os.sep + "img_*.ome.tif") print(len(existing_files)) - training_data_count = len(existing_files)//3 + training_data_count = len(existing_files) // 3 for index, row in df.iterrows(): - if row['score']==1: + if row["score"] == 1: training_data_count += 1 # load raw image - reader = AICSImage(row['raw']) - img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32) + reader = AICSImage(row["raw"]) + img = reader.get_image_data( + "CZYX", S=0, T=0, C=[args.input_channel] + ).astype(np.float32) struct_img = input_normalization(img, args) - struct_img= struct_img[0,:,:,:] + struct_img = struct_img[0, :, :, :] - seg1 = np.squeeze(imread(row['seg1'])) > 0.01 - seg2 = np.squeeze(imread(row['seg2'])) > 0.01 + seg1 = np.squeeze(imread(row["seg1"])) > 0.01 + seg2 = np.squeeze(imread(row["seg2"])) > 0.01 + + if os.path.isfile(str(row["merging_mask"])): + mask = np.squeeze(imread(row["merging_mask"])) + seg1[mask > 0] = 0 + seg2[mask == 0] = 0 + seg1 = np.logical_or(seg1, seg2) - if os.path.isfile(str(row['merging_mask'])): - mask = np.squeeze(imread(row['merging_mask'])) - seg1[mask>0]=0 - seg2[mask==0]=0 - seg1 = np.logical_or(seg1,seg2) - cmap = np.ones(seg1.shape, dtype=np.float32) - if os.path.isfile(str(row['excluding_mask'])): - ex_mask = np.squeeze(imread(row['excluding_mask'])) > 0.01 - cmap[ex_mask>0]=0 - - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer: + if os.path.isfile(str(row["excluding_mask"])): + ex_mask = np.squeeze(imread(row["excluding_mask"])) > 0.01 + cmap[ex_mask > 0] = 0 + + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + ".ome.tif" + ) as writer: writer.save(struct_img) seg1 = seg1.astype(np.uint8) - seg1[seg1>0]=1 - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer: + seg1[seg1 > 0] = 1 + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_GT.ome.tif" + ) as writer: writer.save(seg1) - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer: + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_CM.ome.tif" + ) as writer: writer.save(cmap) - print('training data is ready') + print("training data is ready") def main(): @@ -362,4 +473,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/aicsmlsegment/bin/curator/curator_sorting.py b/aicsmlsegment/bin/curator/curator_sorting.py index 3cc4a32..8c4155c 100644 --- a/aicsmlsegment/bin/curator/curator_sorting.py +++ b/aicsmlsegment/bin/curator/curator_sorting.py @@ -5,28 +5,22 @@ import logging import argparse import traceback -import importlib -import pathlib import csv import pandas as pd import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import matplotlib from glob import glob -from random import shuffle -from scipy import stats -from skimage.io import imsave from skimage.draw import line, polygon -from scipy import ndimage as ndi from aicssegmentation.core.utils import histogram_otsu from aicsimageio import AICSImage, imread from aicsimageio.writers import OmeTiffWriter from aicsmlsegment.utils import input_normalization -matplotlib.use('TkAgg') +matplotlib.use("TkAgg") -#################################################################################################### +##################################################################################### # global settings button = 0 flag_done = False @@ -37,112 +31,158 @@ log = logging.getLogger() -logging.basicConfig(level=logging.INFO, - format='[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s') +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s", +) # # Set the default log level for other modules used by this script # logging.getLogger("labkey").setLevel(logging.ERROR) # logging.getLogger("requests").setLevel(logging.WARNING) # logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.INFO) -#################################################################################################### +###################################################################################### + def quit_curation(event): - if event.key == 'q': + if event.key == "q": exit() + def gt_sorting_callback(event): global button - while(1): + while 1: button = event.button if button == 3: - print('You selected this image as GOOD') + print("You selected this image as GOOD") break elif button == 1: - print('You selected this image as BAD') + print("You selected this image as BAD") break plt.close() + def draw_polygons(event): global pts, draw_img, draw_ax, draw_mask if event.button == 1: - if not (event.ydata == None or event.xdata == None): - pts.append([event.xdata,event.ydata]) - if len(pts)>1: - rr, cc = line(int(round(pts[-1][0])), int(round(pts[-1][1])), int(round(pts[-2][0])), int(round(pts[-2][1])) ) - draw_img[cc,rr,:1]=255 + if not (event.ydata is None or event.xdata is None): + pts.append([event.xdata, event.ydata]) + if len(pts) > 1: + rr, cc = line( + int(round(pts[-1][0])), + int(round(pts[-1][1])), + int(round(pts[-2][0])), + int(round(pts[-2][1])), + ) + draw_img[cc, rr, :1] = 255 draw_ax.set_data(draw_img) plt.draw() elif event.button == 3: - if len(pts)>2: + if len(pts) > 2: # draw polygon pts_array = np.asarray(pts) - rr, cc = polygon(pts_array[:,0], pts_array[:,1]) - draw_img[cc,rr,:1]=255 + rr, cc = polygon(pts_array[:, 0], pts_array[:, 1]) + draw_img[cc, rr, :1] = 255 draw_ax.set_data(draw_img) - draw_mask[cc,rr]=1 + draw_mask[cc, rr] = 1 pts.clear() plt.draw() else: - print('need at least three clicks before finishing annotation') + print("need at least three clicks before finishing annotation") + def quit_mask_drawing(event): - if event.key == 'd': + if event.key == "d": plt.close() - elif event.key == 'q': + elif event.key == "q": exit() + def gt_sorting(raw_img, seg): global button - bw = seg>0 - z_profile = np.zeros((bw.shape[0],),dtype=int) + bw = seg > 0 + z_profile = np.zeros((bw.shape[0],), dtype=int) for zz in range(bw.shape[0]): - z_profile[zz] = np.count_nonzero(bw[zz,:,:]) - mid_frame = histogram_otsu(z_profile)*bw.shape[0] + z_profile[zz] = np.count_nonzero(bw[zz, :, :]) + mid_frame = histogram_otsu(z_profile) * bw.shape[0] print("trying to find the best Z to display ...") print(f"the raw image has z profile {z_profile}") print(f"find best Z = {mid_frame}") mid_frame = int(round(mid_frame)) - #create 2x4 mosaic - out = np.zeros((2*raw_img.shape[1], 4*raw_img.shape[2], 3),dtype=np.uint8) - row_index=0 + # create 2x4 mosaic + out = np.zeros((2 * raw_img.shape[1], 4 * raw_img.shape[2], 3), dtype=np.uint8) + row_index = 0 im = raw_img - - for cc in range(3): - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 0*raw_img.shape[2]:1*raw_img.shape[2], cc]=im[mid_frame-4,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 1*raw_img.shape[2]:2*raw_img.shape[2], cc]=im[mid_frame,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:3*raw_img.shape[2], cc]=im[mid_frame+4,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 3*raw_img.shape[2]:4*raw_img.shape[2], cc]=np.amax(im, axis=0) - row_index=1 + for cc in range(3): + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 0 * raw_img.shape[2] : 1 * raw_img.shape[2], + cc, + ] = im[mid_frame - 4, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 1 * raw_img.shape[2] : 2 * raw_img.shape[2], + cc, + ] = im[mid_frame, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 2 * raw_img.shape[2] : 3 * raw_img.shape[2], + cc, + ] = im[mid_frame + 4, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 3 * raw_img.shape[2] : 4 * raw_img.shape[2], + cc, + ] = np.amax(im, axis=0) + + row_index = 1 offset = 20 - im = seg + offset # make it brighter - im[im==offset]=0 - im = im.astype(float) * (255/im.max()) + im = seg + offset # make it brighter + im[im == offset] = 0 + im = im.astype(float) * (255 / im.max()) im = np.round(im) im = im.astype(np.uint8) for cc in range(3): - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 0*raw_img.shape[2]:1*raw_img.shape[2], cc]=im[mid_frame-4,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 1*raw_img.shape[2]:2*raw_img.shape[2], cc]=im[mid_frame,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:3*raw_img.shape[2], cc]=im[mid_frame+4,:,:] - out[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 3*raw_img.shape[2]:4*raw_img.shape[2], cc]=np.amax(im, axis=0) + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 0 * raw_img.shape[2] : 1 * raw_img.shape[2], + cc, + ] = im[mid_frame - 4, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 1 * raw_img.shape[2] : 2 * raw_img.shape[2], + cc, + ] = im[mid_frame, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 2 * raw_img.shape[2] : 3 * raw_img.shape[2], + cc, + ] = im[mid_frame + 4, :, :] + out[ + row_index * raw_img.shape[1] : (row_index + 1) * raw_img.shape[1], + 3 * raw_img.shape[2] : 4 * raw_img.shape[2], + cc, + ] = np.amax(im, axis=0) # display the image for good/bad inspection fig = plt.figure() - figManager = plt.get_current_fig_manager() - figManager.full_screen_toggle() + figManager = plt.get_current_fig_manager() + figManager.full_screen_toggle() ax = fig.add_subplot(111) ax.imshow(out) - ax.set_title('Interface for Sorting. Left click = BAD. Right click = GOOD \n' - + 'Press Q to quit the current curation (can be resumed later)\n' \ - + 'Columns left to right: 4 z slice below middle z slice, middle z slice, \n' \ - + '4 z slice above middle z slice, max z projection \n' - + 'Top row: raw image; bottom row: segmentation. \n ') - #plt.tight_layout() - cid = fig.canvas.mpl_connect('button_press_event', gt_sorting_callback) - cid2 = fig.canvas.mpl_connect('key_press_event', quit_curation) + ax.set_title( + "Interface for Sorting. Left click = BAD. Right click = GOOD \n" + + "Press Q to quit the current curation (can be resumed later)\n" + + "Columns left to right: 4 z slice below middle z slice, middle z slice, \n" + + "4 z slice above middle z slice, max z projection \n" + + "Top row: raw image; bottom row: segmentation. \n " + ) + # plt.tight_layout() + cid = fig.canvas.mpl_connect("button_press_event", gt_sorting_callback) + cid2 = fig.canvas.mpl_connect("key_press_event", quit_curation) plt.show() fig.canvas.mpl_disconnect(cid) fig.canvas.mpl_disconnect(cid2) @@ -154,52 +194,57 @@ def gt_sorting(raw_img, seg): button = 0 return score + def create_mask(raw_img, seg): global pts, draw_img, draw_mask, draw_ax - bw = seg>0 - z_profile = np.zeros((bw.shape[0],),dtype=int) + bw = seg > 0 + z_profile = np.zeros((bw.shape[0],), dtype=int) for zz in range(bw.shape[0]): - z_profile[zz] = np.count_nonzero(bw[zz,:,:]) - - mid_frame = histogram_otsu(z_profile)*bw.shape[0] + z_profile[zz] = np.count_nonzero(bw[zz, :, :]) + + mid_frame = histogram_otsu(z_profile) * bw.shape[0] print("trying to find the best Z to display ...") print(f"the raw image has z profile {z_profile}") print(f"find best Z = {mid_frame}") mid_frame = int(round(mid_frame)) offset = 20 - seg_label = seg + offset # make it brighter - seg_label[seg_label==offset]=0 - seg_label = seg_label.astype(float) * (255/seg_label.max()) + seg_label = seg + offset # make it brighter + seg_label[seg_label == offset] = 0 + seg_label = seg_label.astype(float) * (255 / seg_label.max()) seg_label = np.round(seg_label) seg_label = seg_label.astype(np.uint8) - img = np.zeros((raw_img.shape[1], 3*raw_img.shape[2], 3),dtype=np.uint8) + img = np.zeros((raw_img.shape[1], 3 * raw_img.shape[2], 3), dtype=np.uint8) for cc in range(3): - img[:, :raw_img.shape[2], cc]=raw_img[mid_frame,:,:] - img[:, raw_img.shape[2]:2*raw_img.shape[2], cc]=seg_label[mid_frame,:,:] - img[:, 2*raw_img.shape[2]:, cc]=np.amax(seg_label, axis=0) + img[:, : raw_img.shape[2], cc] = raw_img[mid_frame, :, :] + img[:, raw_img.shape[2] : 2 * raw_img.shape[2], cc] = seg_label[mid_frame, :, :] + img[:, 2 * raw_img.shape[2] :, cc] = np.amax(seg_label, axis=0) - draw_mask = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) + draw_mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) draw_img = img.copy() # display the image for good/bad inspection fig = plt.figure() - figManager = plt.get_current_fig_manager() - figManager.full_screen_toggle() + figManager = plt.get_current_fig_manager() + figManager.full_screen_toggle() ax = fig.add_subplot(111) - ax.set_title('Interface for annotating excluding mask. \n' \ - +'Left: Middle z slice of raw. Middle: Middle z slice of segmentation. Right: Max z projection of segmentation \n' \ - +'Please draw in the left panel \n' \ - +'Left click to add a vertex; Right click to close the current polygon \n' \ - +'Press D to finish annotating mask, Press Q to quit curation (can resume later)') + ax.set_title( + "Interface for annotating excluding mask. \n" + + "Left: Middle z slice of raw. Middle: Middle z slice of segmentation. " + + "Right: Max z projection of segmentation \n" + + "Please draw in the left panel \n" + + "Left click to add a vertex; Right click to close the current polygon \n" + + "Press D to finish annotation, Press Q to quit curation (can resume later)" + ) draw_ax = ax.imshow(img) - cid = fig.canvas.mpl_connect('button_press_event', draw_polygons) - cid2 = fig.canvas.mpl_connect('key_press_event', quit_mask_drawing) + cid = fig.canvas.mpl_connect("button_press_event", draw_polygons) + cid2 = fig.canvas.mpl_connect("key_press_event", quit_mask_drawing) plt.show() fig.canvas.mpl_disconnect(cid) fig.canvas.mpl_disconnect(cid2) + class Args(object): """ Use this to define command line arguments and use them later. @@ -214,7 +259,7 @@ class Args(object): def __init__(self, log_cmdline=True): self.debug = False - self.output_dir = '.'+os.sep + self.output_dir = "." + os.sep self.struct_ch = 0 self.xy = 0.108 @@ -233,7 +278,7 @@ def __no_args_print_help(parser): This is used to print out the help if no arguments are provided. Note: - You need to remove it's usage if your script truly doesn't want arguments. - - It exits with 1 because it's an error if this is used in a script with no args. + - It exits with 1 because it's an error if this is used in a script with no args That's a non-interactive use scenario - typically you don't want help there. """ if len(sys.argv) == 1: @@ -243,16 +288,30 @@ def __no_args_print_help(parser): def __parse(self): p = argparse.ArgumentParser() # Add arguments - p.add_argument('--d', '--debug', action='store_true', dest='debug', - help='If set debug log output is enabled') - p.add_argument('--raw_path', required=True, help='path to raw images') - p.add_argument('--data_type', required=True, help='the type of raw images') - p.add_argument('--input_channel', default=0, type=int) - p.add_argument('--seg_path', required=True, help='path to segmentation results') - p.add_argument('--train_path', required=True, help='path to output training data') - p.add_argument('--mask_path', help='[optional] the output directory for masks') - p.add_argument('--csv_name', required=True, help='the csv file to save the sorting results') - p.add_argument('--Normalization', required=True, type=int, help='the normalization recipe to use') + p.add_argument( + "--d", + "--debug", + action="store_true", + dest="debug", + help="If set debug log output is enabled", + ) + p.add_argument("--raw_path", required=True, help="path to raw images") + p.add_argument("--data_type", required=True, help="the type of raw images") + p.add_argument("--input_channel", default=0, type=int) + p.add_argument("--seg_path", required=True, help="path to segmentation results") + p.add_argument( + "--train_path", required=True, help="path to output training data" + ) + p.add_argument("--mask_path", help="[optional] the output directory for masks") + p.add_argument( + "--csv_name", required=True, help="the csv file to save the sorting results" + ) + p.add_argument( + "--Normalization", + required=True, + type=int, + help="the normalization recipe to use", + ) self.__no_args_print_help(p) p.parse_args(namespace=self) @@ -269,26 +328,33 @@ def show_info(self): ############################################################################### -class Executor(object): +class Executor(object): def __init__(self, args): if os.path.exists(args.csv_name): - print('the csv file for saving sorting results exists, sorting will be resumed') + print("the csv file for saving sorting results exists, sorting resuming") else: - print('no existing csv found, start a new sorting ') - if not args.data_type.startswith('.'): - args.data_type = '.' + args.data_type + print("no existing csv found, start a new sorting ") + if not args.data_type.startswith("."): + args.data_type = "." + args.data_type - filenames = glob(args.raw_path + os.sep +'*' + args.data_type) + filenames = glob(args.raw_path + os.sep + "*" + args.data_type) filenames.sort() - with open(args.csv_name, 'w') as csvfile: - filewriter = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) - filewriter.writerow(['raw','seg','score','mask']) + with open(args.csv_name, "w") as csvfile: + filewriter = csv.writer( + csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL + ) + filewriter.writerow(["raw", "seg", "score", "mask"]) for _, fn in enumerate(filenames): - seg_fn = args.seg_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' + seg_fn = ( + args.seg_path + + os.sep + + os.path.basename(fn)[: -1 * len(args.data_type)] + + "_struct_segmentation.tiff" + ) assert os.path.exists(seg_fn) - filewriter.writerow([fn, seg_fn , None , None]) + filewriter.writerow([fn, seg_fn, None, None]) def execute(self, args): @@ -298,84 +364,116 @@ def execute(self, args): for index, row in df.iterrows(): - if not np.isnan(row['score']) and (row['score']==1 or row['score']==0): + if not np.isnan(row["score"]) and (row["score"] == 1 or row["score"] == 0): continue - reader = AICSImage(row['raw']) + reader = AICSImage(row["raw"]) struct_img = reader.get_image_data("ZYX", S=0, T=0, C=args.input_channel) - struct_img[struct_img>5000] = struct_img.min() # adjust contrast - raw_img = (struct_img- struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-8) + struct_img[struct_img > 5000] = struct_img.min() # adjust contrast + raw_img = (struct_img - struct_img.min() + 1e-8) / ( + struct_img.max() - struct_img.min() + 1e-8 + ) raw_img = 255 * raw_img raw_img = raw_img.astype(np.uint8) - seg = np.squeeze(imread(row['seg'])) + seg = np.squeeze(imread(row["seg"])) score = gt_sorting(raw_img, seg) if score == 1: - df['score'].iloc[index]=1 - need_mask = input('Do you need to add a mask for this image, enter y or n: ') - if need_mask == 'y': + df["score"].iloc[index] = 1 + need_mask = input( + "Do you need to add a mask for this image, enter y or n: " + ) + if need_mask == "y": create_mask(raw_img, seg.astype(np.uint8)) - mask_fn = args.mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff' + mask_fn = ( + args.mask_path + + os.sep + + os.path.basename(row["raw"])[:-5] + + "_mask.tiff" + ) crop_mask = np.zeros(seg.shape, dtype=np.uint8) for zz in range(crop_mask.shape[0]): - crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]] + crop_mask[zz, :, :] = draw_mask[ + : crop_mask.shape[1], : crop_mask.shape[2] + ] crop_mask = crop_mask.astype(np.uint8) - crop_mask[crop_mask>0]=255 + crop_mask[crop_mask > 0] = 255 with OmeTiffWriter(mask_fn) as writer: writer.save(crop_mask) - df['mask'].iloc[index]=mask_fn + df["mask"].iloc[index] = mask_fn else: - df['score'].iloc[index]=0 + df["score"].iloc[index] = 0 df.to_csv(args.csv_name, index=False) ######################################### # generate training data: # (we want to do this step after "sorting" - # (is mainly because we want to get the sorting + # (is mainly because we want to get the sorting # step as smooth as possible, even though # this may waster i/o time on reloading images) # ####################################### - print('finish merging, start building the training data ...') + print("finish merging, start building the training data ...") - existing_files = glob(args.train_path+os.sep+'img_*.ome.tif') + existing_files = glob(args.train_path + os.sep + "img_*.ome.tif") print(len(existing_files)) - training_data_count = len(existing_files)//3 - + training_data_count = len(existing_files) // 3 + for index, row in df.iterrows(): - if row['score']==1: + if row["score"] == 1: training_data_count += 1 # load raw image - reader = AICSImage(row['raw']) - img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32) + reader = AICSImage(row["raw"]) + img = reader.get_image_data( + "CZYX", S=0, T=0, C=[args.input_channel] + ).astype(np.float32) struct_img = input_normalization(img, args) - struct_img= struct_img[0,:,:,:] + struct_img = struct_img[0, :, :, :] # load segmentation gt - seg = np.squeeze(imread(row['seg'])) > 0.01 + seg = np.squeeze(imread(row["seg"])) > 0.01 seg = seg.astype(np.uint8) - seg[seg>0]=1 + seg[seg > 0] = 1 cmap = np.ones(seg.shape, dtype=np.float32) - if os.path.isfile(str(row['mask'])): + if os.path.isfile(str(row["mask"])): # load segmentation gt - mask = np.squeeze(imread(row['mask'])) - cmap[mask>0]=0 - - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer: + mask = np.squeeze(imread(row["mask"])) + cmap[mask > 0] = 0 + + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + ".ome.tif" + ) as writer: writer.save(struct_img) - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer: + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_GT.ome.tif" + ) as writer: writer.save(seg) - - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer: + + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_CM.ome.tif" + ) as writer: writer.save(cmap) - print('training data is ready') + print("training data is ready") + def main(): dbg = False diff --git a/aicsmlsegment/bin/curator/curator_takeall.py b/aicsmlsegment/bin/curator/curator_takeall.py index 04bb39a..e2b4a08 100644 --- a/aicsmlsegment/bin/curator/curator_takeall.py +++ b/aicsmlsegment/bin/curator/curator_takeall.py @@ -5,26 +5,15 @@ import logging import argparse import traceback -import importlib -import pathlib -import csv -import pandas as pd import numpy as np -import matplotlib.pyplot as plt from glob import glob -from random import shuffle -from scipy import stats -from skimage.io import imsave -from skimage.draw import line, polygon -from scipy import ndimage as ndi -from aicssegmentation.core.utils import histogram_otsu from aicsimageio import AICSImage, imread from aicsimageio.writers import OmeTiffWriter from aicsmlsegment.utils import input_normalization -#################################################################################################### +###################################################################################### # global settings button = 0 flag_done = False @@ -35,15 +24,18 @@ log = logging.getLogger() -logging.basicConfig(level=logging.INFO, - format='[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s') +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s", +) # # Set the default log level for other modules used by this script # logging.getLogger("labkey").setLevel(logging.ERROR) # logging.getLogger("requests").setLevel(logging.WARNING) # logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.INFO) -#################################################################################################### +###################################################################################### + class Args(object): """ @@ -59,7 +51,7 @@ class Args(object): def __init__(self, log_cmdline=True): self.debug = False - self.output_dir = '.' + os.sep + self.output_dir = "." + os.sep self.struct_ch = 0 self.xy = 0.108 @@ -78,7 +70,7 @@ def __no_args_print_help(parser): This is used to print out the help if no arguments are provided. Note: - You need to remove it's usage if your script truly doesn't want arguments. - - It exits with 1 because it's an error if this is used in a script with no args. + - It exits with 1 because it's an error if this is used in a script with no args That's a non-interactive use scenario - typically you don't want help there. """ if len(sys.argv) == 1: @@ -88,15 +80,24 @@ def __no_args_print_help(parser): def __parse(self): p = argparse.ArgumentParser() # Add arguments - p.add_argument('--d', '--debug', action='store_true', dest='debug', - help='If set debug log output is enabled') - p.add_argument('--raw_path', required=True, help='path to raw images') - p.add_argument('--data_type', required=True, help='the type of raw images') - p.add_argument('--input_channel', default=0, type=int) - p.add_argument('--seg_path', required=True, help='path to segmentation results') - p.add_argument('--train_path', required=True, help='path to output training data') - p.add_argument('--mask_path', help='[optional] the output directory for masks') - p.add_argument('--Normalization', default=0, help='the normalization method to use') + p.add_argument( + "--d", + "--debug", + action="store_true", + dest="debug", + help="If set debug log output is enabled", + ) + p.add_argument("--raw_path", required=True, help="path to raw images") + p.add_argument("--data_type", required=True, help="the type of raw images") + p.add_argument("--input_channel", default=0, type=int) + p.add_argument("--seg_path", required=True, help="path to segmentation results") + p.add_argument( + "--train_path", required=True, help="path to output training data" + ) + p.add_argument("--mask_path", help="[optional] the output directory for masks") + p.add_argument( + "--Normalization", default=0, help="the normalization method to use" + ) self.__no_args_print_help(p) p.parse_args(namespace=self) @@ -113,52 +114,82 @@ def show_info(self): ############################################################################### -class Executor(object): +class Executor(object): def __init__(self, args): pass def execute(self, args): - if not args.data_type.startswith('.'): - args.data_type = '.' + args.data_type + if not args.data_type.startswith("."): + args.data_type = "." + args.data_type - filenames = glob(args.raw_path + os.sep +'*' + args.data_type) + filenames = glob(args.raw_path + os.sep + "*" + args.data_type) filenames.sort() - existing_files = glob(args.train_path+os.sep+'img_*.ome.tif') + existing_files = glob(args.train_path + os.sep + "img_*.ome.tif") print(len(existing_files)) - training_data_count = len(existing_files)//3 + training_data_count = len(existing_files) // 3 for _, fn in enumerate(filenames): - + training_data_count += 1 - + # load raw reader = AICSImage(fn) - struct_img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32) - struct_img = input_normalization(img, args) + struct_img = reader.get_image_data( + "CZYX", S=0, T=0, C=[args.input_channel] + ).astype(np.float32) + struct_img = input_normalization(struct_img, args) # load seg - seg_fn = args.seg_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' + seg_fn = ( + args.seg_path + + os.sep + + os.path.basename(fn)[: -1 * len(args.data_type)] + + "_struct_segmentation.tiff" + ) seg = np.squeeze(imread(seg_fn)) > 0.01 seg = seg.astype(np.uint8) - seg[seg>0]=1 + seg[seg > 0] = 1 # excluding mask cmap = np.ones(seg.shape, dtype=np.float32) - mask_fn = args.mask_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_mask.tiff' + mask_fn = ( + args.mask_path + + os.sep + + os.path.basename(fn)[: -1 * len(args.data_type)] + + "_mask.tiff" + ) if os.path.isfile(mask_fn): mask = np.squeeze(imread(mask_fn)) - cmap[mask==0]=0 - - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer: + cmap[mask == 0] = 0 + + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + ".ome.tif" + ) as writer: writer.save(struct_img) - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer: + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_GT.ome.tif" + ) as writer: writer.save(seg) - - with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer: + + with OmeTiffWriter( + args.train_path + + os.sep + + "img_" + + f"{training_data_count:03}" + + "_CM.ome.tif" + ) as writer: writer.save(cmap) @@ -185,4 +216,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/aicsmlsegment/bin/predict.py b/aicsmlsegment/bin/predict.py index 59c5d79..a4fa765 100644 --- a/aicsmlsegment/bin/predict.py +++ b/aicsmlsegment/bin/predict.py @@ -1,184 +1,62 @@ #!/usr/bin/env python -import sys import argparse -import logging -import traceback -import os -import pathlib -import numpy as np +from aicsmlsegment.utils import load_config, create_unique_run_directory +from aicsmlsegment.Model import Model +from aicsmlsegment.DataUtils.DataMod import DataModule +import pytorch_lightning +import torch.autograd.profiler as profiler -from skimage.morphology import remove_small_objects -from skimage.io import imsave -from aicsimageio import AICSImage -from scipy.ndimage import zoom - -from aicsmlsegment.utils import load_config, load_single_image, input_normalization, image_normalization -from aicsmlsegment.utils import get_logger -from aicsmlsegment.model_utils import build_model, load_checkpoint, model_inference, apply_on_image def main(): - - parser = argparse.ArgumentParser() - parser.add_argument('--config', required=True) - args = parser.parse_args() - - config = load_config(args.config) - - # declare the model - model = build_model(config) - - # load the trained model instance - model_path = config['model_path'] - print(f'Loading model from {model_path}...') - load_checkpoint(model_path, model) - - # extract the parameters for running the model inference - args_inference=lambda:None - args_inference.size_in = config['size_in'] - args_inference.size_out = config['size_out'] - args_inference.OutputCh = config['OutputCh'] - args_inference.nclass = config['nclass'] - if config['RuntimeAug'] <=0: - args_inference.RuntimeAug = False - else: - args_inference.RuntimeAug = True - - # run - inf_config = config['mode'] - if inf_config['name'] == 'file': - fn = inf_config['InputFile'] - data_reader = AICSImage(fn) - - if inf_config['timelapse']: - assert data_reader.shape[1] > 1, "not a timelapse, check you data" - - for tt in range(data_reader.shape[1]): - # Assume: dimensions = TCZYX - img = data_reader.get_image_data("CZYX", S=0, T=tt, C=config['InputCh']).astype(float) - img = image_normalization(img, config['Normalization']) - - if len(config['ResizeRatio'])>0: - img = zoom(img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') - for ch_idx in range(img.shape[0]): - struct_img = img[ch_idx,:,:,:] - struct_img = (struct_img - struct_img.min())/(struct_img.max() - struct_img.min()) - img[ch_idx,:,:,:] = struct_img - - # apply the model - output_img = apply_on_image(model, img, model.final_activation, args_inference) - - # extract the result and write the output - if len(config['OutputCh']) == 2: - out = output_img[0] - out = (out - out.min()) / (out.max()-out.min()) - if len(config['ResizeRatio'])>0: - out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') - out = out.astype(np.float32) - if config['Threshold']>0: - out = out > config['Threshold'] - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_'+ f'{tt:03}' +'_struct_segmentation.tiff', out) - else: - for ch_idx in range(len(config['OutputCh'])//2): - out = output_img[ch_idx] - out = (out - out.min()) / (out.max()-out.min()) - if len(config['ResizeRatio'])>0: - out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') - out = out.astype(np.float32) - if config['Threshold']>0: - out = out > config['Threshold'] - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_'+ f'{tt:03}' +'_seg_'+ str(config['OutputCh'][2*ch_idx])+'.tiff',out) - else: - img = data_reader.get_image_data("CZYX", S=0, T=0, C=config['InputCh']).astype(float) - img = image_normalization(img, config['Normalization']) - - if len(config['ResizeRatio'])>0: - img = zoom(img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') - for ch_idx in range(img.shape[0]): - struct_img = img[ch_idx,:,:,:] # note that struct_img is only a view of img, so changes made on struct_img also affects img - struct_img = (struct_img - struct_img.min())/(struct_img.max() - struct_img.min()) - img[ch_idx,:,:,:] = struct_img - - # apply the model - output_img = apply_on_image(model, img, model.final_activation, args_inference) - - # extract the result and write the output - if len(config['OutputCh']) == 2: - out = output_img[0] - out = (out - out.min()) / (out.max()-out.min()) - if len(config['ResizeRatio'])>0: - out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') - out = out.astype(np.float32) - if config['Threshold']>0: - out = out > config['Threshold'] - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem +'_struct_segmentation.tiff', out) - else: - for ch_idx in range(len(config['OutputCh'])//2): - out = output_img[ch_idx] - out = (out - out.min()) / (out.max()-out.min()) - if len(config['ResizeRatio'])>0: - out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') - out = out.astype(np.float32) - if config['Threshold']>0: - out = out > config['Threshold'] - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem +'_seg_'+ str(config['OutputCh'][2*ch_idx])+'.tiff', out) - print(f'Image {fn} has been segmented') - - elif inf_config['name'] == 'folder': - from glob import glob - filenames = glob(inf_config['InputDir'] + '/*' + inf_config['DataType']) - filenames.sort() #(reverse=True) - print('files to be processed:') - print(filenames) - - for _, fn in enumerate(filenames): - - # load data - data_reader = AICSImage(fn) - img = data_reader.get_image_data('CZYX', S=0, T=0, C=config['InputCh']).astype(float) - if len(config['ResizeRatio'])>0: - img = zoom(img, (1,config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') - img = image_normalization(img, config['Normalization']) - - # apply the model - output_img = apply_on_image(model, img, model.final_activation, args_inference) - - # extract the result and write the output - if len(config['OutputCh'])==2: - if config['Threshold']<0: - out = output_img[0] - out = (out - out.min()) / (out.max()-out.min()) - if len(config['ResizeRatio'])>0: - out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') - out = out.astype(np.float32) - out = (out - out.min()) / (out.max()-out.min()) - else: - out = remove_small_objects(output_img[0] > config['Threshold'], min_size=2, connectivity=1) - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_struct_segmentation.tiff', out) - else: - for ch_idx in range(len(config['OutputCh'])//2): - if config['Threshold']<0: - out = output_img[ch_idx] - out = (out - out.min()) / (out.max()-out.min()) - out = out.astype(np.float32) - else: - out = output_img[ch_idx] > config['Threshold'] - out = out.astype(np.uint8) - out[out>0]=255 - imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_seg_'+ str(config['OutputCh'][2*ch_idx])+'.ome.tif', out) - - print(f'Image {fn} has been segmented') - -if __name__ == '__main__': - - main() \ No newline at end of file + with profiler.profile(profile_memory=True) as prof: + + # load config + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + args = parser.parse_args() + config, model_config = load_config(args.config, train=False) + + # load the trained model instance + model_path = config["model_path"] + print(f"Loading model from {model_path}...") + try: + model = Model.load_from_checkpoint( + model_path, config=config, model_config=model_config, train=False + ) + except KeyError: # backwards compatibility for old .pytorch checkpoints + from aicsmlsegment.model_utils import load_checkpoint + + model = Model(config, model_config, train=False) + load_checkpoint(model_path, model) + + # set up GPU + gpu_config = config["gpus"] + if gpu_config < -1: + print("Number of GPUs must be -1 or > 0") + quit() + + # prepare output directory + output_dir = create_unique_run_directory(config, train=False) + config["OutputDir"] = output_dir + + print(config) + + # ddp is the default unless only one gpu is requested + accelerator = config["dist_backend"] + trainer = pytorch_lightning.Trainer( + gpus=gpu_config, + num_sanity_val_steps=0, + distributed_backend=accelerator, + precision=config["precision"], + ) + data_module = DataModule(config, train=False) + with profiler.record_function("inference"): + trainer.test(model, datamodule=data_module) + + # print usage profile + print(prof.key_averages().table()) + + +if __name__ == "__main__": + main() diff --git a/aicsmlsegment/bin/train.py b/aicsmlsegment/bin/train.py index 9c81181..6d927af 100644 --- a/aicsmlsegment/bin/train.py +++ b/aicsmlsegment/bin/train.py @@ -1,44 +1,139 @@ -#!/usr/bin/env python - -import sys +import pytorch_lightning import argparse -import logging -import traceback +from aicsmlsegment.utils import load_config, get_logger, create_unique_run_directory +from aicsmlsegment.Model import Model +from aicsmlsegment.DataUtils.DataMod import DataModule +from pytorch_lightning.callbacks import ModelCheckpoint + -from aicsmlsegment.utils import load_config +def main(config=None, model_config=None): -from aicsmlsegment.training_utils import BasicFolderTrainer, get_loss_criterion, build_optimizer, get_train_dataloader -from aicsmlsegment.utils import get_logger -from aicsmlsegment.model_utils import get_number_of_learnable_parameters, build_model, load_checkpoint + ######### + # only for debugging + ######### + if config is None: + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + args = parser.parse_args() + # create logger + config, model_config = load_config(args.config, train=True) + logger = get_logger("ModelTrainer") -def main(): + # load a specified saved model + if config["resume"] is not None: + print(f"Loading checkpoint '{config['resume']}'...") + try: + model = Model.load_from_checkpoint( + config["resume"], config=config, model_config=model_config, train=True + ) + except KeyError: # backwards compatibility + from aicsmlsegment.model_utils import load_checkpoint - parser = argparse.ArgumentParser() - parser.add_argument('--config', required=True) - args = parser.parse_args() + model = Model(config, model_config, train=True) + load_checkpoint(config["resume"], model) + else: + print("Training new model from scratch") + model = Model(config, model_config, train=True) - # create logger - logger = get_logger('ModelTrainer') - config = load_config(args.config) + checkpoint_dir = create_unique_run_directory(config, train=True) + config["checkpoint_dir"] = checkpoint_dir logger.info(config) - # Create model - model = build_model(config) + # save model checkpoint every n epochs + MC = ModelCheckpoint( + dirpath=checkpoint_dir, + filename="checkpoint_{epoch}", + period=config["save_every_n_epoch"], + save_top_k=-1, + ) + callbacks = [MC] - # Log the number of learnable parameters - logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') + callbacks_config = config["callbacks"] - # check if resuming - if config['resume'] is not None: - print(f"Loading checkpoint '{config['resume']}'...") - load_checkpoint(config['resume'], model) + # it is possible to use early stopping by adding callback config + # in configuration yaml + if callbacks_config["name"] == "EarlyStopping": + es = pytorch_lightning.callbacks.EarlyStopping( + monitor=callbacks_config["monitor"], + min_delta=callbacks_config["min_delta"], + patience=callbacks_config["patience"], + verbose=callbacks_config["verbose"], + mode=callbacks_config["verbose"], + ) + callbacks.append(es) + + # it is possible to use stachastic weight averaging by adding + # a "SWA" option in configuration yaml + if config["SWA"] is not None: + assert ( + config["scheduler"]["name"] != "ReduceLROnPlateau" + ), "ReduceLROnPlateau scheduler is not currently compatible with SWA" + swa = pytorch_lightning.callbacks.StochasticWeightAveraging( + swa_epoch_start=config["SWA"]["swa_start"], + swa_lrs=config["SWA"]["swa_lr"], + annealing_epochs=config["SWA"]["annealing_epochs"], + annealing_strategy=config["SWA"]["annealing_strategy"], + ) + callbacks.append(swa) + + # gpu setting + gpu_config = config["gpus"] + if gpu_config < -1: + print("Number of GPUs must be -1 or > 0") + quit() + + # ddp is the default unless only one gpu is requested + accelerator = config["dist_backend"] + plugins = None + if accelerator == "ddp": + from pytorch_lightning.plugins import DDPPlugin + + # reduces multi-gpu model memory, removes unecessary backwards pass + plugins = ["ddp_sharded", DDPPlugin(find_unused_parameters=False)] + + # it is possible to use tensorboard to track the experiment by adding + # a "tensorboard" option in the configuration yaml + if config["tensorboard"] is not None: + from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor + from pytorch_lightning.loggers import TensorBoardLogger + + logger = TensorBoardLogger(config["tensorboard"]) + GPU = GPUStatsMonitor(intra_step_time=True, inter_step_time=True) + LR = LearningRateMonitor(logging_interval="epoch") + callbacks += [GPU, LR] else: - print('start a new training') + from pytorch_lightning.loggers import CSVLogger + + logger = CSVLogger(save_dir=checkpoint_dir) + + # define the model trainer + trainer = pytorch_lightning.Trainer( + gpus=gpu_config, + max_epochs=config["epochs"], + check_val_every_n_epoch=config["validation"]["validate_every_n_epoch"], + num_sanity_val_steps=0, + callbacks=callbacks, + reload_dataloaders_every_epoch=False, # check https://github.com/PyTorchLightning/pytorch-lightning/pull/5043 for updates on pull request # noqa E501 + # reload_dataloaders_every_n_epoch = config['loader']['epoch_shuffle'] + distributed_backend=accelerator, + logger=logger, + precision=config["precision"], + plugins=plugins, + ) + + # define the data module + data_module = DataModule(config) + + # starts training + trainer.fit(model, data_module) + + # after training is done, print the best model + print( + "The best performing checkpoint is", + MC.best_model_path, + ) - # run the training - trainer = BasicFolderTrainer(model, config, logger=logger) - trainer.train() -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/aicsmlsegment/custom_loss.py b/aicsmlsegment/custom_loss.py index cbd3e74..27d3900 100644 --- a/aicsmlsegment/custom_loss.py +++ b/aicsmlsegment/custom_loss.py @@ -1,85 +1,358 @@ - -from torch.autograd import Variable, Function -import torch.optim as optim +from torch.autograd import Variable import torch.nn as nn +import torch.functional as F import torch import numpy as np +from typing import Dict + +SUPPORTED_LOSSES = { + # MONAI + "Dice": { + "source": "monai.losses", + "args": ["softmax", "include_background"], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": False}, + }, + "GeneralizedDice": { + "source": "monai.losses", + "args": ["softmax", "include_background"], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": False}, + }, + "Focal": { + "source": "monai.losses", + "args": [], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": False}, + }, + "MaskedDice": { + "source": "monai.losses", + "args": ["softmax", "include_background"], + "wrapper_args": { + "n_label_ch": 2, + "accepts_costmap": True, + "cmap_unsqueeze": True, + }, + }, + # TORCH + "MSE": { + "source": "torch.nn", + "args": [], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": False}, + }, + "CrossEntropy": { + "source": "torch.nn", + "args": [], + "wrapper_args": { + "n_label_ch": 1, + "accepts_costmap": False, + "to_long": True, + "label_squeeze": True, + }, + }, + # CUSTOM + "PixelWiseCrossEntropy": { + "source": "aicsmlsegment.custom_loss", + "args": [], + "costmap": True, + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": True}, + }, + "ElementAngularMSE": { + "source": "aicsmlsegment.custom_loss", + "args": [], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": True}, + }, + "MaskedMSE": { + "source": "aicsmlsegment.custom_loss", + "args": [], + "wrapper_args": {"n_label_ch": 2, "accepts_costmap": True}, + }, + "MaskedCrossEntropy": { + "source": "aicsmlsegment.custom_loss", + "args": [], + "wrapper_args": { + "n_label_ch": 1, + "accepts_costmap": True, + "cmap_unsqueeze": True, + "label_squeeze": True, + "to_long": True, + }, + }, + "MultiAuxillaryCrossEntropy": { + "source": "aicsmlsegment.custom_loss", + "args": ["weight", "num_class"], + "wrapper_args": { + "n_label_ch": 1, + "accepts_costmap": True, + "cmap_unsqueeze": True, + "label_squeeze": True, + "to_long": True, + }, + }, +} + + +def get_loss_criterion(config: Dict): + """ + Returns the loss function based on provided configuration + + Parameters + ---------- + config: Dict + a top level configuration object containing the 'loss' key + + Return: + ------------- + an instance of the loss function and whether it accepts a costmap and loss weights + """ + import importlib + + name = config["loss"]["name"] + # backwards compatibility + if name == "Aux": + name = "MultiAuxillaryCrossEntropy" + + loss_names = [name] + if "+" in name: + loss_names = name.split("+") + losses = [] + costmap = [] + for ln in loss_names: + assert ln in SUPPORTED_LOSSES, ( + f"Invalid loss: {ln}. Supported losses: " + f"{[key for key in SUPPORTED_LOSSES]} or combinations as 'l1+l2'" + ) + loss_info = SUPPORTED_LOSSES[ln] + + init_args = loss_info["args"] + + module = importlib.import_module(loss_info["source"]) + module = getattr(module, ln + "Loss") + args = {} + if "softmax" in init_args: + args["softmax"] = True + if "num_task" in init_args: + args["num_task"] = len(config["model"]["nclass"]) + if "weight" in init_args: + args["weight"] = config["loss"]["loss_weight"] + if "num_class" in init_args: + args["num_class"] = config["model"]["nclass"] + if "include_background" in init_args: + args["include_background"] = False + loss = module(**args) + wrapped_loss = LossWrapper(loss, **loss_info["wrapper_args"]) + losses.append(wrapped_loss) + costmap.append(loss_info["wrapper_args"]["accepts_costmap"]) + + if len(losses) == 2: + from aicsmlsegment.custom_loss import CombinedLoss + + return CombinedLoss(*losses), np.any(costmap) + + else: + return losses[0], np.any(costmap) + + +class LossWrapper(torch.nn.Module): + def __init__( + self, + loss, + n_label_ch: int, + accepts_costmap: bool, + to_long: bool = False, + cmap_unsqueeze: bool = False, + label_squeeze: bool = False, + ): + """ + Standardize how unnormalized logits transformed for each loss function + + Parameters + ---------- + loss: loss function + n_label_ch: number of channels that are expected for label by loss function + accepts_costmap: whether loss function expects a costmap + to_long: whether to convert label to long tensor + cmap_unsqueeze: whether to add channel dimension to costmap + label_squeeze: whether to remove channel dimension from label + + Return: wrapped loss function + """ + super(LossWrapper, self).__init__() + + self.loss = loss + self.n_label_ch = n_label_ch + self.cmap_unsqueeze = cmap_unsqueeze + self.label_squeeze = label_squeeze + self.accepts_costmap = accepts_costmap + self.to_long = to_long + + def forward(self, input, target, cmap=None): + if self.n_label_ch == 2: + target = torch.squeeze(torch.stack([1 - target, target], dim=1), dim=2) + if self.cmap_unsqueeze: + cmap = torch.unsqueeze(cmap, dim=1) + if self.to_long: + target = target.long() + if self.label_squeeze: + target = torch.squeeze(target, dim=1) + # THIS HAPPENS ON LAST OUTPUT OF extended_dynunet, not sure why + if type(input) == tuple: + input = input[0] + if self.accepts_costmap and cmap is not None: + loss = self.loss(input, target, cmap) + else: + loss = self.loss(input, target) + return loss + + +class CombinedLoss(torch.nn.Module): + def __init__(self, loss1, loss2): + super(CombinedLoss, self).__init__() + self.loss1 = loss1 + self.loss2 = loss2 + + def forward(self, input, target, cmap=None): + return self.loss1(input, target, cmap) + self.loss2(input, target, cmap) + + +class MaskedCrossEntropyLoss(torch.nn.Module): + def __init__(self): + super(MaskedCrossEntropyLoss, self).__init__() + self.loss = torch.nn.NLLLoss(reduction="none") + self.log_softmax = torch.nn.LogSoftmax(dim=1) + + def forward(self, input, target, cmap): + """ + expects input, target, cmap in NCZYX with input channels=2, target_channels=1 + """ + loss = self.loss(self.log_softmax(input), target) + loss = torch.mean(torch.mul(loss.view(loss.numel()), cmap.view(cmap.numel()))) + return loss + + +class MultiAuxillaryCrossEntropyLoss(torch.nn.Module): + def __init__(self, weight, num_class): + super(MultiAuxillaryCrossEntropyLoss, self).__init__() + self.weight = weight + self.loss_fn = MaskedCrossEntropyLoss() + + def forward(self, input, target, cmap): + if not isinstance(input, list): # custom model validation + input = [input] + total_loss = self.weight[0] * self.loss_fn(input[0], target, cmap) + for n in np.arange(1, len(input)): + total_loss += self.weight[n] * self.loss_fn(input[n], target, cmap) + + return total_loss + class ElementNLLLoss(torch.nn.Module): - def __init__(self, num_class): - super(ElementNLLLoss,self).__init__() - self.num_class = num_class - - def forward(self, input, target, weight): + def __init__(self, num_class): + super(ElementNLLLoss, self).__init__() + self.num_class = num_class + + def forward(self, input, target, weight): + target_np = target.detach().cpu().data.numpy() + target_np = target_np.astype(np.uint8) + + row_num = target_np.shape[0] + mask = np.zeros((row_num, self.num_class)) + mask[np.arange(row_num), target_np] = 1 - target_np = target.cpu().data.numpy() - target_np = target_np.astype(np.uint8) + class_x = torch.masked_select( + input, Variable(torch.from_numpy(mask).cuda().bool()) + ) - row_num = target_np.shape[0] - mask = np.zeros((row_num,self.num_class )) - mask[np.arange(row_num), target_np]=1 - class_x = torch.masked_select(input, Variable(torch.from_numpy(mask).cuda().bool())) + out = torch.mul(class_x, weight) + loss = torch.mean(torch.neg(out), 0) - out = torch.mul(class_x,weight) - loss = torch.mean(torch.neg(out),0) + return loss - return loss class MultiAuxillaryElementNLLLoss(torch.nn.Module): - def __init__(self,num_task, weight, num_class): - super(MultiAuxillaryElementNLLLoss,self).__init__() - self.num_task = num_task - self.weight = weight + def __init__(self, num_task, weight, num_class): + super(MultiAuxillaryElementNLLLoss, self).__init__() + self.num_task = num_task + self.weight = weight + + self.criteria_list = [] + for n in range(self.num_task): + self.criteria_list.append(ElementNLLLoss(num_class[n])) + + def forward(self, input, target, cmap): - self.criteria_list = [] - for nn in range(self.num_task): - self.criteria_list.append(ElementNLLLoss(num_class[nn])) - - def forward(self, input, target, cmap): + total_loss = self.weight[0] * self.criteria_list[0]( + input[0], target.view(target.numel()), cmap.view(cmap.numel()) + ) - total_loss = self.weight[0]*self.criteria_list[0](input[0], target.view(target.numel()), cmap.view(cmap.numel()) ) + for n in np.arange(1, self.num_task): + total_loss = total_loss + self.weight[n] * self.criteria_list[n]( + input[n], target.view(target.numel()), cmap.view(cmap.numel()) + ) - for nn in np.arange(1,self.num_task): - total_loss = total_loss + self.weight[nn]*self.criteria_list[nn](input[nn], target.view(target.numel()), cmap.view(cmap.numel()) ) + return total_loss - return total_loss class MultiTaskElementNLLLoss(torch.nn.Module): - def __init__(self, weight, num_class): - super(MultiTaskElementNLLLoss,self).__init__() - self.num_task = len(num_class) - self.weight = weight + def __init__(self, weight, num_class): + super(MultiTaskElementNLLLoss, self).__init__() + self.num_task = len(num_class) + self.weight = weight + + self.criteria_list = [] + for n in range(self.num_task): + self.criteria_list.append(ElementNLLLoss(num_class[n])) + + def forward(self, input, target, cmap): + + assert len(target) == self.num_task and len(input) == self.num_task + + total_loss = self.weight[0] * self.criteria_list[0]( + input[0], target[0].view(target[0].numel()), cmap.view(cmap.numel()) + ) - self.criteria_list = [] - for nn in range(self.num_task): - self.criteria_list.append(ElementNLLLoss(num_class[nn])) - - def forward(self, input, target, cmap): + for n in np.arange(1, self.num_task): + total_loss = total_loss + self.weight[n] * self.criteria_list[n]( + input[n], target[n].view(target[n].numel()), cmap.view(cmap.numel()) + ) - assert len(target) == self.num_task and len(input) == self.num_task + return total_loss - total_loss = self.weight[0]*self.criteria_list[0](input[0], target[0].view(target[0].numel()), cmap.view(cmap.numel()) ) - for nn in np.arange(1,self.num_task): - total_loss = total_loss + self.weight[nn]*self.criteria_list[nn](input[nn], target[nn].view(target[nn].numel()), cmap.view(cmap.numel()) ) +class MaskedMSELoss(torch.nn.Module): + def __init__(self): + super(MaskedMSELoss, self).__init__() + + def forward(self, input, target, weight): + return ( + torch.sum(torch.mul((input - target) ** 2, weight)) + / torch.gt(weight, 0).data.nelement() + ) - return total_loss class ElementAngularMSELoss(torch.nn.Module): - def __init__(self): - super(ElementAngularMSELoss,self).__init__() - - def forward(self, input, target, weight): - - #((input - target) ** 2).sum() / input.data.nelement() - - return torch.sum( torch.mul( torch.acos(torch.sum(torch.mul(input,target),dim=1))**2, weight) )/ torch.gt(weight,0).data.nelement() - -def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, weight=None): + def __init__(self): + super(ElementAngularMSELoss, self).__init__() + + def forward(self, input, target, weight): + + # ((input - target) ** 2).sum() / input.data.nelement() + return ( + torch.sum( + torch.mul( + torch.acos(torch.sum(torch.mul(input, target), dim=1)) ** 2, weight + ) + ) + / torch.gt(weight, 0).data.nelement() + ) + + +def compute_per_channel_dice( + input, target, epsilon=1e-5, ignore_index=None, weight=None +): # assumes that input is a normalized probability # input and target shapes must match - assert input.size() == target.size(), "'input' and 'target' must have the same shape" + assert ( + input.size() == target.size() + ), "'input' and 'target' must have the same shape" # mask ignore_index if present if ignore_index is not None: @@ -99,7 +372,7 @@ def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, wei intersect = weight * intersect denominator = (input + target).sum(-1) - return 2. * intersect / denominator.clamp(min=epsilon) + return 2.0 * intersect / denominator.clamp(min=epsilon) class DiceLoss(nn.Module): @@ -107,17 +380,25 @@ class DiceLoss(nn.Module): Additionally allows per-class weights to be provided. """ - def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True, - skip_last_target=False): + def __init__( + self, + epsilon=1e-5, + weight=None, + ignore_index=None, + sigmoid_normalization=True, + skip_last_target=False, + ): super(DiceLoss, self).__init__() self.epsilon = epsilon - self.register_buffer('weight', weight) + self.register_buffer("weight", weight) self.ignore_index = ignore_index - # The output from the network during training is assumed to be un-normalized probabilities and we would - # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, - # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. - # However if one would like to apply Softmax in order to get the proper probability distribution from the - # output, just specify sigmoid_normalization=False. + # The output from the network during training is assumed to be un-normalized + # probabilities and we would like to normalize the logits. Since Dice + # (or soft Dice in this case) is usually used for binary data, normalizing + # the channels with Sigmoid is the default choice even for multi-class + # segmentation problems. However if one would like to apply Softmax in order + # to get the proper probability distribution from the output, just specify + # sigmoid_normalization=False. if sigmoid_normalization: self.normalization = nn.Sigmoid() else: @@ -136,20 +417,29 @@ def forward(self, input, target): if self.skip_last_target: target = target[:, :-1, ...] - per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index, - weight=weight) + per_channel_dice = compute_per_channel_dice( + input, + target, + epsilon=self.epsilon, + ignore_index=self.ignore_index, + weight=weight, + ) # Average the Dice score across all channels/classes - return torch.mean(1. - per_channel_dice) + return torch.mean(1.0 - per_channel_dice) class GeneralizedDiceLoss(nn.Module): - """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + Computes Generalized Dice Loss (GDL) as described in + https://arxiv.org/pdf/1707.03237.pdf """ - def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True): + def __init__( + self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True + ): super(GeneralizedDiceLoss, self).__init__() self.epsilon = epsilon - self.register_buffer('weight', weight) + self.register_buffer("weight", weight) self.ignore_index = ignore_index if sigmoid_normalization: self.normalization = nn.Sigmoid() @@ -160,7 +450,9 @@ def forward(self, input, target): # get probabilities from logits input = self.normalization(input) - assert input.size() == target.size(), "'input' and 'target' must have the same shape" + assert ( + input.size() == target.size() + ), "'input' and 'target' must have the same shape" # mask ignore_index if present if self.ignore_index is not None: @@ -175,7 +467,9 @@ def forward(self, input, target): target = target.float() target_sum = target.sum(-1) - class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.epsilon), requires_grad=False) + class_weights = Variable( + 1.0 / (target_sum * target_sum).clamp(min=self.epsilon), requires_grad=False + ) intersect = (input * target).sum(-1) * class_weights if self.weight is not None: @@ -184,16 +478,18 @@ def forward(self, input, target): denominator = (input + target).sum(-1) * class_weights - return torch.mean(1. - 2. * intersect / denominator.clamp(min=self.epsilon)) + return torch.mean(1.0 - 2.0 * intersect / denominator.clamp(min=self.epsilon)) class WeightedCrossEntropyLoss(nn.Module): - """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + WeightedCrossEntropyLoss (WCE) as described + in https://arxiv.org/pdf/1707.03237.pdf """ def __init__(self, weight=None, ignore_index=-1): super(WeightedCrossEntropyLoss, self).__init__() - self.register_buffer('weight', weight) + self.register_buffer("weight", weight) self.ignore_index = ignore_index def forward(self, input, target): @@ -201,14 +497,16 @@ def forward(self, input, target): if self.weight is not None: weight = Variable(self.weight, requires_grad=False) class_weights = class_weights * weight - return F.cross_entropy(input, target, weight=class_weights, ignore_index=self.ignore_index) + return F.cross_entropy( + input, target, weight=class_weights, ignore_index=self.ignore_index + ) @staticmethod def _class_weights(input): # normalize the input first input = F.softmax(input, _stacklevel=5) flattened = flatten(input) - nominator = (1. - flattened).sum(-1) + nominator = (1.0 - flattened).sum(-1) denominator = flattened.sum(-1) class_weights = Variable(nominator / denominator, requires_grad=False) return class_weights @@ -216,12 +514,16 @@ def _class_weights(input): class BCELossWrapper: """ - Wrapper around BCE loss functions allowing to pass 'ignore_index' as well as 'skip_last_target' option. + Wrapper around BCE loss functions allowing to pass 'ignore_index' + as well as 'skip_last_target' option. """ def __init__(self, loss_criterion, ignore_index=-1, skip_last_target=False): - if hasattr(loss_criterion, 'ignore_index'): - raise RuntimeError(f"Cannot wrap {type(loss_criterion)}. Use 'ignore_index' attribute instead") + if hasattr(loss_criterion, "ignore_index"): + raise RuntimeError( + f"Cannot wrap {type(loss_criterion)}. " + "Use 'ignore_index' attribute instead" + ) self.loss_criterion = loss_criterion self.ignore_index = ignore_index self.skip_last_target = skip_last_target @@ -247,7 +549,7 @@ def __call__(self, input, target): class PixelWiseCrossEntropyLoss(nn.Module): def __init__(self, class_weights=None, ignore_index=None): super(PixelWiseCrossEntropyLoss, self).__init__() - self.register_buffer('class_weights', class_weights) + self.register_buffer("class_weights", class_weights) self.ignore_index = ignore_index self.log_softmax = nn.LogSoftmax(dim=1) @@ -255,15 +557,21 @@ def forward(self, input, target, weights): assert target.size() == weights.size() # normalize the input log_probabilities = self.log_softmax(input) - # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) - target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) + + # standard CrossEntropyLoss requires the target to be (NxDxHxW), + # so we need to expand it to (NxCxDxHxW) + target = expand_as_one_hot( + target[:, 0, :, :, :], C=input.size()[1], ignore_index=self.ignore_index + ) # expand weights - weights = weights.unsqueeze(0) + # weights = weights.unsqueeze(0) weights = weights.expand_as(input) # mask ignore_index if present if self.ignore_index is not None: - mask = Variable(target.data.ne(self.ignore_index).float(), requires_grad=False) + mask = Variable( + target.data.ne(self.ignore_index).float(), requires_grad=False + ) log_probabilities = log_probabilities * mask target = target * mask @@ -276,7 +584,6 @@ def forward(self, input, target, weights): class_weights = Variable(class_weights, requires_grad=False) # add class_weights to each channel weights = class_weights + weights - # compute the losses result = -weights * target * log_probabilities # average the losses @@ -299,7 +606,8 @@ def flatten(tensor): def expand_as_one_hot(input, C, ignore_index=None): """ - Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in a separate channel + Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in + a separate channel :param input: 4D input image (NxDxHxW) :param C: number of channels/labels :param ignore_index: ignore index to be kept during the expansion @@ -329,4 +637,5 @@ def expand_as_one_hot(input, C, ignore_index=None): return result else: # scatter to get the one-hot tensor - return torch.zeros(shape).to(input.device).scatter_(1, src, 1) \ No newline at end of file + src = src.type(torch.int64) + return torch.zeros(shape).to(input.device).scatter_(1, src, 1) diff --git a/aicsmlsegment/custom_metrics.py b/aicsmlsegment/custom_metrics.py index 55191ef..10026f6 100644 --- a/aicsmlsegment/custom_metrics.py +++ b/aicsmlsegment/custom_metrics.py @@ -1,15 +1,59 @@ import numpy as np import torch from skimage import measure -from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss, compute_per_channel_dice, expand_as_one_hot +from aicsmlsegment.custom_loss import ( + compute_per_channel_dice, + expand_as_one_hot, +) + +SUPPORTED_METRICS = [ + "default", + "Dice", +] + + +def get_metric(config): + """ + Returns the metric function based on provided configuration + + Parameters + ---------- + config: Dict + a top level configuration object containing the 'validation' key + + Return: + ------------- + an instance of the validation metric function + """ + validation_config = config["validation"] + metric = validation_config["metric"] + + # validate the name of selected metric + assert ( + metric in SUPPORTED_METRICS + ), f"Invalid metric: {metric}. Supported metrics are: {SUPPORTED_METRICS}" + + if metric == "Dice": + from monai.metrics import DiceMetric + + return DiceMetric + elif metric == "default" or metric == "IOU": + return MeanIoU() + elif metric == "AveragePrecision": + return AveragePrecision() + class DiceCoefficient: """Computes Dice Coefficient. Generalized to multiple channels by computing per-channel Dice Score - (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average. + (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the + average. + Input is expected to be probabilities instead of logits. - This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). - DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. + This metric is mostly useful when channels contain the same semantic class (e.g. + affinities computed with different offsets). + DO NOT USE this metric when training with DiceLoss, otherwise the results will be + biased towards the loss. """ def __init__(self, epsilon=1e-5, ignore_index=None): @@ -19,11 +63,16 @@ def __init__(self, epsilon=1e-5, ignore_index=None): def __call__(self, input, target): """ :param input: 5D probability maps torch tensor (NxCxDxHxW) - :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot + :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be + expanded to 5D as one-hot :return: Soft Dice Coefficient averaged over all channels/classes """ # Average across channels in order to get the final score - return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index)) + return torch.mean( + compute_per_channel_dice( + input, target, epsilon=self.epsilon, ignore_index=self.ignore_index + ) + ) class MeanIoU: @@ -33,7 +82,8 @@ class MeanIoU: def __init__(self, skip_channels=(), ignore_index=None): """ - :param skip_channels: list/tuple of channels to be ignored from the IoU computation + :param skip_channels: list/tuple of channels to be ignored from the IoU + computation :param ignore_index: id of the label to be ignored from IoU computation """ self.ignore_index = ignore_index @@ -42,12 +92,15 @@ def __init__(self, skip_channels=(), ignore_index=None): def __call__(self, input, target): """ :param input: 5D probability maps torch float tensor (NxCxDxHxW) - :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot + :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be + expanded to 5D as one-hot :return: intersection over union averaged over all channels """ n_classes = input.size()[1] if target.dim() == 4: - target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) + target = expand_as_one_hot( + target, C=n_classes, ignore_index=self.ignore_index + ) # batch dim must be 1 input = input[0] @@ -78,8 +131,8 @@ def __call__(self, input, target): def _binarize_predictions(self, input): """ - Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the - same size as the input tensor. + Puts 1 for the class/channel with the highest probability and 0 in other + channels. Returns byte tensor of the same size as the input tensor. """ _, max_index = torch.max(input, dim=0, keepdim=True) return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) @@ -88,21 +141,34 @@ def _jaccard_index(self, prediction, target): """ Computes IoU for a given target and prediction tensors """ - return torch.sum(prediction & target).float() / torch.sum(prediction | target).float() + return ( + torch.sum(prediction & target).float() + / torch.sum(prediction | target).float() + ) class AveragePrecision: """ - Computes Average Precision given boundary prediction and ground truth instance segmentation. + Computes Average Precision given boundary prediction and ground truth instance + segmentation. """ - def __init__(self, threshold=0.4, iou_range=(0.5, 1.0), ignore_index=-1, min_instance_size=None, - use_last_target=False): + def __init__( + self, + threshold=0.4, + iou_range=(0.5, 1.0), + ignore_index=-1, + min_instance_size=None, + use_last_target=False, + ): """ - :param threshold: probability value at which the input is going to be thresholded - :param iou_range: compute ROC curve for the the range of IoU values: range(min,max,0.05) + :param threshold: probability value at which the input is going to be + thresholded + :param iou_range: compute ROC curve for the the range of IoU values: + range(min,max,0.05) :param ignore_index: label to be ignored during computation - :param min_instance_size: minimum size of the predicted instances to be considered + :param min_instance_size: minimum size of the predicted instances to be + considered :param use_last_target: if True use the last target channel to compute AP """ self.threshold = threshold @@ -116,8 +182,10 @@ def __init__(self, threshold=0.4, iou_range=(0.5, 1.0), ignore_index=-1, min_ins def __call__(self, input, target): """ - :param input: 5D probability maps torch float tensor (NxCxDxHxW) / or 4D numpy.ndarray - :param target: 4D or 5D ground truth instance segmentation torch long tensor / or 3D numpy.ndarray + :param input: 5D probability maps torch float tensor (NxCxDxHxW) / or + 4D numpy.ndarray + :param target: 4D or 5D ground truth instance segmentation torch long tensor / + or 3D numpy.ndarray :return: highest average precision among channels """ if isinstance(input, torch.Tensor): @@ -139,7 +207,8 @@ def __call__(self, input, target): if isinstance(target, np.ndarray): assert target.ndim == 3 - # filter small instances from the target and get ground truth label set (without 'ignore_index') + # filter small instances from the target and get ground truth label set + # (without 'ignore_index') target, target_instances = self._filter_instances(target) per_channel_ap = [] @@ -148,17 +217,18 @@ def __call__(self, input, target): predictions = input[c] # threshold probability maps predictions = predictions > self.threshold - # for connected component analysis we need to treat boundary signal as background + # for connected component analysis we need to treat boundary signal as + # background # assign 0-label to boundary mask predictions = np.logical_not(predictions).astype(np.uint8) - # run connected components on the predicted mask; consider only 1-connectivity + # run connected components on the predicted mask; consider only + # 1-connectivity predicted = measure.label(predictions, background=0, connectivity=1) ap = self._calculate_average_precision(predicted, target, target_instances) per_channel_ap.append(ap) # get maximum average precision across channels - max_ap, c_index = np.max(per_channel_ap), np.argmax(per_channel_ap) - #LOGGER.info(f'Max average precision: {max_ap}, channel: {c_index}') + max_ap, _ = np.max(per_channel_ap), np.argmax(per_channel_ap) return max_ap def _calculate_average_precision(self, predicted, target, target_instances): @@ -172,17 +242,19 @@ def _calculate_average_precision(self, predicted, target, target_instances): # see: https://www.jeremyjordan.me/evaluating-image-segmentation-models/ e.g. for i in range(len(precision) - 2, -1, -1): precision[i] = max(precision[i], precision[i + 1]) - # compute the area under precision recall curve by simple integration of piece-wise constant function + # compute the area under precision recall curve by simple integration of + # piece-wise constant function ap = 0.0 for i in range(1, len(recall)): - ap += ((recall[i] - recall[i - 1]) * precision[i]) + ap += (recall[i] - recall[i - 1]) * precision[i] return ap def _roc_curve(self, predicted, target, target_instances): ROC = [] predicted, predicted_instances = self._filter_instances(predicted) - # compute precision/recall curve points for various IoU values from a given range + # compute precision/recall curve points for various IoU values from a given + # range for min_iou in np.arange(self.iou_range[0], self.iou_range[1], 0.1): # initialize false negatives set false_negatives = set(target_instances) @@ -192,7 +264,9 @@ def _roc_curve(self, predicted, target, target_instances): true_positives = set() for pred_label in predicted_instances: - target_label = self._find_overlapping_target(pred_label, predicted, target, min_iou) + target_label = self._find_overlapping_target( + pred_label, predicted, target, min_iou + ) if target_label is not None: # update TP, FP and FN if target_label == self.ignore_index: @@ -218,8 +292,8 @@ def _roc_curve(self, predicted, target, target_instances): def _find_overlapping_target(self, predicted_label, predicted, target, min_iou): """ - Return ground truth label which overlaps by at least 'min_iou' with a given input label 'p_label' - or None if such ground truth label does not exist. + Return ground truth label which overlaps by at least 'min_iou' with a given + input label 'p_label' or None if such ground truth label does not exist. """ mask_predicted = predicted == predicted_label overlapping_labels = target[mask_predicted] @@ -227,8 +301,8 @@ def _find_overlapping_target(self, predicted_label, predicted, target, min_iou): # retrieve the biggest overlapping label target_label_ind = np.argmax(counts) target_label = labels[target_label_ind] - # return target label if IoU greater than 'min_iou'; since we're starting from 0.5 IoU there might be - # only one target label that fulfill this criterion + # return target label if IoU greater than 'min_iou'; since we're starting from + # 0.5 IoU there might be only one target label that fulfill this criterion mask_target = target == target_label # return target_label if IoU > min_iou if self._iou(mask_predicted, mask_target) > min_iou: @@ -246,9 +320,11 @@ def _iou(prediction, target): def _filter_instances(self, input): """ - Filters instances smaller than 'min_instance_size' by overriding them with 'ignore_index' + Filters instances smaller than 'min_instance_size' by overriding them with + 'ignore_index' :param input: input instance segmentation - :return: tuple: (instance segmentation with small instances filtered, set of unique labels without the 'ignore_index') + :return: tuple: (instance segmentation with small instances filtered, set of + unique labels without the 'ignore_index') """ if self.min_instance_size is not None: labels, counts = np.unique(input, return_counts=True) @@ -259,4 +335,4 @@ def _filter_instances(self, input): labels = set(np.unique(input)) labels.discard(self.ignore_index) - return input, labels \ No newline at end of file + return input, labels diff --git a/aicsmlsegment/fnet_prediction_torch.py b/aicsmlsegment/fnet_prediction_torch.py new file mode 100644 index 0000000..0bbcb9d --- /dev/null +++ b/aicsmlsegment/fnet_prediction_torch.py @@ -0,0 +1,145 @@ +from scipy.signal import triang +from typing import Union, List +import numpy as np +import torch +from typing import Sequence, Tuple + + +def _get_weights(shape: Sequence[int]) -> Tuple[np.ndarray, Tuple[int]]: + """ + Get triangular weights + + Parameters + ---------- + shape: CZYX shape + + Return: 1ZYX weights np.ndarray, CZYX shape + """ + shape_in = shape + shape = shape[1:] + weights = 1 + for idx_d in range(len(shape)): + slicey = [np.newaxis] * len(shape) + slicey[idx_d] = slice(None) + size = shape[idx_d] + weights = weights * triang(size)[tuple(slicey)] + return weights, shape_in + + +def _predict_piecewise_recurse( + predictor, + ar_in: np.ndarray, + dims_max: Union[int, List[int]], + overlaps: Union[int, List[int]], + mode: str = "fast", + **predict_kwargs, +): + """Performs piecewise prediction recursively.""" + if tuple(ar_in.shape[1:]) == tuple(dims_max[1:]): + ar_in = torch.unsqueeze(ar_in, dim=0) + ar_out = predictor.forward(ar_in, **predict_kwargs) + if isinstance(ar_out, list): + ar_out = ar_out[0] + ar_out = torch.squeeze( + ar_out, dim=0 + ) # remove N dimension so that multichannel outputs can be used + if mode != "fast": + ar_out = ar_out.detach().cpu() + weights, shape_in = _get_weights(ar_out.shape) + weights = torch.as_tensor(weights, dtype=ar_out.dtype, device=ar_out.device) + ar_weight = torch.broadcast_to(weights, shape_in) + return ar_out * ar_weight, weights + dim = None + # Find first dim where input > max + for idx_d in range(1, ar_in.ndim): + if ar_in.shape[idx_d] > dims_max[idx_d]: + dim = idx_d + break + # Size of channel dim is unknown until after first prediction + shape_out = [None] + list(ar_in.shape[1:]) + ar_out = None + ar_weight = None + offset = 0 + done = False + while not done: + slices = [slice(None)] * len(ar_in.shape) + end = offset + dims_max[dim] + slices[dim] = slice(offset, end) + slices = tuple(slices) + ar_in_sub = ar_in[slices] + pred_sub, pred_weight_sub = _predict_piecewise_recurse( + predictor, ar_in_sub, dims_max, overlaps, mode=mode, **predict_kwargs + ) + if ar_out is None or ar_weight is None: + shape_out[0] = pred_sub.shape[0] # Set channel dim for output + ar_out = torch.zeros( + shape_out, dtype=pred_sub.dtype, device=pred_sub.device + ) + ar_weight = torch.zeros( + shape_out[1:], dtype=pred_weight_sub.dtype, device=pred_sub.device + ) + ar_out[slices] += pred_sub + ar_weight[slices[1:]] += pred_weight_sub + offset += dims_max[dim] - overlaps[dim] + if end == ar_in.shape[dim]: + done = True + elif offset + dims_max[dim] > ar_in.shape[dim]: + offset = ar_in.shape[dim] - dims_max[dim] + return ar_out, ar_weight + + +def predict_piecewise( + predictor, + tensor_in: torch.Tensor, + dims_max: Union[int, List[int]] = 64, + overlaps: Union[int, List[int]] = 0, + mode: str = "fast", + **predict_kwargs, +) -> torch.Tensor: + """Performs piecewise prediction and combines results. + Parameters + ---------- + predictor + An object with a predict() method. + tensor_in + Tensor to be input into predictor piecewise. Should be 3d or 4d with + with the first dimension channel. + dims_max + Specifies dimensions of each sub prediction. + overlaps + Specifies overlap along each dimension for sub predictions. + **predict_kwargs + Kwargs to pass to predict method. + Returns + ------- + torch.Tensor + Prediction with size tensor_in.size(). + """ + assert isinstance(tensor_in, torch.Tensor) + assert len(tensor_in.size()) > 2 + shape_in = tuple(tensor_in.size()) + n_dim = len(shape_in) + if isinstance(dims_max, int): + dims_max = [dims_max] * n_dim + for idx_d in range(1, n_dim): + if dims_max[idx_d] > shape_in[idx_d]: + dims_max[idx_d] = shape_in[idx_d] + if isinstance(overlaps, int): + overlaps = [overlaps] * n_dim + assert len(dims_max) == len(overlaps) == n_dim + # Remove restrictions on channel dimension. + dims_max[0] = None + overlaps[0] = None + ar_out, ar_weight = _predict_piecewise_recurse( + predictor, + tensor_in, + dims_max=dims_max, + overlaps=overlaps, + mode=mode, + **predict_kwargs, + ) + + weight_corrected = torch.unsqueeze(ar_out / ar_weight, dim=0) + if mode != "fast": + weight_corrected = weight_corrected.float() + return weight_corrected diff --git a/aicsmlsegment/model_utils.py b/aicsmlsegment/model_utils.py index 87c0a4b..8802bac 100644 --- a/aicsmlsegment/model_utils.py +++ b/aicsmlsegment/model_utils.py @@ -1,183 +1,244 @@ import numpy as np import torch -from torch.autograd import Variable -import time -import logging -import os -import shutil -import sys +from pathlib import Path, PurePosixPath +from aicsmlsegment.multichannel_sliding_window import sliding_window_inference +from aicsmlsegment.fnet_prediction_torch import predict_piecewise -from aicsmlsegment.utils import get_logger -SUPPORTED_MODELS = ['unet_xy_zoom', 'unet_xy'] - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv3d') != -1: - torch.nn.init.kaiming_normal_(m.weight) - m.bias.data.zero_() - -def apply_on_image(model, input_img, softmax, args): +def flip(img: np.ndarray, axis: int) -> torch.Tensor: + """ + Inputs: + img: image to be flipped + axis: axis along which to flip image. Should be indexed from the channel + dimension + Outputs: + (1,C,Z,Y,X)-shaped tensor + flip input img along axis + """ + out_img = img.detach().clone() + for ch_idx in range(out_img.shape[0]): + str_im = out_img[ch_idx, :, :, :] + out_img[ch_idx, :, :, :] = torch.flip(str_im, dims=[axis]) + + return torch.unsqueeze(out_img, dim=0) + + +def apply_on_image( + model, + input_img: torch.Tensor, + args: dict, + squeeze: bool, + to_numpy: bool, + softmax: bool, + model_name, + extract_output_ch: bool, +) -> np.ndarray: + """ + Highest level API to perform inference on an input image through a model with + or without runtime augmentation. If runtime augmentation is selected (via + "RuntimeAug" in config yaml file), perform inference on both original image + and flipped images (3 version flipping along X, Y, Z) and average results. + + Inputs: + model: pytorch model with a forward method + model_name: the name of the model + input_img: tensor that model should be run on + args: Object containing inference arguments + RuntimeAug: boolean, if True inference is run on each of 4 flips + and final output is averaged across each of these augmentations + SizeOut: size of sliding window for inference + OutputCh: channel to extract label from + squeeze: boolean, if true removes the batch dimension in the output image + to_numpy: boolean, if true converts output to a numpy array and send to cpu + + Returns: 4 or 5 dimensional numpy array or tensor with result of model.forward + on input_img + """ - if not args.RuntimeAug: - return model_inference(model, input_img, softmax, args) + if not args["RuntimeAug"]: + return model_inference( + model, + input_img, + args, + model_name=model_name, + squeeze=squeeze, + to_numpy=to_numpy, + extract_output_ch=extract_output_ch, + softmax=softmax, + ) else: - from PIL import Image - print('doing runtime augmentation') - - input_img_aug = input_img.copy() - for ch_idx in range(input_img_aug.shape[0]): - str_im = input_img_aug[ch_idx,:,:,:] - input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=2) - - out1 = model_inference(model, input_img_aug, softmax, args) - - input_img_aug = [] - input_img_aug = input_img.copy() - for ch_idx in range(input_img_aug.shape[0]): - str_im = input_img_aug[ch_idx,:,:,:] - input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=1) - - out2 = model_inference(model, input_img_aug, softmax, args) - - input_img_aug = [] - input_img_aug = input_img.copy() - for ch_idx in range(input_img_aug.shape[0]): - str_im = input_img_aug[ch_idx,:,:,:] - input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=0) - - out3 = model_inference(model, input_img_aug, softmax, args) - - out0 = model_inference(model, input_img, softmax, args) - - for ch_idx in range(len(out0)): - out0[ch_idx] = 0.25*(out0[ch_idx] + np.flip(out1[ch_idx], axis=3) + np.flip(out2[ch_idx], axis=2) + np.flip(out3[ch_idx], axis=1)) - - return out0 - -def model_inference(model, input_img, softmax, args): - - model.eval() - - if args.size_in == args.size_out: - img_pad = np.np.expand_dims(input_img, axis=0) # add batch dimension - else: # zero padding on input image - padding = [(x-y)//2 for x,y in zip(args.size_in, args.size_out)] - img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'symmetric')#'constant') - img_pad = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') - - output_img = [] - for ch_idx in range(len(args.OutputCh)//2): - output_img.append(np.zeros(input_img.shape)) - - # loop through the image patch by patch - num_step_z = int(np.floor(input_img.shape[1]/args.size_out[0])+1) - num_step_y = int(np.floor(input_img.shape[2]/args.size_out[1])+1) - num_step_x = int(np.floor(input_img.shape[3]/args.size_out[2])+1) - - with torch.no_grad(): - for ix in range(num_step_x): - if ix= 2: + args["OutputCh"] = args["OutputCh"][1] + result = result[:, args["OutputCh"], :, :, :] + if not squeeze: + result = torch.unsqueeze(result, dim=1) + if to_numpy: + result = result.detach().cpu().numpy() + return result, vae_loss - def log_info(message): - if logger is not None: - logger.info(message) - if not os.path.exists(checkpoint_dir): - log_info( - f"Checkpoint directory does not exists. Creating {checkpoint_dir}") - os.mkdir(checkpoint_dir) - - file_path = checkpoint_dir + os.sep + 'checkpoint_epoch_' + str(state['epoch']) + '.pytorch' - log_info(f"Saving checkpoint at epoch {state['epoch']} to '{file_path}'") - torch.save(state, file_path) +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv3d") != -1: + torch.nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() def load_checkpoint(checkpoint_path, model): - """Loads model from a given checkpoint_path + """Loads model from a given checkpoint_path, included for backwards compatibility Args: checkpoint_path (string): path to the checkpoint to be loaded model (torch.nn.Module): model into which the parameters are to be copied Returns: state """ + import os + if not os.path.exists(checkpoint_path): raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") - #device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - if 'model_state_dict' in state: - model.load_state_dict(state['model_state_dict']) + # device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + + if "model_state_dict" in state: + try: + model.load_state_dict(state["model_state_dict"]) + except RuntimeError: + # HACK all keys need "model." appended to them sometimes + new_state_dict = {} + for key in state["model_state_dict"]: + new_state_dict["model." + key] = state["model_state_dict"][key] + model.load_state_dict(new_state_dict) else: model.load_state_dict(state) - #TODO: add an option to load training status + # TODO: add an option to load training status return state - - -def get_number_of_learnable_parameters(model): - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - return sum([np.prod(p.size()) for p in model_parameters]) - -def build_model(config): - - assert 'model' in config, 'Could not find model configuration' - model_config = config['model'] - name = model_config['name'] - assert name in SUPPORTED_MODELS, f'Invalid model: {name}. Supported models: {SUPPORTED_MODELS}' - - if name == 'unet_xy': - from aicsmlsegment.Net3D.unet_xy import UNet3D as DNN - model = DNN(config['nchannel'], config['nclass']) - elif name =='unet_xy_zoom': - from aicsmlsegment.Net3D.unet_xy_enlarge import UNet3D as DNN - model = DNN(config['nchannel'], config['nclass'], model_config.get('zoom_ratio',3)) - - model = model.apply(weights_init) - print('model initialization succeeds !') - model = model.to(config['device']) - return model \ No newline at end of file diff --git a/aicsmlsegment/multichannel_sliding_window.py b/aicsmlsegment/multichannel_sliding_window.py new file mode 100644 index 0000000..95c8ca6 --- /dev/null +++ b/aicsmlsegment/multichannel_sliding_window.py @@ -0,0 +1,192 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +from monai.data.utils import ( + compute_importance_map, + dense_patch_slices, + get_valid_patch_size, +) +from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple + +__all__ = ["sliding_window_inference"] + + +def sliding_window_inference( + inputs: torch.Tensor, + roi_size: Union[Sequence[int], int], + out_size, + original_image_size, + model_name, + sw_batch_size: int, + predictor: Callable[..., torch.Tensor], + overlap: float = 0.25, + mode: Union[BlendMode, str] = BlendMode.CONSTANT, + sigma_scale: Union[Sequence[float], float] = 0.125, + padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + cval: float = 0.0, + sw_device: Union[torch.device, str, None] = None, + device: Union[torch.device, str, None] = None, + *args: Any, + **kwargs: Any, +) -> torch.Tensor: + """ + modified from https://docs.monai.io/en/latest/_modules/monai/inferers/utils.html#sliding_window_inference # noqa E501 + + """ + num_spatial_dims = len(inputs.shape) - 2 + if overlap < 0 or overlap >= 1: + raise AssertionError("overlap must be >= 0 and < 1.") + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + image_size_ = list(inputs.shape[2:]) + batch_size = inputs.shape[0] + + if device is None: + device = inputs.device + if sw_device is None: + sw_device = inputs.device + + roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size + image_size = tuple( + max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims) + ) + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + inputs = F.pad( + inputs, pad=pad_size, mode=PytorchPadMode(padding_mode).value, value=cval + ) + + # CHANGED + scan_interval = _get_scan_interval( + original_image_size, out_size, num_spatial_dims, overlap + ) + + # Store all slices in list + slices = dense_patch_slices(image_size, roi_size, scan_interval) + num_win = len(slices) # number of windows per image + total_slices = num_win * batch_size # total number of window + ########################### + # EDIT # + ############################# + + out_slices = dense_patch_slices(original_image_size, out_size, scan_interval) + # Create window-level importance map + importance_map = compute_importance_map( + get_valid_patch_size(original_image_size, out_size), + mode=mode, + sigma_scale=sigma_scale, + device=device, + ) + + # Perform predictions + output_image, count_map = torch.tensor(0.0, device=device), torch.tensor( + 0.0, device=device + ) + _initialized = False + vae_loss = 0 + for slice_g in range(0, total_slices, sw_batch_size): + slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + + # coordinates of patch in input image + unravel_slice = [ + [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + + list(slices[idx % num_win]) + for idx in slice_range + ] + + # coordinates of patch in output image + unravel_slice_out = [ + [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + + list(out_slices[idx % num_win]) + for idx in slice_range + ] + + window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to( + sw_device + ) + seg_prob = predictor(window_data, *args, **kwargs) + + # old models output a list of three predictions + if "unet_xy" in model_name and isinstance(seg_prob, list): + seg_prob = seg_prob[0] + elif model_name == "dynunet": + seg_prob = seg_prob[0] + elif model_name == "segresnetvae": # segresnetvae + seg_prob, loss = seg_prob + if loss: + vae_loss += loss + + seg_prob = seg_prob.to(device) # batched patch segmentation + if not _initialized: # init. buffer at the first iteration + output_classes = seg_prob.shape[1] + output_shape = [batch_size, output_classes] + list(original_image_size) + # allocate memory to store the full output and the count for + # overlapping parts + output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) + count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) + _initialized = True + + # store the result in the proper location of the full output. Apply weights + # from importance map. + for idx, original_idx in zip(slice_range, unravel_slice_out): + output_image[original_idx] += importance_map * seg_prob[idx - slice_g] + count_map[original_idx] += importance_map + + # account for any overlapping sections + output_image = output_image / count_map + + final_slicing: List[slice] = [] + for sp in range(num_spatial_dims): + slice_dim = slice( + pad_size[sp * 2], + original_image_size[num_spatial_dims - sp - 1] + pad_size[sp * 2], + ) + final_slicing.insert(0, slice_dim) + while len(final_slicing) < len(output_image.shape): + final_slicing.insert(0, slice(None)) + + return output_image[final_slicing], vae_loss + + +def _get_scan_interval( + image_size: Sequence[int], + roi_size: Sequence[int], + num_spatial_dims: int, + overlap: float, +) -> Tuple[int, ...]: + """ + Compute scan interval according to the image size, roi size and overlap. + Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, + use 1 instead to make sure sliding window works. + + """ + if len(image_size) != num_spatial_dims: + raise ValueError("image coord different from spatial dims.") + if len(roi_size) != num_spatial_dims: + raise ValueError("roi coord different from spatial dims.") + + scan_interval = [] + for i in range(num_spatial_dims): + if roi_size[i] == image_size[i]: + scan_interval.append(int(roi_size[i])) + else: + interval = int(roi_size[i] * (1 - overlap)) + scan_interval.append(interval if interval > 0 else 1) + return tuple(scan_interval) diff --git a/aicsmlsegment/tests/__init__.py b/aicsmlsegment/tests/__init__.py index e69de29..12dd4f2 100644 --- a/aicsmlsegment/tests/__init__.py +++ b/aicsmlsegment/tests/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +"""Unit test package for segmenter_model_zoo.""" diff --git a/aicsmlsegment/tests/data/example_values.json b/aicsmlsegment/tests/data/example_values.json new file mode 100644 index 0000000..a95357b --- /dev/null +++ b/aicsmlsegment/tests/data/example_values.json @@ -0,0 +1,4 @@ +{ + "start_val": 1, + "next_val": 2 +} diff --git a/aicsmlsegment/tests/dummy_test.py b/aicsmlsegment/tests/dummy_test.py deleted file mode 100644 index e721873..0000000 --- a/aicsmlsegment/tests/dummy_test.py +++ /dev/null @@ -1,7 +0,0 @@ -from time import sleep -from datetime import timedelta - - -def test_dummy(): - assert(True) - diff --git a/aicsmlsegment/tests/test_function.py b/aicsmlsegment/tests/test_function.py new file mode 100644 index 0000000..71d37e1 --- /dev/null +++ b/aicsmlsegment/tests/test_function.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +A simple example of a test file using a function. +NOTE: All test file names must have one of the two forms. +- `test_.py` +- '_test.py' + +Docs: https://docs.pytest.org/en/latest/ + https://docs.pytest.org/en/latest/goodpractices.html#conventions-for-python-test-discovery +""" +import numpy as np + + +# This test just checks to see if the raw step instantiates and runs +def test_dummy(n=3): + arr = np.ones((3, 3), dtype=np.uint8) + assert arr.shape[0] == n diff --git a/aicsmlsegment/training_utils.py b/aicsmlsegment/training_utils.py index f4dfb4e..105ca0e 100644 --- a/aicsmlsegment/training_utils.py +++ b/aicsmlsegment/training_utils.py @@ -1,317 +1,33 @@ -#import torch.nn.functional as F -#from torch import nn as nn -#from torch.autograd import Variable -import logging -import os - -import numpy as np -import torch -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader -from torch.autograd import Variable -import torch.optim as optim -import importlib -import random -from glob import glob -from tqdm import tqdm - -from aicsimageio import imread - -from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss -from aicsmlsegment.custom_metrics import DiceCoefficient, MeanIoU, AveragePrecision -from aicsmlsegment.model_utils import load_checkpoint, save_checkpoint, model_inference -from aicsmlsegment.utils import compute_iou, get_logger, load_single_image, input_normalization - -SUPPORTED_LOSSES = ['Aux'] - -def build_optimizer(config, model): - learning_rate = config['learning_rate'] - weight_decay = config['weight_decay'] - optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) - return optimizer - -def get_loss_criterion(config): - """ - Returns the loss function based on provided configuration - :param config: (dict) a top level configuration object containing the 'loss' key - :return: an instance of the loss function - """ - assert 'loss' in config, 'Could not find loss function configuration' - loss_config = config['loss'] - name = loss_config['name'] - assert name in SUPPORTED_LOSSES, f'Invalid loss: {name}. Supported losses: {SUPPORTED_LOSSES}' - - #ignore_index = loss_config.get('ignore_index', None) - - #TODO: add more loss functions - if name == 'Aux': - return MultiAuxillaryElementNLLLoss(3, loss_config['loss_weight'], config['nclass']) - - -def get_train_dataloader(config): - assert 'loader' in config, 'Could not loader configuration' - name = config['loader']['name'] - if name == 'default': - from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader - return train_loader - else: - print('other loaders are under construction') - quit() - -def shuffle_split_filenames(datafolder, leaveout): - print('prepare the data ... ...') - filenames = glob(datafolder + '/*_GT.ome.tif') - filenames.sort() - total_num = len(filenames) - if len(leaveout)==1: - if leaveout[0]>0 and leaveout[0]<1: - num_train = int(np.floor((1-leaveout[0]) * total_num)) - shuffled_idx = np.arange(total_num) - random.shuffle(shuffled_idx) - train_idx = shuffled_idx[:num_train] - valid_idx = shuffled_idx[num_train:] - else: - valid_idx = [int(leaveout[0])] - train_idx = list(set(range(total_num)) - set(map(int, leaveout))) - elif leaveout: - valid_idx = list(map(int, leaveout)) - train_idx = list(set(range(total_num)) - set(valid_idx)) - - valid_filenames = [] - train_filenames = [] - for _, fn in enumerate(valid_idx): - valid_filenames.append(filenames[fn][:-11]) - for _, fn in enumerate(train_idx): - train_filenames.append(filenames[fn][:-11]) - - return train_filenames, valid_filenames - -class BasicFolderTrainer: - """basic version of trainer. - Args: - model: model to be trained - optimizer (nn.optim.Optimizer): optimizer used for training - loss_criterion (callable): loss function - loaders (dict): 'train' and 'val' loaders - checkpoint_dir (string): dir for saving checkpoints and tensorboard logs - """ - - def __init__(self, model, config, logger=None): - - if logger is None: - self.logger = get_logger('ModelTrainer', level=logging.DEBUG) - else: - self.logger = logger - - device = config['device'] - self.logger.info(f"Sending the model to '{device}'") - self.model = model.to(device) - self.logger.debug(model) - - #self.optimizer = optimizer - #self.scheduler = lr_scheduler - #self.loss_criterion = loss_criterion - self.device = device - #self.loaders = loaders - self.config = config - - - def train(self): - - ### load settings ### - config = self.config #TODO, fix this - model = self.model - - # define loss - #TODO, add more loss - loss_config = config['loss'] - if loss_config['name']=='Aux': - criterion = MultiAuxillaryElementNLLLoss(3,loss_config['loss_weight'], config['nclass']) - else: - print('do not support other loss yet') - quit() - - # dataloader - validation_config = config['validation'] - loader_config = config['loader'] - args_inference=lambda:None - if validation_config['metric'] is not None: - print('prepare the data ... ...') - filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif') - filenames.sort() - total_num = len(filenames) - LeaveOut = validation_config['leaveout'] - if len(LeaveOut)==1: - if LeaveOut[0]>0 and LeaveOut[0]<1: - num_train = int(np.floor((1-LeaveOut[0]) * total_num)) - shuffled_idx = np.arange(total_num) - random.shuffle(shuffled_idx) - train_idx = shuffled_idx[:num_train] - valid_idx = shuffled_idx[num_train:] - else: - valid_idx = [int(LeaveOut[0])] - train_idx = list(set(range(total_num)) - set(map(int, LeaveOut))) - elif LeaveOut: - valid_idx = list(map(int, LeaveOut)) - train_idx = list(set(range(total_num)) - set(valid_idx)) - - valid_filenames = [] - train_filenames = [] - for fi, fn in enumerate(valid_idx): - valid_filenames.append(filenames[fn][:-11]) - for fi, fn in enumerate(train_idx): - train_filenames.append(filenames[fn][:-11]) - - args_inference.size_in = config['size_in'] - args_inference.size_out = config['size_out'] - args_inference.OutputCh = validation_config['OutputCh'] - args_inference.nclass = config['nclass'] - - else: - #TODO, update here - print('need validation') - quit() - - if loader_config['name']=='default': - from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader - train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) - elif loader_config['name']=='focus': - from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader - train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) - else: - print('other loader not support yet') - quit() - - num_iterations = 0 - num_epoch = 0 #TODO: load num_epoch from checkpoint - - start_epoch = num_epoch - for _ in range(start_epoch, config['epochs']+1): - - # sets the model in training mode - model.train() - - optimizer = None - optimizer = optim.Adam(model.parameters(),lr = config['learning_rate'], weight_decay = config['weight_decay']) - - # check if re-load on training data in needed - if num_epoch>0 and num_epoch % loader_config['epoch_shuffle'] ==0: - print('shuffling data') - train_set_loader = None - train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) - - # Training starts ... - epoch_loss = [] - - for i, current_batch in tqdm(enumerate(train_set_loader)): - - inputs = Variable(current_batch[0].cuda()) - targets = current_batch[1] - outputs = model(inputs) - - if len(targets)>1: - for zidx in range(len(targets)): - targets[zidx] = Variable(targets[zidx].cuda()) - else: - targets = Variable(targets[0].cuda()) - - optimizer.zero_grad() - if len(current_batch)==3: # input + target + cmap - cmap = Variable(current_batch[2].cuda()) - loss = criterion(outputs, targets, cmap) - else: # input + target - loss = criterion(outputs,targets) - - loss.backward() - optimizer.step() - - epoch_loss.append(loss.data.item()) - num_iterations += 1 - - average_training_loss = sum(epoch_loss) / len(epoch_loss) - - # validation - if num_epoch % validation_config['validate_every_n_epoch'] ==0: - validation_loss = np.zeros((len(validation_config['OutputCh'])//2,)) - model.eval() - - for img_idx, fn in enumerate(valid_filenames): - - # target - label = np.squeeze(imread(fn+'_GT.ome.tif')) - label = np.expand_dims(label, axis=0) - - # input image - input_img = np.squeeze(imread(fn+'.ome.tif')) - if len(input_img.shape) == 3: - # add channel dimension - input_img = np.expand_dims(input_img, axis=0) - elif len(input_img.shape) == 4: - # assume number of channel < number of Z, make sure channel dim comes first - if input_img.shape[0] > input_img.shape[1]: - input_img = np.transpose(input_img, (1, 0, 2, 3)) - - # cmap tensor - costmap = np.squeeze(imread(fn+'_CM.ome.tif')) - - # output - outputs = model_inference(model, input_img, model.final_activation, args_inference) - - assert len(validation_config['OutputCh'])//2 == len(outputs) - - for vi in range(len(outputs)): - if label.shape[0]==1: # the same label for all output - validation_loss[vi] += compute_iou(outputs[vi][0,:,:,:]>0.5, label[0,:,:,:]==validation_config['OutputCh'][2*vi+1], costmap) - else: - validation_loss[vi] += compute_iou(outputs[vi][0,:,:,:]>0.5, label[vi,:,:,:]==validation_config['OutputCh'][2*vi+1], costmap) - - average_validation_loss = validation_loss / len(valid_filenames) - print(f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}') - else: - print(f'Epoch: {num_epoch}, Training Loss: {average_training_loss}') - - - if num_epoch % config['save_every_n_epoch'] == 0: - save_checkpoint({ - 'epoch': num_epoch, - 'num_iterations': num_iterations, - 'model_state_dict': model.state_dict(), - #'best_val_score': self.best_val_score, - 'optimizer_state_dict': optimizer.state_dict(), - 'device': str(self.device), - }, checkpoint_dir=config['checkpoint_dir'], logger=self.logger) - num_epoch += 1 - - # TODO: add validation step - - def _log_lr(self): - lr = self.optimizer.param_groups[0]['lr'] - self.writer.add_scalar('learning_rate', lr, self.num_iterations) - - def _log_stats(self, phase, loss_avg, eval_score_avg): - tag_value = { - f'{phase}_loss_avg': loss_avg, - f'{phase}_eval_score_avg': eval_score_avg - } - - for tag, value in tag_value.items(): - self.writer.add_scalar(tag, value, self.num_iterations) - - def _log_params(self): - self.logger.info('Logging model parameters and gradients') - for name, value in self.model.named_parameters(): - self.writer.add_histogram(name, value.data.cpu().numpy(), - self.num_iterations) - self.writer.add_histogram(name + '/grad', - value.grad.data.cpu().numpy(), - self.num_iterations) - - def _log_images(self, input, target, prediction): - sources = { - 'inputs': input.data.cpu().numpy(), - 'targets': target.data.cpu().numpy(), - 'predictions': prediction.data.cpu().numpy() - } - for name, batch in sources.items(): - for tag, image in self._images_from_batch(name, batch): - self.writer.add_image(tag, image, self.num_iterations, dataformats='HW') \ No newline at end of file +def _log_lr(self): + lr = self.optimizer.param_groups[0]["lr"] + self.writer.add_scalar("learning_rate", lr, self.num_iterations) + + +def _log_stats(self, phase, loss_avg, eval_score_avg): + tag_value = { + f"{phase}_loss_avg": loss_avg, + f"{phase}_eval_score_avg": eval_score_avg, + } + + for tag, value in tag_value.items(): + self.writer.add_scalar(tag, value, self.num_iterations) + + +def _log_params(self): + self.logger.info("Logging model parameters and gradients") + for name, value in self.model.named_parameters(): + self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) + self.writer.add_histogram( + name + "/grad", value.grad.data.cpu().numpy(), self.num_iterations + ) + + +def _log_images(self, input, target, prediction): + sources = { + "inputs": input.data.cpu().numpy(), + "targets": target.data.cpu().numpy(), + "predictions": prediction.data.cpu().numpy(), + } + for name, batch in sources.items(): + for tag, image in self._images_from_batch(name, batch): + self.writer.add_image(tag, image, self.num_iterations, dataformats="HW") diff --git a/aicsmlsegment/utils.py b/aicsmlsegment/utils.py index 4dff6c7..6349cf2 100644 --- a/aicsmlsegment/utils.py +++ b/aicsmlsegment/utils.py @@ -4,28 +4,348 @@ from typing import List from aicsimageio import AICSImage from scipy.ndimage import zoom -import os from scipy import ndimage as ndi from scipy import stats -import argparse - import yaml +import torch +from monai.networks.layers import Norm, Act +import os +import datetime + + +REQUIRED_CONFIG_FIELDS = { + True: { + "model": ["name"], + "checkpoint_dir": None, + "learning_rate": None, + "weight_decay": None, + "epochs": None, + "save_every_n_epoch": None, + "loss": ["name", "loss_weight"], + "loader": [ + "name", + "datafolder", + "batch_size", + "PatchPerBuffer", + "NumWorkers", + "Transforms", + ], + "validation": ["metric", "leaveout", "OutputCh", "validate_every_n_epoch"], + }, + False: { + "model": ["name"], + "model_path": None, + "OutputCh": None, + "OutputDir": None, + "InputCh": None, + "ResizeRatio": None, + "Normalization": None, + "Threshold": None, + "RuntimeAug": None, + "batch_size": None, + "mode": ["name"], + "NumWorkers": None, + }, +} +OPTIONAL_CONFIG_FIELDS = { + True: { + "resume": None, + "scheduler": ["name", "verbose"], + "gpus": None, + "dist_backend": None, + "callbacks": ["name"], + "SWA": ["swa_start", "swa_lr", "annealing_epochs", "annealing_strategy"], + "tensorboard": None, + "precision": None, + "loader": [ + "epoch_shuffle", + ], + }, + False: { + "gpus": None, + "dist_backend": None, + "model": ["norm", "act", "features", "dropout"], + "large_image_resize": None, + "precision": None, + "segmentation_name": None, + }, +} + +GPUS = torch.cuda.device_count() +DEFAULT_CONFIG = { + "SWA": None, + "resume": None, + "scheduler": {"name": None}, + "gpus": GPUS, + "dist_backend": "ddp" if GPUS > 1 else None, + "tensorboard": None, + "callbacks": {"name": None}, + "precision": 32, + "large_image_resize": [1, 1, 1], + "epoch_shuffle": None, + "segmentation_name": "segmentation", +} + +MODEL_PARAMETERS = { + "basic_unet": { + "Optional": [ + "features", + "act", + "norm", + "dropout", + ], + "Required": ["dimensions", "in_channels", "out_channels", "patch_size"], + }, + "unet_xy": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out"], + }, + "unet_xy_zoom": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "unet_xy_zoom_0pad": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "unet_xy_zoom_0pad_stridedconv": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "unet_xy_zoom_0pad_nopadz_stridedconv": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "unet_xy_zoom_stridedconv": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "unet_xy_zoom_dilated": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], + }, + "sdunet_xy": { + "Optional": [], + "Required": ["nchannel", "nclass", "size_in", "size_out"], + }, + "unet": { + "Optional": [ + "kernel_size", + "up_kernel_size", + "num_res_units", + "act", + "norm", + "dropout", + ], + "Required": [ + "dimensions", + "in_channels", + "out_channels", + "channels", + "strides", + "patch_size", + ], + }, + "dynunet": { + "Optional": ["norm_name", "deep_supr_num", "res_block", "deep_supervision"], + "Required": [ + "spatial_dims", + "in_channels", + "out_channels", + "kernel_size", + "strides", + "upsample_kernel_size", + "patch_size", + ], + }, + "extended_dynunet": { + "Optional": ["norm_name", "deep_supr_num", "res_block", "deep_supervision"], + "Required": [ + "spatial_dims", + "in_channels", + "out_channels", + "kernel_size", + "strides", + "upsample_kernel_size", + "patch_size", + ], + }, + "segresnet": { + "Optional": [ + "dropout_prob", + "norm_name", + "num_groups", + "use_conv_final", + "blocks_down", + "blocks_up", + "upsample_mode", + "init_filters", + ], + "Required": [ + "patch_size", + "spatial_dims", + "in_channels", + "out_channels", + ], + }, + "segresnetvae": { + "Optional": [ + "vae_estimate_std", + "vae_default_std", + "vae_nz", + "init_filters", + "dropout_prob", + "norm_name", + "num_groups", + "use_conv_final", + "blocks_down", + "blocks_up", + "upsample_mode", + ], + "Required": ["patch_size", "spatial_dims", "in_channels", "out_channels"], + }, + "extended_vnet": { + "Optional": ["act", "dropout_prob", "dropout_dim"], + "Required": [ + "spatial_dims", + "in_channels", + "out_channels", + "patch_size", + ], + }, +} + +ACTIVATIONS = { + "LeakyReLU": Act.LEAKYRELU, + "PReLU": Act.PRELU, + "ReLU": Act.RELU, + "ReLU6": Act.RELU6, +} + +NORMALIZATIONS = { + "batch": Norm.BATCH, + "instance": Norm.INSTANCE, +} + + +def get_model_configurations(config): + model_config = config["model"] + model_parameters = {} + + assert model_config["name"] in MODEL_PARAMETERS, ( + f"{model_config['name']} is not supported, supported model names " + f"are {list(MODEL_PARAMETERS.keys())}" + ) + all_parameters = MODEL_PARAMETERS[model_config["name"]] + + # allow users to overwrite specific parameters + for param in all_parameters["Optional"]: + # if optional parameters are not specified, skip them to use monai defaults + if param in model_config and not model_config[param] is None: + if param == "norm": + try: + model_parameters[param] = NORMALIZATIONS[model_config[param]] + except KeyError: + print(f"{model_config[param]} is not an acceptable normalization.") + quit() + elif param == "act": + try: + model_parameters[param] = ACTIVATIONS[model_config[param]] + except KeyError: + print(f"{model_config[param]} is not an acceptable activation.") + quit() + else: + model_parameters[param] = model_config[param] + # find parameters that must be included + for param in all_parameters["Required"]: + assert ( + param in model_config + ), f"{param} is required for model {model_config['name']}" + model_parameters[param] = model_config[param] + + return model_parameters + + +def validate_config(config, train): + # make sure that all required elements are in the config file + for key in REQUIRED_CONFIG_FIELDS[train]: + assert ( + key in config and not config[key] is None + ), f"{key} is required in the config file" + if REQUIRED_CONFIG_FIELDS[train][key]: + for key2 in REQUIRED_CONFIG_FIELDS[train][key]: + assert ( + key2 in config[key] and config[key][key2] is not None + ), f"{key2} is required in {key} configuration." + + # check for optional elements and replace them with defaults if not provided + for key in OPTIONAL_CONFIG_FIELDS[train]: + if key not in config or config[key] is None: + config[key] = DEFAULT_CONFIG[key] + + if GPUS == 1: + config["dist_backend"] = None + + model_config = get_model_configurations(config) + + return config, model_config + -def load_config(config_path): - import torch - config = _load_config_yaml(config_path) - # Get a device to train on - device_name = config.get('device', 'cuda:0') - device = torch.device(device_name if torch.cuda.is_available() else 'cpu') - config['device'] = device - return config +def load_config(config_file, train): + config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) + config, model_config = validate_config(config, train) + config["date"] = datetime.datetime.now().strftime("%b %d, %Y %H:%M") + return config, model_config -def _load_config_yaml(config_file): - return yaml.load(open(config_file, 'r')) +def create_unique_run_directory(config, train): + # directory to check for config in + subdir_names = {True: "/run_", False: "/prediction_"} + if train: + dir_name = config["checkpoint_dir"] + else: + dir_name = config["OutputDir"] + if os.path.exists(dir_name): + subfolders = [x for x in os.walk(dir_name)][0][1] + # only look at run_ or prediction_ folders + run_numbers = [ + int(sub.split("_")[1]) + for sub in subfolders + if subdir_names[train][1:] in sub + ] + if len(subfolders) > 0: + most_recent_run_number = max(run_numbers) + most_recent_run_dir = ( + dir_name + subdir_names[train] + str(most_recent_run_number) + ) + most_recent_config, _ = load_config( + most_recent_run_dir + "/config.yaml", + train=train, + ) + # HACK - this will combine runs with the same config files that are run + # within a minute of one another. multi gpu case - don't create a new + # run folder on non-rank 0 gpu + if most_recent_config == config: + return most_recent_run_dir + else: + most_recent_run_number = 0 + else: + os.makedirs(dir_name) + most_recent_run_number = 0 + + new_run_folder_name = subdir_names[train] + str(most_recent_run_number + 1) + os.makedirs(dir_name + new_run_folder_name) + if train: + os.makedirs(dir_name + new_run_folder_name + "/validation_results") + + with open(dir_name + new_run_folder_name + "/config.yaml", "w") as config_file: + yaml.dump(config, config_file, default_flow_style=False) + return dir_name + new_run_folder_name + def get_samplers(num_training_data, validation_ratio, my_seed): from torch.utils.data import sampler as torch_sampler + indices = list(range(num_training_data)) split = int(np.floor(validation_ratio * num_training_data)) @@ -39,180 +359,237 @@ def get_samplers(num_training_data, validation_ratio, my_seed): return train_sampler, valid_sampler + def simple_norm(img, a, b, m_high=-1, m_low=-1): idx = np.ones(img.shape, dtype=bool) - if m_high>0: - idx = np.logical_and(idx, img0: - idx = np.logical_and(idx, img>m_low) + if m_high > 0: + idx = np.logical_and(idx, img < m_high) + if m_low > 0: + idx = np.logical_and(idx, img > m_low) img_valid = img[idx] - m,s = stats.norm.fit(img_valid.flat) - strech_min = max(m - a*s, img.min()) - strech_max = min(m + b*s, img.max()) - img[img>strech_max]=strech_max - img[img strech_max] = strech_max + img[img < strech_min] = strech_min + img = (img - strech_min + 1e-8) / (strech_max - strech_min + 1e-8) return img + +def get_adjusted_min_max(img, a, b, m_high=-1, m_low=-1): + idx = np.ones(img.shape, dtype=bool) + if m_high > 0: + idx = np.logical_and(idx, img < m_high) + if m_low > 0: + idx = np.logical_and(idx, img > m_low) + img_valid = img[idx] + m, s = stats.norm.fit(img_valid.flat) + strech_min = max(m - a * s, img.min()) + strech_max = min(m + b * s, img.max()) + return strech_min, strech_max + + def background_sub(img, r): - struct_img_smooth = ndi.gaussian_filter(img, sigma=r, mode='nearest', truncate=3.0) + struct_img_smooth = ndi.gaussian_filter(img, sigma=r, mode="nearest", truncate=3.0) struct_img_smooth_sub = img - struct_img_smooth - struct_img = (struct_img_smooth_sub - struct_img_smooth_sub.min())/(struct_img_smooth_sub.max()-struct_img_smooth_sub.min()) + struct_img = (struct_img_smooth_sub - struct_img_smooth_sub.min()) / ( + struct_img_smooth_sub.max() - struct_img_smooth_sub.min() + ) return struct_img -def input_normalization(img, args): +def input_normalization(img, args): nchannel = img.shape[0] args.Normalization = int(args.Normalization) for ch_idx in range(nchannel): - struct_img = img[ch_idx,:,:,:] # note that struct_img is only a view of img, so changes made on struct_img also affects img - if args.Normalization == 0: # min-max normalization - struct_img = (struct_img - struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-7) - elif args.Normalization == 1: # mem: DO NOT CHANGE (FIXED FOR CAAX PRODUCTION) - m,s = stats.norm.fit(struct_img.flat) - strech_min = max(m - 2*s, struct_img.min()) - strech_max = min(m + 11 *s, struct_img.max()) - struct_img[struct_img>strech_max]=strech_max - struct_img[struct_img strech_max] = strech_max + struct_img[struct_img < strech_min] = strech_min + struct_img = (struct_img - strech_min + 1e-8) / ( + strech_max - strech_min + 1e-8 + ) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 2: # nuc + # struct_img = simple_norm(struct_img, 2.5, 10, 1000, 300) struct_img = simple_norm(struct_img, 2.5, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] + img[ch_idx, :, :, :] = struct_img[:, :, :] elif args.Normalization == 4: struct_img = simple_norm(struct_img, 1, 15) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 7: # cardio_wga + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 7: # cardio_wga struct_img = simple_norm(struct_img, 1, 6) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 10: # lamin hipsc, DO NOT CHANGE (FIXED FOR LAMNB1 PRODUCTION) - img_valid = struct_img[struct_img>4000] - m,s = stats.norm.fit(img_valid.flat) - m,s = stats.norm.fit(struct_img.flat) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif ( + args.Normalization == 10 + ): # lamin hipsc, DO NOT CHANGE (FIXED FOR LAMNB1 PRODUCTION) + img_valid = struct_img[struct_img > 4000] + m, s = stats.norm.fit(img_valid.flat) + m, s = stats.norm.fit(struct_img.flat) strech_min = struct_img.min() - strech_max = min(m + 25 *s, struct_img.max()) - struct_img[struct_img>strech_max]=strech_max - struct_img = (struct_img- strech_min + 1e-8)/(strech_max - strech_min + 1e-8) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 12: # nuc - struct_img = background_sub(struct_img,50) + strech_max = min(m + 25 * s, struct_img.max()) + struct_img[struct_img > strech_max] = strech_max + struct_img = (struct_img - strech_min + 1e-8) / ( + strech_max - strech_min + 1e-8 + ) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 12: # nuc + struct_img = background_sub(struct_img, 50) struct_img = simple_norm(struct_img, 2.5, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] - print('subtracted background') - elif args.Normalization == 11: - struct_img = background_sub(struct_img,50) - #struct_img = simple_norm(struct_img, 2.5, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 13: # cellmask - #struct_img[struct_img>10000] = struct_img.min() - struct_img = background_sub(struct_img,50) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 11: + struct_img = background_sub(struct_img, 50) + # struct_img = simple_norm(struct_img, 2.5, 10) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 13: # cellmask + # struct_img[struct_img>10000] = struct_img.min() + struct_img = background_sub(struct_img, 50) struct_img = simple_norm(struct_img, 2, 11) - img[ch_idx,:,:,:] = struct_img[:,:,:] + img[ch_idx, :, :, :] = struct_img[:, :, :] elif args.Normalization == 14: struct_img = simple_norm(struct_img, 1, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 15: # lamin - struct_img[struct_img>4000] = struct_img.min() - struct_img = background_sub(struct_img,50) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 16: # lamin/h2b - struct_img = background_sub(struct_img,50) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 15: # lamin + struct_img[struct_img > 4000] = struct_img.min() + struct_img = background_sub(struct_img, 50) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 16: # lamin/h2b + struct_img = background_sub(struct_img, 50) struct_img = simple_norm(struct_img, 1.5, 6) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 17: # lamin - struct_img = background_sub(struct_img,50) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 17: # lamin + struct_img = background_sub(struct_img, 50) struct_img = simple_norm(struct_img, 1, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] - elif args.Normalization == 18: # h2b - struct_img = background_sub(struct_img,50) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 18: # h2b + struct_img = background_sub(struct_img, 50) struct_img = simple_norm(struct_img, 1.5, 10) - img[ch_idx,:,:,:] = struct_img[:,:,:] + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 19: # EMT lamin + struct_img[struct_img > 4000] = struct_img.min() + struct_img = simple_norm(struct_img, 1, 15) + img[ch_idx, :, :, :] = struct_img[:, :, :] + elif args.Normalization == 22: + print("No Normalization") + img[ch_idx, :, :, :] = struct_img[:, :, :] else: - print('no normalization recipe found') + print("no normalization recipe found") quit() return img -def image_normalization(img, config): +def image_normalization(img, config): if type(config) is dict: - ops = config['ops'] + ops = config["ops"] nchannel = img.shape[0] assert len(ops) == nchannel for ch_idx in range(nchannel): - ch_ops = ops[ch_idx]['ch'] - struct_img = img[ch_idx,:,:,:] + ch_ops = ops[ch_idx]["ch"] + struct_img = img[ch_idx, :, :, :] for transform in ch_ops: - if transform['name'] == 'background_sub': - struct_img = background_sub(struct_img, transform['sigma']) - elif transform['name'] =='auto_contrast': - param = transform['param'] - if len(param)==2: + if transform["name"] == "background_sub": + struct_img = background_sub(struct_img, transform["sigma"]) + elif transform["name"] == "auto_contrast": + param = transform["param"] + if len(param) == 2: struct_img = simple_norm(struct_img, param[0], param[1]) - elif len(param)==4: - struct_img = simple_norm(struct_img, param[0], param[1], param[2], param[3]) - else: - print('bad paramter for auto contrast') + elif len(param) == 4: + struct_img = simple_norm( + struct_img, param[0], param[1], param[2], param[3] + ) + else: + print("bad paramter for auto contrast") quit() - else: - print(transform['name']) - print('other normalization methods are not supported yet') + else: + print(transform["name"]) + print("other normalization methods are not supported yet") quit() - - img[ch_idx,:,:,:] = struct_img[:,:,:] + + img[ch_idx, :, :, :] = struct_img[:, :, :] else: - args_norm = lambda:None + args_norm = lambda: None # noqa E731 args_norm.Normalization = config img = input_normalization(img, args_norm) return img + def load_single_image(args, fn, time_flag=False): if time_flag: - img = fn[:,args.InputCh,:,:] + img = fn[:, args.InputCh, :, :] img = img.astype(float) - img = np.transpose(img, axes=(1,0,2,3)) + img = np.transpose(img, axes=(1, 0, 2, 3)) else: data_reader = AICSImage(fn) if isinstance(args.InputCh, List): channel_list = args.InputCh else: channel_list = [args.InputCh] - img = data_reader.get_image_data('CZYX', S=0, T=0, C=channel_list) + img = data_reader.get_image_data("CZYX", S=0, T=0, C=channel_list) # normalization - if args.mode == 'train': + if args.mode == "train": for ch_idx in range(args.nchannel): - struct_img = img[ch_idx,:,:,:] # note that struct_img is only a view of img, so changes made on struct_img also affects img - struct_img = (struct_img - struct_img.min() )/(struct_img.max() - struct_img.min()) + struct_img = img[ch_idx, :, :, :] + # note that struct_img is only a view of img, so changes + # made on struct_img also affects img + struct_img = (struct_img - struct_img.min()) / ( + struct_img.max() - struct_img.min() + ) elif not args.Normalization == 0: img = input_normalization(img, args) - + # rescale - if len(args.ResizeRatio)>0: - img = zoom(img, (1, args.ResizeRatio[0], args.ResizeRatio[1], args.ResizeRatio[2]), order=1) + if len(args.ResizeRatio) > 0: + img = zoom( + img, + (1, args.ResizeRatio[0], args.ResizeRatio[1], args.ResizeRatio[2]), + order=1, + ) return img def compute_iou(prediction, gt, cmap): - + if type(prediction) == torch.Tensor: + prediction = prediction.detach().cpu().numpy() + if type(gt) == torch.Tensor: + gt = gt.detach().cpu().numpy() + if type(cmap) == torch.Tensor: + cmap = cmap.detach().cpu().numpy() + if prediction.shape[1] == 2: # take foreground channel + prediction = prediction[:, 1, :, :, :] + if len(prediction.shape) == 4: + prediction = np.expand_dims(prediction, axis=1) + if len(cmap.shape) == 3: # cost function doesn't take cmap + cmap = np.ones_like(prediction) area_i = np.logical_and(prediction, gt) - area_i[cmap==0]=False + area_i[cmap == 0] = False area_u = np.logical_or(prediction, gt) - area_u[cmap==0]=False + area_u[cmap == 0] = False return np.count_nonzero(area_i) / np.count_nonzero(area_u) + def get_logger(name, level=logging.INFO): logger = logging.getLogger(name) logger.setLevel(level) # Logging to console stream_handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( - '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') + "%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s" + ) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) diff --git a/aicsmlsegment/version.py b/aicsmlsegment/version.py deleted file mode 100644 index 60592a2..0000000 --- a/aicsmlsegment/version.py +++ /dev/null @@ -1,16 +0,0 @@ -# Autogenerated file - do NOT edit this by hand -MODULE_VERSION = "0.0.8.dev0" - -# For snapshot, X.Y.Z.devN -> X.Y.Z.devN+1 -# bumpversion devbuild -# -# For release, X.Y.Z.devN -> X.Y.Z -# bumpversion release -# DO NOT CALL release on consecutive calls -# DO NOT CALL release on 0.0.0.devN -# -# For preparing for next development cycle after release -# bumpversion patch (X.Y.Z -> X.Y.Z+1.dev0) -# bumpversion minor (X.Y.Z -> X.Y+1.Z.dev0) -# bumpversion major (X.Y.Z -> X+1.Y.Z.dev0) -# diff --git a/build.gradle b/build.gradle deleted file mode 100644 index 71534b6..0000000 --- a/build.gradle +++ /dev/null @@ -1,38 +0,0 @@ -buildscript { - def buildScriptPlugins = ['scripts/common/buildscript-5.gradle'] - println "> Applying script plugins in buildscripts:" - for (scriptPlugin in buildScriptPlugins) { - def pluginPath = "${scriptPluginPrefix}${scriptPlugin}${scriptPluginSuffix}${scriptPluginTag}" - println "${pluginPath}" - apply from: pluginPath, to: buildscript - } -} - -////////////////////////////////////////////////////////////////////////////////////////////////////// - -def scriptPlugins = ['scripts/common/gradle-version-5.gradle', - 'scripts/common/common-5.gradle', - 'scripts/python/build.gradle', - 'scripts/python/version.gradle', - 'scripts/python/publish.gradle'] -println "> Applying script plugins:" -for (scriptPlugin in scriptPlugins) { - def pluginPath = "${scriptPluginPrefix}${scriptPlugin}${scriptPluginSuffix}${scriptPluginTag}" - println "${pluginPath}" - apply from: pluginPath -} - - -// Add the environment variable to gradle for coverage report -// Do not add this to setup.cfg since it will break IDE tools -py.env.put("PYTEST_ADDOPTS", "--cov=${rootProject.name} --cov-config=setup.cfg --cov-report=html --cov-report=xml --cov-report=term") - - -////////////////////////////////////////////////////////////////////////////////////////////////////// -py.uploadToPyPi = true -project.group = "org.alleninstitute.aics.pypi" -description = "AICS ML segmentation" -// Project version will be managed outside of gradle in accordance with PEP 440 -// ("https://www.python.org/dev/peps/pep-0440/") - -////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/configs/all_predict_options.yaml b/configs/all_predict_options.yaml new file mode 100644 index 0000000..bbd97f3 --- /dev/null +++ b/configs/all_predict_options.yaml @@ -0,0 +1,45 @@ +################### +# MODEL +################### +model: + name: basic_unet + patch_size: [32, 256, 256] + dimensions: 3 + in_channels: 1 + out_channels: 2 + features : [32, 64, 128, 256, 512, 64] + +model_path: 'path/to/model/checkpoint' + +##################### +# PREDICTION +##################### +OutputCh: 1 # which channel should be taken from the prediction +precision: 16 #16 or 32, whether to use 16-bit or 32-bit model weights. +batch_size: 1 +inference_mode: 'fast' +gpus: -1 # -1 to use all available gpus, otherwise a positive integer +dist_backend: ddp # either blank or ddp, whether to use ddp or not for multi-GPU training +NumWorkers: 8 # number of workers to spawn for data loading. More workers increases memory usage +segmentation_name: "name_of_segmentation_type" # name of segmentation, written to output image metadata + +###################### +# DATA +###################### +OutputDir: 'path/to/save/results' +InputCh: [1] # channel to take from input image +large_image_resize: [1,1,1] # if single images are too large to fit into GPU, specifies number of chunks to split each image into +ResizeRatio: [1,1,1] #ratio to resize image +Normalization: 18 # normalization 'recipe', should match training data +Threshold: -1 # whether to binarize output images. Either -1 or an integer between 0 and 1 +RuntimeAug: False # run prediction on four flipped versions of the image - increases quality, takes ~4x longer + +mode: + name: file # whether to predict on individual file or entire folder of images + InputFile: "path/to/input/input/img" + timelapse: False +# OR # +# mode: +# name: folder +# InputDir: "path/to/input/input/img/directory" +# DataType: '.tif' \ No newline at end of file diff --git a/configs/all_train_options.yaml b/configs/all_train_options.yaml new file mode 100644 index 0000000..81d2a38 --- /dev/null +++ b/configs/all_train_options.yaml @@ -0,0 +1,70 @@ +################### +# MODEL +################### +model: + name: basic_unet + patch_size: [32, 256, 256] + dimensions: 3 + in_channels: 1 + out_channels: 2 + features : [32, 64, 128, 256, 512, 64] + +checkpoint_dir: '/path/to/save/directory' +resume: 'path/to/model/checkpoint' +precision: 16 #16 or 32, whether to use 16-bit or 32-bit model weights. + +##################### +# TRAINING +##################### +learning_rate: 0.00001 +weight_decay: 0.005 +loss: + name: Aux + loss_weight: [1, 1, 1] # weights on each auxilliary loss + ignore_index: null + +scheduler: + name: ExponentialLR #name of learning rate scheduler. A full list of scheduler is in aicsmlsegment/Model.py + gamma: 0.85 + verbose: True + +SWA: # Configuration for Stochastic Weight Averaging + swa_start: 1 # if > 1, epoch when to start SWA, if between 0 and 1, percentage of epochs to start SWA + swa_lr: 0.001 # learning rate to ramp up to at end of SWA + annealing_epochs: 3 #number of epochs to ramp from learning rate to swa_lr + annealing_strategy: cos # cos or linear, whether to ramp up to swa_lr linearly or cosine + +epochs: 400 +save_every_n_epoch: 50 + +callbacks: + name: EarlyStopping + monitor: val_loss # one of val_loss, train_loss, val_iou - the metric to base early stopping decision on + min_delta: 0.01 # minimum change in metric to prevent early stopping + patience: 10 # how many epochs to wait for a change greater than min_delta before stopping training + verbose: True # whether to print updates to commandline + mode: min #whether the monitor value should be minimized or maximized + +gpus: -1 # -1 to use all available gpus, otherwise a positive integer +dist_backend: ddp # either blank or ddp, whether to use ddp or not for multi-GPU training + +tensorboard: "path/to/logdir" # log directory where tensorboard should look for tensorboard files + +###################### +# DATA +###################### +loader: + name: default + datafolder: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_training_data_iter_1/' + batch_size: 8 + PatchPerBuffer: 160 + epoch_shuffle: 5 + NumWorkers: 1 + Transforms: ['RR'] # list containing any of RR (random rotation), RF (random flip), RN (random noise), RI (random intensity shift), RBF (random bias field) + +validation: + metric: default + leaveout: [0] + OutputCh: [0, 1, 1, 1, 2, 1] + validate_every_n_epoch: 25 + diff --git a/configs/monai_train_config.yaml b/configs/monai_train_config.yaml new file mode 100644 index 0000000..0c38b98 --- /dev/null +++ b/configs/monai_train_config.yaml @@ -0,0 +1,39 @@ +model: + name: basic_unet + patch_size: [32, 256, 256] + dimensions: 3 + in_channels: 1 + out_channels: 2 + features : [32, 64, 128, 256, 512, 64] + +checkpoint_dir: '/checkpoints/' +resume: null + +learning_rate: 0.00001 +weight_decay: 0.005 +epochs: 1000 +save_every_n_epoch: 200 +loss: + name: GeneralizedDice+CrossEntropy + loss_weight: [1, 1, 1] + ignore_index: null + +gpus: -1 +dist_backend: ddp +tensorboard: '/lightning_logs' +precision: 32 + +loader: + name: default + datafolder: ['/folder1/', '/folder2/'] + batch_size: 4 + PatchPerBuffer: 100 + epoch_shuffle: 50 + NumWorkers: 1 + Transforms: ["RR", 'RF', 'RI', 'RN'] + +validation: + metric: default + leaveout: [0, 3, 5] + OutputCh: 1 + validate_every_n_epoch: 200 \ No newline at end of file diff --git a/configs/unet_xy_zoom_0pad.yaml b/configs/unet_xy_zoom_0pad.yaml new file mode 100644 index 0000000..91de160 --- /dev/null +++ b/configs/unet_xy_zoom_0pad.yaml @@ -0,0 +1,44 @@ +model: + name: unet_xy_zoom_0pad + zoom_ratio: 3 + nchannel: 1 + nclass: [2,2,2] + size_in: [65, 384, 384] + size_out: [65, 384, 384] + +checkpoint_dir: 'path/to/save/checkpoints/' + +resume: "path/to/previous/checkpoints/file.ckpt" +learning_rate: 0.00001 + +scheduler: + name: ExponentialLR + gamma: 0.91 + verbose: True + +weight_decay: 0.005 +epochs: 400 +save_every_n_epoch: 2 +loss: + name: Aux + loss_weight: [1, 1, 1] + ignore_index: null +tensorboard: '/path/to/lightning_logs' +gpus: -1 +dist_backend: ddp +precision: 16 + +loader: + name: default + datafolder: [/path/1/, /path/2/] + batch_size: 2 + PatchPerBuffer: 360 + epoch_shuffle: 10 + NumWorkers: 6 + Transforms: ['RR', 'RF', 'RN'] + +validation: + metric: default + leaveout: [0.02] + OutputCh: [0, 1, 1, 1, 2, 1] + validate_every_n_epoch: 1 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..74855a7 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = python -msphinx +SPHINXPROJ = aicsmlsegment +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/bb1.md b/docs/bb1.md index 7e4d470..e898470 100644 --- a/docs/bb1.md +++ b/docs/bb1.md @@ -40,7 +40,7 @@ batch_processing \ ### Understanding model output -The actual prediction from a deep learning based segmentation model is not binary. The value of each voxel is a real number between 0 and 1. To make it binary, we usually apply a cutoff value, i.e., the `Threshold` parameter in the [configuration file](./doc_pred_yaml.md). For each model, a different cutoff value may be needed. To determine a proper cutoff value, you can use `-1` for `Threshold` on sample images and open the output in ImageJ (with [bio-formats importer](https://imagej.net/Bio-Formats#Bio-Formats_Importer)) and try out different threshold values. Then, you can set `Threshold` as the new value and run on all images. Now, the results will be binary. +The actual prediction from a deep learning based segmentation model is not binary. The value of each voxel is a real number between 0 and 1. To make it binary, we usually apply a cutoff value, i.e., the `Threshold` parameter in the [configuration file](./doc_pred_yaml.md). For each model, a different cutoff value may be needed. To determine a proper cutoff value, you can use `-1` for `Threshold` on sample images and open the output in ImageJ (with [bio-formats importer](https://imagej.net/Bio-Formats#Bio-Formats_Importer)) and try out different threshold values. Then, you can set `Threshold` as the new value and run on all images. Then, the results will be binary. Another way to determine the cutoff is to collect a few images with ground truth and sweeping through different cutoff values to find the cutoff achieving the highest accuracy. ### Apply on one image diff --git a/docs/bb3.md b/docs/bb3.md index b5415f7..0957447 100644 --- a/docs/bb3.md +++ b/docs/bb3.md @@ -1,6 +1,8 @@ # Building Block 3: **Trainer** -**Trainer** is used to train deep learning-based segmentation models. The input for **Trainer** should be data prepared by **Curator** (see [documentation](./bb2.md)) and the output should be a model that can be used in **Segmenter**. +**Trainer** is used to train deep learning-based segmentation models. The input for **Trainer** are usually data prepared by **Curator** (see [documentation](./bb2.md)) and the output should be a model that can be used in **Segmenter**. + +In case you don't want to use **Curator** to generate the training data, the images should follow the naming convention: the data folder should contain `XYZ.tiff` (raw image, float, properly normalized to [0,1]), `XYZ_GT.tiff` (ground truth, uint8, 0=background, 1=foreground of class 1, can have more integers if more classes are needed), `XYZ_CM.tiff` (cost map, float, the importance of each pixel, usually 0 for areas to be excluded and 1 for normal areas). ![segmenter pic](./bb3_pic.png) diff --git a/docs/code_overview.md b/docs/code_overview.md new file mode 100644 index 0000000..e15db61 --- /dev/null +++ b/docs/code_overview.md @@ -0,0 +1,31 @@ +aicsmlsegment +. +|-- bin + |--curator + |-- curator_merging.py + |-- curator_sorting.py + |-- curator_takeall.py + |-- predict.py + |-- train.py +|-- DataUtils + |-- DataMod.py + |-- Universal_Loader.py +|-- NetworkArchitecture + |-- unet_xy_zoom.py + |-- unet_xy_zoom_0pad.py + |-- ... +|-- tests + | -- ... (dummy unit test) +|-- custom_loss.py +|-- custom_metric.py +|-- fnet_prediction_torch.py + (core inference function for new Segmenter models, which have same input and output size) +|-- model_utils.py +|-- Model.py + (core function defining a PLT model class) +|-- multichannel_sliding_window.py + (core inference function for old Segmenter models, which have different input and output size) +|-- training_utils.py +|-- utils.py +|-- version.py + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..eed83b9 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# aicsmlsegment documentation build configuration file, created by +# sphinx-quickstart on Fri Jun 9 13:47:02 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another +# directory, add these directories to sys.path here. If the directory is +# relative to the documentation root, use os.path.abspath to make it +# absolute, like shown here. +# +import os +import sys + +import sphinx_rtd_theme + +import aicsmlsegment + +sys.path.insert(0, os.path.abspath("..")) + + +# -- General configuration --------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = "1.0" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named "sphinx.ext.*") or your custom ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", + "m2r", +] + +# Control napoleon +napoleon_google_docstring = False +napolean_include_init_with_doc = True +napoleon_use_ivar = True +napoleon_use_param = False + +# Control autodoc +autoclass_content = "both" # include init doc with class + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = { + ".rst": "restructuredtext", + ".txt": "markdown", + ".md": "markdown", +} + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = u"aics-ml-segmentation" +copyright = u"2020, AICS" +author = u"AICS" + +# The version info for the project you"re documenting, acts as replacement +# for |version| and |release|, also used in various other places throughout +# the built documents. +# +# The short X.Y version. +version = aicsmlsegment.__version__ +# The full version, including alpha/beta/rc tags. +release = aicsmlsegment.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" + +# Theme options are theme-specific and customize the look and feel of a +# theme further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "collapse_navigation": False, + "prev_next_buttons_location": "top", +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + + +# -- Options for HTMLHelp output --------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = "aicsmlsegmentdoc" + + +# -- Options for LaTeX output ------------------------------------------ + +latex_elements = { + # The paper size ("letterpaper" or "a4paper"). + # + # "papersize": "letterpaper", + # The font size ("10pt", "11pt" or "12pt"). + # + # "pointsize": "10pt", + # Additional stuff for the LaTeX preamble. + # + # "preamble": "", + # Latex figure (float) alignment + # + # "figure_align": "htbp", +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass +# [howto, manual, or own class]). +latex_documents = [ + ( + master_doc, + "aicsmlsegment.tex", + u"aics-ml-segmentation Documentation", + u"AICS", + "manual", + ), +] + + +# -- Options for manual page output ------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, "aicsmlsegment", u"aics-ml-segmentation Documentation", [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "aicsmlsegment", + u"aics-ml-segmentation Documentation", + author, + "aicsmlsegment", + "One line description of project.", + "Miscellaneous", + ), +] diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 0000000..4fc5016 --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1 @@ +.. mdinclude:: ../CONTRIBUTING.md diff --git a/docs/doc_pred_yaml.md b/docs/doc_pred_yaml.md index 702c9d2..9d86178 100644 --- a/docs/doc_pred_yaml.md +++ b/docs/doc_pred_yaml.md @@ -20,15 +20,15 @@ or model: name: unet_xy ``` -see model parameters in [training configuration](./doc_train_yaml.md) +See model parameters in [training configuration](./doc_train_yaml.md) 2. input and output type (:ok:) ```yaml nchannel: 1 nclass: [2, 2, 2] -OutputCh: [0, 1] +OutputCh: 1 ``` -These are related to the model architecture and fixed by default. +These are related to the model architecture and fixed by default. See model parameters in [training configuration](./doc_train_yaml.md) 3. patch size (:computer:) @@ -44,22 +44,38 @@ model_path: '/home/model/checkpoint_epoch_300.pytorch' ``` This the place to specify which trained model to run. +### Prediction +```yaml +OutputCh: 1 +precision: 16 +batch_size: 1 +inference_mode: 'fast' +gpus: -1 +dist_backend: ddp +NumWorkers: 8 +segmentation_name: "name_of_type_of_segmentation" +``` +`OutputCh` specifies which channel should be extracted from the raw output image. `precision` specifies whether 16- or 32-bit model weights should be used. `inference_mode` has two options: `fast` or `efficient`. `fast` mode increases speed at the cost of GPU memory usage, while `efficient` mode decreases GPU memory usage at the cost of speed. Multi-GPU training using ddp is supported. If `gpus` = -1, all available GPUs will be used. Otherwise, `gpus` must be a positive integer and up to that many GPUs will be used depending on GPU availability. If multiple GPUs are used, a `dist_backend` can be specified. Currently, ddp is the only supported backend. `NumWorkers` specifies how many workers should be spawned for loading and normalizing images. Additional workers increase CPU memory usage. `segmentation_name` will be saved in the image metadata of the output image to identify the segmentation. + ### Data Info (:pushpin:) ```yaml -OutputDir: '//allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/' +mode: + name: folder + InputDir: "path/to/input/input/img" + DataType: .tiff + +OutputDir: 'path/to/save/results' InputCh: [0] ResizeRatio: [1.0, 1.0, 1.0] +large_image_resize: [1,1,1] Threshold: 0.75 RuntimeAug: False Normalization: 10 -mode: - name: folder - InputDir: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_fluorescent' - DataType: tiff + ``` -`DataType` is the type of images to be processed in `InputDir`, which the `InputCh`'th (keep the [ ]) channel of each images will be segmented. If your model is trained on images of a certain resolution and your test images are of different resolution `ResizeRatio` needs to be set as [new_z_size/old_z_size, new_y_size/old_y_size, new_x_size/old_x_size]. The acutal output is the likelihood of each voxels being the target structure. A `Threshold` between 0 and 1 needs to be set to generate the binary mask. We recommend to use 0.6 ~ 0.9. When `Threshold` is set as `-1`, the raw prediction from the model will be saved, for users to determine a proper binary cutoff. `Normalization` is the index of a list of pre-defined normalization recipes and should be the same index as generating training data (see [Curator](./bb2.md) for the full list of normalization recipes). +`DataType` is the type of images to be processed in `InputDir`, which the `InputCh`'th (keep the [ ]) channel of each image will be segmented. If your model is trained on images of a certain resolution and your test images are of different resolution `ResizeRatio` needs to be set as [new_z_size/old_z_size, new_y_size/old_y_size, new_x_size/old_x_size]. The actual output is the likelihood of each voxels being the target structure. `large_image_resize` can be set if large images cause GPU out of memory during prediction. This parameter specifies how many patches in the `ZYX` axes each image should be split into. After patch-wise prediction, the final image is reconstructed based on overlap between patches. A `Threshold` between 0 and 1 needs to be set to generate the binary mask. We recommend to use 0.6 ~ 0.9. When `Threshold` is set as `-1`, the raw prediction from the model will be saved, for users to determine a proper binary cutoff. `Normalization` is the index of a list of pre-defined normalization recipes and should be the same index as generating training data (see [Curator](./bb2.md) for the full list of normalization recipes). If `RuntimeAug` is `True`, the model will predict on the original image and three flipped versions of the image. The final prediction is then averaged across each flipped prediction. This increases prediction quality, but takes ~4x longer. \ No newline at end of file diff --git a/docs/doc_train_yaml.md b/docs/doc_train_yaml.md index 36c87ec..4469c91 100644 --- a/docs/doc_train_yaml.md +++ b/docs/doc_train_yaml.md @@ -9,6 +9,16 @@ This is a detailed description of the configuration for training a DL model in * 3. Only need to change once on a particular machine (parameters specific to each machine), marked by :computer: 4. No need to change for most problems (parameters pre-defined as a general training scheme and requires advacned deep learning knowledge to adjust), marked by :ok: +******************************************************************** +Here is a list of example yaml configuration files: +- [basic](../configs/train_config.yaml) +- [unet_xy_zoom_0pad using tensorboard + learning rate scheduler](../configs/unet_xy_zoom_0pad.yaml) +- [MONAI model](../configs/monai_train_config.yaml) +- [TBA]() +- [TBA]() +- [TBA]() +******************************************************************** + ### Model related parameters @@ -23,35 +33,34 @@ or model: name: unet_xy ``` -There may be probably more than 100 models in the literature for 3D image segmentation. The two models we implemented here are carefully designed for cell structure segmentation in 3D microscopy images. Model `unet_xy` is suitable for smaller-scale structures, like severl voxels thick (e.g., tubulin, lamin b1). Model `unet_xy_zoom` is more suitable for larger-scale structures, like more than 100 voxels in diameter (e.g., nucleus), while the `zoom_ratio` is an integer (e.g., 2 or 3) and can be estimated by average diameter of target object in voxels divided by 150. +There may be probably more than 100 models in the literature for 3D image segmentation. The two models we implemented here are carefully designed for cell structure segmentation in 3D microscopy images. Model `unet_xy` is suitable for smaller-scale structures, like severl voxels thick (e.g., tubulin, lamin b1). Model `unet_xy_zoom` is more suitable for larger-scale structures, like more than 100 voxels in diameter (e.g., nucleus), while the `zoom_ratio` is an integer (e.g., 2 or 3). Larger `zoom_ratio` allows the model to take more neighbor information into account, but reduce the resolution. -2. start from an existing model? (:warning:) +There are a few more variations of `unet_xy_zoom`: `unet_xy_zoom_0pad` (add 0 padding after all convolutions, so `size_in = size_out`), `unet_xy_zoom_dilated` (use dilated convolution), `unet_xy_zoom_stridedconv` (replace max pooling with strided convolutions), `unet_xy_zoom_0pad_stridedconv` (both 0 padding and strided convolution), etc. -```yaml -resume: null -``` +Now, we can also support most of the baseline models implemented in [MONAI](https://docs.monai.io/en/latest/networks.html#nets). All the parameters in those models can be passed in here. -When doing iterative deep learning, it may be useful to start from the model trained in the previous step. The model can be specified at `resume`. If `null`, a new model will be trained from scratch. -3. input and output type (:ok:) -```yaml -nchannel: 1 -nclass: [2, 2, 2] +2. input and output type (:ok:) +```yaml + nchannel: 1 + nclass: [2, 2, 2] ``` -These are related to the model architecture and fixed by default. We assume the input image has only one channel. +These are related to the model architecture and fixed by default. We assume the input image has only one channel. For `nclass`, it could be `[2, 2, 2]` (for `unet_xy`, `unet_xy_zoom`) or `2` (for other monai models). You can also use more classes than 2, which will also need to have corresponding values in the ground truth images. -4. patch size (:computer:) (:pushpin:) +3. patch size (:computer:) (:pushpin:) ```yaml -size_in: [50, 156, 156] -size_out: [22, 68, 68] + size_in: [50, 156, 156] + size_out: [22, 68, 68] ``` -In most situations, we cannot fit the entire image into the memory of a single GPU. These are also related to `batch_size` (an data loader parameter), which will be discussed shortly. `size_in` is the actual size of each patch fed into the model, while `size_out` is the size of the model's prediction. The prediction size is smaller than the input size is because the multiple convolution operations. The equation for calculating `size_in` and `size_out` is as follows. +In most situations, we cannot fit the entire image into the memory of a single GPU. These are also related to `batch_size` (a data loader parameter), which will be discussed shortly. `size_in` is the actual size of each patch fed into the model, while `size_out` is the size of the model's prediction. The prediction size is smaller than the input size is because the multiple convolution operations. The equation for calculating `size_in` and `size_out` is as follows. > For unet_xy, `size_in` = `[z, 8p+60, 8p+60]`, `size_out` = `[z-28, 8p-28, 8p-28]` > For unet_xy_zoom, with `zoom_ratio`=`k`, `size_in` = `[z, 8kp+60k, 8kp+60k]` and `size_out` = `[z-32, 8kp-28k-4, 8kp-28k-4]` +> For unet_xy_zoom_0pad with `zoom_ratio`=`k`, `size_in` = `size_out` and `x` and `y` must be divisible by `8k`. + Here, `p` and `z` can be any positive integers that make `size_out` has all positive values. Here are some pre-calculated values for different models on different types of GPUs. @@ -63,6 +72,15 @@ Here are some pre-calculated values for different models on different types of G | unet_xy_zoom (ratio=3) on 12GB GPU | [52, 372, 372] | [20, 104, 104] | 4 | | unet_xy_zoom (ratio=3) on 33GB GPU | [52, 420, 420] | [20, 152, 152] | 8 | +4. start from an existing model? (:warning:) + +```yaml +resume: null +``` + +When doing iterative deep learning, it may be useful to start from the model trained in the previous step. The model can be specified at `resume`. If `null`, a new model will be trained from scratch. + + 5. model directory (:warning:) ```yaml checkpoint_dir: /home/model/xyz/ @@ -70,7 +88,13 @@ resume: null ``` This is the directory to save the trained model. If you want to start this training from a previous saved model, you may add the path to `resume`. -### Training scheme realted parameters +6. Precision +```yaml +precision: 16 +``` +Optional. If not specified, defaults to 32-bit model weights. 16-bit model weights can be specified as shown to decrease model memory usage and increase batch size with the possible tradeoff of reduced model performance. + +### Training scheme related parameters 1. optimization parameters (:ok:) ```yaml learning_rate: 0.00001 @@ -80,27 +104,104 @@ loss: loss_weight: [1, 1, 1] ignore_index: null ``` +`learning_rate` specifies the initial learning rate. If a scheduler is provided, this will change over the course of training. `weight_decay` specifies the weight decay parameter for the Adam Optimizer. With the custom `unet_xy` and `unet_xy_soom` as well as some of the MONAI models, deep supervision heads are used during training, resulting in the ouput of several images on which loss is calculated. `loss weight` determines how the loss calculated on each of these heads should be weighted. Several loss functions are currently supported: Dice, Generalized Dice, Focal, Masked Dice, MSE, Element Angular MSE, Masked MSE,Cross Entropy, and Masked Cross Entropy. Any two of these losses can be combined by writing the loss `name` as `loss1+loss2`. Additionally, if deep supervision heads are used, appending "Aux" to the start of the loss `name` will result in a final loss that is a weighted sum (as specified by `loss_weight`) across the deep supervision heads. Compatibility of all losses with all models is currently under construction. + -2. training epochs (:pushpin:) +2. Learning Rate Scheduler +```yaml +scheduler: + name: ExponentialLR + gamma: 0.85 + verbose: True +``` +Supported schedulers and associated parameters are as follows: + +| Scheduler Name | Parameters | +|----------------------------------|:-------------------------------------:| +|ExponentialLR | gamma, verbose | +|CosineAnnealingLR | T_max, verbose | +|StepLR | step_size, gamma, verbose | +|ReduceLROnPlateau | mode, factor, patience, verbose | +|1cycle |max_lr, total_steps, pct_start, verbose| + +Explanation of these parameters can be found in the [Pytorch documentation](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate). Specifying a scheduler is optional. + +3. Stochastic Weight Averaging +```yaml +SWA: + swa_start: 1 + swa_lr: 0.001 + annealing_epochs: 3 + annealing_strategy: cos + +``` + +Stochastic Weight Averaging combines high learning rate and weight averaging across epochs to help model generalization by taking advantage of the shape of the loss minima found by stochastic gradient descent. A more in-depth explanation can be found in the [Pytorch Documentation] (https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/). If used, all parameters must be specified, else this option can be omitted from the config file or specified as `SWA: `. + +| Parameter | Options | Explanation | +|-----------------------------|:----------------------------------:|-------------| +|swa_start | positive integer or between 0 and 1|Epoch number or percentage of total epochs to start SWA| +|swa_lr | float > 0 |Relatively high learning rate to anneal to| +|annealing_epochs | integer > 0 |Number of epochs to attain swa_lr| +|annealing_strategy | `cos` or `linear` |Whether to perform linear or cosine ramp to swa_lr| + +4. training epochs (:pushpin:) ```yaml epochs: 400 save_every_n_epoch: 40 ``` `epochs` controls how many iterations in the training. We suggest to use values between 300 and 600. A model will be saved on every `save_every_n_epoch` epochs. +5. Early Stopping +```yaml +callbacks: + name: EarlyStopping + monitor: val_loss + min_delta: 0.01 + patience: 10 + verbose: True + mode: min +``` +Early stopping can avoid overfitting on the training set by halting training when validation performance stops improving. `monitor` specifies which metric (either `val_loss` or `val_iou`) the decision to stop should be based on. Changes in `monitor` less than `min_delta` will not count as improvement. `patience` specifies how many epochs to wait for an improvement in a metric before stopping. `verbose` controls whether to print updates to the commandline. `mode` specifies whether the monitored value should be minimized or maximized. Early stopping does not have to be specified in the config file. -### Data realted parameters +6. GPU configuration +```yaml +gpus: -1 +dist_backend: ddp +``` +Multi-GPU training using ddp is supported. If `gpus` = -1, all available gpus will be used. Otherwise, `gpus` must be a positive integer and up to that many gpus will be used depending on gpu availability. If multiple gpus are used, a `backend` can be specified. Currently, ddp is the only supported backend. + +7. Tensorboard +```yaml +tensorboard: "path/to/logdir" +``` +Tensorboard can be used to track the progression of training and compare experiments. The `tensorboard` argument is a path to a directory containing tensorboard logs. + + + +### Data related parameters ```yaml loader: name: default - datafolder: '/home/data/train/' + datafolder: ['/home/data/train/'] batch_size: 8 PatchPerBuffer: 200 epoch_shuffle: 5 NumWorkers: 1 + Transforms: ['RR'] ``` -`datafolder` (:warning:) and `PatchPerBuffer` (:pushpin:) need to be specified for each problem. `datafolder` is the directory of training data. `PatchPerBuffer` is the number of sample patches randomly drawn in each epoch, which can be set as *number of patches to draw from each data* **x** *number of training data*. `name`, `epoch_shuffle` and `NumWorkers` (:ok:) are fixed by default. `batch_size` is related to GPU memory and patch size (see values presented with patch size). +`datafolder` (:warning:) and `PatchPerBuffer` (:pushpin:) need to be specified for each problem. `datafolder` is the directory of training data. Multiple directories can be combined by specifying a list of paths. `PatchPerBuffer` is the number of sample patches randomly drawn in each epoch, which can be set as *number of patches to draw from each data* **x** *number of training data*. `name`, `epoch_shuffle` and `NumWorkers` (:ok:) are fixed by default. `batch_size` is related to GPU memory and patch size (see values presented with patch size). +`Transforms` is a list of random transformations to apply to the training data. + +| Transform Name | Description | +|----------------------|:------------------------------------------:| +|`RR` | Random rotation from 1-180 degrees | +|`RF` | 50% probability of left/right fliping image| +|`RN` | Addition of 0-mean Gaussian random noise with standard deviation ~U(0, 0.25)| +|`RI` | Random shift in intensity of up to 0.1 with probability 0.2 | +|`RBF` |Application of [Random Bias Field] (https://torchio.readthedocs.io/transforms/augmentation.html#torchio.transforms.RandomBiasField)| + ### Validation related parameter diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..78f0190 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,21 @@ +Welcome to aics-ml-segmentation's documentation! +====================================== + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Contents: + + Overview + installation + Package modules + contributing + math + +.. mdinclude:: ../README.md + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..c5a78dc --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,51 @@ +.. highlight:: shell + +============ +Installation +============ + + +Stable release +-------------- + +To install aics-ml-segmentation, run this command in your terminal: + +.. code-block:: console + + $ pip install aicsmlsegment + +This is the preferred method to install aics-ml-segmentation, as it will always install the most recent stable release. + +If you don't have `pip`_ installed, this `Python installation guide`_ can guide +you through the process. + +.. _pip: https://pip.pypa.io +.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ + + +From sources +------------ + +The sources for aics-ml-segmentation can be downloaded from the `Github repo`_. + +You can either clone the public repository: + +.. code-block:: console + + $ git clone git://github.com/AllenCell/aics-ml-segmentation + +Or download the `tarball`_: + +.. code-block:: console + + $ curl -OL https://github.com/AllenCell/aics-ml-segment/tarball/main + +Once you have a copy of the source, you can install it with: + +.. code-block:: console + + $ python setup.py install + + +.. _Github repo: https://github.com/AllenCell/aics-ml-segmentation +.. _tarball: https://github.com/AllenCell/aicsmlsegment/tarball/main diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..499101b --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=python -msphinx +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=test + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The Sphinx module was not found. Make sure you have Sphinx installed, + echo.then set the SPHINXBUILD environment variable to point to the full + echo.path of the 'sphinx-build' executable. Alternatively you may add the + echo.Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/overview.md b/docs/overview.md index 3dc428c..293ee31 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -9,7 +9,7 @@ The Allen Cell Structure Segmenter is implemented as two packages: [`aicssegment ## Installation: * `aicssegmentation` (classic image segmentation): [Installation instruction](https://github.com/AllenCell/aics-segmentation) (available on Linux, MacOS, Windows) -* `aicsmlsegment` (deep learning segmentation): [Installation instruction](../README.md) (requires NVIDIA GPU and Linux OS) +* `aicsmlsegment` (deep learning segmentation): [Installation instruction](../README.md) (requires NVIDIA GPU and Linux OS, Windows is also okay) ## Understanding each building block: diff --git a/gradle.properties b/gradle.properties deleted file mode 100644 index b6dd795..0000000 --- a/gradle.properties +++ /dev/null @@ -1,12 +0,0 @@ -# This points to the location of the script plugins used to provide the build tasks. -# This also provides the gradle version - run './gradlew wrapper' to set it. - -#### -scriptPluginPrefix=https://aicsbitbucket.corp.alleninstitute.org/projects/sw/repos/gradle-script-plugins/raw/ -scriptPluginSuffix=?at=refs/tags/ - -# This variable is separated in order to allow it to be overridden with an environment variable -# See: https://docs.gradle.org/current/userguide/build_environment.html#sec:project_properties -# This will be overridden by Jenkins with the version defined in the Jenkins configuration. -# The variable can also be used locally: `ORG_GRADLE_PROJECT_scriptPluginTag=2.4.3 ./gradlew tasks` -scriptPluginTag=2.4.2 diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar deleted file mode 100644 index c44b679..0000000 Binary files a/gradle/wrapper/gradle-wrapper.jar and /dev/null differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties deleted file mode 100644 index 6b3851a..0000000 --- a/gradle/wrapper/gradle-wrapper.properties +++ /dev/null @@ -1,5 +0,0 @@ -distributionBase=GRADLE_USER_HOME -distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.1-bin.zip -zipStoreBase=GRADLE_USER_HOME -zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew deleted file mode 100755 index af6708f..0000000 --- a/gradlew +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env sh - -############################################################################## -## -## Gradle start up script for UN*X -## -############################################################################## - -# Attempt to set APP_HOME -# Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi -done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null - -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m"' - -# Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" - -warn () { - echo "$*" -} - -die () { - echo - echo "$*" - echo - exit 1 -} - -# OS specific support (must be 'true' or 'false'). -cygwin=false -msys=false -darwin=false -nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; -esac - -CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar - -# Determine the Java command to use to start the JVM. -if [ -n "$JAVA_HOME" ] ; then - if [ -x "$JAVA_HOME/jre/sh/java" ] ; then - # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" - else - JAVACMD="$JAVA_HOME/bin/java" - fi - if [ ! -x "$JAVACMD" ] ; then - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." - fi -else - JAVACMD="java" - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." -fi - -# Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi -fi - -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi - -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi - # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" - fi - i=$((i+1)) - done - case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac -fi - -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=$(save "$@") - -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" - -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" -fi - -exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat deleted file mode 100644 index 6d57edc..0000000 --- a/gradlew.bat +++ /dev/null @@ -1,84 +0,0 @@ -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega diff --git a/settings.gradle b/settings.gradle deleted file mode 100644 index d47dc3b..0000000 --- a/settings.gradle +++ /dev/null @@ -1 +0,0 @@ -rootProject.name = 'aicsmlsegment' \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 052e780..01cb056 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,22 +1,31 @@ -[bdist_wheel] -universal=1 +[bumpversion] +current_version = 0.0.0 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version="{current_version}" +replace = version="{new_version}" + +[bumpversion:file:aicsmlsegment/__init__.py] +search = {current_version} +replace = {new_version} -[metadata] -license_file = LICENSE.txt +[bdist_wheel] +universal = 1 [aliases] -test=pytest +test = pytest [tool:pytest] -addopts = --junitxml=build/test_report.xml -v -norecursedirs = aicsmlsegment/tests/checkouts .egg* build dist venv .gradle aicsmlsegment.egg-info/* +collect_ignore = ['setup.py'] [flake8] -max-line-length = 130 - -[coverage:html] -directory = build/coverage_html -title = Test coverage report for aicsbatch - -[coverage:xml] -output = build/coverage.xml +exclude = + docs/ +ignore = + E203 + E402 + W291 + W503 +max-line-length = 88 diff --git a/setup.py b/setup.py index 7148d26..86a7af5 100644 --- a/setup.py +++ b/setup.py @@ -1,65 +1,107 @@ -from setuptools import setup, find_packages +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""The setup script.""" -PACKAGE_NAME = 'aicsmlsegment' +from setuptools import find_packages, setup -""" -Notes: -We get the constants MODULE_VERSION from -See (3) in following link to read about versions from a single source -https://packaging.python.org/guides/single-sourcing-package-version/#single-sourcing-the-version -""" +with open("README.md") as readme_file: + readme = readme_file.read() -MODULE_VERSION = "" -exec(open(PACKAGE_NAME + "/version.py").read()) +PACKAGE_NAME = "aicsmlsegment" +MODULE_VERSION = "0.0.7" -def readme(): - with open('README.md') as f: - return f.read() +setup_requirements = [ + "pytest-runner>=5.2", +] +test_requirements = [ + "black>=19.10b0", + "codecov>=2.1.4", + "flake8>=3.8.3", + "flake8-debugger>=3.2.1", + "pytest>=5.4.3", + "pytest-cov>=2.9.0", + "pytest-raises>=0.11", +] -test_deps = ['pytest', 'pytest-cov'] -lint_deps = ['flake8'] -all_deps = [*test_deps, *lint_deps] -extras = { - 'test_group': test_deps, - 'lint_group': lint_deps, - 'all': all_deps +dev_requirements = [ + *setup_requirements, + *test_requirements, + "bump2version>=1.0.1", + "coverage>=5.1", + "ipython>=7.15.0", + "m2r>=0.2.1", + "pytest-runner>=5.2", + "Sphinx>=2.0.0b1,<3", + "sphinx_rtd_theme>=0.4.3", + "tox>=3.15.2", + "twine>=3.1.1", + "wheel>=0.34.2", +] + +requirements = [ + "numpy>=1.15.1", + "scipy>=1.1.0", + "scikit-image", + "pandas>=0.23.4", + "aicsimageio>3.3.0", + "tqdm", + "pyyaml", + "monai>=0.4.0", + "pytorch-lightning", + "torchio", +] + +extra_requirements = { + "setup": setup_requirements, + "test": test_requirements, + "dev": dev_requirements, + "all": [ + *requirements, + *dev_requirements, + ], } -setup(name=PACKAGE_NAME, - version=MODULE_VERSION, - description='Scripts for ML structure segmentation.', - long_description=readme(), - author='AICS', - author_email='jianxuc@alleninstitute.org', - license='Allen Institute Software License', - packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']), - entry_points={ - "console_scripts": [ +setup( + author="AICS", + author_email="jianxuc@alleninstitute.org", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: Free for non-commercial use", + "Natural Language :: English", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], + description="Scripts for ML structure segmentation.", + entry_points={ + "console_scripts": [ "dl_train={}.bin.train:main".format(PACKAGE_NAME), "dl_predict={}.bin.predict:main".format(PACKAGE_NAME), "curator_merging={}.bin.curator.curator_merging:main".format(PACKAGE_NAME), "curator_sorting={}.bin.curator.curator_sorting:main".format(PACKAGE_NAME), "curator_takeall={}.bin.curator.curator_takeall:main".format(PACKAGE_NAME), - ] - }, - install_requires=[ - 'numpy>=1.15.1', - 'scipy>=1.1.0', - 'scikit-image', - 'pandas>=0.23.4', - 'aicsimageio>3.3.0', - 'tqdm', - 'pyyaml', - #'pytorch=1.0.0' - ], - - # For test setup. This will allow JUnit XML output for Jenkins - setup_requires=['pytest-runner'], - tests_require=test_deps, - - extras_require=extras, - zip_safe=False - ) + ] + }, + install_requires=requirements, + license="Allen Institute Software License", + long_description=readme, + long_description_content_type="text/markdown", + include_package_data=True, + keywords="aicsmlsegment", + name=PACKAGE_NAME, + packages=find_packages(exclude=["tests", "*.tests", "*.tests.*"]), + python_requires=">=3.7", + setup_requires=setup_requirements, + test_suite="aicsmlsegment/tests", + tests_require=test_requirements, + extras_require=extra_requirements, + url="https://github.com/AllenCell/aics-ml-segmentation", + # Do not edit this string manually, always use bumpversion + # Details in CONTRIBUTING.rst + version=MODULE_VERSION, + zip_safe=False, +) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..6c399dd --- /dev/null +++ b/tox.ini @@ -0,0 +1,18 @@ +[tox] +skipsdist = True +envlist = py37, py38, py39, lint + +[testenv:lint] +deps = + .[test] +commands = + flake8 aicsmlsegment --count --verbose --show-source --statistics + black --check aicsmlsegment + +[testenv] +setenv = + PYTHONPATH = {toxinidir} +deps = + .[test] +commands = + pytest --basetemp={envtmpdir} --cov-report html --cov=aicsmlsegment aicsmlsegment/tests/