diff --git a/.gitignore b/.gitignore index 068aeb8..5d48f53 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,7 @@ model/musiq/musiq_spaq_ckpt-358bb6af.pth VBench video_detection watermark/gm/ckpts/model_final.pth -model_from_hf/ \ No newline at end of file +model_from_hf/ +assets +report.html +.coverage \ No newline at end of file diff --git a/detection/__init__.py b/detection/__init__.py index 8dd2c33..3b2a755 100644 --- a/detection/__init__.py +++ b/detection/__init__.py @@ -12,24 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Detection module for MarkDiffusion. +"""Detection module for watermark verification and extraction.""" -This module provides detection functionality for various watermarking algorithms. -""" - -__all__ = [ - 'base', - 'gm', - 'gs', - 'prc', - 'ri', - 'robin', - 'seal', - 'sfw', - 'tr', - 'videomark', - 'videoshield', - 'wind', -] +from .base import BaseDetector +__all__ = ["BaseDetector"] diff --git a/evaluation/__init__.py b/evaluation/__init__.py index c387b94..0ebac98 100644 --- a/evaluation/__init__.py +++ b/evaluation/__init__.py @@ -12,15 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Evaluation module for MarkDiffusion. +"""Evaluation module for watermark quality and robustness assessment.""" -This module provides tools for evaluating watermarking algorithms, -including quality analysis and detection rate calculations. -""" +from .dataset import BaseDataset -__all__ = [ - 'dataset', - 'pipelines', - 'tools', -] +__all__ = ["BaseDataset"] diff --git a/evaluation/pipelines/image_quality_analysis.py b/evaluation/pipelines/image_quality_analysis.py index 7d558e5..cc2f209 100644 --- a/evaluation/pipelines/image_quality_analysis.py +++ b/evaluation/pipelines/image_quality_analysis.py @@ -29,6 +29,20 @@ ) import lpips + +class SilentProgressBar: + """A silent progress bar wrapper that supports set_description but shows no output.""" + + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + return iter(self.iterable) + + def set_description(self, desc): + """No-op for silent mode.""" + pass + class QualityPipelineReturnType(Enum): """Return type of the image quality analysis pipeline.""" FULL = auto() @@ -120,7 +134,7 @@ def _get_progress_bar(self, iterable): """Return an iterable possibly wrapped with a progress bar.""" if self.show_progress: return tqdm(iterable, desc="Processing", leave=True) - return iterable + return SilentProgressBar(iterable) def _get_prompt(self, index: int) -> str: """Get prompt from dataset.""" diff --git a/evaluation/pipelines/video_quality_analysis.py b/evaluation/pipelines/video_quality_analysis.py index e7c7d0d..82554a0 100644 --- a/evaluation/pipelines/video_quality_analysis.py +++ b/evaluation/pipelines/video_quality_analysis.py @@ -11,6 +11,20 @@ import numpy as np from tqdm import tqdm + +class SilentProgressBar: + """A silent progress bar wrapper that supports set_description but shows no output.""" + + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + return iter(self.iterable) + + def set_description(self, desc): + """No-op for silent mode.""" + pass + class QualityPipelineReturnType(Enum): """Return type of the image quality analysis pipeline.""" FULL = auto() @@ -97,7 +111,7 @@ def _get_progress_bar(self, iterable): """Return an iterable possibly wrapped with a progress bar.""" if self.show_progress: return tqdm(iterable, desc="Processing", leave=True) - return iterable + return SilentProgressBar(iterable) def _get_prompt(self, index: int) -> str: """Get prompt from dataset.""" diff --git a/evaluation/tools/image_editor.py b/evaluation/tools/image_editor.py index 17d9ed0..08a1efb 100644 --- a/evaluation/tools/image_editor.py +++ b/evaluation/tools/image_editor.py @@ -212,8 +212,8 @@ def _add_salt_pepper_noise(self, img_array, amount): pepper_coords_y = np.random.randint(0, h, num_pepper) pepper_coords_x = np.random.randint(0, w, num_pepper) noisy[pepper_coords_y, pepper_coords_x] = 0 - - return noisy + + return np.clip(noisy, 0, 255).astype(np.uint8) def _add_poisson_noise(self, img_array): vals = len(np.unique(img_array)) diff --git a/evaluation/tools/video_quality_analyzer.py b/evaluation/tools/video_quality_analyzer.py index dc78683..486a7e2 100644 --- a/evaluation/tools/video_quality_analyzer.py +++ b/evaluation/tools/video_quality_analyzer.py @@ -18,6 +18,9 @@ except ImportError: BICUBIC = Image.BICUBIC +from pathlib import Path +CACHE_DIR = Path.home() / ".cache" / "markdiffusion" + def dino_transform_Image(n_px): """DINO transform for PIL Images.""" @@ -27,6 +30,35 @@ def dino_transform_Image(n_px): Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) +import requests +# def download_file(url, dest): +# resp = requests.get(url, timeout=20) +# resp.raise_for_status() +# with open(dest, "wb") as f: +# f.write(resp.content) +# print(f"Downloaded file: {dest}") + +def download_recursive(api_url, local_dir): + local_dir.mkdir(parents=True, exist_ok=True) + listing = requests.get(api_url, timeout=20).json() + if isinstance(listing, dict) and listing.get("message", "").startswith("API rate limit"): + raise RuntimeError("GitHub API rate-limited. Try again later.") + + for item in listing: + name = item["name"] + dest = local_dir / name + + if item["type"] == "file": + if not dest.exists(): + print(f"Downloading file: {item['download_url']}") + resp = requests.get(item["download_url"], timeout=20) + resp.raise_for_status() + with open(dest, "wb") as f: + f.write(resp.content) + + elif item["type"] == "dir": + dest.mkdir(exist_ok=True) + download_recursive(item["url"], dest) class VideoQualityAnalyzer: """Video quality analyzer base class.""" @@ -161,6 +193,54 @@ def analyze(self, frames: List[Image.Image]) -> float: else: return 1.0 +from contextlib import contextmanager + +@contextmanager +def isolated_import_context(code_dir, isolated_prefixes, prefix_tag=None): + """Context manager for isolated module imports to avoid conflicts with main project. + + Args: + code_dir: External code directory to add to sys.path + isolated_prefixes: List of module name prefixes to isolate (e.g., ['utils', 'networks']) + prefix_tag: Tag to prefix external modules with after loading (default: code_dir.name + '_ext_') + + Example: + with isolated_import_context(CODE_DIR, ['utils', 'networks']): + # imports here will use CODE_DIR's modules + spec = importlib.util.spec_from_file_location("entry", CODE_DIR / "main.py") + ... + # after exiting, main project's 'utils' is restored + """ + import sys + + if prefix_tag is None: + prefix_tag = code_dir.name + '_ext_' + + original_path = sys.path.copy() + saved_modules = {} + + # Remove potentially conflicting modules + for prefix in isolated_prefixes: + for mod_name in list(sys.modules.keys()): + if mod_name == prefix or mod_name.startswith(prefix + '.'): + saved_modules[mod_name] = sys.modules.pop(mod_name) + + sys.path.insert(0, str(code_dir)) + + try: + yield + finally: + sys.path[:] = original_path + + # Rename external modules with prefix tag to avoid future conflicts + for prefix in isolated_prefixes: + for mod_name in list(sys.modules.keys()): + if mod_name == prefix or mod_name.startswith(prefix + '.'): + if mod_name not in saved_modules: + sys.modules[prefix_tag + mod_name] = sys.modules.pop(mod_name) + + # Restore main project modules + sys.modules.update(saved_modules) class MotionSmoothnessAnalyzer(VideoQualityAnalyzer): """Analyzer for evaluating motion smoothness in videos using AMT-S model. @@ -174,25 +254,52 @@ class MotionSmoothnessAnalyzer(VideoQualityAnalyzer): with smoother motion resulting in higher scores. """ - def __init__(self, model_path: str = "model/amt/amt-s.pth", + # Remote sources + + # Local cache dir OUTSIDE project (to avoid vendoring) + from pathlib import Path + CODE_DIR = CACHE_DIR / "amt" + GH_API = "https://api.github.com/repos/MCG-NKU/AMT/contents" + WEIGHT_URL = 'https://hf-mirror.com/lalala125/AMT/resolve/main/amt-s.pth' + WEIGHT_PATH = CACHE_DIR / "amt" / "amt-s.pth" + + def __init__(self, device: str = "cuda", niters: int = 1): """Initialize the MotionSmoothnessAnalyzer. - + Args: - model_path: Path to the AMT-S model checkpoint device: Device to run the model on ('cuda' or 'cpu') niters: Number of interpolation iterations (default: 1) """ self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.niters = niters - + # Initialize model parameters self._initialize_params() - + + # Ensure model files exist → download if needed + self._ensure_files() + # Load AMT-S model - self.model = self._load_amt_model(model_path) + self.model = self._load_amt_model(str(self.WEIGHT_PATH)) self.model.eval() self.model.to(self.device) + + def _ensure_files(self): + """Download architecture and weight files if they do not exist.""" + self.CODE_DIR.mkdir(parents=True, exist_ok=True) + + # Check if key file exists, not just directory + if not (self.CODE_DIR / "networks" / "AMT-S.py").exists(): + print("Downloading AMT architecture...") + download_recursive(self.GH_API, self.CODE_DIR) + + if not self.WEIGHT_PATH.exists(): + print("Downloading AMT-S weights...") + self._download(self.WEIGHT_URL, self.CODE_DIR) + + def _download(self, url: str, local_dir: Path): + subprocess.run(['wget', url, '-P', local_dir], check=True) def _initialize_params(self): """Initialize parameters for video processing.""" @@ -213,35 +320,41 @@ def _initialize_params(self): def _load_amt_model(self, model_path: str): """Load AMT-S model. - + Args: model_path: Path to the model checkpoint - + Returns: Loaded AMT-S model """ - # Import AMT-S model (note the hyphen in filename) - import sys import importlib.util - - # Load the module with hyphen in filename - spec = importlib.util.spec_from_file_location("amt_s", "model/amt/networks/AMT-S.py") - amt_s_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(amt_s_module) - Model = amt_s_module.Model - - # Create model with default parameters - model = Model( - corr_radius=3, - corr_lvls=4, - num_flows=3 - ) - - # Load checkpoint - if os.path.exists(model_path): - ckpt = torch.load(model_path, map_location="cpu", weights_only=False) - model.load_state_dict(ckpt['state_dict']) - + + isolated_prefixes = ["utils", "networks"] + + with isolated_import_context(self.CODE_DIR, isolated_prefixes, prefix_tag="_amt_ext_"): + # Load AMT-S entry module + entry_py = self.CODE_DIR / "networks" / "AMT-S.py" + spec = importlib.util.spec_from_file_location("_amt_s_entry", entry_py) + amt_s_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(amt_s_module) + Model = amt_s_module.Model + + # Also load utils functions needed later (store as instance attributes) + utils_py = self.CODE_DIR / "utils" / "utils.py" + spec_utils = importlib.util.spec_from_file_location("_amt_utils", utils_py) + utils_module = importlib.util.module_from_spec(spec_utils) + spec_utils.loader.exec_module(utils_module) + self._amt_img2tensor = utils_module.img2tensor + self._amt_tensor2img = utils_module.tensor2img + self._amt_check_dim_and_resize = utils_module.check_dim_and_resize + self._amt_InputPadder = utils_module.InputPadder + + # Create model + model = Model(corr_radius=3, corr_lvls=4, num_flows=3) + if os.path.exists(model_path): + ckpt = torch.load(model_path, map_location="cpu", weights_only=False) + model.load_state_dict(ckpt['state_dict']) + return model def _extract_frames(self, frames: List[Image.Image], start_from: int = 0) -> List[np.ndarray]: @@ -263,39 +376,36 @@ def _extract_frames(self, frames: List[Image.Image], start_from: int = 0) -> Lis def _img2tensor(self, img: np.ndarray) -> torch.Tensor: """Convert numpy image to tensor. - + Args: img: Image as numpy array (H, W, C) - + Returns: Image tensor (1, C, H, W) """ - from model.amt.utils.utils import img2tensor - return img2tensor(img) - + return self._amt_img2tensor(img) + def _tensor2img(self, tensor: torch.Tensor) -> np.ndarray: """Convert tensor to numpy image. - + Args: tensor: Image tensor (1, C, H, W) - + Returns: Image as numpy array (H, W, C) """ - from model.amt.utils.utils import tensor2img - return tensor2img(tensor) - + return self._amt_tensor2img(tensor) + def _check_dim_and_resize(self, tensor_list: List[torch.Tensor]) -> List[torch.Tensor]: """Check dimensions and resize tensors if needed. - + Args: tensor_list: List of image tensors - + Returns: List of resized tensors """ - from model.amt.utils.utils import check_dim_and_resize - return check_dim_and_resize(tensor_list) + return self._amt_check_dim_and_resize(tensor_list) def _calculate_scale(self, h: int, w: int) -> float: """Calculate scaling factor based on available VRAM. @@ -314,23 +424,21 @@ def _calculate_scale(self, h: int, w: int) -> float: def _interpolate_frames(self, inputs: List[torch.Tensor], scale: float) -> List[torch.Tensor]: """Interpolate frames using AMT-S model. - + Args: inputs: List of input frame tensors scale: Scaling factor for processing - + Returns: List of interpolated frame tensors """ - from model.amt.utils.utils import InputPadder - # Pad inputs padding = int(16 / scale) - padder = InputPadder(inputs[0].shape, padding) + padder = self._amt_InputPadder(inputs[0].shape, padding) inputs = padder.pad(*inputs) - + # Perform interpolation for specified iterations - for i in range(self.niters): + for _ in range(self.niters): outputs = [inputs[0]] for in_0, in_1 in zip(inputs[:-1], inputs[1:]): in_0 = in_0.to(self.device) @@ -339,7 +447,7 @@ def _interpolate_frames(self, inputs: List[torch.Tensor], scale: float) -> List[ imgt_pred = self.model(in_0, in_1, self.embt, scale_factor=scale, eval=True)['imgt_pred'] outputs += [imgt_pred.cpu(), in_1.cpu()] inputs = outputs - + # Unpad outputs outputs = padder.unpad(*outputs) return outputs @@ -427,62 +535,116 @@ def analyze(self, frames: List[Image.Image]) -> float: class DynamicDegreeAnalyzer(VideoQualityAnalyzer): """Analyzer for evaluating dynamic degree (motion intensity) in videos using RAFT optical flow. - + This analyzer measures the amount and intensity of motion in videos by: 1. Computing optical flow between consecutive frames using RAFT 2. Calculating flow magnitude for each pixel 3. Extracting top 5% highest flow magnitudes 4. Determining if video has sufficient dynamic motion based on thresholds - + The score represents whether the video contains dynamic motion (1.0) or is mostly static (0.0). """ - - def __init__(self, model_path: str = "model/raft/raft-things.pth", + from pathlib import Path + + # GitHub API endpoint to list directory content + GH_API = "https://api.github.com/repos/princeton-vl/RAFT/contents/core" + + # Local cache directory (project-external) + CODE_DIR = CACHE_DIR / "raft" / "core" + WEIGHT_DIR = CACHE_DIR / "raft" + + def __init__(self, model_name: str = "raft-things.pth", device: str = "cuda", sample_fps: int = 8): """Initialize the DynamicDegreeAnalyzer. - + Args: - model_path: Path to the RAFT model checkpoint + model_name: Name of the RAFT model checkpoint device: Device to run the model on ('cuda' or 'cpu') sample_fps: Target FPS for frame sampling (default: 8) """ self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.sample_fps = sample_fps - + self.weight_path = self.WEIGHT_DIR / "models" / model_name + + self._ensure_files() + # Load RAFT model - self.model = self._load_raft_model(model_path) + self.model = self._load_raft_model(str(self.weight_path)) self.model.eval() self.model.to(self.device) - + + def _ensure_files(self): + """Ensure RAFT model code + weights available locally""" + self.CODE_DIR.mkdir(parents=True, exist_ok=True) + + # Check if key file exists, not just directory + if not (self.CODE_DIR / "raft.py").exists(): + print("Downloading RAFT architecture...") + download_recursive(self.GH_API, self.CODE_DIR) + + # Download weights + if not self.weight_path.exists(): + print("Downloading RAFT weights...") + self._download_file() + + def _download_file(self): + self.WEIGHT_DIR.mkdir(parents=True, exist_ok=True) + wget_command = ['wget', '-P', str(self.WEIGHT_DIR), 'https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip'] + unzip_command = ['unzip', '-o', '-d', str(self.WEIGHT_DIR), str(self.WEIGHT_DIR / 'models.zip')] + remove_command = ['rm', '-f', str(self.WEIGHT_DIR / 'models.zip')] + + subprocess.run(wget_command, check=True) + subprocess.run(unzip_command, check=True) + subprocess.run(remove_command, check=True) + def _load_raft_model(self, model_path: str): """Load RAFT optical flow model. - + Args: model_path: Path to the model checkpoint - + Returns: Loaded RAFT model """ - from model.raft.core.raft import RAFT - from easydict import EasyDict as edict - - # Configure RAFT arguments - args = edict({ - "model": model_path, - "small": False, - "mixed_precision": False, - "alternate_corr": False - }) - - # Create and load model - model = RAFT(args) - - if os.path.exists(model_path): - ckpt = torch.load(model_path, map_location="cpu") - # Remove 'module.' prefix if present (from DataParallel) - new_ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} - model.load_state_dict(new_ckpt) - + import importlib.util + + # RAFT internal modules that may conflict with main project + isolated_prefixes = ["utils", "update", "extractor", "corr", "raft", "alt_cuda_corr"] + + with isolated_import_context(self.CODE_DIR, isolated_prefixes, prefix_tag="_raft_ext_"): + # Load RAFT entry module + raft_py = self.CODE_DIR / "raft.py" + spec = importlib.util.spec_from_file_location("_raft_entry", raft_py) + raft_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(raft_module) + RAFT = raft_module.RAFT + + # Load InputPadder for later use + utils_py = self.CODE_DIR / "utils" / "utils.py" + spec_utils = importlib.util.spec_from_file_location("_raft_utils", utils_py) + utils_module = importlib.util.module_from_spec(spec_utils) + spec_utils.loader.exec_module(utils_module) + self._raft_InputPadder = utils_module.InputPadder + + from easydict import EasyDict as edict + + # Configure RAFT arguments + args = edict({ + "model": model_path, + "small": False, + "mixed_precision": False, + "alternate_corr": False + }) + + # Create and load model + model = RAFT(args) + + if os.path.exists(model_path): + ckpt = torch.load(model_path, map_location="cpu") + # Remove 'module.' prefix if present (from DataParallel) + new_ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} + model.load_state_dict(new_ckpt) + return model def _extract_frames_for_flow(self, frames: List[Image.Image], target_fps: int = 8) -> List[torch.Tensor]: @@ -617,13 +779,12 @@ def analyze(self, frames: List[Image.Image]) -> float: with torch.no_grad(): for frame1, frame2 in zip(prepared_frames[:-1], prepared_frames[1:]): # Pad frames if necessary - from model.raft.core.utils_core.utils import InputPadder - padder = InputPadder(frame1.shape) + padder = self._raft_InputPadder(frame1.shape) frame1_padded, frame2_padded = padder.pad(frame1, frame2) - + # Compute optical flow _, flow_up = self.model(frame1_padded, frame2_padded, iters=20, test_mode=True) - + # Calculate flow magnitude score magnitude_score = self._compute_flow_magnitude(flow_up) flow_scores.append(magnitude_score) @@ -778,7 +939,7 @@ class ImagingQualityAnalyzer(VideoQualityAnalyzer): The score represents the quality of the video (higher is better). """ - def __init__(self, model_path: str = "model/musiq/musiq_spaq_ckpt-358bb6af.pth", device: str = "cuda"): + def __init__(self, model_path: str = "musiq_spaq_ckpt-358bb6af.pth", device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.model = self._load_musiq(model_path) self.model.to(self.device) @@ -793,15 +954,19 @@ def _load_musiq(self, model_path: str): Returns: MUSIQ model """ + model_path = CACHE_DIR / model_path + # if the model_path not exists # then makedir and wget if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) wget_command = ['wget', 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/musiq_spaq_ckpt-358bb6af.pth', '-P', os.path.dirname(model_path)] subprocess.run(wget_command, check=True) - - from pyiqa.archs.musiq_arch import MUSIQ - model = MUSIQ(pretrained_model_path=model_path) + try: + from pyiqa.archs.musiq_arch import MUSIQ + except ImportError: + raise ImportError("Please install pyiqa to use ImagingQualityAnalyzer: pip install pyiqa") + model = MUSIQ(pretrained_model_path=str(model_path)) return model diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index bd1d8c0..0000000 --- a/examples/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file diff --git a/examples/assess_detectability.py b/examples/assess_detectability.py deleted file mode 100644 index a468373..0000000 --- a/examples/assess_detectability.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -from watermark.auto_watermark import AutoWatermark -from evaluation.dataset import StableDiffusionPromptsDataset -from evaluation.pipelines.detection import WatermarkedMediaDetectionPipeline, UnWatermarkedMediaDetectionPipeline, DetectionPipelineReturnType -from evaluation.tools.image_editor import JPEGCompression -from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator, FundamentalSuccessRateCalculator -from utils.diffusion_config import DiffusionConfig -from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline -import dotenv -import os - -dotenv.load_dotenv() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -model_path = os.getenv("MODEL_PATH") - -def assess_numerical_detectability(algorithm_name, labels, rules, target_fpr): - my_dataset = StableDiffusionPromptsDataset(max_samples=200) - scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - diffusion_config = DiffusionConfig( - scheduler = scheduler, - pipe = pipe, - device = device, - image_size = (512, 512), - num_inference_steps = 50, - guidance_scale = 3.5, - gen_seed = 42, - inversion_type = "ddim" - ) - my_watermark = AutoWatermark.load(f'{algorithm_name}', - algorithm_config=f'config/{algorithm_name}.json', - diffusion_config=diffusion_config) - - pipeline1 = WatermarkedMediaDetectionPipeline(dataset=my_dataset, media_editor_list=[], - show_progress=True, return_type=DetectionPipelineReturnType.SCORES) - - pipeline2 = UnWatermarkedMediaDetectionPipeline(dataset=my_dataset, media_editor_list=[], - show_progress=True, return_type=DetectionPipelineReturnType.SCORES) - - detection_kwargs = { - "num_inference_steps": 50, - "guidance_scale": 1.0, - } - - calculator = DynamicThresholdSuccessRateCalculator(labels=labels, rule=rules, target_fpr=target_fpr) - print(calculator.calculate(pipeline1.evaluate(my_watermark, detection_kwargs=detection_kwargs), pipeline2.evaluate(my_watermark, detection_kwargs=detection_kwargs))) - -def assess_binary_detectability(algorithm_name, labels, rules, target_fpr): - my_dataset = StableDiffusionPromptsDataset(max_samples=200) - scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - diffusion_config = DiffusionConfig( - scheduler = scheduler, - pipe = pipe, - device = device, - image_size = (512, 512), - num_inference_steps = 50, - guidance_scale = 3.5, - gen_seed = 42, - inversion_type = "exact" - ) - my_watermark = AutoWatermark.load(f'{algorithm_name}', - algorithm_config=f'config/{algorithm_name}.json', - diffusion_config=diffusion_config) - - # Use IS_WATERMARKED return type for fixed threshold evaluation - pipeline1 = WatermarkedMediaDetectionPipeline( - dataset=my_dataset, - media_editor_list=[], - show_progress=True, - detector_type="is_watermarked", - return_type=DetectionPipelineReturnType.IS_WATERMARKED - ) - - pipeline2 = UnWatermarkedMediaDetectionPipeline( - dataset=my_dataset, - media_editor_list=[], - media_source_mode="generated", - show_progress=True, - detector_type="is_watermarked", - return_type=DetectionPipelineReturnType.IS_WATERMARKED - ) - - # Use FundamentalSuccessRateCalculator for fixed threshold evaluation - calculator = FundamentalSuccessRateCalculator(labels=['F1', 'TPR', 'TNR', 'FPR', 'P', 'R', 'ACC']) - - detection_kwargs = { - "num_inference_steps": 50, - "guidance_scale": 1.0, - "decoder_inv": False, - "inv_order": 0 - } - - # Get detection results - watermarked_results = pipeline1.evaluate(my_watermark, detection_kwargs=detection_kwargs) - non_watermarked_results = pipeline2.evaluate(my_watermark, detection_kwargs=detection_kwargs) - print(calculator.calculate(watermarked_results, non_watermarked_results)) - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', type=str, default='PRC') - parser.add_argument('--labels', nargs='+', default=['TPR', 'F1']) - parser.add_argument('--rules', type=str, default='best') - parser.add_argument('--target_fpr', type=float, default=0.01) - args = parser.parse_args() - - # assess_numerical_detectability(args.algorithm, args.labels, args.rules, args.target_fpr) - assess_binary_detectability(args.algorithm, args.labels, args.rules, args.target_fpr) \ No newline at end of file diff --git a/examples/assess_image_quality.py b/examples/assess_image_quality.py deleted file mode 100644 index a0d9159..0000000 --- a/examples/assess_image_quality.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2024 THU-BPM MarkLLM. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ========================================================================== -# assess_quality.py -# Description: Assess the impact on text quality of a watermarking algorithm -# ========================================================================== - -import torch -import torch -from watermark.auto_watermark import AutoWatermark -from evaluation.dataset import StableDiffusionPromptsDataset, MSCOCODataset, VBenchDataset -from evaluation.pipelines.image_quality_analysis import ( - DirectImageQualityAnalysisPipeline, - ReferencedImageQualityAnalysisPipeline, - GroupImageQualityAnalysisPipeline, - RepeatImageQualityAnalysisPipeline, - ComparedImageQualityAnalysisPipeline, - QualityPipelineReturnType -) -from evaluation.tools.image_quality_analyzer import ( -NIQECalculator, -CLIPScoreCalculator, FIDCalculator, InceptionScoreCalculator, LPIPSAnalyzer, PSNRAnalyzer,SSIMAnalyzer, BRISQUEAnalyzer,VIFAnalyzer,FSIMAnalyzer) -from utils.diffusion_config import DiffusionConfig -from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline -import dotenv -import os -import dotenv -dotenv.load_dotenv() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -model_path = os.getenv("MODEL_PATH") -""" - DirectImageQualityAnalysisPipeline: PSNRAnalyzer, SSIMAnalyzer,BRISQUEAnalyzer - ReferencedImageQualityAnalysisPipeline: CLIPScoreCalculator - GroupImageQualityAnalysisPipeline: FIDCalculator, InceptionScoreCalculator - RepeatImageQualityAnalysisPipeline: LPIPSAnalyzer - ComparedImageQualityAnalysisPipeline: PSNRAnalyzer, SSIMAnalyzer - -""" - -def assess_image_quality(algorithm_name, metric, max_samples=10): - print(model_path) - scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - diffusion_config = DiffusionConfig( - scheduler = scheduler, - pipe = pipe, - device = device, - image_size = (512, 512), - num_inference_steps = 50, - guidance_scale = 3.5, - gen_seed = 42, - inversion_type = "ddim" - ) - if metric == 'NIQE': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = DirectImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[NIQECalculator()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'CLIP-T': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = ReferencedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[CLIPScoreCalculator(reference_source="text")], - unwatermarked_image_source='generated', - reference_image_source='generated', - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'CLIP-I': - my_dataset = MSCOCODataset(max_samples=max_samples) - pipeline = ReferencedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[CLIPScoreCalculator(reference_source="image")], - unwatermarked_image_source='generated', - reference_image_source='natural', - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'FID': - my_dataset = MSCOCODataset(max_samples=max_samples) - pipeline = GroupImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[FIDCalculator()], - unwatermarked_image_source='generated', - reference_image_source='natural', - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - elif metric == 'IS': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = GroupImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[InceptionScoreCalculator()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'LPIPS': - my_dataset = StableDiffusionPromptsDataset(max_samples=10) - pipeline = RepeatImageQualityAnalysisPipeline(dataset=my_dataset, - prompt_per_image=20, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[LPIPSAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'PSNR': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = ComparedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[PSNRAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'SSIM': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = ComparedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[SSIMAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'BRISQUE': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = DirectImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[BRISQUEAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - - elif metric == 'VIF': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = ComparedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[VIFAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - elif metric == 'FSIM': - my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples) - pipeline = ComparedImageQualityAnalysisPipeline(dataset=my_dataset, - watermarked_image_editor_list=[], - unwatermarked_image_editor_list=[], - analyzers=[FSIMAnalyzer()], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES) - else: - raise ValueError('Invalid metric') - - - my_watermark = AutoWatermark.load(f'{algorithm_name}', - algorithm_config=f'config/{algorithm_name}.json', - diffusion_config=diffusion_config) - print(pipeline.evaluate(my_watermark)) - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', type=str, default='TR') - parser.add_argument('--metric', type=str, default='FID') - parser.add_argument('--max_samples', type=int, default=10) - args = parser.parse_args() - - assess_image_quality(args.algorithm, args.metric, args.max_samples) diff --git a/examples/assess_image_robustness.py b/examples/assess_image_robustness.py deleted file mode 100644 index 56d6d48..0000000 --- a/examples/assess_image_robustness.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from watermark.auto_watermark import AutoWatermark -from evaluation.dataset import StableDiffusionPromptsDataset -from evaluation.pipelines.detection import WatermarkMediaDetectionPipeline, UnWatermarkMediaDetectionPipeline, DetectionPipelineReturnType -from evaluation.tools.image_editor import JPEGCompression, Rotation, CrSc, GaussianBlurring, GaussianNoise, Brightness, Mask, Overlay, AdaptiveNoiseInjection -from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator -from utils.diffusion_config import DiffusionConfig -from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline -import dotenv -import os - -dotenv.load_dotenv() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -model_path = os.getenv("MODEL_PATH") - -def assess_image_robustness(algorithm_name, attack_name): - my_dataset = StableDiffusionPromptsDataset(max_samples=200) - scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - diffusion_config = DiffusionConfig( - scheduler = scheduler, - pipe = pipe, - device = device, - image_size = (512, 512), - num_inference_steps = 50, - guidance_scale = 3.5, - gen_seed = 42, - inversion_type = "ddim" - ) - - my_watermark = AutoWatermark.load(f'{algorithm_name}', - algorithm_config=f'config/{algorithm_name}.json', - diffusion_config=diffusion_config) - if attack_name == 'JPEG': - attack = JPEGCompression(quality=25) - elif attack_name == 'Rotation': - attack = Rotation(angle=75, expand=False) - elif attack_name == 'CrSc': - attack = CrSc(crop_ratio=0.75) - elif attack_name == 'Blur': - attack = GaussianBlurring(radius=8) - elif attack_name == 'Noise': - attack = GaussianNoise(sigma=0.1) - elif attack_name == 'Brightness': - attack = Brightness(factor=0.6) - elif attack_name == 'Mask': - attack = Mask(mask_ratio=0.1, num_masks=5) - elif attack_name == 'Overlay': - attack = Overlay(num_strokes=10, stroke_width=5, stroke_type='random') - elif attack_name == 'AdaptiveNoise': - attack = AdaptiveNoiseInjection(intensity=0.5, auto_select=True) - - pipline1 = WatermarkMediaDetectionPipeline(dataset=my_dataset, media_editor_list=[attack], - show_progress=True, return_type=DetectionPipelineReturnType.SCORES) - - pipline2 = UnWatermarkMediaDetectionPipeline(dataset=my_dataset, media_editor_list=[], - show_progress=True, return_type=DetectionPipelineReturnType.SCORES) - - calculator = DynamicThresholdSuccessRateCalculator(labels=['TPR', 'F1'], rule='best') - print(calculator.calculate(pipline1.evaluate(my_watermark), pipline2.evaluate(my_watermark))) - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', type=str, default='TR') - parser.add_argument('--attack', type=str, default='JPEG') - args = parser.parse_args() - - assess_image_robustness(args.algorithm, args.attack) diff --git a/examples/assess_video_quality.py b/examples/assess_video_quality.py deleted file mode 100644 index 98da769..0000000 --- a/examples/assess_video_quality.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2024 THU-BPM MarkLLM. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ========================================================================== -# assess_video_quality.py -# Description: Assess the impact on video quality of a watermarking algorithm -# ========================================================================== - -import torch -import os -import dotenv -from watermark.auto_watermark import AutoWatermark -from evaluation.dataset import VBenchDataset -from evaluation.pipelines.video_quality_analysis import ( - DirectVideoQualityAnalysisPipeline, - QualityPipelineReturnType -) -from evaluation.tools.video_quality_analyzer import ( - SubjectConsistencyAnalyzer, - MotionSmoothnessAnalyzer, - DynamicDegreeAnalyzer, - BackgroundConsistencyAnalyzer, - ImagingQualityAnalyzer -) -from utils.diffusion_config import DiffusionConfig -from diffusers import DDIMScheduler, TextToVideoSDPipeline - -# Load environment variables -dotenv.load_dotenv() - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -model_path = os.getenv("T2V_MODEL_PATH") - -""" -Video Quality Analysis Pipeline and Metrics: - DirectVideoQualityAnalysisPipeline: - - SubjectConsistencyAnalyzer: Measures subject consistency across frames using DINO features - - MotionSmoothnessAnalyzer: Evaluates motion smoothness using AMT-S frame interpolation - - DynamicDegreeAnalyzer: Analyzes motion intensity using RAFT optical flow - - BackgroundConsistencyAnalyzer: Measures background consistency using CLIP features - - ImagingQualityAnalyzer: Evaluates overall imaging quality using MUSIQ -""" - -def assess_video_quality(algorithm_name: str = "VideoShield", metric: str = "subject_consistency", dimension: str = "subject_consistency"): - """ - Assess video quality using specified metric and VBench dataset. - - Args: - algorithm_name (str): Name of the watermarking algorithm - metric (str): Quality metric to evaluate ('subject_consistency', 'motion_smoothness', - 'dynamic_degree', 'background_consistency', 'imaging_quality') - dimension (str): VBench dimension to use for evaluation - """ - - # Load VBench dataset - my_dataset = VBenchDataset(max_samples=200, dimension=dimension) - - # Initialize analyzer based on metric - if metric == 'subject_consistency': - analyzer = SubjectConsistencyAnalyzer(device=device) - elif metric == 'motion_smoothness': - analyzer = MotionSmoothnessAnalyzer(device=device) - elif metric == 'dynamic_degree': - analyzer = DynamicDegreeAnalyzer(device=device) - elif metric == 'background_consistency': - analyzer = BackgroundConsistencyAnalyzer(device=device) - elif metric == 'imaging_quality': - analyzer = ImagingQualityAnalyzer(device=device) - else: - raise ValueError(f'Invalid metric: {metric}. Supported metrics: subject_consistency, motion_smoothness, dynamic_degree, background_consistency, imaging_quality') - - # Create video quality analysis pipeline - pipeline = DirectVideoQualityAnalysisPipeline( - dataset=my_dataset, - watermarked_video_editor_list=[], - unwatermarked_video_editor_list=[], - watermarked_frame_editor_list=[], - unwatermarked_frame_editor_list=[], - analyzers=[analyzer], - show_progress=True, - return_type=QualityPipelineReturnType.MEAN_SCORES - ) - - # Create diffusion config for video generation - if model_path is None: - raise ValueError("T2V_MODEL_PATH environment variable is not set") - - # Load video generation pipeline (placeholder - adapt based on your T2V model) - # For example, if using VideoCrafter or similar models: - try: - # This is a placeholder - replace with actual T2V pipeline initialization - # scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") - # pipe = YourVideoGenerationPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") - pipe = TextToVideoSDPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) - diffusion_config = DiffusionConfig( - scheduler=scheduler, - pipe=pipe, - device=device, - # Video-specific parameters - num_frames=16, # Number of frames - width=512, - height=512, - num_inference_steps=50, - guidance_scale=7.5, - gen_seed=42, - inversion_type="ddim" - ) - except Exception as e: - print(f"Warning: Could not load T2V model from {model_path}. Using default config. Error: {e}") - diffusion_config = DiffusionConfig( - device=device, - num_frames=16, - width=512, - height=512, - num_inference_steps=50, - guidance_scale=7.5, - gen_seed=42, - ) - - # Load watermark algorithm - my_watermark = AutoWatermark.load( - f'{algorithm_name}', - algorithm_config=f'config/{algorithm_name}.json', - diffusion_config=diffusion_config - ) - - # Run evaluation - print(f"Evaluating {algorithm_name} with {metric} metric on VBench {dimension} dimension...") - result = pipeline.evaluate(my_watermark) - print(f"Results: {result}") - - return result - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Assess video quality impact of watermarking algorithms') - parser.add_argument('--algorithm', type=str, default='VideoShield', - help='Watermarking algorithm name') - parser.add_argument('--metric', type=str, default='subject_consistency', - choices=['subject_consistency', 'motion_smoothness', 'dynamic_degree', - 'background_consistency', 'imaging_quality'], - help='Quality metric to evaluate') - parser.add_argument('--dimension', type=str, default='subject_consistency', - choices=['subject_consistency', 'background_consistency', 'imaging_quality', - 'motion_smoothness', 'dynamic_degree'], - help='VBench dimension to use for evaluation') - - args = parser.parse_args() - - assess_video_quality(args.algorithm, args.metric, args.dimension) diff --git a/exceptions/__init__.py b/exceptions/__init__.py index 3b89505..2de5aaf 100644 --- a/exceptions/__init__.py +++ b/exceptions/__init__.py @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Custom exceptions for MarkDiffusion. - -This module defines custom exception classes used throughout the MarkDiffusion library. -""" +"""Custom exceptions for MarkDiffusion.""" from .exceptions import * - -__all__ = ['exceptions'] - diff --git a/inversions/base_inversion.py b/inversions/base_inversion.py index 53ff8f9..206fdfc 100644 --- a/inversions/base_inversion.py +++ b/inversions/base_inversion.py @@ -120,10 +120,10 @@ def forward_diffusion(self, ): pass - def _apply_guidance_scale(self, model_output, guidance_scale): - if guidance_scale > 1.0: - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - return noise_pred - else: - return model_output \ No newline at end of file + # def _apply_guidance_scale(self, model_output, guidance_scale): + # if guidance_scale > 1.0: + # noise_pred_uncond, noise_pred_text = model_output.chunk(2) + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # return noise_pred + # else: + # return model_output \ No newline at end of file diff --git a/model/amt/__init__.py b/model/amt/__init__.py deleted file mode 100644 index c9ffca6..0000000 --- a/model/amt/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - diff --git a/model/amt/amt-s.pth b/model/amt/amt-s.pth deleted file mode 100644 index dbfe53e..0000000 Binary files a/model/amt/amt-s.pth and /dev/null differ diff --git a/model/amt/networks/AMT-G.py b/model/amt/networks/AMT-G.py deleted file mode 100644 index 332ec76..0000000 --- a/model/amt/networks/AMT-G.py +++ /dev/null @@ -1,172 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from vbench.third_party.amt.networks.blocks.raft import ( - coords_grid, - BasicUpdateBlock, BidirCorrBlock -) -from vbench.third_party.amt.networks.blocks.feat_enc import ( - LargeEncoder -) -from vbench.third_party.amt.networks.blocks.ifrnet import ( - resize, - Encoder, - InitDecoder, - IntermediateDecoder -) -from vbench.third_party.amt.networks.blocks.multi_flow import ( - multi_flow_combine, - MultiFlowDecoder -) - - -class Model(nn.Module): - def __init__(self, - corr_radius=3, - corr_lvls=4, - num_flows=5, - channels=[84, 96, 112, 128], - skip_channels=84): - super(Model, self).__init__() - self.radius = corr_radius - self.corr_levels = corr_lvls - self.num_flows = num_flows - - self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) - self.encoder = Encoder(channels, large=True) - self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) - self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) - self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) - self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) - - self.update4 = self._get_updateblock(112, None) - self.update3_low = self._get_updateblock(96, 2.0) - self.update2_low = self._get_updateblock(84, 4.0) - - self.update3_high = self._get_updateblock(96, None) - self.update2_high = self._get_updateblock(84, None) - - self.comb_block = nn.Sequential( - nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), - nn.PReLU(6*self.num_flows), - nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), - ) - - def _get_updateblock(self, cdim, scale_factor=None): - return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, - corr_dim=256, corr_dim2=192, fc_dim=188, - scale_factor=scale_factor, corr_levels=self.corr_levels, - radius=self.radius) - - def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): - # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 - # based on linear assumption - t1_scale = 1. / embt - t0_scale = 1. / (1. - embt) - if downsample != 1: - inv = 1 / downsample - flow0 = inv * resize(flow0, scale_factor=inv) - flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) - corr = torch.cat([corr0, corr1], dim=1) - flow = torch.cat([flow0, flow1], dim=1) - return corr, flow - - def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) - img0 = img0 - mean_ - img1 = img1 - mean_ - img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 - img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 - b, _, h, w = img0_.shape - coord = coords_grid(b, h // 8, w // 8, img0.device) - - fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] - corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) - - # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] - # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] - f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) - f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) - - ######################################### the 4th decoder ######################################### - up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) - corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, - up_flow0_4, up_flow1_4, - embt, downsample=1) - - # residue update with lookup corr - delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) - delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) - up_flow0_4 = up_flow0_4 + delta_flow0_4 - up_flow1_4 = up_flow1_4 + delta_flow1_4 - ft_3_ = ft_3_ + delta_ft_3_ - - ######################################### the 3rd decoder ######################################### - up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - corr_3, flow_3 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_3, up_flow1_3, - embt, downsample=2) - - # residue update with lookup corr - delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) - delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) - up_flow0_3 = up_flow0_3 + delta_flow0_3 - up_flow1_3 = up_flow1_3 + delta_flow1_3 - ft_2_ = ft_2_ + delta_ft_2_ - - # residue update with lookup corr (hr) - corr_3 = resize(corr_3, scale_factor=2.0) - up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) - delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) - ft_2_ += delta_ft_2_ - up_flow0_3 += delta_up_flow_3[:, 0:2] - up_flow1_3 += delta_up_flow_3[:, 2:4] - - ######################################### the 2nd decoder ######################################### - up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - corr_2, flow_2 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_2, up_flow1_2, - embt, downsample=4) - - # residue update with lookup corr - delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) - delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) - up_flow0_2 = up_flow0_2 + delta_flow0_2 - up_flow1_2 = up_flow1_2 + delta_flow1_2 - ft_1_ = ft_1_ + delta_ft_1_ - - # residue update with lookup corr (hr) - corr_2 = resize(corr_2, scale_factor=4.0) - up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) - delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) - ft_1_ += delta_ft_1_ - up_flow0_2 += delta_up_flow_2[:, 0:2] - up_flow1_2 += delta_up_flow_2[:, 2:4] - - ######################################### the 1st decoder ######################################### - up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - mask = resize(mask, scale_factor=(1.0/scale_factor)) - img_res = resize(img_res, scale_factor=(1.0/scale_factor)) - - # Merge multiple predictions - imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, - mask, img_res, mean_) - imgt_pred = torch.clamp(imgt_pred, 0, 1) - - if eval: - return { 'imgt_pred': imgt_pred, } - else: - up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) - up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) - return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], - } diff --git a/model/amt/networks/AMT-L.py b/model/amt/networks/AMT-L.py deleted file mode 100644 index 551fac5..0000000 --- a/model/amt/networks/AMT-L.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import torch.nn as nn -from vbench.third_party.amt.networks.blocks.raft import ( - coords_grid, - BasicUpdateBlock, BidirCorrBlock -) -from vbench.third_party.amt.networks.blocks.feat_enc import ( - BasicEncoder, -) -from vbench.third_party.amt.networks.blocks.ifrnet import ( - resize, - Encoder, - InitDecoder, - IntermediateDecoder -) -from vbench.third_party.amt.networks.blocks.multi_flow import ( - multi_flow_combine, - MultiFlowDecoder -) - -class Model(nn.Module): - def __init__(self, - corr_radius=3, - corr_lvls=4, - num_flows=5, - channels=[48, 64, 72, 128], - skip_channels=48 - ): - super(Model, self).__init__() - self.radius = corr_radius - self.corr_levels = corr_lvls - self.num_flows = num_flows - - self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) - self.encoder = Encoder([48, 64, 72, 128], large=True) - - self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) - self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) - self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) - self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) - - self.update4 = self._get_updateblock(72, None) - self.update3 = self._get_updateblock(64, 2.0) - self.update2 = self._get_updateblock(48, 4.0) - - self.comb_block = nn.Sequential( - nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), - nn.PReLU(6*self.num_flows), - nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), - ) - - def _get_updateblock(self, cdim, scale_factor=None): - return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48, - corr_dim=256, corr_dim2=160, fc_dim=124, - scale_factor=scale_factor, corr_levels=self.corr_levels, - radius=self.radius) - - def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): - # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 - # based on linear assumption - t1_scale = 1. / embt - t0_scale = 1. / (1. - embt) - if downsample != 1: - inv = 1 / downsample - flow0 = inv * resize(flow0, scale_factor=inv) - flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) - corr = torch.cat([corr0, corr1], dim=1) - flow = torch.cat([flow0, flow1], dim=1) - return corr, flow - - def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) - img0 = img0 - mean_ - img1 = img1 - mean_ - img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 - img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 - b, _, h, w = img0_.shape - coord = coords_grid(b, h // 8, w // 8, img0.device) - - fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] - corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) - - # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] - # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] - f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) - f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) - - ######################################### the 4th decoder ######################################### - up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) - corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, - up_flow0_4, up_flow1_4, - embt, downsample=1) - - # residue update with lookup corr - delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) - delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) - up_flow0_4 = up_flow0_4 + delta_flow0_4 - up_flow1_4 = up_flow1_4 + delta_flow1_4 - ft_3_ = ft_3_ + delta_ft_3_ - - ######################################### the 3rd decoder ######################################### - up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - corr_3, flow_3 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_3, up_flow1_3, - embt, downsample=2) - - # residue update with lookup corr - delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) - delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) - up_flow0_3 = up_flow0_3 + delta_flow0_3 - up_flow1_3 = up_flow1_3 + delta_flow1_3 - ft_2_ = ft_2_ + delta_ft_2_ - - ######################################### the 2nd decoder ######################################### - up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - corr_2, flow_2 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_2, up_flow1_2, - embt, downsample=4) - - # residue update with lookup corr - delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) - delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) - up_flow0_2 = up_flow0_2 + delta_flow0_2 - up_flow1_2 = up_flow1_2 + delta_flow1_2 - ft_1_ = ft_1_ + delta_ft_1_ - - ######################################### the 1st decoder ######################################### - up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - mask = resize(mask, scale_factor=(1.0/scale_factor)) - img_res = resize(img_res, scale_factor=(1.0/scale_factor)) - - # Merge multiple predictions - imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, - mask, img_res, mean_) - imgt_pred = torch.clamp(imgt_pred, 0, 1) - - if eval: - return { 'imgt_pred': imgt_pred, } - else: - up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) - up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) - return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], - } - diff --git a/model/amt/networks/AMT-S.py b/model/amt/networks/AMT-S.py deleted file mode 100644 index 3d8bbcc..0000000 --- a/model/amt/networks/AMT-S.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import torch.nn as nn -from model.amt.networks.blocks.raft import ( - SmallUpdateBlock, - coords_grid, - BidirCorrBlock -) -from model.amt.networks.blocks.feat_enc import ( - SmallEncoder -) -from model.amt.networks.blocks.ifrnet import ( - resize, - Encoder, - InitDecoder, - IntermediateDecoder -) -from model.amt.networks.blocks.multi_flow import ( - multi_flow_combine, - MultiFlowDecoder -) - -class Model(nn.Module): - def __init__(self, - corr_radius=3, - corr_lvls=4, - num_flows=3, - channels=[20, 32, 44, 56], - skip_channels=20): - super(Model, self).__init__() - self.radius = corr_radius - self.corr_levels = corr_lvls - self.num_flows = num_flows - self.channels = channels - self.skip_channels = skip_channels - - self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) - self.encoder = Encoder(channels) - - self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) - self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) - self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) - self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) - - self.update4 = self._get_updateblock(44) - self.update3 = self._get_updateblock(32, 2) - self.update2 = self._get_updateblock(20, 4) - - self.comb_block = nn.Sequential( - nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1), - nn.PReLU(6*num_flows), - nn.Conv2d(6*num_flows, 3, 3, 1, 1), - ) - - def _get_updateblock(self, cdim, scale_factor=None): - return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64, - fc_dim=68, scale_factor=scale_factor, - corr_levels=self.corr_levels, radius=self.radius) - - def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): - # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 - # based on linear assumption - t1_scale = 1. / embt - t0_scale = 1. / (1. - embt) - if downsample != 1: - inv = 1 / downsample - flow0 = inv * resize(flow0, scale_factor=inv) - flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) - corr = torch.cat([corr0, corr1], dim=1) - flow = torch.cat([flow0, flow1], dim=1) - return corr, flow - - def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) - img0 = img0 - mean_ - img1 = img1 - mean_ - img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 - img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 - b, _, h, w = img0_.shape - coord = coords_grid(b, h // 8, w // 8, img0.device) - - fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] - corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) - - # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] - # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] - f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) - f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) - - ######################################### the 4th decoder ######################################### - up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) - corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, - up_flow0_4, up_flow1_4, - embt, downsample=1) - - # residue update with lookup corr - delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) - delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) - up_flow0_4 = up_flow0_4 + delta_flow0_4 - up_flow1_4 = up_flow1_4 + delta_flow1_4 - ft_3_ = ft_3_ + delta_ft_3_ - - ######################################### the 3rd decoder ######################################### - up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - corr_3, flow_3 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_3, up_flow1_3, - embt, downsample=2) - - # residue update with lookup corr - delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) - delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) - up_flow0_3 = up_flow0_3 + delta_flow0_3 - up_flow1_3 = up_flow1_3 + delta_flow1_3 - ft_2_ = ft_2_ + delta_ft_2_ - - ######################################### the 2nd decoder ######################################### - up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - corr_2, flow_2 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_2, up_flow1_2, - embt, downsample=4) - - # residue update with lookup corr - delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) - delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) - up_flow0_2 = up_flow0_2 + delta_flow0_2 - up_flow1_2 = up_flow1_2 + delta_flow1_2 - ft_1_ = ft_1_ + delta_ft_1_ - - ######################################### the 1st decoder ######################################### - up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - mask = resize(mask, scale_factor=(1.0/scale_factor)) - img_res = resize(img_res, scale_factor=(1.0/scale_factor)) - - # Merge multiple predictions - imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, - mask, img_res, mean_) - imgt_pred = torch.clamp(imgt_pred, 0, 1) - - if eval: - return { 'imgt_pred': imgt_pred, } - else: - up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) - up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) - return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], - } diff --git a/model/amt/networks/IFRNet.py b/model/amt/networks/IFRNet.py deleted file mode 100644 index dbb7a69..0000000 --- a/model/amt/networks/IFRNet.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -from vbench.third_party.amt.utils.flow_utils import warp -from vbench.third_party.amt.networks.blocks.ifrnet import ( - convrelu, resize, - ResBlock, -) - - -class Encoder(nn.Module): - def __init__(self): - super(Encoder, self).__init__() - self.pyramid1 = nn.Sequential( - convrelu(3, 32, 3, 2, 1), - convrelu(32, 32, 3, 1, 1) - ) - self.pyramid2 = nn.Sequential( - convrelu(32, 48, 3, 2, 1), - convrelu(48, 48, 3, 1, 1) - ) - self.pyramid3 = nn.Sequential( - convrelu(48, 72, 3, 2, 1), - convrelu(72, 72, 3, 1, 1) - ) - self.pyramid4 = nn.Sequential( - convrelu(72, 96, 3, 2, 1), - convrelu(96, 96, 3, 1, 1) - ) - - def forward(self, img): - f1 = self.pyramid1(img) - f2 = self.pyramid2(f1) - f3 = self.pyramid3(f2) - f4 = self.pyramid4(f3) - return f1, f2, f3, f4 - - -class Decoder4(nn.Module): - def __init__(self): - super(Decoder4, self).__init__() - self.convblock = nn.Sequential( - convrelu(192+1, 192), - ResBlock(192, 32), - nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True) - ) - - def forward(self, f0, f1, embt): - b, c, h, w = f0.shape - embt = embt.repeat(1, 1, h, w) - f_in = torch.cat([f0, f1, embt], 1) - f_out = self.convblock(f_in) - return f_out - - -class Decoder3(nn.Module): - def __init__(self): - super(Decoder3, self).__init__() - self.convblock = nn.Sequential( - convrelu(220, 216), - ResBlock(216, 32), - nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True) - ) - - def forward(self, ft_, f0, f1, up_flow0, up_flow1): - f0_warp = warp(f0, up_flow0) - f1_warp = warp(f1, up_flow1) - f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) - f_out = self.convblock(f_in) - return f_out - - -class Decoder2(nn.Module): - def __init__(self): - super(Decoder2, self).__init__() - self.convblock = nn.Sequential( - convrelu(148, 144), - ResBlock(144, 32), - nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True) - ) - - def forward(self, ft_, f0, f1, up_flow0, up_flow1): - f0_warp = warp(f0, up_flow0) - f1_warp = warp(f1, up_flow1) - f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) - f_out = self.convblock(f_in) - return f_out - - -class Decoder1(nn.Module): - def __init__(self): - super(Decoder1, self).__init__() - self.convblock = nn.Sequential( - convrelu(100, 96), - ResBlock(96, 32), - nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True) - ) - - def forward(self, ft_, f0, f1, up_flow0, up_flow1): - f0_warp = warp(f0, up_flow0) - f1_warp = warp(f1, up_flow1) - f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) - f_out = self.convblock(f_in) - return f_out - - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.encoder = Encoder() - self.decoder4 = Decoder4() - self.decoder3 = Decoder3() - self.decoder2 = Decoder2() - self.decoder1 = Decoder1() - - def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) - img0 = img0 - mean_ - img1 = img1 - mean_ - - img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 - img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 - - f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) - f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) - - out4 = self.decoder4(f0_4, f1_4, embt) - up_flow0_4 = out4[:, 0:2] - up_flow1_4 = out4[:, 2:4] - ft_3_ = out4[:, 4:] - - out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) - up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) - ft_2_ = out3[:, 4:] - - out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) - up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) - ft_1_ = out2[:, 4:] - - out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) - up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) - up_mask_1 = torch.sigmoid(out1[:, 4:5]) - up_res_1 = out1[:, 5:] - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) - up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) - - img0_warp = warp(img0, up_flow0_1) - img1_warp = warp(img1, up_flow1_1) - imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ - imgt_pred = imgt_merge + up_res_1 - imgt_pred = torch.clamp(imgt_pred, 0, 1) - - if eval: - return { 'imgt_pred': imgt_pred, } - else: - return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], - 'img0_warp': img0_warp, - 'img1_warp': img1_warp - } diff --git a/model/amt/networks/__init__.py b/model/amt/networks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/amt/networks/blocks/__init__.py b/model/amt/networks/blocks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/amt/networks/blocks/feat_enc.py b/model/amt/networks/blocks/feat_enc.py deleted file mode 100644 index 3805bd3..0000000 --- a/model/amt/networks/blocks/feat_enc.py +++ /dev/null @@ -1,343 +0,0 @@ -import torch -import torch.nn as nn - - -class BottleneckBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(BottleneckBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes//4) - self.norm2 = nn.BatchNorm2d(planes//4) - self.norm3 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm4 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes//4) - self.norm2 = nn.InstanceNorm2d(planes//4) - self.norm3 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm4 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - self.norm3 = nn.Sequential() - if not stride == 1: - self.norm4 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - y = self.relu(self.norm3(self.conv3(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - - -class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - - -class SmallEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(SmallEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(32) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(32) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 32 - self.layer1 = self._make_layer(32, stride=1) - self.layer2 = self._make_layer(64, stride=2) - self.layer3 = self._make_layer(96, stride=2) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x - -class BasicEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(BasicEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(64) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(64) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) - self.layer2 = self._make_layer(72, stride=2) - self.layer3 = self._make_layer(128, stride=2) - - # output convolution - self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x - -class LargeEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(LargeEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(64) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(64) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) - self.layer2 = self._make_layer(112, stride=2) - self.layer3 = self._make_layer(160, stride=2) - self.layer3_2 = self._make_layer(160, stride=1) - - # output convolution - self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer3_2(x) - - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x diff --git a/model/amt/networks/blocks/ifrnet.py b/model/amt/networks/blocks/ifrnet.py deleted file mode 100644 index f3c5995..0000000 --- a/model/amt/networks/blocks/ifrnet.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from model.amt.utils.flow_utils import warp - - -def resize(x, scale_factor): - return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) - -def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): - return nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), - nn.PReLU(out_channels) - ) - -class ResBlock(nn.Module): - def __init__(self, in_channels, side_channels, bias=True): - super(ResBlock, self).__init__() - self.side_channels = side_channels - self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), - nn.PReLU(in_channels) - ) - self.conv2 = nn.Sequential( - nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), - nn.PReLU(side_channels) - ) - self.conv3 = nn.Sequential( - nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), - nn.PReLU(in_channels) - ) - self.conv4 = nn.Sequential( - nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), - nn.PReLU(side_channels) - ) - self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) - self.prelu = nn.PReLU(in_channels) - - def forward(self, x): - out = self.conv1(x) - - res_feat = out[:, :-self.side_channels, ...] - side_feat = out[:, -self.side_channels:, :, :] - side_feat = self.conv2(side_feat) - out = self.conv3(torch.cat([res_feat, side_feat], 1)) - - res_feat = out[:, :-self.side_channels, ...] - side_feat = out[:, -self.side_channels:, :, :] - side_feat = self.conv4(side_feat) - out = self.conv5(torch.cat([res_feat, side_feat], 1)) - - out = self.prelu(x + out) - return out - -class Encoder(nn.Module): - def __init__(self, channels, large=False): - super(Encoder, self).__init__() - self.channels = channels - prev_ch = 3 - for idx, ch in enumerate(channels, 1): - k = 7 if large and idx == 1 else 3 - p = 3 if k ==7 else 1 - self.register_module(f'pyramid{idx}', - nn.Sequential( - convrelu(prev_ch, ch, k, 2, p), - convrelu(ch, ch, 3, 1, 1) - )) - prev_ch = ch - - def forward(self, in_x): - fs = [] - for idx in range(len(self.channels)): - out_x = getattr(self, f'pyramid{idx+1}')(in_x) - fs.append(out_x) - in_x = out_x - return fs - -class InitDecoder(nn.Module): - def __init__(self, in_ch, out_ch, skip_ch) -> None: - super().__init__() - self.convblock = nn.Sequential( - convrelu(in_ch*2+1, in_ch*2), - ResBlock(in_ch*2, skip_ch), - nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) - ) - def forward(self, f0, f1, embt): - h, w = f0.shape[2:] - embt = embt.repeat(1, 1, h, w) - out = self.convblock(torch.cat([f0, f1, embt], 1)) - flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) - ft_ = out[:, 4:, ...] - return flow0, flow1, ft_ - -class IntermediateDecoder(nn.Module): - def __init__(self, in_ch, out_ch, skip_ch) -> None: - super().__init__() - self.convblock = nn.Sequential( - convrelu(in_ch*3+4, in_ch*3), - ResBlock(in_ch*3, skip_ch), - nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) - ) - def forward(self, ft_, f0, f1, flow0_in, flow1_in): - f0_warp = warp(f0, flow0_in) - f1_warp = warp(f1, flow1_in) - f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) - out = self.convblock(f_in) - flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) - ft_ = out[:, 4:, ...] - flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) - flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) - return flow0, flow1, ft_ diff --git a/model/amt/networks/blocks/multi_flow.py b/model/amt/networks/blocks/multi_flow.py deleted file mode 100644 index 78aaa45..0000000 --- a/model/amt/networks/blocks/multi_flow.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.nn as nn -from model.amt.utils.flow_utils import warp -from model.amt.networks.blocks.ifrnet import ( - convrelu, resize, - ResBlock, -) - - -def multi_flow_combine(comb_block, img0, img1, flow0, flow1, - mask=None, img_res=None, mean=None): - ''' - A parallel implementation of multiple flow field warping - comb_block: An nn.Seqential object. - img shape: [b, c, h, w] - flow shape: [b, 2*num_flows, h, w] - mask (opt): - If 'mask' is None, the function conduct a simple average. - img_res (opt): - If 'img_res' is None, the function adds zero instead. - mean (opt): - If 'mean' is None, the function adds zero instead. - ''' - b, c, h, w = flow0.shape - num_flows = c // 2 - flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - - mask = mask.reshape(b, num_flows, 1, h, w - ).reshape(-1, 1, h, w) if mask is not None else None - img_res = img_res.reshape(b, num_flows, 3, h, w - ).reshape(-1, 3, h, w) if img_res is not None else 0 - img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) - img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) - mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 - ) if mean is not None else 0 - - img0_warp = warp(img0, flow0) - img1_warp = warp(img1, flow1) - img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res - img_warps = img_warps.reshape(b, num_flows, 3, h, w) - imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) - return imgt_pred - - -class MultiFlowDecoder(nn.Module): - def __init__(self, in_ch, skip_ch, num_flows=3): - super(MultiFlowDecoder, self).__init__() - self.num_flows = num_flows - self.convblock = nn.Sequential( - convrelu(in_ch*3+4, in_ch*3), - ResBlock(in_ch*3, skip_ch), - nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) - ) - - def forward(self, ft_, f0, f1, flow0, flow1): - n = self.num_flows - f0_warp = warp(f0, flow0) - f1_warp = warp(f1, flow1) - out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) - delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) - mask = torch.sigmoid(mask) - - flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 - ).repeat(1, self.num_flows, 1, 1) - flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 - ).repeat(1, self.num_flows, 1, 1) - - return flow0, flow1, mask, img_res diff --git a/model/amt/networks/blocks/raft.py b/model/amt/networks/blocks/raft.py deleted file mode 100644 index 9fb85ad..0000000 --- a/model/amt/networks/blocks/raft.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def resize(x, scale_factor): - return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) - - -def bilinear_sampler(img, coords, mask=False): - """ Wrapper for grid_sample, uses pixel coordinates """ - H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1,1], dim=-1) - xgrid = 2*xgrid/(W-1) - 1 - ygrid = 2*ygrid/(H-1) - 1 - - grid = torch.cat([xgrid, ygrid], dim=-1) - img = F.grid_sample(img, grid, align_corners=True) - - if mask: - mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) - return img, mask.float() - - return img - - -def coords_grid(batch, ht, wd, device): - coords = torch.meshgrid(torch.arange(ht, device=device), - torch.arange(wd, device=device), - indexing='ij') - coords = torch.stack(coords[::-1], dim=0).float() - return coords[None].repeat(batch, 1, 1, 1) - - -class SmallUpdateBlock(nn.Module): - def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, - corr_levels=4, radius=3, scale_factor=None): - super(SmallUpdateBlock, self).__init__() - cor_planes = corr_levels * (2 * radius + 1) **2 - self.scale_factor = scale_factor - - self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) - self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) - self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) - self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) - - self.gru = nn.Sequential( - nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - ) - - self.feat_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, cdim, 3, padding=1), - ) - - self.flow_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, 4, 3, padding=1), - ) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, net, flow, corr): - net = resize(net, 1 / self.scale_factor - ) if self.scale_factor is not None else net - cor = self.lrelu(self.convc1(corr)) - flo = self.lrelu(self.convf1(flow)) - flo = self.lrelu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - inp = self.lrelu(self.conv(cor_flo)) - inp = torch.cat([inp, flow, net], dim=1) - - out = self.gru(inp) - delta_net = self.feat_head(out) - delta_flow = self.flow_head(out) - - if self.scale_factor is not None: - delta_net = resize(delta_net, scale_factor=self.scale_factor) - delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) - - return delta_net, delta_flow - - -class BasicUpdateBlock(nn.Module): - def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, - fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): - super(BasicUpdateBlock, self).__init__() - cor_planes = corr_levels * (2 * radius + 1) **2 - - self.scale_factor = scale_factor - self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) - self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) - self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) - self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) - self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) - - self.gru = nn.Sequential( - nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - ) - - self.feat_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, cdim, 3, padding=1), - ) - - self.flow_head = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), - ) - - self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) - - def forward(self, net, flow, corr): - net = resize(net, 1 / self.scale_factor - ) if self.scale_factor is not None else net - cor = self.lrelu(self.convc1(corr)) - cor = self.lrelu(self.convc2(cor)) - flo = self.lrelu(self.convf1(flow)) - flo = self.lrelu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - inp = self.lrelu(self.conv(cor_flo)) - inp = torch.cat([inp, flow, net], dim=1) - - out = self.gru(inp) - delta_net = self.feat_head(out) - delta_flow = self.flow_head(out) - - if self.scale_factor is not None: - delta_net = resize(delta_net, scale_factor=self.scale_factor) - delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) - return delta_net, delta_flow - - -class BidirCorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - self.corr_pyramid = [] - self.corr_pyramid_T = [] - - corr = BidirCorrBlock.corr(fmap1, fmap2) - batch, h1, w1, dim, h2, w2 = corr.shape - corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) - - corr = corr.reshape(batch*h1*w1, dim, h2, w2) - corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) - - self.corr_pyramid.append(corr) - self.corr_pyramid_T.append(corr_T) - - for _ in range(self.num_levels-1): - corr = F.avg_pool2d(corr, 2, stride=2) - corr_T = F.avg_pool2d(corr_T, 2, stride=2) - self.corr_pyramid.append(corr) - self.corr_pyramid_T.append(corr_T) - - def __call__(self, coords0, coords1): - r = self.radius - coords0 = coords0.permute(0, 2, 3, 1) - coords1 = coords1.permute(0, 2, 3, 1) - assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" - batch, h1, w1, _ = coords0.shape - - out_pyramid = [] - out_pyramid_T = [] - for i in range(self.num_levels): - corr = self.corr_pyramid[i] - corr_T = self.corr_pyramid_T[i] - - dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) - dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) - delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) - delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) - - centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i - centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i - coords_lvl_0 = centroid_lvl_0 + delta_lvl - coords_lvl_1 = centroid_lvl_1 + delta_lvl - - corr = bilinear_sampler(corr, coords_lvl_0) - corr_T = bilinear_sampler(corr_T, coords_lvl_1) - corr = corr.view(batch, h1, w1, -1) - corr_T = corr_T.view(batch, h1, w1, -1) - out_pyramid.append(corr) - out_pyramid_T.append(corr_T) - - out = torch.cat(out_pyramid, dim=-1) - out_T = torch.cat(out_pyramid_T, dim=-1) - return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() - - @staticmethod - def corr(fmap1, fmap2): - batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht*wd) - fmap2 = fmap2.view(batch, dim, ht*wd) - - corr = torch.matmul(fmap1.transpose(1,2), fmap2) - corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/model/amt/utils/__init__.py b/model/amt/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/amt/utils/build_utils.py b/model/amt/utils/build_utils.py deleted file mode 100644 index 6e0c5f5..0000000 --- a/model/amt/utils/build_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -import importlib -import os -import sys -CUR_DIR = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.join(CUR_DIR, "../")) - - -def base_build_fn(module, cls, params): - return getattr(importlib.import_module( - module, package=None), cls)(**params) - - -def build_from_cfg(config): - module, cls = config['name'].rsplit(".", 1) - params = config.get('params', {}) - return base_build_fn(module, cls, params) diff --git a/model/amt/utils/dist_utils.py b/model/amt/utils/dist_utils.py deleted file mode 100644 index 6337f99..0000000 --- a/model/amt/utils/dist_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import torch - - -def get_world_size(): - """Find OMPI world size without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_SIZE') is not None: - return int(os.environ.get('PMI_SIZE') or 1) - elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) - else: - return torch.cuda.device_count() - - -def get_global_rank(): - """Find OMPI world rank without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_RANK') is not None: - return int(os.environ.get('PMI_RANK') or 0) - elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) - else: - return 0 - - -def get_local_rank(): - """Find OMPI local rank without calling mpi functions - :rtype: int - """ - if os.environ.get('MPI_LOCALRANKID') is not None: - return int(os.environ.get('MPI_LOCALRANKID') or 0) - elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) - else: - return 0 - - -def get_master_ip(): - if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] - elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') - else: - return "127.0.0.1" - diff --git a/model/amt/utils/flow_utils.py b/model/amt/utils/flow_utils.py deleted file mode 100644 index 4adb7e9..0000000 --- a/model/amt/utils/flow_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import torch -from PIL import ImageFile -import torch.nn.functional as F -ImageFile.LOAD_TRUNCATED_IMAGES = True - - -def warp(img, flow): - B, _, H, W = flow.shape - xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) - yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) - grid = torch.cat([xx, yy], 1).to(img) - flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) - grid_ = (grid + flow_).permute(0, 2, 3, 1) - output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) - return output - - -def make_colorwheel(): - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf - Code follows the original C++ source code of Daniel Scharstein. - Code follows the the Matlab source code of Deqing Sun. - Returns: - np.ndarray: Color wheel - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = np.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) - col = col+RY - # YG - colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) - colorwheel[col:col+YG, 1] = 255 - col = col+YG - # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) - col = col+GC - # CB - colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) - colorwheel[col:col+CB, 2] = 255 - col = col+CB - # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) - col = col+BM - # MR - colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) - colorwheel[col:col+MR, 0] = 255 - return colorwheel - -def flow_uv_to_colors(u, v, convert_to_bgr=False): - """ - Applies the flow color wheel to (possibly clipped) flow components u and v. - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - Args: - u (np.ndarray): Input horizontal flow of shape [H,W] - v (np.ndarray): Input vertical flow of shape [H,W] - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) - colorwheel = make_colorwheel() # shape [55x3] - ncols = colorwheel.shape[0] - rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u)/np.pi - fk = (a+1) / 2*(ncols-1) - k0 = np.floor(fk).astype(np.int32) - k1 = k0 + 1 - k1[k1 == ncols] = 0 - f = fk - k0 - for i in range(colorwheel.shape[1]): - tmp = colorwheel[:,i] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1-f)*col0 + f*col1 - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1-col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range - # Note the 2-i => BGR instead of RGB - ch_idx = 2-i if convert_to_bgr else i - flow_image[:,:,ch_idx] = np.floor(255 * col) - return flow_image - -def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): - """ - Expects a two dimensional flow image of shape. - Args: - flow_uv (np.ndarray): Flow UV image of shape [H,W,2] - clip_flow (float, optional): Clip maximum of flow values. Defaults to None. - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' - if clip_flow is not None: - flow_uv = np.clip(flow_uv, 0, clip_flow) - u = flow_uv[:,:,0] - v = flow_uv[:,:,1] - rad = np.sqrt(np.square(u) + np.square(v)) - rad_max = np.max(rad) - epsilon = 1e-5 - u = u / (rad_max + epsilon) - v = v / (rad_max + epsilon) - return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/model/amt/utils/utils.py b/model/amt/utils/utils.py deleted file mode 100644 index 0473226..0000000 --- a/model/amt/utils/utils.py +++ /dev/null @@ -1,297 +0,0 @@ -import re -import sys -import torch -import random -import numpy as np -from PIL import ImageFile -import torch.nn.functional as F -from imageio import imread, imwrite -ImageFile.LOAD_TRUNCATED_IMAGES = True - - -class AverageMeter(): - def __init__(self): - self.reset() - - def reset(self): - self.val = 0. - self.avg = 0. - self.sum = 0. - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -class AverageMeterGroups: - def __init__(self) -> None: - self.meter_dict = dict() - - def update(self, dict, n=1): - for name, val in dict.items(): - if self.meter_dict.get(name) is None: - self.meter_dict[name] = AverageMeter() - self.meter_dict[name].update(val, n) - - def reset(self, name=None): - if name is None: - for v in self.meter_dict.values(): - v.reset() - else: - meter = self.meter_dict.get(name) - if meter is not None: - meter.reset() - - def avg(self, name): - meter = self.meter_dict.get(name) - if meter is not None: - return meter.avg - - -class InputPadder: - """ Pads images such that dimensions are divisible by divisor """ - def __init__(self, dims, divisor=16): - self.ht, self.wd = dims[-2:] - pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor - pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor - self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] - - def pad(self, *inputs): - if len(inputs) == 1: - return F.pad(inputs[0], self._pad, mode='replicate') - else: - return [F.pad(x, self._pad, mode='replicate') for x in inputs] - - def unpad(self, *inputs): - if len(inputs) == 1: - return self._unpad(inputs[0]) - else: - return [self._unpad(x) for x in inputs] - - def _unpad(self, x): - ht, wd = x.shape[-2:] - c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] - return x[..., c[0]:c[1], c[2]:c[3]] - - -def img2tensor(img): - if img.shape[-1] > 3: - img = img[:,:,:3] - return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 - - -def tensor2img(img_t): - return (img_t * 255.).detach( - ).squeeze(0).permute(1, 2, 0).cpu().numpy( - ).clip(0, 255).astype(np.uint8) - -def seed_all(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def read(file): - if file.endswith('.float3'): return readFloat(file) - elif file.endswith('.flo'): return readFlow(file) - elif file.endswith('.ppm'): return readImage(file) - elif file.endswith('.pgm'): return readImage(file) - elif file.endswith('.png'): return readImage(file) - elif file.endswith('.jpg'): return readImage(file) - elif file.endswith('.pfm'): return readPFM(file)[0] - else: raise Exception('don\'t know how to read %s' % file) - - -def write(file, data): - if file.endswith('.float3'): return writeFloat(file, data) - elif file.endswith('.flo'): return writeFlow(file, data) - elif file.endswith('.ppm'): return writeImage(file, data) - elif file.endswith('.pgm'): return writeImage(file, data) - elif file.endswith('.png'): return writeImage(file, data) - elif file.endswith('.jpg'): return writeImage(file, data) - elif file.endswith('.pfm'): return writePFM(file, data) - else: raise Exception('don\'t know how to write %s' % file) - - -def readPFM(file): - file = open(file, 'rb') - - color = None - width = None - height = None - scale = None - endian = None - - header = file.readline().rstrip() - if header.decode("ascii") == 'PF': - color = True - elif header.decode("ascii") == 'Pf': - color = False - else: - raise Exception('Not a PFM file.') - - dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) - if dim_match: - width, height = list(map(int, dim_match.groups())) - else: - raise Exception('Malformed PFM header.') - - scale = float(file.readline().decode("ascii").rstrip()) - if scale < 0: - endian = '<' - scale = -scale - else: - endian = '>' - - data = np.fromfile(file, endian + 'f') - shape = (height, width, 3) if color else (height, width) - - data = np.reshape(data, shape) - data = np.flipud(data) - return data, scale - - -def writePFM(file, image, scale=1): - file = open(file, 'wb') - - color = None - - if image.dtype.name != 'float32': - raise Exception('Image dtype must be float32.') - - image = np.flipud(image) - - if len(image.shape) == 3 and image.shape[2] == 3: - color = True - elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: - color = False - else: - raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') - - file.write('PF\n' if color else 'Pf\n'.encode()) - file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) - - endian = image.dtype.byteorder - - if endian == '<' or endian == '=' and sys.byteorder == 'little': - scale = -scale - - file.write('%f\n'.encode() % scale) - - image.tofile(file) - - -def readFlow(name): - if name.endswith('.pfm') or name.endswith('.PFM'): - return readPFM(name)[0][:,:,0:2] - - f = open(name, 'rb') - - header = f.read(4) - if header.decode("utf-8") != 'PIEH': - raise Exception('Flow file header does not contain PIEH') - - width = np.fromfile(f, np.int32, 1).squeeze() - height = np.fromfile(f, np.int32, 1).squeeze() - - flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) - - return flow.astype(np.float32) - - -def readImage(name): - if name.endswith('.pfm') or name.endswith('.PFM'): - data = readPFM(name)[0] - if len(data.shape)==3: - return data[:,:,0:3] - else: - return data - return imread(name) - - -def writeImage(name, data): - if name.endswith('.pfm') or name.endswith('.PFM'): - return writePFM(name, data, 1) - return imwrite(name, data) - - -def writeFlow(name, flow): - f = open(name, 'wb') - f.write('PIEH'.encode('utf-8')) - np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) - flow = flow.astype(np.float32) - flow.tofile(f) - - -def readFloat(name): - f = open(name, 'rb') - - if(f.readline().decode("utf-8")) != 'float\n': - raise Exception('float file %s did not contain keyword' % name) - - dim = int(f.readline()) - - dims = [] - count = 1 - for i in range(0, dim): - d = int(f.readline()) - dims.append(d) - count *= d - - dims = list(reversed(dims)) - - data = np.fromfile(f, np.float32, count).reshape(dims) - if dim > 2: - data = np.transpose(data, (2, 1, 0)) - data = np.transpose(data, (1, 0, 2)) - - return data - - -def writeFloat(name, data): - f = open(name, 'wb') - - dim=len(data.shape) - if dim>3: - raise Exception('bad float file dimension: %d' % dim) - - f.write(('float\n').encode('ascii')) - f.write(('%d\n' % dim).encode('ascii')) - - if dim == 1: - f.write(('%d\n' % data.shape[0]).encode('ascii')) - else: - f.write(('%d\n' % data.shape[1]).encode('ascii')) - f.write(('%d\n' % data.shape[0]).encode('ascii')) - for i in range(2, dim): - f.write(('%d\n' % data.shape[i]).encode('ascii')) - - data = data.astype(np.float32) - if dim==2: - data.tofile(f) - - else: - np.transpose(data, (2, 0, 1)).tofile(f) - - -def check_dim_and_resize(tensor_list): - shape_list = [] - for t in tensor_list: - shape_list.append(t.shape[2:]) - - if len(set(shape_list)) > 1: - desired_shape = shape_list[0] - print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') - - resize_tensor_list = [] - for t in tensor_list: - resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) - - tensor_list = resize_tensor_list - - return tensor_list - diff --git a/model/raft/__init__.py b/model/raft/__init__.py deleted file mode 100644 index c9ffca6..0000000 --- a/model/raft/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - diff --git a/model/raft/core/__init__.py b/model/raft/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/raft/core/corr.py b/model/raft/core/corr.py deleted file mode 100644 index 3839ba8..0000000 --- a/model/raft/core/corr.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.nn.functional as F -from .utils_core.utils import bilinear_sampler, coords_grid - -try: - import alt_cuda_corr -except: - # alt_cuda_corr is not compiled - pass - - -class CorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - self.corr_pyramid = [] - - # all pairs correlation - corr = CorrBlock.corr(fmap1, fmap2) - - batch, h1, w1, dim, h2, w2 = corr.shape - corr = corr.reshape(batch*h1*w1, dim, h2, w2) - - self.corr_pyramid.append(corr) - for i in range(self.num_levels-1): - corr = F.avg_pool2d(corr, 2, stride=2) - self.corr_pyramid.append(corr) - - def __call__(self, coords): - r = self.radius - coords = coords.permute(0, 2, 3, 1) - batch, h1, w1, _ = coords.shape - - out_pyramid = [] - for i in range(self.num_levels): - corr = self.corr_pyramid[i] - dx = torch.linspace(-r, r, 2*r+1, device=coords.device) - dy = torch.linspace(-r, r, 2*r+1, device=coords.device) - delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) - - centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i - delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) - coords_lvl = centroid_lvl + delta_lvl - - corr = bilinear_sampler(corr, coords_lvl) - corr = corr.view(batch, h1, w1, -1) - out_pyramid.append(corr) - - out = torch.cat(out_pyramid, dim=-1) - return out.permute(0, 3, 1, 2).contiguous().float() - - @staticmethod - def corr(fmap1, fmap2): - batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht*wd) - fmap2 = fmap2.view(batch, dim, ht*wd) - - corr = torch.matmul(fmap1.transpose(1,2), fmap2) - corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim).float()) - - -class AlternateCorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - - self.pyramid = [(fmap1, fmap2)] - for i in range(self.num_levels): - fmap1 = F.avg_pool2d(fmap1, 2, stride=2) - fmap2 = F.avg_pool2d(fmap2, 2, stride=2) - self.pyramid.append((fmap1, fmap2)) - - def __call__(self, coords): - coords = coords.permute(0, 2, 3, 1) - B, H, W, _ = coords.shape - dim = self.pyramid[0][0].shape[1] - - corr_list = [] - for i in range(self.num_levels): - r = self.radius - fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() - fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() - - coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() - corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) - corr_list.append(corr.squeeze(1)) - - corr = torch.stack(corr_list, dim=1) - corr = corr.reshape(B, -1, H, W) - return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/model/raft/core/datasets.py b/model/raft/core/datasets.py deleted file mode 100644 index 8e07e5d..0000000 --- a/model/raft/core/datasets.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Data loading based on https://github.com/NVIDIA/flownet2-pytorch - -import numpy as np -import torch -import torch.utils.data as data -import torch.nn.functional as F - -import os -import math -import random -from glob import glob -import os.path as osp - -from utils_core import frame_utils -from utils_core.augmentor import FlowAugmentor, SparseFlowAugmentor - - -class FlowDataset(data.Dataset): - def __init__(self, aug_params=None, sparse=False): - self.augmentor = None - self.sparse = sparse - if aug_params is not None: - if sparse: - self.augmentor = SparseFlowAugmentor(**aug_params) - else: - self.augmentor = FlowAugmentor(**aug_params) - - self.is_test = False - self.init_seed = False - self.flow_list = [] - self.image_list = [] - self.extra_info = [] - - def __getitem__(self, index): - - if self.is_test: - img1 = frame_utils.read_gen(self.image_list[index][0]) - img2 = frame_utils.read_gen(self.image_list[index][1]) - img1 = np.array(img1).astype(np.uint8)[..., :3] - img2 = np.array(img2).astype(np.uint8)[..., :3] - img1 = torch.from_numpy(img1).permute(2, 0, 1).float() - img2 = torch.from_numpy(img2).permute(2, 0, 1).float() - return img1, img2, self.extra_info[index] - - if not self.init_seed: - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - torch.manual_seed(worker_info.id) - np.random.seed(worker_info.id) - random.seed(worker_info.id) - self.init_seed = True - - index = index % len(self.image_list) - valid = None - if self.sparse: - flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) - else: - flow = frame_utils.read_gen(self.flow_list[index]) - - img1 = frame_utils.read_gen(self.image_list[index][0]) - img2 = frame_utils.read_gen(self.image_list[index][1]) - - flow = np.array(flow).astype(np.float32) - img1 = np.array(img1).astype(np.uint8) - img2 = np.array(img2).astype(np.uint8) - - # grayscale images - if len(img1.shape) == 2: - img1 = np.tile(img1[...,None], (1, 1, 3)) - img2 = np.tile(img2[...,None], (1, 1, 3)) - else: - img1 = img1[..., :3] - img2 = img2[..., :3] - - if self.augmentor is not None: - if self.sparse: - img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) - else: - img1, img2, flow = self.augmentor(img1, img2, flow) - - img1 = torch.from_numpy(img1).permute(2, 0, 1).float() - img2 = torch.from_numpy(img2).permute(2, 0, 1).float() - flow = torch.from_numpy(flow).permute(2, 0, 1).float() - - if valid is not None: - valid = torch.from_numpy(valid) - else: - valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) - - return img1, img2, flow, valid.float() - - - def __rmul__(self, v): - self.flow_list = v * self.flow_list - self.image_list = v * self.image_list - return self - - def __len__(self): - return len(self.image_list) - - -class MpiSintel(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): - super(MpiSintel, self).__init__(aug_params) - flow_root = osp.join(root, split, 'flow') - image_root = osp.join(root, split, dstype) - - if split == 'test': - self.is_test = True - - for scene in os.listdir(image_root): - image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) - for i in range(len(image_list)-1): - self.image_list += [ [image_list[i], image_list[i+1]] ] - self.extra_info += [ (scene, i) ] # scene and frame_id - - if split != 'test': - self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) - - -class FlyingChairs(FlowDataset): - def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): - super(FlyingChairs, self).__init__(aug_params) - - images = sorted(glob(osp.join(root, '*.ppm'))) - flows = sorted(glob(osp.join(root, '*.flo'))) - assert (len(images)//2 == len(flows)) - - split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) - for i in range(len(flows)): - xid = split_list[i] - if (split=='training' and xid==1) or (split=='validation' and xid==2): - self.flow_list += [ flows[i] ] - self.image_list += [ [images[2*i], images[2*i+1]] ] - - -class FlyingThings3D(FlowDataset): - def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): - super(FlyingThings3D, self).__init__(aug_params) - - for cam in ['left']: - for direction in ['into_future', 'into_past']: - image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) - image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) - - flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) - flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) - - for idir, fdir in zip(image_dirs, flow_dirs): - images = sorted(glob(osp.join(idir, '*.png')) ) - flows = sorted(glob(osp.join(fdir, '*.pfm')) ) - for i in range(len(flows)-1): - if direction == 'into_future': - self.image_list += [ [images[i], images[i+1]] ] - self.flow_list += [ flows[i] ] - elif direction == 'into_past': - self.image_list += [ [images[i+1], images[i]] ] - self.flow_list += [ flows[i+1] ] - - -class KITTI(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): - super(KITTI, self).__init__(aug_params, sparse=True) - if split == 'testing': - self.is_test = True - - root = osp.join(root, split) - images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) - images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) - - for img1, img2 in zip(images1, images2): - frame_id = img1.split('/')[-1] - self.extra_info += [ [frame_id] ] - self.image_list += [ [img1, img2] ] - - if split == 'training': - self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) - - -class HD1K(FlowDataset): - def __init__(self, aug_params=None, root='datasets/HD1k'): - super(HD1K, self).__init__(aug_params, sparse=True) - - seq_ix = 0 - while 1: - flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) - images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) - - if len(flows) == 0: - break - - for i in range(len(flows)-1): - self.flow_list += [flows[i]] - self.image_list += [ [images[i], images[i+1]] ] - - seq_ix += 1 - - -def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): - """ Create the data loader for the corresponding trainign set """ - - if args.stage == 'chairs': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} - train_dataset = FlyingChairs(aug_params, split='training') - - elif args.stage == 'things': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} - clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') - final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') - train_dataset = clean_dataset + final_dataset - - elif args.stage == 'sintel': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} - things = FlyingThings3D(aug_params, dstype='frames_cleanpass') - sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') - sintel_final = MpiSintel(aug_params, split='training', dstype='final') - - if TRAIN_DS == 'C+T+K+S+H': - kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) - hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) - train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things - - elif TRAIN_DS == 'C+T+K/S': - train_dataset = 100*sintel_clean + 100*sintel_final + things - - elif args.stage == 'kitti': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} - train_dataset = KITTI(aug_params, split='training') - - train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, - pin_memory=False, shuffle=True, num_workers=4, drop_last=True) - - print('Training with %d image pairs' % len(train_dataset)) - return train_loader - diff --git a/model/raft/core/extractor.py b/model/raft/core/extractor.py deleted file mode 100644 index 9a9c759..0000000 --- a/model/raft/core/extractor.py +++ /dev/null @@ -1,267 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - - - -class BottleneckBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(BottleneckBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes//4) - self.norm2 = nn.BatchNorm2d(planes//4) - self.norm3 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm4 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes//4) - self.norm2 = nn.InstanceNorm2d(planes//4) - self.norm3 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm4 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - self.norm3 = nn.Sequential() - if not stride == 1: - self.norm4 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - y = self.relu(self.norm3(self.conv3(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - -class BasicEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(BasicEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(64) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(64) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) - self.layer2 = self._make_layer(96, stride=2) - self.layer3 = self._make_layer(128, stride=2) - - # output convolution - self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x - - -class SmallEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(SmallEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(32) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(32) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 32 - self.layer1 = self._make_layer(32, stride=1) - self.layer2 = self._make_layer(64, stride=2) - self.layer3 = self._make_layer(96, stride=2) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x diff --git a/model/raft/core/raft.py b/model/raft/core/raft.py deleted file mode 100644 index 1d7404b..0000000 --- a/model/raft/core/raft.py +++ /dev/null @@ -1,144 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .update import BasicUpdateBlock, SmallUpdateBlock -from .extractor import BasicEncoder, SmallEncoder -from .corr import CorrBlock, AlternateCorrBlock -from .utils_core.utils import bilinear_sampler, coords_grid, upflow8 - -try: - autocast = torch.cuda.amp.autocast -except: - # dummy autocast for PyTorch < 1.6 - class autocast: - def __init__(self, enabled): - pass - def __enter__(self): - pass - def __exit__(self, *args): - pass - - -class RAFT(nn.Module): - def __init__(self, args): - super(RAFT, self).__init__() - self.args = args - - if args.small: - self.hidden_dim = hdim = 96 - self.context_dim = cdim = 64 - args.corr_levels = 4 - args.corr_radius = 3 - - else: - self.hidden_dim = hdim = 128 - self.context_dim = cdim = 128 - args.corr_levels = 4 - args.corr_radius = 4 - - if 'dropout' not in self.args: - self.args.dropout = 0 - - if 'alternate_corr' not in self.args: - self.args.alternate_corr = False - - # feature network, context network, and update block - if args.small: - self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) - self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) - self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) - - else: - self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) - self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) - self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) - - def freeze_bn(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eval() - - def initialize_flow(self, img): - """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" - N, C, H, W = img.shape - coords0 = coords_grid(N, H//8, W//8, device=img.device) - coords1 = coords_grid(N, H//8, W//8, device=img.device) - - # optical flow computed as difference: flow = coords1 - coords0 - return coords0, coords1 - - def upsample_flow(self, flow, mask): - """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ - N, _, H, W = flow.shape - mask = mask.view(N, 1, 9, 8, 8, H, W) - mask = torch.softmax(mask, dim=2) - - up_flow = F.unfold(8 * flow, [3,3], padding=1) - up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) - - up_flow = torch.sum(mask * up_flow, dim=2) - up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, 2, 8*H, 8*W) - - - def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): - """ Estimate optical flow between pair of frames """ - - image1 = 2 * (image1 / 255.0) - 1.0 - image2 = 2 * (image2 / 255.0) - 1.0 - - image1 = image1.contiguous() - image2 = image2.contiguous() - - hdim = self.hidden_dim - cdim = self.context_dim - - # run the feature network - with autocast(enabled=self.args.mixed_precision): - fmap1, fmap2 = self.fnet([image1, image2]) - - fmap1 = fmap1.float() - fmap2 = fmap2.float() - if self.args.alternate_corr: - corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) - else: - corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) - - # run the context network - with autocast(enabled=self.args.mixed_precision): - cnet = self.cnet(image1) - net, inp = torch.split(cnet, [hdim, cdim], dim=1) - net = torch.tanh(net) - inp = torch.relu(inp) - - coords0, coords1 = self.initialize_flow(image1) - - if flow_init is not None: - coords1 = coords1 + flow_init - - flow_predictions = [] - for itr in range(iters): - coords1 = coords1.detach() - corr = corr_fn(coords1) # index correlation volume - - flow = coords1 - coords0 - with autocast(enabled=self.args.mixed_precision): - net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) - - # F(t+1) = F(t) + \Delta(t) - coords1 = coords1 + delta_flow - - # upsample predictions - if up_mask is None: - flow_up = upflow8(coords1 - coords0) - else: - flow_up = self.upsample_flow(coords1 - coords0, up_mask) - - flow_predictions.append(flow_up) - - if test_mode: - return coords1 - coords0, flow_up - - return flow_predictions diff --git a/model/raft/core/update.py b/model/raft/core/update.py deleted file mode 100644 index f940497..0000000 --- a/model/raft/core/update.py +++ /dev/null @@ -1,139 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class FlowHead(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256): - super(FlowHead, self).__init__() - self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) - self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - return self.conv2(self.relu(self.conv1(x))) - -class ConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): - super(ConvGRU, self).__init__() - self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - - def forward(self, h, x): - hx = torch.cat([h, x], dim=1) - - z = torch.sigmoid(self.convz(hx)) - r = torch.sigmoid(self.convr(hx)) - q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) - - h = (1-z) * h + z * q - return h - -class SepConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): - super(SepConvGRU, self).__init__() - self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - - self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - - - def forward(self, h, x): - # horizontal - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz1(hx)) - r = torch.sigmoid(self.convr1(hx)) - q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q - - # vertical - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz2(hx)) - r = torch.sigmoid(self.convr2(hx)) - q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q - - return h - -class SmallMotionEncoder(nn.Module): - def __init__(self, args): - super(SmallMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 - self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) - self.convf1 = nn.Conv2d(2, 64, 7, padding=3) - self.convf2 = nn.Conv2d(64, 32, 3, padding=1) - self.conv = nn.Conv2d(128, 80, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - -class BasicMotionEncoder(nn.Module): - def __init__(self, args): - super(BasicMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 - self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) - self.convc2 = nn.Conv2d(256, 192, 3, padding=1) - self.convf1 = nn.Conv2d(2, 128, 7, padding=3) - self.convf2 = nn.Conv2d(128, 64, 3, padding=1) - self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - cor = F.relu(self.convc2(cor)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - -class SmallUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=96): - super(SmallUpdateBlock, self).__init__() - self.encoder = SmallMotionEncoder(args) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) - self.flow_head = FlowHead(hidden_dim, hidden_dim=128) - - def forward(self, net, inp, corr, flow): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - return net, None, delta_flow - -class BasicUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=128, input_dim=128): - super(BasicUpdateBlock, self).__init__() - self.args = args - self.encoder = BasicMotionEncoder(args) - self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) - self.flow_head = FlowHead(hidden_dim, hidden_dim=256) - - self.mask = nn.Sequential( - nn.Conv2d(128, 256, 3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 64*9, 1, padding=0)) - - def forward(self, net, inp, corr, flow, upsample=True): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - # scale mask to balence gradients - mask = .25 * self.mask(net) - return net, mask, delta_flow - - - diff --git a/model/raft/core/utils_core/__init__.py b/model/raft/core/utils_core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/raft/core/utils_core/augmentor.py b/model/raft/core/utils_core/augmentor.py deleted file mode 100644 index e81c4f2..0000000 --- a/model/raft/core/utils_core/augmentor.py +++ /dev/null @@ -1,246 +0,0 @@ -import numpy as np -import random -import math -from PIL import Image - -import cv2 -cv2.setNumThreads(0) -cv2.ocl.setUseOpenCL(False) - -import torch -from torchvision.transforms import ColorJitter -import torch.nn.functional as F - - -class FlowAugmentor: - def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): - - # spatial augmentation params - self.crop_size = crop_size - self.min_scale = min_scale - self.max_scale = max_scale - self.spatial_aug_prob = 0.8 - self.stretch_prob = 0.8 - self.max_stretch = 0.2 - - # flip augmentation params - self.do_flip = do_flip - self.h_flip_prob = 0.5 - self.v_flip_prob = 0.1 - - # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) - self.asymmetric_color_aug_prob = 0.2 - self.eraser_aug_prob = 0.5 - - def color_transform(self, img1, img2): - """ Photometric augmentation """ - - # asymmetric - if np.random.rand() < self.asymmetric_color_aug_prob: - img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) - img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) - - # symmetric - else: - image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) - img1, img2 = np.split(image_stack, 2, axis=0) - - return img1, img2 - - def eraser_transform(self, img1, img2, bounds=[50, 100]): - """ Occlusion augmentation """ - - ht, wd = img1.shape[:2] - if np.random.rand() < self.eraser_aug_prob: - mean_color = np.mean(img2.reshape(-1, 3), axis=0) - for _ in range(np.random.randint(1, 3)): - x0 = np.random.randint(0, wd) - y0 = np.random.randint(0, ht) - dx = np.random.randint(bounds[0], bounds[1]) - dy = np.random.randint(bounds[0], bounds[1]) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color - - return img1, img2 - - def spatial_transform(self, img1, img2, flow): - # randomly sample scale - ht, wd = img1.shape[:2] - min_scale = np.maximum( - (self.crop_size[0] + 8) / float(ht), - (self.crop_size[1] + 8) / float(wd)) - - scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) - scale_x = scale - scale_y = scale - if np.random.rand() < self.stretch_prob: - scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) - scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) - - scale_x = np.clip(scale_x, min_scale, None) - scale_y = np.clip(scale_y, min_scale, None) - - if np.random.rand() < self.spatial_aug_prob: - # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow = flow * [scale_x, scale_y] - - if self.do_flip: - if np.random.rand() < self.h_flip_prob: # h-flip - img1 = img1[:, ::-1] - img2 = img2[:, ::-1] - flow = flow[:, ::-1] * [-1.0, 1.0] - - if np.random.rand() < self.v_flip_prob: # v-flip - img1 = img1[::-1, :] - img2 = img2[::-1, :] - flow = flow[::-1, :] * [1.0, -1.0] - - y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) - x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) - - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - - return img1, img2, flow - - def __call__(self, img1, img2, flow): - img1, img2 = self.color_transform(img1, img2) - img1, img2 = self.eraser_transform(img1, img2) - img1, img2, flow = self.spatial_transform(img1, img2, flow) - - img1 = np.ascontiguousarray(img1) - img2 = np.ascontiguousarray(img2) - flow = np.ascontiguousarray(flow) - - return img1, img2, flow - -class SparseFlowAugmentor: - def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): - # spatial augmentation params - self.crop_size = crop_size - self.min_scale = min_scale - self.max_scale = max_scale - self.spatial_aug_prob = 0.8 - self.stretch_prob = 0.8 - self.max_stretch = 0.2 - - # flip augmentation params - self.do_flip = do_flip - self.h_flip_prob = 0.5 - self.v_flip_prob = 0.1 - - # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) - self.asymmetric_color_aug_prob = 0.2 - self.eraser_aug_prob = 0.5 - - def color_transform(self, img1, img2): - image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) - img1, img2 = np.split(image_stack, 2, axis=0) - return img1, img2 - - def eraser_transform(self, img1, img2): - ht, wd = img1.shape[:2] - if np.random.rand() < self.eraser_aug_prob: - mean_color = np.mean(img2.reshape(-1, 3), axis=0) - for _ in range(np.random.randint(1, 3)): - x0 = np.random.randint(0, wd) - y0 = np.random.randint(0, ht) - dx = np.random.randint(50, 100) - dy = np.random.randint(50, 100) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color - - return img1, img2 - - def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): - ht, wd = flow.shape[:2] - coords = np.meshgrid(np.arange(wd), np.arange(ht)) - coords = np.stack(coords, axis=-1) - - coords = coords.reshape(-1, 2).astype(np.float32) - flow = flow.reshape(-1, 2).astype(np.float32) - valid = valid.reshape(-1).astype(np.float32) - - coords0 = coords[valid>=1] - flow0 = flow[valid>=1] - - ht1 = int(round(ht * fy)) - wd1 = int(round(wd * fx)) - - coords1 = coords0 * [fx, fy] - flow1 = flow0 * [fx, fy] - - xx = np.round(coords1[:,0]).astype(np.int32) - yy = np.round(coords1[:,1]).astype(np.int32) - - v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) - xx = xx[v] - yy = yy[v] - flow1 = flow1[v] - - flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) - valid_img = np.zeros([ht1, wd1], dtype=np.int32) - - flow_img[yy, xx] = flow1 - valid_img[yy, xx] = 1 - - return flow_img, valid_img - - def spatial_transform(self, img1, img2, flow, valid): - # randomly sample scale - - ht, wd = img1.shape[:2] - min_scale = np.maximum( - (self.crop_size[0] + 1) / float(ht), - (self.crop_size[1] + 1) / float(wd)) - - scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) - scale_x = np.clip(scale, min_scale, None) - scale_y = np.clip(scale, min_scale, None) - - if np.random.rand() < self.spatial_aug_prob: - # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) - - if self.do_flip: - if np.random.rand() < 0.5: # h-flip - img1 = img1[:, ::-1] - img2 = img2[:, ::-1] - flow = flow[:, ::-1] * [-1.0, 1.0] - valid = valid[:, ::-1] - - margin_y = 20 - margin_x = 50 - - y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) - x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) - - y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) - x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) - - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - return img1, img2, flow, valid - - - def __call__(self, img1, img2, flow, valid): - img1, img2 = self.color_transform(img1, img2) - img1, img2 = self.eraser_transform(img1, img2) - img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) - - img1 = np.ascontiguousarray(img1) - img2 = np.ascontiguousarray(img2) - flow = np.ascontiguousarray(flow) - valid = np.ascontiguousarray(valid) - - return img1, img2, flow, valid diff --git a/model/raft/core/utils_core/flow_viz.py b/model/raft/core/utils_core/flow_viz.py deleted file mode 100644 index dcee65e..0000000 --- a/model/raft/core/utils_core/flow_viz.py +++ /dev/null @@ -1,132 +0,0 @@ -# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization - - -# MIT License -# -# Copyright (c) 2018 Tom Runia -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to conditions. -# -# Author: Tom Runia -# Date Created: 2018-08-03 - -import numpy as np - -def make_colorwheel(): - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf - - Code follows the original C++ source code of Daniel Scharstein. - Code follows the the Matlab source code of Deqing Sun. - - Returns: - np.ndarray: Color wheel - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = np.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) - col = col+RY - # YG - colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) - colorwheel[col:col+YG, 1] = 255 - col = col+YG - # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) - col = col+GC - # CB - colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) - colorwheel[col:col+CB, 2] = 255 - col = col+CB - # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) - col = col+BM - # MR - colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) - colorwheel[col:col+MR, 0] = 255 - return colorwheel - - -def flow_uv_to_colors(u, v, convert_to_bgr=False): - """ - Applies the flow color wheel to (possibly clipped) flow components u and v. - - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - - Args: - u (np.ndarray): Input horizontal flow of shape [H,W] - v (np.ndarray): Input vertical flow of shape [H,W] - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) - colorwheel = make_colorwheel() # shape [55x3] - ncols = colorwheel.shape[0] - rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u)/np.pi - fk = (a+1) / 2*(ncols-1) - k0 = np.floor(fk).astype(np.int32) - k1 = k0 + 1 - k1[k1 == ncols] = 0 - f = fk - k0 - for i in range(colorwheel.shape[1]): - tmp = colorwheel[:,i] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1-f)*col0 + f*col1 - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1-col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range - # Note the 2-i => BGR instead of RGB - ch_idx = 2-i if convert_to_bgr else i - flow_image[:,:,ch_idx] = np.floor(255 * col) - return flow_image - - -def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): - """ - Expects a two dimensional flow image of shape. - - Args: - flow_uv (np.ndarray): Flow UV image of shape [H,W,2] - clip_flow (float, optional): Clip maximum of flow values. Defaults to None. - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' - if clip_flow is not None: - flow_uv = np.clip(flow_uv, 0, clip_flow) - u = flow_uv[:,:,0] - v = flow_uv[:,:,1] - rad = np.sqrt(np.square(u) + np.square(v)) - rad_max = np.max(rad) - epsilon = 1e-5 - u = u / (rad_max + epsilon) - v = v / (rad_max + epsilon) - return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/model/raft/core/utils_core/frame_utils.py b/model/raft/core/utils_core/frame_utils.py deleted file mode 100644 index 6c49113..0000000 --- a/model/raft/core/utils_core/frame_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -import numpy as np -from PIL import Image -from os.path import * -import re - -import cv2 -cv2.setNumThreads(0) -cv2.ocl.setUseOpenCL(False) - -TAG_CHAR = np.array([202021.25], np.float32) - -def readFlow(fn): - """ Read .flo file in Middlebury format""" - # Code adapted from: - # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy - - # WARNING: this will work on little-endian architectures (eg Intel x86) only! - # print 'fn = %s'%(fn) - with open(fn, 'rb') as f: - magic = np.fromfile(f, np.float32, count=1) - if 202021.25 != magic: - print('Magic number incorrect. Invalid .flo file') - return None - else: - w = np.fromfile(f, np.int32, count=1) - h = np.fromfile(f, np.int32, count=1) - # print 'Reading %d x %d flo file\n' % (w, h) - data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) - # Reshape data into 3D array (columns, rows, bands) - # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (int(h), int(w), 2)) - -def readPFM(file): - file = open(file, 'rb') - - color = None - width = None - height = None - scale = None - endian = None - - header = file.readline().rstrip() - if header == b'PF': - color = True - elif header == b'Pf': - color = False - else: - raise Exception('Not a PFM file.') - - dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) - if dim_match: - width, height = map(int, dim_match.groups()) - else: - raise Exception('Malformed PFM header.') - - scale = float(file.readline().rstrip()) - if scale < 0: # little-endian - endian = '<' - scale = -scale - else: - endian = '>' # big-endian - - data = np.fromfile(file, endian + 'f') - shape = (height, width, 3) if color else (height, width) - - data = np.reshape(data, shape) - data = np.flipud(data) - return data - -def writeFlow(filename,uv,v=None): - """ Write optical flow to file. - - If v is None, uv is assumed to contain both u and v channels, - stacked in depth. - Original code by Deqing Sun, adapted from Daniel Scharstein. - """ - nBands = 2 - - if v is None: - assert(uv.ndim == 3) - assert(uv.shape[2] == 2) - u = uv[:,:,0] - v = uv[:,:,1] - else: - u = uv - - assert(u.shape == v.shape) - height,width = u.shape - f = open(filename,'wb') - # write the header - f.write(TAG_CHAR) - np.array(width).astype(np.int32).tofile(f) - np.array(height).astype(np.int32).tofile(f) - # arrange into matrix form - tmp = np.zeros((height, width*nBands)) - tmp[:,np.arange(width)*2] = u - tmp[:,np.arange(width)*2 + 1] = v - tmp.astype(np.float32).tofile(f) - f.close() - - -def readFlowKITTI(filename): - flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) - flow = flow[:,:,::-1].astype(np.float32) - flow, valid = flow[:, :, :2], flow[:, :, 2] - flow = (flow - 2**15) / 64.0 - return flow, valid - -def readDispKITTI(filename): - disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 - valid = disp > 0.0 - flow = np.stack([-disp, np.zeros_like(disp)], -1) - return flow, valid - - -def writeFlowKITTI(filename, uv): - uv = 64.0 * uv + 2**15 - valid = np.ones([uv.shape[0], uv.shape[1], 1]) - uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) - cv2.imwrite(filename, uv[..., ::-1]) - - -def read_gen(file_name, pil=False): - ext = splitext(file_name)[-1] - if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': - return Image.open(file_name) - elif ext == '.bin' or ext == '.raw': - return np.load(file_name) - elif ext == '.flo': - return readFlow(file_name).astype(np.float32) - elif ext == '.pfm': - flow = readPFM(file_name).astype(np.float32) - if len(flow.shape) == 2: - return flow - else: - return flow[:, :, :-1] - return [] \ No newline at end of file diff --git a/model/raft/core/utils_core/utils.py b/model/raft/core/utils_core/utils.py deleted file mode 100644 index 15cdc6f..0000000 --- a/model/raft/core/utils_core/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2025 THU-BPM MarkDiffusion. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn.functional as F -import numpy as np -from scipy import interpolate - - -class InputPadder: - """ Pads images such that dimensions are divisible by 8 """ - def __init__(self, dims, mode='sintel'): - self.ht, self.wd = dims[-2:] - pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 - pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 - if mode == 'sintel': - self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] - else: - self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] - - def pad(self, *inputs): - return [F.pad(x, self._pad, mode='replicate') for x in inputs] - - def unpad(self,x): - ht, wd = x.shape[-2:] - c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] - return x[..., c[0]:c[1], c[2]:c[3]] - -def forward_interpolate(flow): - flow = flow.detach().cpu().numpy() - dx, dy = flow[0], flow[1] - - ht, wd = dx.shape - x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) - - x1 = x0 + dx - y1 = y0 + dy - - x1 = x1.reshape(-1) - y1 = y1.reshape(-1) - dx = dx.reshape(-1) - dy = dy.reshape(-1) - - valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) - x1 = x1[valid] - y1 = y1[valid] - dx = dx[valid] - dy = dy[valid] - - flow_x = interpolate.griddata( - (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) - - flow_y = interpolate.griddata( - (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) - - flow = np.stack([flow_x, flow_y], axis=0) - return torch.from_numpy(flow).float() - - -def bilinear_sampler(img, coords, mode='bilinear', mask=False): - """ Wrapper for grid_sample, uses pixel coordinates """ - H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1,1], dim=-1) - xgrid = 2*xgrid/(W-1) - 1 - ygrid = 2*ygrid/(H-1) - 1 - - grid = torch.cat([xgrid, ygrid], dim=-1) - img = F.grid_sample(img, grid, align_corners=True) - - if mask: - mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) - return img, mask.float() - - return img - - -def coords_grid(batch, ht, wd, device): - coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) - coords = torch.stack(coords[::-1], dim=0).float() - return coords[None].repeat(batch, 1, 1, 1) - - -def upflow8(flow, mode='bilinear'): - new_size = (8 * flow.shape[2], 8 * flow.shape[3]) - return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/model/raft/raft-things.pth b/model/raft/raft-things.pth deleted file mode 100644 index dbe6f9f..0000000 Binary files a/model/raft/raft-things.pth and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index 1d49248..7cb5041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "setuptools.build_meta" [project] name = "markdiffusion" -version = "0.1.0" +dynamic = ["version"] description = "An Open-Source Toolkit for Generative Watermarking of Latent Diffusion Models" readme = {file = "README.md", content-type = "text/markdown"} license = {text = "Apache-2.0"} diff --git a/test/README.md b/test/README.md index be8a5a1..98b3024 100644 --- a/test/README.md +++ b/test/README.md @@ -75,6 +75,16 @@ Test dependencies include: #### Run directly with pytest ```bash +# Test the whole project and report coverage +pytest test -v \ + --cov=. \ + --cov-report=html \ + --cov-report=term-missing \ + --html=report.html + +# Test all pipelines +pytest test/test_pipelines.py -v + # Test all algorithms and modules pytest test/test_watermark_algorithms.py -v diff --git a/test/conftest.py b/test/conftest.py index 05c6c8d..8af97b2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -31,10 +31,9 @@ # ============================================================================ # Default model paths (can be overridden via pytest options) -# DEFAULT_IMAGE_MODEL_PATH = "huanzi05/stable-diffusion-2-1-base" -# DEFAULT_VIDEO_MODEL_PATH = "ali-vilab/text-to-video-ms-1.7b" -DEFAULT_VIDEO_MODEL_PATH = "/mnt/ckpt/text-to-video-ms-1.7b" -DEFAULT_IMAGE_MODEL_PATH = "/mnt/ckpt/stable-diffusion-2-1-base" +DEFAULT_IMAGE_MODEL_PATH = "/home/harry/models/stable-diffusion-2-1-base" +DEFAULT_VIDEO_MODEL_PATH = "/home/harry/models/text-to-video-ms-1.7b" + # Test prompts TEST_PROMPT_IMAGE = "A beautiful sunset over the ocean" TEST_PROMPT_VIDEO = "A cinematic timelapse of city lights at night" @@ -232,6 +231,8 @@ def video_pipeline(device, video_model_path): def image_diffusion_config(device, image_pipeline): """Create diffusion config for image generation.""" pipe, scheduler = image_pipeline + # Use the pipeline's dtype to ensure consistency + pipe_dtype = getattr(pipe, 'dtype', torch.float32) return DiffusionConfig( scheduler=scheduler, pipe=pipe, @@ -240,7 +241,8 @@ def image_diffusion_config(device, image_pipeline): num_inference_steps=NUM_INFERENCE_STEPS, guidance_scale=GUIDANCE_SCALE, gen_seed=GEN_SEED, - inversion_type="ddim" + inversion_type="ddim", + dtype=pipe_dtype ) @@ -248,6 +250,8 @@ def image_diffusion_config(device, image_pipeline): def video_diffusion_config(device, video_pipeline): """Create diffusion config for video generation.""" pipe, scheduler = video_pipeline + # Explicitly set dtype to match the pipeline's dtype to avoid Half/Float mismatch + pipe_dtype = torch.float16 if device == 'cuda' else torch.float32 return DiffusionConfig( scheduler=scheduler, pipe=pipe, @@ -257,7 +261,8 @@ def video_diffusion_config(device, video_pipeline): guidance_scale=GUIDANCE_SCALE, gen_seed=GEN_SEED, inversion_type="ddim", - num_frames=NUM_FRAMES + num_frames=NUM_FRAMES, + dtype=pipe_dtype ) @@ -310,7 +315,7 @@ def all_image_editors(): Brightness(), Mask(), Overlay(), - # AdaptiveNoiseInjection() + AdaptiveNoiseInjection() ] diff --git a/test/test_dataset.py b/test/test_dataset.py new file mode 100644 index 0000000..4a694a2 --- /dev/null +++ b/test/test_dataset.py @@ -0,0 +1,519 @@ +# Copyright 2025 THU-BPM MarkDiffusion. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for dataset classes in MarkDiffusion. + +Tests cover: +- BaseDataset: Base class functionality +- StableDiffusionPromptsDataset: Prompt-only dataset +- MSCOCODataset: Image-caption dataset +- VBenchDataset: Video benchmark dataset +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch +from PIL import Image +import pandas as pd + +from evaluation.dataset import ( + BaseDataset, + StableDiffusionPromptsDataset, + MSCOCODataset, + VBenchDataset, +) + + +# ============================================================================ +# Tests for BaseDataset +# ============================================================================ + +class TestBaseDataset: + """Tests for BaseDataset class.""" + + def test_initialization_with_default_max_samples(self): + """Test BaseDataset initializes with default max_samples.""" + dataset = BaseDataset() + assert dataset.max_samples == 200 + assert dataset.prompts == [] + assert dataset.references == [] + + def test_initialization_with_custom_max_samples(self): + """Test BaseDataset initializes with custom max_samples.""" + dataset = BaseDataset(max_samples=50) + assert dataset.max_samples == 50 + + def test_num_samples_property(self): + """Test num_samples returns length of prompts.""" + dataset = BaseDataset() + assert dataset.num_samples == 0 + + dataset.prompts = ["prompt1", "prompt2", "prompt3"] + assert dataset.num_samples == 3 + + def test_num_references_property(self): + """Test num_references returns length of references.""" + dataset = BaseDataset() + assert dataset.num_references == 0 + + dataset.references = [Mock(), Mock()] + assert dataset.num_references == 2 + + def test_len_method(self): + """Test __len__ returns num_samples.""" + dataset = BaseDataset() + dataset.prompts = ["a", "b", "c", "d"] + assert len(dataset) == 4 + + def test_get_prompt(self): + """Test get_prompt returns correct prompt at index.""" + dataset = BaseDataset() + dataset.prompts = ["first", "second", "third"] + + assert dataset.get_prompt(0) == "first" + assert dataset.get_prompt(1) == "second" + assert dataset.get_prompt(2) == "third" + + def test_get_prompt_index_error(self): + """Test get_prompt raises IndexError for invalid index.""" + dataset = BaseDataset() + dataset.prompts = ["only_one"] + + with pytest.raises(IndexError): + dataset.get_prompt(5) + + def test_get_reference(self): + """Test get_reference returns correct reference at index.""" + dataset = BaseDataset() + mock_images = [Mock(spec=Image.Image), Mock(spec=Image.Image)] + dataset.references = mock_images + + assert dataset.get_reference(0) == mock_images[0] + assert dataset.get_reference(1) == mock_images[1] + + def test_get_reference_index_error(self): + """Test get_reference raises IndexError for invalid index.""" + dataset = BaseDataset() + dataset.references = [] + + with pytest.raises(IndexError): + dataset.get_reference(0) + + def test_getitem_without_references(self): + """Test __getitem__ returns only prompt when no references.""" + dataset = BaseDataset() + dataset.prompts = ["prompt1", "prompt2"] + + assert dataset[0] == "prompt1" + assert dataset[1] == "prompt2" + + def test_getitem_with_references(self): + """Test __getitem__ returns (prompt, reference) tuple when references exist.""" + dataset = BaseDataset() + dataset.prompts = ["prompt1", "prompt2"] + mock_images = [Mock(spec=Image.Image), Mock(spec=Image.Image)] + dataset.references = mock_images + + result = dataset[0] + assert isinstance(result, tuple) + assert result[0] == "prompt1" + assert result[1] == mock_images[0] + + def test_load_data_is_noop(self): + """Test _load_data does nothing in base class.""" + dataset = BaseDataset() + dataset._load_data() # Should not raise + assert dataset.prompts == [] + assert dataset.references == [] + + +# ============================================================================ +# Tests for StableDiffusionPromptsDataset +# ============================================================================ + +class TestStableDiffusionPromptsDataset: + """Tests for StableDiffusionPromptsDataset class.""" + + @patch('evaluation.dataset.load_dataset') + def test_initialization(self, mock_load_dataset): + """Test dataset initializes correctly.""" + # Setup mock + mock_data = {"Prompt": ["prompt1", "prompt2", "prompt3"]} + mock_load_dataset.return_value = mock_data + + dataset = StableDiffusionPromptsDataset(max_samples=2) + + assert dataset.max_samples == 2 + assert dataset.split == "test" + assert dataset.shuffle is False + mock_load_dataset.assert_called_once() + + @patch('evaluation.dataset.load_dataset') + def test_name_property(self, mock_load_dataset): + """Test name property returns correct name.""" + mock_load_dataset.return_value = {"Prompt": []} + dataset = StableDiffusionPromptsDataset(max_samples=1) + assert dataset.name == "Stable Diffusion Prompts" + + @patch('evaluation.dataset.load_dataset') + def test_prompts_loaded(self, mock_load_dataset): + """Test prompts are loaded from dataset.""" + test_prompts = ["A cat sitting on a mat", "A dog running in park", "A bird flying"] + mock_load_dataset.return_value = {"Prompt": test_prompts} + + dataset = StableDiffusionPromptsDataset(max_samples=3) + + assert len(dataset.prompts) == 3 + assert dataset.prompts == test_prompts + + @patch('evaluation.dataset.load_dataset') + def test_max_samples_limit(self, mock_load_dataset): + """Test max_samples limits the number of prompts loaded.""" + test_prompts = ["p1", "p2", "p3", "p4", "p5"] + mock_load_dataset.return_value = {"Prompt": test_prompts} + + dataset = StableDiffusionPromptsDataset(max_samples=2) + + assert len(dataset.prompts) == 2 + assert dataset.prompts == ["p1", "p2"] + + @patch('evaluation.dataset.load_dataset') + def test_shuffle_option(self, mock_load_dataset): + """Test shuffle option is passed to dataset.""" + mock_dataset = MagicMock() + mock_dataset.__getitem__ = lambda self, key: ["p1", "p2"] + mock_dataset.shuffle.return_value = mock_dataset + mock_load_dataset.return_value = mock_dataset + + dataset = StableDiffusionPromptsDataset(max_samples=2, shuffle=True) + + assert dataset.shuffle is True + mock_dataset.shuffle.assert_called_once() + + @patch('evaluation.dataset.load_dataset') + def test_custom_split(self, mock_load_dataset): + """Test custom split option.""" + mock_load_dataset.return_value = {"Prompt": ["p1"]} + + dataset = StableDiffusionPromptsDataset(max_samples=1, split="train") + + assert dataset.split == "train" + mock_load_dataset.assert_called_with( + "dataset/stable_diffusion_prompts", split="train" + ) + + @patch('evaluation.dataset.load_dataset') + def test_no_references(self, mock_load_dataset): + """Test that StableDiffusionPromptsDataset has no references.""" + mock_load_dataset.return_value = {"Prompt": ["p1", "p2"]} + + dataset = StableDiffusionPromptsDataset(max_samples=2) + + assert dataset.num_references == 0 + assert dataset.references == [] + + +# ============================================================================ +# Tests for MSCOCODataset +# ============================================================================ + +class TestMSCOCODataset: + """Tests for MSCOCODataset class.""" + + @patch('evaluation.dataset.pd.read_parquet') + @patch('evaluation.dataset.tqdm') + def test_initialization(self, mock_tqdm, mock_read_parquet): + """Test dataset initializes correctly.""" + # Setup mock DataFrame + mock_df = pd.DataFrame({ + 'TEXT': ['caption1', 'caption2'], + 'URL': ['http://example.com/1.jpg', 'http://example.com/2.jpg'] + }) + mock_read_parquet.return_value = mock_df + mock_tqdm.return_value = range(2) + + with patch.object(MSCOCODataset, '_load_image_from_url', return_value=Mock(spec=Image.Image)): + dataset = MSCOCODataset(max_samples=2) + + assert dataset.max_samples == 2 + assert dataset.shuffle is False + + @patch('evaluation.dataset.pd.read_parquet') + def test_name_property(self, mock_read_parquet): + """Test name property returns correct name.""" + mock_read_parquet.return_value = pd.DataFrame({'TEXT': [], 'URL': []}) + + with patch.object(MSCOCODataset, '_load_data'): + dataset = MSCOCODataset.__new__(MSCOCODataset) + dataset.max_samples = 0 + dataset.prompts = [] + dataset.references = [] + dataset.shuffle = False + + assert dataset.name == "MS-COCO 2017" + + @patch('evaluation.dataset.requests.get') + def test_load_image_from_url_success(self, mock_get): + """Test _load_image_from_url successfully loads an image.""" + # Create a mock response with image data + mock_response = Mock() + mock_response.raise_for_status = Mock() + # Create a simple PNG image bytes + from io import BytesIO + img = Image.new('RGB', (100, 100), color='red') + img_bytes = BytesIO() + img.save(img_bytes, format='PNG') + mock_response.content = img_bytes.getvalue() + mock_get.return_value = mock_response + + # Create dataset instance without loading data + dataset = BaseDataset.__new__(MSCOCODataset) + result = dataset._load_image_from_url("http://example.com/image.jpg") + + assert isinstance(result, Image.Image) + mock_get.assert_called_once_with("http://example.com/image.jpg") + + @patch('evaluation.dataset.requests.get') + def test_load_image_from_url_failure(self, mock_get, capsys): + """Test _load_image_from_url returns None on failure.""" + mock_get.side_effect = Exception("Connection error") + + dataset = BaseDataset.__new__(MSCOCODataset) + result = dataset._load_image_from_url("http://example.com/bad.jpg") + + assert result is None + captured = capsys.readouterr() + assert "Load image from url failed" in captured.out + + @patch('evaluation.dataset.pd.read_parquet') + @patch('evaluation.dataset.tqdm') + def test_shuffle_option(self, mock_tqdm, mock_read_parquet): + """Test shuffle option shuffles the DataFrame.""" + mock_df = MagicMock(spec=pd.DataFrame) + mock_df.iloc = MagicMock() + mock_df.sample.return_value.reset_index.return_value = mock_df + mock_read_parquet.return_value = mock_df + mock_tqdm.return_value = [] + + with patch.object(MSCOCODataset, '_load_image_from_url'): + dataset = MSCOCODataset(max_samples=0, shuffle=True) + + assert dataset.shuffle is True + mock_df.sample.assert_called_once_with(frac=1) + + +# ============================================================================ +# Tests for VBenchDataset +# ============================================================================ + +class TestVBenchDataset: + """Tests for VBenchDataset class.""" + + @patch('builtins.open') + def test_initialization(self, mock_open): + """Test dataset initializes correctly.""" + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = ["prompt1\n", "prompt2\n", "prompt3\n"] + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=2, dimension="subject_consistency") + + assert dataset.max_samples == 2 + assert dataset.dimension == "subject_consistency" + assert dataset.shuffle is False + + @patch('builtins.open') + def test_name_property(self, mock_open): + """Test name property returns correct name.""" + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = [] + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=0, dimension="test") + + assert dataset.name == "VBench" + + @patch('builtins.open') + def test_prompts_loaded(self, mock_open): + """Test prompts are loaded from file.""" + test_prompts = ["A man walking\n", "A car driving\n", "A plane flying\n"] + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = test_prompts + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=3, dimension="test") + + assert len(dataset.prompts) == 3 + assert dataset.prompts == ["A man walking", "A car driving", "A plane flying"] + + @patch('builtins.open') + def test_max_samples_limit(self, mock_open): + """Test max_samples limits the number of prompts.""" + test_prompts = ["p1\n", "p2\n", "p3\n", "p4\n", "p5\n"] + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = test_prompts + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=2, dimension="test") + + assert len(dataset.prompts) == 2 + + @patch('builtins.open') + @patch('evaluation.dataset.random.shuffle') + def test_shuffle_option(self, mock_shuffle, mock_open): + """Test shuffle option shuffles the prompts.""" + test_prompts = ["p1\n", "p2\n", "p3\n"] + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = test_prompts + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=3, dimension="test", shuffle=True) + + assert dataset.shuffle is True + mock_shuffle.assert_called_once() + + @patch('builtins.open') + def test_file_path_format(self, mock_open): + """Test correct file path is used based on dimension.""" + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = [] + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=0, dimension="motion_smoothness") + + mock_open.assert_called_with( + "dataset/vbench/prompts_per_dimension/motion_smoothness.txt", "r" + ) + + @patch('builtins.open') + def test_no_references(self, mock_open): + """Test that VBenchDataset has no references.""" + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.readlines.return_value = ["p1\n"] + mock_open.return_value = mock_file + + dataset = VBenchDataset(max_samples=1, dimension="test") + + assert dataset.num_references == 0 + + def test_file_not_found(self): + """Test FileNotFoundError when dimension file doesn't exist.""" + with pytest.raises(FileNotFoundError): + VBenchDataset(max_samples=1, dimension="nonexistent_dimension") + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestDatasetIntegration: + """Integration tests for dataset functionality.""" + + def test_base_dataset_iteration(self): + """Test iterating over BaseDataset.""" + dataset = BaseDataset() + dataset.prompts = ["p1", "p2", "p3"] + + collected = [] + for i in range(len(dataset)): + collected.append(dataset[i]) + + assert collected == ["p1", "p2", "p3"] + + def test_dataset_with_references_iteration(self): + """Test iterating over dataset with references.""" + dataset = BaseDataset() + dataset.prompts = ["p1", "p2"] + mock_images = [Mock(spec=Image.Image), Mock(spec=Image.Image)] + dataset.references = mock_images + + for i in range(len(dataset)): + prompt, ref = dataset[i] + assert prompt == dataset.prompts[i] + assert ref == mock_images[i] + + @patch('evaluation.dataset.load_dataset') + def test_stable_diffusion_dataset_as_base_dataset(self, mock_load_dataset): + """Test StableDiffusionPromptsDataset works as BaseDataset.""" + mock_load_dataset.return_value = {"Prompt": ["test_prompt"]} + + dataset = StableDiffusionPromptsDataset(max_samples=1) + + # Should have all BaseDataset functionality + assert isinstance(dataset, BaseDataset) + assert len(dataset) == dataset.num_samples + assert dataset.get_prompt(0) == "test_prompt" + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + +class TestDatasetEdgeCases: + """Test edge cases for dataset classes.""" + + def test_empty_dataset(self): + """Test behavior with empty dataset.""" + dataset = BaseDataset(max_samples=0) + assert len(dataset) == 0 + assert dataset.num_samples == 0 + assert dataset.num_references == 0 + + def test_single_item_dataset(self): + """Test dataset with single item.""" + dataset = BaseDataset() + dataset.prompts = ["only_prompt"] + + assert len(dataset) == 1 + assert dataset[0] == "only_prompt" + assert dataset.get_prompt(0) == "only_prompt" + + def test_large_max_samples(self): + """Test with very large max_samples.""" + dataset = BaseDataset(max_samples=1000000) + assert dataset.max_samples == 1000000 + + def test_negative_index(self): + """Test negative indexing behavior.""" + dataset = BaseDataset() + dataset.prompts = ["first", "second", "third"] + + # Python lists support negative indexing + assert dataset[-1] == "third" + assert dataset[-2] == "second" + + @patch('evaluation.dataset.load_dataset') + def test_unicode_prompts(self, mock_load_dataset): + """Test handling of unicode prompts.""" + unicode_prompts = ["日本語プロンプト", "中文提示", "🎨 emoji art"] + mock_load_dataset.return_value = {"Prompt": unicode_prompts} + + dataset = StableDiffusionPromptsDataset(max_samples=3) + + assert dataset.prompts == unicode_prompts + assert dataset.get_prompt(0) == "日本語プロンプト" diff --git a/test/test_exceptions.py b/test/test_exceptions.py new file mode 100644 index 0000000..2222a46 --- /dev/null +++ b/test/test_exceptions.py @@ -0,0 +1,521 @@ +# Copyright 2025 THU-BPM MarkDiffusion. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for custom exceptions in MarkDiffusion. + +Tests cover all exception classes defined in exceptions/exceptions.py. +""" + +import pytest + +from exceptions.exceptions import ( + LengthMismatchError, + InvalidTextSourceModeError, + AlgorithmNameMismatchError, + InvalidDirectAnalyzerTypeError, + InvalidReferencedAnalyzerTypeError, + InvalidAnswerError, + TypeMismatchException, + ConfigurationError, + OpenAIModelConfigurationError, + DiversityValueError, + CodeExecutionError, + InvalidDetectModeError, + InvalidWatermarkModeError, +) + + +# ============================================================================ +# Tests for LengthMismatchError +# ============================================================================ + +class TestLengthMismatchError: + """Tests for LengthMismatchError exception.""" + + def test_message_format(self): + """Test that error message contains expected and actual values.""" + error = LengthMismatchError(expected=10, actual=5) + message = str(error) + assert "Expected length: 10" in message + assert "but got 5" in message + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(LengthMismatchError): + raise LengthMismatchError(expected=100, actual=50) + + def test_inheritance(self): + """Test that LengthMismatchError inherits from Exception.""" + error = LengthMismatchError(10, 5) + assert isinstance(error, Exception) + + def test_different_values(self): + """Test with different expected and actual values.""" + test_cases = [(0, 1), (100, 0), (1000, 999), (1, 1000000)] + for expected, actual in test_cases: + error = LengthMismatchError(expected, actual) + assert str(expected) in str(error) + assert str(actual) in str(error) + + +# ============================================================================ +# Tests for InvalidTextSourceModeError +# ============================================================================ + +class TestInvalidTextSourceModeError: + """Tests for InvalidTextSourceModeError exception.""" + + def test_message_format(self): + """Test that error message contains the invalid mode.""" + error = InvalidTextSourceModeError("invalid_mode") + message = str(error) + assert "'invalid_mode' is not a valid text source mode" in message + assert "natural" in message + assert "generated" in message + + def test_inheritance(self): + """Test that InvalidTextSourceModeError inherits from ValueError.""" + error = InvalidTextSourceModeError("test") + assert isinstance(error, ValueError) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidTextSourceModeError): + raise InvalidTextSourceModeError("bad_mode") + + def test_various_invalid_modes(self): + """Test with various invalid mode names.""" + invalid_modes = ["invalid", "", "Natural", "GENERATED", "random", "123"] + for mode in invalid_modes: + error = InvalidTextSourceModeError(mode) + assert f"'{mode}'" in str(error) + + +# ============================================================================ +# Tests for AlgorithmNameMismatchError +# ============================================================================ + +class TestAlgorithmNameMismatchError: + """Tests for AlgorithmNameMismatchError exception.""" + + def test_message_format(self): + """Test that error message contains expected and actual algorithm names.""" + error = AlgorithmNameMismatchError(expected="TR", actual="GS") + message = str(error) + assert "TR" in message + assert "GS" in message + assert "does not match" in message + + def test_inheritance(self): + """Test that AlgorithmNameMismatchError inherits from ValueError.""" + error = AlgorithmNameMismatchError("A", "B") + assert isinstance(error, ValueError) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(AlgorithmNameMismatchError): + raise AlgorithmNameMismatchError(expected="Expected", actual="Actual") + + def test_various_algorithm_names(self): + """Test with various algorithm name combinations.""" + test_cases = [ + ("TR", "GS"), + ("VideoShield", "VideoMark"), + ("PRC", "ROBIN"), + ("SFW", "SEAL"), + ] + for expected, actual in test_cases: + error = AlgorithmNameMismatchError(expected, actual) + assert expected in str(error) + assert actual in str(error) + + +# ============================================================================ +# Tests for InvalidDirectAnalyzerTypeError +# ============================================================================ + +class TestInvalidDirectAnalyzerTypeError: + """Tests for InvalidDirectAnalyzerTypeError exception.""" + + def test_default_message(self): + """Test default error message.""" + error = InvalidDirectAnalyzerTypeError() + assert "DirectTextQualityAnalyzer" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + custom_msg = "Custom analyzer error message" + error = InvalidDirectAnalyzerTypeError(custom_msg) + assert str(error) == custom_msg + + def test_inheritance(self): + """Test that InvalidDirectAnalyzerTypeError inherits from Exception.""" + error = InvalidDirectAnalyzerTypeError() + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidDirectAnalyzerTypeError): + raise InvalidDirectAnalyzerTypeError() + + +# ============================================================================ +# Tests for InvalidReferencedAnalyzerTypeError +# ============================================================================ + +class TestInvalidReferencedAnalyzerTypeError: + """Tests for InvalidReferencedAnalyzerTypeError exception.""" + + def test_default_message(self): + """Test default error message.""" + error = InvalidReferencedAnalyzerTypeError() + assert "ReferencedTextQualityAnalyzer" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + custom_msg = "Custom referenced analyzer error" + error = InvalidReferencedAnalyzerTypeError(custom_msg) + assert str(error) == custom_msg + + def test_inheritance(self): + """Test that InvalidReferencedAnalyzerTypeError inherits from Exception.""" + error = InvalidReferencedAnalyzerTypeError() + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidReferencedAnalyzerTypeError): + raise InvalidReferencedAnalyzerTypeError() + + +# ============================================================================ +# Tests for InvalidAnswerError +# ============================================================================ + +class TestInvalidAnswerError: + """Tests for InvalidAnswerError exception.""" + + def test_message_format(self): + """Test that error message contains the invalid answer.""" + error = InvalidAnswerError("bad_answer") + assert "Invalid answer: bad_answer" in str(error) + + def test_inheritance(self): + """Test that InvalidAnswerError inherits from ValueError.""" + error = InvalidAnswerError("test") + assert isinstance(error, ValueError) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidAnswerError): + raise InvalidAnswerError("invalid") + + def test_various_answer_types(self): + """Test with various answer types.""" + answers = ["string_answer", 123, None, "", ["list"], {"dict": "value"}] + for answer in answers: + error = InvalidAnswerError(answer) + assert "Invalid answer" in str(error) + + +# ============================================================================ +# Tests for TypeMismatchException +# ============================================================================ + +class TestTypeMismatchException: + """Tests for TypeMismatchException exception.""" + + def test_message_with_types(self): + """Test error message with expected and found types.""" + error = TypeMismatchException(expected_type=int, found_type=str) + message = str(error) + assert "int" in message + assert "str" in message + + def test_custom_message(self): + """Test custom error message overrides default.""" + custom_msg = "Custom type mismatch message" + error = TypeMismatchException(int, str, custom_msg) + assert str(error) == custom_msg + + def test_attributes_stored(self): + """Test that expected_type and found_type are stored.""" + error = TypeMismatchException(expected_type=list, found_type=dict) + assert error.expected_type == list + assert error.found_type == dict + + def test_inheritance(self): + """Test that TypeMismatchException inherits from Exception.""" + error = TypeMismatchException(int, str) + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(TypeMismatchException): + raise TypeMismatchException(int, str) + + def test_various_type_combinations(self): + """Test with various type combinations.""" + type_pairs = [ + (int, str), + (list, tuple), + (dict, list), + (float, int), + (str, bytes), + ] + for expected, found in type_pairs: + error = TypeMismatchException(expected, found) + assert expected.__name__ in str(error) + assert found.__name__ in str(error) + + +# ============================================================================ +# Tests for ConfigurationError +# ============================================================================ + +class TestConfigurationError: + """Tests for ConfigurationError exception.""" + + def test_message_stored(self): + """Test that message is stored correctly.""" + error = ConfigurationError("Test configuration error") + assert error.message == "Test configuration error" + assert str(error) == "Test configuration error" + + def test_inheritance(self): + """Test that ConfigurationError inherits from Exception.""" + error = ConfigurationError("test") + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(ConfigurationError): + raise ConfigurationError("Config error") + + def test_various_messages(self): + """Test with various error messages.""" + messages = [ + "Missing required field", + "Invalid value for parameter", + "Configuration file not found", + "", + ] + for msg in messages: + error = ConfigurationError(msg) + assert error.message == msg + + +# ============================================================================ +# Tests for OpenAIModelConfigurationError +# ============================================================================ + +class TestOpenAIModelConfigurationError: + """Tests for OpenAIModelConfigurationError exception.""" + + def test_message_format(self): + """Test error message format.""" + error = OpenAIModelConfigurationError("Invalid API key") + assert str(error) == "Invalid API key" + + def test_inheritance(self): + """Test that OpenAIModelConfigurationError inherits from Exception.""" + error = OpenAIModelConfigurationError("test") + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(OpenAIModelConfigurationError): + raise OpenAIModelConfigurationError("API configuration error") + + +# ============================================================================ +# Tests for DiversityValueError +# ============================================================================ + +class TestDiversityValueError: + """Tests for DiversityValueError exception.""" + + def test_message_format(self): + """Test that error message contains diversity type and valid values.""" + error = DiversityValueError("lexical") + message = str(error) + assert "lexical" in message + assert "0, 20, 40, 60, 80, 100" in message + + def test_inheritance(self): + """Test that DiversityValueError inherits from Exception.""" + error = DiversityValueError("test") + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(DiversityValueError): + raise DiversityValueError("semantic") + + def test_various_diversity_types(self): + """Test with various diversity type names.""" + diversity_types = ["lexical", "semantic", "syntactic", "custom"] + for dtype in diversity_types: + error = DiversityValueError(dtype) + assert dtype in str(error) + + +# ============================================================================ +# Tests for CodeExecutionError +# ============================================================================ + +class TestCodeExecutionError: + """Tests for CodeExecutionError exception.""" + + def test_default_message(self): + """Test default error message.""" + error = CodeExecutionError() + assert error.message == "Error during code execution" + assert "Error during code execution" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + custom_msg = "Specific code execution failure" + error = CodeExecutionError(custom_msg) + assert error.message == custom_msg + assert str(error) == custom_msg + + def test_inheritance(self): + """Test that CodeExecutionError inherits from Exception.""" + error = CodeExecutionError() + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(CodeExecutionError): + raise CodeExecutionError() + + +# ============================================================================ +# Tests for InvalidDetectModeError +# ============================================================================ + +class TestInvalidDetectModeError: + """Tests for InvalidDetectModeError exception.""" + + def test_mode_stored(self): + """Test that mode is stored correctly.""" + error = InvalidDetectModeError("bad_mode") + assert error.mode == "bad_mode" + + def test_default_message(self): + """Test default error message format.""" + error = InvalidDetectModeError("test_mode") + assert error.message == "Invalid detect mode configuration" + assert "test_mode" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + error = InvalidDetectModeError("mode", "Custom detect error") + assert error.message == "Custom detect error" + assert "mode" in str(error) + + def test_inheritance(self): + """Test that InvalidDetectModeError inherits from Exception.""" + error = InvalidDetectModeError("test") + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidDetectModeError): + raise InvalidDetectModeError("invalid") + + +# ============================================================================ +# Tests for InvalidWatermarkModeError +# ============================================================================ + +class TestInvalidWatermarkModeError: + """Tests for InvalidWatermarkModeError exception.""" + + def test_mode_stored(self): + """Test that mode is stored correctly.""" + error = InvalidWatermarkModeError("bad_mode") + assert error.mode == "bad_mode" + + def test_default_message(self): + """Test default error message format.""" + error = InvalidWatermarkModeError("test_mode") + assert error.message == "Invalid watermark mode configuration" + assert "test_mode" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + error = InvalidWatermarkModeError("mode", "Custom watermark error") + assert error.message == "Custom watermark error" + assert "mode" in str(error) + + def test_inheritance(self): + """Test that InvalidWatermarkModeError inherits from Exception.""" + error = InvalidWatermarkModeError("test") + assert isinstance(error, Exception) + + def test_raises_correctly(self): + """Test that exception can be raised and caught.""" + with pytest.raises(InvalidWatermarkModeError): + raise InvalidWatermarkModeError("invalid") + + +# ============================================================================ +# Integration Tests - Exception Handling Patterns +# ============================================================================ + +class TestExceptionHandlingPatterns: + """Test common exception handling patterns.""" + + def test_catch_value_errors(self): + """Test that ValueError subclasses can be caught as ValueError.""" + value_error_exceptions = [ + InvalidTextSourceModeError("test"), + AlgorithmNameMismatchError("A", "B"), + InvalidAnswerError("test"), + ] + for exc in value_error_exceptions: + with pytest.raises(ValueError): + raise exc + + def test_exception_chaining(self): + """Test exception chaining works correctly.""" + try: + try: + raise ValueError("Original error") + except ValueError as e: + raise ConfigurationError("Config failed") from e + except ConfigurationError as e: + assert e.__cause__ is not None + assert isinstance(e.__cause__, ValueError) + + def test_exception_in_context_manager(self): + """Test exceptions work in context managers.""" + class TestContext: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is LengthMismatchError: + return True # Suppress the exception + return False + + with TestContext(): + raise LengthMismatchError(10, 5) # Should be suppressed + + with pytest.raises(ConfigurationError): + with TestContext(): + raise ConfigurationError("Not suppressed") diff --git a/test/test_image_editor.py b/test/test_image_editor.py new file mode 100644 index 0000000..9cce9d7 --- /dev/null +++ b/test/test_image_editor.py @@ -0,0 +1,824 @@ +# Copyright 2025 THU-BPM MarkDiffusion. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for image editor classes in MarkDiffusion. + +Tests cover: +- ImageEditor: Base class +- JPEGCompression: JPEG quality compression +- Rotation: Image rotation +- CrSc: Crop and scale +- GaussianBlurring: Gaussian blur filter +- GaussianNoise: Add Gaussian noise +- Brightness: Brightness adjustment +- Mask: Random rectangular masks +- Overlay: Random stroke overlays +- AdaptiveNoiseInjection: Adaptive noise based on image features +""" + +import pytest +import numpy as np +from PIL import Image +import os +import tempfile + +from evaluation.tools.image_editor import ( + ImageEditor, + JPEGCompression, + Rotation, + CrSc, + GaussianBlurring, + GaussianNoise, + Brightness, + Mask, + Overlay, + AdaptiveNoiseInjection, +) + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def sample_rgb_image(): + """Create a sample RGB test image.""" + return Image.new("RGB", (256, 256), color="red") + + +@pytest.fixture +def sample_gradient_image(): + """Create a gradient image for testing edge detection.""" + arr = np.zeros((256, 256, 3), dtype=np.uint8) + for i in range(256): + arr[:, i, :] = i # Horizontal gradient + return Image.fromarray(arr) + + +@pytest.fixture +def sample_complex_image(): + """Create a more complex image with varied content.""" + arr = np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8) + return Image.fromarray(arr) + + +@pytest.fixture +def sample_dark_image(): + """Create a dark image for testing adaptive noise.""" + arr = np.full((256, 256, 3), 30, dtype=np.uint8) + return Image.fromarray(arr) + + +@pytest.fixture +def sample_bright_image(): + """Create a bright image for testing.""" + arr = np.full((256, 256, 3), 220, dtype=np.uint8) + return Image.fromarray(arr) + + +# ============================================================================ +# Tests for ImageEditor Base Class +# ============================================================================ + +class TestImageEditor: + """Tests for ImageEditor base class.""" + + def test_initialization(self): + """Test base class can be instantiated.""" + editor = ImageEditor() + assert editor is not None + + def test_edit_method_exists(self): + """Test edit method exists.""" + editor = ImageEditor() + assert hasattr(editor, 'edit') + + def test_edit_returns_none_by_default(self, sample_rgb_image): + """Test base edit method returns None.""" + editor = ImageEditor() + result = editor.edit(sample_rgb_image) + assert result is None + + +# ============================================================================ +# Tests for JPEGCompression +# ============================================================================ + +class TestJPEGCompression: + """Tests for JPEGCompression editor.""" + + def test_default_quality(self): + """Test default quality is 95.""" + editor = JPEGCompression() + assert editor.quality == 95 + + def test_custom_quality(self): + """Test custom quality setting.""" + editor = JPEGCompression(quality=50) + assert editor.quality == 50 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = JPEGCompression(quality=75) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = JPEGCompression(quality=75) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_low_quality_changes_image(self, sample_complex_image): + """Test low quality compression changes the image.""" + editor = JPEGCompression(quality=10) + result = editor.edit(sample_complex_image) + + # Convert to arrays and compare + original_arr = np.array(sample_complex_image) + result_arr = np.array(result) + + # Should be different due to compression artifacts + assert not np.array_equal(original_arr, result_arr) + + def test_high_quality_preserves_more(self, sample_complex_image): + """Test high quality compression preserves more detail.""" + low_editor = JPEGCompression(quality=10) + high_editor = JPEGCompression(quality=95) + + original_arr = np.array(sample_complex_image) + low_result = np.array(low_editor.edit(sample_complex_image)) + high_result = np.array(high_editor.edit(sample_complex_image)) + + low_diff = np.mean(np.abs(original_arr.astype(float) - low_result.astype(float))) + high_diff = np.mean(np.abs(original_arr.astype(float) - high_result.astype(float))) + + # High quality should have less difference + assert high_diff < low_diff + + def test_temp_file_cleanup(self, sample_rgb_image): + """Test temporary file is cleaned up.""" + editor = JPEGCompression(quality=75) + editor.edit(sample_rgb_image) + assert not os.path.exists("temp.jpg") + + def test_various_quality_levels(self, sample_rgb_image): + """Test various quality levels.""" + for quality in [1, 25, 50, 75, 100]: + editor = JPEGCompression(quality=quality) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + +# ============================================================================ +# Tests for Rotation +# ============================================================================ + +class TestRotation: + """Tests for Rotation editor.""" + + def test_default_parameters(self): + """Test default rotation parameters.""" + editor = Rotation() + assert editor.angle == 30 + assert editor.expand is False + + def test_custom_angle(self): + """Test custom rotation angle.""" + editor = Rotation(angle=45) + assert editor.angle == 45 + + def test_custom_expand(self): + """Test custom expand parameter.""" + editor = Rotation(angle=30, expand=True) + assert editor.expand is True + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = Rotation(angle=45) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_no_expand_preserves_size(self, sample_rgb_image): + """Test rotation without expand preserves size.""" + editor = Rotation(angle=45, expand=False) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_expand_changes_size(self, sample_rgb_image): + """Test rotation with expand may change size.""" + editor = Rotation(angle=45, expand=True) + result = editor.edit(sample_rgb_image) + # Rotated image with expand should be larger + orig_w, orig_h = sample_rgb_image.size + new_w, new_h = result.size + assert new_w >= orig_w or new_h >= orig_h + + def test_zero_rotation(self, sample_rgb_image): + """Test zero rotation doesn't change image significantly.""" + editor = Rotation(angle=0) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + np.testing.assert_array_equal(original_arr, result_arr) + + def test_360_rotation(self, sample_rgb_image): + """Test 360 degree rotation returns similar image.""" + editor = Rotation(angle=360) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + np.testing.assert_array_equal(original_arr, result_arr) + + def test_negative_angle(self, sample_rgb_image): + """Test negative rotation angle.""" + editor = Rotation(angle=-45) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + +# ============================================================================ +# Tests for CrSc (Crop and Scale) +# ============================================================================ + +class TestCrSc: + """Tests for CrSc (Crop and Scale) editor.""" + + def test_default_crop_ratio(self): + """Test default crop ratio is 0.8.""" + editor = CrSc() + assert editor.crop_ratio == 0.8 + + def test_custom_crop_ratio(self): + """Test custom crop ratio.""" + editor = CrSc(crop_ratio=0.5) + assert editor.crop_ratio == 0.5 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = CrSc(crop_ratio=0.8) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves original size after scaling back.""" + editor = CrSc(crop_ratio=0.8) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_center_crop(self, sample_gradient_image): + """Test that crop is centered.""" + editor = CrSc(crop_ratio=0.5) + result = editor.edit(sample_gradient_image) + + # Result should be scaled back, but content should be from center + assert result.size == sample_gradient_image.size + + def test_various_crop_ratios(self, sample_rgb_image): + """Test various crop ratios.""" + for ratio in [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]: + editor = CrSc(crop_ratio=ratio) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_crop_ratio_one(self, sample_rgb_image): + """Test crop ratio of 1.0 (no crop).""" + editor = CrSc(crop_ratio=1.0) + result = editor.edit(sample_rgb_image) + + # Should be very similar to original + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + # Allow small differences due to resize + assert np.allclose(original_arr, result_arr, atol=1) + + +# ============================================================================ +# Tests for GaussianBlurring +# ============================================================================ + +class TestGaussianBlurring: + """Tests for GaussianBlurring editor.""" + + def test_default_radius(self): + """Test default blur radius is 2.""" + editor = GaussianBlurring() + assert editor.radius == 2 + + def test_custom_radius(self): + """Test custom blur radius.""" + editor = GaussianBlurring(radius=5) + assert editor.radius == 5 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = GaussianBlurring(radius=2) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = GaussianBlurring(radius=2) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_blur_reduces_variance(self, sample_complex_image): + """Test blur reduces local variance in complex image.""" + editor = GaussianBlurring(radius=5) + result = editor.edit(sample_complex_image) + + original_arr = np.array(sample_complex_image).astype(float) + result_arr = np.array(result).astype(float) + + # Blurred image should have lower local variance + # Compare variance of small patches + orig_var = np.var(original_arr) + result_var = np.var(result_arr) + + assert result_var < orig_var + + def test_larger_radius_more_blur(self, sample_complex_image): + """Test larger radius produces more blur.""" + small_blur = GaussianBlurring(radius=1) + large_blur = GaussianBlurring(radius=10) + + small_result = np.array(small_blur.edit(sample_complex_image)) + large_result = np.array(large_blur.edit(sample_complex_image)) + + # Larger blur should have lower variance + assert np.var(large_result) < np.var(small_result) + + +# ============================================================================ +# Tests for GaussianNoise +# ============================================================================ + +class TestGaussianNoise: + """Tests for GaussianNoise editor.""" + + def test_default_sigma(self): + """Test default sigma is 25.0.""" + editor = GaussianNoise() + assert editor.sigma == 25.0 + + def test_custom_sigma(self): + """Test custom sigma.""" + editor = GaussianNoise(sigma=50.0) + assert editor.sigma == 50.0 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = GaussianNoise(sigma=25.0) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = GaussianNoise(sigma=25.0) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_noise_changes_image(self, sample_rgb_image): + """Test noise changes the image.""" + editor = GaussianNoise(sigma=25.0) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + assert not np.array_equal(original_arr, result_arr) + + def test_higher_sigma_more_noise(self, sample_rgb_image): + """Test higher sigma produces more noise.""" + low_noise = GaussianNoise(sigma=10.0) + high_noise = GaussianNoise(sigma=100.0) + + original_arr = np.array(sample_rgb_image).astype(float) + low_result = np.array(low_noise.edit(sample_rgb_image)).astype(float) + high_result = np.array(high_noise.edit(sample_rgb_image)).astype(float) + + low_diff = np.mean(np.abs(original_arr - low_result)) + high_diff = np.mean(np.abs(original_arr - high_result)) + + assert high_diff > low_diff + + def test_output_clipped_to_valid_range(self, sample_rgb_image): + """Test output values are in valid [0, 255] range.""" + editor = GaussianNoise(sigma=100.0) + result = editor.edit(sample_rgb_image) + result_arr = np.array(result) + + assert result_arr.min() >= 0 + assert result_arr.max() <= 255 + + def test_zero_sigma_preserves_image(self, sample_rgb_image): + """Test zero sigma doesn't add noise.""" + editor = GaussianNoise(sigma=0.0) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + np.testing.assert_array_equal(original_arr, result_arr) + + +# ============================================================================ +# Tests for Brightness +# ============================================================================ + +class TestBrightness: + """Tests for Brightness editor.""" + + def test_default_factor(self): + """Test default brightness factor is 1.2.""" + editor = Brightness() + assert editor.factor == 1.2 + + def test_custom_factor(self): + """Test custom brightness factor.""" + editor = Brightness(factor=0.5) + assert editor.factor == 0.5 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = Brightness(factor=1.5) + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = Brightness(factor=1.5) + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_factor_one_preserves_image(self, sample_rgb_image): + """Test factor of 1.0 preserves image.""" + editor = Brightness(factor=1.0) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + np.testing.assert_array_equal(original_arr, result_arr) + + def test_higher_factor_increases_brightness(self, sample_dark_image): + """Test higher factor increases brightness.""" + editor = Brightness(factor=2.0) + result = editor.edit(sample_dark_image) + + original_mean = np.mean(np.array(sample_dark_image)) + result_mean = np.mean(np.array(result)) + + assert result_mean > original_mean + + def test_lower_factor_decreases_brightness(self, sample_bright_image): + """Test lower factor decreases brightness.""" + editor = Brightness(factor=0.5) + result = editor.edit(sample_bright_image) + + original_mean = np.mean(np.array(sample_bright_image)) + result_mean = np.mean(np.array(result)) + + assert result_mean < original_mean + + +# ============================================================================ +# Tests for Mask +# ============================================================================ + +class TestMask: + """Tests for Mask editor.""" + + def test_default_parameters(self): + """Test default mask parameters.""" + editor = Mask() + assert editor.mask_ratio == 0.1 + assert editor.num_masks == 5 + + def test_custom_parameters(self): + """Test custom mask parameters.""" + editor = Mask(mask_ratio=0.2, num_masks=10) + assert editor.mask_ratio == 0.2 + assert editor.num_masks == 10 + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = Mask() + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = Mask() + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_masks_add_black_regions(self, sample_bright_image): + """Test masks add black regions.""" + editor = Mask(num_masks=10) + result = editor.edit(sample_bright_image) + + result_arr = np.array(result) + + # Should have some black pixels (all zeros) + black_pixels = np.all(result_arr == 0, axis=2) + assert np.any(black_pixels) + + def test_original_not_modified(self, sample_rgb_image): + """Test original image is not modified.""" + original_arr = np.array(sample_rgb_image).copy() + + editor = Mask() + editor.edit(sample_rgb_image) + + current_arr = np.array(sample_rgb_image) + np.testing.assert_array_equal(original_arr, current_arr) + + +# ============================================================================ +# Tests for Overlay +# ============================================================================ + +class TestOverlay: + """Tests for Overlay editor.""" + + def test_default_parameters(self): + """Test default overlay parameters.""" + editor = Overlay() + assert editor.num_strokes == 10 + assert editor.stroke_width == 5 + assert editor.stroke_type == 'random' + + def test_custom_parameters(self): + """Test custom overlay parameters.""" + editor = Overlay(num_strokes=20, stroke_width=10, stroke_type='black') + assert editor.num_strokes == 20 + assert editor.stroke_width == 10 + assert editor.stroke_type == 'black' + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = Overlay() + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = Overlay() + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_overlay_changes_image(self, sample_rgb_image): + """Test overlay changes the image.""" + editor = Overlay(num_strokes=20) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + assert not np.array_equal(original_arr, result_arr) + + def test_black_stroke_type(self, sample_bright_image): + """Test black stroke type adds black pixels.""" + editor = Overlay(num_strokes=20, stroke_width=10, stroke_type='black') + result = editor.edit(sample_bright_image) + + result_arr = np.array(result) + black_pixels = np.all(result_arr == 0, axis=2) + assert np.any(black_pixels) + + def test_white_stroke_type(self, sample_dark_image): + """Test white stroke type adds white pixels.""" + editor = Overlay(num_strokes=20, stroke_width=10, stroke_type='white') + result = editor.edit(sample_dark_image) + + result_arr = np.array(result) + white_pixels = np.all(result_arr == 255, axis=2) + assert np.any(white_pixels) + + def test_original_not_modified(self, sample_rgb_image): + """Test original image is not modified.""" + original_arr = np.array(sample_rgb_image).copy() + + editor = Overlay() + editor.edit(sample_rgb_image) + + current_arr = np.array(sample_rgb_image) + np.testing.assert_array_equal(original_arr, current_arr) + + +# ============================================================================ +# Tests for AdaptiveNoiseInjection +# ============================================================================ + +class TestAdaptiveNoiseInjection: + """Tests for AdaptiveNoiseInjection editor.""" + + def test_default_parameters(self): + """Test default parameters.""" + editor = AdaptiveNoiseInjection() + assert editor.intensity == 0.5 + assert editor.auto_select is True + + def test_custom_parameters(self): + """Test custom parameters.""" + editor = AdaptiveNoiseInjection(intensity=0.8, auto_select=False) + assert editor.intensity == 0.8 + assert editor.auto_select is False + + def test_edit_returns_image(self, sample_rgb_image): + """Test edit returns a PIL Image.""" + editor = AdaptiveNoiseInjection() + result = editor.edit(sample_rgb_image) + assert isinstance(result, Image.Image) + + def test_edit_preserves_size(self, sample_rgb_image): + """Test edit preserves image size.""" + editor = AdaptiveNoiseInjection() + result = editor.edit(sample_rgb_image) + assert result.size == sample_rgb_image.size + + def test_noise_changes_image(self, sample_rgb_image): + """Test noise injection changes the image.""" + editor = AdaptiveNoiseInjection(intensity=0.5) + result = editor.edit(sample_rgb_image) + + original_arr = np.array(sample_rgb_image) + result_arr = np.array(result) + + assert not np.array_equal(original_arr, result_arr) + + def test_analyze_image_features(self, sample_complex_image): + """Test _analyze_image_features returns expected keys.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_complex_image).astype(np.float32) + + features = editor._analyze_image_features(img_arr) + + assert 'brightness_mean' in features + assert 'brightness_std' in features + assert 'edge_density' in features + assert 'texture_complexity' in features + + def test_select_noise_type_dark_image(self, sample_dark_image): + """Test noise type selection for dark image.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_dark_image).astype(np.float32) + + features = editor._analyze_image_features(img_arr) + noise_type = editor._select_noise_type(features) + + # Dark images should use gaussian noise + assert noise_type == 'gaussian' + + def test_auto_select_false_uses_mixed_noise(self, sample_rgb_image): + """Test auto_select=False uses mixed noise.""" + editor = AdaptiveNoiseInjection(auto_select=False) + result = editor.edit(sample_rgb_image) + + assert isinstance(result, Image.Image) + + def test_add_gaussian_noise(self, sample_rgb_image): + """Test _add_gaussian_noise method.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_rgb_image).astype(np.float32) + + noisy = editor._add_gaussian_noise(img_arr, sigma=25) + + assert noisy.shape == img_arr.shape + assert noisy.dtype == np.uint8 + assert not np.array_equal(noisy, img_arr.astype(np.uint8)) + + def test_add_salt_pepper_noise(self, sample_rgb_image): + """Test _add_salt_pepper_noise method.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_rgb_image).astype(np.float32) + + noisy = editor._add_salt_pepper_noise(img_arr, amount=0.1) + + assert noisy.shape == img_arr.shape + # Should have some extreme values (0 or 255) + assert np.any(noisy == 0) or np.any(noisy == 255) + + def test_add_poisson_noise(self, sample_rgb_image): + """Test _add_poisson_noise method.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_rgb_image).astype(np.float32) + + noisy = editor._add_poisson_noise(img_arr) + + assert noisy.shape == img_arr.shape + assert noisy.dtype == np.uint8 + + def test_add_speckle_noise(self, sample_rgb_image): + """Test _add_speckle_noise method.""" + editor = AdaptiveNoiseInjection() + img_arr = np.array(sample_rgb_image).astype(np.float32) + + noisy = editor._add_speckle_noise(img_arr, variance=0.5) + + assert noisy.shape == img_arr.shape + assert noisy.dtype == np.uint8 + + def test_output_clipped_to_valid_range(self, sample_rgb_image): + """Test output values are in valid [0, 255] range.""" + editor = AdaptiveNoiseInjection(intensity=1.0) + result = editor.edit(sample_rgb_image) + result_arr = np.array(result) + + assert result_arr.min() >= 0 + assert result_arr.max() <= 255 + + def test_grayscale_feature_analysis(self): + """Test feature analysis works with grayscale-like input.""" + editor = AdaptiveNoiseInjection() + + # 2D array (grayscale) + gray_arr = np.random.randint(0, 256, (256, 256)).astype(np.float32) + features = editor._analyze_image_features(gray_arr) + + assert 'brightness_mean' in features + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestEditorChaining: + """Test chaining multiple editors.""" + + def test_chain_multiple_editors(self, sample_rgb_image): + """Test applying multiple editors in sequence.""" + editors = [ + JPEGCompression(quality=75), + Rotation(angle=15), + GaussianBlurring(radius=2), + Brightness(factor=1.1), + ] + + result = sample_rgb_image + for editor in editors: + result = editor.edit(result) + + assert isinstance(result, Image.Image) + assert result.size[0] > 0 and result.size[1] > 0 + + def test_all_editors_work_together(self, sample_complex_image): + """Test all editors can process the same image.""" + editors = [ + JPEGCompression(quality=90), + Rotation(angle=10), + CrSc(crop_ratio=0.9), + GaussianBlurring(radius=1), + GaussianNoise(sigma=10), + Brightness(factor=1.05), + Mask(num_masks=2), + Overlay(num_strokes=5), + AdaptiveNoiseInjection(intensity=0.3), + ] + + for editor in editors: + result = editor.edit(sample_complex_image) + assert isinstance(result, Image.Image) + assert result.size[0] > 0 and result.size[1] > 0 + + +class TestEditorConsistency: + """Test editor consistency and reproducibility.""" + + def test_same_params_same_result_deterministic(self, sample_rgb_image): + """Test deterministic editors produce same result.""" + # JPEGCompression, Rotation, CrSc, Brightness are deterministic + editor = Rotation(angle=45) + + result1 = editor.edit(sample_rgb_image) + result2 = editor.edit(sample_rgb_image) + + np.testing.assert_array_equal(np.array(result1), np.array(result2)) + + def test_random_editors_produce_different_results(self, sample_rgb_image): + """Test random editors may produce different results.""" + editor = GaussianNoise(sigma=50) + + result1 = editor.edit(sample_rgb_image) + result2 = editor.edit(sample_rgb_image) + + # Very unlikely to be identical + assert not np.array_equal(np.array(result1), np.array(result2)) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..3949d26 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,1488 @@ +# Copyright 2025 THU-BPM MarkDiffusion. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the utils module. + +This file contains tests for: +- utils/utils.py: General utility functions +- utils/callbacks.py: Callback classes for diffusion models +- utils/pipeline_utils.py: Pipeline type detection utilities +- utils/media_utils.py: Media conversion utilities +- utils/diffusion_config.py: Diffusion configuration +""" + +import pytest +import torch +import numpy as np +import json +import os +import tempfile +import shutil +from pathlib import Path +from PIL import Image +from unittest.mock import Mock, MagicMock, patch + + +# ============================================================================ +# Tests for utils/utils.py +# ============================================================================ + +class TestInheritDocstring: + """Tests for inherit_docstring decorator.""" + + def test_inherit_docstring_from_base_class(self): + """Test that docstrings are inherited from base classes.""" + from utils.utils import inherit_docstring + + class Base: + def method(self): + """Base method docstring.""" + pass + + @inherit_docstring + class Derived(Base): + def method(self): + pass + + assert Derived.method.__doc__ == "Base method docstring." + + def test_no_override_existing_docstring(self): + """Test that existing docstrings are not overridden.""" + from utils.utils import inherit_docstring + + class Base: + def method(self): + """Base method docstring.""" + pass + + @inherit_docstring + class Derived(Base): + def method(self): + """Derived method docstring.""" + pass + + assert Derived.method.__doc__ == "Derived method docstring." + + def test_no_docstring_in_base(self): + """Test behavior when base class has no docstring.""" + from utils.utils import inherit_docstring + + class Base: + def method(self): + pass + + @inherit_docstring + class Derived(Base): + def method(self): + pass + + assert Derived.method.__doc__ is None + + def test_multiple_inheritance(self): + """Test docstring inheritance with multiple base classes.""" + from utils.utils import inherit_docstring + + class Base1: + def method(self): + """Base1 method docstring.""" + pass + + class Base2: + def method(self): + """Base2 method docstring.""" + pass + + @inherit_docstring + class Derived(Base1, Base2): + def method(self): + pass + + # Should inherit from first base class + assert Derived.method.__doc__ == "Base1 method docstring." + + +class TestLoadConfigFile: + """Tests for load_config_file function.""" + + def test_load_valid_json(self, tmp_path): + """Test loading a valid JSON configuration file.""" + from utils.utils import load_config_file + + config_data = {"key": "value", "number": 42, "nested": {"a": 1}} + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps(config_data)) + + result = load_config_file(str(config_file)) + assert result == config_data + + def test_load_nonexistent_file(self, capsys): + """Test loading a nonexistent file returns None.""" + from utils.utils import load_config_file + + result = load_config_file("/nonexistent/path/config.json") + assert result is None + + captured = capsys.readouterr() + assert "does not exist" in captured.out + + def test_load_invalid_json(self, tmp_path, capsys): + """Test loading an invalid JSON file returns None.""" + from utils.utils import load_config_file + + config_file = tmp_path / "invalid.json" + config_file.write_text("{invalid json content") + + result = load_config_file(str(config_file)) + assert result is None + + captured = capsys.readouterr() + assert "Error decoding JSON" in captured.out + + def test_load_empty_json(self, tmp_path): + """Test loading an empty JSON object.""" + from utils.utils import load_config_file + + config_file = tmp_path / "empty.json" + config_file.write_text("{}") + + result = load_config_file(str(config_file)) + assert result == {} + + +class TestLoadJsonAsList: + """Tests for load_json_as_list function.""" + + def test_load_jsonl_file(self, tmp_path): + """Test loading a JSONL file (one JSON object per line).""" + from utils.utils import load_json_as_list + + data = [ + {"id": 1, "name": "first"}, + {"id": 2, "name": "second"}, + {"id": 3, "name": "third"}, + ] + jsonl_file = tmp_path / "data.jsonl" + with open(jsonl_file, "w") as f: + for item in data: + f.write(json.dumps(item) + "\n") + + result = load_json_as_list(str(jsonl_file)) + assert result == data + + def test_load_empty_jsonl(self, tmp_path): + """Test loading an empty JSONL file.""" + from utils.utils import load_json_as_list + + jsonl_file = tmp_path / "empty.jsonl" + jsonl_file.write_text("") + + result = load_json_as_list(str(jsonl_file)) + assert result == [] + + +class TestCreateDirectoryForFile: + """Tests for create_directory_for_file function.""" + + def test_create_directory(self, tmp_path): + """Test creating a directory for a file path.""" + from utils.utils import create_directory_for_file + + file_path = tmp_path / "new_dir" / "subdir" / "file.txt" + create_directory_for_file(str(file_path)) + + assert (tmp_path / "new_dir" / "subdir").exists() + assert (tmp_path / "new_dir" / "subdir").is_dir() + + def test_existing_directory(self, tmp_path): + """Test that existing directories don't cause errors.""" + from utils.utils import create_directory_for_file + + existing_dir = tmp_path / "existing" + existing_dir.mkdir() + file_path = existing_dir / "file.txt" + + # Should not raise an error + create_directory_for_file(str(file_path)) + assert existing_dir.exists() + + +class TestSetRandomSeed: + """Tests for set_random_seed function.""" + + def test_reproducibility_torch(self): + """Test that torch random is reproducible with same seed.""" + from utils.utils import set_random_seed + + set_random_seed(42) + tensor1 = torch.randn(10) + + set_random_seed(42) + tensor2 = torch.randn(10) + + assert torch.allclose(tensor1, tensor2) + + def test_reproducibility_numpy(self): + """Test that numpy random is reproducible with same seed.""" + from utils.utils import set_random_seed + + set_random_seed(42) + arr1 = np.random.randn(10) + + set_random_seed(42) + arr2 = np.random.randn(10) + + np.testing.assert_array_almost_equal(arr1, arr2) + + def test_reproducibility_python_random(self): + """Test that Python random is reproducible with same seed.""" + from utils.utils import set_random_seed + import random + + set_random_seed(42) + val1 = [random.random() for _ in range(10)] + + set_random_seed(42) + val2 = [random.random() for _ in range(10)] + + assert val1 == val2 + + def test_different_seeds_produce_different_results(self): + """Test that different seeds produce different results.""" + from utils.utils import set_random_seed + + set_random_seed(42) + tensor1 = torch.randn(10) + + set_random_seed(123) + tensor2 = torch.randn(10) + + assert not torch.allclose(tensor1, tensor2) + + +# ============================================================================ +# Tests for utils/callbacks.py +# ============================================================================ + +class TestDenoisingLatentsCollector: + """Tests for DenoisingLatentsCollector class.""" + + def test_init_default_parameters(self): + """Test default initialization parameters.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + assert collector.save_every_n_steps == 1 + assert collector.to_cpu is True + assert collector.data == [] + assert collector._call_count == 0 + + def test_init_custom_parameters(self): + """Test custom initialization parameters.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector(save_every_n_steps=5, to_cpu=False) + assert collector.save_every_n_steps == 5 + assert collector.to_cpu is False + + def test_call_saves_latents(self): + """Test that __call__ saves latents correctly.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + latents = torch.randn(1, 4, 64, 64) + + collector(step=0, timestep=1000, latents=latents) + + assert len(collector.data) == 1 + assert collector.data[0]["step"] == 0 + assert collector.data[0]["timestep"] == 1000 + assert collector.data[0]["call_count"] == 1 + assert collector.data[0]["latents"].shape == latents.shape + + def test_call_respects_save_every_n_steps(self): + """Test that latents are saved every n steps.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector(save_every_n_steps=2) + latents = torch.randn(1, 4, 64, 64) + + for i in range(5): + collector(step=i, timestep=1000 - i * 100, latents=latents) + + # Should save at call 2 and 4 (1-indexed) + assert len(collector.data) == 2 + assert collector.data[0]["call_count"] == 2 + assert collector.data[1]["call_count"] == 4 + + def test_call_moves_to_cpu(self): + """Test that latents are moved to CPU when to_cpu=True.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector(to_cpu=True) + latents = torch.randn(1, 4, 64, 64) + + collector(step=0, timestep=1000, latents=latents) + + assert collector.data[0]["latents"].device == torch.device("cpu") + + def test_latents_list_property(self): + """Test latents_list property returns list of latents.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + + for i in range(3): + collector(step=i, timestep=1000 - i * 100, latents=torch.randn(1, 4, 64, 64)) + + latents_list = collector.latents_list + assert len(latents_list) == 3 + assert all(isinstance(l, torch.Tensor) for l in latents_list) + + def test_timesteps_list_property(self): + """Test timesteps_list property returns list of timesteps.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + timesteps = [1000, 800, 600] + + for i, ts in enumerate(timesteps): + collector(step=i, timestep=ts, latents=torch.randn(1, 4, 64, 64)) + + assert collector.timesteps_list == timesteps + + def test_get_latents_at_step(self): + """Test get_latents_at_step returns correct latents.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + latents_0 = torch.randn(1, 4, 64, 64) + latents_1 = torch.randn(1, 4, 64, 64) + + collector(step=0, timestep=1000, latents=latents_0) + collector(step=1, timestep=800, latents=latents_1) + + result = collector.get_latents_at_step(0) + assert torch.allclose(result, latents_0.cpu()) + + def test_get_latents_at_step_not_found(self): + """Test get_latents_at_step raises ValueError for missing step.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + collector(step=0, timestep=1000, latents=torch.randn(1, 4, 64, 64)) + + with pytest.raises(ValueError, match="No latents found for step"): + collector.get_latents_at_step(999) + + def test_clear(self): + """Test clear method resets collector state.""" + from utils.callbacks import DenoisingLatentsCollector + + collector = DenoisingLatentsCollector() + + for i in range(3): + collector(step=i, timestep=1000, latents=torch.randn(1, 4, 64, 64)) + + collector.clear() + + assert collector.data == [] + assert collector._call_count == 0 + + +# ============================================================================ +# Tests for utils/pipeline_utils.py +# ============================================================================ + +class TestPipelineUtils: + """Tests for pipeline utility functions.""" + + def test_pipeline_type_constants(self): + """Test pipeline type constants are defined correctly.""" + from utils.pipeline_utils import ( + PIPELINE_TYPE_IMAGE, + PIPELINE_TYPE_TEXT_TO_VIDEO, + PIPELINE_TYPE_IMAGE_TO_VIDEO, + ) + + assert PIPELINE_TYPE_IMAGE == "image" + assert PIPELINE_TYPE_TEXT_TO_VIDEO == "t2v" + assert PIPELINE_TYPE_IMAGE_TO_VIDEO == "i2v" + + def test_get_pipeline_type_unknown(self): + """Test get_pipeline_type returns None for unknown pipeline.""" + from utils.pipeline_utils import get_pipeline_type + + mock_pipeline = Mock() + result = get_pipeline_type(mock_pipeline) + assert result is None + + def test_is_video_pipeline_with_mock(self): + """Test is_video_pipeline with mocked pipelines.""" + from utils.pipeline_utils import is_video_pipeline, get_pipeline_type + + # Mock image pipeline + mock_image_pipe = Mock() + with patch("utils.pipeline_utils.get_pipeline_type", return_value="image"): + assert is_video_pipeline(mock_image_pipe) is False + + # Mock video pipeline + mock_video_pipe = Mock() + with patch("utils.pipeline_utils.get_pipeline_type", return_value="t2v"): + assert is_video_pipeline(mock_video_pipe) is True + + def test_is_image_pipeline_with_mock(self): + """Test is_image_pipeline with mocked pipelines.""" + from utils.pipeline_utils import is_image_pipeline + + mock_pipe = Mock() + with patch("utils.pipeline_utils.get_pipeline_type", return_value="image"): + assert is_image_pipeline(mock_pipe) is True + + with patch("utils.pipeline_utils.get_pipeline_type", return_value="t2v"): + assert is_image_pipeline(mock_pipe) is False + + def test_is_t2v_pipeline_with_mock(self): + """Test is_t2v_pipeline with mocked pipelines.""" + from utils.pipeline_utils import is_t2v_pipeline + + mock_pipe = Mock() + with patch("utils.pipeline_utils.get_pipeline_type", return_value="t2v"): + assert is_t2v_pipeline(mock_pipe) is True + + with patch("utils.pipeline_utils.get_pipeline_type", return_value="i2v"): + assert is_t2v_pipeline(mock_pipe) is False + + def test_is_i2v_pipeline_with_mock(self): + """Test is_i2v_pipeline with mocked pipelines.""" + from utils.pipeline_utils import is_i2v_pipeline + + mock_pipe = Mock() + with patch("utils.pipeline_utils.get_pipeline_type", return_value="i2v"): + assert is_i2v_pipeline(mock_pipe) is True + + with patch("utils.pipeline_utils.get_pipeline_type", return_value="t2v"): + assert is_i2v_pipeline(mock_pipe) is False + + def test_get_pipeline_requirements_image(self): + """Test get_pipeline_requirements for image pipeline.""" + from utils.pipeline_utils import get_pipeline_requirements, PIPELINE_TYPE_IMAGE + + result = get_pipeline_requirements(PIPELINE_TYPE_IMAGE) + assert result["required_params"] == [] + assert "height" in result["optional_params"] + assert "width" in result["optional_params"] + + def test_get_pipeline_requirements_t2v(self): + """Test get_pipeline_requirements for text-to-video pipeline.""" + from utils.pipeline_utils import get_pipeline_requirements, PIPELINE_TYPE_TEXT_TO_VIDEO + + result = get_pipeline_requirements(PIPELINE_TYPE_TEXT_TO_VIDEO) + assert "num_frames" in result["required_params"] + assert "fps" in result["optional_params"] + + def test_get_pipeline_requirements_i2v(self): + """Test get_pipeline_requirements for image-to-video pipeline.""" + from utils.pipeline_utils import get_pipeline_requirements, PIPELINE_TYPE_IMAGE_TO_VIDEO + + result = get_pipeline_requirements(PIPELINE_TYPE_IMAGE_TO_VIDEO) + assert "input_image" in result["required_params"] + assert "num_frames" in result["required_params"] + + def test_get_pipeline_requirements_unknown(self): + """Test get_pipeline_requirements for unknown pipeline type.""" + from utils.pipeline_utils import get_pipeline_requirements + + result = get_pipeline_requirements("unknown") + assert result["required_params"] == [] + assert result["optional_params"] == [] + + +# ============================================================================ +# Tests for utils/media_utils.py +# ============================================================================ + +class TestTorchToNumpy: + """Tests for torch_to_numpy function.""" + + def test_image_tensor_conversion(self): + """Test conversion of 4D image tensor.""" + from utils.media_utils import torch_to_numpy + + # Create tensor in range [-1, 1] + tensor = torch.randn(1, 3, 64, 64).clamp(-1, 1) + result = torch_to_numpy(tensor) + + assert isinstance(result, np.ndarray) + assert result.shape == (1, 64, 64, 3) # B, H, W, C + assert result.min() >= 0 and result.max() <= 1 + + def test_video_tensor_conversion(self): + """Test conversion of 5D video tensor.""" + from utils.media_utils import torch_to_numpy + + # Create tensor in range [-1, 1] + tensor = torch.randn(1, 3, 8, 64, 64).clamp(-1, 1) + result = torch_to_numpy(tensor) + + assert isinstance(result, np.ndarray) + assert result.shape == (1, 8, 64, 64, 3) # B, F, H, W, C + + def test_unsupported_dimension(self): + """Test that unsupported dimensions raise ValueError.""" + from utils.media_utils import torch_to_numpy + + tensor = torch.randn(3, 64, 64) # 3D tensor + with pytest.raises(ValueError, match="Unsupported tensor dimension"): + torch_to_numpy(tensor) + + +class TestPilToTorch: + """Tests for pil_to_torch function.""" + + def test_basic_conversion(self): + """Test basic PIL to torch conversion.""" + from utils.media_utils import pil_to_torch + + img = Image.new("RGB", (64, 64), color="red") + tensor = pil_to_torch(img) + + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 64, 64) + + def test_normalized_range(self): + """Test that normalized output is in [-1, 1] range.""" + from utils.media_utils import pil_to_torch + + img = Image.new("RGB", (64, 64), color="white") + tensor = pil_to_torch(img, normalize=True) + + # White pixels should be close to 1.0 after normalization + assert tensor.max() <= 1.0 + assert tensor.min() >= -1.0 + + def test_unnormalized_range(self): + """Test that unnormalized output is in [0, 1] range.""" + from utils.media_utils import pil_to_torch + + img = Image.new("RGB", (64, 64), color="white") + tensor = pil_to_torch(img, normalize=False) + + assert tensor.max() <= 1.0 + assert tensor.min() >= 0.0 + + +class TestNumpyToPil: + """Tests for numpy_to_pil function.""" + + def test_uint8_array(self): + """Test conversion of uint8 numpy array.""" + from utils.media_utils import numpy_to_pil + + arr = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + img = numpy_to_pil(arr) + + assert isinstance(img, Image.Image) + assert img.size == (64, 64) + + def test_float_array(self): + """Test conversion of float numpy array in [0, 1] range.""" + from utils.media_utils import numpy_to_pil + + arr = np.random.rand(64, 64, 3).astype(np.float32) + img = numpy_to_pil(arr) + + assert isinstance(img, Image.Image) + assert img.size == (64, 64) + + +class TestCv2ToPil: + """Tests for cv2_to_pil function.""" + + def test_uint8_array(self): + """Test conversion of uint8 cv2 array.""" + from utils.media_utils import cv2_to_pil + + arr = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8) + img = cv2_to_pil(arr) + + assert isinstance(img, Image.Image) + assert img.size == (64, 64) + + def test_float_array(self): + """Test conversion of float cv2 array.""" + from utils.media_utils import cv2_to_pil + + arr = np.random.rand(64, 64, 3).astype(np.float32) + img = cv2_to_pil(arr) + + assert isinstance(img, Image.Image) + + +class TestPilToCv2: + """Tests for pil_to_cv2 function.""" + + def test_basic_conversion(self): + """Test basic PIL to cv2 conversion.""" + from utils.media_utils import pil_to_cv2 + + img = Image.new("RGB", (64, 64), color="red") + arr = pil_to_cv2(img) + + assert isinstance(arr, np.ndarray) + assert arr.shape == (64, 64, 3) + assert arr.dtype == np.float64 + assert arr.max() <= 1.0 and arr.min() >= 0.0 + + +class TestTransformToModelFormat: + """Tests for transform_to_model_format function.""" + + def test_single_pil_image(self): + """Test transformation of single PIL image.""" + from utils.media_utils import transform_to_model_format + + img = Image.new("RGB", (64, 64), color="blue") + tensor = transform_to_model_format(img) + + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 64, 64) + # Check normalization to [-1, 1] + assert tensor.min() >= -1.0 and tensor.max() <= 1.0 + + def test_single_pil_image_with_resize(self): + """Test transformation with resize.""" + from utils.media_utils import transform_to_model_format + + img = Image.new("RGB", (128, 128), color="blue") + tensor = transform_to_model_format(img, target_size=64) + + assert tensor.shape == (3, 64, 64) + + def test_list_of_pil_images(self): + """Test transformation of list of PIL images.""" + from utils.media_utils import transform_to_model_format + + frames = [Image.new("RGB", (64, 64), color="red") for _ in range(4)] + tensor = transform_to_model_format(frames) + + assert tensor.shape == (4, 3, 64, 64) + + def test_list_of_numpy_arrays(self): + """Test transformation of list of numpy arrays.""" + from utils.media_utils import transform_to_model_format + + frames = [np.random.rand(64, 64, 3).astype(np.float32) for _ in range(4)] + tensor = transform_to_model_format(frames) + + assert tensor.shape == (4, 3, 64, 64) + + def test_single_numpy_frame(self): + """Test transformation of single numpy frame (3D array).""" + from utils.media_utils import transform_to_model_format + + frame = np.random.rand(64, 64, 3).astype(np.float32) + tensor = transform_to_model_format(frame) + + assert tensor.shape == (3, 64, 64) + + def test_numpy_video_array(self): + """Test transformation of numpy video array (4D).""" + from utils.media_utils import transform_to_model_format + + video = np.random.rand(8, 64, 64, 3).astype(np.float32) + tensor = transform_to_model_format(video) + + assert tensor.shape == (8, 3, 64, 64) + + def test_unsupported_type(self): + """Test that unsupported types raise ValueError.""" + from utils.media_utils import transform_to_model_format + + with pytest.raises(ValueError, match="Unsupported media type"): + transform_to_model_format("not_valid_input") + + def test_mixed_frame_types_raises_error(self): + """Test that mixed frame types raise ValueError.""" + from utils.media_utils import transform_to_model_format + + frames = [ + Image.new("RGB", (64, 64)), + np.random.rand(64, 64, 3), + ] + with pytest.raises(ValueError, match="All frames must be either"): + transform_to_model_format(frames) + + +class TestConvertVideoFramesToImages: + """Tests for convert_video_frames_to_images function.""" + + def test_numpy_frames(self): + """Test conversion of numpy frames to PIL images.""" + from utils.media_utils import convert_video_frames_to_images + + frames = [np.random.rand(64, 64, 3).astype(np.float32) for _ in range(4)] + result = convert_video_frames_to_images(frames) + + assert len(result) == 4 + assert all(isinstance(img, Image.Image) for img in result) + + def test_pil_frames(self): + """Test that PIL frames pass through unchanged.""" + from utils.media_utils import convert_video_frames_to_images + + frames = [Image.new("RGB", (64, 64)) for _ in range(4)] + result = convert_video_frames_to_images(frames) + + assert len(result) == 4 + assert all(isinstance(img, Image.Image) for img in result) + + def test_unsupported_frame_type(self): + """Test that unsupported frame types raise ValueError.""" + from utils.media_utils import convert_video_frames_to_images + + frames = ["not_a_frame"] + with pytest.raises(ValueError, match="Unsupported frame type"): + convert_video_frames_to_images(frames) + + +class TestSaveVideoFrames: + """Tests for save_video_frames function.""" + + def test_save_numpy_frames(self, tmp_path): + """Test saving numpy frames to disk.""" + from utils.media_utils import save_video_frames + + frames = [np.random.rand(64, 64, 3).astype(np.float32) for _ in range(4)] + save_dir = str(tmp_path) + save_video_frames(frames, save_dir) + + saved_files = list(tmp_path.glob("*.png")) + assert len(saved_files) == 4 + + def test_save_pil_frames(self, tmp_path): + """Test saving PIL frames to disk.""" + from utils.media_utils import save_video_frames + + frames = [Image.new("RGB", (64, 64), color="red") for _ in range(4)] + save_dir = str(tmp_path) + save_video_frames(frames, save_dir) + + saved_files = list(tmp_path.glob("*.png")) + assert len(saved_files) == 4 + + def test_frame_naming(self, tmp_path): + """Test that frames are named with zero-padded indices.""" + from utils.media_utils import save_video_frames + + frames = [Image.new("RGB", (64, 64)) for _ in range(3)] + save_dir = str(tmp_path) + save_video_frames(frames, save_dir) + + assert (tmp_path / "00.png").exists() + assert (tmp_path / "01.png").exists() + assert (tmp_path / "02.png").exists() + + +# ============================================================================ +# Tests for utils/diffusion_config.py +# ============================================================================ + +class TestDiffusionConfig: + """Tests for DiffusionConfig class.""" + + @pytest.fixture + def mock_image_pipeline(self): + """Create a mock image pipeline.""" + from diffusers import StableDiffusionPipeline + + mock_pipe = MagicMock(spec=StableDiffusionPipeline) + return mock_pipe + + @pytest.fixture + def mock_scheduler(self): + """Create a mock scheduler.""" + from diffusers import DPMSolverMultistepScheduler + + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + return mock_scheduler + + def test_default_parameters(self, mock_image_pipeline, mock_scheduler): + """Test DiffusionConfig with default parameters.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + ) + + assert config.guidance_scale == 7.5 + assert config.num_images == 1 + assert config.num_inference_steps == 50 + assert config.image_size == (512, 512) + assert config.dtype == torch.float16 + assert config.gen_seed == 0 + assert config.inversion_type == "ddim" + + def test_custom_parameters(self, mock_image_pipeline, mock_scheduler): + """Test DiffusionConfig with custom parameters.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cuda", + guidance_scale=10.0, + num_inference_steps=30, + image_size=(256, 256), + gen_seed=42, + ) + + assert config.guidance_scale == 10.0 + assert config.num_inference_steps == 30 + assert config.image_size == (256, 256) + assert config.gen_seed == 42 + + def test_invalid_inversion_type(self, mock_image_pipeline, mock_scheduler): + """Test that invalid inversion type raises AssertionError.""" + from utils.diffusion_config import DiffusionConfig + + with pytest.raises(AssertionError, match="Invalid inversion type"): + DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + inversion_type="invalid", + ) + + def test_num_inversion_steps_defaults_to_inference_steps( + self, mock_image_pipeline, mock_scheduler + ): + """Test num_inversion_steps defaults to num_inference_steps.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + num_inference_steps=30, + ) + + assert config.num_inversion_steps == 30 + + def test_explicit_num_inversion_steps(self, mock_image_pipeline, mock_scheduler): + """Test explicit num_inversion_steps.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + num_inference_steps=30, + num_inversion_steps=20, + ) + + assert config.num_inversion_steps == 20 + + def test_pipeline_type_property(self, mock_image_pipeline, mock_scheduler): + """Test pipeline_type property.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + ) + + # Should return "image" for StableDiffusionPipeline + assert config.pipeline_type == "image" + + def test_is_image_pipeline_property(self, mock_image_pipeline, mock_scheduler): + """Test is_image_pipeline property.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + ) + + assert config.is_image_pipeline is True + assert config.is_video_pipeline is False + + def test_gen_kwargs_stored(self, mock_image_pipeline, mock_scheduler): + """Test that extra kwargs are stored in gen_kwargs.""" + from utils.diffusion_config import DiffusionConfig + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_image_pipeline, + device="cpu", + custom_param="value", + another_param=42, + ) + + assert config.gen_kwargs["custom_param"] == "value" + assert config.gen_kwargs["another_param"] == 42 + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestMediaConversionRoundTrip: + """Integration tests for media conversion round trips.""" + + def test_pil_torch_pil_roundtrip(self): + """Test PIL -> Torch -> numpy -> PIL roundtrip.""" + from utils.media_utils import pil_to_torch, torch_to_numpy, numpy_to_pil + + original = Image.new("RGB", (64, 64), color=(128, 64, 192)) + tensor = pil_to_torch(original, normalize=True) + tensor = tensor.unsqueeze(0) # Add batch dim + numpy_arr = torch_to_numpy(tensor) + result = numpy_to_pil(numpy_arr[0]) + + # Check sizes match + assert result.size == original.size + + def test_numpy_pil_numpy_roundtrip(self): + """Test numpy -> PIL -> numpy roundtrip.""" + from utils.media_utils import numpy_to_pil, pil_to_cv2 + + original = np.random.rand(64, 64, 3).astype(np.float32) + pil_img = numpy_to_pil(original) + result = pil_to_cv2(pil_img) + + # Shape should be preserved + assert result.shape == original.shape + + +# ============================================================================ +# Additional Tests for media_utils.py - Coverage Improvement +# ============================================================================ + +class TestSetInversion: + """Tests for set_inversion function.""" + + def test_set_inversion_ddim(self): + """Test set_inversion with ddim type.""" + from utils.media_utils import set_inversion + + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.unet = MagicMock() + mock_pipe.device = "cpu" + + with patch("inversions.DDIMInversion") as mock_ddim: + mock_ddim.return_value = MagicMock() + result = set_inversion(mock_pipe, "ddim") + mock_ddim.assert_called_once_with(mock_pipe.scheduler, mock_pipe.unet, mock_pipe.device) + assert result is not None + + def test_set_inversion_exact(self): + """Test set_inversion with exact type.""" + from utils.media_utils import set_inversion + + mock_pipe = MagicMock() + mock_pipe.scheduler = MagicMock() + mock_pipe.unet = MagicMock() + mock_pipe.device = "cpu" + + with patch("inversions.ExactInversion") as mock_exact: + mock_exact.return_value = MagicMock() + result = set_inversion(mock_pipe, "exact") + mock_exact.assert_called_once_with(mock_pipe.scheduler, mock_pipe.unet, mock_pipe.device) + assert result is not None + + def test_set_inversion_invalid_type(self): + """Test set_inversion with invalid type raises ValueError.""" + from utils.media_utils import set_inversion + + mock_pipe = MagicMock() + with pytest.raises(ValueError, match="Invalid inversion type"): + set_inversion(mock_pipe, "invalid") + + +class TestTensor2Vid: + """Tests for tensor2vid function.""" + + def test_tensor2vid_np_output(self): + """Test tensor2vid with numpy output.""" + from utils.media_utils import tensor2vid + + # Create mock video tensor [B, C, F, H, W] + video = torch.rand(1, 3, 4, 64, 64) + + mock_processor = MagicMock() + mock_processor.postprocess.return_value = np.random.rand(4, 64, 64, 3) + + result = tensor2vid(video, mock_processor, output_type="np") + + assert isinstance(result, np.ndarray) + mock_processor.postprocess.assert_called() + + def test_tensor2vid_pt_output(self): + """Test tensor2vid with pytorch output.""" + from utils.media_utils import tensor2vid + + video = torch.rand(1, 3, 4, 64, 64) + + mock_processor = MagicMock() + mock_processor.postprocess.return_value = torch.rand(4, 64, 64, 3) + + result = tensor2vid(video, mock_processor, output_type="pt") + + assert isinstance(result, torch.Tensor) + + def test_tensor2vid_pil_output(self): + """Test tensor2vid with PIL output.""" + from utils.media_utils import tensor2vid + + video = torch.rand(1, 3, 4, 64, 64) + + mock_processor = MagicMock() + mock_processor.postprocess.return_value = [Image.new("RGB", (64, 64)) for _ in range(4)] + + result = tensor2vid(video, mock_processor, output_type="pil") + + assert isinstance(result, list) + + def test_tensor2vid_invalid_output_type(self): + """Test tensor2vid with invalid output type raises ValueError.""" + from utils.media_utils import tensor2vid + + video = torch.rand(1, 3, 4, 64, 64) + mock_processor = MagicMock() + mock_processor.postprocess.return_value = "invalid" + + with pytest.raises(ValueError, match="does not exist"): + tensor2vid(video, mock_processor, output_type="invalid") + + +class TestGetMediaLatents: + """Tests for get_media_latents function.""" + + def test_get_media_latents_image(self): + """Test get_media_latents for image pipeline.""" + from utils.media_utils import get_media_latents + + mock_pipe = MagicMock() + + image = torch.rand(1, 3, 512, 512) + + with patch("utils.media_utils.get_pipeline_type", return_value="image"): + with patch("utils.media_utils._get_image_latents") as mock_get: + mock_get.return_value = torch.rand(1, 4, 64, 64) + result = get_media_latents(mock_pipe, image) + mock_get.assert_called_once() + assert result is not None + + def test_get_media_latents_t2v(self): + """Test get_media_latents for text-to-video pipeline.""" + from utils.media_utils import get_media_latents + from diffusers import TextToVideoSDPipeline + + mock_pipe = MagicMock(spec=TextToVideoSDPipeline) + + video = torch.rand(8, 3, 512, 512) + + with patch("utils.media_utils.get_pipeline_type", return_value="t2v"): + with patch("utils.media_utils._get_video_latents") as mock_get: + mock_get.return_value = torch.rand(1, 4, 8, 64, 64) + result = get_media_latents(mock_pipe, video) + mock_get.assert_called_once() + assert result is not None + + def test_get_media_latents_unsupported_pipeline(self): + """Test get_media_latents with unsupported pipeline type.""" + from utils.media_utils import get_media_latents + + mock_pipe = MagicMock() + image = torch.rand(1, 3, 512, 512) + + with patch("utils.media_utils.get_pipeline_type", return_value=None): + with pytest.raises(ValueError, match="Unsupported pipeline type"): + get_media_latents(mock_pipe, image) + + +class TestDecodeMediaLatents: + """Tests for decode_media_latents function.""" + + def test_decode_media_latents_image(self): + """Test decode_media_latents for image pipeline.""" + from utils.media_utils import decode_media_latents + from diffusers import StableDiffusionPipeline + + mock_pipe = MagicMock(spec=StableDiffusionPipeline) + + latents = torch.rand(1, 4, 64, 64) + + with patch("utils.media_utils.get_pipeline_type", return_value="image"): + with patch("utils.media_utils._decode_image_latents") as mock_decode: + mock_decode.return_value = torch.rand(1, 3, 512, 512) + result = decode_media_latents(mock_pipe, latents) + mock_decode.assert_called_once() + assert result is not None + + def test_decode_media_latents_video(self): + """Test decode_media_latents for video pipeline.""" + from utils.media_utils import decode_media_latents + from diffusers import TextToVideoSDPipeline + + mock_pipe = MagicMock(spec=TextToVideoSDPipeline) + + latents = torch.rand(1, 4, 8, 64, 64) + + with patch("utils.media_utils.get_pipeline_type", return_value="t2v"): + with patch("utils.media_utils._decode_video_latents") as mock_decode: + mock_decode.return_value = np.random.rand(1, 8, 512, 512, 3) + result = decode_media_latents(mock_pipe, latents) + mock_decode.assert_called_once() + assert result is not None + + def test_decode_media_latents_unsupported_pipeline(self): + """Test decode_media_latents with unsupported pipeline type.""" + from utils.media_utils import decode_media_latents + + mock_pipe = MagicMock() + latents = torch.rand(1, 4, 64, 64) + + with patch("utils.media_utils.get_pipeline_type", return_value=None): + with pytest.raises(ValueError, match="Unsupported pipeline type"): + decode_media_latents(mock_pipe, latents) + + +class TestGetVideoLatents: + """Tests for _get_video_latents function.""" + + def test_get_video_latents_sample(self): + """Test _get_video_latents with sampling.""" + from utils.media_utils import _get_video_latents + + mock_pipe = MagicMock() + mock_dist = MagicMock() + mock_dist.sample.return_value = torch.rand(8, 4, 64, 64) + mock_pipe.vae.encode.return_value.latent_dist = mock_dist + + video_frames = torch.rand(8, 3, 512, 512) + + result = _get_video_latents(mock_pipe, video_frames, sample=True, permute=True) + + assert result is not None + mock_dist.sample.assert_called_once() + + def test_get_video_latents_mode(self): + """Test _get_video_latents with mode (no sampling).""" + from utils.media_utils import _get_video_latents + + mock_pipe = MagicMock() + mock_dist = MagicMock() + mock_dist.mode.return_value = torch.rand(8, 4, 64, 64) + mock_pipe.vae.encode.return_value.latent_dist = mock_dist + + video_frames = torch.rand(8, 3, 512, 512) + + result = _get_video_latents(mock_pipe, video_frames, sample=False, permute=False) + + assert result is not None + mock_dist.mode.assert_called_once() + + def test_get_video_latents_decoder_inv_not_implemented(self): + """Test _get_video_latents raises NotImplementedError for decoder_inv.""" + from utils.media_utils import _get_video_latents + + mock_pipe = MagicMock() + mock_dist = MagicMock() + mock_dist.sample.return_value = torch.rand(8, 4, 64, 64) + mock_pipe.vae.encode.return_value.latent_dist = mock_dist + + video_frames = torch.rand(8, 3, 512, 512) + + with pytest.raises(NotImplementedError, match="Decoder inversion is not implemented"): + _get_video_latents(mock_pipe, video_frames, decoder_inv=True) + + +class TestDecodeVideoLatents: + """Tests for _decode_video_latents function.""" + + def test_decode_video_latents_with_num_frames(self): + """Test _decode_video_latents with num_frames specified.""" + from utils.media_utils import _decode_video_latents + + mock_pipe = MagicMock() + mock_pipe.decode_latents.return_value = torch.rand(1, 3, 8, 64, 64) + mock_pipe.video_processor.postprocess.return_value = np.random.rand(8, 64, 64, 3) + + latents = torch.rand(1, 4, 8, 64, 64) + + with patch("utils.media_utils.tensor2vid") as mock_tensor2vid: + mock_tensor2vid.return_value = np.random.rand(1, 8, 64, 64, 3) + result = _decode_video_latents(mock_pipe, latents, num_frames=8) + mock_pipe.decode_latents.assert_called_once_with(latents, 8) + + def test_decode_video_latents_without_num_frames(self): + """Test _decode_video_latents without num_frames.""" + from utils.media_utils import _decode_video_latents + + mock_pipe = MagicMock() + mock_pipe.decode_latents.return_value = torch.rand(1, 3, 8, 64, 64) + + latents = torch.rand(1, 4, 8, 64, 64) + + with patch("utils.media_utils.tensor2vid") as mock_tensor2vid: + mock_tensor2vid.return_value = np.random.rand(1, 8, 64, 64, 3) + result = _decode_video_latents(mock_pipe, latents) + mock_pipe.decode_latents.assert_called_once_with(latents) + + +class TestGetImageLatents: + """Tests for _get_image_latents function.""" + + def test_get_image_latents_sample(self): + """Test _get_image_latents with sampling.""" + from utils.media_utils import _get_image_latents + + mock_pipe = MagicMock() + mock_dist = MagicMock() + mock_dist.sample.return_value = torch.rand(1, 4, 64, 64) + mock_pipe.vae.encode.return_value.latent_dist = mock_dist + + image = torch.rand(1, 3, 512, 512) + + result = _get_image_latents(mock_pipe, image, sample=True) + + assert result is not None + mock_dist.sample.assert_called_once() + + def test_get_image_latents_mode(self): + """Test _get_image_latents with mode (no sampling).""" + from utils.media_utils import _get_image_latents + + mock_pipe = MagicMock() + mock_dist = MagicMock() + mock_dist.mode.return_value = torch.rand(1, 4, 64, 64) + mock_pipe.vae.encode.return_value.latent_dist = mock_dist + + image = torch.rand(1, 3, 512, 512) + + result = _get_image_latents(mock_pipe, image, sample=False) + + assert result is not None + mock_dist.mode.assert_called_once() + + +class TestDecodeImageLatents: + """Tests for _decode_image_latents function.""" + + def test_decode_image_latents(self): + """Test _decode_image_latents basic functionality.""" + from utils.media_utils import _decode_image_latents + + mock_pipe = MagicMock() + mock_pipe.vae.decode.return_value = [torch.rand(1, 3, 512, 512)] + + latents = torch.rand(1, 4, 64, 64) + + result = _decode_image_latents(mock_pipe, latents) + + assert result is not None + mock_pipe.vae.decode.assert_called_once() + + +# ============================================================================ +# Additional Tests for diffusion_config.py - Coverage Improvement +# ============================================================================ + +class TestDiffusionConfigAdditional: + """Additional tests for DiffusionConfig class.""" + + def test_pipeline_requirements_image(self): + """Test pipeline_requirements property for image pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=StableDiffusionPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + ) + + requirements = config.pipeline_requirements + assert requirements["required_params"] == [] + assert "height" in requirements["optional_params"] + + def test_pipeline_requirements_t2v(self): + """Test pipeline_requirements property for t2v pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=TextToVideoSDPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=8, + ) + + requirements = config.pipeline_requirements + assert "num_frames" in requirements["required_params"] + assert "fps" in requirements["optional_params"] + + def test_pipeline_requirements_i2v(self): + """Test pipeline_requirements property for i2v pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import StableVideoDiffusionPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=StableVideoDiffusionPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=8, + ) + + requirements = config.pipeline_requirements + assert "input_image" in requirements["required_params"] + assert "num_frames" in requirements["required_params"] + + def test_validate_pipeline_config_unsupported(self): + """Test _validate_pipeline_config raises ValueError for unsupported pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import DPMSolverMultistepScheduler + + mock_pipe = MagicMock() # Generic mock, not a specific pipeline type + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + with pytest.raises(ValueError, match="Unsupported pipeline type"): + DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + ) + + def test_validate_pipeline_config_video_needs_frames(self): + """Test _validate_pipeline_config raises ValueError when video pipeline has no frames.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=TextToVideoSDPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + with pytest.raises(ValueError, match="num_frames must be >= 1"): + DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=-1, # Invalid for video + ) + + def test_validate_pipeline_config_image_auto_corrects_frames(self): + """Test _validate_pipeline_config auto-corrects num_frames for image pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=StableDiffusionPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=10, # Should be auto-corrected to -1 for image + ) + + assert config.num_frames == -1 + + def test_is_video_pipeline_for_t2v(self): + """Test is_video_pipeline returns True for t2v pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import TextToVideoSDPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=TextToVideoSDPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=8, + ) + + assert config.is_video_pipeline is True + assert config.is_image_pipeline is False + + def test_is_video_pipeline_for_i2v(self): + """Test is_video_pipeline returns True for i2v pipeline.""" + from utils.diffusion_config import DiffusionConfig + from diffusers import StableVideoDiffusionPipeline, DPMSolverMultistepScheduler + + mock_pipe = MagicMock(spec=StableVideoDiffusionPipeline) + mock_scheduler = MagicMock(spec=DPMSolverMultistepScheduler) + + config = DiffusionConfig( + scheduler=mock_scheduler, + pipe=mock_pipe, + device="cpu", + num_frames=8, + ) + + assert config.is_video_pipeline is True + + +# ============================================================================ +# Additional Tests for utils.py - Coverage Improvement +# ============================================================================ + +class TestLoadConfigFileAdditional: + """Additional tests for load_config_file function.""" + + def test_load_config_file_unexpected_error(self, tmp_path, capsys): + """Test load_config_file handles unexpected errors.""" + from utils.utils import load_config_file + + # Create a file that will cause an unexpected error by making it a directory + config_path = tmp_path / "config" + config_path.mkdir() + + result = load_config_file(str(config_path)) + assert result is None + + captured = capsys.readouterr() + assert "unexpected error" in captured.out.lower() or "error" in captured.out.lower() + + +class TestCreateDirectoryForFileAdditional: + """Additional tests for create_directory_for_file function.""" + + def test_create_directory_for_file_with_relative_path(self, tmp_path): + """Test create_directory_for_file with relative path in tmp directory.""" + from utils.utils import create_directory_for_file + import os + + # Create a file path with a subdirectory in tmp_path + file_path = str(tmp_path / "subdir" / "file.txt") + create_directory_for_file(file_path) + + # Verify the directory was created + assert (tmp_path / "subdir").exists() + assert (tmp_path / "subdir").is_dir() diff --git a/test/test_video_editor.py b/test/test_video_editor.py new file mode 100644 index 0000000..0487b0a --- /dev/null +++ b/test/test_video_editor.py @@ -0,0 +1,698 @@ +# Copyright 2025 THU-BPM MarkDiffusion. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for video editor classes in MarkDiffusion. + +Tests cover: +- VideoEditor: Base class +- MPEG4Compression: MPEG-4 video compression +- VideoCodecAttack: Re-encode with various codecs +- FrameAverage: Averaging frames in sliding window +- FrameRateAdapter: Frame rate conversion +- FrameSwap: Random adjacent frame swapping +- FrameInterpolationAttack: Insert interpolated frames +""" + +import pytest +import numpy as np +from PIL import Image +from unittest.mock import patch, MagicMock +import shutil + +from evaluation.tools.video_editor import ( + VideoEditor, + MPEG4Compression, + VideoCodecAttack, + FrameAverage, + FrameRateAdapter, + FrameSwap, + FrameInterpolationAttack, +) + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def sample_frames(): + """Create a list of sample RGB frames.""" + frames = [] + for i in range(10): + # Create frames with different colors for variation + color = (i * 25, 100, 255 - i * 25) + img = Image.new("RGB", (128, 128), color=color) + frames.append(img) + return frames + + +@pytest.fixture +def sample_gradient_frames(): + """Create gradient frames for testing interpolation.""" + frames = [] + for i in range(5): + arr = np.full((64, 64, 3), i * 50, dtype=np.uint8) + frames.append(Image.fromarray(arr)) + return frames + + +@pytest.fixture +def single_frame(): + """Create a single frame.""" + return [Image.new("RGB", (128, 128), color="red")] + + +@pytest.fixture +def two_frames(): + """Create two frames for edge case testing.""" + return [ + Image.new("RGB", (64, 64), color="red"), + Image.new("RGB", (64, 64), color="blue"), + ] + + +@pytest.fixture +def empty_frames(): + """Return empty frame list.""" + return [] + + +# ============================================================================ +# Tests for VideoEditor Base Class +# ============================================================================ + +class TestVideoEditor: + """Tests for VideoEditor base class.""" + + def test_initialization(self): + """Test base class can be instantiated.""" + editor = VideoEditor() + assert editor is not None + + def test_edit_method_exists(self): + """Test edit method exists.""" + editor = VideoEditor() + assert hasattr(editor, 'edit') + + def test_edit_returns_none_by_default(self, sample_frames): + """Test base edit method returns None.""" + editor = VideoEditor() + result = editor.edit(sample_frames) + assert result is None + + +# ============================================================================ +# Tests for MPEG4Compression +# ============================================================================ + +class TestMPEG4Compression: + """Tests for MPEG4Compression editor.""" + + def test_default_fps(self): + """Test default fps is 24.0.""" + editor = MPEG4Compression() + assert editor.fps == 24.0 + + def test_custom_fps(self): + """Test custom fps setting.""" + editor = MPEG4Compression(fps=30.0) + assert editor.fps == 30.0 + + def test_fourcc_initialized(self): + """Test fourcc codec is initialized.""" + editor = MPEG4Compression() + assert editor.fourcc is not None + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + editor = MPEG4Compression(fps=24.0) + result = editor.edit(sample_frames) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + editor = MPEG4Compression(fps=24.0) + result = editor.edit(sample_frames) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_compression_changes_pixels(self, sample_frames): + """Test compression may change pixel values due to lossy encoding.""" + editor = MPEG4Compression(fps=24.0) + result = editor.edit(sample_frames) + + # At least some frames should have different pixels due to compression + original_arr = np.array(sample_frames[0]) + result_arr = np.array(result[0]) + + # Due to lossy compression, pixels may differ + # We just check the shapes match + assert original_arr.shape == result_arr.shape + + def test_various_fps_values(self, two_frames): + """Test various fps values.""" + for fps in [15.0, 24.0, 30.0, 60.0]: + editor = MPEG4Compression(fps=fps) + result = editor.edit(two_frames) + assert isinstance(result, list) + + +# ============================================================================ +# Tests for VideoCodecAttack +# ============================================================================ + +class TestVideoCodecAttack: + """Tests for VideoCodecAttack editor.""" + + def test_default_parameters(self): + """Test default parameters.""" + # Skip if ffmpeg not available + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack() + assert editor.codec == "h264" + assert editor.bitrate == "2M" + assert editor.fps == 24.0 + + def test_custom_codec(self): + """Test custom codec setting.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack(codec="h265") + assert editor.codec == "h265" + + def test_hevc_alias(self): + """Test hevc is aliased to h265.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack(codec="hevc") + assert editor.codec == "h265" + + def test_unsupported_codec_raises(self): + """Test unsupported codec raises ValueError.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + with pytest.raises(ValueError, match="Unsupported codec"): + VideoCodecAttack(codec="invalid_codec") + + def test_ffmpeg_not_found_raises(self): + """Test missing ffmpeg raises EnvironmentError.""" + with patch('shutil.which', return_value=None): + with pytest.raises(EnvironmentError, match="ffmpeg executable not found"): + VideoCodecAttack(ffmpeg_path=None) + + def test_custom_bitrate(self): + """Test custom bitrate setting.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack(bitrate="5M") + assert editor.bitrate == "5M" + + def test_edit_empty_frames(self): + """Test edit with empty frames returns empty list.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack() + result = editor.edit([]) + assert result == [] + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack(codec="h264") + result = editor.edit(sample_frames) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + if shutil.which("ffmpeg") is None: + pytest.skip("ffmpeg not available") + editor = VideoCodecAttack(codec="h264") + result = editor.edit(sample_frames) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_codec_map_entries(self): + """Test codec map has expected entries.""" + expected_codecs = {"h264", "h265", "hevc", "vp9", "av1"} + assert expected_codecs.issubset(VideoCodecAttack._CODEC_MAP.keys()) + + +# ============================================================================ +# Tests for FrameAverage +# ============================================================================ + +class TestFrameAverage: + """Tests for FrameAverage editor.""" + + def test_default_n_frames(self): + """Test default n_frames is 3.""" + editor = FrameAverage() + assert editor.n_frames == 3 + + def test_custom_n_frames(self): + """Test custom n_frames setting.""" + editor = FrameAverage(n_frames=5) + assert editor.n_frames == 5 + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + editor = FrameAverage(n_frames=3) + result = editor.edit(sample_frames) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_count(self, sample_frames): + """Test edit preserves number of frames.""" + editor = FrameAverage(n_frames=3) + result = editor.edit(sample_frames) + assert len(result) == len(sample_frames) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + editor = FrameAverage(n_frames=3) + result = editor.edit(sample_frames) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_averaging_effect(self): + """Test that averaging smooths frames.""" + # Create frames with distinct pixel values + frames = [] + for i in range(5): + arr = np.full((64, 64, 3), i * 60, dtype=np.uint8) # 0, 60, 120, 180, 240 + frames.append(Image.fromarray(arr)) + + editor = FrameAverage(n_frames=3) + result = editor.edit(frames) + + # Middle frame (index 2, value 120) should be averaged with neighbors (60, 180) + # Average of [60, 120, 180] = 120, so it stays the same + # But first frame (value 0) averaged with [0, 60] should change + first_original = np.array(frames[0]).mean() # 0 + first_result = np.array(result[0]).mean() # avg of [0, 60] = 30 + + # First frame should be affected by second frame in the window + assert first_result > first_original + + def test_single_frame(self, single_frame): + """Test with single frame.""" + editor = FrameAverage(n_frames=3) + result = editor.edit(single_frame) + assert len(result) == 1 + assert isinstance(result[0], Image.Image) + + def test_n_frames_larger_than_video(self, two_frames): + """Test when n_frames is larger than video length.""" + editor = FrameAverage(n_frames=10) + result = editor.edit(two_frames) + assert len(result) == 2 + + def test_various_n_frames(self, sample_frames): + """Test various n_frames values.""" + for n in [1, 2, 3, 5, 7]: + editor = FrameAverage(n_frames=n) + result = editor.edit(sample_frames) + assert len(result) == len(sample_frames) + + +# ============================================================================ +# Tests for FrameRateAdapter +# ============================================================================ + +class TestFrameRateAdapter: + """Tests for FrameRateAdapter editor.""" + + def test_default_parameters(self): + """Test default parameters.""" + editor = FrameRateAdapter() + assert editor.source_fps == 30.0 + assert editor.target_fps == 24.0 + + def test_custom_parameters(self): + """Test custom fps settings.""" + editor = FrameRateAdapter(source_fps=60.0, target_fps=30.0) + assert editor.source_fps == 60.0 + assert editor.target_fps == 30.0 + + def test_invalid_fps_raises(self): + """Test invalid fps raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + FrameRateAdapter(source_fps=0) + with pytest.raises(ValueError, match="must be positive"): + FrameRateAdapter(target_fps=-10) + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + editor = FrameRateAdapter(source_fps=30.0, target_fps=24.0) + result = editor.edit(sample_frames) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + editor = FrameRateAdapter(source_fps=30.0, target_fps=24.0) + result = editor.edit(sample_frames) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_same_fps_returns_copies(self, sample_frames): + """Test same source and target fps returns copies.""" + editor = FrameRateAdapter(source_fps=30.0, target_fps=30.0) + result = editor.edit(sample_frames) + assert len(result) == len(sample_frames) + + def test_downsampling(self, sample_frames): + """Test downsampling reduces frame count.""" + editor = FrameRateAdapter(source_fps=30.0, target_fps=15.0) + result = editor.edit(sample_frames) + # Should have roughly half the frames + assert len(result) < len(sample_frames) + + def test_upsampling(self, sample_frames): + """Test upsampling increases frame count.""" + editor = FrameRateAdapter(source_fps=15.0, target_fps=30.0) + result = editor.edit(sample_frames) + # Should have roughly double the frames + assert len(result) > len(sample_frames) + + def test_empty_frames(self, empty_frames): + """Test with empty frames.""" + editor = FrameRateAdapter() + result = editor.edit(empty_frames) + assert result == [] + + def test_single_frame(self, single_frame): + """Test with single frame.""" + editor = FrameRateAdapter() + result = editor.edit(single_frame) + assert len(result) == 1 + + def test_interpolation_smoothness(self, sample_gradient_frames): + """Test interpolated frames have smooth transitions.""" + editor = FrameRateAdapter(source_fps=15.0, target_fps=30.0) + result = editor.edit(sample_gradient_frames) + + # Check that interpolated values are between neighbors + for i in range(1, len(result) - 1): + arr = np.array(result[i]) + # Values should be within valid range + assert arr.min() >= 0 + assert arr.max() <= 255 + + +# ============================================================================ +# Tests for FrameSwap +# ============================================================================ + +class TestFrameSwap: + """Tests for FrameSwap editor.""" + + def test_default_probability(self): + """Test default swap probability is 0.25.""" + editor = FrameSwap() + assert editor.p == 0.25 + + def test_custom_probability(self): + """Test custom swap probability.""" + editor = FrameSwap(p=0.5) + assert editor.p == 0.5 + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + editor = FrameSwap(p=0.5) + result = editor.edit(sample_frames.copy()) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_count(self, sample_frames): + """Test edit preserves number of frames.""" + editor = FrameSwap(p=0.5) + frames_copy = sample_frames.copy() + result = editor.edit(frames_copy) + assert len(result) == len(sample_frames) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + editor = FrameSwap(p=0.5) + frames_copy = sample_frames.copy() + result = editor.edit(frames_copy) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_high_probability_no_swap(self, sample_frames): + """Test high probability (p=1.0) doesn't swap (inverted logic in implementation).""" + # Note: The implementation swaps when random() >= p + # So p=1.0 means never swap (random() < 1.0 always) + editor = FrameSwap(p=1.0) + frames_copy = [f.copy() for f in sample_frames] + original_arrays = [np.array(f) for f in frames_copy] + + result = editor.edit(frames_copy) + result_arrays = [np.array(f) for f in result] + + for orig, res in zip(original_arrays, result_arrays): + np.testing.assert_array_equal(orig, res) + + def test_zero_probability_always_swaps(self, sample_frames): + """Test zero probability (p=0.0) always swaps (inverted logic in implementation).""" + # Note: The implementation swaps when random() >= p + # So p=0.0 means always swap (random() >= 0.0 always) + editor = FrameSwap(p=0.0) + frames_copy = [f.copy() for f in sample_frames] + + editor.edit(frames_copy) + # Just check it runs without error and returns correct length + assert len(frames_copy) == len(sample_frames) + + def test_single_frame(self, single_frame): + """Test with single frame.""" + editor = FrameSwap(p=0.5) + result = editor.edit(single_frame) + assert len(result) == 1 + + def test_two_frames(self, two_frames): + """Test with two frames.""" + editor = FrameSwap(p=0.5) + result = editor.edit(two_frames) + assert len(result) == 2 + + +# ============================================================================ +# Tests for FrameInterpolationAttack +# ============================================================================ + +class TestFrameInterpolationAttack: + """Tests for FrameInterpolationAttack editor.""" + + def test_default_parameters(self): + """Test default interpolated_frames is 1.""" + editor = FrameInterpolationAttack() + assert editor.interpolated_frames == 1 + + def test_custom_parameters(self): + """Test custom interpolated_frames setting.""" + editor = FrameInterpolationAttack(interpolated_frames=3) + assert editor.interpolated_frames == 3 + + def test_negative_frames_raises(self): + """Test negative interpolated_frames raises ValueError.""" + with pytest.raises(ValueError, match="must be non-negative"): + FrameInterpolationAttack(interpolated_frames=-1) + + def test_edit_returns_list(self, sample_frames): + """Test edit returns a list of PIL Images.""" + editor = FrameInterpolationAttack(interpolated_frames=1) + result = editor.edit(sample_frames) + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + def test_edit_preserves_frame_size(self, sample_frames): + """Test edit preserves frame dimensions.""" + editor = FrameInterpolationAttack(interpolated_frames=1) + result = editor.edit(sample_frames) + original_size = sample_frames[0].size + for frame in result: + assert frame.size == original_size + + def test_frame_count_increases(self, sample_frames): + """Test interpolation increases frame count.""" + n_interp = 2 + editor = FrameInterpolationAttack(interpolated_frames=n_interp) + result = editor.edit(sample_frames) + + # Expected: original frames + (n_original - 1) * n_interp + expected_count = len(sample_frames) + (len(sample_frames) - 1) * n_interp + assert len(result) == expected_count + + def test_zero_interpolation(self, sample_frames): + """Test zero interpolated_frames returns copies.""" + editor = FrameInterpolationAttack(interpolated_frames=0) + result = editor.edit(sample_frames) + assert len(result) == len(sample_frames) + + def test_empty_frames(self, empty_frames): + """Test with empty frames.""" + editor = FrameInterpolationAttack(interpolated_frames=1) + result = editor.edit(empty_frames) + assert result == [] + + def test_single_frame(self, single_frame): + """Test with single frame.""" + editor = FrameInterpolationAttack(interpolated_frames=2) + result = editor.edit(single_frame) + assert len(result) == 1 + + def test_two_frames_interpolation(self, two_frames): + """Test interpolation between two frames.""" + editor = FrameInterpolationAttack(interpolated_frames=1) + result = editor.edit(two_frames) + + # Should have: frame1, interpolated, frame2 + assert len(result) == 3 + + def test_interpolated_values(self, sample_gradient_frames): + """Test interpolated frames have intermediate values.""" + editor = FrameInterpolationAttack(interpolated_frames=1) + result = editor.edit(sample_gradient_frames) + + # Check that interpolated frame (index 1) is between original frames + original_0 = np.array(sample_gradient_frames[0]).astype(float) + original_1 = np.array(sample_gradient_frames[1]).astype(float) + interpolated = np.array(result[1]).astype(float) + + # Interpolated value should be close to average of neighbors + expected = (original_0 + original_1) / 2 + np.testing.assert_array_almost_equal(interpolated, expected, decimal=0) + + def test_various_interpolation_counts(self, two_frames): + """Test various interpolation counts.""" + for n in [0, 1, 2, 3, 5]: + editor = FrameInterpolationAttack(interpolated_frames=n) + result = editor.edit([f.copy() for f in two_frames]) + expected = 2 + 1 * n # 2 original + 1 gap * n interpolated + assert len(result) == expected + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestVideoEditorChaining: + """Test chaining multiple video editors.""" + + def test_chain_frame_editors(self, sample_frames): + """Test chaining frame-based editors.""" + editors = [ + FrameAverage(n_frames=3), + FrameSwap(p=0.1), + ] + + result = sample_frames + for editor in editors: + result = editor.edit(result) + + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(f, Image.Image) for f in result) + + def test_chain_with_interpolation(self, sample_frames): + """Test chaining with interpolation editor.""" + editors = [ + FrameInterpolationAttack(interpolated_frames=1), + FrameAverage(n_frames=3), + ] + + result = sample_frames + for editor in editors: + result = editor.edit(result) + + assert isinstance(result, list) + assert len(result) > len(sample_frames) # Should have more frames + + def test_chain_with_rate_adapter(self, sample_frames): + """Test chaining with frame rate adapter.""" + editors = [ + FrameRateAdapter(source_fps=30.0, target_fps=24.0), + FrameAverage(n_frames=3), + ] + + result = sample_frames + for editor in editors: + result = editor.edit(result) + + assert isinstance(result, list) + assert all(isinstance(f, Image.Image) for f in result) + + +class TestVideoEditorEdgeCases: + """Test edge cases for video editors.""" + + def test_large_frame_count(self): + """Test with many frames.""" + frames = [Image.new("RGB", (32, 32), color=(i % 256, 0, 0)) for i in range(100)] + + editor = FrameAverage(n_frames=5) + result = editor.edit(frames) + + assert len(result) == 100 + + def test_small_frame_size(self): + """Test with very small frames.""" + frames = [Image.new("RGB", (8, 8), color="red") for _ in range(5)] + + editor = FrameAverage(n_frames=3) + result = editor.edit(frames) + + assert len(result) == 5 + assert all(f.size == (8, 8) for f in result) + + def test_large_frame_size(self): + """Test with large frames.""" + frames = [Image.new("RGB", (512, 512), color="blue") for _ in range(3)] + + editor = FrameAverage(n_frames=3) + result = editor.edit(frames) + + assert len(result) == 3 + assert all(f.size == (512, 512) for f in result) + + def test_non_square_frames(self): + """Test with non-square frames.""" + frames = [Image.new("RGB", (320, 180), color="green") for _ in range(5)] + + editor = FrameAverage(n_frames=3) + result = editor.edit(frames) + + assert len(result) == 5 + assert all(f.size == (320, 180) for f in result) + + def test_grayscale_to_rgb_conversion(self): + """Test editors handle RGB frames correctly.""" + # Create RGB frames + frames = [Image.new("RGB", (64, 64), color=(128, 128, 128)) for _ in range(5)] + + editor = FrameAverage(n_frames=3) + result = editor.edit(frames) + + assert all(f.mode == "RGB" for f in result) diff --git a/test/test_watermark_algorithms.py b/test/test_watermark_algorithms.py index 1398e8d..3cbb8c5 100644 --- a/test/test_watermark_algorithms.py +++ b/test/test_watermark_algorithms.py @@ -67,19 +67,79 @@ def test_image_watermark_generation(algorithm_name, image_diffusion_config, skip pytest.skip("Generation tests skipped by --skip-generation flag") try: + # smaller test params for generation tests + image_diffusion_config.num_inference_steps = 10 + image_diffusion_config.image_size = (128, 128) + watermark = AutoWatermark.load( algorithm_name, algorithm_config=f'config/{algorithm_name}.json', diffusion_config=image_diffusion_config ) + + # Generate watermarked image + + # more tests for specific algorithms if applicable + if algorithm_name == "TR": + tr_w_patterns = ["seed_ring", "seed_zeros", "seed_rand", "rand", "zeros", "const", "ring"] + for w_pattern in tr_w_patterns: + watermark.config.w_pattern = w_pattern + watermarked_image = watermark.generate_watermarked_media(TEST_PROMPT_IMAGE) + assert watermarked_image is not None + assert isinstance(watermarked_image, Image.Image) + assert watermarked_image.size == (128, 128) + print(f"✓ {algorithm_name} generated watermarked image with {w_pattern} pattern successfully") + elif algorithm_name == "GM": + gm_w_patterns = ["seed_ring", "seed_zeros", "seed_rand", "rand", "zeros", "const", "signal_ring"] + for w_pattern in gm_w_patterns: + watermark.config.w_pattern = w_pattern + watermarked_image = watermark.generate_watermarked_media(TEST_PROMPT_IMAGE) + assert watermarked_image is not None + assert isinstance(watermarked_image, Image.Image) + assert watermarked_image.size == (128, 128) + print(f"✓ {algorithm_name} generated watermarked image with {w_pattern} pattern successfully") + elif algorithm_name == "SFW": + # Test SFW with wm_type="wm" (non-HSQR mode uses Fourier treering pattern) + # Create a temporary config file with wm_type="wm" + import json + import tempfile + sfw_config_wm = { + "algorithm_name": "SFW", + "w_seed": 42, + "wm_type": "wm", # Test with non-HSQR mode + "delta": 1, + "w_channel": 3, + "threshold": 50 + } + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(sfw_config_wm, f) + temp_config_path = f.name + + try: + watermark_wm = AutoWatermark.load( + "SFW", + algorithm_config=temp_config_path, + diffusion_config=image_diffusion_config + ) + # Generate with wm_type="wm" + watermarked_image_wm = watermark_wm.generate_watermarked_media(TEST_PROMPT_IMAGE) + assert watermarked_image_wm is not None + assert isinstance(watermarked_image_wm, Image.Image) + assert watermarked_image_wm.size == (128, 128) + print(f" ✓ SFW wm_type='wm' generation passed") + finally: + import os + os.unlink(temp_config_path) + + watermarked_image = watermark.generate_watermarked_media(TEST_PROMPT_IMAGE) # Validate output assert watermarked_image is not None assert isinstance(watermarked_image, Image.Image) - assert watermarked_image.size == (IMAGE_SIZE[1], IMAGE_SIZE[0]) + assert watermarked_image.size == (128, 128) print(f"✓ {algorithm_name} generated watermarked image successfully") @@ -142,6 +202,53 @@ def test_image_watermark_detection(algorithm_name, image_diffusion_config, skip_ assert detection_result_wm is not None assert isinstance(detection_result_wm, dict) assert detection_result_wm['is_watermarked'] is True + + # Test other detector_type for specific algorithms if applicable + detector_types = [] + if algorithm_name == "RI": + modes = ['real', 'imag'] + for mode in modes: + watermark.config.mode = mode + detection_result_mode = watermark.detect_watermark_in_media(watermarked_image) + elif algorithm_name == "TR": + detection_result_mode = watermark.detect_watermark_in_media(watermarked_image, detector_type='p_value') + assert detection_result_mode is not None + assert isinstance(detection_result_mode, dict) + elif algorithm_name == "GS": + # Test GS with chacha=False (non-ChaCha mode uses simple XOR key) + # Create a temporary config file with chacha=False + import json + import tempfile + gs_config_no_chacha = { + "algorithm_name": "GS", + "channel_copy": 1, + "wm_key": 42, + "hw_copy": 8, + "chacha": False, # Test with chacha disabled + "chacha_key_seed": 123456, + "chacha_nonce_seed": 789012, + "threshold": 0.7 + } + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(gs_config_no_chacha, f) + temp_config_path = f.name + + try: + watermark_no_chacha = AutoWatermark.load( + "GS", + algorithm_config=temp_config_path, + diffusion_config=image_diffusion_config + ) + # Generate and detect with chacha=False + watermarked_image_no_chacha = watermark_no_chacha.generate_watermarked_media(TEST_PROMPT_IMAGE) + detection_result_no_chacha = watermark_no_chacha.detect_watermark_in_media(watermarked_image_no_chacha) + assert detection_result_no_chacha is not None + assert isinstance(detection_result_no_chacha, dict) + assert detection_result_no_chacha['is_watermarked'] is True + print(f" ✓ GS chacha=False detection passed: {detection_result_no_chacha}") + finally: + import os + os.unlink(temp_config_path) # Detect watermark in unwatermarked image detection_result_unwm = watermark.detect_watermark_in_media(unwatermarked_image) @@ -190,6 +297,11 @@ def test_video_watermark_generation(algorithm_name, video_diffusion_config, skip pytest.skip("Generation tests skipped by --skip-generation flag") try: + # smaller test params for generation tests + video_diffusion_config.num_inference_steps = 10 + video_diffusion_config.num_frames = 8 + video_diffusion_config.image_size = (128, 128) + watermark = AutoWatermark.load( algorithm_name, algorithm_config=f'config/{algorithm_name}.json', @@ -697,6 +809,12 @@ def test_watermark_visualization(algorithm_name, image_diffusion_config, video_d # Call the method method(**params) + + # add `channel` parameter if needed + if 'channel' in sig.parameters: + params['channel'] = 0 + method(**params) + subclass_tested.append(method_name) plt.close(fig) except Exception as e: @@ -717,3 +835,537 @@ def test_watermark_visualization(algorithm_name, image_diffusion_config, video_d pytest.skip(f"{algorithm_name} visualization not fully implemented: {e}") except Exception as e: pytest.fail(f"Failed to test visualization for {algorithm_name}: {e}") + + +# ============================================================================ +# Test Cases - BaseVisualizer Unit Tests +# ============================================================================ + +class TestBaseVisualizerMethods: + """Unit tests for BaseVisualizer's visualize(), _draw_single_image, and _draw_video_frames methods.""" + + @pytest.fixture + def mock_data_for_image(self): + """Create mock DataForVisualization for image tests.""" + import torch + from unittest.mock import MagicMock + + mock_data = MagicMock() + mock_data.algorithm_name = "TestAlgorithm" + + # Create a test image (PIL Image) + test_image = Image.new("RGB", (64, 64), color=(128, 64, 192)) + mock_data.image = test_image + + # Create test latents for image: [B, C, H, W] + mock_data.orig_watermarked_latents = torch.randn(1, 4, 8, 8) + mock_data.reversed_latents = [torch.randn(1, 4, 8, 8) for _ in range(5)] + + return mock_data + + @pytest.fixture + def mock_data_for_video(self): + """Create mock DataForVisualization for video tests.""" + import torch + from unittest.mock import MagicMock + import numpy as np + + mock_data = MagicMock() + mock_data.algorithm_name = "TestVideoAlgorithm" + + # Create test video frames (list of PIL Images) + video_frames = [Image.new("RGB", (64, 64), color=(i * 30, 100, 200)) for i in range(8)] + mock_data.video_frames = video_frames + mock_data.image = None + + # Create test latents for video: [B, C, F, H, W] + mock_data.orig_watermarked_latents = torch.randn(1, 4, 8, 8, 8) + mock_data.reversed_latents = [torch.randn(1, 4, 8, 8, 8) for _ in range(5)] + + return mock_data + + @pytest.fixture + def image_visualizer(self, mock_data_for_image): + """Create a concrete visualizer for image tests.""" + from visualize.base import BaseVisualizer + + class ConcreteImageVisualizer(BaseVisualizer): + """Concrete implementation for testing.""" + pass + + return ConcreteImageVisualizer( + data_for_visualization=mock_data_for_image, + dpi=100, + is_video=False + ) + + @pytest.fixture + def video_visualizer(self, mock_data_for_video): + """Create a concrete visualizer for video tests.""" + from visualize.base import BaseVisualizer + + class ConcreteVideoVisualizer(BaseVisualizer): + """Concrete implementation for testing.""" + pass + + return ConcreteVideoVisualizer( + data_for_visualization=mock_data_for_video, + dpi=100, + is_video=True + ) + + # ------------------------------------------------------------------------- + # Tests for _draw_single_image + # ------------------------------------------------------------------------- + + @pytest.mark.visualization + def test_draw_single_image_with_pil_image(self, image_visualizer): + """Test _draw_single_image with PIL Image input.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + + result = image_visualizer._draw_single_image( + title="Test Single Image", + ax=ax + ) + + assert result is not None + assert result == ax + plt.close(fig) + + @pytest.mark.visualization + def test_draw_single_image_with_tensor(self, mock_data_for_image): + """Test _draw_single_image with tensor input.""" + import torch + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Replace image with tensor + mock_data_for_image.image = torch.rand(1, 3, 64, 64) # [B, C, H, W] + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_image, + dpi=100, + is_video=False + ) + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + result = visualizer._draw_single_image(title="Tensor Image", ax=ax) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_single_image_with_3d_tensor(self, mock_data_for_image): + """Test _draw_single_image with 3D tensor input [C, H, W].""" + import torch + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Replace image with 3D tensor + mock_data_for_image.image = torch.rand(3, 64, 64) # [C, H, W] + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_image, + dpi=100, + is_video=False + ) + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + result = visualizer._draw_single_image(title="3D Tensor Image", ax=ax) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_single_image_with_normalized_tensor(self, mock_data_for_image): + """Test _draw_single_image with tensor in [-1, 1] range.""" + import torch + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Replace image with tensor in [-1, 1] range + mock_data_for_image.image = torch.rand(1, 3, 64, 64) * 2 - 1 # [-1, 1] + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_image, + dpi=100, + is_video=False + ) + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + result = visualizer._draw_single_image(title="Normalized Tensor", ax=ax) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_single_image_empty_title(self, image_visualizer): + """Test _draw_single_image with empty title.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + + result = image_visualizer._draw_single_image(title="", ax=ax) + + assert result is not None + plt.close(fig) + + # ------------------------------------------------------------------------- + # Tests for _draw_video_frames + # ------------------------------------------------------------------------- + + @pytest.mark.visualization + def test_draw_video_frames_basic(self, video_visualizer): + """Test _draw_video_frames with basic parameters.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + result = video_visualizer._draw_video_frames( + title="Test Video Frames", + num_frames=4, + ax=ax + ) + + assert result is not None + assert result == ax + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_single_frame(self, video_visualizer): + """Test _draw_video_frames with single frame.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + + result = video_visualizer._draw_video_frames( + title="Single Frame", + num_frames=1, + ax=ax + ) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_all_frames(self, video_visualizer): + """Test _draw_video_frames requesting more frames than available.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + # Request 16 frames but only 8 are available + result = video_visualizer._draw_video_frames( + title="All Frames", + num_frames=16, + ax=ax + ) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_with_numpy_frames(self, mock_data_for_video): + """Test _draw_video_frames with numpy array frames.""" + import numpy as np + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Replace frames with numpy arrays + mock_data_for_video.video_frames = [ + np.random.rand(64, 64, 3).astype(np.float32) for _ in range(8) + ] + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_video, + dpi=100, + is_video=True + ) + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + result = visualizer._draw_video_frames(title="Numpy Frames", num_frames=4, ax=ax) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_with_tensor_frames(self, mock_data_for_video): + """Test _draw_video_frames with tensor frames.""" + import torch + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Replace frames with tensors [C, H, W] + mock_data_for_video.video_frames = [ + torch.rand(3, 64, 64) for _ in range(8) + ] + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_video, + dpi=100, + is_video=True + ) + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + result = visualizer._draw_video_frames(title="Tensor Frames", num_frames=4, ax=ax) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_no_frames_raises_error(self, mock_data_for_video): + """Test _draw_video_frames raises error when no frames available.""" + import matplotlib.pyplot as plt + from visualize.base import BaseVisualizer + + class ConcreteVisualizer(BaseVisualizer): + pass + + # Remove video_frames + mock_data_for_video.video_frames = None + + visualizer = ConcreteVisualizer( + data_for_visualization=mock_data_for_video, + dpi=100, + is_video=True + ) + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + with pytest.raises(ValueError, match="No video frames available"): + visualizer._draw_video_frames(title="No Frames", num_frames=4, ax=ax) + + plt.close(fig) + + @pytest.mark.visualization + def test_draw_video_frames_empty_title(self, video_visualizer): + """Test _draw_video_frames with empty title.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + result = video_visualizer._draw_video_frames(title="", num_frames=4, ax=ax) + + assert result is not None + plt.close(fig) + + # ------------------------------------------------------------------------- + # Tests for draw_watermarked_image (dispatches to _draw_single_image or _draw_video_frames) + # ------------------------------------------------------------------------- + + @pytest.mark.visualization + def test_draw_watermarked_image_dispatches_to_single_image(self, image_visualizer): + """Test draw_watermarked_image dispatches to _draw_single_image for images.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + + result = image_visualizer.draw_watermarked_image( + title="Watermarked Image", + ax=ax + ) + + assert result is not None + plt.close(fig) + + @pytest.mark.visualization + def test_draw_watermarked_image_dispatches_to_video_frames(self, video_visualizer): + """Test draw_watermarked_image dispatches to _draw_video_frames for videos.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + result = video_visualizer.draw_watermarked_image( + title="Watermarked Video", + num_frames=4, + ax=ax + ) + + assert result is not None + plt.close(fig) + + # ------------------------------------------------------------------------- + # Tests for visualize() method + # ------------------------------------------------------------------------- + + @pytest.mark.visualization + def test_visualize_single_method(self, image_visualizer, tmp_path): + """Test visualize() with single method.""" + result = image_visualizer.visualize( + rows=1, + cols=1, + methods=["draw_watermarked_image"], + figsize=(5, 5) + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_multiple_methods(self, image_visualizer, tmp_path): + """Test visualize() with multiple methods in grid layout.""" + result = image_visualizer.visualize( + rows=2, + cols=2, + methods=[ + "draw_watermarked_image", + "draw_orig_latents", + "draw_orig_latents_fft", + "draw_inverted_latents" + ], + figsize=(10, 10), + method_kwargs=[ + {"title": "Image"}, + {"channel": 0, "title": "Original Latents Ch0"}, + {"channel": 0, "title": "FFT Ch0"}, + {"channel": 0, "title": "Inverted Latents Ch0"} + ] + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_saves_to_file(self, image_visualizer, tmp_path): + """Test visualize() saves figure to file.""" + save_path = str(tmp_path / "test_visualization.png") + + result = image_visualizer.visualize( + rows=1, + cols=1, + methods=["draw_watermarked_image"], + save_path=save_path + ) + + assert result is not None + assert (tmp_path / "test_visualization.png").exists() + + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_with_default_figsize(self, image_visualizer): + """Test visualize() with default figsize calculation.""" + result = image_visualizer.visualize( + rows=2, + cols=2, + methods=[ + "draw_watermarked_image", + "draw_orig_latents", + "draw_orig_latents_fft", + "draw_inverted_latents" + ], + method_kwargs=[ + {}, + {"channel": 0}, + {"channel": 0}, + {"channel": 0} + ] + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_mismatched_layout_raises_error(self, image_visualizer): + """Test visualize() raises error when methods don't match layout.""" + with pytest.raises(ValueError, match="not compatible with the layout"): + image_visualizer.visualize( + rows=2, + cols=2, + methods=["draw_watermarked_image"] # Only 1 method for 2x2 layout + ) + + @pytest.mark.visualization + def test_visualize_invalid_method_raises_error(self, image_visualizer): + """Test visualize() raises error for invalid method name.""" + with pytest.raises(ValueError, match="Method .* not found"): + image_visualizer.visualize( + rows=1, + cols=1, + methods=["nonexistent_method"] + ) + + @pytest.mark.visualization + def test_visualize_video_with_frame_selection(self, video_visualizer, tmp_path): + """Test visualize() for video with frame parameter.""" + result = video_visualizer.visualize( + rows=1, + cols=2, + methods=[ + "draw_watermarked_image", + "draw_orig_latents" + ], + method_kwargs=[ + {"num_frames": 4}, + {"channel": 0, "frame": 0} + ] + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_single_row(self, image_visualizer): + """Test visualize() with single row layout.""" + result = image_visualizer.visualize( + rows=1, + cols=3, + methods=[ + "draw_watermarked_image", + "draw_orig_latents", + "draw_orig_latents_fft" + ], + method_kwargs=[ + {}, + {"channel": 0}, + {"channel": 0} + ] + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) + + @pytest.mark.visualization + def test_visualize_single_column(self, image_visualizer): + """Test visualize() with single column layout.""" + result = image_visualizer.visualize( + rows=3, + cols=1, + methods=[ + "draw_watermarked_image", + "draw_orig_latents", + "draw_orig_latents_fft" + ], + method_kwargs=[ + {}, + {"channel": 0}, + {"channel": 0} + ] + ) + + assert result is not None + import matplotlib.pyplot as plt + plt.close(result) diff --git a/utils/__init__.py b/utils/__init__.py index bb43b0b..f662ca6 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -12,18 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Utility functions for MarkDiffusion. +"""Utility functions and helpers for MarkDiffusion.""" -This module provides various utility functions for callbacks, diffusion configuration, -media processing, pipeline utilities, and general utilities. -""" +from .utils import set_random_seed, load_config_file +from .diffusion_config import DiffusionConfig +from .media_utils import ( + pil_to_torch, + torch_to_numpy, + numpy_to_pil, + get_media_latents, + decode_media_latents, +) +from .pipeline_utils import ( + get_pipeline_type, + is_image_pipeline, + is_video_pipeline, + is_t2v_pipeline, + is_i2v_pipeline, +) __all__ = [ - 'callbacks', - 'diffusion_config', - 'media_utils', - 'pipeline_utils', - 'utils', + "set_random_seed", + "load_config_file", + "DiffusionConfig", + "pil_to_torch", + "torch_to_numpy", + "numpy_to_pil", + "get_media_latents", + "decode_media_latents", + "get_pipeline_type", + "is_image_pipeline", + "is_video_pipeline", + "is_t2v_pipeline", + "is_i2v_pipeline", ] - diff --git a/utils/diffusion_config.py b/utils/diffusion_config.py index 0eed9d6..f014558 100644 --- a/utils/diffusion_config.py +++ b/utils/diffusion_config.py @@ -97,24 +97,3 @@ def is_video_pipeline(self) -> bool: def is_image_pipeline(self) -> bool: """Check if this is an image pipeline.""" return self.pipeline_type == PIPELINE_TYPE_IMAGE - - @property - def pipeline_requirements(self) -> Dict[str, Any]: - """Get the requirements for this pipeline type.""" - if self.pipeline_type == PIPELINE_TYPE_IMAGE: - return { - "required_params": [], - "optional_params": ["height", "width", "num_images_per_prompt"] - } - elif self.pipeline_type == PIPELINE_TYPE_TEXT_TO_VIDEO: - return { - "required_params": ["num_frames"], - "optional_params": ["height", "width", "fps"] - } - elif self.pipeline_type == PIPELINE_TYPE_IMAGE_TO_VIDEO: - return { - "required_params": ["input_image", "num_frames"], - "optional_params": ["height", "width", "fps"] - } - else: - return {"required_params": [], "optional_params": []} \ No newline at end of file diff --git a/visualize/__init__.py b/visualize/__init__.py index a859de2..9e2299e 100644 --- a/visualize/__init__.py +++ b/visualize/__init__.py @@ -12,26 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Visualization module for MarkDiffusion. +"""Visualization module for watermarking analysis and results display.""" -This module provides visualization tools for different watermarking algorithms. -""" - -__all__ = [ - 'auto_visualization', - 'base', - 'data_for_visualization', - 'gm', - 'gs', - 'prc', - 'ri', - 'robin', - 'seal', - 'sfw', - 'tr', - 'videomark', - 'videoshield', - 'wind', -] +from .base import BaseVisualizer +from .data_for_visualization import DataForVisualization +from .auto_visualization import AutoVisualizer +__all__ = ["BaseVisualizer", "DataForVisualization", "AutoVisualizer"] diff --git a/visualize/base.py b/visualize/base.py index 5e26176..8f46f48 100644 --- a/visualize/base.py +++ b/visualize/base.py @@ -534,6 +534,10 @@ def _draw_single_image(self, # Normalize to 0-1 if needed if image_array.max() > 1.0: image_array = image_array / 255.0 + + # Normalize [-1, 1] range to [0, 1] for imshow + if image_array.min() < 0: + image_array = (image_array + 1.0) / 2.0 # Normalize [-1, 1] range to [0, 1] for imshow if image_array.min() < 0: @@ -639,6 +643,13 @@ def _draw_video_frames(self, frame = frame.astype(np.float32) if frame.max() > 1.0: frame = frame / 255.0 + + # Normalize [-1, 1] range to [0, 1] for imshow + if frame.min() < 0: + frame = (frame + 1.0) / 2.0 + + # Clip to valid range [0, 1] + frame = np.clip(frame, 0, 1) # Normalize [-1, 1] range to [0, 1] for imshow if frame.min() < 0: diff --git a/watermark/__init__.py b/watermark/__init__.py index 44a48df..bc901b7 100644 --- a/watermark/__init__.py +++ b/watermark/__init__.py @@ -13,26 +13,34 @@ # limitations under the License. """ -Watermark module for MarkDiffusion. +MarkDiffusion - An Open-Source Toolkit for Generative Watermarking of Latent Diffusion Models. -This module provides watermarking functionality for different algorithms -including GM, GS, PRC, RI, ROBIN, SEAL, SFW, TR, VideoMark, VideoShield, and WIND. +This package provides watermarking algorithms for diffusion models including: +- Tree-Ring (TR) +- Gaussian Shading (GS) +- RingID (RI) +- PRC +- ROBIN +- Gaussian Marking (GM) +- SFW (Stable Few Watermarks) +- SEAL +- WIND +- VideoMark +- VideoShield """ +__version__ = "0.1.6.post1" +__author__ = "THU-BPM MarkDiffusion Team" +__license__ = "Apache-2.0" + +from .base import BaseWatermark, BaseConfig +from .auto_watermark import AutoWatermark +from .auto_config import AutoConfig + __all__ = [ - 'auto_config', - 'auto_watermark', - 'base', - 'gm', - 'gs', - 'prc', - 'ri', - 'robin', - 'seal', - 'sfw', - 'tr', - 'videomark', - 'videoshield', - 'wind', + "__version__", + "BaseWatermark", + "BaseConfig", + "AutoWatermark", + "AutoConfig", ] - diff --git a/watermark/gm/gm.py b/watermark/gm/gm.py index 9b54b78..db6ab41 100644 --- a/watermark/gm/gm.py +++ b/watermark/gm/gm.py @@ -355,22 +355,22 @@ def _build_watermarking_mask(self) -> torch.Tensor: mask[:, :, base_mask] = True else: mask[:, self.config.w_channel, base_mask] = True - elif shape == "square": - anchor = self.latent_shape[-1] // 2 - sl = slice(anchor - self.config.w_radius, anchor + self.config.w_radius) - if self.config.w_channel == -1: - mask[:, :, sl, sl] = True - else: - mask[:, self.config.w_channel, sl, sl] = True - elif shape == "signal_circle": - mask = torch.zeros(self.latent_shape, dtype=torch.long, device=self.device) - label = 1 - for radius in self.radius_list: - base_mask = torch.tensor(circle_mask(self.latent_shape[-1], radius), device=self.device) - mask[:, :, base_mask] = label - label += 1 - elif shape == "no": - return mask + # elif shape == "square": + # anchor = self.latent_shape[-1] // 2 + # sl = slice(anchor - self.config.w_radius, anchor + self.config.w_radius) + # if self.config.w_channel == -1: + # mask[:, :, sl, sl] = True + # else: + # mask[:, self.config.w_channel, sl, sl] = True + # elif shape == "signal_circle": + # mask = torch.zeros(self.latent_shape, dtype=torch.long, device=self.device) + # label = 1 + # for radius in self.radius_list: + # base_mask = torch.tensor(circle_mask(self.latent_shape[-1], radius), device=self.device) + # mask[:, :, base_mask] = label + # label += 1 + # elif shape == "no": + # return mask else: raise NotImplementedError(f"Unsupported watermark mask shape: {shape}") @@ -382,14 +382,24 @@ def _build_gnr_restorer(self) -> Optional[GNRRestorer]: return None checkpoint_path = Path(checkpoint) repo = getattr(self.config, "huggingface_repo", None) - hf_dir=getattr(self.config, "hf_dir", None) + hf_dir = getattr(self.config, "hf_dir", None) if repo: - try: - hf_path = hf_hub_download(repo_id=repo, filename=Path(checkpoint).name,cache_dir=hf_dir) - print(f"Downloaded GNR checkpoint from Huggingface Hub: {hf_path}") - checkpoint_path = Path(hf_path) - except Exception as e: - raise FileNotFoundError(f"GNR checkpoint not found on ({repo}). error: {e}") + # Check if file already exists locally before downloading + local_path = checkpoint_path if checkpoint_path.is_file() else None + if hf_dir: + potential_local = Path(hf_dir) / Path(checkpoint).name + if potential_local.is_file(): + local_path = potential_local + if local_path and local_path.is_file(): + print(f"Using existing GNR checkpoint: {local_path}") + checkpoint_path = local_path + else: + try: + hf_path = hf_hub_download(repo_id=repo, filename=Path(checkpoint).name, cache_dir=hf_dir) + print(f"Downloaded GNR checkpoint from Huggingface Hub: {hf_path}") + checkpoint_path = Path(hf_path) + except Exception as e: + raise FileNotFoundError(f"GNR checkpoint not found on ({repo}). error: {e}") in_channels = self.config.latent_channels * (2 if self.config.gnr_classifier_type == 1 else 1) return GNRRestorer( checkpoint_path=checkpoint_path, @@ -410,13 +420,25 @@ def _build_fuser(self): "joblib is required to load the GaussMarker fuser. Install joblib or disable the fuser." ) repo = getattr(self.config, "huggingface_repo", None) - hf_dir=getattr(self.config, "hf_dir", None) + hf_dir = getattr(self.config, "hf_dir", None) + candidates = [] if repo: - try: - hf_path = hf_hub_download(repo_id=repo, filename=Path(checkpoint).name,cache_dir=hf_dir) - candidates = [Path(hf_path)] - except Exception as e: - raise FileNotFoundError(f"Fuser checkpoint not found on ({repo}). error: {e}") + # Check if file already exists locally before downloading + local_path = Path(checkpoint) if Path(checkpoint).is_file() else None + if hf_dir: + potential_local = Path(hf_dir) / Path(checkpoint).name + if potential_local.is_file(): + local_path = potential_local + if local_path and local_path.is_file(): + print(f"Using existing fuser checkpoint: {local_path}") + candidates = [local_path] + else: + try: + hf_path = hf_hub_download(repo_id=repo, filename=Path(checkpoint).name, cache_dir=hf_dir) + print(f"Downloaded fuser checkpoint from Huggingface Hub: {hf_path}") + candidates = [Path(hf_path)] + except Exception as e: + raise FileNotFoundError(f"Fuser checkpoint not found on ({repo}). error: {e}") base_dir = Path(__file__).resolve().parent candidates.append(base_dir / checkpoint) candidates.append(base_dir.parent.parent / checkpoint) @@ -448,30 +470,30 @@ def _inject_complex(self, latents: torch.Tensor) -> torch.Tensor: injected = torch.fft.ifft2(torch.fft.ifftshift(fft_latents, dim=(-1, -2))).real return injected - def _inject_seed(self, latents: torch.Tensor) -> torch.Tensor: - mask = self.watermarking_mask - injected = latents.clone() - injected[mask] = self.gt_patch[mask].clone() - return injected - - def _inject_signal(self, latents: torch.Tensor) -> torch.Tensor: - fft_latents = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2)) - mask = self.watermarking_mask - signals = extract_complex_sign(self.gt_patch) - fft_latents_signal = set_complex_sign(fft_latents, signals) - fft_latents[mask != 0] = fft_latents_signal[mask != 0] - injected = torch.fft.ifft2(torch.fft.ifftshift(fft_latents, dim=(-1, -2))).real - return injected + # def _inject_seed(self, latents: torch.Tensor) -> torch.Tensor: + # mask = self.watermarking_mask + # injected = latents.clone() + # injected[mask] = self.gt_patch[mask].clone() + # return injected + + # def _inject_signal(self, latents: torch.Tensor) -> torch.Tensor: + # fft_latents = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2)) + # mask = self.watermarking_mask + # signals = extract_complex_sign(self.gt_patch) + # fft_latents_signal = set_complex_sign(fft_latents, signals) + # fft_latents[mask != 0] = fft_latents_signal[mask != 0] + # injected = torch.fft.ifft2(torch.fft.ifftshift(fft_latents, dim=(-1, -2))).real + # return injected def inject_watermark(self, base_latents: torch.Tensor) -> torch.Tensor: base_latents = base_latents.to(self.device, dtype=torch.float32) injection = self.config.w_injection.lower() if "complex" in injection: watermarked = self._inject_complex(base_latents) - elif "seed" in injection: - watermarked = self._inject_seed(base_latents) - elif "signal" in injection: - watermarked = self._inject_signal(base_latents) + # elif "seed" in injection: + # watermarked = self._inject_seed(base_latents) + # elif "signal" in injection: + # watermarked = self._inject_signal(base_latents) else: raise NotImplementedError(f"Unsupported injection mode: {self.config.w_injection}") return watermarked.to(self.config.dtype) diff --git a/watermark/gm/train_GNR.py b/watermark/gm/train_GNR.py deleted file mode 100644 index 53903fc..0000000 --- a/watermark/gm/train_GNR.py +++ /dev/null @@ -1,423 +0,0 @@ -# modify from the official implementation of GaussMarker. To check it, see https://github.com/SunnierLee/GaussMarker -# run python train_GNR.py --train_steps 50000 --r 180 --s_min 1.0 --s_max 1.2 --fp 0.35 --neg_p 0.5 --model_nf 128 --batch_size 32 --num_workers 16 --w_info_path w1_256.pth -# After training, the model weight will be saved as `./ckpts/model_final.pth` - -import os -import argparse -import logging -from tqdm import tqdm -import datetime -import numpy as np -import torch -from torch import nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader, IterableDataset, TensorDataset -from torchvision import transforms -from torchvision.utils import save_image - -from scipy.stats import norm,truncnorm -from functools import reduce -from scipy.special import betainc -from Crypto.Cipher import ChaCha20 -from Crypto.Random import get_random_bytes - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels, out_channels) - ) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, - diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super(OutConv, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) - -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, nf=64, bilinear=False): - super(UNet, self).__init__() - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - - self.inc = (DoubleConv(n_channels, nf)) - self.down1 = (Down(nf, nf*2)) - self.down2 = (Down(nf*2, nf*4)) - self.down3 = (Down(nf*4, nf*8)) - factor = 2 if bilinear else 1 - # self.down4 = (Down(512, 1024 // factor)) - # self.up1 = (Up(1024, 512 // factor, bilinear)) - self.up2 = (Up(nf*8, nf*4 // factor, bilinear)) - self.up3 = (Up(nf*4, nf*2 // factor, bilinear)) - self.up4 = (Up(nf*2, nf, bilinear)) - self.outc = (OutConv(nf, n_classes)) - - def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - # x5 = self.down4(x4) - # print(x5.shape) - # x = self.up1(x5, x4) - # print(x4.shape) - x = self.up2(x4, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits - - def use_checkpointing(self): - self.inc = torch.utils.checkpoint(self.inc) - self.down1 = torch.utils.checkpoint(self.down1) - self.down2 = torch.utils.checkpoint(self.down2) - self.down3 = torch.utils.checkpoint(self.down3) - self.down4 = torch.utils.checkpoint(self.down4) - self.up1 = torch.utils.checkpoint(self.up1) - self.up2 = torch.utils.checkpoint(self.up2) - self.up3 = torch.utils.checkpoint(self.up3) - self.up4 = torch.utils.checkpoint(self.up4) - self.outc = torch.utils.checkpoint(self.outc) - -class Gaussian_Shading_chacha: - def __init__(self, ch_factor, w_factor, h_factor, fpr, user_number, watermark=None, key=None, nonce=None, m=None): - self.ch = ch_factor - self.w = w_factor - self.h = h_factor - self.nonce = nonce - self.key = key - self.watermark = watermark - self.m = m - self.latentlength = 4 * 64 * 64 - self.marklength = self.latentlength//(self.ch * self.w * self.h) - - self.threshold = 1 if self.h == 1 and self.w == 1 and self.ch == 1 else self.ch * self.w * self.h // 2 - self.tp_onebit_count = 0 - self.tp_bits_count = 0 - self.tau_onebit = None - self.tau_bits = None - - for i in range(self.marklength): - fpr_onebit = betainc(i+1, self.marklength-i, 0.5) - fpr_bits = betainc(i+1, self.marklength-i, 0.5) * user_number - if fpr_onebit <= fpr and self.tau_onebit is None: - self.tau_onebit = i / self.marklength - if fpr_bits <= fpr and self.tau_bits is None: - self.tau_bits = i / self.marklength - - def truncSampling(self, message): - z = np.zeros(self.latentlength) - denominator = 2.0 - ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)] - for i in range(self.latentlength): - dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1]) - dec_mes = int(dec_mes) - z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1]) - z = torch.from_numpy(z).reshape(1, 4, 64, 64).half() - return z - - def create_watermark_and_return_w(self): - if self.watermark is None: - self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.w, 64 // self.h]) - sd = self.watermark.repeat(1,self.ch,self.w,self.h) - m = self.stream_key_encrypt(sd.flatten().numpy()) - self.m = torch.from_numpy(m).reshape(1, 4, 64, 64) - w = self.truncSampling(self.m) - return w - - # def create_watermark_and_return_w_sd(self): - # self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.hw, 64 // self.hw]) - # sd = self.watermark.repeat(1,self.ch,self.hw,self.hw) - # m = self.stream_key_encrypt(sd.flatten().numpy()) - # w = self.truncSampling(m) - # return w, sd - - def create_watermark_and_return_w_m(self): - if self.watermark is None: - self.watermark = torch.randint(0, 2, [1, 4 // self.ch, 64 // self.w, 64 // self.h]) - sd = self.watermark.repeat(1, self.ch, self.w, self.h) - self.m = self.stream_key_encrypt(sd.flatten().numpy()) - w = self.truncSampling(self.m) - return w, torch.from_numpy(self.m).reshape(1, 4, 64, 64) - - def stream_key_encrypt(self, sd): - if self.key is None or self.nonce is None: - self.key = get_random_bytes(32) - self.nonce = get_random_bytes(12) - cipher = ChaCha20.new(key=self.key, nonce=self.nonce) - m_byte = cipher.encrypt(np.packbits(sd).tobytes()) - m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8)) - return m_bit - - def stream_key_decrypt(self, reversed_m): - cipher = ChaCha20.new(key=self.key, nonce=self.nonce) - sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes()) - sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8)) - sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8) - return sd_tensor - - # def stream_key_encrypt(self, sd): - # return sd - - # def stream_key_decrypt(self, reversed_m): - # return torch.from_numpy(reversed_m).reshape(1, 4, 64, 64).to(torch.uint8) - - def diffusion_inverse(self,watermark_r): - ch_stride = 4 // self.ch - w_stride = 64 // self.w - h_stride = 64 // self.h - ch_list = [ch_stride] * self.ch - w_list = [w_stride] * self.w - h_list = [h_stride] * self.h - split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0) - split_dim2 = torch.cat(torch.split(split_dim1, tuple(w_list), dim=2), dim=0) - split_dim3 = torch.cat(torch.split(split_dim2, tuple(h_list), dim=3), dim=0) - vote = torch.sum(split_dim3, dim=0).clone() - vote[vote <= self.threshold] = 0 - vote[vote > self.threshold] = 1 - return vote - - def pred_m_from_latent(self, reversed_w): - reversed_m = (reversed_w > 0).int() - return reversed_m - - def pred_w_from_latent(self, reversed_w): - reversed_m = (reversed_w > 0).int() - reversed_sd = self.stream_key_decrypt(reversed_m.flatten().cpu().numpy()) - reversed_watermark = self.diffusion_inverse(reversed_sd) - return reversed_watermark - - def pred_w_from_m(self, reversed_m): - reversed_sd = self.stream_key_decrypt(reversed_m.flatten().cpu().numpy()) - reversed_watermark = self.diffusion_inverse(reversed_sd) - return reversed_watermark - -def flip_tensor(tensor, flip_prob): - random_tensor = torch.rand(tensor.size()) - flipped_tensor = tensor.clone() - flipped_tensor[random_tensor < flip_prob] = 1 - flipped_tensor[random_tensor < flip_prob] - return flipped_tensor - -def Affine_random(latent, r, t, s_min, s_max, sh): - config = dict(degrees=(-r, r), translate=(t, t), scale_ranges=(s_min, s_max), shears=(-sh, sh), img_size=latent.shape[-2:]) - r, (tx, ty), s, (shx, shy) = transforms.RandomAffine.get_params(**config) - - b, c, w, h = latent.shape - new_latent = transforms.functional.affine(latent.view(b*c, 1, w, h), angle=r, translate=(tx, ty), scale=s, shear=(shx, shy), fill=999999) - new_latent = new_latent.view(b, c, w, h) - - mask = (new_latent[:, :1, ...] < 999998).float() - new_latent = new_latent * mask + torch.randint_like(new_latent, low=0, high=2) * (1-mask) - - return new_latent, (r, tx, ty, s) - -class LatentDataset_m(IterableDataset): - def __init__(self, watermark, args): - super(LatentDataset_m, self).__init__() - self.watermark = watermark - self.args = args - if self.args.num_watermarks > 1: - t_m = torch.from_numpy(self.watermark.m).reshape(1, 4, 64, 64) - o_m = torch.randint(low=0, high=2, size=(self.args.num_watermarks-1, 4, 64, 64)) - self.m = torch.cat([t_m, o_m]) - else: - self.m = torch.from_numpy(self.watermark.m).reshape(1, 4, 64, 64) - self.args.neg_p = 1 / (1 + self.args.num_watermarks) - - def __iter__(self): - while True: - random_index = torch.randint(0, self.args.num_watermarks, (1,)).item() - latents_m = self.m[random_index:random_index+1] - false_latents_m = torch.randint_like(latents_m, low=0, high=2) - # latents_m = latents_m[:, :1, ...] - # false_latents_m = false_latents_m[:, :1, ...] - if np.random.rand() > self.args.neg_p: - aug_latents_m, params = Affine_random(latents_m.float(), self.args.r, self.args.t, self.args.s_min, self.args.s_max, self.args.sh) - aug_latents_m = flip_tensor(aug_latents_m, self.args.fp) - yield aug_latents_m.squeeze(0).float(), latents_m.squeeze(0).float() - else: - aug_false_latents_m, params = Affine_random(false_latents_m.float(), self.args.r, self.args.t, self.args.s_min, self.args.s_max, self.args.sh) - aug_false_latents_m = flip_tensor(aug_false_latents_m, self.args.fp) - yield aug_false_latents_m.squeeze(0).float(), aug_false_latents_m.squeeze(0).float() - - -def set_logger(gfile_stream): - handler = logging.StreamHandler(gfile_stream) - formatter = logging.Formatter( - '%(levelname)s - %(filename)s - %(asctime)s - %(message)s') - handler.setFormatter(formatter) - logger = logging.getLogger() - logger.addHandler(handler) - logger.setLevel('INFO') - -def main(args): - os.makedirs(args.output_path, exist_ok=True) - gfile_stream = open(os.path.join(args.output_path, 'log.txt'), 'a') - set_logger(gfile_stream) - logging.info(args) - - num_steps = args.train_steps - bs = args.batch_size - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = UNet(4, 4, nf=args.model_nf).to(device) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - n_params = sum([np.prod(p.size()) for p in model_parameters]) - print('Number of trainable parameters in model: %d' % n_params) - logging.info('Number of trainable parameters in model: %d' % n_params) - - criterion = torch.nn.BCEWithLogitsLoss() - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - - if os.path.exists(args.w_info_path): - w_info = torch.load(args.w_info_path) - watermark = Gaussian_Shading_chacha(args.channel_copy, args.w_copy, args.h_copy, args.fpr, args.user_number, watermark=w_info["w"], m=w_info["m"], key=w_info["key"], nonce=w_info["nonce"]) - else: - watermark = Gaussian_Shading_chacha(args.channel_copy, args.w_copy, args.h_copy, args.fpr, args.user_number) - _ = watermark.create_watermark_and_return_w_m() - torch.save({"w": watermark.watermark, "m": watermark.m, "key": watermark.key, "nonce": watermark.nonce}, args.w_info_path) - - if args.sample_type == "m": - dataset = LatentDataset_m(watermark, args) - else: - raise NotImplementedError - - data_loader = DataLoader(dataset, batch_size=bs, num_workers=args.num_workers) - - for i, batch in tqdm(enumerate(data_loader)): - x, y = batch - # print(x[0, 0]) - x = x.to(device) - y = y.to(device).float() - - pred = model(x) - loss = criterion(pred, y) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # if i % 2000 == 0: - # torch.save(model.state_dict(), os.path.join(args.output_path, "model_{}.pth".format(i))) - if i % 2000 == 0: - # torch.save(model.state_dict(), os.path.join(args.output_path, "model_{}.pth".format(i))) - pred = F.sigmoid(pred) - save_imgs = torch.cat([x[:, :1, ...].unsqueeze(0), pred[:, :1, ...].unsqueeze(0), y[:, :1, ...].unsqueeze(0)]).permute(1, 0, 2, 3, 4).contiguous() - save_imgs = save_imgs.view(-1, save_imgs.shape[2], save_imgs.shape[3], save_imgs.shape[4])[:64] - save_image(save_imgs, os.path.join(args.output_path, "sample_{}.png".format(i)), nrow=6) - if i % 200 == 0: - print(loss.item()) - torch.save(model.state_dict(), os.path.join(args.output_path, "checkpoint.pth".format(i))) - logging.info("Iter {} Loss {}".format(i, loss.item())) - - if i > num_steps: - break - - torch.save(model.state_dict(), os.path.join(args.output_path, "model_final.pth")) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Gaussian Shading') - parser.add_argument('--num', default=1000, type=int) - parser.add_argument('--image_length', default=512, type=int) - parser.add_argument('--guidance_scale', default=7.5, type=float) - parser.add_argument('--num_inference_steps', default=50, type=int) - parser.add_argument('--num_inversion_steps', default=None, type=int) - parser.add_argument('--gen_seed', default=0, type=int) - parser.add_argument('--channel_copy', default=1, type=int) - parser.add_argument('--w_copy', default=8, type=int) - parser.add_argument('--h_copy', default=8, type=int) - parser.add_argument('--user_number', default=1000000, type=int) - parser.add_argument('--fpr', default=0.000001, type=float) - parser.add_argument('--output_path', default='./ckpts') - parser.add_argument('--chacha', action='store_true', help='chacha20 for cipher') - parser.add_argument('--reference_model', default=None) - parser.add_argument('--reference_model_pretrain', default=None) - parser.add_argument('--dataset_path', default='Gustavosta/Stable-Diffusion-Prompts') - parser.add_argument('--model_path', default='stabilityai/stable-diffusion-2-1-base') - parser.add_argument('--w_info_path', default='./w1.pth') - - parser.add_argument('--train_steps', type=int, default=10000) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--lr', type=float, default=1e-4) - parser.add_argument('--sample_type', default="m") - parser.add_argument('--r', type=float, default=8) - parser.add_argument('--t', type=float, default=0) - parser.add_argument('--s_min', type=float, default=0.5) - parser.add_argument('--s_max', type=float, default=2.0) - parser.add_argument('--sh', type=float, default=0) - parser.add_argument('--fp', type=float, default=0.00) - parser.add_argument('--neg_p', type=float, default=0.5) - parser.add_argument('--num_workers', type=int, default=8) - parser.add_argument('--num_watermarks', type=int, default=1) - - parser.add_argument('--model_nf', type=int, default=64) - parser.add_argument('--exp_description', '-ed', default="") - - args = parser.parse_args() - - # multiprocessing.set_start_method("spawn") - nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - # args.output_path = args.output_path + 'r{}_t{}_s_{}_{}_sh{}_fp{}_np{}_{}_{}'.format(args.r, args.t, args.s_min, args.s_max, args.sh, args.fp, args.neg_p, args.exp_description, nowTime) - args.output_path = args.output_path + '_' + args.exp_description - - main(args) diff --git a/watermark/robin/robin.py b/watermark/robin/robin.py index f41cb47..52c01c5 100644 --- a/watermark/robin/robin.py +++ b/watermark/robin/robin.py @@ -13,7 +13,7 @@ from visualize.data_for_visualization import DataForVisualization from evaluation.dataset import StableDiffusionPromptsDataset from utils.media_utils import get_random_latents -from .watermark_generator import OptimizedDataset, get_watermarking_mask, get_watermarking_pattern, inject_watermark, optimizer_wm_prompt, ROBINWatermarkedImageGeneration +from .watermark_generator import get_watermarking_mask, get_watermarking_pattern, inject_watermark, ROBINWatermarkedImageGeneration # OptimizedDataset, optimizer_wm_prompt from detection.robin.robin_detection import ROBINDetector class ROBINConfig(BaseConfig): @@ -94,26 +94,26 @@ def build_generation_params(self, **kwargs) -> Dict: return generation_params - def generate_clean_images(self, dataset: StableDiffusionPromptsDataset, **kwargs) -> List[Image.Image]: - """Generate clean images for optimization.""" - generation_params = self.build_generation_params(**kwargs, guidance_scale=self.config.data_guidance_scale) - - clean_images = [] - for i, prompt in enumerate(dataset): - formatted_img_filename = f"ori-lg{generation_params['guidance_scale']}-{i}.jpg" - if os.path.exists(os.path.join(self.config.output_img_dir, formatted_img_filename)): - clean_images.append(Image.open(os.path.join(self.config.output_img_dir, formatted_img_filename))) - else: - no_watermarked_image = self.config.pipe( - prompt, - **generation_params, - ).images[0] - clean_images.append(no_watermarked_image) + # def generate_clean_images(self, dataset: StableDiffusionPromptsDataset, **kwargs) -> List[Image.Image]: + # """Generate clean images for optimization.""" + # generation_params = self.build_generation_params(**kwargs, guidance_scale=self.config.data_guidance_scale) + + # clean_images = [] + # for i, prompt in enumerate(dataset): + # formatted_img_filename = f"ori-lg{generation_params['guidance_scale']}-{i}.jpg" + # if os.path.exists(os.path.join(self.config.output_img_dir, formatted_img_filename)): + # clean_images.append(Image.open(os.path.join(self.config.output_img_dir, formatted_img_filename))) + # else: + # no_watermarked_image = self.config.pipe( + # prompt, + # **generation_params, + # ).images[0] + # clean_images.append(no_watermarked_image) - os.makedirs(self.config.output_img_dir, exist_ok=True) - no_watermarked_image.save(os.path.join(self.config.output_img_dir, f"ori-lg{generation_params['guidance_scale']}-{i}.jpg")) + # os.makedirs(self.config.output_img_dir, exist_ok=True) + # no_watermarked_image.save(os.path.join(self.config.output_img_dir, f"ori-lg{generation_params['guidance_scale']}-{i}.jpg")) - return clean_images + # return clean_images def build_watermarking_args(self) -> types.SimpleNamespace: """Build watermarking arguments from config.""" @@ -153,55 +153,81 @@ def optimize_watermark(self, dataset: StableDiffusionPromptsDataset, watermarkin # Build hyperparameters hyperparameters = self.build_hyperparameters() - checkpoint_path=hf_hub_download(repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", filename=f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt", cache_dir=self.config.hf_dir) + filename = f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt" - # if os.path.exists(checkpoint_path): - if (not self.config.is_training_from_scratch): - if not os.path.exists(checkpoint_path): - os.makedirs(self.config.ckpt_dir, exist_ok=True) - from huggingface_hub import snapshot_download - snapshot_download( - repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", - local_dir=self.config.ckpt_dir, - repo_type="model", - local_dir_use_symlinks=False, - endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"), - ) - - print(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=self.config.device) - optimized_watermark = checkpoint['opt_wm'].to(self.config.device) - optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device) + # Check if file already exists locally before downloading + base_dir = os.path.dirname(os.path.abspath(__file__)) + checkpoint_path = None - return watermarking_mask, optimized_watermark, optimized_watermarking_signal - else: - print(f"Start training from scratch") - # Generate clean images - clean_images = self.generate_clean_images(dataset) - # Create training dataset - train_dataset = OptimizedDataset( - data_root=self.config.output_img_dir, - custom_dataset=dataset, - size=512, - repeats=10, - interpolation="bicubic", + # Check multiple potential local paths + potential_paths = [ + os.path.join(base_dir, self.config.hf_dir, filename) if self.config.hf_dir else None, + os.path.join(self.config.hf_dir, filename) if self.config.hf_dir else None, + os.path.join(self.config.ckpt_dir, filename), + ] + + for path in potential_paths: + if path and os.path.exists(path): + checkpoint_path = path + print(f"Using existing ROBIN checkpoint: {checkpoint_path}") + break + + # If not found locally, download from HuggingFace + if checkpoint_path is None: + checkpoint_path = hf_hub_download( + repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", + filename=filename, + cache_dir=self.config.hf_dir ) + print(f"Downloaded ROBIN checkpoint from Huggingface Hub: {checkpoint_path}") + + # if os.path.exists(checkpoint_path): + # if (not self.config.is_training_from_scratch): + if not os.path.exists(checkpoint_path): + os.makedirs(self.config.ckpt_dir, exist_ok=True) + from huggingface_hub import snapshot_download + snapshot_download( + repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", + local_dir=self.config.ckpt_dir, + repo_type="model", + local_dir_use_symlinks=False, + endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"), + ) + + print(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=self.config.device) + optimized_watermark = checkpoint['opt_wm'].to(self.config.device) + optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device) + + return watermarking_mask, optimized_watermark, optimized_watermarking_signal + # else: + # print(f"Start training from scratch") + # # Generate clean images + # clean_images = self.generate_clean_images(dataset) + # # Create training dataset + # train_dataset = OptimizedDataset( + # data_root=self.config.output_img_dir, + # custom_dataset=dataset, + # size=512, + # repeats=10, + # interpolation="bicubic", + # ) - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.train_batch_size, shuffle=True) + # train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config.train_batch_size, shuffle=True) - opt_watermark = get_watermarking_pattern(pipe=self.config.pipe, args=watermarking_args, device=self.config.device) + # opt_watermark = get_watermarking_pattern(pipe=self.config.pipe, args=watermarking_args, device=self.config.device) - optimized_watermark, optimized_watermarking_signal = optimizer_wm_prompt( - pipe=self.config.pipe, - dataloader=train_dataloader, - hyperparameters=hyperparameters, - mask=watermarking_mask, - opt_wm=opt_watermark, - save_path=self.config.ckpt_dir, - args=watermarking_args, - ) + # optimized_watermark, optimized_watermarking_signal = optimizer_wm_prompt( + # pipe=self.config.pipe, + # dataloader=train_dataloader, + # hyperparameters=hyperparameters, + # mask=watermarking_mask, + # opt_wm=opt_watermark, + # save_path=self.config.ckpt_dir, + # args=watermarking_args, + # ) - return watermarking_mask, optimized_watermark, optimized_watermarking_signal + # return watermarking_mask, optimized_watermark, optimized_watermarking_signal def initialize_detector(self, watermarking_mask, optimized_watermark) -> ROBINDetector: """Initialize the ROBIN detector.""" diff --git a/watermark/robin/watermark_generator.py b/watermark/robin/watermark_generator.py index 448e602..058941a 100644 --- a/watermark/robin/watermark_generator.py +++ b/watermark/robin/watermark_generator.py @@ -35,69 +35,69 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name -class OptimizedDataset(Dataset): - def __init__( - self, - data_root, - custom_dataset: BaseDataset, - size=512, - repeats=10, - interpolation="bicubic", - set="train", - center_crop=False, - ): - - self.data_root = data_root - self.size = size - self.center_crop = center_crop - - file_list = os.listdir(self.data_root) - file_list.sort(key=lambda x: int(x.split('-')[-1].split('.')[0])) # ori-lg7.5-xx.jpg - self.image_paths = [os.path.join(self.data_root, file_path) for file_path in file_list] - self.dataset = custom_dataset +# class OptimizedDataset(Dataset): +# def __init__( +# self, +# data_root, +# custom_dataset: BaseDataset, +# size=512, +# repeats=10, +# interpolation="bicubic", +# set="train", +# center_crop=False, +# ): + +# self.data_root = data_root +# self.size = size +# self.center_crop = center_crop + +# file_list = os.listdir(self.data_root) +# file_list.sort(key=lambda x: int(x.split('-')[-1].split('.')[0])) # ori-lg7.5-xx.jpg +# self.image_paths = [os.path.join(self.data_root, file_path) for file_path in file_list] +# self.dataset = custom_dataset - self.num_images = len(self.image_paths) - self._length = self.num_images +# self.num_images = len(self.image_paths) +# self._length = self.num_images - if set == "train": - self._length = self.num_images * repeats +# if set == "train": +# self._length = self.num_images * repeats - self.interpolation = { - "bilinear": Image.BILINEAR, - "bicubic": Image.BICUBIC, - "lanczos": Image.LANCZOS, - }[interpolation] +# self.interpolation = { +# "bilinear": Image.BILINEAR, +# "bicubic": Image.BICUBIC, +# "lanczos": Image.LANCZOS, +# }[interpolation] - def __len__(self): - return self._length +# def __len__(self): +# return self._length - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) +# def __getitem__(self, i): +# example = {} +# image = Image.open(self.image_paths[i % self.num_images]) - if not image.mode == "RGB": - image = image.convert("RGB") +# if not image.mode == "RGB": +# image = image.convert("RGB") - text = self.dataset[i % self.num_images] # __getitem__ of BaseDataset: return prompt[idx] - example["prompt"] = text +# text = self.dataset[i % self.num_images] # __getitem__ of BaseDataset: return prompt[idx] +# example["prompt"] = text - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) +# # default to score-sde preprocessing +# img = np.array(image).astype(np.uint8) - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - h, w, = ( - img.shape[0], - img.shape[1], - ) - img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] +# if self.center_crop: +# crop = min(img.shape[0], img.shape[1]) +# h, w, = ( +# img.shape[0], +# img.shape[1], +# ) +# img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] - image = Image.fromarray(img) - image = image.resize((self.size, self.size), resample=self.interpolation) +# image = Image.fromarray(img) +# image = image.resize((self.size, self.size), resample=self.interpolation) - example["pixel_values"] = pil_to_torch(image, normalize=False) # scale to [0, 1] +# example["pixel_values"] = pil_to_torch(image, normalize=False) # scale to [0, 1] - return example +# return example def circle_mask(size=64, r_max=10, r_min=0, x_offset=0, y_offset=0): @@ -199,234 +199,234 @@ def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args): return init_latents_w -def freeze_params(params): - for param in params: - param.requires_grad = False - -def to_ring(latent_fft, args): - # Calculate mean value for each ring - num_rings = args.w_up_radius - args.w_low_radius - r_max = args.w_up_radius - for i in range(num_rings): - # ring_mask = mask[..., (radii[i * 2] <= distances) & (distances < radii[i * 2 + 1])] - ring_mask = circle_mask(latent_fft.shape[-1], r_max=r_max, r_min=r_max-1) - ring_mean = latent_fft[:, args.w_channel,ring_mask].real.mean().item() - # print(f'ring mean: {ring_mean}') - latent_fft[:, args.w_channel,ring_mask] = ring_mean - r_max = r_max - 1 - - return latent_fft - -def optimizer_wm_prompt(pipe: StableDiffusionPipeline, - dataloader: OptimizedDataset, - hyperparameters: dict, - mask: torch.Tensor, - opt_wm: torch.Tensor, - save_path: str, - args: dict, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - eta: float = 0.0,) -> tuple[torch.Tensor, torch.Tensor]: - train_batch_size = hyperparameters["train_batch_size"] - gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"] - learning_rate = hyperparameters["learning_rate"] - max_train_steps = hyperparameters["max_train_steps"] - output_dir = hyperparameters["output_dir"] - gradient_checkpointing = hyperparameters["gradient_checkpointing"] - original_guidance_scale = hyperparameters["guidance_scale"] - optimized_guidance_scale = hyperparameters["optimized_guidance_scale"] - - # Check if checkpoint exists - checkpoint_path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{max_train_steps}.pt") - # checkpoint_path = "/workspace/panleyi/gs/ROBIN/ckpts/optimized_wm5-30_embedding-step-2000.pt" - if os.path.exists(checkpoint_path): - logger.info(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path) - opt_wm = checkpoint['opt_wm'].to(pipe.device) - opt_wm_embedding = checkpoint['opt_acond'].to(pipe.device) - return opt_wm, opt_wm_embedding - - text_encoder: CLIPTextModel = pipe.text_encoder - unet: UNet2DConditionModel = pipe.unet - vae: AutoencoderKL = pipe.vae - scheduler: DPMSolverMultistepScheduler = pipe.scheduler - - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) - freeze_params(text_encoder.parameters()) - - accelerator = Accelerator( - gradient_accumulation_steps=gradient_accumulation_steps, - mixed_precision=hyperparameters["mixed_precision"] - ) - - if gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - unet.enable_gradient_checkpointing() - - if hyperparameters["scale_lr"]: - learning_rate = ( - learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes - ) - - tester_prompt = '' # assume at the detection time, the original prompt is unknown - # null text, text_embedding.dtype = torch.float16 - do_classifier_free_guidance = False # guidance_scale = 1.0 - prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( - prompt=tester_prompt, - device=pipe.device, - do_classifier_free_guidance=do_classifier_free_guidance, - num_images_per_prompt=1, - ) +# def freeze_params(params): +# for param in params: +# param.requires_grad = False + +# def to_ring(latent_fft, args): +# # Calculate mean value for each ring +# num_rings = args.w_up_radius - args.w_low_radius +# r_max = args.w_up_radius +# for i in range(num_rings): +# # ring_mask = mask[..., (radii[i * 2] <= distances) & (distances < radii[i * 2 + 1])] +# ring_mask = circle_mask(latent_fft.shape[-1], r_max=r_max, r_min=r_max-1) +# ring_mean = latent_fft[:, args.w_channel,ring_mask].real.mean().item() +# # print(f'ring mean: {ring_mean}') +# latent_fft[:, args.w_channel,ring_mask] = ring_mean +# r_max = r_max - 1 + +# return latent_fft + +# def optimizer_wm_prompt(pipe: StableDiffusionPipeline, +# dataloader: OptimizedDataset, +# hyperparameters: dict, +# mask: torch.Tensor, +# opt_wm: torch.Tensor, +# save_path: str, +# args: dict, +# generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, +# eta: float = 0.0,) -> tuple[torch.Tensor, torch.Tensor]: +# train_batch_size = hyperparameters["train_batch_size"] +# gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"] +# learning_rate = hyperparameters["learning_rate"] +# max_train_steps = hyperparameters["max_train_steps"] +# output_dir = hyperparameters["output_dir"] +# gradient_checkpointing = hyperparameters["gradient_checkpointing"] +# original_guidance_scale = hyperparameters["guidance_scale"] +# optimized_guidance_scale = hyperparameters["optimized_guidance_scale"] + +# # Check if checkpoint exists +# checkpoint_path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{max_train_steps}.pt") +# # checkpoint_path = "/workspace/panleyi/gs/ROBIN/ckpts/optimized_wm5-30_embedding-step-2000.pt" +# if os.path.exists(checkpoint_path): +# logger.info(f"Loading checkpoint from {checkpoint_path}") +# checkpoint = torch.load(checkpoint_path) +# opt_wm = checkpoint['opt_wm'].to(pipe.device) +# opt_wm_embedding = checkpoint['opt_acond'].to(pipe.device) +# return opt_wm, opt_wm_embedding + +# text_encoder: CLIPTextModel = pipe.text_encoder +# unet: UNet2DConditionModel = pipe.unet +# vae: AutoencoderKL = pipe.vae +# scheduler: DPMSolverMultistepScheduler = pipe.scheduler + +# freeze_params(vae.parameters()) +# freeze_params(unet.parameters()) +# freeze_params(text_encoder.parameters()) + +# accelerator = Accelerator( +# gradient_accumulation_steps=gradient_accumulation_steps, +# mixed_precision=hyperparameters["mixed_precision"] +# ) + +# if gradient_checkpointing: +# text_encoder.gradient_checkpointing_enable() +# unet.enable_gradient_checkpointing() + +# if hyperparameters["scale_lr"]: +# learning_rate = ( +# learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes +# ) + +# tester_prompt = '' # assume at the detection time, the original prompt is unknown +# # null text, text_embedding.dtype = torch.float16 +# do_classifier_free_guidance = False # guidance_scale = 1.0 +# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( +# prompt=tester_prompt, +# device=pipe.device, +# do_classifier_free_guidance=do_classifier_free_guidance, +# num_images_per_prompt=1, +# ) - text_embeddings = prompt_embeds - - extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta) - - unet, text_encoder, dataloader,text_embeddings = accelerator.prepare( - unet, text_encoder, dataloader, text_embeddings - ) - - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move vae and unet to device - vae.to(accelerator.device, dtype=weight_dtype) - unet.to(accelerator.device, dtype=weight_dtype) - - # Keep vae in eval mode as we don't train it - vae.eval() - # Keep unet in train mode to enable gradient checkpointing - unet.train() - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) - num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) - - # Train! - total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(dataloader)}") - logger.info(f" Instantaneous batch size per device = {train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - global_step = 0 - - scaler = GradScaler(device=accelerator.device) - # pipe.scheduler.set_timesteps(1000) # need for compute the next state +# text_embeddings = prompt_embeds + +# extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta) + +# unet, text_encoder, dataloader,text_embeddings = accelerator.prepare( +# unet, text_encoder, dataloader, text_embeddings +# ) + +# weight_dtype = torch.float32 +# if accelerator.mixed_precision == "fp16": +# weight_dtype = torch.float16 +# elif accelerator.mixed_precision == "bf16": +# weight_dtype = torch.bfloat16 + +# # Move vae and unet to device +# vae.to(accelerator.device, dtype=weight_dtype) +# unet.to(accelerator.device, dtype=weight_dtype) + +# # Keep vae in eval mode as we don't train it +# vae.eval() +# # Keep unet in train mode to enable gradient checkpointing +# unet.train() + +# # We need to recalculate our total training steps as the size of the training dataloader may have changed. +# num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) +# num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + +# # Train! +# total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps + +# logger.info("***** Running training *****") +# logger.info(f" Num examples = {len(dataloader)}") +# logger.info(f" Instantaneous batch size per device = {train_batch_size}") +# logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") +# logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") +# logger.info(f" Total optimization steps = {max_train_steps}") +# # Only show the progress bar once on each machine. +# progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) +# progress_bar.set_description("Steps") +# global_step = 0 + +# scaler = GradScaler(device=accelerator.device) +# # pipe.scheduler.set_timesteps(1000) # need for compute the next state - do_classifier_free_guidance = False # guidance_scale = 1.0 - prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( - prompt='', - device=pipe.device, - do_classifier_free_guidance=do_classifier_free_guidance, - num_images_per_prompt=1, - ) +# do_classifier_free_guidance = False # guidance_scale = 1.0 +# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( +# prompt='', +# device=pipe.device, +# do_classifier_free_guidance=do_classifier_free_guidance, +# num_images_per_prompt=1, +# ) - opt_wm_embedding = prompt_embeds - null_embedding = opt_wm_embedding.clone() - total_time = 0 - with autocast(device_type=accelerator.device.type): - for epoch in range(num_train_epochs): - for step, batch in enumerate(dataloader): - with accelerator.accumulate(unet): - # Convert images to latent space - gt_tensor = batch["pixel_values"] - image = 2.0 * gt_tensor - 1.0 - latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample().detach() - latents = latents * 0.18215 - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - ori_timesteps = torch.randint(200, 300, (bsz,), device=latents.device).long() # 35~40steps - timesteps = len(scheduler) - 1 - ori_timesteps - - # Add noise to the latents according to the noise magnitude at each timestep - noisy_latents = scheduler.add_noise(latents, noise, timesteps) - opt_wm = opt_wm.to(noisy_latents.device).to(torch.complex64) # add wm to latents - - - ### detailed the inject_watermark function for fft.grad - init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(noisy_latents), dim=(-1, -2)) - init_latents_w_fft[mask] = opt_wm[mask].clone() - init_latents_w_fft.requires_grad = True - noisy_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real - ### Get the text embedding for conditioning CFG - prompt = batch["prompt"] - do_classifier_free_guidance = False # guidance_scale = 1.0 - prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( - prompt=prompt, - device=pipe.device, - do_classifier_free_guidance=do_classifier_free_guidance, - num_images_per_prompt=1, - ) +# opt_wm_embedding = prompt_embeds +# null_embedding = opt_wm_embedding.clone() +# total_time = 0 +# with autocast(device_type=accelerator.device.type): +# for epoch in range(num_train_epochs): +# for step, batch in enumerate(dataloader): +# with accelerator.accumulate(unet): +# # Convert images to latent space +# gt_tensor = batch["pixel_values"] +# image = 2.0 * gt_tensor - 1.0 +# latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample().detach() +# latents = latents * 0.18215 +# # Sample noise that we'll add to the latents +# noise = torch.randn_like(latents) +# bsz = latents.shape[0] +# # Sample a random timestep for each image +# ori_timesteps = torch.randint(200, 300, (bsz,), device=latents.device).long() # 35~40steps +# timesteps = len(scheduler) - 1 - ori_timesteps + +# # Add noise to the latents according to the noise magnitude at each timestep +# noisy_latents = scheduler.add_noise(latents, noise, timesteps) +# opt_wm = opt_wm.to(noisy_latents.device).to(torch.complex64) # add wm to latents + + +# ### detailed the inject_watermark function for fft.grad +# init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(noisy_latents), dim=(-1, -2)) +# init_latents_w_fft[mask] = opt_wm[mask].clone() +# init_latents_w_fft.requires_grad = True +# noisy_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real +# ### Get the text embedding for conditioning CFG +# prompt = batch["prompt"] +# do_classifier_free_guidance = False # guidance_scale = 1.0 +# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( +# prompt=prompt, +# device=pipe.device, +# do_classifier_free_guidance=do_classifier_free_guidance, +# num_images_per_prompt=1, +# ) - cond_embedding = prompt_embeds - text_embeddings = torch.cat([opt_wm_embedding, cond_embedding, null_embedding]) - text_embeddings.requires_grad = True - - ### Predict the noise residual with CFG - latent_model_input = torch.cat([noisy_latents] * 3) - latent_model_input = scheduler.scale_model_input(latent_model_input, timesteps) - noise_pred = unet(latent_model_input, ori_timesteps, encoder_hidden_states=text_embeddings).sample - noise_pred_wm, noise_pred_text, noise_pred_null = noise_pred.chunk(3) - noise_pred = noise_pred_null + original_guidance_scale * (noise_pred_text - noise_pred_null) + optimized_guidance_scale * (noise_pred_wm - noise_pred_null) # different guidance scale +# cond_embedding = prompt_embeds +# text_embeddings = torch.cat([opt_wm_embedding, cond_embedding, null_embedding]) +# text_embeddings.requires_grad = True + +# ### Predict the noise residual with CFG +# latent_model_input = torch.cat([noisy_latents] * 3) +# latent_model_input = scheduler.scale_model_input(latent_model_input, timesteps) +# noise_pred = unet(latent_model_input, ori_timesteps, encoder_hidden_states=text_embeddings).sample +# noise_pred_wm, noise_pred_text, noise_pred_null = noise_pred.chunk(3) +# noise_pred = noise_pred_null + original_guidance_scale * (noise_pred_text - noise_pred_null) + optimized_guidance_scale * (noise_pred_wm - noise_pred_null) # different guidance scale - ### get the predicted x0 tensor - scheduler._init_step_index(timesteps) - x0_latents = scheduler.convert_model_output(model_output=noise_pred, sample=noisy_latents) #predict x0 in one-step - x0_tensor = decode_media_latents(pipe=pipe, latents=x0_latents) +# ### get the predicted x0 tensor +# scheduler._init_step_index(timesteps) +# x0_latents = scheduler.convert_model_output(model_output=noise_pred, sample=noisy_latents) #predict x0 in one-step +# x0_tensor = decode_media_latents(pipe=pipe, latents=x0_latents) - loss_noise = F.mse_loss(x0_tensor.float(), gt_tensor.float(), reduction="mean") # pixel alignment - loss_wm = torch.mean(torch.abs(opt_wm[mask].real)) - loss_constrain = F.mse_loss(noise_pred_wm.float(), noise_pred_null.float(), reduction="mean") # prompt constraint - - ### optimize wm pattern and uncond prompt alternately - if (global_step // 500) % 2 == 0: - loss = 10 * loss_noise + loss_constrain - 0.00001 * loss_wm # opt wm pattern - accelerator.backward(loss) - with torch.no_grad(): - grads = init_latents_w_fft.grad - init_latents_w_fft = init_latents_w_fft - 1.0 * grads # update wm pattern - init_latents_w_fft = to_ring(init_latents_w_fft, args) - opt_wm = init_latents_w_fft.detach() - else: - loss = 10 * loss_noise + loss_constrain # opt prompt - accelerator.backward(loss) - with torch.no_grad(): - grads = text_embeddings.grad - text_embeddings = text_embeddings - 5e-04 * grads - opt_wm_embedding = text_embeddings[0].unsqueeze(0).detach() # update acond embedding - - - print(f'global_step: {global_step}, loss_mse: {loss_noise}, loss_wm: {loss_wm}, loss_cons: {loss_constrain},loss: {loss}') - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - if global_step % hyperparameters["save_steps"] == 0: - path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{global_step}.pt") - torch.save({'opt_acond': opt_wm_embedding, 'opt_wm': opt_wm.cpu()}, path) - - logs = {"loss": loss.detach().item()} - progress_bar.set_postfix(**logs) - - if global_step >= max_train_steps: - break - - accelerator.wait_for_everyone() - - return opt_wm, opt_wm_embedding +# loss_noise = F.mse_loss(x0_tensor.float(), gt_tensor.float(), reduction="mean") # pixel alignment +# loss_wm = torch.mean(torch.abs(opt_wm[mask].real)) +# loss_constrain = F.mse_loss(noise_pred_wm.float(), noise_pred_null.float(), reduction="mean") # prompt constraint + +# ### optimize wm pattern and uncond prompt alternately +# if (global_step // 500) % 2 == 0: +# loss = 10 * loss_noise + loss_constrain - 0.00001 * loss_wm # opt wm pattern +# accelerator.backward(loss) +# with torch.no_grad(): +# grads = init_latents_w_fft.grad +# init_latents_w_fft = init_latents_w_fft - 1.0 * grads # update wm pattern +# init_latents_w_fft = to_ring(init_latents_w_fft, args) +# opt_wm = init_latents_w_fft.detach() +# else: +# loss = 10 * loss_noise + loss_constrain # opt prompt +# accelerator.backward(loss) +# with torch.no_grad(): +# grads = text_embeddings.grad +# text_embeddings = text_embeddings - 5e-04 * grads +# opt_wm_embedding = text_embeddings[0].unsqueeze(0).detach() # update acond embedding + + +# print(f'global_step: {global_step}, loss_mse: {loss_noise}, loss_wm: {loss_wm}, loss_cons: {loss_constrain},loss: {loss}') + +# # Checks if the accelerator has performed an optimization step behind the scenes +# if accelerator.sync_gradients: +# progress_bar.update(1) +# global_step += 1 +# if global_step % hyperparameters["save_steps"] == 0: +# path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{global_step}.pt") +# torch.save({'opt_acond': opt_wm_embedding, 'opt_wm': opt_wm.cpu()}, path) + +# logs = {"loss": loss.detach().item()} +# progress_bar.set_postfix(**logs) + +# if global_step >= max_train_steps: +# break + +# accelerator.wait_for_everyone() + +# return opt_wm, opt_wm_embedding class ROBINStableDiffusionPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/watermark/tr/tr.py b/watermark/tr/tr.py index a69d8f9..a2cabd1 100644 --- a/watermark/tr/tr.py +++ b/watermark/tr/tr.py @@ -18,7 +18,7 @@ def initialize_parameters(self) -> None: self.w_seed = self.config_dict['w_seed'] self.w_channel = self.config_dict['w_channel'] self.w_pattern = self.config_dict['w_pattern'] - self.w_mask_shape = self.config_dict['w_mask_shape'] + # self.w_mask_shape = self.config_dict['w_mask_shape'] self.w_radius = self.config_dict['w_radius'] self.w_pattern_const = self.config_dict['w_pattern_const'] self.threshold = self.config_dict['threshold'] @@ -97,26 +97,26 @@ def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor: """Get the watermarking mask.""" watermarking_mask = torch.zeros(init_latents.shape, dtype=torch.bool).to(self.config.device) - if self.config.w_mask_shape == 'circle': - np_mask = self._circle_mask(init_latents.shape[-1], r=self.config.w_radius) - torch_mask = torch.tensor(np_mask).to(self.config.device) + # if self.config.w_mask_shape == 'circle': + np_mask = self._circle_mask(init_latents.shape[-1], r=self.config.w_radius) + torch_mask = torch.tensor(np_mask).to(self.config.device) - if self.config.w_channel == -1: - # all channels - watermarking_mask[:, :] = torch_mask - else: - watermarking_mask[:, self.config.w_channel] = torch_mask - elif self.config.w_mask_shape == 'square': - anchor_p = init_latents.shape[-1] // 2 - if self.config.w_channel == -1: - # all channels - watermarking_mask[:, :, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True - else: - watermarking_mask[:, self.config.w_channel, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True - elif self.config.w_mask_shape == 'no': - pass + if self.config.w_channel == -1: + # all channels + watermarking_mask[:, :] = torch_mask else: - raise NotImplementedError(f'w_mask_shape: {self.config.w_mask_shape}') + watermarking_mask[:, self.config.w_channel] = torch_mask + # elif self.config.w_mask_shape == 'square': + # anchor_p = init_latents.shape[-1] // 2 + # if self.config.w_channel == -1: + # # all channels + # watermarking_mask[:, :, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True + # else: + # watermarking_mask[:, self.config.w_channel, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True + # elif self.config.w_mask_shape == 'no': + # pass + # else: + # raise NotImplementedError(f'w_mask_shape: {self.config.w_mask_shape}') return watermarking_mask diff --git a/watermark/videomark/video_mark.py b/watermark/videomark/video_mark.py index 333a513..2c495d2 100644 --- a/watermark/videomark/video_mark.py +++ b/watermark/videomark/video_mark.py @@ -398,6 +398,8 @@ def _detect_watermark_in_video(self, inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) self.config.pipe.scheduler = inverse_scheduler + video_latents = video_latents.to(self.config.pipe.unet.dtype) + final_reversed_latents = self.config.pipe( prompt=prompt, latents=video_latents, @@ -479,6 +481,8 @@ def get_data_for_visualize(self, self.config.pipe.scheduler = inverse_scheduler collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True) + video_latents = video_latents.to(self.config.pipe.unet.dtype) + final_reversed_latents = self.config.pipe( prompt=prompt, latents=video_latents, diff --git a/watermark/videoshield/video_shield.py b/watermark/videoshield/video_shield.py index a84f651..29f32c6 100644 --- a/watermark/videoshield/video_shield.py +++ b/watermark/videoshield/video_shield.py @@ -436,6 +436,8 @@ def _detect_watermark_in_video(self, inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) self.config.pipe.scheduler = inverse_scheduler + video_latents = video_latents.to(self.config.pipe.unet.dtype) + final_reversed_latents = self.config.pipe( prompt=prompt, latents=video_latents, @@ -495,6 +497,8 @@ def get_data_for_visualize(self, self.config.pipe.scheduler = inverse_scheduler collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True) + video_latents = video_latents.to(self.config.pipe.unet.dtype) + final_reversed_latents = self.config.pipe( prompt=prompt, latents=video_latents,