Skip to content

Commit

Permalink
setup.py + cmd line script
Browse files Browse the repository at this point in the history
  • Loading branch information
hlgirard committed Apr 26, 2019
1 parent e1b6397 commit 3f8050e
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 26 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include models *.h5 *.json
13 changes: 13 additions & 0 deletions bin/crystalml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env pythonw
import os
import argparse

from src.crystal_processing.process_image_folder import process_image_folder

# Setup parser
ap = argparse.ArgumentParser()
ap.add_argument("directory", default=os.getcwd(), help="Path of the directory")

args = ap.parse_args()

process_image_folder(args.directory)
13 changes: 0 additions & 13 deletions bin/process_kinetics_folder

This file was deleted.

Empty file added models/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ opencv-python==4.0.0.21
pandas==0.23.4
Pillow==6.0.0
plotly==3.6.1
setuptools
scikit-image==0.15.0
scipy==1.1.0
seaborn==0.9.0
tensorboard==1.13.1
tensorflow==1.13.1
tqdm==4.28.1
35 changes: 35 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from setuptools import setup, find_packages

with open("README.md", "r") as f:
long_description = f.read()

setup(name='crystalml',
version='0.0.1',
description='Integrated tool to measure the nucleation rate of protein crystals. ',
long_description=long_description,
long_description_content_type="text/markdown",
url='https://github.com/hlgirard/CrystalML',
author='Henri-Louis Girard',
author_email='[email protected]',
license='GPLv3',
packages=find_packages(exclude=['tests']),
install_requires=[
'tensorflow',
'matplotlib',
'numpy',
'opencv-python',
'pandas',
'pillow',
'plotly',
'scikit-image',
'setuptools',
'scipy',
'seaborn',
'tensorboard'

],
scripts=[
'bin/crystalml',
],
zip_safe=False,
include_package_data=True)
27 changes: 15 additions & 12 deletions src/crystal_processing/process_image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,37 @@
import re
import logging
from joblib import Parallel, delayed
import pkg_resources
import logging

import numpy as np
import pandas as pd

from tensorflow.keras.models import model_from_json
logging.getLogger('tensorflow').disabled = True

from src.data.utils import select_rectangle, get_date_taken, open_grey_scale_image
from src.data.segment_droplets import crop, segment, extract_indiv_droplets
from src.visualization.image_processing_overlay import save_overlay_image
from src.visualization.process_plotting import plot_crystal_data

def load_model(path):
def load_model(model_name):
'''Loads model from path and get most recent associated weights'''
# TODO: Move this methods to the models utils package

model_name = path.split('/')[-1].split('.')[0]
model_basename = model_name.split('.')[0]
model_path = pkg_resources.resource_filename('models', model_name)

## Load model from JSON
with open(path, 'r') as json_file:
with open(model_path, 'r') as json_file:
loaded_model_json = json_file.read()

model = model_from_json(loaded_model_json)

## Load weights into model
model_list = sorted([model for model in os.listdir("models") if model.startswith(model_name) and model.endswith('.h5')], key = lambda x: int(re.search(r'\d+', x).group(0)))
model_list = sorted([model for model in pkg_resources.resource_listdir('models', '.') if model.startswith(model_basename) and model.endswith('.h5')], key = lambda x: int(re.search(r'\d+', x).group(0)))
logging.info("Loading model weights: %s", model_list[-1])
model.load_weights("models/" + model_list[-1])
model.load_weights(pkg_resources.resource_filename('models', model_list[-1]))

return model

Expand Down Expand Up @@ -92,7 +96,7 @@ def process_image(image_path, crop_box, model, save_overlay = False):

return (date_taken, num_drops, num_clear, num_crystal)

def process_image_batch(image_list, crop_box, model_path, save_overlay = False):
def process_image_batch(image_list, crop_box, model_name, save_overlay = False):
'''Process a batch of images and return a list of results
Parameters
Expand All @@ -101,7 +105,7 @@ def process_image_batch(image_list, crop_box, model_path, save_overlay = False):
List of paths to the image to process
crop_box: (minRow, maxRow, minCol, maxCol)
Cropping box to select the region of interest
model_path: string
model_name: string
Path to the tensorflow model to load
save_overlay: bool, optional
Save an image with green / red overlays for drops containing crystals / empty to `image_path / overlay`
Expand All @@ -114,7 +118,7 @@ def process_image_batch(image_list, crop_box, model_path, save_overlay = False):
'''

# Instantiate the model
model = load_model(model_path)
model = load_model(model_name)

# Process the data
data = []
Expand All @@ -138,7 +142,7 @@ def process_image_folder(directory, crop_box=None, show_plot=False, save_overlay
print(f"Number of batches: {num_batches}")

# Define the model path
model_path = "models/cnn-simple-model.json"
model_name = "cnn-simple-model.json"

# Obtain crop box from user if not passed as argument
if not crop_box:
Expand All @@ -148,10 +152,10 @@ def process_image_folder(directory, crop_box=None, show_plot=False, save_overlay
# Process all images from directory in parallel
if num_batches == 0:
# Process serialy
data = [process_image_batch(image_list[i*batch_size:min([(i+1)*batch_size, num_images])], crop_box, model_path, save_overlay)
data = [process_image_batch(image_list[i*batch_size:min([(i+1)*batch_size, num_images])], crop_box, model_name, save_overlay)
for i in range(num_batches)]
else:
data = Parallel(n_jobs=-2, verbose=10)(delayed(process_image_batch)(image_list[i*batch_size:min([(i+1)*batch_size, num_images])], crop_box, model_path, save_overlay)
data = Parallel(n_jobs=-2, verbose=10)(delayed(process_image_batch)(image_list[i*batch_size:min([(i+1)*batch_size, num_images])], crop_box, model_name, save_overlay)
for i in range(num_batches))

flat_data = [item for sublist in data for item in sublist]
Expand All @@ -165,7 +169,6 @@ def process_image_folder(directory, crop_box=None, show_plot=False, save_overlay
if show_plot:
plot_crystal_data(df, directory)


if __name__ == "__main__":
folder = "notebooks/example_data"

Expand Down

0 comments on commit 3f8050e

Please sign in to comment.