Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,30 @@ This project uses TensorFlow2.0 for image classification tasks.

## How to use
### Requirements
+ **Python 3.x** (My Python version is 3.6.8)<br/>
+ **TensorFlow version: 2.0.0-beta1**<br/>
+ **Python 3.x** (My Python version is 3.8.0)<br/>
+ **TensorFlow version: 2.11.0 <br/>
+ The file directory of the dataset should look like this: <br/>
```
${dataset_root}
|——train
| |——class_name_0
| |——class_name_1
| |——class_name_2
| |——class_name_3
| |——class_dir_0
| | |——image_1.jpg
| | |——image_2.jpg
| | |——image_3.jpg
| | ...
| |——class_dir_1
| |——class_dir_2
| |——class_dir_3
|——valid
| |——class_name_0
| |——class_name_1
| |——class_name_2
| |——class_name_3
| |——class_dir_0
| |——class_dir_1
| |——class_dir_2
| |——class_dir_3
|——test
|——class_name_0
|——class_name_1
|——class_name_2
|——class_name_3
|——class_dir_0
|——class_dir_1
|——class_dir_2
|——class_dir_3
```

### Train
Expand All @@ -40,3 +44,4 @@ The structure of the network is defined in `model_definition.py`, you can change
## References
1. AlexNet : http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
2. VGG : https://arxiv.org/abs/1409.1556
3. Keras : https://keras.io/api/applications/
23 changes: 13 additions & 10 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# some training parameters
EPOCHS = 50
BATCH_SIZE = 8
NUM_CLASSES = 5
image_height = 224
image_width = 224
EPOCHS = 100
BATCH_SIZE = 128
NUM_CLASSES = 6
image_height = 128
image_width = 128
channels = 3
model_dir = "image_classification_model.h5"
train_dir = "dataset/train"
valid_dir = "dataset/valid"
test_dir = "dataset/test"
test_image_path = ""

model_save_name = "EfficientNetB2"
model_dir = "trained_models/salmon_crop_128/"+model_save_name+"/" # = save_dir

train_dir = "/home/mirap/0_DATABASE/IMAS_Salmon/7_SalmonTest/train"
valid_dir = "/home/mirap/0_DATABASE/IMAS_Salmon/7_SalmonTest/valid"
test_dir = "/home/mirap/0_DATABASE/IMAS_Salmon/7_SalmonTest/test"
test_image_path = "/home/mirap/0_DATABASE/IMAS_Salmon/7_SalmonTest/test/5/5_108.jpg"
34 changes: 26 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
import tensorflow as tf
import config
from prepare_data import get_datasets
from sklearn.metrics import classification_report
import numpy as np

train_generator, valid_generator, test_generator, \
train_num, valid_num, test_num= get_datasets()

def eval_model(new_model):
# Load data
train_generator, valid_generator, test_generator, \
train_num, valid_num, test_num= get_datasets()

# Load the model
new_model = tf.keras.models.load_model(config.model_dir)
# Get the accuracy on the test set
loss, acc = new_model.evaluate_generator(test_generator,
steps=test_num // config.BATCH_SIZE)
print("The accuracy on test set is: {:6.3f}%".format(acc*100))
# Get the accuracy on the test set
loss, acc, auc, precision, recall = new_model.evaluate(test_generator,
batch_size=config.BATCH_SIZE,
steps=test_num // config.BATCH_SIZE)
print("result of ",config.model_dir)
print("The accuracy on test set is: {:6.3f}%".format(acc*100))
print("The auc on test set is: {:6.3f}%".format(auc*100))
print("The precision on test set is: {:6.3f}%".format(precision*100))
print("The recall on test set is: {:6.3f}%".format(recall*100))

# Evaluate per class
lables_array = test_generator.classes
predictions = new_model.predict(test_generator)
predictions = np.argmax(predictions, axis=1)
print(classification_report(lables_array, predictions))

if __name__ == '__main__':
# Load the model
new_model = tf.keras.models.load_model(config.model_dir+config.model_save_name+".h5")
eval_model(new_model)
1 change: 0 additions & 1 deletion log/README.md

This file was deleted.

Loading