Skip to content

Commit 9b82df4

Browse files
authored
Remove _wrap() class method from base class Datapoint (#7805)
1 parent 2030d20 commit 9b82df4

File tree

7 files changed

+29
-18
lines changed

7 files changed

+29
-18
lines changed

test/test_datapoints.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ def test_detach_wrapping():
113113
assert type(image_detached) is datapoints.Image
114114

115115

116+
def test_no_wrapping_exceptions_with_metadata():
117+
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
118+
format, canvas_size = "XYXY", (32, 32)
119+
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
120+
121+
bbox = bbox.clone()
122+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
123+
124+
bbox = bbox.to(torch.float64)
125+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
126+
127+
bbox = bbox.detach()
128+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
129+
130+
assert not bbox.requires_grad
131+
bbox.requires_grad_(True)
132+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
133+
assert bbox.requires_grad
134+
135+
116136
def test_other_op_no_wrapping():
117137
image = datapoints.Image(torch.rand(3, 16, 16))
118138

torchvision/datapoints/_bounding_box.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint):
4242
canvas_size: Tuple[int, int]
4343

4444
@classmethod
45-
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
45+
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
46+
if isinstance(format, str):
47+
format = BoundingBoxFormat[format.upper()]
4648
bounding_boxes = tensor.as_subclass(cls)
4749
bounding_boxes.format = format
4850
bounding_boxes.canvas_size = canvas_size
@@ -59,10 +61,6 @@ def __new__(
5961
requires_grad: Optional[bool] = None,
6062
) -> BoundingBoxes:
6163
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
62-
63-
if isinstance(format, str):
64-
format = BoundingBoxFormat[format.upper()]
65-
6664
return cls._wrap(tensor, format=format, canvas_size=canvas_size)
6765

6866
@classmethod
@@ -71,7 +69,7 @@ def wrap_like(
7169
other: BoundingBoxes,
7270
tensor: torch.Tensor,
7371
*,
74-
format: Optional[BoundingBoxFormat] = None,
72+
format: Optional[Union[BoundingBoxFormat, str]] = None,
7573
canvas_size: Optional[Tuple[int, int]] = None,
7674
) -> BoundingBoxes:
7775
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
@@ -85,9 +83,6 @@ def wrap_like(
8583
omitted, it is taken from the reference.
8684
8785
"""
88-
if isinstance(format, str):
89-
format = BoundingBoxFormat[format.upper()]
90-
9186
return cls._wrap(
9287
tensor,
9388
format=format if format is not None else other.format,

torchvision/datapoints/_datapoint.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,9 @@ def _to_tensor(
3232
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
3333
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
3434

35-
@classmethod
36-
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
37-
return tensor.as_subclass(cls)
38-
3935
@classmethod
4036
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
41-
return cls._wrap(tensor)
37+
return tensor.as_subclass(cls)
4238

4339
_NO_WRAPPING_EXCEPTIONS = {
4440
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),

torchvision/datapoints/_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __new__(
4141
elif tensor.ndim == 2:
4242
tensor = tensor.unsqueeze(0)
4343

44-
return cls._wrap(tensor)
44+
return tensor.as_subclass(cls)
4545

4646
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
4747
return self._make_repr()

torchvision/datapoints/_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ def __new__(
3636
data = F.pil_to_tensor(data)
3737

3838
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
39-
return cls._wrap(tensor)
39+
return tensor.as_subclass(cls)

torchvision/datapoints/_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __new__(
3131
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
3232
if data.ndim < 4:
3333
raise ValueError
34-
return cls._wrap(tensor)
34+
return tensor.as_subclass(cls)
3535

3636
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
3737
return self._make_repr()

torchvision/prototype/datapoints/_label.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
1515
categories: Optional[Sequence[str]]
1616

1717
@classmethod
18-
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override]
18+
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
1919
label_base = tensor.as_subclass(cls)
2020
label_base.categories = categories
2121
return label_base

0 commit comments

Comments
 (0)