torch-fidelity is a PyTorch library providing epsilon-exact implementations of generative model evaluation metrics: Inception Score (ISC), Frechet Inception Distance (FID), Kernel Inception Distance (KID), Precision/Recall/F-score (PRC), Perceptual Path Length (PPL), and Monge Inception Distance (MIND). The library prioritizes numerical fidelity with reference TensorFlow implementations.
Current version: 0.4.0 (torch_fidelity/version.py)
torch_fidelity/ # Main package (~4100 lines, 29 modules)
metrics.py # Orchestration: calculate_metrics() entry point
metric_isc.py # Inception Score
metric_fid.py # Frechet Inception Distance
metric_kid.py # Kernel Inception Distance (poly/rbf kernels)
metric_prc.py # Precision, Recall, F-score
metric_ppl.py # Perceptual Path Length
metric_mind.py # Monge Inception Distance (sliced 2-Wasserstein)
feature_extractor_base.py # Abstract base for feature extractors
feature_extractor_inceptionv3.py # InceptionV3 (TF-compatible weights)
feature_extractor_clip.py # CLIP feature extraction
feature_extractor_vgg16.py # VGG16 feature extraction
feature_extractor_dinov2.py # DinoV2 (4 variants)
generative_model_base.py # Abstract base for generative models
generative_model_modulewrapper.py # PyTorch module wrapper
generative_model_onnx.py # ONNX/JIT model support
sample_similarity_base.py # Abstract base for similarity metrics
sample_similarity_lpips.py # LPIPS implementation
registry.py # Plugin registration system
defaults.py # All configuration defaults (~60 parameters)
deprecations.py # Deprecated parameter handling
datasets.py # Dataset wrappers (CIFAR-10/100, STL-10)
noise.py # Noise generators (normal, uniform, unit sphere)
helpers.py # vassert, vprint, get_kwarg utilities
utils.py # Feature extraction, caching, dataset handling
utils_torch.py # torch.compile support
utils_torchvision.py # TorchVision integration
interpolate_compat_tensorflow.py # TF-compatible bilinear interpolation
fidelity.py # CLI entry point
version.py # Version string
__init__.py # Public API exports
tests/ # Test suite (42 test files)
__init__.py # TimeTrackingTestCase base class
run_tests.sh # Docker-based full test runner
torch_pure/ # Pure PyTorch tests (batching, feature extractors, misc)
tf1/ # TF1 reference implementation comparison tests
clip/ # CLIP feature extractor tests
prc_ppl_reference/ # PRC and PPL reference tests
torch_versions_ge_1_13_0/ # PyTorch version compatibility tests
sphinx_doc/ # Documentation build tests
aws/ # AWS test harness
examples/ # Example training integrations
sngan_cifar10.py # SNGAN training with metric evaluation
doc/ # Sphinx documentation source
.circleci/ # CI configuration
config.yml # CircleCI pipeline
smoke_tests.py # CI smoke tests
pip install -e .
# or
pip install numpy pillow torch torchvision tqdm# The package provides a `fidelity` console script
fidelity --input1 /path/to/images1 --input2 /path/to/images2 --fid --isc --kid
# Or run as module
python -m torch_fidelity.fidelity --input1 ... --input2 ... --fidSmoke tests (CI-style, CPU-only):
CUDA_VISIBLE_DEVICES="" PYTHONPATH=. python .circleci/smoke_tests.pyFull test suite (requires Docker + GPU):
tests/run_tests.sh # runs all suites except tf1
tests/run_tests.sh --with-tf1 # includes tf1 (requires pre-Ampere GPU)The full suite uses Docker containers (NGC PyTorch base images) and runs six test flavors sequentially:
-
torch_versions_ge_1_13_0(CUDA, strict warnings) — Backward compatibility testing. Dynamically installs torch 1.13.1, 2.0.1, 2.1.1, and latest, running the full metrics pipeline (ISC/FID/KID/PRC/MIND with Inception-v3, CLIP, DINOv2) against each version. Verifies metric values stay within tight tolerances across versions. The Docker image strips all pre-installed torch packages and installs/uninstalls them per test. Minimum version is 1.13.0 due to CUDA sm_86 (Ampere) support requirement. -
tf1(CUDA, no strict warnings) — Skipped by default (enable with--with-tf1). Legacy TensorFlow 1.14 numerical precision comparison. Uses an old NGC base image (pytorch:19.02-py3) with CUDA 10.0 to cross-validate against the original TF reference implementations. Requires pre-Ampere GPU (V100, T4, etc.) — fails on Ampere+ (sm_86+) due to missing cuBLAS kernels in CUDA 10.0. Warnings are not treated as errors because TF1/legacy code generates many deprecation warnings. These tests were validated up to and including v0.3.0; numerical correctness of v0.4.0 is ensured via the other test suites. -
torch_pure(CUDA, strict warnings) — Core functionality tests (~135+ tests). Covers batch size independence (verifies metrics are identical across batch sizes in fp32/fp64), torch.compile() compatibility, all feature extractors (InceptionV3, VGG16, CLIP, DINOv2), all metrics, FID statistics edge cases, PRC convention correctness, generative model input (ISC/PPL with SNGAN), and LPIPS reference comparison (VGG16-based, validates against NVIDIA's pretrained module). Uses the latest torch from the NGC base image. -
clip(CUDA, strict warnings) — CLIP feature extractor integration tests. Validates CLIP-based metric computation against the OpenAI CLIP model (pinned commit). Has a separate Docker image because it requires CLIP's dependencies (ftfy, regex, setuptools, clean-fid). -
prc_ppl_reference(CUDA, strict warnings) — Reference implementation comparison for PRC and PPL metrics. Uses an older NGC base image (pytorch:21.02-py3) to compare against the original StyleGAN2-ADA reference implementation, ensuring epsilon-exact numerical agreement. -
sphinx_doc(CPU) — Documentation build test. Uses thesphinxdoc/sphinxbase image to verify the Sphinx documentation builds without errors. Runs shell-based tests (test_*.sh).
Individual test discovery:
python -W error -m unittest discover -s tests/<flavor> -t . -p 'test_*.py'black --line-length 120 .cd doc/sphinx && make html- Formatter: Black with line-length 120 (configured in
pyproject.toml) - Python version: >= 3.6
- Assertions: Use
vassert(condition, message)fromhelpers.pyinstead of bareassert - Verbose output: Use
vprint(verbose, message)which prints to stderr - Configuration access: Use
get_kwarg("name", kwargs)to read parameters with defaults fromdefaults.py - Deprecations: Handled via
process_deprecations()and theDEPRECATIONSdict indeprecations.py - Test base class: All tests extend
TimeTrackingTestCasewhich tracks timing and clears CUDA cache
The registry (registry.py) supports five extension points:
register_dataset(name, fn_create)- Custom datasetsregister_feature_extractor(name, cls)- Must subclassFeatureExtractorBaseregister_sample_similarity(name, cls)- Must subclassSampleSimilarityBaseregister_noise_source(name, fn_generate)- Noise generatorsregister_interpolation(name, fn_interpolate)- Interpolation methods
Pre-registered components are in registry.py lines 143-199.
Input → Feature Extraction (cached) → Metric Computation
- Unary metrics (ISC, PPL): require only
input1 - Binary metrics (FID, KID, PRC, MIND): require
input1andinput2 - PRC convention:
input1= generated (evaluated),input2= real (reference). Precision = fraction of generated samples in real manifold; recall = fraction of real samples in generated manifold - Feature extraction results are cached to disk when
cache=True - FID has a shortcut path when statistics are cached but features are not
- When using default feature extractors, ISC/FID/KID/MIND use InceptionV3 and PRC uses VGG16; if both groups are requested, two separate feature extraction passes run automatically
| Name | Class | Default For |
|---|---|---|
inception-v3-compat |
FeatureExtractorInceptionV3 |
ISC, FID, KID, MIND |
vgg16 |
FeatureExtractorVGG16 |
PRC |
clip-vit-b-32 (and other CLIP variants) |
FeatureExtractorCLIP |
- |
dinov2-vit-{s,b,l,g}-14 |
FeatureExtractorDinoV2 |
- |
The input1/input2 parameters accept:
- Registered dataset names (e.g.,
cifar10-train,stl10-test) - Directory paths containing images
- ONNX/PTH model file paths
torch.utils.data.DatasetinstancesGenerativeModelBaseinstances
- Numerical fidelity: The InceptionV3 implementation uses TF-compatible weights and a custom bilinear interpolation (
interpolate_compat_tensorflow.py) to match TensorFlow output to machine precision - kwargs-based API: All configuration flows through
**kwargschecked againstdefaults.py; no dataclass or typed config objects - No scipy dependency: Matrix square root for FID is implemented in pure PyTorch
- Caching: Multi-level caching (features and FID statistics) to avoid redundant computation
- Verbose to stderr: All progress output goes to stderr so stdout can be parsed as JSON
- CircleCI: Python 3.11.7, large resource class, CPU-only smoke tests
- Runs on every push and weekly (Mondays) on master
- Smoke tests validate all metrics against known reference values with tight tolerances
- Tests use
psutilfor memory monitoring - Smoke test pattern: all tests run via
_run_fidelity_command()which wrapssubprocess.run. Tests must not importtorchor library internals directly; instead, runpython3 -m torch_fidelity.fidelityorpython3 -c "..."as a subprocess and assert on JSON output
- The
feature_extractor_compileoption is experimental and may affect numerical precision - KID can produce negative values (this is mathematically expected)
- Lossy image formats (jpg/jpeg) trigger warnings since they affect metric precision
- When modifying feature extractors, ensure the output layer names remain consistent as they are used as cache keys
- The InceptionV3 implementation intentionally differs from torchvision's to maintain TF compatibility
See CONTRIBUTING.md.