forked from thatbrguy/Dehaze-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
120 lines (101 loc) · 4.34 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import tensorflow as tf
import argparse
from pathlib import Path
from model import FogRemovalGAN
from training import FogDataset
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate Fog Removal GAN on test set')
parser.add_argument('--model_name', type=str, required=True,
help='Name of the trained model')
parser.add_argument('--data_root', type=str, required=True,
help='Root directory containing test folder')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for testing')
parser.add_argument('--save_results', action='store_true',
help='Save processed images')
# Add other model architecture parameters that match your training
parser.add_argument('--growth_rate', type=int, default=12)
parser.add_argument('--layers', type=int, default=4)
parser.add_argument('--D_filters', type=int, default=64)
return parser.parse_args()
def evaluate_model(args):
# Create dataset
dataset = FogDataset(
root_dir=args.data_root,
batch_size=args.batch_size,
image_size=(256, 256) # adjust if using different size
)
test_dataset = dataset.get_dataset('test')
# Initialize model
model = FogRemovalGAN(args)
# Load weights
checkpoint_dir = Path(args.model_name) / 'checkpoints'
checkpoint = tf.train.Checkpoint(generator=model.generator)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint).expect_partial()
print(f"Loaded weights from {latest_checkpoint}")
else:
raise ValueError(f"No checkpoint found in {checkpoint_dir}")
# Initialize metrics
test_metrics = {
'test_psnr': tf.keras.metrics.Mean(),
'test_ssim': tf.keras.metrics.Mean()
}
# Process test set
total_test_steps = tf.data.experimental.cardinality(test_dataset).numpy()
progress_bar = tf.keras.utils.Progbar(total_test_steps,
stateful_metrics=['PSNR', 'SSIM'])
# If saving results
if args.save_results:
output_dir = Path(args.model_name) / 'test_results'
output_dir.mkdir(parents=True, exist_ok=True)
# Evaluate each image
all_psnrs = []
all_ssims = []
for step, (input_image, target_image) in enumerate(test_dataset):
# Generate output
generated_image = model.generator(input_image, training=False)
# Convert images from [-1,1] to [0,255] range
target_image = (target_image + 1) * 127.5
generated_image = (generated_image + 1) * 127.5
# Calculate metrics
psnr = tf.image.psnr(target_image, generated_image, max_val=255.0)
ssim = tf.image.ssim(target_image, generated_image, max_val=255.0)
# Store individual metrics
all_psnrs.append(float(psnr))
all_ssims.append(float(ssim))
# Update running averages
test_metrics['test_psnr'].update_state(psnr)
test_metrics['test_ssim'].update_state(ssim)
# Update progress bar
progress_bar.update(
step + 1,
[('PSNR', float(psnr)),
('SSIM', float(ssim))]
)
# Save results if requested
if args.save_results:
output_image = tf.cast(generated_image[0], tf.uint8)
tf.io.write_file(
str(output_dir / f'output_{step:04d}.png'),
tf.image.encode_png(output_image)
)
# Calculate and print final results
results = {
'Average PSNR': float(test_metrics['test_psnr'].result()),
'Average SSIM': float(test_metrics['test_ssim'].result()),
'Std PSNR': float(tf.math.reduce_std(all_psnrs)),
'Std SSIM': float(tf.math.reduce_std(all_ssims)),
'Min PSNR': float(tf.math.reduce_min(all_psnrs)),
'Min SSIM': float(tf.math.reduce_min(all_ssims)),
'Max PSNR': float(tf.math.reduce_max(all_psnrs)),
'Max SSIM': float(tf.math.reduce_max(all_ssims))
}
print("\nTest Set Evaluation Results:")
for metric_name, value in results.items():
print(f"{metric_name}: {value:.4f}")
return results
if __name__ == '__main__':
args = parse_args()
evaluate_model(args)