diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 93d4ba45d65..10d9ce36973 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -58,7 +58,7 @@ def __init__( p: float = 0.5, scale: Sequence[float] = (0.02, 0.33), ratio: Sequence[float] = (0.3, 3.3), - value: float = 0.0, + value: Union[float, Sequence[float]] = 0.0, inplace: bool = False, ): super().__init__(p=p) @@ -77,13 +77,12 @@ def __init__( self.scale = scale self.ratio = ratio if isinstance(value, (int, float)): - self.value = [float(value)] + value = [float(value)] elif isinstance(value, str): - self.value = None + value = None elif isinstance(value, (list, tuple)): - self.value = [float(v) for v in value] - else: - self.value = value + value = [float(v) for v in value] + self.value: Optional[Sequence[float]] = value self.inplace = inplace self._log_ratio = torch.log(torch.tensor(self.ratio))