Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions test/diffusion/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared fixtures and constants for diffusion tests."""

import random

import pytest
import torch

# =============================================================================
# Shared Constants
# =============================================================================

GLOBAL_SEED = 42

CPU_TOLERANCES = {"atol": 1e-5, "rtol": 1e-5}
GPU_TOLERANCES = {"atol": 1e-2, "rtol": 5e-2}


# =============================================================================
# Shared Fixtures
# =============================================================================


@pytest.fixture
def deterministic_settings():
"""Set deterministic settings for reproducibility, then restore old state."""
old_cudnn_deterministic = torch.backends.cudnn.deterministic
old_cudnn_benchmark = torch.backends.cudnn.benchmark
old_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32
old_cudnn_tf32 = torch.backends.cudnn.allow_tf32
old_random_state = random.getstate()

try:
random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(GLOBAL_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
yield
finally:
torch.backends.cudnn.deterministic = old_cudnn_deterministic
torch.backends.cudnn.benchmark = old_cudnn_benchmark
torch.backends.cuda.matmul.allow_tf32 = old_matmul_tf32
torch.backends.cudnn.allow_tf32 = old_cudnn_tf32
random.setstate(old_random_state)


@pytest.fixture
def tolerances(device):
"""Return tolerances based on the device (CPU vs GPU)."""
if device == "cpu":
return CPU_TOLERANCES
return GPU_TOLERANCES
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
29 changes: 28 additions & 1 deletion test/diffusion/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for diffusion preconditioner tests."""
"""Helper functions for diffusion tests."""

from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
Expand All @@ -28,6 +28,33 @@
DATA_DIR = Path(__file__).parent / "data"


def make_input(
shape: Tuple[int, ...],
seed: int = 42,
device: str = "cpu",
) -> torch.Tensor:
"""
Create a deterministic input tensor using a separate Generator.

Parameters
----------
shape : Tuple[int, ...]
Shape of the output tensor.
seed : int
Random seed for deterministic generation.
device : str
Device to place the tensor on.

Returns
-------
torch.Tensor
A normally-distributed random tensor with the given shape.
"""
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
return torch.randn(*shape, generator=gen).to(device)


def instantiate_model_deterministic(
cls,
seed: int = 0,
Expand Down
Loading