Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .benchmark import Benchmark

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔🤔could we have a config driven interface? :P

26 changes: 26 additions & 0 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import abc
import deepinv as dinv
from typing import Any


class Benchmark(abc.ABC):
r"""
Abstract base class for benchmarks

All of the benchmarks should inherit this class and implement the `run` method.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have dataset, metrics in the init here


@abc.abstractmethod
def run(
self,
model: dinv.models.Denoiser | dinv.models.Reconstructor,
*,
device: torch.device | str = torch.device("cpu")
) -> Any:
"""Run the benchmark on the given model

:param dinv.models.Denoiser | dinv.models.Reconstructor model: The model to benchmark
:param torch.device | str device: The device to run the benchmark on (default: `"cpu"`)
:return: (`Any`) The result of the benchmark
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could write the function here with deepinv.test and avoid re-writing the for-loop for each benchmark

pass
82 changes: 82 additions & 0 deletions benchmarks/denoising/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations
import torch
import torch.utils.data
import deepinv as dinv
import torchvision.transforms as transforms
import numpy as np
from tqdm.auto import tqdm
from typing import Any
import pandas as pd # noqa: TID253
import benchmarks


class DenoisingBenchmark(benchmarks.Benchmark):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we need to redefine a new class,
why not something like

Benchmark(dataset=dinv.datasets.Urban100HR("data/Urban100", download=True, transform=transforms.ToTensor()), device=device, )

r"""
Benchmark for Gaussian Denoising on Urban100 dataset

.. note::

The noise standard deviation is set to 25/255 for images normalized between 0 and 1.
"""

@staticmethod
def run(
model: dinv.models.Denoiser,
*,
device: torch.device | str = torch.device("cpu"),
) -> Any:
"""Run the benchmark on the given model"""
dataset = dinv.datasets.Urban100HR(
"data/Urban100", download=True, transform=transforms.ToTensor()
)

rng = torch.Generator(device)
physics = dinv.physics.Denoising(
dinv.physics.GaussianNoise(sigma=25 / 255, rng=rng)
).to(device)

psnr_fn = dinv.metric.PSNR(min_pixel=0.0, max_pixel=1.0).to(device)

dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=1, prefetch_factor=1
)

psnrs = []
model = model.to(device).eval()
for k, x in enumerate(tqdm(dataloader)):
x = x.to(device)
y = physics(x, seed=k)

with torch.no_grad():
x_hat = model(y, physics.noise_model.sigma)

# Clip and quantize
x_hat = x_hat.mul(255.0).round().div(255.0).clamp(0.0, 1.0)
x = x.mul(255.0).round().div(255.0).clamp(0.0, 1.0)

psnr = psnr_fn(x_hat, x).item()
psnrs.append(psnr)
if k >= 1:
break

return np.mean(psnrs), np.std(psnrs)


if __name__ == "__main__":
benchmark = DenoisingBenchmark()
models = [dinv.models.DRUNet(), dinv.models.Restormer()]

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

rows = []
for model in models:
model_name = type(model).__name__
psnr_avg, psnr_std = benchmark.run(model, device=device)
rows.append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could move this logic to the base class since all benchmarks will use this

{"model_name": model_name, "psnr_avg": psnr_avg, "psnr_std": psnr_std}
)

df = pd.DataFrame(rows)
out_path = "./denoising.csv"
df.to_csv(out_path, index=False)
print(f"Benchmark results saved to {out_path}")