diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..c5bf1fb --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,60 @@ +# visdet - Repository Guidelines + +This is a minimal version of MMDetection, supporting only Swin Mask R-CNN for object detection and instance segmentation. + +## Repository Structure + +### Documentation Files + +- **README.md** - Main project documentation (keep at root) +- **AGENTS.md** - This file, repository-wide guidelines (keep at root) +- **Other capitalised markdown files** - Should be written to the `scratch_pads/` directory + +The `scratch_pads/` directory is intended for experimental documentation, planning documents, and other markdown files that are not part of the core repository documentation. This keeps the root directory clean and organized. + +## Key Principles + +1. **Single Model Focus**: Only support Swin Transformer + Mask R-CNN +2. **COCO Format**: Only support COCO-style datasets +3. **Essential Components**: Keep only what's needed for this specific model +4. **Absolute Imports**: Always use absolute imports (e.g., `from visdet.engine import X`) instead of relative imports (e.g., `from .engine import X`) to avoid circular import issues + +## What to Keep + +### Models + +- **Backbone**: SwinTransformer only +- **Neck**: FPN only +- **Head**: RPNHead, StandardRoIHead (with bbox and mask branches) +- **Detector**: MaskRCNN (two-stage detector) + +### Data + +- COCO dataset format support +- Essential data transforms for training/inference +- DetDataPreprocessor + +### Evaluation + +- COCO metrics for object detection and instance segmentation + +## What to Remove + +- All other backbones (ResNet, RegNet, etc.) +- All other detectors (YOLO, RetinaNet, DETR, etc.) +- All other necks (PAFPN, NAS-FPN, etc.) +- Video/tracking components +- 3D detection components +- Panoptic segmentation +- All other dataset formats + +## Dependencies + +- Training infrastructure (visengine) +- Image operations (viscv) +- pycocotools for COCO evaluation + +--- + +*For model-specific guidelines, see `visdet/AGENTS.md`* +*For personal development guidelines, see `~/.claude/CLAUDE.md` (local only)* diff --git a/libs/viscv/AGENTS.md b/libs/viscv/AGENTS.md deleted file mode 100644 index a731186..0000000 --- a/libs/viscv/AGENTS.md +++ /dev/null @@ -1,34 +0,0 @@ -# viscv - -This is a trimmed-down version of MMCV, designed to support only the minimal computer vision operations needed for Swin Mask R-CNN. - -## Key Principles - -1. **No C++ Extensions**: All operations should be implemented in pure PyTorch or use torchvision ops -2. **Minimal Dependencies**: Only depend on PyTorch, torchvision, and numpy -3. **No ext_loader**: Remove all references to MMCV's extension loader system - -## What to Keep - -- Image I/O operations -- Basic image transformations (resize, normalize, etc.) -- Color space conversions -- Bounding box operations (pure Python/PyTorch) -- Basic visualization utilities - -## What to Remove - -- All C++ extensions and CUDA kernels -- Complex ops that require compilation (deformable conv, etc.) -- Video processing capabilities -- 3D operations -- Ops not used by Swin Mask R-CNN - -## Import Structure - -All imports should use absolute paths: `from viscv.image import imread` - ---- - -*For machine learning guidelines, see the machine_learning/AGENTS.md file.* -*For general repository guidelines, see the root AGENTS.md file.* diff --git a/libs/viscv/BUILD.pkl b/libs/viscv/BUILD.pkl deleted file mode 100644 index 7af534c..0000000 --- a/libs/viscv/BUILD.pkl +++ /dev/null @@ -1,43 +0,0 @@ -amends "@grog/package.pkl" - -local py_sources = List( - "viscv/**/*", - "pyproject.toml" -) - -// @inferred_deps -local inferred_deps = List( - "//machine_learning/packages/visengine" -) - -targets { - new { - name = "viscv" - inputs { - ...py_sources - } - } - - new { - name = "test" - command = "uv run pytest" - inputs { - ...py_sources - "tests/**/*" - } - - dependencies { - ...inferred_deps - "//tools:uv" - } - - platform { - os { - "linux" - } - arch { - "amd64" - } - } - } -} diff --git a/libs/viscv/CLAUDE.md b/libs/viscv/CLAUDE.md deleted file mode 120000 index 47dc3e3..0000000 --- a/libs/viscv/CLAUDE.md +++ /dev/null @@ -1 +0,0 @@ -AGENTS.md \ No newline at end of file diff --git a/libs/viscv/pyproject.toml b/libs/viscv/pyproject.toml deleted file mode 100644 index 1c7ef48..0000000 --- a/libs/viscv/pyproject.toml +++ /dev/null @@ -1,29 +0,0 @@ -[project] -name = "viscv" -version = "0.1.0" -description = "Minimal computer vision operations for Swin Mask R-CNN" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "numpy>=2.0.0", - "Pillow>=10.0.0", - "opencv-python-headless>=4.8.0", - "matplotlib>=3.6.0", - "torch==2.5.1; platform_system == 'Linux' and platform_machine == 'x86_64'", - "torchvision==0.20.1; platform_system == 'Linux' and platform_machine == 'x86_64'", - "visengine", -] - -[build-system] -requires = ["setuptools==80.9.0", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -where = ["."] -include = ["viscv*"] - -[tool.setuptools.package-data] -'*' = ['*.yaml', '*.json'] - -[tool.uv.sources] -visengine = { workspace = true } diff --git a/libs/viscv/tests/test_cnn/test_build_layers.py b/libs/viscv/tests/test_cnn/test_build_layers.py deleted file mode 100644 index b044b6c..0000000 --- a/libs/viscv/tests/test_cnn/test_build_layers.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch.nn as nn -from viscv.cnn.bricks import ( - ConvTranspose2d, - build_conv_layer, - build_norm_layer, - build_upsample_layer, -) -from viscv.cnn.bricks.norm import infer_abbr as infer_norm_abbr -from visengine.registry import MODELS - - -def test_build_conv_layer(): - with pytest.raises(TypeError): - # cfg must be a dict - cfg = "Conv2d" - build_conv_layer(cfg) - - with pytest.raises(KeyError): - # `type` must be in cfg - cfg = dict(kernel_size=3) - build_conv_layer(cfg) - - with pytest.raises(KeyError): - # unsupported conv type - cfg = dict(type="FancyConv") - build_conv_layer(cfg) - - kwargs = dict(in_channels=4, out_channels=8, kernel_size=3, groups=2, dilation=2) - cfg = None - layer = build_conv_layer(cfg, **kwargs) - assert isinstance(layer, nn.Conv2d) - assert layer.in_channels == kwargs["in_channels"] - assert layer.out_channels == kwargs["out_channels"] - assert layer.kernel_size == (kwargs["kernel_size"], kwargs["kernel_size"]) - assert layer.groups == kwargs["groups"] - assert layer.dilation == (kwargs["dilation"], kwargs["dilation"]) - - cfg = dict(type="Conv") - layer = build_conv_layer(cfg, **kwargs) - assert isinstance(layer, nn.Conv2d) - assert layer.in_channels == kwargs["in_channels"] - assert layer.out_channels == kwargs["out_channels"] - assert layer.kernel_size == (kwargs["kernel_size"], kwargs["kernel_size"]) - assert layer.groups == kwargs["groups"] - assert layer.dilation == (kwargs["dilation"], kwargs["dilation"]) - - cfg = dict(type="deconv") - layer = build_conv_layer(cfg, **kwargs) - assert isinstance(layer, nn.ConvTranspose2d) - assert layer.in_channels == kwargs["in_channels"] - assert layer.out_channels == kwargs["out_channels"] - assert layer.kernel_size == (kwargs["kernel_size"], kwargs["kernel_size"]) - assert layer.groups == kwargs["groups"] - assert layer.dilation == (kwargs["dilation"], kwargs["dilation"]) - - -def test_infer_norm_abbr(): - with pytest.raises(TypeError): - # class_type must be a class - infer_norm_abbr(0) - - class MyNorm: - _abbr_ = "mn" - - assert infer_norm_abbr(MyNorm) == "mn" - - class FancyBatchNorm: - pass - - assert infer_norm_abbr(FancyBatchNorm) == "bn" - - class FancyInstanceNorm: - pass - - assert infer_norm_abbr(FancyInstanceNorm) == "in" - - class FancyLayerNorm: - pass - - assert infer_norm_abbr(FancyLayerNorm) == "ln" - - class FancyGroupNorm: - pass - - assert infer_norm_abbr(FancyGroupNorm) == "gn" - - class FancyNorm: - pass - - assert infer_norm_abbr(FancyNorm) == "norm_layer" - - -def test_build_norm_layer(): - with pytest.raises(TypeError): - # cfg must be a dict - cfg = "BN" - build_norm_layer(cfg, 3) - - with pytest.raises(KeyError): - # `type` must be in cfg - cfg = dict() - build_norm_layer(cfg, 3) - - with pytest.raises(KeyError): - # unsupported norm type - cfg = dict(type="FancyNorm") - build_norm_layer(cfg, 3) - - with pytest.raises(AssertionError): - # postfix must be int or str - cfg = dict(type="BN") - build_norm_layer(cfg, 3, postfix=[1, 2]) - - with pytest.raises(AssertionError): - # `num_groups` must be in cfg when using 'GN' - cfg = dict(type="GN") - build_norm_layer(cfg, 3) - - # test each type of norm layer in norm_cfg - abbr_mapping = { - "BN": "bn", - "BN1d": "bn", - "BN2d": "bn", - "BN3d": "bn", - "SyncBN": "bn", - "GN": "gn", - "LN": "ln", - "IN": "in", - "IN1d": "in", - "IN2d": "in", - "IN3d": "in", - } - for type_name, module in MODELS.module_dict.items(): - if type_name not in abbr_mapping: - continue - if type_name == "MMSyncBN": # skip MMSyncBN - continue - for postfix in ["_test", 1]: - for type_name_ in (type_name, module): - cfg = dict(type=type_name_) - if type_name == "GN": - cfg["num_groups"] = 3 - name, layer = build_norm_layer(cfg, 3, postfix=postfix) - assert name == abbr_mapping[type_name] + str(postfix) - assert isinstance(layer, module) - if type_name == "GN": - assert layer.num_channels == 3 - assert layer.num_groups == cfg["num_groups"] - elif type_name != "LN": - assert layer.num_features == 3 - - -def test_upsample_layer(): - with pytest.raises(TypeError): - # cfg must be a dict - cfg = "bilinear" - build_upsample_layer(cfg) - - with pytest.raises(KeyError): - # `type` must be in cfg - cfg = dict() - build_upsample_layer(cfg) - - with pytest.raises(KeyError): - # unsupported activation type - cfg = dict(type="FancyUpsample") - build_upsample_layer(cfg) - - for type_name in ["nearest", "bilinear"]: - cfg = dict(type=type_name) - layer = build_upsample_layer(cfg) - assert isinstance(layer, nn.Upsample) - assert layer.mode == type_name - - cfg = dict(type=nn.Upsample) - layer_from_cls = build_upsample_layer(cfg) - assert isinstance(layer_from_cls, nn.Upsample) - assert layer_from_cls.mode == "nearest" - - cfg = dict(type="deconv", in_channels=3, out_channels=3, kernel_size=3, stride=2) - layer = build_upsample_layer(cfg) - assert isinstance(layer, nn.ConvTranspose2d) - - for type_name in ("deconv", ConvTranspose2d): - cfg = dict(type=ConvTranspose2d) - kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2) - layer = build_upsample_layer(cfg, **kwargs) - assert isinstance(layer, nn.ConvTranspose2d) - assert layer.in_channels == kwargs["in_channels"] - assert layer.out_channels == kwargs["out_channels"] - assert layer.kernel_size == (kwargs["kernel_size"], kwargs["kernel_size"]) - assert layer.stride == (kwargs["stride"], kwargs["stride"]) - - layer = build_upsample_layer(cfg, 3, 3, 3, 2) - assert isinstance(layer, nn.ConvTranspose2d) - assert layer.in_channels == kwargs["in_channels"] - assert layer.out_channels == kwargs["out_channels"] - assert layer.kernel_size == (kwargs["kernel_size"], kwargs["kernel_size"]) - assert layer.stride == (kwargs["stride"], kwargs["stride"]) diff --git a/libs/viscv/tests/test_cnn/test_conv_module.py b/libs/viscv/tests/test_cnn/test_conv_module.py deleted file mode 100644 index aa1679d..0000000 --- a/libs/viscv/tests/test_cnn/test_conv_module.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from unittest.mock import patch - -import pytest -import torch -import torch.nn as nn -from viscv.cnn.bricks import ConvModule, HSigmoid, HSwish -from visengine.registry import MODELS - - -@MODELS.register_module() -class ExampleConv(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - norm_cfg=None, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.bias = bias - self.norm_cfg = norm_cfg - self.output_padding = (0, 0, 0) - self.transposed = False - - self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size) - self.init_weights() - - def forward(self, x): - x = self.conv0(x) - return x - - def init_weights(self): - nn.init.constant_(self.conv0.weight, 0) - - -def test_conv_module(): - with pytest.raises(AssertionError): - # conv_cfg must be a dict or None - conv_cfg = "conv" - ConvModule(3, 8, 2, conv_cfg=conv_cfg) - - with pytest.raises(AssertionError): - # norm_cfg must be a dict or None - norm_cfg = "norm" - ConvModule(3, 8, 2, norm_cfg=norm_cfg) - - with pytest.raises(KeyError): - # softmax is not supported - act_cfg = dict(type="softmax") - ConvModule(3, 8, 2, act_cfg=act_cfg) - - # conv + norm + act - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN")) - assert conv.with_activation - assert hasattr(conv, "activate") - assert conv.with_norm - assert hasattr(conv, "norm") - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - # conv + norm with efficient mode - efficient_conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN"), efficient_conv_bn_eval=True).eval() - plain_conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN"), efficient_conv_bn_eval=False).eval() - for efficient_param, plain_param in zip( - efficient_conv.state_dict().values(), - plain_conv.state_dict().values(), - strict=False, - ): - plain_param.copy_(efficient_param) - - efficient_mode_output = efficient_conv(x) - plain_mode_output = plain_conv(x) - assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5) - - # `conv` attribute can be dynamically modified in efficient mode - efficient_conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN"), efficient_conv_bn_eval=True).eval() - new_conv = nn.Conv2d(3, 8, 2).eval() - efficient_conv.conv = new_conv - efficient_mode_output = efficient_conv(x) - plain_mode_output = efficient_conv.activate(efficient_conv.norm(new_conv(x))) - assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5) - - # conv + act - conv = ConvModule(3, 8, 2) - assert conv.with_activation - assert hasattr(conv, "activate") - assert not conv.with_norm - assert conv.norm is None - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - # conv - conv = ConvModule(3, 8, 2, act_cfg=None) - assert not conv.with_norm - assert conv.norm is None - assert not conv.with_activation - assert not hasattr(conv, "activate") - x = torch.rand(1, 3, 256, 256) - output = conv(x) - assert output.shape == (1, 8, 255, 255) - - # conv with its own `init_weights` method - conv_module = ConvModule(3, 8, 2, conv_cfg=dict(type="ExampleConv"), act_cfg=None) - assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2)) - - # with_spectral_norm=True - conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True) - assert hasattr(conv.conv, "weight_orig") - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # padding_mode='reflect' - conv = ConvModule(3, 8, 3, padding=1, padding_mode="reflect") - assert isinstance(conv.padding_layer, nn.ReflectionPad2d) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # non-existing padding mode - with pytest.raises(KeyError): - conv = ConvModule(3, 8, 3, padding=1, padding_mode="non_exists") - - # leaky relu - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="LeakyReLU")) - assert isinstance(conv.activate, nn.LeakyReLU) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # tanh - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="Tanh")) - assert isinstance(conv.activate, nn.Tanh) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # Sigmoid - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="Sigmoid")) - assert isinstance(conv.activate, nn.Sigmoid) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # PReLU - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="PReLU")) - assert isinstance(conv.activate, nn.PReLU) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # HSwish - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="HSwish")) - # We always use our custom HSwish implementation - assert isinstance(conv.activate, HSwish) - - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - # HSigmoid - conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type="HSigmoid")) - assert isinstance(conv.activate, HSigmoid) - output = conv(x) - assert output.shape == (1, 8, 256, 256) - - -def test_bias(): - # bias: auto, without norm - conv = ConvModule(3, 8, 2) - assert conv.conv.bias is not None - - # bias: auto, with norm - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN")) - assert conv.conv.bias is None - - # bias: False, without norm - conv = ConvModule(3, 8, 2, bias=False) - assert conv.conv.bias is None - - # bias: True, with batch norm - with pytest.warns(UserWarning) as record: - ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type="BN")) - assert len(record) == 1 - assert record[0].message.args[0] == "Unnecessary conv bias before batch/instance norm" - - # bias: True, with instance norm - with pytest.warns(UserWarning) as record: - ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type="IN")) - assert len(record) == 1 - assert record[0].message.args[0] == "Unnecessary conv bias before batch/instance norm" - - # bias: True, with other norm - with pytest.warns(UserWarning) as record: - norm_cfg = dict(type="GN", num_groups=1) - ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg) - warnings.warn("No warnings") - assert len(record) == 1 - assert record[0].message.args[0] == "No warnings" - - -def conv_forward(self, x): - return x + "_conv" - - -def bn_forward(self, x): - return x + "_bn" - - -def relu_forward(self, x): - return x + "_relu" - - -@patch("torch.nn.ReLU.forward", relu_forward) -@patch("torch.nn.BatchNorm2d.forward", bn_forward) -@patch("torch.nn.Conv2d.forward", conv_forward) -def test_order(): - with pytest.raises(AssertionError): - # order must be a tuple - order = ["conv", "norm", "act"] - ConvModule(3, 8, 2, order=order) - - with pytest.raises(AssertionError): - # length of order must be 3 - order = ("conv", "norm") - ConvModule(3, 8, 2, order=order) - - with pytest.raises(AssertionError): - # order must be an order of 'conv', 'norm', 'act' - order = ("conv", "norm", "norm") - ConvModule(3, 8, 2, order=order) - - with pytest.raises(AssertionError): - # order must be an order of 'conv', 'norm', 'act' - order = ("conv", "norm", "something") - ConvModule(3, 8, 2, order=order) - - # ('conv', 'norm', 'act') - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN")) - out = conv("input") - assert out == "input_conv_bn_relu" - - # ('norm', 'conv', 'act') - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN"), order=("norm", "conv", "act")) - out = conv("input") - assert out == "input_bn_conv_relu" - - # ('conv', 'norm', 'act'), activate=False - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN")) - out = conv("input", activate=False) - assert out == "input_conv_bn" - - # ('conv', 'norm', 'act'), activate=False - conv = ConvModule(3, 8, 2, norm_cfg=dict(type="BN")) - out = conv("input", norm=False) - assert out == "input_conv_relu" diff --git a/libs/viscv/tests/test_cnn/test_transformer.py b/libs/viscv/tests/test_cnn/test_transformer.py deleted file mode 100644 index 1523aa4..0000000 --- a/libs/viscv/tests/test_cnn/test_transformer.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch -from viscv.cnn.bricks.transformer import FFN, build_dropout - - -def test_ffn(): - with pytest.raises(AssertionError): - # num_fcs should be no less than 2 - FFN(num_fcs=1) - ffn = FFN(dropout=0, add_identity=True) - - input_tensor = torch.rand(2, 20, 256) - input_tensor_nbc = input_tensor.transpose(0, 1) - assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) - residual = torch.rand_like(input_tensor) - torch.allclose( - ffn(input_tensor, residual=residual).sum(), - ffn(input_tensor).sum() + residual.sum() - input_tensor.sum(), - ) - - torch.allclose( - ffn(input_tensor, identity=residual).sum(), - ffn(input_tensor).sum() + residual.sum() - input_tensor.sum(), - ) - - # test with layer_scale - ffn = FFN(dropout=0, add_identity=True, layer_scale_init_value=0.1) - - input_tensor = torch.rand(2, 20, 256) - input_tensor_nbc = input_tensor.transpose(0, 1) - assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) - - -def test_build_dropout(): - # Test None config returns None - assert build_dropout(None) is None - - # Test dict config with type - cfg = dict(type="Dropout", drop_prob=0.5) - dropout = build_dropout(cfg) - assert isinstance(dropout, torch.nn.Dropout) - assert dropout.p == 0.5 - - # Test dict config with DropPath - cfg = dict(type="DropPath", drop_prob=0.3) - dropout = build_dropout(cfg) - assert dropout.drop_prob == 0.3 - - # Test with float (should create Dropout) - dropout = build_dropout(0.2) - assert isinstance(dropout, torch.nn.Dropout) - assert dropout.p == 0.2 diff --git a/libs/viscv/tests/test_cnn/test_transformer_complete.py b/libs/viscv/tests/test_cnn/test_transformer_complete.py deleted file mode 100644 index cf49d32..0000000 --- a/libs/viscv/tests/test_cnn/test_transformer_complete.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import pytest -import torch -from viscv.cnn.bricks.transformer import ( - FFN, - AdaptivePadding, - BaseTransformerLayer, - MultiheadAttention, - PatchEmbed, - PatchMerging, - TransformerLayerSequence, - build_attention, - build_feedforward_network, - build_transformer_layer, - build_transformer_layer_sequence, -) - - -def test_build_functions(): - """Test build functions for transformer components.""" - # Test build_attention - cfg = dict(type="MultiheadAttention", embed_dims=256, num_heads=8) - attn = build_attention(cfg) - assert isinstance(attn, MultiheadAttention) - assert attn.embed_dims == 256 - assert attn.num_heads == 8 - - # Test build_feedforward_network - cfg = dict(type="FFN", embed_dims=256, feedforward_channels=1024) - ffn = build_feedforward_network(cfg) - assert isinstance(ffn, FFN) - assert ffn.embed_dims == 256 - assert ffn.feedforward_channels == 1024 - - # Test build_transformer_layer - cfg = dict( - type="BaseTransformerLayer", - attn_cfgs=dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - ffn_cfgs=dict(type="FFN", embed_dims=256, feedforward_channels=1024), - operation_order=("self_attn", "norm", "ffn", "norm"), - ) - layer = build_transformer_layer(cfg) - assert isinstance(layer, BaseTransformerLayer) - assert layer.embed_dims == 256 - - # Test build_transformer_layer_sequence - cfg = dict( - type="TransformerLayerSequence", - transformerlayers=dict( - type="BaseTransformerLayer", - attn_cfgs=dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - ffn_cfgs=dict(type="FFN", embed_dims=256, feedforward_channels=1024), - operation_order=("self_attn", "norm", "ffn", "norm"), - ), - num_layers=2, - ) - sequence = build_transformer_layer_sequence(cfg) - assert isinstance(sequence, TransformerLayerSequence) - assert sequence.num_layers == 2 - - -def test_adaptive_padding(): - """Test AdaptivePadding module.""" - for padding in ("same", "corner"): - kernel_size = 16 - stride = 16 - dilation = 1 - input = torch.rand(1, 1, 15, 17) - adap_pad = AdaptivePadding(kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) - out = adap_pad(input) - # padding to divisible by 16 - assert (out.shape[2], out.shape[3]) == (16, 32) - - input = torch.rand(1, 1, 16, 17) - out = adap_pad(input) - # padding to divisible by 16 - assert (out.shape[2], out.shape[3]) == (16, 32) - - # assert only support "same" "corner" - with pytest.raises(AssertionError): - AdaptivePadding(kernel_size=kernel_size, stride=stride, dilation=dilation, padding=1) - - -def test_patch_embed(): - """Test PatchEmbed module.""" - B = 2 - H = 3 - W = 4 - C = 3 - embed_dims = 10 - kernel_size = 3 - stride = 1 - dummy_input = torch.rand(B, C, H, W) - patch_merge_1 = PatchEmbed( - in_channels=C, - embed_dims=embed_dims, - kernel_size=kernel_size, - stride=stride, - padding=0, - dilation=1, - norm_cfg=None, - ) - - x1, shape = patch_merge_1(dummy_input) - # test out shape - assert x1.shape == (2, 2, 10) - # test outsize is correct - assert shape == (1, 2) - # test L = out_h * out_w - assert shape[0] * shape[1] == x1.shape[1] - - -def test_patch_merging(): - """Test PatchMerging module.""" - # Test the model with int padding - in_c = 3 - out_c = 4 - kernel_size = 3 - stride = 3 - padding = 1 - dilation = 1 - bias = False - # test the case with int padding - patch_merge = PatchMerging( - in_channels=in_c, - out_channels=out_c, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - B, L, C = 1, 100, 3 - input_size = (10, 10) - x = torch.rand(B, L, C) - x_out, out_size = patch_merge(x, input_size) - assert x_out.size() == (1, 16, 4) - assert out_size == (4, 4) - # assert out size is consistent with real output - assert x_out.size(1) == out_size[0] * out_size[1] - - -def test_multiheadattention(): - """Test MultiheadAttention module.""" - batch_dim = 2 - embed_dim = 5 - num_query = 100 - attn_batch_first = MultiheadAttention( - embed_dims=5, - num_heads=5, - attn_drop=0, - proj_drop=0, - dropout_layer=dict(type="DropPath", drop_prob=0.0), - batch_first=True, - ) - - attn_query_first = MultiheadAttention( - embed_dims=5, - num_heads=5, - attn_drop=0, - proj_drop=0, - dropout_layer=dict(type="DropPath", drop_prob=0.0), - batch_first=False, - ) - - param_dict = dict(attn_query_first.named_parameters()) - for n, v in attn_batch_first.named_parameters(): - param_dict[n].data = v.data - - input_batch_first = torch.rand(batch_dim, num_query, embed_dim) - input_query_first = input_batch_first.transpose(0, 1) - - assert torch.allclose( - attn_query_first(input_query_first).sum(), - attn_batch_first(input_batch_first).sum(), - ) - - key_batch_first = torch.rand(batch_dim, num_query, embed_dim) - key_query_first = key_batch_first.transpose(0, 1) - - assert torch.allclose( - attn_query_first(input_query_first, key_query_first).sum(), - attn_batch_first(input_batch_first, key_batch_first).sum(), - ) - - identity = torch.ones_like(input_query_first) - - # check deprecated arguments can be used normally - assert torch.allclose( - attn_query_first(input_query_first, key_query_first, identity=identity).sum(), - attn_batch_first(input_batch_first, key_batch_first).sum() + identity.sum() - input_batch_first.sum(), - ) - - -def test_basetransformerlayer(): - """Test BaseTransformerLayer module.""" - # Test basic functionality - operation_order = ("self_attn", "norm", "ffn", "norm") - baselayer = BaseTransformerLayer( - operation_order=operation_order, - batch_first=True, - attn_cfgs=dict( - type="MultiheadAttention", - embed_dims=256, - num_heads=8, - ), - ffn_cfgs=dict( - type="FFN", - embed_dims=256, - feedforward_channels=1024, - ), - ) - - x = torch.rand(2, 10, 256) - output = baselayer(x) - assert output.shape == torch.Size([2, 10, 256]) - - # Test with cross attention - operation_order = ("self_attn", "norm", "cross_attn", "norm", "ffn", "norm") - baselayer = BaseTransformerLayer( - operation_order=operation_order, - batch_first=True, - attn_cfgs=[ - dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - ], - ffn_cfgs=dict( - type="FFN", - embed_dims=256, - feedforward_channels=1024, - ), - ) - - query = torch.rand(2, 10, 256) - key = value = torch.rand(2, 20, 256) - output = baselayer(query, key, value) - assert output.shape == torch.Size([2, 10, 256]) - - # Test pre-norm - operation_order = ("norm", "self_attn", "norm", "ffn") - baselayer = BaseTransformerLayer( - operation_order=operation_order, - batch_first=True, - attn_cfgs=dict( - type="MultiheadAttention", - embed_dims=256, - num_heads=8, - ), - ffn_cfgs=dict( - type="FFN", - embed_dims=256, - feedforward_channels=1024, - ), - ) - assert baselayer.pre_norm is True - - x = torch.rand(2, 10, 256) - output = baselayer(x) - assert output.shape == torch.Size([2, 10, 256]) - - -def test_transformerlayersequence(): - """Test TransformerLayerSequence module.""" - # Test with dict config - transformerlayers = dict( - type="BaseTransformerLayer", - attn_cfgs=dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - ffn_cfgs=dict(type="FFN", embed_dims=256, feedforward_channels=1024), - operation_order=("self_attn", "norm", "ffn", "norm"), - ) - num_layers = 3 - - sequence = TransformerLayerSequence(transformerlayers=transformerlayers, num_layers=num_layers) - - assert sequence.num_layers == 3 - assert len(sequence.layers) == 3 - assert sequence.embed_dims == 256 - - # Test forward - query = key = value = torch.rand(2, 10, 256) - output = sequence(query, key, value) - assert output.shape == torch.Size([2, 10, 256]) - - # Test with list config - transformerlayers = [ - dict( - type="BaseTransformerLayer", - attn_cfgs=dict(type="MultiheadAttention", embed_dims=256, num_heads=8), - ffn_cfgs=dict(type="FFN", embed_dims=256, feedforward_channels=1024), - operation_order=("self_attn", "norm", "ffn", "norm"), - ) - for _ in range(2) - ] - - sequence = TransformerLayerSequence(transformerlayers=transformerlayers, num_layers=2) - - assert sequence.num_layers == 2 - assert len(sequence.layers) == 2 - - # Test forward with positional encoding and masks - query = key = value = torch.rand(2, 10, 256) - query_pos = key_pos = torch.rand(2, 10, 256) - # Each layer has 1 self-attention, so we need 1 mask per layer - attn_masks = None # Let it be handled automatically - - output = sequence(query, key, value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks) - assert output.shape == torch.Size([2, 10, 256]) - - -if __name__ == "__main__": - test_build_functions() - test_adaptive_padding() - test_patch_embed() - test_patch_merging() - test_multiheadattention() - test_basetransformerlayer() - test_transformerlayersequence() - print("All tests passed!") diff --git a/libs/viscv/tests/test_image/test_cache.py b/libs/viscv/tests/test_image/test_cache.py deleted file mode 100644 index 85851f1..0000000 --- a/libs/viscv/tests/test_image/test_cache.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Unit tests for ImageCache class.""" - -import shutil -import tempfile -import time -from pathlib import Path - -import numpy as np -import pytest -from viscv.image import ImageCache - - -@pytest.fixture -def temp_cache_dir(): - """Create temporary directory for cache tests.""" - temp_dir = tempfile.mkdtemp() - yield Path(temp_dir) - # Cleanup after test - if Path(temp_dir).exists(): - shutil.rmtree(temp_dir) - - -@pytest.fixture -def temp_image_file(tmp_path): - """Create a temporary test image file.""" - img_path = tmp_path / "test_image.jpg" - # Just create the path - we'll use it for mtime testing - img_path.touch() - return img_path - - -class TestImageCache: - """Test suite for ImageCache class.""" - - def test_cache_initialization(self, temp_cache_dir): - """Test cache initialization creates necessary directories and database.""" - cache = ImageCache(cache_dir=temp_cache_dir, max_size_gb=1.0, enabled=True) - - assert cache.cache_dir.exists() - assert cache.db_path.exists() - assert cache.enabled is True - assert cache.max_size_bytes == 1024 * 1024 * 1024 - - def test_cache_disabled(self): - """Test that disabled cache doesn't create any files.""" - cache = ImageCache(enabled=False) - - assert cache.enabled is False - # get() should return None when disabled - result = cache.get("any_path.jpg", (100, 100)) - assert result is None - - def test_cache_key_generation(self, temp_cache_dir, temp_image_file): - """Test cache key generation is consistent.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - target_size = (224, 224) - mtime = temp_image_file.stat().st_mtime - - key1 = cache._generate_cache_key(img_path, target_size, mtime) - key2 = cache._generate_cache_key(img_path, target_size, mtime) - - assert key1 == key2 - assert len(key1) == 32 # MD5 hash length - - def test_cache_miss(self, temp_cache_dir, temp_image_file): - """Test cache miss returns None.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - result = cache.get(str(temp_image_file), (224, 224)) - assert result is None - - def test_cache_put_and_get(self, temp_cache_dir, temp_image_file): - """Test basic cache put and get operations.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - # Create test image - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - img_path = str(temp_image_file) - target_size = (224, 224) - - # Put image in cache - cache.put(img_path, target_size, test_img) - - # Get image from cache - cached_img = cache.get(img_path, target_size) - - assert cached_img is not None - assert np.array_equal(cached_img, test_img) - assert cached_img.shape == test_img.shape - assert cached_img.dtype == test_img.dtype - - def test_cache_different_sizes(self, temp_cache_dir, temp_image_file): - """Test caching same image at different sizes.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img_224 = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - test_img_512 = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) - - # Cache at two different sizes - cache.put(img_path, (224, 224), test_img_224) - cache.put(img_path, (512, 512), test_img_512) - - # Retrieve both - cached_224 = cache.get(img_path, (224, 224)) - cached_512 = cache.get(img_path, (512, 512)) - - assert np.array_equal(cached_224, test_img_224) - assert np.array_equal(cached_512, test_img_512) - - def test_cache_invalidation_on_mtime_change(self, temp_cache_dir, temp_image_file): - """Test cache invalidates when file mtime changes.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Cache the image - cache.put(img_path, (224, 224), test_img) - assert cache.get(img_path, (224, 224)) is not None - - # Modify file mtime - time.sleep(0.01) # Ensure mtime changes - temp_image_file.touch() - - # Cache should miss now - result = cache.get(img_path, (224, 224)) - assert result is None - - def test_cache_access_count(self, temp_cache_dir, temp_image_file): - """Test cache tracks access count.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Cache the image - cache.put(img_path, (224, 224), test_img) - - # Access multiple times - for _ in range(5): - cache.get(img_path, (224, 224)) - - # Check stats - stats = cache.get_stats() - assert stats["enabled"] is True - assert stats["total_entries"] == 1 - # Access count should be at least 5 (may be more due to put operation) - assert stats["avg_access_count"] >= 5 - - def test_cache_eviction(self, temp_cache_dir, temp_image_file): - """Test LRU eviction when cache size exceeds limit.""" - # Create cache with very small size limit (1 MB) - cache = ImageCache(cache_dir=temp_cache_dir, max_size_gb=0.001, enabled=True) - - img_path = str(temp_image_file) - - # Create several large images to exceed cache size - images = [] - for i in range(5): - # Each image is ~500KB - img = np.random.randint(0, 255, (400, 400, 3), dtype=np.uint8) - images.append(img) - cache.put(img_path, (400 + i, 400 + i), img) - - # Check that eviction occurred - stats = cache.get_stats() - # Should have fewer than 5 entries due to eviction - assert stats["total_entries"] < 5 - - def test_cache_clear(self, temp_cache_dir, temp_image_file): - """Test clearing all cache entries.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Add some entries - cache.put(img_path, (224, 224), test_img) - cache.put(img_path, (512, 512), test_img) - - assert cache.get_stats()["total_entries"] == 2 - - # Clear cache - cache.clear() - - # Verify cache is empty - assert cache.get_stats()["total_entries"] == 0 - assert cache.get(img_path, (224, 224)) is None - - def test_cache_corrupted_file_handling(self, temp_cache_dir, temp_image_file): - """Test cache handles corrupted cache files gracefully.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Cache the image - cache.put(img_path, (224, 224), test_img) - - # Get cache key and corrupt the file - mtime = temp_image_file.stat().st_mtime - cache_key = cache._generate_cache_key(img_path, (224, 224), mtime) - cache_path = cache._get_cache_path(cache_key) - - # Corrupt the cache file - with open(cache_path, "wb") as f: - f.write(b"corrupted data") - - # Should return None and clean up corrupted file - result = cache.get(img_path, (224, 224)) - assert result is None - - def test_cache_nonexistent_file(self, temp_cache_dir): - """Test cache handles non-existent image files gracefully.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - # Try to get cache for non-existent file - result = cache.get("/nonexistent/path.jpg", (224, 224)) - assert result is None - - # Try to put cache for non-existent file (should be no-op) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - cache.put("/nonexistent/path.jpg", (224, 224), test_img) - - # Verify nothing was cached - assert cache.get_stats()["total_entries"] == 0 - - def test_cache_stats(self, temp_cache_dir, temp_image_file): - """Test cache statistics are accurate.""" - cache = ImageCache(cache_dir=temp_cache_dir, max_size_gb=1.0, enabled=True) - - stats = cache.get_stats() - assert stats["enabled"] is True - assert stats["total_entries"] == 0 - assert stats["total_size_mb"] == 0 - assert "cache_dir" in stats - - # Add some entries - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - cache.put(img_path, (224, 224), test_img) - - stats = cache.get_stats() - assert stats["total_entries"] == 1 - assert stats["total_size_mb"] > 0 - - def test_cache_repr(self, temp_cache_dir): - """Test cache string representation.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - repr_str = repr(cache) - - assert "ImageCache" in repr_str - assert str(temp_cache_dir) in repr_str - - # Test disabled cache repr - disabled_cache = ImageCache(enabled=False) - disabled_repr = repr(disabled_cache) - assert "enabled=False" in disabled_repr - - def test_cache_subdirectory_structure(self, temp_cache_dir, temp_image_file): - """Test cache creates subdirectories for better filesystem performance.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Cache the image - cache.put(img_path, (224, 224), test_img) - - # Verify subdirectory structure exists - # Cache should create subdirectories based on first 2 chars of hash - subdirs = [d for d in temp_cache_dir.iterdir() if d.is_dir()] - assert len(subdirs) > 0 # At least one subdirectory should exist - - def test_cache_concurrent_access(self, temp_cache_dir, temp_image_file): - """Test cache handles multiple access patterns.""" - cache = ImageCache(cache_dir=temp_cache_dir, enabled=True) - - img_path = str(temp_image_file) - test_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) - - # Cache the image - cache.put(img_path, (224, 224), test_img) - - # Multiple reads - results = [cache.get(img_path, (224, 224)) for _ in range(10)] - - # All should return the same image - assert all(r is not None for r in results) - assert all(np.array_equal(r, test_img) for r in results) diff --git a/libs/viscv/tests/test_image/test_io.py b/libs/viscv/tests/test_image/test_io.py deleted file mode 100644 index ad2c4ef..0000000 --- a/libs/viscv/tests/test_image/test_io.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Tests for viscv.image.io module.""" - -import sys -from pathlib import Path - -import cv2 -import numpy as np -import pytest - -# Add viscv to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from viscv.image import imfrombytes - - -class TestImfrombytes: - """Test imfrombytes function.""" - - @pytest.fixture - def image_bytes(self): - """Create test image bytes.""" - # Create a simple test image - img = np.random.randint(0, 255, (50, 50, 3), dtype=np.uint8) - - # Encode as JPEG bytes - success, buffer = cv2.imencode(".jpg", img) - assert success - - return buffer.tobytes(), img - - def test_basic_decode(self, image_bytes): - """Test basic image decoding from bytes.""" - img_bytes, original = image_bytes - - decoded = imfrombytes(img_bytes) - - assert isinstance(decoded, np.ndarray) - assert decoded.shape == original.shape - assert decoded.dtype == np.uint8 - - def test_color_flag(self, image_bytes): - """Test different color flags.""" - img_bytes, _ = image_bytes - - # Color image - color_img = imfrombytes(img_bytes, flag="color") - assert len(color_img.shape) == 3 - - # Grayscale image - gray_img = imfrombytes(img_bytes, flag="grayscale") - assert len(gray_img.shape) == 2 - - # Unchanged - unchanged_img = imfrombytes(img_bytes, flag="unchanged") - assert isinstance(unchanged_img, np.ndarray) - - def test_channel_order(self, image_bytes): - """Test channel order parameter.""" - img_bytes, _ = image_bytes - - # Default BGR - bgr_img = imfrombytes(img_bytes, channel_order="bgr") - - # RGB order - rgb_img = imfrombytes(img_bytes, channel_order="rgb") - - # The channels should be swapped - # Note: Due to JPEG compression, exact equality won't work - assert bgr_img.shape == rgb_img.shape - - def test_backend_parameter(self, image_bytes): - """Test backend parameter.""" - img_bytes, _ = image_bytes - - # Test cv2 backend (default) - cv2_img = imfrombytes(img_bytes, backend="cv2") - assert isinstance(cv2_img, np.ndarray) - - # Test explicit cv2 backend - cv2_explicit = imfrombytes(img_bytes, backend="cv2") - assert isinstance(cv2_explicit, np.ndarray) - - def test_invalid_backend(self, image_bytes): - """Test invalid backend raises error.""" - img_bytes, _ = image_bytes - - with pytest.raises(ValueError, match="backend: invalid is not supported"): - imfrombytes(img_bytes, backend="invalid") - - def test_png_bytes(self): - """Test decoding PNG bytes.""" - # Create image with transparency - img = np.random.randint(0, 255, (30, 30, 4), dtype=np.uint8) - - # Encode as PNG - success, buffer = cv2.imencode(".png", img) - assert success - png_bytes = buffer.tobytes() - - # Decode - decoded = imfrombytes(png_bytes, flag="unchanged") - assert decoded.shape == img.shape - - def test_empty_bytes(self): - """Test handling of empty bytes.""" - # cv2.imdecode returns None for invalid data - result = imfrombytes(b"") - assert result is None - - result = imfrombytes(b"invalid image data") - assert result is None - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/libs/viscv/tests/test_ops/test_nms.py b/libs/viscv/tests/test_ops/test_nms.py deleted file mode 100644 index f26b87b..0000000 --- a/libs/viscv/tests/test_ops/test_nms.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import pytest -import torch -from viscv.ops import batched_nms, nms - - -class TestNMS: - def test_nms_cpu(self): - """Test NMS on CPU.""" - np_boxes = np.array( - [ - [6.0, 3.0, 8.0, 7.0], - [3.0, 6.0, 9.0, 11.0], - [3.0, 7.0, 10.0, 12.0], - [1.0, 4.0, 13.0, 7.0], - ], - dtype=np.float32, - ) - np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) - - boxes = torch.from_numpy(np_boxes) - scores = torch.from_numpy(np_scores) - - dets, inds = nms(boxes, scores, iou_threshold=0.3) - - # Check that highest scoring box is kept - assert inds[0] == 1 # index of box with score 0.9 - - # Check shape - assert dets.shape[1] == 5 # x1, y1, x2, y2, score - assert len(inds) <= len(boxes) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") - def test_nms_cuda(self): - """Test NMS on CUDA.""" - np_boxes = np.array( - [ - [6.0, 3.0, 8.0, 7.0], - [3.0, 6.0, 9.0, 11.0], - [3.0, 7.0, 10.0, 12.0], - [1.0, 4.0, 13.0, 7.0], - ], - dtype=np.float32, - ) - np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) - - boxes = torch.from_numpy(np_boxes).cuda() - scores = torch.from_numpy(np_scores).cuda() - - dets, inds = nms(boxes, scores, iou_threshold=0.3) - - # Check that results are on cuda - assert dets.is_cuda - assert inds.is_cuda - - # Check that highest scoring box is kept - assert inds[0].cpu() == 1 # index of box with score 0.9 - - def test_batched_nms(self): - """Test batched NMS.""" - boxes = torch.tensor( - [ - [6.0, 3.0, 8.0, 7.0], - [3.0, 6.0, 9.0, 11.0], - [3.0, 7.0, 10.0, 12.0], - [1.0, 4.0, 13.0, 7.0], - ] - ) - scores = torch.tensor([0.6, 0.9, 0.7, 0.2]) - idxs = torch.tensor([0, 0, 1, 1]) # Two different classes/batches - - # Test with class-based NMS - dets, keep = batched_nms(boxes, scores, idxs, nms_cfg=dict(iou_threshold=0.3)) - - # Should keep the top scoring box from each class - assert len(keep) >= 2 # At least one from each class - assert 1 in keep # Highest score in class 0 - assert 2 in keep # Highest score in class 1 - - def test_batched_nms_split_thr(self): - """Test batched NMS with split threshold.""" - # Create many boxes to trigger split - n_boxes = 10000 - boxes = torch.rand(n_boxes, 4) - boxes[:, 2:] = boxes[:, :2] + 0.1 # Ensure x2 > x1, y2 > y1 - scores = torch.rand(n_boxes) - idxs = torch.zeros(n_boxes, dtype=torch.long) - - # Test with split threshold - dets, keep = batched_nms( - boxes, - scores, - idxs, - nms_cfg=dict(type="nms", iou_threshold=0.5, split_thr=1000), - ) - - # Check output shapes - assert dets.shape[1] == 5 - assert len(keep) > 0 - assert len(keep) <= n_boxes diff --git a/libs/viscv/tests/test_ops/test_roi_align.py b/libs/viscv/tests/test_ops/test_roi_align.py deleted file mode 100644 index 102ec78..0000000 --- a/libs/viscv/tests/test_ops/test_roi_align.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch -from viscv.ops import roi_align - - -class TestRoIAlign: - def test_roi_align_cpu(self): - """Test RoIAlign on CPU.""" - # Create a simple 4x4 feature map - features = torch.tensor( - [ - [ - [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - ] - ] - ], - dtype=torch.float32, - ) - - # Single RoI covering the top-left quadrant - rois = torch.tensor([[0.0, 0.0, 0.0, 2.0, 2.0]], dtype=torch.float32) - - output_size = (2, 2) - spatial_scale = 1.0 - sampling_ratio = 2 - - output = roi_align( - features, - rois, - output_size, - spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio, - ) - - # Check output shape - assert output.shape == (1, 1, 2, 2) - - # Values should be interpolated from top-left quadrant - assert output.min() >= 1.0 - assert output.max() <= 6.0 - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") - def test_roi_align_cuda(self): - """Test RoIAlign on CUDA.""" - features = torch.tensor( - [ - [ - [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - ] - ] - ], - dtype=torch.float32, - ).cuda() - - rois = torch.tensor([[0.0, 0.0, 0.0, 2.0, 2.0]], dtype=torch.float32).cuda() - - output_size = (2, 2) - spatial_scale = 1.0 - sampling_ratio = 2 - - output = roi_align( - features, - rois, - output_size, - spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio, - ) - - # Check output is on CUDA - assert output.is_cuda - - # Check output shape - assert output.shape == (1, 1, 2, 2) - - def test_roi_align_multiple_rois(self): - """Test RoIAlign with multiple RoIs.""" - features = torch.rand(2, 3, 8, 8) # 2 images, 3 channels, 8x8 - - # 3 RoIs: first two for image 0, last one for image 1 - rois = torch.tensor( - [ - [0.0, 0.0, 0.0, 4.0, 4.0], # Top-left of image 0 - [0.0, 4.0, 4.0, 8.0, 8.0], # Bottom-right of image 0 - [1.0, 2.0, 2.0, 6.0, 6.0], # Center of image 1 - ] - ) - - output_size = (3, 3) - output = roi_align(features, rois, output_size) - - # Check output shape: (num_rois, channels, h, w) - assert output.shape == (3, 3, 3, 3) - - def test_roi_align_with_different_scales(self): - """Test RoIAlign with different spatial scales.""" - features = torch.rand(1, 2, 16, 16) - - # RoI in original image coordinates - rois = torch.tensor([[0.0, 8.0, 8.0, 24.0, 24.0]]) - - # spatial_scale = 0.5 means features are half the size of original - output = roi_align(features, rois, (4, 4), spatial_scale=0.5) - - assert output.shape == (1, 2, 4, 4) - - def test_roi_align_empty_rois(self): - """Test RoIAlign with empty RoIs.""" - features = torch.rand(1, 3, 8, 8) - rois = torch.empty((0, 5)) # Empty tensor - - output = roi_align(features, rois, (4, 4)) - - # Should return empty output - assert output.shape == (0, 3, 4, 4) diff --git a/libs/viscv/tests/test_transforms/__init__.py b/libs/viscv/tests/test_transforms/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/libs/viscv/tests/test_transforms/test_base.py b/libs/viscv/tests/test_transforms/test_base.py deleted file mode 100644 index 4de8f84..0000000 --- a/libs/viscv/tests/test_transforms/test_base.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Tests for viscv.transforms.base module.""" - -import sys -from pathlib import Path - -import pytest - -# Add viscv to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from viscv.transforms.base import BaseTransform - - -class TestBaseTransform: - """Test BaseTransform abstract class.""" - - def test_abstract_class(self): - """Test that BaseTransform cannot be instantiated directly.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - BaseTransform() - - def test_concrete_implementation(self): - """Test concrete implementation of BaseTransform.""" - - class ConcreteTransform(BaseTransform): - def __init__(self, value=1): - self.value = value - - def transform(self, results: dict) -> dict | None: - results["transformed"] = True - results["value"] = self.value - return results - - # Should be able to instantiate concrete class - transform = ConcreteTransform(value=42) - assert transform.value == 42 - - # Test __call__ method - input_dict = {"data": "test"} - output = transform(input_dict) - - assert output is not None - assert output["data"] == "test" - assert output["transformed"] is True - assert output["value"] == 42 - - def test_none_return(self): - """Test transform that returns None.""" - - class NoneTransform(BaseTransform): - def transform(self, results: dict) -> dict | None: - if results.get("skip", False): - return None - return results - - transform = NoneTransform() - - # Normal case - output = transform({"data": "test"}) - assert output == {"data": "test"} - - # Skip case - output = transform({"skip": True}) - assert output is None - - def test_transform_pipeline(self): - """Test chaining multiple transforms.""" - - class AddOneTransform(BaseTransform): - def transform(self, results: dict) -> dict | None: - results["value"] = results.get("value", 0) + 1 - return results - - class MultiplyTwoTransform(BaseTransform): - def transform(self, results: dict) -> dict | None: - results["value"] = results.get("value", 1) * 2 - return results - - # Create pipeline - transforms = [AddOneTransform(), MultiplyTwoTransform(), AddOneTransform()] - - # Apply transforms - data = {"value": 5} - for t in transforms: - data = t(data) - if data is None: - break - - # (5 + 1) * 2 + 1 = 13 - assert data["value"] == 13 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/libs/viscv/tests/test_transforms/test_builder.py b/libs/viscv/tests/test_transforms/test_builder.py deleted file mode 100644 index ddb7c0f..0000000 --- a/libs/viscv/tests/test_transforms/test_builder.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Tests for viscv.transforms.builder module.""" - -import sys -from pathlib import Path - -import pytest - -# Add viscv to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from viscv.transforms.base import BaseTransform -from viscv.transforms.builder import TRANSFORMS, Registry - - -class TestRegistry: - """Test Registry class.""" - - def test_init(self): - """Test Registry initialization.""" - registry = Registry("test_registry") - assert registry._name == "test_registry" - assert registry._module_dict == {} - - def test_register_module_direct(self): - """Test direct module registration.""" - registry = Registry("test") - - class DummyTransform(BaseTransform): - def transform(self, results): - return results - - # Register with explicit name - registry.register_module(name="Dummy", module=DummyTransform) - assert registry.get("Dummy") == DummyTransform - - # Register without name (uses class name) - class AnotherTransform(BaseTransform): - def transform(self, results): - return results - - registry.register_module(module=AnotherTransform) - assert registry.get("AnotherTransform") == AnotherTransform - - def test_register_module_decorator(self): - """Test decorator-based registration.""" - registry = Registry("test") - - @registry.register_module() - class DecoratedTransform(BaseTransform): - def transform(self, results): - return results - - assert registry.get("DecoratedTransform") == DecoratedTransform - - # Test with custom name - @registry.register_module(name="CustomName") - class AnotherDecoratedTransform(BaseTransform): - def transform(self, results): - return results - - assert registry.get("CustomName") == AnotherDecoratedTransform - assert registry.get("AnotherDecoratedTransform") is None - - def test_register_duplicate(self): - """Test registering duplicate names.""" - registry = Registry("test") - - class Transform1(BaseTransform): - def transform(self, results): - return results - - class Transform2(BaseTransform): - def transform(self, results): - return results - - registry.register_module(name="Transform", module=Transform1) - - # Should raise error without force - with pytest.raises(KeyError, match="Transform is already registered"): - registry.register_module(name="Transform", module=Transform2) - - # Should work with force=True - registry.register_module(name="Transform", module=Transform2, force=True) - assert registry.get("Transform") == Transform2 - - def test_build(self): - """Test building modules from config.""" - registry = Registry("test") - - class ConfigurableTransform(BaseTransform): - def __init__(self, param1=1, param2="default"): - self.param1 = param1 - self.param2 = param2 - - def transform(self, results): - return results - - registry.register_module(module=ConfigurableTransform) - - # Test basic build - config = {"type": "ConfigurableTransform"} - instance = registry.build(config) - assert isinstance(instance, ConfigurableTransform) - assert instance.param1 == 1 - assert instance.param2 == "default" - - # Test build with parameters - config = {"type": "ConfigurableTransform", "param1": 42, "param2": "custom"} - instance = registry.build(config) - assert instance.param1 == 42 - assert instance.param2 == "custom" - - def test_build_errors(self): - """Test error cases in build.""" - registry = Registry("test") - - # Test non-dict config - with pytest.raises(TypeError, match="cfg must be a dict"): - registry.build("not a dict") - - # Test missing type - with pytest.raises(KeyError, match='cfg must contain the key "type"'): - registry.build({}) - - # Test unregistered type - with pytest.raises(KeyError, match="UnknownType is not in the test registry"): - registry.build({"type": "UnknownType"}) - - def test_transforms_registry(self): - """Test the global TRANSFORMS registry.""" - assert isinstance(TRANSFORMS, Registry) - assert TRANSFORMS._name == "transforms" - - # Check that LoadImageFromFile is registered - from viscv.transforms.loading import LoadImageFromFile - - assert TRANSFORMS.get("LoadImageFromFile") == LoadImageFromFile - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/libs/viscv/tests/test_transforms/test_cache_integration.py b/libs/viscv/tests/test_transforms/test_cache_integration.py deleted file mode 100644 index 479af3d..0000000 --- a/libs/viscv/tests/test_transforms/test_cache_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Integration tests for image cache with transform pipeline.""" - -import shutil -import tempfile -from pathlib import Path - -import numpy as np -import pytest -from PIL import Image -from viscv.transforms.loading import LoadImageFromFile -from viscv.transforms.processing import Resize - - -@pytest.fixture -def temp_cache_dir(): - """Create temporary directory for cache tests.""" - temp_dir = tempfile.mkdtemp() - yield Path(temp_dir) - # Cleanup after test - if Path(temp_dir).exists(): - shutil.rmtree(temp_dir) - - -@pytest.fixture -def temp_image(tmp_path): - """Create a temporary test image file.""" - img_path = tmp_path / "test_image.jpg" - # Create and save a test image - test_img = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) - Image.fromarray(test_img).save(img_path) - return str(img_path) - - -class TestCacheIntegration: - """Test suite for cache integration with transforms.""" - - def test_load_and_resize_with_cache(self, temp_cache_dir, temp_image): - """Test complete pipeline: load image with cache, resize, cache resized.""" - # Create transforms with cache enabled - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - cache_max_size_gb=1.0, - ) - resize_transform = Resize(scale=(224, 224), keep_ratio=True, enable_cache=True) - - # First pass - should be cache miss - results = {"img_path": temp_image} - results = load_transform(results) - assert results is not None - assert results["_cache_hit"] is False - assert "img" in results - - # Resize - results = resize_transform(results) - assert results is not None - assert "img" in results - assert results["img_shape"][0] <= 224 - assert results["img_shape"][1] <= 224 - - # Second pass - should be cache hit - # Need to provide target size from first run - results2 = { - "img_path": temp_image, - "_cache_target_size": results["_cache_target_size"], - } - results2 = load_transform(results2) - assert results2 is not None - assert results2["_cache_hit"] is True # Should hit cache this time - - # Verify cached image is correct - assert results2["img"].shape[:2] == results["img"].shape[:2] - - def test_cache_disabled_by_default_in_resize(self, temp_cache_dir, temp_image): - """Test that resize still works when cache is disabled.""" - load_transform = LoadImageFromFile() # Cache disabled - resize_transform = Resize(scale=(224, 224), enable_cache=True) - - results = {"img_path": temp_image} - results = load_transform(results) - results = resize_transform(results) - - assert results is not None - assert "img" in results - - def test_multiple_sizes_cached_separately(self, temp_cache_dir, temp_image): - """Test that different resize sizes are cached separately.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - resize_224 = Resize(scale=(224, 224), enable_cache=True) - resize_512 = Resize(scale=(512, 512), enable_cache=True) - - # Process at 224x224 - results_224 = {"img_path": temp_image} - results_224 = load_transform(results_224) - results_224 = resize_224(results_224) - - # Process at 512x512 - results_512 = {"img_path": temp_image} - results_512 = load_transform(results_512) - results_512 = resize_512(results_512) - - # Verify both are cached - stats = load_transform.cache.get_stats() - assert stats["total_entries"] == 2 - - # Verify second load hits cache for 224 (need to provide target size) - results_224_cached = { - "img_path": temp_image, - "_cache_target_size": results_224["_cache_target_size"], - } - results_224_cached = load_transform(results_224_cached) - assert results_224_cached["_cache_hit"] is True - assert results_224_cached["img"].shape[:2] == results_224["img"].shape[:2] - - def test_cache_instance_passed_through_pipeline(self, temp_cache_dir, temp_image): - """Test that cache instance is properly passed through results dict.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - - results = {"img_path": temp_image} - results = load_transform(results) - - # Verify cache instance is in results - assert "_image_cache" in results - assert results["_image_cache"] is load_transform.cache - - def test_cache_not_passed_when_disabled(self, temp_image): - """Test that cache is not passed when disabled.""" - load_transform = LoadImageFromFile(enable_cache=False) - - results = {"img_path": temp_image} - results = load_transform(results) - - # Verify cache instance is NOT in results - assert "_image_cache" not in results - - def test_resize_saves_to_cache_after_first_load(self, temp_cache_dir, temp_image): - """Test that resize saves to cache only after first load.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - resize_transform = Resize(scale=(224, 224), enable_cache=True) - - # First pass - load and resize - results = {"img_path": temp_image} - results = load_transform(results) - initial_cache_entries = load_transform.cache.get_stats()["total_entries"] - - results = resize_transform(results) - - # After resize, cache should have the resized image - final_cache_entries = load_transform.cache.get_stats()["total_entries"] - assert final_cache_entries > initial_cache_entries - - def test_cache_with_keep_ratio_false(self, temp_cache_dir, temp_image): - """Test cache works with keep_ratio=False.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - resize_transform = Resize(scale=(224, 224), keep_ratio=False, enable_cache=True) - - # First pass - results = {"img_path": temp_image} - results = load_transform(results) - results = resize_transform(results) - assert results["img"].shape[:2] == (224, 224) - - # Second pass - should hit cache (need to provide target size) - results2 = { - "img_path": temp_image, - "_cache_target_size": results["_cache_target_size"], - } - results2 = load_transform(results2) - assert results2["_cache_hit"] is True - - def test_cache_stats_accuracy(self, temp_cache_dir, temp_image): - """Test that cache statistics are accurate after transform pipeline.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - cache_max_size_gb=1.0, - ) - resize_transform = Resize(scale=(224, 224), enable_cache=True) - - # Run pipeline multiple times - results = {"img_path": temp_image} - results = load_transform(results) - results = resize_transform(results) - target_size = results["_cache_target_size"] - - # Subsequent runs with target size - for _ in range(2): - results = {"img_path": temp_image, "_cache_target_size": target_size} - results = load_transform(results) - results = resize_transform(results) - - stats = load_transform.cache.get_stats() - assert stats["enabled"] is True - assert stats["total_entries"] == 1 # Same image, same size - assert stats["avg_access_count"] >= 2 # At least 2 accesses (1 write, 2+ reads) - - def test_pipeline_without_resize(self, temp_cache_dir, temp_image): - """Test that cache works even without resize transform.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - - # Load without resize (cache won't save anything since no target size) - results = {"img_path": temp_image} - results = load_transform(results) - - assert results is not None - assert "_cache_hit" in results - - def test_cache_repr_in_transform(self, temp_cache_dir): - """Test transform repr includes cache status.""" - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - - repr_str = repr(load_transform) - assert "cache_enabled=True" in repr_str - - def test_multiple_images_cached(self, temp_cache_dir, tmp_path): - """Test that multiple different images are cached correctly.""" - # Create multiple test images - images = [] - for i in range(3): - img_path = tmp_path / f"test_image_{i}.jpg" - test_img = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) - Image.fromarray(test_img).save(img_path) - images.append(str(img_path)) - - load_transform = LoadImageFromFile( - enable_cache=True, - cache_dir=str(temp_cache_dir), - ) - resize_transform = Resize(scale=(224, 224), enable_cache=True) - - # Process all images and store target sizes - target_sizes = {} - for img_path in images: - results = {"img_path": img_path} - results = load_transform(results) - results = resize_transform(results) - target_sizes[img_path] = results["_cache_target_size"] - - # Verify all are cached - stats = load_transform.cache.get_stats() - assert stats["total_entries"] == 3 - - # Verify cache hits on second pass (with target sizes) - for img_path in images: - results = { - "img_path": img_path, - "_cache_target_size": target_sizes[img_path], - } - results = load_transform(results) - assert results["_cache_hit"] is True diff --git a/libs/viscv/tests/test_transforms/test_integration.py b/libs/viscv/tests/test_transforms/test_integration.py deleted file mode 100644 index 9e924ce..0000000 --- a/libs/viscv/tests/test_transforms/test_integration.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Integration tests for viscv transforms.""" - -import sys -import tempfile -from pathlib import Path - -import cv2 -import numpy as np - -# Add viscv to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from viscv.transforms import TRANSFORMS, LoadImageFromFile - - -def test_end_to_end_image_loading(): - """Test complete image loading workflow.""" - # Create a test image - test_img = np.array( - [ - [[255, 0, 0], [0, 255, 0], [0, 0, 255]], # RGB pixels - [[255, 255, 0], [255, 0, 255], [0, 255, 255]], - [[128, 128, 128], [0, 0, 0], [255, 255, 255]], - ], - dtype=np.uint8, - ) - - with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: - cv2.imwrite(f.name, test_img) - temp_path = f.name - - try: - # Method 1: Direct instantiation - print("Testing direct instantiation...") - loader = LoadImageFromFile(to_float32=False) - result = loader({"img_path": temp_path}) - - assert result is not None - assert "img" in result - assert np.array_equal(result["img"], test_img) - print("✓ Direct instantiation works") - - # Method 2: Using registry - print("\nTesting registry build...") - config = { - "type": "LoadImageFromFile", - "to_float32": True, - "color_type": "color", - } - loader2 = TRANSFORMS.build(config) - result2 = loader2({"img_path": temp_path}) - - assert result2 is not None - assert result2["img"].dtype == np.float32 - print("✓ Registry build works") - - # Method 3: Pipeline simulation - print("\nTesting pipeline simulation...") - pipeline_config = [ - {"type": "LoadImageFromFile", "to_float32": False}, - # Could add more transforms here - ] - - data = {"img_path": temp_path} - for cfg in pipeline_config: - transform = TRANSFORMS.build(cfg) - data = transform(data) - if data is None: - break - - assert data is not None - assert "img" in data - print("✓ Pipeline simulation works") - - print("\n✅ All integration tests passed!") - - finally: - # Cleanup - Path(temp_path).unlink(missing_ok=True) - - -if __name__ == "__main__": - test_end_to_end_image_loading() diff --git a/libs/viscv/tests/test_transforms/test_loading.py b/libs/viscv/tests/test_transforms/test_loading.py deleted file mode 100644 index 3953a86..0000000 --- a/libs/viscv/tests/test_transforms/test_loading.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Tests for viscv.transforms.loading module.""" - -import sys -import tempfile -from pathlib import Path - -import cv2 -import numpy as np -import pytest - -# Add viscv to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from viscv.transforms import TRANSFORMS, LoadImageFromFile - - -class TestLoadImageFromFile: - """Test LoadImageFromFile transform.""" - - @pytest.fixture - def temp_image(self): - """Create a temporary test image.""" - # Create a simple 100x100 RGB image - img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) - - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: - cv2.imwrite(f.name, img) - yield f.name, img - - # Cleanup - Path(f.name).unlink(missing_ok=True) - - def test_init(self): - """Test LoadImageFromFile initialization.""" - loader = LoadImageFromFile() - assert loader.to_float32 is False - assert loader.color_type == "color" - assert loader.imdecode_backend == "cv2" - assert loader.ignore_empty is False - assert loader.file_client_args is None - assert loader.backend_args is None - - # Test with custom parameters - loader = LoadImageFromFile( - to_float32=True, - color_type="grayscale", - imdecode_backend="pillow", - ignore_empty=True, - ) - assert loader.to_float32 is True - assert loader.color_type == "grayscale" - assert loader.imdecode_backend == "pillow" - assert loader.ignore_empty is True - - def test_transform_basic(self, temp_image): - """Test basic image loading.""" - img_path, original_img = temp_image - - loader = LoadImageFromFile() - results = {"img_path": img_path} - - output = loader.transform(results) - - assert output is not None - assert "img" in output - assert "img_shape" in output - assert "ori_shape" in output - - # Check image is loaded correctly - assert isinstance(output["img"], np.ndarray) - assert output["img"].shape == original_img.shape - assert output["img_shape"] == original_img.shape[:2] - assert output["ori_shape"] == original_img.shape[:2] - - # Check data type - assert output["img"].dtype == np.uint8 - - def test_transform_float32(self, temp_image): - """Test loading image as float32.""" - img_path, _ = temp_image - - loader = LoadImageFromFile(to_float32=True) - results = {"img_path": img_path} - - output = loader.transform(results) - - assert output["img"].dtype == np.float32 - - def test_transform_grayscale(self, temp_image): - """Test loading image as grayscale.""" - img_path, original_img = temp_image - - loader = LoadImageFromFile(color_type="grayscale") - results = {"img_path": img_path} - - output = loader.transform(results) - - # Grayscale image should have 2 dimensions - assert len(output["img"].shape) == 2 - assert output["img_shape"] == original_img.shape[:2] - - def test_transform_missing_file(self): - """Test behavior with missing file.""" - loader = LoadImageFromFile(ignore_empty=False) - results = {"img_path": "/nonexistent/path/image.jpg"} - - with pytest.raises(Exception): - loader.transform(results) - - # Test with ignore_empty=True - loader = LoadImageFromFile(ignore_empty=True) - output = loader.transform(results) - assert output is None - - def test_registry(self): - """Test that LoadImageFromFile is registered.""" - assert TRANSFORMS.get("LoadImageFromFile") == LoadImageFromFile - - # Test building from config - config = { - "type": "LoadImageFromFile", - "to_float32": True, - "color_type": "color", - } - loader = TRANSFORMS.build(config) - assert isinstance(loader, LoadImageFromFile) - assert loader.to_float32 is True - assert loader.color_type == "color" - - def test_repr(self): - """Test string representation.""" - loader = LoadImageFromFile(to_float32=True, ignore_empty=True, backend_args={"backend": "disk"}) - repr_str = repr(loader) - assert "LoadImageFromFile" in repr_str - assert "to_float32=True" in repr_str - assert "ignore_empty=True" in repr_str - assert "backend_args=" in repr_str - - def test_deprecated_file_client_args(self): - """Test deprecated file_client_args parameter.""" - with pytest.warns(DeprecationWarning): - loader = LoadImageFromFile(file_client_args={"backend": "disk"}) - assert loader.file_client_args == {"backend": "disk"} - assert loader.backend_args is None - - # Test that both args cannot be set - with pytest.raises(ValueError): - LoadImageFromFile(file_client_args={"backend": "disk"}, backend_args={"backend": "disk"}) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/libs/viscv/viscv/__init__.py b/libs/viscv/viscv/__init__.py deleted file mode 100644 index 3dc113a..0000000 --- a/libs/viscv/viscv/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from . import image, transforms -from .image import imfrombytes, imwrite - -__all__ = ["image", "imfrombytes", "imwrite", "transforms"] diff --git a/libs/viscv/viscv/cnn/__init__.py b/libs/viscv/viscv/cnn/__init__.py deleted file mode 100644 index c550b0a..0000000 --- a/libs/viscv/viscv/cnn/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .bricks import ConvModule, build_conv_layer, build_upsample_layer - -__all__ = ["ConvModule", "build_conv_layer", "build_upsample_layer"] diff --git a/libs/viscv/viscv/cnn/bricks/__init__.py b/libs/viscv/viscv/cnn/bricks/__init__.py deleted file mode 100644 index 583e4bd..0000000 --- a/libs/viscv/viscv/cnn/bricks/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from . import wrappers as wrappers # Registers ConvTranspose2d as 'deconv' -from .activation import HSigmoid, HSwish, Swish, build_activation_layer -from .conv import build_conv_layer -from .conv_module import ConvModule -from .drop import build_dropout -from .norm import build_norm_layer -from .padding import build_padding_layer -from .scale import LayerScale, Scale -from .transformer import FFN, MultiheadAttention -from .upsample import build_upsample_layer -from .wrappers import ( - Conv2d, - Conv3d, - ConvTranspose2d, - ConvTranspose3d, - Linear, - MaxPool2d, - MaxPool3d, -) - -__all__ = [ - "FFN", - "Conv2d", - "Conv3d", - "ConvModule", - "ConvTranspose2d", - "ConvTranspose3d", - "HSigmoid", - "HSwish", - "LayerScale", - "Linear", - "MaxPool2d", - "MaxPool3d", - "MultiheadAttention", - "Scale", - "Swish", - "build_activation_layer", - "build_conv_layer", - "build_dropout", - "build_norm_layer", - "build_padding_layer", - "build_upsample_layer", -] diff --git a/libs/viscv/viscv/cnn/bricks/activation.py b/libs/viscv/viscv/cnn/bricks/activation.py deleted file mode 100644 index 0f911ac..0000000 --- a/libs/viscv/viscv/cnn/bricks/activation.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from visengine.registry import MODELS - -# Import activation classes from their modules -from .hsigmoid import HSigmoid -from .hswish import HSwish -from .swish import Swish - -for module in [ - nn.ReLU, - nn.LeakyReLU, - nn.PReLU, - nn.RReLU, - nn.ReLU6, - nn.ELU, - nn.Sigmoid, - nn.Tanh, -]: - MODELS.register_module(module=module) - -# Register custom activation modules -MODELS.register_module(module=HSigmoid) -MODELS.register_module(module=HSwish) -MODELS.register_module(module=Swish) -MODELS.register_module(module=nn.SiLU, name="SiLU") -MODELS.register_module(module=nn.GELU) - - -@MODELS.register_module(name="Clip") -@MODELS.register_module() -class Clamp(nn.Module): - """Clamp activation layer. - - This activation function is to clamp the feature map value within - :math:`[min, max]`. More details can be found in ``torch.clamp()``. - - Args: - min (Number | optional): Lower-bound of the range to be clamped to. - Default to -1. - max (Number | optional): Upper-bound of the range to be clamped to. - Default to 1. - """ - - def __init__(self, min: float = -1.0, max: float = 1.0): - super().__init__() - self.min = min - self.max = max - - def forward(self, x) -> torch.Tensor: - """Forward function. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: Clamped tensor. - """ - return torch.clamp(x, min=self.min, max=self.max) - - -class GELU(nn.Module): - r"""Applies the Gaussian Error Linear Units function: - - .. math:: - \text{GELU}(x) = x * \Phi(x) - where :math:`\Phi(x)` is the Cumulative Distribution Function for - Gaussian Distribution. - - Shape: - - Input: :math:`(N, *)` where `*` means, any number of additional - dimensions - - Output: :math:`(N, *)`, same shape as the input - - .. image:: scripts/activation_images/GELU.png - - Examples:: - - >>> m = nn.GELU() - >>> input = torch.randn(2) - >>> output = m(input) - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.gelu(input) - - -def build_activation_layer(cfg: dict) -> nn.Module: - """Build activation layer. - - Args: - cfg (dict): The activation layer config, which should contain: - - - type (str): Layer type. - - layer args: Args needed to instantiate an activation layer. - - Returns: - nn.Module: Created activation layer. - """ - return MODELS.build(cfg) diff --git a/libs/viscv/viscv/cnn/bricks/conv.py b/libs/viscv/viscv/cnn/bricks/conv.py deleted file mode 100644 index c665549..0000000 --- a/libs/viscv/viscv/cnn/bricks/conv.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect - -from torch import nn -from visengine.registry import MODELS - -MODELS.register_module("Conv1d", module=nn.Conv1d) -MODELS.register_module("Conv2d", module=nn.Conv2d) -MODELS.register_module("Conv3d", module=nn.Conv3d) -MODELS.register_module("Conv", module=nn.Conv2d) - - -def build_conv_layer(cfg: dict | None, *args, **kwargs) -> nn.Module: - """Build convolution layer. - - Args: - cfg (None or dict): The conv layer config, which should contain: - - type (str): Layer type. - - layer args: Args needed to instantiate an conv layer. - args (argument list): Arguments passed to the `__init__` - method of the corresponding conv layer. - kwargs (keyword arguments): Keyword arguments passed to the `__init__` - method of the corresponding conv layer. - - Returns: - nn.Module: Created conv layer. - """ - if cfg is None: - cfg_ = dict(type="Conv2d") - else: - if not isinstance(cfg, dict): - raise TypeError("cfg must be a dict") - if "type" not in cfg: - raise KeyError('the cfg dict must contain the key "type"') - cfg_ = cfg.copy() - - layer_type = cfg_.pop("type") - if inspect.isclass(layer_type): - return layer_type(*args, **kwargs, **cfg_) # type: ignore - # Switch registry to the target scope. If `conv_layer` cannot be found - # in the registry, fallback to search `conv_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - conv_layer = registry.get(layer_type) - if conv_layer is None: - raise KeyError(f"Cannot find {conv_layer} in registry under scope name {registry.scope}") - layer = conv_layer(*args, **kwargs, **cfg_) - - return layer diff --git a/libs/viscv/viscv/cnn/bricks/conv_module.py b/libs/viscv/viscv/cnn/bricks/conv_module.py deleted file mode 100644 index f473b91..0000000 --- a/libs/viscv/viscv/cnn/bricks/conv_module.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from functools import partial - -import torch -import torch.nn as nn -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm -from visengine.model import constant_init, kaiming_init -from visengine.registry import MODELS - -from .activation import build_activation_layer -from .conv import build_conv_layer -from .norm import build_norm_layer -from .padding import build_padding_layer - - -def efficient_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor): - """ - Implementation based on https://arxiv.org/abs/2305.11624 - "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" - It leverages the associative law between convolution and affine transform, - i.e., normalize (weight conv feature) = (normalize weight) conv feature. - It works for Eval mode of ConvBN blocks during validation, and can be used - for training as well. It reduces memory and computation cost. - - Args: - bn (_BatchNorm): a BatchNorm module. - conv (nn._ConvNd): a conv module - x (torch.Tensor): Input feature map. - """ - # These lines of code are designed to deal with various cases - # like bn without affine transform, and conv without bias - weight_on_the_fly = conv.weight - if conv.bias is not None: - bias_on_the_fly = conv.bias - else: - bias_on_the_fly = torch.zeros_like(bn.running_var) - - if bn.weight is not None: - bn_weight = bn.weight - else: - bn_weight = torch.ones_like(bn.running_var) - - if bn.bias is not None: - bn_bias = bn.bias - else: - bn_bias = torch.zeros_like(bn.running_var) - - # shape of [C_out, 1, 1, 1] in Conv2d - weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape([-1] + [1] * (len(conv.weight.shape) - 1)) - # shape of [C_out, 1, 1, 1] in Conv2d - coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff - - # shape of [C_out, C_in, k, k] in Conv2d - weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly - # shape of [C_out] in Conv2d - bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (bias_on_the_fly - bn.running_mean) - - return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) - - -@MODELS.register_module() -class ConvModule(nn.Module): - """A conv block that bundles conv/norm/activation layers. - - This block simplifies the usage of convolution layers, which are commonly - used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). - It is based upon three build methods: `build_conv_layer()`, - `build_norm_layer()` and `build_activation_layer()`. - - Besides, we add some additional features in this module. - 1. Automatically set `bias` of the conv layer. - 2. Spectral norm is supported. - 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only - supports zero and circular padding, and we add "reflect" padding mode. - - Args: - in_channels (int): Number of channels in the input feature map. - Same as that in ``nn._ConvNd``. - out_channels (int): Number of channels produced by the convolution. - Same as that in ``nn._ConvNd``. - kernel_size (int | tuple[int]): Size of the convolving kernel. - Same as that in ``nn._ConvNd``. - stride (int | tuple[int]): Stride of the convolution. - Same as that in ``nn._ConvNd``. - padding (int | tuple[int]): Zero-padding added to both sides of - the input. Same as that in ``nn._ConvNd``. - dilation (int | tuple[int]): Spacing between kernel elements. - Same as that in ``nn._ConvNd``. - groups (int): Number of blocked connections from input channels to - output channels. Same as that in ``nn._ConvNd``. - bias (bool | str): If specified as `auto`, it will be decided by the - norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise - False. Default: "auto". - conv_cfg (dict): Config dict for convolution layer. Default: None, - which means using conv2d. - norm_cfg (dict): Config dict for normalization layer. Default: None. - act_cfg (dict): Config dict for activation layer. - Default: dict(type='ReLU'). - inplace (bool): Whether to use inplace mode for activation. - Default: True. - with_spectral_norm (bool): Whether use spectral norm in conv module. - Default: False. - padding_mode (str): If the `padding_mode` has not been supported by - current `Conv2d` in PyTorch, we will use our own padding layer - instead. Currently, we support ['zeros', 'circular'] with official - implementation and ['reflect'] with our own implementation. - Default: 'zeros'. - order (tuple[str]): The order of conv/norm/activation layers. It is a - sequence of "conv", "norm" and "act". Common examples are - ("conv", "norm", "act") and ("act", "conv", "norm"). - Default: ('conv', 'norm', 'act'). - efficient_conv_bn_eval (bool): Whether use efficient conv when the - consecutive bn is in eval mode (either training or testing), as - proposed in https://arxiv.org/abs/2305.11624 . Default: `False`. - """ - - _abbr_ = "conv_block" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int | tuple[int, int], - stride: int | tuple[int, int] = 1, - padding: int | tuple[int, int] = 0, - dilation: int | tuple[int, int] = 1, - groups: int = 1, - bias: bool | str = "auto", - conv_cfg: dict | None = None, - norm_cfg: dict | None = None, - act_cfg: dict | None = dict(type="ReLU"), - inplace: bool = True, - with_spectral_norm: bool = False, - padding_mode: str = "zeros", - order: tuple = ("conv", "norm", "act"), - efficient_conv_bn_eval: bool = False, - ): - super().__init__() - assert conv_cfg is None or isinstance(conv_cfg, dict) - assert norm_cfg is None or isinstance(norm_cfg, dict) - assert act_cfg is None or isinstance(act_cfg, dict) - official_padding_mode = ["zeros", "circular"] - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.inplace = inplace - self.with_spectral_norm = with_spectral_norm - self.with_explicit_padding = padding_mode not in official_padding_mode - self.order = order - assert isinstance(self.order, tuple) and len(self.order) == 3 - assert set(order) == {"conv", "norm", "act"} - - self.with_norm = norm_cfg is not None - self.with_activation = act_cfg is not None - # if the conv layer is before a norm layer, bias is unnecessary. - if bias == "auto": - bias = not self.with_norm - self.with_bias = bias - - if self.with_explicit_padding: - pad_cfg = dict(type=padding_mode) - self.padding_layer = build_padding_layer(pad_cfg, padding) - - # reset padding to 0 for conv module - conv_padding = 0 if self.with_explicit_padding else padding - # build convolution layer - self.conv = build_conv_layer( - conv_cfg, - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=conv_padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - # export the attributes of self.conv to a higher level for convenience - self.in_channels = self.conv.in_channels - self.out_channels = self.conv.out_channels - self.kernel_size = self.conv.kernel_size - self.stride = self.conv.stride - self.padding = padding - self.dilation = self.conv.dilation - self.transposed = self.conv.transposed - self.output_padding = self.conv.output_padding - self.groups = self.conv.groups - - if self.with_spectral_norm: - self.conv = nn.utils.spectral_norm(self.conv) - - # build normalization layers - if self.with_norm: - # norm layer is after conv layer - if order.index("norm") > order.index("conv"): - norm_channels = out_channels - else: - norm_channels = in_channels - self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) # type: ignore - self.add_module(self.norm_name, norm) - if self.with_bias: - if isinstance(norm, (_BatchNorm, _InstanceNorm)): - warnings.warn("Unnecessary conv bias before batch/instance norm") - else: - self.norm_name = None # type: ignore - - self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval) - - # build activation layer - if self.with_activation: - act_cfg_ = act_cfg.copy() # type: ignore - # nn.Tanh has no 'inplace' argument - if act_cfg_["type"] not in [ - "Tanh", - "PReLU", - "Sigmoid", - "HSigmoid", - "Swish", - "GELU", - ]: - act_cfg_.setdefault("inplace", inplace) - self.activate = build_activation_layer(act_cfg_) - - # Use msra init by default - self.init_weights() - - @property - def norm(self): - if self.norm_name: - return getattr(self, self.norm_name) - else: - return None - - def init_weights(self): - # 1. It is mainly for customized conv layers with their own - # initialization manners by calling their own ``init_weights()``, - # and we do not want ConvModule to override the initialization. - # 2. For customized conv layers without their own initialization - # manners (that is, they don't have their own ``init_weights()``) - # and PyTorch's conv layers, they will be initialized by - # this method with default ``kaiming_init``. - # Note: For PyTorch's conv layers, they will be overwritten by our - # initialization implementation using default ``kaiming_init``. - if not hasattr(self.conv, "init_weights"): - if self.with_activation and self.act_cfg["type"] == "LeakyReLU": - nonlinearity = "leaky_relu" - a = self.act_cfg.get("negative_slope", 0.01) - else: - nonlinearity = "relu" - a = 0 - kaiming_init(self.conv, a=a, nonlinearity=nonlinearity) - if self.with_norm: - constant_init(self.norm, 1, bias=0) - - def forward(self, x: torch.Tensor, activate: bool = True, norm: bool = True) -> torch.Tensor: - layer_index = 0 - while layer_index < len(self.order): - layer = self.order[layer_index] - if layer == "conv": - if self.with_explicit_padding: - x = self.padding_layer(x) - # if the next operation is norm and we have a norm layer in - # eval mode and we have enabled `efficient_conv_bn_eval` for - # the conv operator, then activate the optimized forward and - # skip the next norm operator since it has been fused - if ( - layer_index + 1 < len(self.order) - and self.order[layer_index + 1] == "norm" - and norm - and self.with_norm - and not self.norm.training - and self.efficient_conv_bn_eval_forward is not None - ): - self.conv.forward = partial(self.efficient_conv_bn_eval_forward, self.norm, self.conv) - layer_index += 1 - x = self.conv(x) - del self.conv.forward - else: - x = self.conv(x) - elif layer == "norm" and norm and self.with_norm: - x = self.norm(x) - elif layer == "act" and activate and self.with_activation: - x = self.activate(x) - layer_index += 1 - return x - - def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True): - # efficient_conv_bn_eval works for conv + bn - # with `track_running_stats` option - if efficient_conv_bn_eval and self.norm and isinstance(self.norm, _BatchNorm) and self.norm.track_running_stats: - self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward - else: - self.efficient_conv_bn_eval_forward = None # type: ignore - - @staticmethod - def create_from_conv_bn( - conv: torch.nn.modules.conv._ConvNd, - bn: torch.nn.modules.batchnorm._BatchNorm, - efficient_conv_bn_eval=True, - ) -> "ConvModule": - """Create a ConvModule from a conv and a bn module.""" - self = ConvModule.__new__(ConvModule) - super(ConvModule, self).__init__() - - self.conv_cfg = None - self.norm_cfg = None - self.act_cfg = None - self.inplace = False - self.with_spectral_norm = False - self.with_explicit_padding = False - self.order = ("conv", "norm", "act") - - self.with_norm = True - self.with_activation = False - self.with_bias = conv.bias is not None - - # build convolution layer - self.conv = conv - # export the attributes of self.conv to a higher level for convenience - self.in_channels = self.conv.in_channels - self.out_channels = self.conv.out_channels - self.kernel_size = self.conv.kernel_size - self.stride = self.conv.stride - self.padding = self.conv.padding - self.dilation = self.conv.dilation - self.transposed = self.conv.transposed - self.output_padding = self.conv.output_padding - self.groups = self.conv.groups - - # build normalization layers - self.norm_name, norm = "bn", bn - self.add_module(self.norm_name, norm) - - self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval) - - return self diff --git a/libs/viscv/viscv/cnn/bricks/drop.py b/libs/viscv/viscv/cnn/bricks/drop.py deleted file mode 100644 index ca1b07b..0000000 --- a/libs/viscv/viscv/cnn/bricks/drop.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any - -import torch -import torch.nn as nn -from visengine.registry import MODELS - - -def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: - """Drop paths (Stochastic Depth) per sample (when applied in main path of - residual blocks). - - We follow the implementation - https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py - # noqa: E501 - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - # handle tensors with different dimensions, not just 4D tensors. - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - output = x.div(keep_prob) * random_tensor.floor() - return output - - -@MODELS.register_module() -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of - residual blocks). - - We follow the implementation - https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501 - - Args: - drop_prob (float): Probability of the path to be zeroed. Default: 0.1 - """ - - def __init__(self, drop_prob: float = 0.1): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return drop_path(x, self.drop_prob, self.training) - - -@MODELS.register_module() -class Dropout(nn.Dropout): - """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of - ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with - ``DropPath`` - - Args: - drop_prob (float): Probability of the elements to be - zeroed. Default: 0.5. - inplace (bool): Do the operation inplace or not. Default: False. - """ - - def __init__(self, drop_prob: float = 0.5, inplace: bool = False): - super().__init__(p=drop_prob, inplace=inplace) - - -def build_dropout(cfg: dict | float | None, default_args: dict | None = None) -> Any: - """Builder for drop out layers.""" - if cfg is None: - return None - if isinstance(cfg, float): - cfg = dict(type="Dropout", drop_prob=cfg) - return MODELS.build(cfg, default_args=default_args) diff --git a/libs/viscv/viscv/cnn/bricks/hsigmoid.py b/libs/viscv/viscv/cnn/bricks/hsigmoid.py deleted file mode 100644 index 5d2be8b..0000000 --- a/libs/viscv/viscv/cnn/bricks/hsigmoid.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import torch -import torch.nn as nn -from visengine.registry import MODELS - - -@MODELS.register_module() -class HSigmoid(nn.Module): - """Hard Sigmoid Module. Apply the hard sigmoid function: - Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) - Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1) - - Note: - In MMCV v1.4.4, we modified the default value of args to align with - PyTorch official. - - Args: - bias (float): Bias of the input feature map. Default: 3.0. - divisor (float): Divisor of the input feature map. Default: 6.0. - min_value (float): Lower bound value. Default: 0.0. - max_value (float): Upper bound value. Default: 1.0. - - Returns: - Tensor: The output tensor. - """ - - def __init__( - self, - bias: float = 3.0, - divisor: float = 6.0, - min_value: float = 0.0, - max_value: float = 1.0, - ): - super().__init__() - warnings.warn( - "In MMCV v1.4.4, we modified the default value of args to align " - "with PyTorch official. Previous Implementation: " - "Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). " - "Current Implementation: " - "Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).", - stacklevel=2, - ) - self.bias = bias - self.divisor = divisor - assert self.divisor != 0 - self.min_value = min_value - self.max_value = max_value - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = (x + self.bias) / self.divisor - - return x.clamp_(self.min_value, self.max_value) diff --git a/libs/viscv/viscv/cnn/bricks/hswish.py b/libs/viscv/viscv/cnn/bricks/hswish.py deleted file mode 100644 index 490c17b..0000000 --- a/libs/viscv/viscv/cnn/bricks/hswish.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn -from visengine.registry import MODELS - -HSwish = nn.Hardswish -MODELS.register_module(module=nn.Hardswish, name="HSwish") diff --git a/libs/viscv/viscv/cnn/bricks/norm.py b/libs/viscv/viscv/cnn/bricks/norm.py deleted file mode 100644 index 1fd75eb..0000000 --- a/libs/viscv/viscv/cnn/bricks/norm.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect - -import torch.nn as nn -from torch.nn import SyncBatchNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm -from visengine.registry import MODELS -from visengine.utils import is_tuple_of - -MODELS.register_module("BN", module=nn.BatchNorm2d) -MODELS.register_module("BN1d", module=nn.BatchNorm1d) -MODELS.register_module("BN2d", module=nn.BatchNorm2d) -MODELS.register_module("BN3d", module=nn.BatchNorm3d) -MODELS.register_module("SyncBN", module=SyncBatchNorm) -MODELS.register_module("GN", module=nn.GroupNorm) -MODELS.register_module("LN", module=nn.LayerNorm) -MODELS.register_module("IN", module=nn.InstanceNorm2d) -MODELS.register_module("IN1d", module=nn.InstanceNorm1d) -MODELS.register_module("IN2d", module=nn.InstanceNorm2d) -MODELS.register_module("IN3d", module=nn.InstanceNorm3d) - - -def infer_abbr(class_type): - """Infer abbreviation from the class name. - - When we build a norm layer with `build_norm_layer()`, we want to preserve - the norm type in variable names, e.g, self.bn1, self.gn. This method will - infer the abbreviation to map class types to abbreviations. - - Rule 1: If the class has the property "_abbr_", return the property. - Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or - InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and - "in" respectively. - Rule 3: If the class name contains "batch", "group", "layer" or "instance", - the abbreviation of this layer will be "bn", "gn", "ln" and "in" - respectively. - Rule 4: Otherwise, the abbreviation falls back to "norm". - - Args: - class_type (type): The norm layer type. - - Returns: - str: The inferred abbreviation. - """ - if not inspect.isclass(class_type): - raise TypeError(f"class_type must be a type, but got {type(class_type)}") - if hasattr(class_type, "_abbr_"): - return class_type._abbr_ - if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN - return "in" - elif issubclass(class_type, _BatchNorm): - return "bn" - elif issubclass(class_type, nn.GroupNorm): - return "gn" - elif issubclass(class_type, nn.LayerNorm): - return "ln" - else: - class_name = class_type.__name__.lower() - if "batch" in class_name: - return "bn" - elif "group" in class_name: - return "gn" - elif "layer" in class_name: - return "ln" - elif "instance" in class_name: - return "in" - else: - return "norm_layer" - - -def build_norm_layer(cfg: dict, num_features: int, postfix: int | str = "") -> tuple[str, nn.Module]: - """Build normalization layer. - - Args: - cfg (dict): The norm layer config, which should contain: - - - type (str): Layer type. - - layer args: Args needed to instantiate a norm layer. - - requires_grad (bool, optional): Whether stop gradient updates. - num_features (int): Number of input channels. - postfix (int | str): The postfix to be appended into norm abbreviation - to create named layer. - - Returns: - tuple[str, nn.Module]: The first element is the layer name consisting - of abbreviation and postfix, e.g., bn1, gn. The second element is the - created norm layer. - """ - if not isinstance(cfg, dict): - raise TypeError("cfg must be a dict") - if "type" not in cfg: - raise KeyError('the cfg dict must contain the key "type"') - cfg_ = cfg.copy() - - layer_type = cfg_.pop("type") - - if inspect.isclass(layer_type): - norm_layer = layer_type - else: - # Switch registry to the target scope. If `norm_layer` cannot be found - # in the registry, fallback to search `norm_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - norm_layer = registry.get(layer_type) - if norm_layer is None: - raise KeyError(f"Cannot find {norm_layer} in registry under scope name {registry.scope}") - abbr = infer_abbr(norm_layer) - - assert isinstance(postfix, (int, str)) - name = abbr + str(postfix) - - requires_grad = cfg_.pop("requires_grad", True) - cfg_.setdefault("eps", 1e-5) - if norm_layer is not nn.GroupNorm: - layer = norm_layer(num_features, **cfg_) - if layer_type == "SyncBN" and hasattr(layer, "_specify_ddp_gpu_num"): - layer._specify_ddp_gpu_num(1) - else: - assert "num_groups" in cfg_ - layer = norm_layer(num_channels=num_features, **cfg_) - - for param in layer.parameters(): - param.requires_grad = requires_grad - - return name, layer - - -def is_norm(layer: nn.Module, exclude: type | tuple | None = None) -> bool: - """Check if a layer is a normalization layer. - - Args: - layer (nn.Module): The layer to be checked. - exclude (type | tuple[type]): Types to be excluded. - - Returns: - bool: Whether the layer is a norm layer. - """ - if exclude is not None: - if not isinstance(exclude, tuple): - exclude = (exclude,) - if not is_tuple_of(exclude, type): - raise TypeError( - f'"exclude" must be either None or type or a tuple of types, but got {type(exclude)}: {exclude}' - ) - - if exclude and isinstance(layer, exclude): - return False - - all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) - return isinstance(layer, all_norm_bases) diff --git a/libs/viscv/viscv/cnn/bricks/padding.py b/libs/viscv/viscv/cnn/bricks/padding.py deleted file mode 100644 index df0b537..0000000 --- a/libs/viscv/viscv/cnn/bricks/padding.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect - -import torch.nn as nn -from visengine.registry import MODELS - -MODELS.register_module("zero", module=nn.ZeroPad2d) -MODELS.register_module("reflect", module=nn.ReflectionPad2d) -MODELS.register_module("replicate", module=nn.ReplicationPad2d) - - -def build_padding_layer(cfg: dict, *args, **kwargs) -> nn.Module: - """Build padding layer. - - Args: - cfg (dict): The padding layer config, which should contain: - - type (str): Layer type. - - layer args: Args needed to instantiate a padding layer. - - Returns: - nn.Module: Created padding layer. - """ - if not isinstance(cfg, dict): - raise TypeError("cfg must be a dict") - if "type" not in cfg: - raise KeyError('the cfg dict must contain the key "type"') - - cfg_ = cfg.copy() - padding_type = cfg_.pop("type") - if inspect.isclass(padding_type): - return padding_type(*args, **kwargs, **cfg_) - # Switch registry to the target scope. If `padding_layer` cannot be found - # in the registry, fallback to search `padding_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - padding_layer = registry.get(padding_type) - if padding_layer is None: - raise KeyError(f"Cannot find {padding_layer} in registry under scope name {registry.scope}") - layer = padding_layer(*args, **kwargs, **cfg_) - - return layer diff --git a/libs/viscv/viscv/cnn/bricks/scale.py b/libs/viscv/viscv/cnn/bricks/scale.py deleted file mode 100644 index e708786..0000000 --- a/libs/viscv/viscv/cnn/bricks/scale.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn - - -class Scale(nn.Module): - """A learnable scale parameter. - - This layer scales the input by a learnable factor. It multiplies a - learnable scale parameter of shape (1,) with input of any shape. - - Args: - scale (float): Initial value of scale factor. Default: 1.0 - """ - - def __init__(self, scale: float = 1.0): - super().__init__() - self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * self.scale - - -class LayerScale(nn.Module): - """LayerScale layer. - - Args: - dim (int): Dimension of input features. - inplace (bool): Whether performs operation in-place. - Default: `False`. - data_format (str): The input data format, could be 'channels_last' - or 'channels_first', representing (B, C, H, W) and - (B, N, C) format data respectively. Default: 'channels_last'. - scale (float): Initial value of scale factor. Default: 1.0 - """ - - def __init__( - self, - dim: int, - inplace: bool = False, - data_format: str = "channels_last", - scale: float = 1e-5, - ): - super().__init__() - assert data_format in ("channels_last", "channels_first"), ( - "'data_format' could only be channels_last or channels_first." - ) - self.inplace = inplace - self.data_format = data_format - self.weight = nn.Parameter(torch.ones(dim) * scale) - - def forward(self, x) -> torch.Tensor: - if self.data_format == "channels_first": - shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) # noqa: C409 - else: - shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) # noqa: C409 - if self.inplace: - return x.mul_(self.weight.view(*shape)) - else: - return x * self.weight.view(*shape) diff --git a/libs/viscv/viscv/cnn/bricks/swish.py b/libs/viscv/viscv/cnn/bricks/swish.py deleted file mode 100644 index 192d4f5..0000000 --- a/libs/viscv/viscv/cnn/bricks/swish.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -from visengine.registry import MODELS - - -@MODELS.register_module() -class Swish(nn.Module): - """Swish Module. - - This module applies the swish function: - - .. math:: - Swish(x) = x * Sigmoid(x) - - Returns: - Tensor: The output tensor. - """ - - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.sigmoid(x) diff --git a/libs/viscv/viscv/cnn/bricks/transformer.py b/libs/viscv/viscv/cnn/bricks/transformer.py deleted file mode 100644 index ea042a9..0000000 --- a/libs/viscv/viscv/cnn/bricks/transformer.py +++ /dev/null @@ -1,932 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import math -import warnings -from collections.abc import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F -from visengine.config import ConfigDict -from visengine.model import BaseModule, ModuleList, Sequential -from visengine.registry import MODELS -from visengine.utils import deprecated_api_warning, to_2tuple - -from viscv.cnn.bricks.activation import build_activation_layer -from viscv.cnn.bricks.conv import build_conv_layer -from viscv.cnn.bricks.norm import build_norm_layer -from viscv.cnn.bricks.wrappers import Linear - -from .drop import build_dropout -from .scale import LayerScale - - -def build_positional_encoding(cfg, default_args=None): - """Builder for Position Encoding.""" - return MODELS.build(cfg, default_args=default_args) - - -def build_attention(cfg, default_args=None): - """Builder for attention.""" - return MODELS.build(cfg, default_args=default_args) - - -def build_feedforward_network(cfg, default_args=None): - """Builder for feed-forward network (FFN).""" - return MODELS.build(cfg, default_args=default_args) - - -def build_transformer_layer(cfg, default_args=None): - """Builder for transformer layer.""" - return MODELS.build(cfg, default_args=default_args) - - -def build_transformer_layer_sequence(cfg, default_args=None): - """Builder for transformer encoder and transformer decoder.""" - return MODELS.build(cfg, default_args=default_args) - - -class AdaptivePadding(nn.Module): - """Applies padding adaptively to the input. - - This module can make input get fully covered by filter - you specified. It support two modes "same" and "corner". The - "same" mode is same with "SAME" padding mode in TensorFlow, pad - zero around input. The "corner" mode would pad zero - to bottom right. - - Args: - kernel_size (int | tuple): Size of the kernel. Default: 1. - stride (int | tuple): Stride of the filter. Default: 1. - dilation (int | tuple): Spacing between kernel elements. - Default: 1. - padding (str): Support "same" and "corner", "corner" mode - would pad zero to bottom right, and "same" mode would - pad zero around input. Default: "corner". - - Example: - >>> kernel_size = 16 - >>> stride = 16 - >>> dilation = 1 - >>> input = torch.rand(1, 1, 15, 17) - >>> adap_pad = AdaptivePadding( - >>> kernel_size=kernel_size, - >>> stride=stride, - >>> dilation=dilation, - >>> padding="corner") - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - >>> input = torch.rand(1, 1, 16, 17) - >>> out = adap_pad(input) - >>> assert (out.shape[2], out.shape[3]) == (16, 32) - """ - - def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): - super().__init__() - assert padding in ("same", "corner") - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - self.padding = padding - self.kernel_size = kernel_size - self.stride = stride - self.dilation = dilation - - def get_pad_shape(self, input_shape): - """Calculate the padding size of input. - - Args: - input_shape (:obj:`torch.Size`): arrange as (H, W). - - Returns: - Tuple[int]: The padding size along the - original H and W directions - """ - input_h, input_w = input_shape - kernel_h, kernel_w = self.kernel_size - stride_h, stride_w = self.stride - output_h = math.ceil(input_h / stride_h) - output_w = math.ceil(input_w / stride_w) - pad_h = max( - (output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, - 0, - ) - pad_w = max( - (output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, - 0, - ) - return pad_h, pad_w - - def forward(self, x): - """Add padding to `x` - - Args: - x (Tensor): Input tensor has shape (B, C, H, W). - - Returns: - Tensor: The tensor with adaptive padding - """ - pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) - if pad_h > 0 or pad_w > 0: - if self.padding == "corner": - x = F.pad(x, [0, pad_w, 0, pad_h]) - elif self.padding == "same": - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return x - - -class PatchEmbed(BaseModule): - """Image to Patch Embedding. - - We use a conv layer to implement PatchEmbed. - - Args: - in_channels (int): The num of input channels. Default: 3 - embed_dims (int): The dimensions of embedding. Default: 768 - conv_type (str): The type of convolution - to generate patch embedding. Default: "Conv2d". - kernel_size (int): The kernel_size of embedding conv. Default: 16. - stride (int): The slide stride of embedding conv. - Default: 16. - padding (int | tuple | string): The padding length of - embedding conv. When it is a string, it means the mode - of adaptive padding, support "same" and "corner" now. - Default: "corner". - dilation (int): The dilation rate of embedding conv. Default: 1. - bias (bool): Bias of embed conv. Default: True. - norm_cfg (dict, optional): Config dict for normalization layer. - Default: None. - input_size (int | tuple | None): The size of input, which will be - used to calculate the out size. Only works when `dynamic_size` - is False. Default: None. - init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. - Default: None. - """ - - def __init__( - self, - in_channels=3, - embed_dims=768, - conv_type="Conv2d", - kernel_size=16, - stride=16, - padding="corner", - dilation=1, - bias=True, - norm_cfg=None, - input_size=None, - init_cfg=None, - ): - super().__init__(init_cfg=init_cfg) - - self.embed_dims = embed_dims - if stride is None: - stride = kernel_size - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - if isinstance(padding, str): - self.adaptive_padding = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding, - ) - # disable the padding of conv - padding = 0 - else: - self.adaptive_padding = None - padding = to_2tuple(padding) - - self.projection = build_conv_layer( - dict(type=conv_type), - in_channels=in_channels, - out_channels=embed_dims, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, embed_dims)[1] - else: - self.norm = None - - if input_size: - input_size = to_2tuple(input_size) - # `init_out_size` would be used outside to - # calculate the num_patches - # e.g. when `use_abs_pos_embed` outside - self.init_input_size = input_size - if self.adaptive_padding: - pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size) - input_h, input_w = input_size - input_h = input_h + pad_h - input_w = input_w + pad_w - input_size = (input_h, input_w) - - # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - h_out = (input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 - w_out = (input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 - self.init_out_size = (h_out, w_out) - else: - self.init_input_size = None - self.init_out_size = None - - def forward(self, x): - """ - Args: - x (Tensor): Has shape (B, C, H, W). In most case, C is 3. - - Returns: - tuple: Contains merged results and its spatial shape. - - - x (Tensor): Has shape (B, out_h * out_w, embed_dims) - - out_size (tuple[int]): Spatial shape of x, arrange as - (out_h, out_w). - """ - - if self.adaptive_padding: - x = self.adaptive_padding(x) - - x = self.projection(x) - out_size = (x.shape[2], x.shape[3]) - x = x.flatten(2).transpose(1, 2) - if self.norm is not None: - x = self.norm(x) - return x, out_size - - -class PatchMerging(BaseModule): - """Merge patch feature map. - - This layer groups feature map by kernel_size, and applies norm and linear - layers to the grouped feature map ((used in Swin Transformer)). - Our implementation uses `nn.Unfold` to - merge patches, which is about 25% faster than the original - implementation. However, we need to modify pretrained - models for compatibility. - - Args: - in_channels (int): The num of input channels. - to gets fully covered by filter and stride you specified. - out_channels (int): The num of output channels. - kernel_size (int | tuple, optional): the kernel size in the unfold - layer. Defaults to 2. - stride (int | tuple, optional): the stride of the sliding blocks in the - unfold layer. Default: None. (Would be set as `kernel_size`) - padding (int | tuple | string ): The padding length of - embedding conv. When it is a string, it means the mode - of adaptive padding, support "same" and "corner" now. - Default: "corner". - dilation (int | tuple, optional): dilation parameter in the unfold - layer. Default: 1. - bias (bool, optional): Whether to add bias in linear layer or not. - Defaults: False. - norm_cfg (dict, optional): Config dict for normalization layer. - Default: dict(type='LN'). - init_cfg (dict, optional): The extra config for initialization. - Default: None. - """ - - def __init__( - self, - in_channels, - out_channels, - kernel_size=2, - stride=None, - padding="corner", - dilation=1, - bias=False, - norm_cfg=dict(type="LN"), - init_cfg=None, - ): - super().__init__(init_cfg=init_cfg) - self.in_channels = in_channels - self.out_channels = out_channels - if stride: - stride = stride - else: - stride = kernel_size - - kernel_size = to_2tuple(kernel_size) - stride = to_2tuple(stride) - dilation = to_2tuple(dilation) - - if isinstance(padding, str): - self.adaptive_padding = AdaptivePadding( - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding, - ) - # disable the padding of unfold - padding = 0 - else: - self.adaptive_padding = None - - padding = to_2tuple(padding) - self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) - - sample_dim = kernel_size[0] * kernel_size[1] * in_channels - - if norm_cfg is not None: - self.norm = build_norm_layer(norm_cfg, sample_dim)[1] - else: - self.norm = None - - self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) - - def forward(self, x, input_size): - """ - Args: - x (Tensor): Has shape (B, H*W, C_in). - input_size (tuple[int]): The spatial shape of x, arrange as (H, W). - Default: None. - - Returns: - tuple: Contains merged results and its spatial shape. - - - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) - - out_size (tuple[int]): Spatial shape of x, arrange as - (Merged_H, Merged_W). - """ - B, L, C = x.shape - assert isinstance(input_size, Sequence), f"Expect input_size is `Sequence` but get {input_size}" - - H, W = input_size - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W - - if self.adaptive_padding: - x = self.adaptive_padding(x) - H, W = x.shape[-2:] - - # Use nn.Unfold to merge patch. About 25% faster than original method, - # but need to modify pretrained model for compatibility - # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) - x = self.sampler(x) - - out_h = ( - H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 - ) // self.sampler.stride[0] + 1 - out_w = ( - W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 - ) // self.sampler.stride[1] + 1 - - output_size = (out_h, out_w) - x = x.transpose(1, 2) # B, H/2*W/2, 4*C - x = self.norm(x) if self.norm else x - x = self.reduction(x) - return x, output_size - - -@MODELS.register_module() -class MultiheadAttention(BaseModule): - """A wrapper for ``torch.nn.MultiheadAttention``. - - This module implements MultiheadAttention with identity connection, - and positional encoding is also passed as input. - - Args: - embed_dims (int): The embedding dimension. - num_heads (int): Parallel attention heads. - attn_drop (float): A Dropout layer on attn_output_weights. - Default: 0.0. - proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. - Default: 0.0. - dropout_layer (obj:`ConfigDict`): The dropout_layer used - when adding the shortcut. - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - batch_first (bool): When it is True, Key, Query and Value are shape of - (batch, n, embed_dim), otherwise (n, batch, embed_dim). - Default to False. - """ - - def __init__( - self, - embed_dims, - num_heads, - attn_drop=0.0, - proj_drop=0.0, - dropout_layer=dict(type="Dropout", drop_prob=0.0), - init_cfg=None, - batch_first=False, - **kwargs, - ): - super().__init__(init_cfg) - if "dropout" in kwargs: - warnings.warn( - "The arguments `dropout` in MultiheadAttention " - "has been deprecated, now you can separately " - "set `attn_drop`(float), proj_drop(float), " - "and `dropout_layer`(dict) ", - DeprecationWarning, - ) - attn_drop = kwargs["dropout"] - dropout_layer["drop_prob"] = kwargs.pop("dropout") - - self.embed_dims = embed_dims - self.num_heads = num_heads - self.batch_first = batch_first - - self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) - - self.proj_drop = nn.Dropout(proj_drop) - self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else nn.Identity() - - @deprecated_api_warning({"residual": "identity"}, cls_name="MultiheadAttention") - def forward( - self, - query, - key=None, - value=None, - identity=None, - query_pos=None, - key_pos=None, - attn_mask=None, - key_padding_mask=None, - **kwargs, - ): - """Forward function for `MultiheadAttention`. - - **kwargs allow passing a more general data flow when combining - with other operations in `transformerlayer`. - - Args: - query (Tensor): The input query with shape [num_queries, bs, - embed_dims] if self.batch_first is False, else - [bs, num_queries embed_dims]. - key (Tensor): The key tensor with shape [num_keys, bs, - embed_dims] if self.batch_first is False, else - [bs, num_keys, embed_dims] . - If None, the ``query`` will be used. Defaults to None. - value (Tensor): The value tensor with same shape as `key`. - Same in `nn.MultiheadAttention.forward`. Defaults to None. - If None, the `key` will be used. - identity (Tensor): This tensor, with the same shape as x, - will be used for the identity link. - If None, `x` will be used. Defaults to None. - query_pos (Tensor): The positional encoding for query, with - the same shape as `x`. If not None, it will - be added to `x` before forward function. Defaults to None. - key_pos (Tensor): The positional encoding for `key`, with the - same shape as `key`. Defaults to None. If not None, it will - be added to `key` before forward function. If None, and - `query_pos` has the same shape as `key`, then `query_pos` - will be used for `key_pos`. Defaults to None. - attn_mask (Tensor): ByteTensor mask with shape [num_queries, - num_keys]. Same in `nn.MultiheadAttention.forward`. - Defaults to None. - key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. - Defaults to None. - - Returns: - Tensor: forwarded results with shape - [num_queries, bs, embed_dims] - if self.batch_first is False, else - [bs, num_queries embed_dims]. - """ - - if key is None: - key = query - if value is None: - value = key - if identity is None: - identity = query - if key_pos is None: - if query_pos is not None: - # use query_pos if key_pos is not available - if query_pos.shape == key.shape: - key_pos = query_pos - else: - warnings.warn(f"position encoding of key ismissing in {self.__class__.__name__}.") - if query_pos is not None: - query = query + query_pos - if key_pos is not None: - key = key + key_pos - - # Because the dataflow('key', 'query', 'value') of - # ``torch.nn.MultiheadAttention`` is (num_query, batch, - # embed_dims), We should adjust the shape of dataflow from - # batch_first (batch, num_query, embed_dims) to num_query_first - # (num_query ,batch, embed_dims), and recover ``attn_output`` - # from num_query_first to batch_first. - if self.batch_first: - query = query.transpose(0, 1) - key = key.transpose(0, 1) - value = value.transpose(0, 1) - - out = self.attn( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - )[0] - - if self.batch_first: - out = out.transpose(0, 1) - - return identity + self.dropout_layer(self.proj_drop(out)) - - -@MODELS.register_module() -class FFN(BaseModule): - """Implements feed-forward networks (FFNs) with identity connection. - - Args: - embed_dims (int): The feature dimension. Same as - `MultiheadAttention`. Defaults: 256. - feedforward_channels (int): The hidden dimension of FFNs. - Defaults: 1024. - num_fcs (int, optional): The number of fully-connected layers in - FFNs. Default: 2. - act_cfg (dict, optional): The activation config for FFNs. - Default: dict(type='ReLU') - ffn_drop (float, optional): Probability of an element to be - zeroed in FFN. Default 0.0. - add_identity (bool, optional): Whether to add the - identity connection. Default: `True`. - dropout_layer (obj:`ConfigDict`): The dropout_layer used - when adding the shortcut. - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - layer_scale_init_value (float): Initial value of scale factor in - LayerScale. Default: 1.0 - """ - - @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") - def __init__( - self, - embed_dims=256, - feedforward_channels=1024, - num_fcs=2, - act_cfg=dict(type="ReLU", inplace=True), - ffn_drop=0.0, - dropout_layer=None, - add_identity=True, - init_cfg=None, - layer_scale_init_value=0.0, - ): - super().__init__(init_cfg) - assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}." - self.embed_dims = embed_dims - self.feedforward_channels = feedforward_channels - self.num_fcs = num_fcs - - layers = [] - in_channels = embed_dims - for _ in range(num_fcs - 1): - layers.append( - Sequential( - Linear(in_channels, feedforward_channels), - build_activation_layer(act_cfg), - nn.Dropout(ffn_drop), - ) - ) - in_channels = feedforward_channels - layers.append(Linear(feedforward_channels, embed_dims)) - layers.append(nn.Dropout(ffn_drop)) - self.layers = Sequential(*layers) - self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() - self.add_identity = add_identity - - if layer_scale_init_value > 0: - self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value) - else: - self.gamma2 = nn.Identity() - - @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") - def forward(self, x, identity=None): - """Forward function for `FFN`. - - The function would add x to the output tensor if residue is None. - """ - out = self.layers(x) - out = self.gamma2(out) - if not self.add_identity: - return self.dropout_layer(out) - if identity is None: - identity = x - return identity + self.dropout_layer(out) - - -@MODELS.register_module() -class BaseTransformerLayer(BaseModule): - """Base `TransformerLayer` for vision transformer. - - It can be built from `mmcv.ConfigDict` and support more flexible - customization, for example, using any number of `FFN or LN ` and - use different kinds of `attention` by specifying a list of `ConfigDict` - named `attn_cfgs`. It is worth mentioning that it supports `prenorm` - when you specifying `norm` as the first element of `operation_order`. - More details about the `prenorm`: `On Layer Normalization in the - Transformer Architecture `_ . - - Args: - attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): - Configs for `self_attention` or `cross_attention` modules, - The order of the configs in the list should be consistent with - corresponding attentions in operation_order. - If it is a dict, all of the attention modules in operation_order - will be built with this config. Default: None. - ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): - Configs for FFN, The order of the configs in the list should be - consistent with corresponding ffn in operation_order. - If it is a dict, all of the attention modules in operation_order - will be built with this config. - operation_order (tuple[str]): The execution order of operation - in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). - Support `prenorm` when you specifying first element as `norm`. - Default:None. - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='LN'). - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - batch_first (bool): Key, Query and Value are shape - of (batch, n, embed_dim) - or (n, batch, embed_dim). Default to False. - """ - - def __init__( - self, - attn_cfgs=None, - ffn_cfgs=dict( - type="FFN", - embed_dims=256, - feedforward_channels=1024, - num_fcs=2, - ffn_drop=0.0, - act_cfg=dict(type="ReLU", inplace=True), - ), - operation_order=None, - norm_cfg=dict(type="LN"), - init_cfg=None, - batch_first=False, - **kwargs, - ): - deprecated_args = dict( - feedforward_channels="feedforward_channels", - ffn_dropout="ffn_drop", - ffn_num_fcs="num_fcs", - ) - for ori_name, new_name in deprecated_args.items(): - if ori_name in kwargs: - warnings.warn( - f"The arguments `{ori_name}` in BaseTransformerLayer " - f"has been deprecated, now you should set `{new_name}` " - f"and other FFN related arguments " - f"to a dict named `ffn_cfgs`. ", - DeprecationWarning, - ) - ffn_cfgs[new_name] = kwargs[ori_name] - - super().__init__(init_cfg) - - self.batch_first = batch_first - - assert set(operation_order) & {"self_attn", "norm", "ffn", "cross_attn"} == set(operation_order), ( - f"The operation_order of {self.__class__.__name__} should contains all four operation type {['self_attn', 'norm', 'ffn', 'cross_attn']}" - ) - - num_attn = operation_order.count("self_attn") + operation_order.count("cross_attn") - if isinstance(attn_cfgs, dict): - attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] - else: - assert num_attn == len(attn_cfgs), ( - f"The length of attn_cfg {num_attn} is not consistent with the number of attentionin operation_order {operation_order}." - ) - - self.num_attn = num_attn - self.operation_order = operation_order - self.norm_cfg = norm_cfg - self.pre_norm = operation_order[0] == "norm" - self.attentions = ModuleList() - - index = 0 - for operation_name in operation_order: - if operation_name in ["self_attn", "cross_attn"]: - if "batch_first" in attn_cfgs[index]: - assert self.batch_first == attn_cfgs[index]["batch_first"] - else: - attn_cfgs[index]["batch_first"] = self.batch_first - attention = build_attention(attn_cfgs[index]) - # Some custom attentions used as `self_attn` - # or `cross_attn` can have different behavior. - attention.operation_name = operation_name - self.attentions.append(attention) - index += 1 - - self.embed_dims = self.attentions[0].embed_dims - - self.ffns = ModuleList() - num_ffns = operation_order.count("ffn") - if isinstance(ffn_cfgs, dict): - ffn_cfgs = ConfigDict(ffn_cfgs) - if isinstance(ffn_cfgs, dict): - ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] - assert len(ffn_cfgs) == num_ffns - for ffn_index in range(num_ffns): - if "embed_dims" not in ffn_cfgs[ffn_index]: - ffn_cfgs[ffn_index]["embed_dims"] = self.embed_dims - else: - assert ffn_cfgs[ffn_index]["embed_dims"] == self.embed_dims - self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index], dict(type="FFN"))) - - self.norms = ModuleList() - num_norms = operation_order.count("norm") - for _ in range(num_norms): - self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) - - def forward( - self, - query, - key=None, - value=None, - query_pos=None, - key_pos=None, - attn_masks=None, - query_key_padding_mask=None, - key_padding_mask=None, - **kwargs, - ): - """Forward function for `TransformerDecoderLayer`. - - **kwargs contains some specific arguments of attentions. - - Args: - query (Tensor): The input query with shape - [num_queries, bs, embed_dims] if - self.batch_first is False, else - [bs, num_queries embed_dims]. - key (Tensor): The key tensor with shape [num_keys, bs, - embed_dims] if self.batch_first is False, else - [bs, num_keys, embed_dims] . - value (Tensor): The value tensor with same shape as `key`. - query_pos (Tensor): The positional encoding for `query`. - Default: None. - key_pos (Tensor): The positional encoding for `key`. - Default: None. - attn_masks (List[Tensor] | None): 2D Tensor used in - calculation of corresponding attention. The length of - it should equal to the number of `attention` in - `operation_order`. Default: None. - query_key_padding_mask (Tensor): ByteTensor for `query`, with - shape [bs, num_queries]. Only used in `self_attn` layer. - Defaults to None. - key_padding_mask (Tensor): ByteTensor for `query`, with - shape [bs, num_keys]. Default: None. - - Returns: - Tensor: forwarded results with shape [num_queries, bs, embed_dims]. - """ - - norm_index = 0 - attn_index = 0 - ffn_index = 0 - identity = query - if attn_masks is None: - attn_masks = [None for _ in range(self.num_attn)] - elif isinstance(attn_masks, torch.Tensor): - attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)] - warnings.warn(f"Use same attn_mask in all attentions in {self.__class__.__name__} ") - else: - assert len(attn_masks) == self.num_attn, ( - f"The length of attn_masks {len(attn_masks)} must be equal to the number of attention in operation_order {self.num_attn}" - ) - - for layer in self.operation_order: - if layer == "self_attn": - temp_key = temp_value = query - query = self.attentions[attn_index]( - query, - temp_key, - temp_value, - identity if self.pre_norm else None, - query_pos=query_pos, - key_pos=query_pos, - attn_mask=attn_masks[attn_index], - key_padding_mask=query_key_padding_mask, - **kwargs, - ) - attn_index += 1 - identity = query - - elif layer == "norm": - query = self.norms[norm_index](query) - norm_index += 1 - - elif layer == "cross_attn": - query = self.attentions[attn_index]( - query, - key, - value, - identity if self.pre_norm else None, - query_pos=query_pos, - key_pos=key_pos, - attn_mask=attn_masks[attn_index], - key_padding_mask=key_padding_mask, - **kwargs, - ) - attn_index += 1 - identity = query - - elif layer == "ffn": - query = self.ffns[ffn_index](query, identity if self.pre_norm else None) - ffn_index += 1 - - return query - - -@MODELS.register_module() -class TransformerLayerSequence(BaseModule): - """Base class for TransformerEncoder and TransformerDecoder in vision - transformer. - - As base-class of Encoder and Decoder in vision transformer. - Support customization such as specifying different kind - of `transformer_layer` in `transformer_coder`. - - Args: - transformerlayer (list[obj:`mmcv.ConfigDict`] | - obj:`mmcv.ConfigDict`): Config of transformerlayer - in TransformerCoder. If it is obj:`mmcv.ConfigDict`, - it would be repeated `num_layer` times to a - list[`mmcv.ConfigDict`]. Default: None. - num_layers (int): The number of `TransformerLayer`. Default: None. - init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. - Default: None. - """ - - def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): - super().__init__(init_cfg) - if isinstance(transformerlayers, dict): - transformerlayers = [copy.deepcopy(transformerlayers) for _ in range(num_layers)] - else: - assert isinstance(transformerlayers, list) and len(transformerlayers) == num_layers - self.num_layers = num_layers - self.layers = ModuleList() - for i in range(num_layers): - self.layers.append(build_transformer_layer(transformerlayers[i])) - self.embed_dims = self.layers[0].embed_dims - self.pre_norm = self.layers[0].pre_norm - - def forward( - self, - query, - key, - value, - query_pos=None, - key_pos=None, - attn_masks=None, - query_key_padding_mask=None, - key_padding_mask=None, - **kwargs, - ): - """Forward function for `TransformerCoder`. - - Args: - query (Tensor): Input query with shape - `(num_queries, bs, embed_dims)`. - key (Tensor): The key tensor with shape - `(num_keys, bs, embed_dims)`. - value (Tensor): The value tensor with shape - `(num_keys, bs, embed_dims)`. - query_pos (Tensor): The positional encoding for `query`. - Default: None. - key_pos (Tensor): The positional encoding for `key`. - Default: None. - attn_masks (List[Tensor], optional): Each element is 2D Tensor - which is used in calculation of corresponding attention in - operation_order. Default: None. - query_key_padding_mask (Tensor): ByteTensor for `query`, with - shape [bs, num_queries]. Only used in self-attention - Default: None. - key_padding_mask (Tensor): ByteTensor for `query`, with - shape [bs, num_keys]. Default: None. - - Returns: - Tensor: results with shape [num_queries, bs, embed_dims]. - """ - for layer in self.layers: - query = layer( - query, - key, - value, - query_pos=query_pos, - key_pos=key_pos, - attn_masks=attn_masks, - query_key_padding_mask=query_key_padding_mask, - key_padding_mask=key_padding_mask, - **kwargs, - ) - return query diff --git a/libs/viscv/viscv/cnn/bricks/upsample.py b/libs/viscv/viscv/cnn/bricks/upsample.py deleted file mode 100644 index 5c2d768..0000000 --- a/libs/viscv/viscv/cnn/bricks/upsample.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect - -import torch -import torch.nn as nn -import torch.nn.functional as F -from visengine.model import xavier_init -from visengine.registry import MODELS - -MODELS.register_module("nearest", module=nn.Upsample) -MODELS.register_module("bilinear", module=nn.Upsample) - - -@MODELS.register_module(name="pixel_shuffle") -class PixelShufflePack(nn.Module): - """Pixel Shuffle upsample layer. - - This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to - achieve a simple upsampling with pixel shuffle. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - scale_factor (int): Upsample ratio. - upsample_kernel (int): Kernel size of the conv layer to expand the - channels. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - scale_factor: int, - upsample_kernel: int, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.scale_factor = scale_factor - self.upsample_kernel = upsample_kernel - self.upsample_conv = nn.Conv2d( - self.in_channels, - self.out_channels * scale_factor * scale_factor, - self.upsample_kernel, - padding=(self.upsample_kernel - 1) // 2, - ) - self.init_weights() - - def init_weights(self): - xavier_init(self.upsample_conv, distribution="uniform") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.upsample_conv(x) - x = F.pixel_shuffle(x, self.scale_factor) - return x - - -def build_upsample_layer(cfg: dict, *args, **kwargs) -> nn.Module: - """Build upsample layer. - - Args: - cfg (dict): The upsample layer config, which should contain: - - - type (str): Layer type. - - scale_factor (int): Upsample ratio, which is not applicable to - deconv. - - layer args: Args needed to instantiate a upsample layer. - args (argument list): Arguments passed to the ``__init__`` - method of the corresponding conv layer. - kwargs (keyword arguments): Keyword arguments passed to the - ``__init__`` method of the corresponding conv layer. - - Returns: - nn.Module: Created upsample layer. - """ - if not isinstance(cfg, dict): - raise TypeError(f"cfg must be a dict, but got {type(cfg)}") - if "type" not in cfg: - raise KeyError(f'the cfg dict must contain the key "type", but got {cfg}') - cfg_ = cfg.copy() - - layer_type = cfg_.pop("type") - - if inspect.isclass(layer_type): - upsample = layer_type - # Switch registry to the target scope. If `upsample` cannot be found - # in the registry, fallback to search `upsample` in the - # mmengine.MODELS. - else: - with MODELS.switch_scope_and_registry(None) as registry: - upsample = registry.get(layer_type) - if upsample is None: - raise KeyError(f"Cannot find {upsample} in registry under scope name {registry.scope}") - if upsample is nn.Upsample: - cfg_["mode"] = layer_type - layer = upsample(*args, **kwargs, **cfg_) - return layer diff --git a/libs/viscv/viscv/cnn/bricks/wrappers.py b/libs/viscv/viscv/cnn/bricks/wrappers.py deleted file mode 100644 index 88301fa..0000000 --- a/libs/viscv/viscv/cnn/bricks/wrappers.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501 - -Wrap some nn modules to support empty tensor input. Currently, these wrappers -are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask -heads are trained on only positive RoIs. -""" - -import math - -import torch -import torch.nn as nn -from torch.nn.modules.utils import _pair, _triple -from visengine.registry import MODELS - -if torch.__version__ == "parrots": - TORCH_VERSION = torch.__version__ -else: - # torch.__version__ could be 1.3.1+cu92, we only need the first two - # for comparison - TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) - - -def obsolete_torch_version(torch_version, version_threshold) -> bool: - return torch_version == "parrots" or torch_version <= version_threshold - - -class NewEmptyTensorOp(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor: - ctx.shape = x.shape - return x.new_empty(new_shape) - - @staticmethod - def backward(ctx, grad: torch.Tensor) -> tuple: - shape = ctx.shape - return NewEmptyTensorOp.apply(grad, shape), None - - -@MODELS.register_module("Conv", force=True) -class Conv2d(nn.Conv2d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: - out_shape = [x.shape[0], self.out_channels] - for i, k, p, s, d in zip( - x.shape[-2:], - self.kernel_size, - self.padding, - self.stride, - self.dilation, - strict=False, - ): - o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 - out_shape.append(o) - empty = NewEmptyTensorOp.apply(x, out_shape) - if self.training: - # produce dummy gradient to avoid DDP warning. - dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 - return empty + dummy - else: - return empty - - return super().forward(x) - - -@MODELS.register_module("Conv3d", force=True) -class Conv3d(nn.Conv3d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: - out_shape = [x.shape[0], self.out_channels] - for i, k, p, s, d in zip( - x.shape[-3:], - self.kernel_size, - self.padding, - self.stride, - self.dilation, - strict=False, - ): - o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1 - out_shape.append(o) - empty = NewEmptyTensorOp.apply(x, out_shape) - if self.training: - # produce dummy gradient to avoid DDP warning. - dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 - return empty + dummy - else: - return empty - - return super().forward(x) - - -@MODELS.register_module() -@MODELS.register_module("deconv") -class ConvTranspose2d(nn.ConvTranspose2d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: - out_shape = [x.shape[0], self.out_channels] - for i, k, p, s, d, op in zip( - x.shape[-2:], - self.kernel_size, - self.padding, - self.stride, - self.dilation, - self.output_padding, - strict=False, - ): - out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) - empty = NewEmptyTensorOp.apply(x, out_shape) - if self.training: - # produce dummy gradient to avoid DDP warning. - dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 - return empty + dummy - else: - return empty - - return super().forward(x) - - -@MODELS.register_module() -@MODELS.register_module("deconv3d") -class ConvTranspose3d(nn.ConvTranspose3d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0: - out_shape = [x.shape[0], self.out_channels] - for i, k, p, s, d, op in zip( - x.shape[-3:], - self.kernel_size, - self.padding, - self.stride, - self.dilation, - self.output_padding, - strict=False, - ): - out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op) - empty = NewEmptyTensorOp.apply(x, out_shape) - if self.training: - # produce dummy gradient to avoid DDP warning. - dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 - return empty + dummy - else: - return empty - - return super().forward(x) - - -class MaxPool2d(nn.MaxPool2d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - # PyTorch 1.9 does not support empty tensor inference yet - if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: - out_shape = list(x.shape[:2]) - for i, k, p, s, d in zip( - x.shape[-2:], - _pair(self.kernel_size), - _pair(self.padding), - _pair(self.stride), - _pair(self.dilation), - strict=False, - ): - o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 - o = math.ceil(o) if self.ceil_mode else math.floor(o) - out_shape.append(o) - empty = NewEmptyTensorOp.apply(x, out_shape) - return empty - - return super().forward(x) - - -class MaxPool3d(nn.MaxPool3d): - def forward(self, x: torch.Tensor) -> torch.Tensor: - # PyTorch 1.9 does not support empty tensor inference yet - if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0: - out_shape = list(x.shape[:2]) - for i, k, p, s, d in zip( - x.shape[-3:], - _triple(self.kernel_size), - _triple(self.padding), - _triple(self.stride), - _triple(self.dilation), - strict=False, - ): - o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1 - o = math.ceil(o) if self.ceil_mode else math.floor(o) - out_shape.append(o) - empty = NewEmptyTensorOp.apply(x, out_shape) - return empty - - return super().forward(x) - - -class Linear(torch.nn.Linear): - def forward(self, x: torch.Tensor) -> torch.Tensor: - # empty tensor forward of Linear layer is supported in Pytorch 1.6 - if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0: - out_shape = [x.shape[0], self.out_features] - empty = NewEmptyTensorOp.apply(x, out_shape) - if self.training: - # produce dummy gradient to avoid DDP warning. - dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 - return empty + dummy - else: - return empty - - return super().forward(x) diff --git a/libs/viscv/viscv/fileio/__init__.py b/libs/viscv/viscv/fileio/__init__.py deleted file mode 100644 index 859a1a6..0000000 --- a/libs/viscv/viscv/fileio/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .io import FileClient, get - -__all__ = ["FileClient", "get"] diff --git a/libs/viscv/viscv/fileio/io.py b/libs/viscv/viscv/fileio/io.py deleted file mode 100644 index ed92a85..0000000 --- a/libs/viscv/viscv/fileio/io.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Simple file I/O utilities for viscv.""" - -from pathlib import Path - - -def get(filepath: str | Path, backend_args: dict | None = None) -> bytes: - """Read bytes from a given filepath. - - Args: - filepath: Path to read from. - backend_args: Backend-specific arguments (currently unused). - - Returns: - bytes: File contents as bytes. - """ - filepath = Path(filepath) - with open(filepath, "rb") as f: - return f.read() - - -class FileClient: - """Simple file client for reading files.""" - - def __init__(self, backend="disk", **kwargs): - self.backend = backend - - @classmethod - def infer_client(cls, file_client_args, filename): - """Infer the file client from arguments.""" - return cls(**file_client_args) - - def get(self, filepath: str | Path) -> bytes: - """Read bytes from a given filepath.""" - return get(filepath) diff --git a/libs/viscv/viscv/image/__init__.py b/libs/viscv/viscv/image/__init__.py deleted file mode 100644 index 00d44b1..0000000 --- a/libs/viscv/viscv/image/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .cache import ImageCache -from .geometric import ( - imcrop, - imflip, - impad, - imrescale, - imresize, - imrotate, - imshear, - imtranslate, - rescale_size, -) -from .io import imfrombytes, imwrite -from .photometric import hsv2bgr, imdenormalize, imnormalize - -__all__ = [ - "hsv2bgr", - "ImageCache", - "imcrop", - "imdenormalize", - "imflip", - "imfrombytes", - "imnormalize", - "impad", - "imrescale", - "imresize", - "imrotate", - "imshear", - "imtranslate", - "imwrite", - "rescale_size", -] diff --git a/libs/viscv/viscv/image/cache.py b/libs/viscv/viscv/image/cache.py deleted file mode 100644 index 5773feb..0000000 --- a/libs/viscv/viscv/image/cache.py +++ /dev/null @@ -1,347 +0,0 @@ -"""On-disk cache for downsized images to avoid repeated decoding and resizing.""" - -import hashlib -import shutil -import sqlite3 -import time -from pathlib import Path -from typing import Any - -import numpy as np - - -class ImageCache: - """On-disk cache for downsized images. - - This cache stores downsized images on disk after the first load, avoiding - repeated JPEG decoding and resizing operations in subsequent training epochs. - - The cache uses: - - NPY format for fast numpy array serialization - - MD5 hashing for cache keys (based on image path, target size, and modification time) - - SQLite for metadata and LRU tracking - - Automatic eviction when cache size exceeds max_size_gb - - Args: - cache_dir: Directory to store cached images. Defaults to ~/.cache/viscv/image_cache - max_size_gb: Maximum cache size in gigabytes. When exceeded, least recently used - images are evicted. Defaults to 10.0 GB. - enabled: Whether caching is enabled. Defaults to True. - """ - - def __init__( - self, - cache_dir: Path | str = Path.home() / ".cache" / "viscv" / "image_cache", - max_size_gb: float = 10.0, - enabled: bool = True, - ) -> None: - self.enabled = enabled - if not self.enabled: - return - - self.cache_dir = Path(cache_dir) - self.max_size_bytes = int(max_size_gb * 1024 * 1024 * 1024) - - # Create cache directory - self.cache_dir.mkdir(parents=True, exist_ok=True) - - # Initialize SQLite database for metadata - self.db_path = self.cache_dir / "cache_metadata.db" - self._init_db() - - def _init_db(self) -> None: - """Initialize SQLite database for cache metadata.""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - # Create table for cache entries - cursor.execute(""" - CREATE TABLE IF NOT EXISTS cache_entries ( - cache_key TEXT PRIMARY KEY, - img_path TEXT NOT NULL, - target_width INTEGER NOT NULL, - target_height INTEGER NOT NULL, - mtime REAL NOT NULL, - file_size INTEGER NOT NULL, - last_accessed REAL NOT NULL, - access_count INTEGER DEFAULT 1 - ) - """) - - conn.commit() - conn.close() - - def _generate_cache_key(self, img_path: str, target_size: tuple[int, int], mtime: float) -> str: - """Generate unique cache key for an image. - - Args: - img_path: Path to original image - target_size: Target size (width, height) after resize - mtime: Modification time of original image - - Returns: - MD5 hash as cache key - """ - key_str = f"{img_path}:{target_size[0]}x{target_size[1]}:{mtime:.6f}" - return hashlib.md5(key_str.encode()).hexdigest() - - def _get_cache_path(self, cache_key: str) -> Path: - """Get file path for cached image. - - Args: - cache_key: Cache key hash - - Returns: - Path to NPY file - """ - # Use first 2 chars of hash for subdirectory (256 buckets) - subdir = cache_key[:2] - cache_subdir = self.cache_dir / subdir - cache_subdir.mkdir(exist_ok=True) - return cache_subdir / f"{cache_key}.npy" - - def get(self, img_path: str, target_size: tuple[int, int]) -> np.ndarray | None: - """Get cached image if available. - - Args: - img_path: Path to original image - target_size: Target size (width, height) after resize - - Returns: - Cached image array, or None if not in cache or cache is disabled - """ - if not self.enabled: - return None - - # Check if original file exists and get mtime - img_file = Path(img_path) - if not img_file.exists(): - return None - - mtime = img_file.stat().st_mtime - cache_key = self._generate_cache_key(img_path, target_size, mtime) - cache_path = self._get_cache_path(cache_key) - - # Check if cached file exists - if not cache_path.exists(): - return None - - # Verify cache entry is valid (check mtime hasn't changed) - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - cursor.execute( - "SELECT mtime FROM cache_entries WHERE cache_key = ?", - (cache_key,), - ) - result = cursor.fetchone() - - if result is None: - # Cache file exists but no metadata - clean up - cache_path.unlink(missing_ok=True) - conn.close() - return None - - cached_mtime = result[0] - if abs(cached_mtime - mtime) > 0.001: # Allow small floating point diff - # Image has been modified - invalidate cache - self._remove_entry(cache_key, cache_path, cursor) - conn.commit() - conn.close() - return None - - # Load cached image - try: - img = np.load(cache_path) - - # Update access time and count - now = time.time() - cursor.execute( - """ - UPDATE cache_entries - SET last_accessed = ?, access_count = access_count + 1 - WHERE cache_key = ? - """, - (now, cache_key), - ) - conn.commit() - conn.close() - - return img - - except Exception: - # Failed to load - clean up corrupted cache - self._remove_entry(cache_key, cache_path, cursor) - conn.commit() - conn.close() - return None - - def put(self, img_path: str, target_size: tuple[int, int], img: np.ndarray) -> None: - """Store image in cache. - - Args: - img_path: Path to original image - target_size: Target size (width, height) after resize - img: Image array to cache - """ - if not self.enabled: - return - - # Get original file mtime - img_file = Path(img_path) - if not img_file.exists(): - return - - mtime = img_file.stat().st_mtime - cache_key = self._generate_cache_key(img_path, target_size, mtime) - cache_path = self._get_cache_path(cache_key) - - # Save to disk - try: - np.save(cache_path, img) - file_size = cache_path.stat().st_size - - # Add metadata - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - now = time.time() - cursor.execute( - """ - INSERT OR REPLACE INTO cache_entries - (cache_key, img_path, target_width, target_height, mtime, file_size, last_accessed, access_count) - VALUES (?, ?, ?, ?, ?, ?, ?, 1) - """, - ( - cache_key, - img_path, - target_size[0], - target_size[1], - mtime, - file_size, - now, - ), - ) - - conn.commit() - conn.close() - - # Check if we need to evict - self._maybe_evict() - - except Exception: - # Failed to save - clean up - cache_path.unlink(missing_ok=True) - - def _remove_entry(self, cache_key: str, cache_path: Path, cursor: sqlite3.Cursor) -> None: - """Remove a cache entry. - - Args: - cache_key: Cache key to remove - cache_path: Path to cached file - cursor: SQLite cursor (must be committed by caller) - """ - cache_path.unlink(missing_ok=True) - cursor.execute("DELETE FROM cache_entries WHERE cache_key = ?", (cache_key,)) - - def _maybe_evict(self) -> None: - """Evict least recently used entries if cache size exceeds max_size_bytes.""" - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - # Get total cache size - cursor.execute("SELECT SUM(file_size) FROM cache_entries") - result = cursor.fetchone() - total_size = result[0] if result[0] is not None else 0 - - if total_size <= self.max_size_bytes: - conn.close() - return - - # Evict LRU entries until under limit - bytes_to_free = total_size - self.max_size_bytes - - cursor.execute( - """ - SELECT cache_key, file_size - FROM cache_entries - ORDER BY last_accessed ASC - """, - ) - - freed_bytes = 0 - for cache_key, file_size in cursor.fetchall(): - cache_path = self._get_cache_path(cache_key) - self._remove_entry(cache_key, cache_path, cursor) - freed_bytes += file_size - - if freed_bytes >= bytes_to_free: - break - - conn.commit() - conn.close() - - def clear(self) -> None: - """Clear all cached images.""" - if not self.enabled: - return - - # Remove all cache files - if self.cache_dir.exists(): - shutil.rmtree(self.cache_dir) - self.cache_dir.mkdir(parents=True, exist_ok=True) - - # Reinitialize database - self._init_db() - - def get_stats(self) -> dict[str, Any]: - """Get cache statistics. - - Returns: - Dictionary with cache statistics: - - total_entries: Number of cached images - - total_size_mb: Total cache size in megabytes - - hit_rate: Estimated cache hit rate (based on access counts) - """ - if not self.enabled: - return {"enabled": False} - - conn = sqlite3.connect(self.db_path) - cursor = conn.cursor() - - cursor.execute( - """ - SELECT - COUNT(*) as total_entries, - SUM(file_size) as total_size, - AVG(access_count) as avg_access_count - FROM cache_entries - """, - ) - - result = cursor.fetchone() - conn.close() - - total_entries = result[0] if result[0] is not None else 0 - total_size = result[1] if result[1] is not None else 0 - avg_access = result[2] if result[2] is not None else 0 - - return { - "enabled": True, - "total_entries": total_entries, - "total_size_mb": total_size / (1024 * 1024), - "total_size_gb": total_size / (1024 * 1024 * 1024), - "avg_access_count": float(avg_access), - "cache_dir": str(self.cache_dir), - } - - def __repr__(self) -> str: - stats = self.get_stats() - if not stats["enabled"]: - return f"{self.__class__.__name__}(enabled=False)" - - return ( - f"{self.__class__.__name__}(" - f"cache_dir={stats['cache_dir']}, " - f"entries={stats['total_entries']}, " - f"size={stats['total_size_gb']:.2f}GB)" - ) diff --git a/libs/viscv/viscv/image/geometric.py b/libs/viscv/viscv/image/geometric.py deleted file mode 100644 index ee625e7..0000000 --- a/libs/viscv/viscv/image/geometric.py +++ /dev/null @@ -1,575 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# type: ignore -import numbers - -import cv2 -import numpy as np - -try: - from PIL import Image -except ImportError: - Image = None - - -def _scale_size( - size: tuple[int, int], - scale: float | int | tuple[float, float] | tuple[int, int], -) -> tuple[int, int]: - """Rescale a size by a ratio. - - Args: - size (tuple[int]): (w, h). - scale (float | int | tuple(float) | tuple(int)): Scaling factor. - - Returns: - tuple[int]: scaled size. - """ - if isinstance(scale, float | int): - scale = (scale, scale) - w, h = size - return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) - - -cv2_interp_codes = { - "nearest": cv2.INTER_NEAREST, - "bilinear": cv2.INTER_LINEAR, - "bicubic": cv2.INTER_CUBIC, - "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4, -} - -cv2_border_modes = { - "constant": cv2.BORDER_CONSTANT, - "replicate": cv2.BORDER_REPLICATE, - "reflect": cv2.BORDER_REFLECT, - "wrap": cv2.BORDER_WRAP, - "reflect_101": cv2.BORDER_REFLECT_101, - "transparent": cv2.BORDER_TRANSPARENT, - "isolated": cv2.BORDER_ISOLATED, -} - - -def imresize( - img: np.ndarray, - size: tuple[int, int], - return_scale: bool = False, - interpolation: str = "bilinear", - out: np.ndarray | None = None, - backend: str | None = None, -) -> tuple[np.ndarray, float, float] | np.ndarray: - """Resize image to a given size. - - Args: - img (ndarray): The input image. - size (tuple[int]): Target size (w, h). - return_scale (bool): Whether to return `w_scale` and `h_scale`. - interpolation (str): Interpolation method, accepted values are - "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' - backend. - out (ndarray): The output destination. - backend (str | None): The image resize backend type. Options are `cv2`, - `pillow`, `None`. If backend is None, `cv2` will be used. Default: None. - - Returns: - tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or - `resized_img`. - """ - h, w = img.shape[:2] - if backend is None: - backend = "cv2" - - if backend == "cv2": - resized_img = cv2.resize(img, size, dst=out, interpolation=cv2_interp_codes[interpolation]) - else: - raise ValueError(f"backend: {backend} is not supported for resize.") - - if not return_scale: - return resized_img - else: - w_scale = size[0] / w - h_scale = size[1] / h - return resized_img, w_scale, h_scale - - -def rescale_size(old_size: tuple, scale: float | int | tuple[int, int], return_scale: bool = False) -> tuple: - """Calculate the new size to be rescaled to. - - Args: - old_size (tuple[int]): The old size (w, h) of image. - scale (float | int | tuple[int]): The scaling factor or maximum size. - If it is a float number or an integer, then the image will be - rescaled by this factor, else if it is a tuple of 2 integers, then - the image will be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image size. - - Returns: - tuple[int]: The new rescaled image size. - """ - w, h = old_size - if isinstance(scale, float | int): - if scale <= 0: - raise ValueError(f"Invalid scale {scale}, must be positive.") - scale_factor = scale - elif isinstance(scale, tuple): - max_long_edge = max(scale) - max_short_edge = min(scale) - scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) - else: - raise TypeError(f"Scale must be a number or tuple of int, but got {type(scale)}") - - new_size = _scale_size((w, h), scale_factor) - - if return_scale: - return new_size, scale_factor - else: - return new_size - - -def imrescale( - img: np.ndarray, - scale: float | int | tuple[int, int], - return_scale: bool = False, - interpolation: str = "bilinear", - backend: str | None = None, -) -> np.ndarray | tuple[np.ndarray, float]: # type: ignore[return] - """Resize image while keeping the aspect ratio. - - Args: - img (ndarray): The input image. - scale (float | int | tuple[int]): The scaling factor or maximum size. - If it is a float number or an integer, then the image will be - rescaled by this factor, else if it is a tuple of 2 integers, then - the image will be rescaled as large as possible within the scale. - return_scale (bool): Whether to return the scaling factor besides the - rescaled image. - interpolation (str): Same as :func:`resize`. - backend (str | None): Same as :func:`resize`. - - Returns: - ndarray: The rescaled image. - """ - h, w = img.shape[:2] - new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) - rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) - if return_scale: - return rescaled_img, scale_factor - else: - return rescaled_img - - -def imflip(img: np.ndarray, direction: str = "horizontal") -> np.ndarray: - """Flip an image horizontally or vertically. - - Args: - img (ndarray): Image to be flipped. - direction (str): The flip direction, either "horizontal" or - "vertical" or "diagonal". - - Returns: - ndarray: The flipped image. - """ - assert direction in ["horizontal", "vertical", "diagonal"] - if direction == "horizontal": - return np.flip(img, axis=1) - elif direction == "vertical": - return np.flip(img, axis=0) - else: - return np.flip(img, axis=(0, 1)) - - -def imrotate( - img: np.ndarray, - angle: float, - center: tuple[float, float] | None = None, - scale: float = 1.0, - border_value: int = 0, - interpolation: str = "bilinear", - auto_bound: bool = False, - border_mode: str = "constant", -) -> np.ndarray: - """Rotate an image. - - Args: - img (np.ndarray): Image to be rotated. - angle (float): Rotation angle in degrees, positive values mean - clockwise rotation. - center (tuple[float], optional): Center point (w, h) of the rotation in - the source image. If not specified, the center of the image will be - used. - scale (float): Isotropic scale factor. - border_value (int): Border value used in case of a constant border. - Defaults to 0. - interpolation (str): Same as :func:`resize`. - auto_bound (bool): Whether to adjust the image size to cover the whole - rotated image. - border_mode (str): Pixel extrapolation method. Defaults to 'constant'. - - Returns: - np.ndarray: The rotated image. - """ - if center is not None and auto_bound: - raise ValueError("`auto_bound` conflicts with `center`") - h, w = img.shape[:2] - if center is None: - center = ((w - 1) * 0.5, (h - 1) * 0.5) - assert isinstance(center, tuple) - - matrix = cv2.getRotationMatrix2D(center, -angle, scale) - if auto_bound: - cos = np.abs(matrix[0, 0]) - sin = np.abs(matrix[0, 1]) - new_w = h * sin + w * cos - new_h = h * cos + w * sin - matrix[0, 2] += (new_w - w) * 0.5 - matrix[1, 2] += (new_h - h) * 0.5 - w = int(np.round(new_w)) - h = int(np.round(new_h)) - rotated = cv2.warpAffine( - img, - matrix, - (w, h), - flags=cv2_interp_codes[interpolation], - borderMode=cv2_border_modes[border_mode], - borderValue=border_value, - ) - return rotated - - -def impad( - img: np.ndarray, - *, - shape: tuple[int, int] | None = None, - padding: int | tuple | None = None, - pad_val: float | list = 0, - padding_mode: str = "constant", -) -> np.ndarray: - """Pad the given image to a certain shape or pad on all sides with - specified padding mode and padding value. - - Args: - img (ndarray): Image to be padded. - shape (tuple[int]): Expected padding shape (h, w). Default: None. - padding (int or tuple[int]): Padding on each border. If a single int is - provided this is used to pad all borders. If tuple of length 2 is - provided this is the padding on left/right and top/bottom - respectively. If a tuple of length 4 is provided this is the - padding for the left, top, right and bottom borders respectively. - Default: None. Note that `shape` and `padding` can not be both - set. - pad_val (Number | Sequence[Number]): Values to be filled in padding - areas when padding_mode is 'constant'. Default: 0. - padding_mode (str): Type of padding. Should be: constant, edge, - reflect or symmetric. Default: constant. - - - constant: pads with a constant value, this value is specified - with pad_val. - - edge: pads with the last value at the edge of the image. - - reflect: pads with reflection of image without repeating the last - value on the edge. For example, padding [1, 2, 3, 4] with 2 - elements on both sides in reflect mode will result in - [3, 2, 1, 2, 3, 4, 3, 2]. - - symmetric: pads with reflection of image repeating the last value - on the edge. For example, padding [1, 2, 3, 4] with 2 elements on - both sides in symmetric mode will result in - [2, 1, 1, 2, 3, 4, 4, 3] - - Returns: - ndarray: The padded image. - """ - - assert (shape is not None) ^ (padding is not None) - if shape is not None: - width = max(shape[1] - img.shape[1], 0) - height = max(shape[0] - img.shape[0], 0) - padding = (0, 0, width, height) - - # check pad_val - if isinstance(pad_val, tuple): - assert len(pad_val) == img.shape[-1] - elif not isinstance(pad_val, numbers.Number): - raise TypeError(f"pad_val must be a int or a tuple. But received {type(pad_val)}") - - # check padding - if isinstance(padding, tuple) and len(padding) in [2, 4]: - if len(padding) == 2: - padding = (padding[0], padding[1], padding[0], padding[1]) - elif isinstance(padding, numbers.Number): - padding = (padding, padding, padding, padding) - else: - raise ValueError(f"Padding must be a int or a 2, or 4 element tuple.But received {padding}") - - # check padding mode - assert padding_mode in ["constant", "edge", "reflect", "symmetric"] - - border_type = { - "constant": cv2.BORDER_CONSTANT, - "edge": cv2.BORDER_REPLICATE, - "reflect": cv2.BORDER_REFLECT_101, - "symmetric": cv2.BORDER_REFLECT, - } - img = cv2.copyMakeBorder( - img, - padding[1], - padding[3], - padding[0], - padding[2], - border_type[padding_mode], - value=pad_val, - ) - - return img - - -def _get_shear_matrix(magnitude: int | float, direction: str = "horizontal") -> np.ndarray: - """Generate the shear matrix for transformation. - - Args: - magnitude (int | float): The magnitude used for shear. - direction (str): The flip direction, either "horizontal" - or "vertical". - - Returns: - ndarray: The shear matrix with dtype float32. - """ - if direction == "horizontal": - shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]]) - elif direction == "vertical": - shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]]) - return shear_matrix - - -def imshear( - img: np.ndarray, - magnitude: int | float, - direction: str = "horizontal", - border_value: int | tuple[int, int] = 0, - interpolation: str = "bilinear", -) -> np.ndarray: - """Shear an image. - - Args: - img (ndarray): Image to be sheared with format (h, w) - or (h, w, c). - magnitude (int | float): The magnitude used for shear. - direction (str): The flip direction, either "horizontal" - or "vertical". - border_value (int | tuple[int]): Value used in case of a - constant border. - interpolation (str): Same as :func:`resize`. - - Returns: - ndarray: The sheared image. - """ - assert direction in ["horizontal", "vertical"], f"Invalid direction: {direction}" - height, width = img.shape[:2] - if img.ndim == 2: - channels = 1 - elif img.ndim == 3: - channels = img.shape[-1] - else: - raise ValueError(f"Invalid image dimensions: {img.ndim}") - if isinstance(border_value, int): - border_value = tuple([border_value] * channels) # type: ignore - elif isinstance(border_value, tuple): - assert len(border_value) == channels, ( - f"Expected the num of elements in tuple equals the channels of input image. Found {len(border_value)} vs {channels}" - ) - else: - raise ValueError(f"Invalid type {type(border_value)} for `border_value`") - shear_matrix = _get_shear_matrix(magnitude, direction) - sheared = cv2.warpAffine( - img, - shear_matrix, - (width, height), - # Note case when the number elements in `border_value` - # greater than 3 (e.g. shearing masks whose channels large - # than 3) will raise TypeError in `cv2.warpAffine`. - # Here simply slice the first 3 values in `border_value`. - borderValue=border_value[:3], # type: ignore - flags=cv2_interp_codes[interpolation], - ) - return sheared - - -def _get_translate_matrix(offset: int | float, direction: str = "horizontal") -> np.ndarray: - """Generate the translate matrix. - - Args: - offset (int | float): The offset used for translate. - direction (str): The translate direction, either - "horizontal" or "vertical". - - Returns: - ndarray: The translate matrix with dtype float32. - """ - if direction == "horizontal": - translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]]) - elif direction == "vertical": - translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]]) - return translate_matrix - - -def imtranslate( - img: np.ndarray, - offset: int | float, - direction: str = "horizontal", - border_value: int | tuple = 0, - interpolation: str = "bilinear", -) -> np.ndarray: - """Translate an image. - - Args: - img (ndarray): Image to be translated with format - (h, w) or (h, w, c). - offset (int | float): The offset used for translate. - direction (str): The translate direction, either "horizontal" - or "vertical". - border_value (int | tuple[int]): Value used in case of a - constant border. - interpolation (str): Same as :func:`resize`. - - Returns: - ndarray: The translated image. - """ - assert direction in ["horizontal", "vertical"], f"Invalid direction: {direction}" - height, width = img.shape[:2] - if img.ndim == 2: - channels = 1 - elif img.ndim == 3: - channels = img.shape[-1] - else: - raise ValueError(f"Invalid image dimensions: {img.ndim}") - if isinstance(border_value, int): - border_value = tuple([border_value] * channels) - elif isinstance(border_value, tuple): - assert len(border_value) == channels, ( - f"Expected the num of elements in tuple equals the channels of input image. Found {len(border_value)} vs {channels}" - ) - else: - raise ValueError(f"Invalid type {type(border_value)} for `border_value`.") - translate_matrix = _get_translate_matrix(offset, direction) - translated = cv2.warpAffine( - img, - translate_matrix, - (width, height), - # Note case when the number elements in `border_value` - # greater than 3 (e.g. translating masks whose channels - # large than 3) will raise TypeError in `cv2.warpAffine`. - # Here simply slice the first 3 values in `border_value`. - borderValue=border_value[:3], - flags=cv2_interp_codes[interpolation], - ) - return translated - - -def bbox_clip(bboxes: np.ndarray, img_shape: tuple[int, int]) -> np.ndarray: - """Clip bboxes to fit the image shape. - - Args: - bboxes (ndarray): Shape (..., 4*k) in format (x1, y1, x2, y2, ...) - img_shape (tuple[int]): (height, width) of the image. - - Returns: - ndarray: Clipped bboxes. - """ - assert bboxes.shape[-1] % 4 == 0 - clipped_bboxes = bboxes.copy() - - h, w = img_shape - - # Process each group of 4 coordinates (x1, y1, x2, y2) - for i in range(0, bboxes.shape[-1], 4): - # Clip x coordinates - clipped_bboxes[..., i] = np.clip(clipped_bboxes[..., i], 0, w) # x1 - clipped_bboxes[..., i + 2] = np.clip(clipped_bboxes[..., i + 2], 0, w) # x2 - - # Clip y coordinates - clipped_bboxes[..., i + 1] = np.clip(clipped_bboxes[..., i + 1], 0, h) # y1 - clipped_bboxes[..., i + 3] = np.clip(clipped_bboxes[..., i + 3], 0, h) # y2 - - return clipped_bboxes - - -def bbox_scaling(bboxes: np.ndarray, scale: float, clip_shape=None) -> np.ndarray: - """Scaling bboxes w.r.t the box center. - - Args: - bboxes (ndarray): Shape(..., 4). - scale (float): Scaling factor. - clip_shape (tuple[int], optional): If specified, bboxes that exceed the - boundary will be clipped according to the given shape (h, w). - - Returns: - ndarray: Scaled bboxes. - """ - if float(scale) == 1.0: - scaled_bboxes = bboxes.copy() - else: - w = bboxes[..., 2] - bboxes[..., 0] + 1 - h = bboxes[..., 3] - bboxes[..., 1] + 1 - dw = (w * (scale - 1)) * 0.5 - dh = (h * (scale - 1)) * 0.5 - scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1) - if clip_shape is not None: - return bbox_clip(scaled_bboxes, clip_shape) - else: - return scaled_bboxes - - -def imcrop( - img: np.ndarray, - bboxes: np.ndarray, - scale: float = 1.0, - pad_fill: float | list | None = None, -) -> np.ndarray | list[np.ndarray]: - """Crop image patches. - - 3 steps: scale the bboxes -> clip bboxes -> crop and pad. - - Args: - img (ndarray): Image to be cropped. - bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. - scale (float, optional): Scale ratio of bboxes, the default value - 1.0 means no scaling. - pad_fill (Number | list[Number]): Value to be filled for padding. - Default: None, which means no padding. - - Returns: - list[ndarray] | ndarray: The cropped image patches. - """ - chn = 1 if img.ndim == 2 else img.shape[2] - if pad_fill is not None: - if isinstance(pad_fill, (int, float)): - pad_fill = [pad_fill for _ in range(chn)] - assert len(pad_fill) == chn - - _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes - scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32) - clipped_bbox = bbox_clip(scaled_bboxes, img.shape) - - patches = [] - for i in range(clipped_bbox.shape[0]): - x1, y1, x2, y2 = tuple(clipped_bbox[i, :]) - if pad_fill is None: - patch = img[y1 : y2 + 1, x1 : x2 + 1, ...] - else: - _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :]) - patch_h = _y2 - _y1 + 1 - patch_w = _x2 - _x1 + 1 - if chn == 1: - patch_shape = (patch_h, patch_w) - else: - patch_shape = (patch_h, patch_w, chn) # type: ignore - patch = np.array(pad_fill, dtype=img.dtype) * np.ones(patch_shape, dtype=img.dtype) - x_start = 0 if _x1 >= 0 else -_x1 - y_start = 0 if _y1 >= 0 else -_y1 - w = x2 - x1 + 1 - h = y2 - y1 + 1 - patch[y_start : y_start + h, x_start : x_start + w, ...] = img[y1 : y1 + h, x1 : x1 + w, ...] - patches.append(patch) - - if bboxes.ndim == 1: - return patches[0] - else: - return patches diff --git a/libs/viscv/viscv/image/io.py b/libs/viscv/viscv/image/io.py deleted file mode 100644 index c7e90b8..0000000 --- a/libs/viscv/viscv/image/io.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import io -import os.path as osp -import warnings -from typing import Optional - -import cv2 -import numpy as np -from cv2 import ( - IMREAD_COLOR, - IMREAD_GRAYSCALE, - IMREAD_IGNORE_ORIENTATION, - IMREAD_UNCHANGED, -) - -try: - from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG -except ImportError: - TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None - -try: - from PIL import Image, ImageOps -except ImportError: - Image = None - -try: - import tifffile -except ImportError: - tifffile = None - -jpeg = None -supported_backends = ["cv2", "turbojpeg", "pillow", "tifffile"] - -imread_flags = { - "color": IMREAD_COLOR, - "grayscale": IMREAD_GRAYSCALE, - "unchanged": IMREAD_UNCHANGED, - "color_ignore_orientation": IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR, - "grayscale_ignore_orientation": IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE, -} - -imread_backend = "cv2" - - -def imfrombytes( - content: bytes, - flag: str = "color", - channel_order: str = "bgr", - backend: str | None = None, -) -> np.ndarray: - """Read an image from bytes. - - Args: - content (bytes): Image bytes got from files or other streams. - flag (str): Same as :func:`imread`. - channel_order (str): The channel order of the output, candidates - are 'bgr' and 'rgb'. Default to 'bgr'. - backend (str | None): The image decoding backend type. Options are - `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is - None, the global imread_backend specified by ``use_backend()`` will - be used. Default: None. - - Returns: - ndarray: Loaded image array. - """ - if backend is None: - backend = imread_backend - if backend not in supported_backends: - raise ValueError( - f"backend: {backend} is not supported. Supported backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'" - ) - - if backend == "turbojpeg": - img = _jpegflag(flag, channel_order) - if jpeg is None: - raise ImportError("`PyTurboJPEG` is not installed") - img = jpeg.decode(content, img) - if channel_order == "rgb": - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - elif backend == "pillow": - if Image is None: - raise ImportError("`Pillow` is not installed") - with io.BytesIO(content) as buff: - img = Image.open(buff) - img = _pillow2array(img, flag, channel_order) - return img - - elif backend == "tifffile": - if tifffile is None: - raise ImportError("`tifffile` is not installed") - with io.BytesIO(content) as buff: - img = tifffile.imread(buff) - return img - - else: - # cv2 backend - if len(content) == 0: - return None - img_np = np.frombuffer(content, np.uint8) - flag = imread_flags[flag] if isinstance(flag, str) else flag - img = cv2.imdecode(img_np, flag) - if img is not None and flag == IMREAD_COLOR and channel_order == "rgb": - cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) - return img - - -def _jpegflag(flag: str = "color", channel_order: str = "bgr"): - channel_order = channel_order.lower() - if channel_order not in ["rgb", "bgr"]: - raise ValueError('channel order must be either "rgb" or "bgr"') - - if flag == "color": - if channel_order == "bgr": - return TJPF_BGR - elif channel_order == "rgb": - return TJCS_RGB - elif flag == "grayscale": - return TJPF_GRAY - else: - raise ValueError('flag must be "color" or "grayscale"') - - -def _pillow2array(img, flag: str = "color", channel_order: str = "bgr") -> np.ndarray: - """Convert a pillow image to numpy array. - - Args: - img (:obj:`PIL.Image.Image`): The image loaded using PIL - flag (str): Flags specifying the color type of a loaded image, - candidates are 'color', 'grayscale' and 'unchanged'. - Default to 'color'. - channel_order (str): The channel order of the output image array, - candidates are 'bgr' and 'rgb'. Default to 'bgr'. - - Returns: - np.ndarray: The converted numpy array - """ - channel_order = channel_order.lower() - if channel_order not in ["rgb", "bgr"]: - raise ValueError('channel order must be either "rgb" or "bgr"') - - if flag == "unchanged": - array = np.array(img) - if array.ndim >= 3 and array.shape[2] >= 3: # color image - array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR - else: - # Handle exif orientation tag - if flag in ["color", "grayscale"]: - if hasattr(img, "_getexif") and img._getexif() is not None: - # Not all images have exif info - exif = img._getexif() - orientation = exif.get(274, 1) # 274 is the orientation tag id - if orientation > 1: - img = ImageOps.exif_transpose(img) - - # If the image mode is not 'RGB', convert it to 'RGB' first. - if flag == "color": - if img.mode != "RGB": - if img.mode != "LA": - # Most formats except 'LA' can be directly converted to RGB - img = img.convert("RGB") - else: - # When the mode is 'LA', the default conversion will fill in - # the canvas with black, which sometimes shadows black objects - # in the foreground. - # Therefore, a random color (124, 117, 104) is used for canvas - img_rgba = img.convert("RGBA") - img = Image.new("RGB", img_rgba.size, (124, 117, 104)) - img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha - if channel_order == "bgr": - array = np.array(img, dtype=np.uint8) - array = array[:, :, ::-1] # RGB to BGR - else: - array = np.array(img, dtype=np.uint8) - elif flag == "grayscale": - img = img.convert("L") - array = np.array(img, dtype=np.uint8) - - return array - - -def imwrite( - img: np.ndarray, - file_path: str, - params: list | None = None, - auto_mkdir: bool | None = None, -) -> bool: - """Write image to file. - - Warning: - The parameter `auto_mkdir` will be deprecated in the future and every - file clients will make directory automatically. - - Args: - img (ndarray): Image array to be written. - file_path (str): Image file path. - params (None or list): Same as opencv :func:`imwrite` interface. - auto_mkdir (bool): If the parent folder of `file_path` does not exist, - whether to create it automatically. It will be deprecated. - - Returns: - bool: Successful or not. - - Examples: - >>> # write to hard disk client - >>> ret = viscv.imwrite(img, '/path/to/img.jpg') - """ - file_path = str(file_path) - if auto_mkdir is not None: - warnings.warn( - "The parameter `auto_mkdir` will be deprecated in the future and " - "every file clients will make directory automatically." - ) - - # Create directory if it doesn't exist - dir_name = osp.dirname(file_path) - if dir_name and not osp.exists(dir_name): - import os - - os.makedirs(dir_name, exist_ok=True) - - img_ext = osp.splitext(file_path)[-1] - # Encode image according to image suffix. - # For example, if image path is '/path/your/img.jpg', the encode - # format is '.jpg'. - flag, img_buff = cv2.imencode(img_ext, img, params) - - if flag: - with open(file_path, "wb") as f: - f.write(img_buff.tobytes()) - - return flag diff --git a/libs/viscv/viscv/image/photometric.py b/libs/viscv/viscv/image/photometric.py deleted file mode 100644 index f501a77..0000000 --- a/libs/viscv/viscv/image/photometric.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 -import numpy as np -import torch - - -def imnormalize(img, mean, std, to_rgb=True): - """Normalize an image with mean and std. - - Args: - img (ndarray): Image to be normalized. - mean (ndarray): The mean to be used for image normalize. - std (ndarray): The std to be used for image normalize. - to_rgb (bool): Whether to convert to rgb. - - Returns: - ndarray: The normalized image. - """ - img = img.copy().astype(np.float32) - mean = np.array(mean, dtype=np.float32) - std = np.array(std, dtype=np.float32) - - assert img.dtype != np.uint8 - mean = mean.reshape(1, -1) - std = std.reshape(1, -1) - - if to_rgb: - cv2 = None - try: - import cv2 - except ImportError: - pass - - if cv2 is not None: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - else: - img = img[..., ::-1] - - img = (img - mean) / std - return img - - -def imdenormalize(img, mean, std, to_bgr=True): - """Denormalize an image with mean and std. - - Args: - img (ndarray or Tensor): Image to be denormalized. - mean (ndarray): The mean to be used for image denormalize. - std (ndarray): The std to be used for image denormalize. - to_bgr (bool): Whether to convert to bgr. - - Returns: - ndarray or Tensor: The denormalized image. - """ - if isinstance(img, torch.Tensor): - mean = torch.tensor(mean, dtype=img.dtype, device=img.device) - std = torch.tensor(std, dtype=img.dtype, device=img.device) - if img.dim() == 4: # (N, C, H, W) - mean = mean.view(1, -1, 1, 1) - std = std.view(1, -1, 1, 1) - elif img.dim() == 3: # (C, H, W) - mean = mean.view(-1, 1, 1) - std = std.view(-1, 1, 1) - img = img * std + mean - if to_bgr and img.shape[-3] == 3: - img = img.flip(-3) - else: # numpy array - mean = np.array(mean, dtype=img.dtype) - std = np.array(std, dtype=img.dtype) - img = img * std + mean - if to_bgr and img.shape[-1] == 3: - cv2 = None - try: - import cv2 - except ImportError: - pass - - if cv2 is not None: - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - else: - img = img[..., ::-1].copy() - - return img - - -def hsv2bgr(img: np.ndarray) -> np.ndarray: - """Convert a HSV image to BGR image. - - Args: - img (ndarray): The input HSV image. - - Returns: - ndarray: The converted BGR image. - """ - return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) diff --git a/libs/viscv/viscv/ops/__init__.py b/libs/viscv/viscv/ops/__init__.py deleted file mode 100644 index 573d442..0000000 --- a/libs/viscv/viscv/ops/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .nms import batched_nms, nms -from .roi_align import RoIAlign, roi_align - -__all__ = ["RoIAlign", "batched_nms", "nms", "roi_align"] diff --git a/libs/viscv/viscv/ops/nms.py b/libs/viscv/viscv/ops/nms.py deleted file mode 100644 index be7d6c7..0000000 --- a/libs/viscv/viscv/ops/nms.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch -from torch import Tensor -from torchvision.ops import nms as torch_nms - - -def batched_nms( - boxes: Tensor, - scores: Tensor, - idxs: Tensor, - nms_cfg: dict | None, - class_agnostic: bool = False, -) -> tuple[Tensor, Tensor]: - r"""Performs non-maximum suppression in a batched fashion. - - Modified from `torchvision/ops/boxes.py#L39 - `_. - In order to perform NMS independently per class, we add an offset to all - the boxes. The offset is dependent only on the class idx, and is large - enough so that boxes from different classes do not overlap. - - Note: - In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and - returns sorted raw results when `nms_cfg` is None. - - Args: - boxes (torch.Tensor): boxes in shape (N, 4) or (N, 5). - scores (torch.Tensor): scores in shape (N, ). - idxs (torch.Tensor): each index value correspond to a bbox cluster, - and NMS will not be applied between elements of different idxs, - shape (N, ). - nms_cfg (dict | optional): Supports skipping the nms when `nms_cfg` - is None, otherwise it should specify nms type and other - parameters like `iou_thr`. Possible keys includes the following. - - - iou_threshold (float): IoU threshold used for NMS. - - split_thr (float): threshold number of boxes. In some cases the - number of boxes is large (e.g., 200k). To avoid OOM during - training, the users could set `split_thr` to a small value. - If the number of boxes is greater than the threshold, it will - perform NMS on each group of boxes separately and sequentially. - Defaults to 10000. - class_agnostic (bool): if true, nms is class agnostic, - i.e. IoU thresholding happens over all boxes, - regardless of the predicted class. Defaults to False. - - Returns: - tuple: kept dets and indice. - - - boxes (Tensor): Bboxes with score after nms, has shape - (num_bboxes, 5). last dimension 5 arrange as - (x1, y1, x2, y2, score) - - keep (Tensor): The indices of remaining boxes in input - boxes. - """ - # skip nms when nms_cfg is None - if nms_cfg is None: - scores, inds = scores.sort(descending=True) - boxes = boxes[inds] - return torch.cat([boxes, scores[:, None]], -1), inds - - nms_cfg_ = nms_cfg.copy() - class_agnostic = nms_cfg_.pop("class_agnostic", class_agnostic) - if class_agnostic: - boxes_for_nms = boxes - else: - # When using rotated boxes, only apply offsets on center. - if boxes.size(-1) == 5: - # Strictly, the maximum coordinates of the rotating box - # (x,y,w,h,a) should be calculated by polygon coordinates. - # But the conversion from rotated box to polygon will - # slow down the speed. - # So we use max(x,y) + max(w,h) as max coordinate - # which is larger than polygon max coordinate - # max(x1, y1, x2, y2,x3, y3, x4, y4) - max_coordinate = boxes[..., :2].max() + boxes[..., 2:4].max() - offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) - boxes_ctr_for_nms = boxes[..., :2] + offsets[:, None] - boxes_for_nms = torch.cat([boxes_ctr_for_nms, boxes[..., 2:5]], dim=-1) - else: - max_coordinate = boxes.max() - offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) - boxes_for_nms = boxes + offsets[:, None] - - nms_op = nms_cfg_.pop("type", "nms") - if isinstance(nms_op, str): - nms_op = eval(nms_op) - - split_thr = nms_cfg_.pop("split_thr", 10000) - # Won't split to multiple nms nodes when exporting to onnx - if boxes_for_nms.shape[0] < split_thr: - # Extract parameters for torchvision nms - iou_threshold = nms_cfg_.pop("iou_threshold", nms_cfg_.pop("iou_thr", 0.5)) - - # Use torchvision's nms instead of custom implementation - keep = torch_nms(boxes_for_nms, scores, iou_threshold) - - # Apply max_num if specified - max_num = nms_cfg_.get("max_num", -1) - if max_num > 0 and keep.shape[0] > max_num: - keep = keep[:max_num] - - boxes = boxes[keep] - scores = scores[keep] - else: - max_num = nms_cfg_.pop("max_num", -1) - total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) - # Some type of nms would reweight the score, such as SoftNMS - scores_after_nms = scores.new_zeros(scores.size()) - for id in torch.unique(idxs): - mask = (idxs == id).nonzero(as_tuple=False).view(-1) - # Extract parameters for torchvision nms - iou_threshold = nms_cfg_.pop("iou_threshold", nms_cfg_.pop("iou_thr", 0.5)) - - # Use torchvision's nms instead of custom implementation - keep = torch_nms(boxes_for_nms[mask], scores[mask], iou_threshold) - - total_mask[mask[keep]] = True - scores_after_nms[mask[keep]] = scores[mask[keep]] - keep = total_mask.nonzero(as_tuple=False).view(-1) - - scores, inds = scores_after_nms[keep].sort(descending=True) - keep = keep[inds] - boxes = boxes[keep] - - if max_num > 0: - keep = keep[:max_num] - boxes = boxes[:max_num] - scores = scores[:max_num] - - boxes = torch.cat([boxes, scores[:, None]], -1) - return boxes, keep - - -def nms( - boxes: Tensor, - scores: Tensor, - iou_threshold: float, - offset: int = 0, - score_threshold: float = 0, - max_num: int = -1, -) -> tuple[Tensor, Tensor]: - """Dispatch to torchvision NMS implementation. - - The input can be either torch tensor. This implementation uses - torchvision's NMS which is optimized and runs on GPU if available. - - Arguments: - boxes (torch.Tensor): boxes in shape (N, 4). - scores (torch.Tensor): scores in shape (N, ). - iou_threshold (float): IoU threshold for NMS. - offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). - score_threshold (float): score threshold for NMS. - max_num (int): maximum number of boxes after NMS. - - Returns: - tuple: kept dets (boxes and scores) and indice. - - Example: - >>> boxes = torch.tensor([[49.1, 32.4, 51.0, 35.9], - >>> [49.3, 32.9, 51.0, 35.3], - >>> [49.2, 31.8, 51.0, 35.4], - >>> [35.1, 11.5, 39.1, 15.7], - >>> [35.6, 11.8, 39.3, 14.2], - >>> [35.3, 11.5, 39.9, 14.5], - >>> [35.2, 11.7, 39.7, 15.7]], dtype=torch.float32) - >>> scores = torch.tensor([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3], - >>> dtype=torch.float32) - >>> iou_threshold = 0.6 - >>> dets, inds = nms(boxes, scores, iou_threshold) - >>> assert len(inds) == len(dets) == 3 - """ - assert isinstance(boxes, Tensor) - assert isinstance(scores, Tensor) - assert boxes.size(1) == 4 - assert boxes.size(0) == scores.size(0) - assert offset in (0, 1) - - # Filter by score threshold if needed - valid_inds = None - if score_threshold > 0: - valid_mask = scores > score_threshold - boxes = boxes[valid_mask] - scores = scores[valid_mask] - valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1) - - # Apply offset if needed (mmcv compatibility) - if offset == 1: - boxes_for_nms = boxes.clone() - boxes_for_nms[:, 2:] += offset - else: - boxes_for_nms = boxes - - # Use torchvision's NMS - keep = torch_nms(boxes_for_nms, scores, iou_threshold) - - # Apply max_num constraint - if max_num > 0 and keep.shape[0] > max_num: - keep = keep[:max_num] - - # Map back to original indices if we filtered by score - if score_threshold > 0 and valid_inds is not None: - keep = valid_inds[keep] - - dets = torch.cat((boxes[keep], scores[keep].reshape(-1, 1)), dim=1) - return dets, keep diff --git a/libs/viscv/viscv/ops/roi_align.py b/libs/viscv/viscv/ops/roi_align.py deleted file mode 100644 index 0f93678..0000000 --- a/libs/viscv/viscv/ops/roi_align.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn -from torchvision.ops import RoIAlign as TVRoIAlign -from torchvision.ops import roi_align as tv_roi_align - - -class RoIAlign(nn.Module): - """RoI align pooling layer using torchvision's implementation. - - Args: - output_size (tuple): h, w - spatial_scale (float): scale the input boxes by this number - sampling_ratio (int): number of inputs samples to take for each - output sample. 0 to take samples densely for current models. - pool_mode (str): pooling mode in each bin. - aligned (bool): if False, use the legacy implementation in - MMDetection. If True, align the results more perfectly. - use_torchvision (bool): whether to use torchvision's implementation. - We set this to True by default for better performance. - """ - - def __init__( - self, - output_size, - spatial_scale=1.0, - sampling_ratio=0, - pool_mode="avg", - aligned=True, - use_torchvision=True, - ): - super().__init__() - self.output_size = output_size - self.spatial_scale = spatial_scale - self.sampling_ratio = sampling_ratio - self.pool_mode = pool_mode - self.aligned = aligned - self.use_torchvision = use_torchvision - - if isinstance(self.output_size, int): - self.output_size = (self.output_size, self.output_size) - - # We always use torchvision's implementation for simplicity - self.roi_align = TVRoIAlign( - output_size=self.output_size, - spatial_scale=self.spatial_scale, - sampling_ratio=self.sampling_ratio, - aligned=self.aligned, - ) - - def forward(self, input, rois): - """ - Args: - input: NCHW images - rois: Bx5 boxes. First column is the index into N. - The other 4 columns are xyxy. - """ - return self.roi_align(input, rois) - - def __repr__(self): - s = self.__class__.__name__ - s += f"(output_size={self.output_size}, " - s += f"spatial_scale={self.spatial_scale}, " - s += f"sampling_ratio={self.sampling_ratio}, " - s += f"pool_mode={self.pool_mode}, " - s += f"aligned={self.aligned}, " - s += f"use_torchvision={self.use_torchvision})" - return s - - -# Functional interface -def roi_align( - input, - rois, - output_size, - spatial_scale=1.0, - sampling_ratio=0, - pool_mode="avg", - aligned=True, -): - """RoI align pooling layer functional interface. - - Args: - input (Tensor): input tensor. - rois (Tensor): RoIs tensor. - output_size (tuple): h, w - spatial_scale (float): scale the input boxes by this number - sampling_ratio (int): number of inputs samples to take for each - output sample. 0 to take samples densely. - pool_mode (str): pooling mode in each bin. - aligned (bool): if False, use the legacy implementation in - MMDetection. If True, align the results more perfectly. - - Returns: - Tensor: RoI align pooling result. - """ - if isinstance(output_size, int): - output_size = (output_size, output_size) - - # Use torchvision's roi_align directly - return tv_roi_align(input, rois, output_size, spatial_scale, sampling_ratio, aligned) diff --git a/libs/viscv/viscv/transforms/__init__.py b/libs/viscv/viscv/transforms/__init__.py deleted file mode 100644 index 52da0ca..0000000 --- a/libs/viscv/viscv/transforms/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from .base import BaseTransform -from .builder import TRANSFORMS, build_from_cfg, build_transforms -from .formatting import to_tensor -from .loading import LoadAnnotations, LoadImageFromFile -from .processing import Normalize, Pad, RandomFlip, RandomResize, Resize -from .wrappers import ( - Compose, - KeyMapper, - RandomApply, - RandomChoice, - TransformBroadcaster, -) - -__all__ = [ - "TRANSFORMS", - "BaseTransform", - "Compose", - "KeyMapper", - "LoadAnnotations", - "LoadImageFromFile", - "Normalize", - "Pad", - "RandomApply", - "RandomChoice", - "RandomFlip", - "RandomResize", - "Resize", - "TransformBroadcaster", - "build_from_cfg", - "build_transforms", - "to_tensor", -] diff --git a/libs/viscv/viscv/transforms/base.py b/libs/viscv/viscv/transforms/base.py deleted file mode 100644 index 74f0b50..0000000 --- a/libs/viscv/viscv/transforms/base.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - - -class BaseTransform(metaclass=ABCMeta): - """Base class for all transformations.""" - - def __call__(self, results: dict) -> dict | tuple[list, list] | None: - return self.transform(results) - - @abstractmethod - def transform(self, results: dict) -> dict | tuple[list, list] | None: - """The transform function. All subclass of BaseTransform should - override this method. - - This function takes the result dict as the input, and can add new - items to the dict or modify existing items in the dict. And the result - dict will be returned in the end, which allows to concate multiple - transforms into a pipeline. - - Args: - results (dict): The result dict. - - Returns: - dict: The result dict. - """ diff --git a/libs/viscv/viscv/transforms/builder.py b/libs/viscv/viscv/transforms/builder.py deleted file mode 100644 index a529486..0000000 --- a/libs/viscv/viscv/transforms/builder.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Create a simple registry for transforms -class Registry: - """A simple registry to register transforms.""" - - def __init__(self, name): - self._name = name - self._module_dict = {} - - def register_module(self, name=None, module=None, force=False): - """Register a module. - - Args: - name (str | None): The module name to be registered. If not - specified, the class name will be used. - module (type): Module class to be registered. - force (bool): Whether to override an existing class with the same - name. Default: False. - """ - if module is not None: - self._register_module(module_class=module, module_name=name, force=force) - return module - - # use as a decorator - def _register(cls): - self._register_module(module_class=cls, module_name=name, force=force) - return cls - - return _register - - def _register_module(self, module_class, module_name=None, force=False): - if module_name is None: - module_name = module_class.__name__ - if not force and module_name in self._module_dict: - raise KeyError(f"{module_name} is already registered in {self._name}") - self._module_dict[module_name] = module_class - - def get(self, key): - """Get the registered module.""" - return self._module_dict.get(key, None) - - def build(self, cfg): - """Build a module from config dict.""" - if isinstance(cfg, dict): - cfg = cfg.copy() - if "type" not in cfg: - raise KeyError('cfg must contain the key "type"') - module_type = cfg.pop("type") - if module_type not in self._module_dict: - raise KeyError(f"{module_type} is not in the {self._name} registry") - module_cls = self._module_dict[module_type] - return module_cls(**cfg) - else: - raise TypeError("cfg must be a dict") - - -TRANSFORMS = Registry("transforms") - - -def build_from_cfg(cfg, registry, default_args=None): - """Build a module from config dict. - - Args: - cfg (dict): Configuration dict. It should at least contain the key "type". - registry (Registry): The registry to find the type from. - default_args (dict, optional): Default initialization arguments. - - Returns: - obj: The constructed object. - """ - if not isinstance(cfg, dict): - raise TypeError(f"cfg must be a dict, but got {type(cfg)}") - - if "type" not in cfg: - raise KeyError('cfg must contain the key "type"') - - cfg = cfg.copy() - - # Merge default arguments - if default_args is not None: - for k, v in default_args.items(): - cfg.setdefault(k, v) - - obj_type = cfg.pop("type") - - if isinstance(obj_type, str): - obj_cls = registry.get(obj_type) - if obj_cls is None: - raise KeyError(f"{obj_type} is not in the {registry._name} registry") - else: - obj_cls = obj_type - - return obj_cls(**cfg) - - -def build_transforms(cfg): - """Build a transform or a sequence of transforms. - - Args: - cfg (dict, list[dict]): Transform config or list of configs. - - Returns: - transform: The transform or a composed transform. - """ - if isinstance(cfg, list): - transforms = [] - for transform_cfg in cfg: - transform = build_from_cfg(transform_cfg, TRANSFORMS) - transforms.append(transform) - - # Import Compose here to avoid circular imports - from viscv.transforms.compose import Compose - - return Compose(transforms) - else: - return build_from_cfg(cfg, TRANSFORMS) diff --git a/libs/viscv/viscv/transforms/formatting.py b/libs/viscv/viscv/transforms/formatting.py deleted file mode 100644 index e075692..0000000 --- a/libs/viscv/viscv/transforms/formatting.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Sequence - -import numpy as np -import torch -from visengine.utils import is_str - -from .base import BaseTransform -from .builder import TRANSFORMS - - -def to_tensor(data: torch.Tensor | np.ndarray | Sequence | int | float) -> torch.Tensor: - """Convert objects of various python types to :obj:`torch.Tensor`. - - Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, - :class:`Sequence`, :class:`int` and :class:`float`. - - Args: - data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to - be converted. - - Returns: - torch.Tensor: the converted data. - """ - - if isinstance(data, torch.Tensor): - return data - elif isinstance(data, np.ndarray): - return torch.from_numpy(data) - elif isinstance(data, Sequence) and not is_str(data): - return torch.tensor(data) - elif isinstance(data, int): - return torch.LongTensor([data]) - elif isinstance(data, float): - return torch.FloatTensor([data]) - else: - raise TypeError(f"type {type(data)} cannot be converted to tensor.") - - -@TRANSFORMS.register_module() -class ToTensor(BaseTransform): - """Convert some results to :obj:`torch.Tensor` by given keys. - - Required keys: - - - all these keys in `keys` - - Modified Keys: - - - all these keys in `keys` - - Args: - keys (Sequence[str]): Keys that need to be converted to Tensor. - """ - - def __init__(self, keys: Sequence[str]) -> None: - self.keys = keys - - def transform(self, results: dict) -> dict: - """Transform function to convert data to `torch.Tensor`. - - Args: - results (dict): Result dict from loading pipeline. - Returns: - dict: `keys` in results will be updated. - """ - for key in self.keys: - key_list = key.split(".") - cur_item = results - for i in range(len(key_list)): - if key_list[i] not in cur_item: - raise KeyError(f"Can not find key {key}") - if i == len(key_list) - 1: - cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]]) - break - cur_item = cur_item[key_list[i]] - - return results - - def __repr__(self) -> str: - return self.__class__.__name__ + f"(keys={self.keys})" - - -@TRANSFORMS.register_module() -class ImageToTensor(BaseTransform): - """Convert image to :obj:`torch.Tensor` by given keys. - - The dimension order of input image is (H, W, C). The pipeline will convert - it to (C, H, W). If only 2 dimension (H, W) is given, the output would be - (1, H, W). - - Required keys: - - - all these keys in `keys` - - Modified Keys: - - - all these keys in `keys` - - Args: - keys (Sequence[str]): Key of images to be converted to Tensor. - """ - - def __init__(self, keys: dict) -> None: - self.keys = keys - - def transform(self, results: dict) -> dict: - """Transform function to convert image in results to - :obj:`torch.Tensor` and transpose the channel order. - - Args: - results (dict): Result dict contains the image data to convert. - Returns: - dict: The result dict contains the image converted - to :obj:``torch.Tensor`` and transposed to (C, H, W) order. - """ - for key in self.keys: - img = results[key] - if len(img.shape) < 3: - img = np.expand_dims(img, -1) - results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() - return results - - def __repr__(self) -> str: - return self.__class__.__name__ + f"(keys={self.keys})" diff --git a/libs/viscv/viscv/transforms/loading.py b/libs/viscv/viscv/transforms/loading.py deleted file mode 100644 index 0f126d9..0000000 --- a/libs/viscv/viscv/transforms/loading.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import numpy as np -from visengine import fileio as engine_fileio - -import viscv -import viscv.fileio as fileio - -from .base import BaseTransform -from .builder import TRANSFORMS - - -@TRANSFORMS.register_module() -class LoadImageFromFile(BaseTransform): - """Load an image from file. - - Required Keys: - - - img_path - - Modified Keys: - - - img - - img_shape - - ori_shape - - Args: - to_float32 (bool): Whether to convert the loaded image to a float32 - numpy array. If set to False, the loaded image is an uint8 array. - Defaults to False. - color_type (str): The flag argument for :func:`viscv.imfrombytes`. - Defaults to 'color'. - imdecode_backend (str): The image decoding backend type. The backend - argument for :func:`viscv.imfrombytes`. - See :func:`viscv.imfrombytes` for details. - Defaults to 'cv2'. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`visengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - Deprecated in version 2.0.0rc4. - ignore_empty (bool): Whether to allow loading empty image or file path - not existent. Defaults to False. - backend_args (dict, optional): Instantiates the corresponding file - backend. It may contain `backend` key to specify the file - backend. If it contains, the file backend corresponding to this - value will be used and initialized with the remaining values, - otherwise the corresponding file backend will be selected - based on the prefix of the file path. Defaults to None. - New in version 2.0.0rc4. - enable_cache (bool): Whether to enable on-disk caching of resized images. - When enabled, images are cached after resizing to avoid repeated - decoding and resizing in subsequent epochs. Defaults to False. - cache_dir (str, optional): Directory to store cached images. If None, - uses default cache directory (~/.cache/viscv/image_cache). - Only used if enable_cache is True. - cache_max_size_gb (float): Maximum cache size in gigabytes. When exceeded, - least recently used images are evicted. Defaults to 10.0. - Only used if enable_cache is True. - """ - - def __init__( - self, - to_float32: bool = False, - color_type: str = "color", - imdecode_backend: str = "cv2", - file_client_args: dict | None = None, - ignore_empty: bool = False, - *, - backend_args: dict | None = None, - enable_cache: bool = False, - cache_dir: str | None = None, - cache_max_size_gb: float = 10.0, - ) -> None: - self.ignore_empty = ignore_empty - self.to_float32 = to_float32 - self.color_type = color_type - self.imdecode_backend = imdecode_backend - - self.file_client_args: dict | None = None - self.backend_args: dict | None = None - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - self.file_client_args = file_client_args.copy() - if backend_args is not None: - self.backend_args = backend_args.copy() - - # Initialize cache - from viscv.image import ImageCache - - self.cache = ImageCache( - cache_dir=cache_dir if cache_dir is not None else ImageCache.__init__.__defaults__[0], # type: ignore - max_size_gb=cache_max_size_gb, - enabled=enable_cache, - ) - - def transform(self, results: dict) -> dict | None: - """Functions to load image. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded image and meta information. - """ - - filename = results["img_path"] - - # Check cache if resize target is known - target_size = results.get("_cache_target_size") - if target_size is not None and self.cache.enabled: - cached_img = self.cache.get(filename, target_size) - if cached_img is not None: - img = cached_img - if self.to_float32: - img = img.astype(np.float32) - - results["img"] = img - results["img_shape"] = img.shape[:2] - results["ori_shape"] = cached_img.shape[:2] # Store original cached size - results["_cache_hit"] = True - results["_image_cache"] = self.cache - return results - - # Cache miss - load normally - try: - if self.file_client_args is not None: - file_client = fileio.FileClient.infer_client(self.file_client_args, filename) - img_bytes = file_client.get(filename) - else: - img_bytes = fileio.get(filename, backend_args=self.backend_args) - img = viscv.imfrombytes(img_bytes, flag=self.color_type, backend=self.imdecode_backend) - except Exception as e: - if self.ignore_empty: - return None - else: - raise e - # in some cases, images are not read successfully, the img would be - # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427 - assert img is not None, f"failed to load image: {filename}" - if self.to_float32: - img = img.astype(np.float32) - - results["img"] = img - results["img_shape"] = img.shape[:2] - results["ori_shape"] = img.shape[:2] - results["_cache_hit"] = False - - # Pass cache instance to downstream transforms - if self.cache.enabled: - results["_image_cache"] = self.cache - - return results - - def __repr__(self): - repr_str = ( - f"{self.__class__.__name__}(" - f"ignore_empty={self.ignore_empty}, " - f"to_float32={self.to_float32}, " - f"color_type='{self.color_type}', " - f"imdecode_backend='{self.imdecode_backend}', " - f"cache_enabled={self.cache.enabled}, " - ) - - if self.file_client_args is not None: - repr_str += f"file_client_args={self.file_client_args})" - else: - repr_str += f"backend_args={self.backend_args})" - - return repr_str - - -@TRANSFORMS.register_module() -class LoadAnnotations(BaseTransform): - """Load and process the ``instances`` and ``seg_map`` annotation provided - by dataset. - - The annotation format is as the following: - - .. code-block:: python - - { - 'instances': - [ - { - # List of 4 numbers representing the bounding box of the - # instance, in (x1, y1, x2, y2) order. - 'bbox': [x1, y1, x2, y2], - - # Label of image classification. - 'bbox_label': 1, - - # Used in key point detection. - # Can only load the format of [x1, y1, v1,…, xn, yn, vn]. v[i] - # means the visibility of this keypoint. n must be equal to the - # number of keypoint categories. - 'keypoints': [x1, y1, v1, ..., xn, yn, vn] - } - ] - # Filename of semantic or panoptic segmentation ground truth file. - 'seg_map_path': 'a/b/c' - } - - After this module, the annotation has been changed to the format below: - - .. code-block:: python - - { - # In (x1, y1, x2, y2) order, float type. N is the number of bboxes - # in np.float32 - 'gt_bboxes': np.ndarray(N, 4) - # In np.int64 type. - 'gt_bboxes_labels': np.ndarray(N, ) - # In uint8 type. - 'gt_seg_map': np.ndarray (H, W) - # with (x, y, v) order, in np.float32 type. - 'gt_keypoints': np.ndarray(N, NK, 3) - } - - Required Keys: - - - instances - - - bbox (optional) - - bbox_label - - keypoints (optional) - - - seg_map_path (optional) - - Added Keys: - - - gt_bboxes (np.float32) - - gt_bboxes_labels (np.int64) - - gt_seg_map (np.uint8) - - gt_keypoints (np.float32) - - Args: - with_bbox (bool): Whether to parse and load the bbox annotation. - Defaults to True. - with_label (bool): Whether to parse and load the label annotation. - Defaults to True. - with_seg (bool): Whether to parse and load the semantic segmentation - annotation. Defaults to False. - with_keypoints (bool): Whether to parse and load the keypoints - annotation. Defaults to False. - imdecode_backend (str): The image decoding backend type. The backend - argument for :func:`viscv.imfrombytes`. - See :func:`viscv.imfrombytes` for details. - Defaults to 'cv2'. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`visengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - Deprecated in version 2.0.0rc4. - backend_args (dict, optional): Instantiates the corresponding file - backend. It may contain `backend` key to specify the file - backend. If it contains, the file backend corresponding to this - value will be used and initialized with the remaining values, - otherwise the corresponding file backend will be selected - based on the prefix of the file path. Defaults to None. - New in version 2.0.0rc4. - """ - - def __init__( - self, - with_bbox: bool = True, - with_label: bool = True, - with_seg: bool = False, - with_keypoints: bool = False, - imdecode_backend: str = "cv2", - file_client_args: dict | None = None, - *, - backend_args: dict | None = None, - ) -> None: - super().__init__() - self.with_bbox = with_bbox - self.with_label = with_label - self.with_seg = with_seg - self.with_keypoints = with_keypoints - self.imdecode_backend = imdecode_backend - - self.file_client_args: dict | None = None - self.backend_args: dict | None = None - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - self.file_client_args = file_client_args.copy() - if backend_args is not None: - self.backend_args = backend_args.copy() - - def _load_bboxes(self, results: dict) -> None: - """Private function to load bounding box annotations. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded bounding box annotations. - """ - gt_bboxes = [] - for instance in results["instances"]: - gt_bboxes.append(instance["bbox"]) - results["gt_bboxes"] = np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4) - - def _load_labels(self, results: dict) -> None: - """Private function to load label annotations. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded label annotations. - """ - gt_bboxes_labels = [] - for instance in results["instances"]: - gt_bboxes_labels.append(instance["bbox_label"]) - results["gt_bboxes_labels"] = np.array(gt_bboxes_labels, dtype=np.int64) - - def _load_seg_map(self, results: dict) -> None: - """Private function to load semantic segmentation annotations. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded semantic segmentation annotations. - """ - if self.file_client_args is not None: - file_client = engine_fileio.FileClient.infer_client(self.file_client_args, results["seg_map_path"]) - img_bytes = file_client.get(results["seg_map_path"]) - else: - img_bytes = engine_fileio.get(results["seg_map_path"], backend_args=self.backend_args) - - results["gt_seg_map"] = viscv.imfrombytes(img_bytes, flag="unchanged", backend=self.imdecode_backend).squeeze() - - def _load_kps(self, results: dict) -> None: - """Private function to load keypoints annotations. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded keypoints annotations. - """ - gt_keypoints = [] - for instance in results["instances"]: - gt_keypoints.append(instance["keypoints"]) - results["gt_keypoints"] = np.array(gt_keypoints, np.float32).reshape((len(gt_keypoints), -1, 3)) - - def transform(self, results: dict) -> dict: - """Function to load multiple types annotations. - - Args: - results (dict): Result dict from - :class:`visengine.dataset.BaseDataset`. - - Returns: - dict: The dict contains loaded bounding box, label and - semantic segmentation and keypoints annotations. - """ - if self.with_bbox: - self._load_bboxes(results) - if self.with_label: - self._load_labels(results) - if self.with_seg: - self._load_seg_map(results) - if self.with_keypoints: - self._load_kps(results) - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(with_bbox={self.with_bbox}, " - repr_str += f"with_label={self.with_label}, " - repr_str += f"with_seg={self.with_seg}, " - repr_str += f"with_keypoints={self.with_keypoints}, " - repr_str += f"imdecode_backend='{self.imdecode_backend}', " - - if self.file_client_args is not None: - repr_str += f"file_client_args={self.file_client_args})" - else: - repr_str += f"backend_args={self.backend_args})" - - return repr_str diff --git a/libs/viscv/viscv/transforms/processing.py b/libs/viscv/viscv/transforms/processing.py deleted file mode 100644 index 6d87216..0000000 --- a/libs/viscv/viscv/transforms/processing.py +++ /dev/null @@ -1,1497 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import random -import warnings -from collections.abc import Iterable, Sequence -from itertools import product - -import numpy as np -from visengine.utils import is_list_of, is_seq_of, is_tuple_of - -from viscv.image import hsv2bgr, imcrop, imflip, imnormalize, impad - -from .base import BaseTransform -from .builder import TRANSFORMS -from .utils import cache_randomness -from .wrappers import Compose - -Number = int | float - - -@TRANSFORMS.register_module() -class Normalize(BaseTransform): - """Normalize the image. - - Required Keys: - - - img - - Modified Keys: - - - img - - Added Keys: - - - img_norm_cfg - - - mean - - std - - to_rgb - - - Args: - mean (sequence): Mean values of 3 channels. - std (sequence): Std values of 3 channels. - to_rgb (bool): Whether to convert the image from BGR to RGB before - normlizing the image. If ``to_rgb=True``, the order of mean and std - should be RGB. If ``to_rgb=False``, the order of mean and std - should be the same order of the image. Defaults to True. - """ - - def __init__(self, mean: Sequence[Number], std: Sequence[Number], to_rgb: bool = True) -> None: - self.mean = np.array(mean, dtype=np.float32) - self.std = np.array(std, dtype=np.float32) - self.to_rgb = to_rgb - - def transform(self, results: dict) -> dict: - """Function to normalize images. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Normalized results, key 'img_norm_cfg' key is added in to - result dict. - """ - - results["img"] = imnormalize(results["img"], self.mean, self.std, self.to_rgb) - results["img_norm_cfg"] = { - "mean": self.mean, - "std": self.std, - "to_rgb": self.to_rgb, - } - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})" - return repr_str - - -@TRANSFORMS.register_module() -class Resize(BaseTransform): - """Resize images & bbox & mask. - - This transform resizes the input image to the given scale. - Bboxes and masks are then resized accordingly. - - Required Keys: - - - img - - gt_bboxes (optional) - - gt_masks (optional) - - Modified Keys: - - - img - - gt_bboxes - - gt_masks - - img_shape - - Added Keys: - - - scale_factor - - Args: - scale (tuple[int]): Images scales for resizing in (w, h) order. - keep_ratio (bool): Whether to keep the aspect ratio when resizing the - image. Defaults to True. - enable_cache (bool): Whether to enable caching of resized images. - When enabled, resized images will be saved to disk cache if the - LoadImageFromFile transform has caching enabled. Defaults to True. - """ - - def __init__(self, scale, keep_ratio=True, enable_cache=True): - if isinstance(scale, int): - scale = (scale, scale) - self.scale = scale - self.keep_ratio = keep_ratio - self.enable_cache = enable_cache - - def transform(self, results): - """Transform function to resize images, bounding boxes and masks. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Resized results. - """ - from viscv.image import imresize - - img = results["img"] - h, w = img.shape[:2] - - if self.keep_ratio: - # Calculate scale factor to fit within target size - scale_factor = min(self.scale[0] / w, self.scale[1] / h) - new_w = int(w * scale_factor) - new_h = int(h * scale_factor) - else: - new_w, new_h = self.scale - scale_factor = (new_w / w, new_h / h) - - # Resize image - resized_img = imresize(img, (new_w, new_h)) - results["img"] = resized_img - results["img_shape"] = resized_img.shape[:2] - - # Record the scale factor - if isinstance(scale_factor, int | float): - scale_factor = (scale_factor, scale_factor) - results["scale_factor"] = scale_factor - - # Resize bboxes - if "gt_bboxes" in results: - bboxes = results["gt_bboxes"] - - # Check if it's a numpy array or a BaseBoxes object - if hasattr(bboxes, "rescale_"): - # It's a BaseBoxes object, use its rescale_ method (in-place) - bboxes.rescale_(scale_factor) - results["gt_bboxes"] = bboxes - elif isinstance(bboxes, np.ndarray) and len(bboxes) > 0: - # It's a numpy array, scale directly - bboxes = bboxes.copy() - bboxes[:, 0::2] *= scale_factor[0] # x coordinates - bboxes[:, 1::2] *= scale_factor[1] # y coordinates - results["gt_bboxes"] = bboxes - - # Resize masks - if "gt_masks" in results: - gt_masks = results["gt_masks"] - # Handle different mask formats - if hasattr(gt_masks, "resize"): - results["gt_masks"] = gt_masks.resize((new_h, new_w)) - elif isinstance(gt_masks, list): - # Polygon format - scale coordinates - resized_masks = [] - for mask in gt_masks: - resized_mask = [] - for poly in mask: - # Scale polygon coordinates - poly = np.array(poly).reshape(-1, 2) - poly[:, 0] *= scale_factor[0] - poly[:, 1] *= scale_factor[1] - resized_mask.append(poly.reshape(-1).tolist()) - resized_masks.append(resized_mask) - results["gt_masks"] = resized_masks - - # Save to cache if enabled and not already cached - if self.enable_cache and not results.get("_cache_hit", False): - img_path = results.get("img_path") - cache = results.get("_image_cache") - - if img_path is not None and cache is not None: - # Store target size so LoadImageFromFile can check cache on next load - target_size = (new_w, new_h) - results["_cache_target_size"] = target_size - - # Save resized image to cache - cache.put(img_path, target_size, resized_img) - - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f"(scale={self.scale}, keep_ratio={self.keep_ratio})" - return repr_str - - -@TRANSFORMS.register_module() -class Pad(BaseTransform): - """Pad the image & segmentation map. - - There are three padding modes: (1) pad to a fixed size and (2) pad to the - minimum size that is divisible by some number. and (3)pad to square. Also, - pad to square and pad to the minimum size can be used as the same time. - - Required Keys: - - - img - - gt_bboxes (optional) - - gt_seg_map (optional) - - Modified Keys: - - - img - - gt_seg_map - - img_shape - - Added Keys: - - - pad_shape - - pad_fixed_size - - pad_size_divisor - - Args: - size (tuple, optional): Fixed padding size. - Expected padding shape (w, h). Defaults to None. - size_divisor (int, optional): The divisor of padded size. Defaults to - None. - pad_to_square (bool): Whether to pad the image into a square. - Currently only used for YOLOX. Defaults to False. - pad_val (Number | dict[str, Number], optional): Padding value for if - the pad_mode is "constant". If it is a single number, the value - to pad the image is the number and to pad the semantic - segmentation map is 255. If it is a dict, it should have the - following keys: - - - img: The value to pad the image. - - seg: The value to pad the semantic segmentation map. - - Defaults to dict(img=0, seg=255). - padding_mode (str): Type of padding. Should be: constant, edge, - reflect or symmetric. Defaults to 'constant'. - - - constant: pads with a constant value, this value is specified - with pad_val. - - edge: pads with the last value at the edge of the image. - - reflect: pads with reflection of image without repeating the last - value on the edge. For example, padding [1, 2, 3, 4] with 2 - elements on both sides in reflect mode will result in - [3, 2, 1, 2, 3, 4, 3, 2]. - - symmetric: pads with reflection of image repeating the last value - on the edge. For example, padding [1, 2, 3, 4] with 2 elements on - both sides in symmetric mode will result in - [2, 1, 1, 2, 3, 4, 4, 3] - """ - - def __init__( - self, - size: tuple[int, int] | None = None, - size_divisor: int | None = None, - pad_to_square: bool = False, - pad_val: Number | dict = dict(img=0, seg=255), - padding_mode: str = "constant", - ) -> None: - self.size = size - self.size_divisor = size_divisor - if isinstance(pad_val, int): - pad_val = dict(img=pad_val, seg=255) - assert isinstance(pad_val, dict), "pad_val " - self.pad_val = pad_val - self.pad_to_square = pad_to_square - - if pad_to_square: - assert size is None, "The size and size_divisor must be None when pad2square is True" - else: - assert size is not None or size_divisor is not None, "only one of size and size_divisor should be valid" - assert size is None or size_divisor is None - assert padding_mode in ["constant", "edge", "reflect", "symmetric"] - self.padding_mode = padding_mode - - def _pad_img(self, results: dict) -> None: - """Pad images according to ``self.size``.""" - pad_val = self.pad_val.get("img", 0) - - size = None - if self.pad_to_square: - max_size = max(results["img"].shape[:2]) - size = (max_size, max_size) - if self.size_divisor is not None: - if size is None: - size = (results["img"].shape[0], results["img"].shape[1]) - pad_h = int(np.ceil(size[0] / self.size_divisor)) * self.size_divisor - pad_w = int(np.ceil(size[1] / self.size_divisor)) * self.size_divisor - size = (pad_h, pad_w) - elif self.size is not None: - size = self.size[::-1] - if isinstance(pad_val, int) and results["img"].ndim == 3: - pad_val = tuple(pad_val for _ in range(results["img"].shape[2])) - padded_img = impad(results["img"], shape=size, pad_val=pad_val, padding_mode=self.padding_mode) - - results["img"] = padded_img - results["pad_shape"] = padded_img.shape - results["pad_fixed_size"] = self.size - results["pad_size_divisor"] = self.size_divisor - results["img_shape"] = padded_img.shape[:2] - - def _pad_seg(self, results: dict) -> None: - """Pad semantic segmentation map according to - ``results['pad_shape']``.""" - if results.get("gt_seg_map", None) is not None: - pad_val = self.pad_val.get("seg", 255) - if isinstance(pad_val, int) and results["gt_seg_map"].ndim == 3: - pad_val = tuple(pad_val for _ in range(results["gt_seg_map"].shape[2])) - results["gt_seg_map"] = impad( - results["gt_seg_map"], - shape=results["pad_shape"][:2], - pad_val=pad_val, - padding_mode=self.padding_mode, - ) - - def transform(self, results: dict) -> dict: - """Call function to pad images, masks, semantic segmentation maps. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Updated result dict. - """ - self._pad_img(results) - self._pad_seg(results) - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f"(size={self.size}, " - repr_str += f"size_divisor={self.size_divisor}, " - repr_str += f"pad_to_square={self.pad_to_square}, " - repr_str += f"pad_val={self.pad_val}), " - repr_str += f"padding_mode={self.padding_mode})" - return repr_str - - -@TRANSFORMS.register_module() -class CenterCrop(BaseTransform): - """Crop the center of the image, segmentation masks, bounding boxes and key - points. If the crop area exceeds the original image and ``auto_pad`` is - True, the original image will be padded before cropping. - - Required Keys: - - - img - - gt_seg_map (optional) - - gt_bboxes (optional) - - gt_keypoints (optional) - - Modified Keys: - - - img - - img_shape - - gt_seg_map (optional) - - gt_bboxes (optional) - - gt_keypoints (optional) - - Added Key: - - - pad_shape - - - Args: - crop_size (Union[int, Tuple[int, int]]): Expected size after cropping - with the format of (w, h). If set to an integer, then cropping - width and height are equal to this integer. - auto_pad (bool): Whether to pad the image if it's smaller than the - ``crop_size``. Defaults to False. - pad_cfg (dict): Base config for padding. Refer to ``mmcv.Pad`` for - detail. Defaults to ``dict(type='Pad')``. - clip_object_border (bool): Whether to clip the objects - outside the border of the image. In some dataset like MOT17, the - gt bboxes are allowed to cross the border of images. Therefore, - we don't need to clip the gt bboxes in these cases. - Defaults to True. - """ - - def __init__( - self, - crop_size: int | tuple[int, int], - auto_pad: bool = False, - pad_cfg: dict = dict(type="Pad"), - clip_object_border: bool = True, - ) -> None: - super().__init__() - assert isinstance(crop_size, int) or (isinstance(crop_size, tuple) and len(crop_size) == 2), ( - "The expected crop_size is an integer, or a tuple containing two " - ) - "intergers" - - if isinstance(crop_size, int): - crop_size = (crop_size, crop_size) - assert crop_size[0] > 0 and crop_size[1] > 0 - self.crop_size = crop_size - self.auto_pad = auto_pad - - self.pad_cfg = pad_cfg.copy() - # size will be overwritten - if "size" in self.pad_cfg and auto_pad: - warnings.warn( - "``size`` is set in ``pad_cfg``,however this argument will be overwritten according to crop size and image size" - ) - - self.clip_object_border = clip_object_border - - def _crop_img(self, results: dict, bboxes: np.ndarray) -> None: - """Crop image. - - Args: - results (dict): Result dict contains the data to transform. - bboxes (np.ndarray): Shape (4, ), location of cropped bboxes. - """ - if results.get("img", None) is not None: - img = imcrop(results["img"], bboxes=bboxes) - img_shape = img.shape[:2] # type: ignore - results["img"] = img - results["img_shape"] = img_shape - results["pad_shape"] = img_shape - - def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None: - """Crop semantic segmentation map. - - Args: - results (dict): Result dict contains the data to transform. - bboxes (np.ndarray): Shape (4, ), location of cropped bboxes. - """ - if results.get("gt_seg_map", None) is not None: - img = imcrop(results["gt_seg_map"], bboxes=bboxes) - results["gt_seg_map"] = img - - def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None: - """Update bounding boxes according to CenterCrop. - - Args: - results (dict): Result dict contains the data to transform. - bboxes (np.ndarray): Shape (4, ), location of cropped bboxes. - """ - if "gt_bboxes" in results: - offset_w = bboxes[0] - offset_h = bboxes[1] - bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h]) - # gt_bboxes has shape (num_gts, 4) in (tl_x, tl_y, br_x, br_y) - # order. - gt_bboxes = results["gt_bboxes"] - bbox_offset - if self.clip_object_border: - # Clip coordinates - img_h, img_w = results["img"].shape[:2] - - # Clip x coordinates - gt_bboxes[:, 0] = np.clip(gt_bboxes[:, 0], 0, img_w) # x_min - gt_bboxes[:, 2] = np.clip(gt_bboxes[:, 2], 0, img_w) # x_max - - # Clip y coordinates - gt_bboxes[:, 1] = np.clip(gt_bboxes[:, 1], 0, img_h) # y_min - gt_bboxes[:, 3] = np.clip(gt_bboxes[:, 3], 0, img_h) # y_max - results["gt_bboxes"] = gt_bboxes - - def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None: - """Update key points according to CenterCrop. Keypoints that not in the - cropped image will be set invisible. - - Args: - results (dict): Result dict contains the data to transform. - bboxes (np.ndarray): Shape (4, ), location of cropped bboxes. - """ - if "gt_keypoints" in results: - offset_w = bboxes[0] - offset_h = bboxes[1] - keypoints_offset = np.array([offset_w, offset_h, 0]) - # gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order, - # NK = number of points per object - gt_keypoints = results["gt_keypoints"] - keypoints_offset - # set gt_kepoints out of the result image invisible - height, width = results["img"].shape[:2] - valid_pos = ( - (gt_keypoints[:, :, 0] >= 0) - * (gt_keypoints[:, :, 0] < width) - * (gt_keypoints[:, :, 1] >= 0) - * (gt_keypoints[:, :, 1] < height) - ) - gt_keypoints[:, :, 2] = np.where(valid_pos, gt_keypoints[:, :, 2], 0) - gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0, results["img"].shape[1]) - gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0, results["img"].shape[0]) - results["gt_keypoints"] = gt_keypoints - - def transform(self, results: dict) -> dict: - """Apply center crop on results. - - Args: - results (dict): Result dict contains the data to transform. - - Returns: - dict: Results with CenterCropped image and semantic segmentation - map. - """ - crop_width, crop_height = self.crop_size[0], self.crop_size[1] - - assert "img" in results, "`img` is not found in results" - img = results["img"] - # img.shape has length 2 for grayscale, length 3 for color - img_height, img_width = img.shape[:2] - - if crop_height > img_height or crop_width > img_width: - if self.auto_pad: - # pad the area - img_height = max(img_height, crop_height) - img_width = max(img_width, crop_width) - pad_size = (img_width, img_height) - _pad_cfg = self.pad_cfg.copy() - _pad_cfg.update(dict(size=pad_size)) - pad_transform = TRANSFORMS.build(_pad_cfg) - results = pad_transform(results) - else: - crop_height = min(crop_height, img_height) - crop_width = min(crop_width, img_width) - - y1 = max(0, int(round((img_height - crop_height) / 2.0))) - x1 = max(0, int(round((img_width - crop_width) / 2.0))) - y2 = min(img_height, y1 + crop_height) - 1 - x2 = min(img_width, x1 + crop_width) - 1 - bboxes = np.array([x1, y1, x2, y2]) - - # crop the image - self._crop_img(results, bboxes) - # crop the gt_seg_map - self._crop_seg_map(results, bboxes) - # crop the bounding box - self._crop_bboxes(results, bboxes) - # crop the keypoints - self._crop_keypoints(results, bboxes) - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(crop_size = {self.crop_size}" - repr_str += f", auto_pad={self.auto_pad}" - repr_str += f", pad_cfg={self.pad_cfg}" - repr_str += f",clip_object_border = {self.clip_object_border})" - return repr_str - - -@TRANSFORMS.register_module() -class RandomGrayscale(BaseTransform): - """Randomly convert image to grayscale with a probability. - - Required Key: - - - img - - Modified Key: - - - img - - Added Keys: - - - grayscale - - grayscale_weights - - Args: - prob (float): Probability that image should be converted to - grayscale. Defaults to 0.1. - keep_channels (bool): Whether keep channel number the same as - input. Defaults to False. - channel_weights (tuple): The grayscale weights of each channel, - and the weights will be normalized. For example, (1, 2, 1) - will be normalized as (0.25, 0.5, 0.25). Defaults to - (1., 1., 1.). - color_format (str): Color format set to be any of 'bgr', - 'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr' - format no matter whether it is grayscaled. Defaults to 'bgr'. - """ - - def __init__( - self, - prob: float = 0.1, - keep_channels: bool = False, - channel_weights: Sequence[float] = (1.0, 1.0, 1.0), - color_format: str = "bgr", - ) -> None: - super().__init__() - assert 0.0 <= prob <= 1.0, "The range of ``prob`` value is [0., 1.]," + f" but got {prob} instead" - self.prob = prob - self.keep_channels = keep_channels - self.channel_weights = channel_weights - assert color_format in ["bgr", "rgb", "hsv"] - self.color_format = color_format - - @cache_randomness - def _random_prob(self): - return random.random() - - def transform(self, results: dict) -> dict: - """Apply random grayscale on results. - - Args: - results (dict): Result dict contains the data to transform. - - Returns: - dict: Results with grayscale image. - """ - img = results["img"] - # convert hsv to bgr - if self.color_format == "hsv": - img = hsv2bgr(img) - img = img[..., None] if img.ndim == 2 else img - num_output_channels = img.shape[2] - if self._random_prob() < self.prob: - if num_output_channels > 1: - assert num_output_channels == len(self.channel_weights), ( - "The length of ``channel_weights`` are supposed to be " - ) - f"num_output_channels, but got {len(self.channel_weights)}" - " instead." - normalized_weights = np.array(self.channel_weights) / sum(self.channel_weights) - img = (normalized_weights * img).sum(axis=2) - img = img.astype("uint8") - if self.keep_channels: - img = img[:, :, None] - results["img"] = np.dstack([img for _ in range(num_output_channels)]) - else: - results["img"] = img - return results - img = img.astype("uint8") - results["img"] = img - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(prob = {self.prob}" - repr_str += f", keep_channels = {self.keep_channels}" - repr_str += f", channel_weights = {self.channel_weights}" - repr_str += f", color_format = {self.color_format})" - return repr_str - - -@TRANSFORMS.register_module() -class MultiScaleFlipAug(BaseTransform): - """Test-time augmentation with multiple scales and flipping. - - An example configuration is as followed: - - .. code-block:: - - dict( - type='MultiScaleFlipAug', - scales=[(1333, 400), (1333, 800)], - flip=True, - transforms=[ - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=1), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']) - ]) - - ``results`` will be resized using all the sizes in ``scales``. - If ``flip`` is True, then flipped results will also be added into output - list. - - For the above configuration, there are four combinations of resize - and flip: - - - Resize to (1333, 400) + no flip - - Resize to (1333, 400) + flip - - Resize to (1333, 800) + no flip - - resize to (1333, 800) + flip - - The four results are then transformed with ``transforms`` argument. - After that, results are wrapped into lists of the same length as below: - - .. code-block:: - - dict( - inputs=[...], - data_samples=[...] - ) - - Where the length of ``inputs`` and ``data_samples`` are both 4. - - Required Keys: - - - Depending on the requirements of the ``transforms`` parameter. - - Modified Keys: - - - All output keys of each transform. - - Args: - transforms (list[dict]): Transforms to be applied to each resized - and flipped data. - scales (tuple | list[tuple] | None): Images scales for resizing. - scale_factor (float or tuple[float]): Scale factors for resizing. - Defaults to None. - allow_flip (bool): Whether apply flip augmentation. Defaults to False. - flip_direction (str | list[str]): Flip augmentation directions, - options are "horizontal", "vertical" and "diagonal". If - flip_direction is a list, multiple flip augmentations will be - applied. It has no effect when flip == False. Defaults to - "horizontal". - resize_cfg (dict): Base config for resizing. Defaults to - ``dict(type='Resize', keep_ratio=True)``. - flip_cfg (dict): Base config for flipping. Defaults to - ``dict(type='RandomFlip')``. - """ - - def __init__( - self, - transforms: list[dict], - scales: tuple | list[tuple] | None = None, - scale_factor: float | list[float] | None = None, - allow_flip: bool = False, - flip_direction: str | list[str] = "horizontal", - resize_cfg: dict = dict(type="Resize", keep_ratio=True), - flip_cfg: dict = dict(type="RandomFlip"), - ) -> None: - super().__init__() - self.transforms = Compose(transforms) # type: ignore - - if scales is not None: - self.scales = scales if isinstance(scales, list) else [scales] - self.scale_key = "scale" - assert is_list_of(self.scales, tuple) - else: - # if ``scales`` and ``scale_factor`` both be ``None`` - if scale_factor is None: - self.scales = [1.0] # type: ignore - elif isinstance(scale_factor, list): - self.scales = scale_factor # type: ignore - else: - self.scales = [scale_factor] # type: ignore - - self.scale_key = "scale_factor" - - self.allow_flip = allow_flip - self.flip_direction = flip_direction if isinstance(flip_direction, list) else [flip_direction] - assert is_list_of(self.flip_direction, str) - if not self.allow_flip and self.flip_direction != ["horizontal"]: - warnings.warn("flip_direction has no effect when flip is set to False") - self.resize_cfg = resize_cfg.copy() - self.flip_cfg = flip_cfg - - def transform(self, results: dict) -> dict: - """Apply test time augment transforms on results. - - Args: - results (dict): Result dict contains the data to transform. - - Returns: - dict: The augmented data, where each value is wrapped - into a list. - """ - - data_samples = [] - inputs = [] - flip_args = [(False, "")] - if self.allow_flip: - flip_args += [(True, direction) for direction in self.flip_direction] - for scale in self.scales: - for flip, direction in flip_args: - _resize_cfg = self.resize_cfg.copy() - _resize_cfg.update({self.scale_key: scale}) - _resize_flip = [_resize_cfg] - - if flip: - _flip_cfg = self.flip_cfg.copy() - _flip_cfg.update(prob=1.0, direction=direction) - _resize_flip.append(_flip_cfg) - else: - results["flip"] = False - results["flip_direction"] = None - - resize_flip = Compose(_resize_flip) - _results = resize_flip(results.copy()) - packed_results = self.transforms(_results) # type: ignore - - inputs.append(packed_results["inputs"]) # type: ignore - data_samples.append(packed_results["data_sample"]) # type: ignore - return dict(inputs=inputs, data_sample=data_samples) - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(transforms={self.transforms}" - repr_str += f", scales={self.scales}" - repr_str += f", allow_flip={self.allow_flip}" - repr_str += f", flip_direction={self.flip_direction})" - return repr_str - - -@TRANSFORMS.register_module() -class TestTimeAug(BaseTransform): - """Test-time augmentation transform. - - An example configuration is as followed: - - .. code-block:: - - dict(type='TestTimeAug', - transforms=[ - [dict(type='Resize', scale=(1333, 400), keep_ratio=True), - dict(type='Resize', scale=(1333, 800), keep_ratio=True)], - [dict(type='RandomFlip', prob=1.), - dict(type='RandomFlip', prob=0.)], - [dict(type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', - 'img_shape', 'scale_factor', 'flip', - 'flip_direction'))]]) - - ``results`` will be transformed using all transforms defined in - ``transforms`` arguments. - - For the above configuration, there are four combinations of resize - and flip: - - - Resize to (1333, 400) + no flip - - Resize to (1333, 400) + flip - - Resize to (1333, 800) + no flip - - resize to (1333, 800) + flip - - After that, results are wrapped into lists of the same length as below: - - .. code-block:: - - dict( - inputs=[...], - data_samples=[...] - ) - - The length of ``inputs`` and ``data_samples`` are both 4. - - Required Keys: - - - Depending on the requirements of the ``transforms`` parameter. - - Modified Keys: - - - All output keys of each transform. - - Args: - transforms (list[list[dict]]): Transforms to be applied to data sampled - from dataset. ``transforms`` is a list of list, and each list - element usually represents a series of transforms with the same - type and different arguments. Data will be processed by each list - elements sequentially. See more information in :meth:`transform`. - """ - - def __init__(self, transforms: list): - for i, transform_list in enumerate(transforms): - for j, transform in enumerate(transform_list): - if isinstance(transform, dict): - transform_list[j] = TRANSFORMS.build(transform) - elif callable(transform): - continue - else: - raise TypeError(f"transform must be callable or a dict, but got {type(transform)}") - transforms[i] = transform_list - - self.subroutines = [Compose(subroutine) for subroutine in product(*transforms)] - - def transform(self, results: dict) -> dict: - """Apply all transforms defined in :attr:`transforms` to the results. - - As the example given in :obj:`TestTimeAug`, ``transforms`` consists of - 2 ``Resize``, 2 ``RandomFlip`` and 1 ``PackDetInputs``. - The data sampled from dataset will be processed as follows: - - 1. Data will be processed by 2 ``Resize`` and return a list - of 2 results. - 2. Each result in list will be further passed to 2 - ``RandomFlip``, and aggregates into a list of 4 results. - 3. Each result will be processed by ``PackDetInputs``, and - return a list of dict. - 4. Aggregates the same fields of results, and finally returns - a dict. Each value of the dict represents 4 transformed - results. - - Args: - results (dict): Result dict contains the data to transform. - - Returns: - dict: The augmented data, where each value is wrapped - into a list. - """ - results_list = [] # type: ignore - for subroutine in self.subroutines: - result = subroutine(copy.deepcopy(results)) - assert isinstance(result, dict), f"Data processed by {subroutine} must return a dict, but got {result}" - assert result is not None, ( - f"Data processed by {subroutine} in `TestTimeAug` should not " - "be None! Please check your validation dataset and the " - f"transforms in {subroutine}" - ) - results_list.append(result) - - aug_data_dict = { - key: [item[key] for item in results_list] # type: ignore - for key in results_list[0] # type: ignore - } - return aug_data_dict - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += "transforms=\n" - for subroutine in self.subroutines: - repr_str += f"{subroutine!r}\n" - return repr_str - - -@TRANSFORMS.register_module() -class RandomChoiceResize(BaseTransform): - """Resize images & bbox & mask from a list of multiple scales. - - This transform resizes the input image to some scale. Bboxes and masks are - then resized with the same scale factor. Resize scale will be randomly - selected from ``scales``. - - How to choose the target scale to resize the image will follow the rules - below: - - - if `scale` is a list of tuple, the target scale is sampled from the list - uniformally. - - if `scale` is a tuple, the target scale will be set to the tuple. - - Required Keys: - - - img - - gt_bboxes (optional) - - gt_seg_map (optional) - - gt_keypoints (optional) - - Modified Keys: - - - img - - img_shape - - gt_bboxes (optional) - - gt_seg_map (optional) - - gt_keypoints (optional) - - Added Keys: - - - scale - - scale_factor - - scale_idx - - keep_ratio - - - Args: - scales (Union[list, Tuple]): Images scales for resizing. - resize_type (str): The type of resize class to use. Defaults to - "Resize". - **resize_kwargs: Other keyword arguments for the ``resize_type``. - - Note: - By defaults, the ``resize_type`` is "Resize", if it's not overwritten - by your registry, it indicates the :class:`mmcv.Resize`. And therefore, - ``resize_kwargs`` accepts any keyword arguments of it, like - ``keep_ratio``, ``interpolation`` and so on. - - If you want to use your custom resize class, the class should accept - ``scale`` argument and have ``scale`` attribution which determines the - resize shape. - """ - - def __init__( - self, - scales: Sequence[int | tuple], - resize_type: str = "Resize", - **resize_kwargs, - ) -> None: - super().__init__() - if isinstance(scales, list): - self.scales = scales - else: - self.scales = [scales] - assert is_seq_of(self.scales, (tuple, int)) - - self.resize_cfg = dict(type=resize_type, **resize_kwargs) - # create a empty Resize object - self.resize = TRANSFORMS.build({"scale": 0, **self.resize_cfg}) - - @cache_randomness - def _random_select(self) -> tuple[int, int]: - """Randomly select an scale from given candidates. - - Returns: - (tuple, int): Returns a tuple ``(scale, scale_dix)``, - where ``scale`` is the selected image scale and - ``scale_idx`` is the selected index in the given candidates. - """ - - scale_idx = np.random.randint(len(self.scales)) - scale = self.scales[scale_idx] - return scale, scale_idx - - def transform(self, results: dict) -> dict: - """Apply resize transforms on results from a list of scales. - - Args: - results (dict): Result dict contains the data to transform. - - Returns: - dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', - 'gt_keypoints', 'scale', 'scale_factor', 'img_shape', - and 'keep_ratio' keys are updated in result dict. - """ - - target_scale, scale_idx = self._random_select() - self.resize.scale = target_scale - results = self.resize(results) - results["scale_idx"] = scale_idx - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(scales={self.scales}" - repr_str += f", resize_cfg={self.resize_cfg})" - return repr_str - - -@TRANSFORMS.register_module() -class RandomFlip(BaseTransform): - """Flip the image & bbox & keypoints & segmentation map. Added or Updated - keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and - gt_keypoints. There are 3 flip modes: - - - ``prob`` is float, ``direction`` is string: the image will be - ``direction``ly flipped with probability of ``prob`` . - E.g., ``prob=0.5``, ``direction='horizontal'``, - then image will be horizontally flipped with probability of 0.5. - - - ``prob`` is float, ``direction`` is list of string: the image will - be ``direction[i]``ly flipped with probability of - ``prob/len(direction)``. - E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, - then image will be horizontally flipped with probability of 0.25, - vertically with probability of 0.25. - - - ``prob`` is list of float, ``direction`` is list of string: - given ``len(prob) == len(direction)``, the image will - be ``direction[i]``ly flipped with probability of ``prob[i]``. - E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', - 'vertical']``, then image will be horizontally flipped with - probability of 0.3, vertically with probability of 0.5. - - Required Keys: - - - img - - gt_bboxes (optional) - - gt_seg_map (optional) - - gt_keypoints (optional) - - Modified Keys: - - - img - - gt_bboxes (optional) - - gt_seg_map (optional) - - gt_keypoints (optional) - - Added Keys: - - - flip - - flip_direction - - swap_seg_labels (optional) - - Args: - prob (float | list[float], optional): The flipping probability. - Defaults to None. - direction(str | list[str]): The flipping direction. Options - If input is a list, the length must equal ``prob``. Each - element in ``prob`` indicates the flip probability of - corresponding direction. Defaults to 'horizontal'. - swap_seg_labels (list, optional): The label pair need to be swapped - for ground truth, like 'left arm' and 'right arm' need to be - swapped after horizontal flipping. For example, ``[(1, 5)]``, - where 1/5 is the label of the left/right arm. Defaults to None. - """ - - def __init__( - self, - prob: float | Iterable[float] | None = None, - direction: str | Sequence[str | None] = "horizontal", - swap_seg_labels: Sequence | None = None, - ) -> None: - if isinstance(prob, list): - assert is_list_of(prob, float) - assert 0 <= sum(prob) <= 1 - elif isinstance(prob, float): - assert 0 <= prob <= 1 - else: - raise ValueError( - f"probs must be float or list of float, but \ - got `{type(prob)}`." - ) - self.prob = prob - self.swap_seg_labels = swap_seg_labels - - valid_directions = ["horizontal", "vertical", "diagonal"] - if isinstance(direction, str): - assert direction in valid_directions - elif isinstance(direction, list): - assert is_list_of(direction, str) - assert set(direction).issubset(set(valid_directions)) - else: - raise ValueError( - f"direction must be either str or list of str, \ - but got `{type(direction)}`." - ) - self.direction = direction - - if isinstance(prob, list): - assert len(prob) == len(self.direction) - - def _flip_bbox(self, bboxes: np.ndarray, img_shape: tuple[int, int], direction: str) -> np.ndarray: - """Flip bboxes horizontally. - - Args: - bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) - img_shape (tuple[int]): Image shape (height, width) - direction (str): Flip direction. Options are 'horizontal', - 'vertical', and 'diagonal'. - - Returns: - numpy.ndarray: Flipped bounding boxes. - """ - # Handle BaseBoxes objects using their own flip method - if hasattr(bboxes, "flip_"): - flipped = bboxes.clone() - flipped.flip_(img_shape, direction) - return flipped - - # Handle numpy arrays - assert bboxes.shape[-1] % 4 == 0 - flipped = bboxes.copy() - h, w = img_shape - if direction == "horizontal": - flipped[..., 0::4] = w - bboxes[..., 2::4] - flipped[..., 2::4] = w - bboxes[..., 0::4] - elif direction == "vertical": - flipped[..., 1::4] = h - bboxes[..., 3::4] - flipped[..., 3::4] = h - bboxes[..., 1::4] - elif direction == "diagonal": - flipped[..., 0::4] = w - bboxes[..., 2::4] - flipped[..., 1::4] = h - bboxes[..., 3::4] - flipped[..., 2::4] = w - bboxes[..., 0::4] - flipped[..., 3::4] = h - bboxes[..., 1::4] - else: - raise ValueError( - f"Flipping direction must be 'horizontal', 'vertical', \ - or 'diagonal', but got '{direction}'" - ) - return flipped - - def _flip_keypoints( - self, - keypoints: np.ndarray, - img_shape: tuple[int, int], - direction: str, - ) -> np.ndarray: - """Flip keypoints horizontally, vertically or diagonally. - - Args: - keypoints (numpy.ndarray): Keypoints, shape (..., 2) - img_shape (tuple[int]): Image shape (height, width) - direction (str): Flip direction. Options are 'horizontal', - 'vertical', and 'diagonal'. - - Returns: - numpy.ndarray: Flipped keypoints. - """ - - meta_info = keypoints[..., 2:] - keypoints = keypoints[..., :2] - flipped = keypoints.copy() - h, w = img_shape - if direction == "horizontal": - flipped[..., 0::2] = w - keypoints[..., 0::2] - elif direction == "vertical": - flipped[..., 1::2] = h - keypoints[..., 1::2] - elif direction == "diagonal": - flipped[..., 0::2] = w - keypoints[..., 0::2] - flipped[..., 1::2] = h - keypoints[..., 1::2] - else: - raise ValueError( - f"Flipping direction must be 'horizontal', 'vertical', \ - or 'diagonal', but got '{direction}'" - ) - flipped = np.concatenate([flipped, meta_info], axis=-1) - return flipped - - def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray: - """Flip segmentation map horizontally, vertically or diagonally. - - Args: - seg_map (numpy.ndarray): segmentation map, shape (H, W). - direction (str): Flip direction. Options are 'horizontal', - 'vertical'. - - Returns: - numpy.ndarray: Flipped segmentation map. - """ - seg_map = imflip(seg_map, direction=direction) - if self.swap_seg_labels is not None: - # to handle datasets with left/right annotations - # like 'Left-arm' and 'Right-arm' in LIP dataset - # Modified from https://github.com/openseg-group/openseg.pytorch/blob/master/lib/datasets/tools/cv2_aug_transforms.py - # Licensed under MIT license - temp = seg_map.copy() - assert isinstance(self.swap_seg_labels, (tuple, list)) - for pair in self.swap_seg_labels: - assert isinstance(pair, (tuple, list)) and len(pair) == 2, ( - f"swap_seg_labels must be a sequence with pair, but got {self.swap_seg_labels}." - ) - seg_map[temp == pair[0]] = pair[1] - seg_map[temp == pair[1]] = pair[0] - return seg_map - - @cache_randomness - def _choose_direction(self) -> str: - """Choose the flip direction according to `prob` and `direction`""" - if isinstance(self.direction, Sequence) and not isinstance(self.direction, str): - # None means non-flip - direction_list: list = list(self.direction) + [None] - elif isinstance(self.direction, str): - # None means non-flip - direction_list = [self.direction, None] - - if isinstance(self.prob, list): - non_prob: float = 1 - sum(self.prob) - prob_list = self.prob + [non_prob] - elif isinstance(self.prob, float): - non_prob = 1.0 - self.prob - # exclude non-flip - single_ratio = self.prob / (len(direction_list) - 1) - prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob] - - cur_dir = np.random.choice(direction_list, p=prob_list) - - return cur_dir - - def _flip(self, results: dict) -> None: - """Flip images, bounding boxes, semantic segmentation map and - keypoints.""" - # flip image - results["img"] = imflip(results["img"], direction=results["flip_direction"]) - - img_shape = results["img"].shape[:2] - - # flip bboxes - if results.get("gt_bboxes", None) is not None: - results["gt_bboxes"] = self._flip_bbox(results["gt_bboxes"], img_shape, results["flip_direction"]) - - # flip keypoints - if results.get("gt_keypoints", None) is not None: - results["gt_keypoints"] = self._flip_keypoints( - results["gt_keypoints"], img_shape, results["flip_direction"] - ) - - # flip seg map - if results.get("gt_seg_map", None) is not None: - results["gt_seg_map"] = self._flip_seg_map(results["gt_seg_map"], direction=results["flip_direction"]) - results["swap_seg_labels"] = self.swap_seg_labels - - def _flip_on_direction(self, results: dict) -> None: - """Function to flip images, bounding boxes, semantic segmentation map - and keypoints.""" - cur_dir = self._choose_direction() - if cur_dir is None: - results["flip"] = False - results["flip_direction"] = None - else: - results["flip"] = True - results["flip_direction"] = cur_dir - self._flip(results) - - def transform(self, results: dict) -> dict: - """Transform function to flip images, bounding boxes, semantic - segmentation map and keypoints. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Flipped results, 'img', 'gt_bboxes', 'gt_seg_map', - 'gt_keypoints', 'flip', and 'flip_direction' keys are - updated in result dict. - """ - self._flip_on_direction(results) - - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(prob={self.prob}, " - repr_str += f"direction={self.direction})" - - return repr_str - - -@TRANSFORMS.register_module() -class RandomResize(BaseTransform): - """Random resize images & bbox & keypoints. - - How to choose the target scale to resize the image will follow the rules - below: - - - if ``scale`` is a sequence of tuple - - .. math:: - target\\_scale[0] \\sim Uniform([scale[0][0], scale[1][0]]) - .. math:: - target\\_scale[1] \\sim Uniform([scale[0][1], scale[1][1]]) - - Following the resize order of weight and height in cv2, ``scale[i][0]`` - is for width, and ``scale[i][1]`` is for height. - - - if ``scale`` is a tuple - - .. math:: - target\\_scale[0] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]]) - * scale[0] - .. math:: - target\\_scale[1] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]]) - * scale[1] - - Following the resize order of weight and height in cv2, ``ratio_range[0]`` - is for width, and ``ratio_range[1]`` is for height. - - - if ``keep_ratio`` is True, the minimum value of ``target_scale`` will be - used to set the shorter side and the maximum value will be used to - set the longer side. - - - if ``keep_ratio`` is False, the value of ``target_scale`` will be used to - reisze the width and height accordingly. - - Required Keys: - - - img - - gt_bboxes - - gt_seg_map - - gt_keypoints - - Modified Keys: - - - img - - gt_bboxes - - gt_seg_map - - gt_keypoints - - img_shape - - Added Keys: - - - scale - - scale_factor - - keep_ratio - - Args: - scale (tuple or Sequence[tuple]): Images scales for resizing. - Defaults to None. - ratio_range (tuple[float], optional): (min_ratio, max_ratio). - Defaults to None. - resize_type (str): The type of resize class to use. Defaults to - "Resize". - **resize_kwargs: Other keyword arguments for the ``resize_type``. - - Note: - By defaults, the ``resize_type`` is "Resize", if it's not overwritten - by your registry, it indicates the :class:`mmcv.Resize`. And therefore, - ``resize_kwargs`` accepts any keyword arguments of it, like - ``keep_ratio``, ``interpolation`` and so on. - - If you want to use your custom resize class, the class should accept - ``scale`` argument and have ``scale`` attribution which determines the - resize shape. - """ - - def __init__( - self, - scale: tuple[int, int] | Sequence[tuple[int, int]], - ratio_range: tuple[float, float] | None = None, - resize_type: str = "Resize", - **resize_kwargs, - ) -> None: - self.scale = scale - self.ratio_range = ratio_range - - self.resize_cfg = dict(type=resize_type, **resize_kwargs) - # create a empty Reisize object - self.resize = TRANSFORMS.build({"scale": 0, **self.resize_cfg}) - - @staticmethod - def _random_sample(scales: Sequence[tuple[int, int]]) -> tuple: - """Private function to randomly sample a scale from a list of tuples. - - Args: - scales (list[tuple]): Images scale range for sampling. - There must be two tuples in scales, which specify the lower - and upper bound of image scales. - - Returns: - tuple: The targeted scale of the image to be resized. - """ - - assert is_list_of(scales, tuple) and len(scales) == 2 - scale_0 = [scales[0][0], scales[1][0]] - scale_1 = [scales[0][1], scales[1][1]] - edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1) - edge_1 = np.random.randint(min(scale_1), max(scale_1) + 1) - scale = (edge_0, edge_1) - return scale - - @staticmethod - def _random_sample_ratio(scale: tuple, ratio_range: tuple[float, float]) -> tuple: - """Private function to randomly sample a scale from a tuple. - - A ratio will be randomly sampled from the range specified by - ``ratio_range``. Then it would be multiplied with ``scale`` to - generate sampled scale. - - Args: - scale (tuple): Images scale base to multiply with ratio. - ratio_range (tuple[float]): The minimum and maximum ratio to scale - the ``scale``. - - Returns: - tuple: The targeted scale of the image to be resized. - """ - - assert isinstance(scale, tuple) and len(scale) == 2 - min_ratio, max_ratio = ratio_range - assert min_ratio <= max_ratio - ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio - scale = int(scale[0] * ratio), int(scale[1] * ratio) - return scale - - @cache_randomness - def _random_scale(self) -> tuple: - """Private function to randomly sample an scale according to the type - of ``scale``. - - Returns: - tuple: The targeted scale of the image to be resized. - """ - - if is_tuple_of(self.scale, int): - assert self.ratio_range is not None and len(self.ratio_range) == 2 - scale = self._random_sample_ratio( - self.scale, # type: ignore - self.ratio_range, - ) - elif is_seq_of(self.scale, tuple): - scale = self._random_sample(self.scale) # type: ignore - else: - raise NotImplementedError(f'Do not support sampling function for "{self.scale}"') - - return scale - - def transform(self, results: dict) -> dict: - """Transform function to resize images, bounding boxes, semantic - segmentation map. - - Args: - results (dict): Result dict from loading pipeline. - - Returns: - dict: Resized results, ``img``, ``gt_bboxes``, ``gt_semantic_seg``, - ``gt_keypoints``, ``scale``, ``scale_factor``, ``img_shape``, and - ``keep_ratio`` keys are updated in result dict. - """ - results["scale"] = self._random_scale() - self.resize.scale = results["scale"] - results = self.resize(results) - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(scale={self.scale}, " - repr_str += f"ratio_range={self.ratio_range}, " - repr_str += f"resize_cfg={self.resize_cfg})" - return repr_str diff --git a/libs/viscv/viscv/transforms/utils.py b/libs/viscv/viscv/transforms/utils.py deleted file mode 100644 index 38c9b59..0000000 --- a/libs/viscv/viscv/transforms/utils.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import copy -import functools -import inspect -import weakref -from collections import defaultdict -from collections.abc import Callable, Iterable -from contextlib import contextmanager - -from .base import BaseTransform - - -class cache_randomness: - """Decorator that marks the method with random return value(s) in a - transform class. - - This decorator is usually used together with the context-manager - :func`:cache_random_params`. In this context, a decorated method will - cache its return value(s) at the first time of being invoked, and always - return the cached values when being invoked again. - - .. note:: - Only an instance method can be decorated with ``cache_randomness``. - """ - - def __init__(self, func): - # Check `func` is to be bound as an instance method - if not inspect.isfunction(func): - raise TypeError("Unsupport callable to decorate with@cache_randomness.") - func_args = inspect.getfullargspec(func).args - if len(func_args) == 0 or func_args[0] != "self": - raise TypeError( - "@cache_randomness should only be used to decorate instance methods (the first argument is ``self``)." - ) - - functools.update_wrapper(self, func) - self.func = func - self.instance_ref = None - - def __set_name__(self, owner, name): - # Maintain a record of decorated methods in the class - if not hasattr(owner, "_methods_with_randomness"): - owner._methods_with_randomness = [] - - # Here `name` equals to `self.__name__`, i.e., the name of the - # decorated function, due to the invocation of `update_wrapper` in - # `self.__init__()` - owner._methods_with_randomness.append(name) - - def __call__(self, *args, **kwargs): - # Get the transform instance whose method is decorated - # by cache_randomness - instance = self.instance_ref() - name = self.__name__ - - # Check the flag ``self._cache_enabled``, which should be - # set by the contextmanagers like ``cache_random_parameters``` - cache_enabled = getattr(instance, "_cache_enabled", False) - - if cache_enabled: - # Initialize the cache of the transform instances. The flag - # ``cache_enabled``` is set by contextmanagers like - # ``cache_random_params```. - if not hasattr(instance, "_cache"): - instance._cache = {} - - if name not in instance._cache: - instance._cache[name] = self.func(instance, *args, **kwargs) - # Return the cached value - return instance._cache[name] - else: - # Clear cache - if hasattr(instance, "_cache"): - del instance._cache - # Return function output - return self.func(instance, *args, **kwargs) - - def __get__(self, obj, cls): - self.instance_ref = weakref.ref(obj) - # Return a copy to avoid multiple transform instances sharing - # one `cache_randomness` instance, which may cause data races - # in multithreading cases. - return copy.copy(self) - - -def avoid_cache_randomness(cls): - """Decorator that marks a data transform class (subclass of - :class:`BaseTransform`) prohibited from caching randomness. With this - decorator, errors will be raised in following cases: - - 1. A method is defined in the class with the decorate - `cache_randomness`; - 2. An instance of the class is invoked with the context - `cache_random_params`. - - A typical usage of `avoid_cache_randomness` is to decorate the data - transforms with non-cacheable random behaviors (e.g., the random behavior - can not be defined in a method, thus can not be decorated with - `cache_randomness`). This is for preventing unintentinoal use of such data - transforms within the context of caching randomness, which may lead to - unexpected results. - """ - - # Check that cls is a data transform class - assert issubclass(cls, BaseTransform) - - # Check that no method is decorated with `cache_randomness` in cls - if getattr(cls, "_methods_with_randomness", None): - raise RuntimeError( - f"Class {cls.__name__} decorated with " - "``avoid_cache_randomness`` should not have methods decorated " - "with ``cache_randomness`` (invalid methods: " - f"{cls._methods_with_randomness})" - ) - - class AvoidCacheRandomness: - def __get__(self, obj, objtype=None): - # Here we check the value in `objtype.__dict__` instead of - # directly checking the attribute - # `objtype._avoid_cache_randomness`. So if the base class is - # decorated with :func:`avoid_cache_randomness`, it will not be - # inherited by subclasses. - return objtype.__dict__.get("_avoid_cache_randomness", False) - - cls.avoid_cache_randomness = AvoidCacheRandomness() - cls._avoid_cache_randomness = True - - return cls - - -@contextmanager -def cache_random_params(transforms: BaseTransform | Iterable): - """Context-manager that enables the cache of return values of methods - decorated with ``cache_randomness`` in transforms. - - In this mode, decorated methods will cache their return values on the - first invoking, and always return the cached value afterward. This allow - to apply random transforms in a deterministic way. For example, apply same - transforms on multiple examples. See ``cache_randomness`` for more - information. - - Args: - transforms (BaseTransform|list[BaseTransform]): The transforms to - enable cache. - """ - - # key2method stores the original methods that are replaced by the wrapped - # ones. These methods will be restituted when exiting the context. - key2method = dict() - - # key2counter stores the usage number of each cache_randomness. This is - # used to check that any cache_randomness is invoked once during processing - # on data sample. - key2counter: dict = defaultdict(int) - - def _add_invoke_counter(obj, method_name): - method = getattr(obj, method_name) - key = f"{id(obj)}.{method_name}" - key2method[key] = method - - @functools.wraps(method) - def wrapped(*args, **kwargs): - key2counter[key] += 1 - return method(*args, **kwargs) - - return wrapped - - def _add_invoke_checker(obj, method_name): - # check that the method in _methods_with_randomness has been - # invoked at most once - method = getattr(obj, method_name) - key = f"{id(obj)}.{method_name}" - key2method[key] = method - - @functools.wraps(method) - def wrapped(*args, **kwargs): - # clear counter - for name in obj._methods_with_randomness: - key = f"{id(obj)}.{name}" - key2counter[key] = 0 - - output = method(*args, **kwargs) - - for name in obj._methods_with_randomness: - key = f"{id(obj)}.{name}" - if key2counter[key] > 1: - raise RuntimeError( - "The method decorated with ``cache_randomness`` " - "should be invoked at most once during processing " - f"one data sample. The method {name} of {obj} has " - f"been invoked {key2counter[key]} times." - ) - return output - - return wrapped - - def _start_cache(t: BaseTransform): - # Check if cache is allowed for `t` - if getattr(t, "avoid_cache_randomness", False): - raise RuntimeError( - f"Class {t.__class__.__name__} decorated with " - "``avoid_cache_randomness`` is not allowed to be used with" - " ``cache_random_params`` (e.g. wrapped by " - "``ApplyToMultiple`` with ``share_random_params==True``)." - ) - - # Skip transforms w/o random method - if not hasattr(t, "_methods_with_randomness"): - return - - # Set cache enabled flag - t._cache_enabled = True - - # Store the original method and init the counter - if hasattr(t, "_methods_with_randomness"): - t.transform = _add_invoke_checker(t, "transform") - for name in t._methods_with_randomness: - setattr(t, name, _add_invoke_counter(t, name)) - - def _end_cache(t: BaseTransform): - # Skip transforms w/o random method - if not hasattr(t, "_methods_with_randomness"): - return - - # Remove cache enabled flag - delattr(t, "_cache_enabled") - if hasattr(t, "_cache"): - delattr(t, "_cache") - - # Restore the original method - if hasattr(t, "_methods_with_randomness"): - for name in t._methods_with_randomness: - key = f"{id(t)}.{name}" - setattr(t, name, key2method[key]) - - key_transform = f"{id(t)}.transform" - t.transform = key2method[key_transform] - - def _apply(t: BaseTransform | Iterable, func: Callable[[BaseTransform], None]): - if isinstance(t, BaseTransform): - func(t) - if isinstance(t, Iterable): - for _t in t: - _apply(_t, func) - - try: - _apply(transforms, _start_cache) - yield - finally: - _apply(transforms, _end_cache) diff --git a/libs/viscv/viscv/transforms/wrappers.py b/libs/viscv/viscv/transforms/wrappers.py deleted file mode 100644 index ef8df90..0000000 --- a/libs/viscv/viscv/transforms/wrappers.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -from collections.abc import Callable, Sequence -from typing import Any, Union - -import numpy as np -from visengine.utils import is_seq_of - -from .base import BaseTransform -from .builder import TRANSFORMS -from .utils import cache_random_params, cache_randomness - -# Define type of transform or transform config -Transform = Union[dict, Callable[[dict], dict]] - -# Indicator of keys marked by KeyMapper._map_input, which means ignoring the -# marked keys in KeyMapper._apply_transform so they will be invisible to -# wrapped transforms. -# This can be 2 possible case: -# 1. The key is required but missing in results -# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means -# the original value in results should be ignored -IgnoreKey = object() - -# Import nullcontext if python>=3.7, otherwise use a simple alternative -# implementation. -try: - from contextlib import nullcontext # type: ignore -except ImportError: - from contextlib import contextmanager - - @contextmanager # type: ignore - def nullcontext(resource=None): - try: - yield resource - finally: - pass - - -@TRANSFORMS.register_module() -class Compose(BaseTransform): - """Compose multiple transforms sequentially. - - Args: - transforms (list[dict | callable]): Sequence of transform object or - config dict to be composed. - - Examples: - >>> pipeline = [ - >>> dict(type='Compose', - >>> transforms=[ - >>> dict(type='LoadImageFromFile'), - >>> dict(type='Normalize') - >>> ] - >>> ) - >>> ] - """ - - def __init__(self, transforms: Transform | Sequence[Transform]): - super().__init__() - - if not isinstance(transforms, Sequence): - transforms = [transforms] - self.transforms: list = [] - for transform in transforms: - if isinstance(transform, dict): - transform = TRANSFORMS.build(transform) - self.transforms.append(transform) - elif callable(transform): - self.transforms.append(transform) - else: - raise TypeError(f"transform must be callable or a dict, but got {type(transform)}") - - def __iter__(self): - """Allow easy iteration over the transform sequence.""" - return iter(self.transforms) - - def transform(self, results: dict) -> dict | None: - """Call function to apply transforms sequentially. - - Args: - results (dict): A result dict contains the results to transform. - - Returns: - dict or None: Transformed results. - """ - for t in self.transforms: - results = t(results) # type: ignore - if results is None: - return None - return results - - def __repr__(self): - """Compute the string representation.""" - format_string = self.__class__.__name__ + "(" - for t in self.transforms: - format_string += f"\n {t}" - format_string += "\n)" - return format_string - - -@TRANSFORMS.register_module() -class KeyMapper(BaseTransform): - """A transform wrapper to map and reorganize the input/output of the - wrapped transforms (or sub-pipeline). - - Args: - transforms (list[dict | callable], optional): Sequence of transform - object or config dict to be wrapped. - mapping (dict): A dict that defines the input key mapping. - The keys corresponds to the inner key (i.e., kwargs of the - ``transform`` method), and should be string type. The values - corresponds to the outer keys (i.e., the keys of the - data/results), and should have a type of string, list or dict. - None means not applying input mapping. Default: None. - remapping (dict): A dict that defines the output key mapping. - The keys and values have the same meanings and rules as in the - ``mapping``. Default: None. - auto_remap (bool, optional): If True, an inverse of the mapping will - be used as the remapping. If auto_remap is not given, it will be - automatically set True if 'remapping' is not given, and vice - versa. Default: None. - allow_nonexist_keys (bool): If False, the outer keys in the mapping - must exist in the input data, or an exception will be raised. - Default: False. - - Examples: - >>> # Example 1: KeyMapper 'gt_img' to 'img' - >>> pipeline = [ - >>> # Use KeyMapper to convert outer (original) field name - >>> # 'gt_img' to inner (used by inner transforms) filed name - >>> # 'img' - >>> dict(type='KeyMapper', - >>> mapping={'img': 'gt_img'}, - >>> # auto_remap=True means output key mapping is the revert of - >>> # the input key mapping, e.g. inner 'img' will be mapped - >>> # back to outer 'gt_img' - >>> auto_remap=True, - >>> transforms=[ - >>> # In all transforms' implementation just use 'img' - >>> # as a standard field name - >>> dict(type='Crop', crop_size=(384, 384)), - >>> dict(type='Normalize'), - >>> ]) - >>> ] - - >>> # Example 2: Collect and structure multiple items - >>> pipeline = [ - >>> # The inner field 'imgs' will be a dict with keys 'img_src' - >>> # and 'img_tar', whose values are outer fields 'img1' and - >>> # 'img2' respectively. - >>> dict(type='KeyMapper', - >>> dict( - >>> type='KeyMapper', - >>> mapping=dict( - >>> imgs=dict( - >>> img_src='img1', - >>> img_tar='img2')), - >>> transforms=...) - >>> ] - - >>> # Example 3: Manually set ignored keys by "..." - >>> pipeline = [ - >>> ... - >>> dict(type='KeyMapper', - >>> mapping={ - >>> # map outer key "gt_img" to inner key "img" - >>> 'img': 'gt_img', - >>> # ignore outer key "mask" - >>> 'mask': ..., - >>> }, - >>> transforms=[ - >>> dict(type='RandomFlip'), - >>> ]) - >>> ... - >>> ] - """ - - def __init__( - self, - transforms: Transform | list[Transform] | None = None, - mapping: dict | None = None, - remapping: dict | None = None, - auto_remap: bool | None = None, - allow_nonexist_keys: bool = False, - ): - super().__init__() - - self.allow_nonexist_keys = allow_nonexist_keys - self.mapping = mapping - - if auto_remap is None: - auto_remap = remapping is None - self.auto_remap = auto_remap - - if self.auto_remap: - if remapping is not None: - raise ValueError("KeyMapper: ``remapping`` must be None if`auto_remap` is set True.") - self.remapping = mapping - else: - self.remapping = remapping - - if transforms is None: - transforms = [] - self.transforms = Compose(transforms) - - def __iter__(self): - """Allow easy iteration over the transform sequence.""" - return iter(self.transforms) - - def _map_input(self, data: dict, mapping: dict | None) -> dict[str, Any]: - """KeyMapper inputs for the wrapped transforms by gathering and - renaming data items according to the mapping. - - Args: - data (dict): The original input data - mapping (dict, optional): The input key mapping. See the document - of ``mmcv.transforms.wrappers.KeyMapper`` for details. In - set None, return the input data directly. - - Returns: - dict: The input data with remapped keys. This will be the actual - input of the wrapped pipeline. - """ - - if mapping is None: - return data.copy() - - def _map(data, m): - if isinstance(m, dict): - # m is a dict {inner_key:outer_key, ...} - return {k_in: _map(data, k_out) for k_in, k_out in m.items()} - if isinstance(m, (tuple, list)): - # m is a list or tuple [outer_key1, outer_key2, ...] - # This is the case when we collect items from the original - # data to form a list or tuple to feed to the wrapped - # transforms. - return m.__class__(_map(data, e) for e in m) - - # allow manually mark a key to be ignored by ... - if m is ...: - return IgnoreKey - - # m is an outer_key - if self.allow_nonexist_keys: - return data.get(m, IgnoreKey) - else: - return data.get(m) - - collected = _map(data, mapping) - - # Retain unmapped items - inputs = data.copy() - inputs.update(collected) - - return inputs - - def _map_output(self, data: dict, remapping: dict | None) -> dict[str, Any]: - """KeyMapper outputs from the wrapped transforms by gathering and - renaming data items according to the remapping. - - Args: - data (dict): The output of the wrapped pipeline. - remapping (dict, optional): The output key mapping. See the - document of ``mmcv.transforms.wrappers.KeyMapper`` for - details. If ``remapping is None``, no key mapping will be - applied but only remove the special token ``IgnoreKey``. - - Returns: - dict: The output with remapped keys. - """ - - # Remove ``IgnoreKey`` - if remapping is None: - return {k: v for k, v in data.items() if v is not IgnoreKey} - - def _map(data, m): - if isinstance(m, dict): - assert isinstance(data, dict) - results = {} - for k_in, k_out in m.items(): - assert k_in in data - results.update(_map(data[k_in], k_out)) - return results - if isinstance(m, (list, tuple)): - assert isinstance(data, (list, tuple)) - assert len(data) == len(m) - results = {} - for m_i, d_i in zip(m, data, strict=False): - results.update(_map(d_i, m_i)) - return results - - # ``m is ...`` means the key is marked ignored, in which case the - # inner resuls will not affect the outer results in remapping. - # Another case that will have ``data is IgnoreKey`` is that the - # key is missing in the inputs. In this case, if the inner key is - # created by the wrapped transforms, it will be remapped to the - # corresponding outer key during remapping. - if m is ... or data is IgnoreKey: - return {} - - return {m: data} - - # Note that unmapped items are not retained, which is different from - # the behavior in _map_input. This is to avoid original data items - # being overwritten by intermediate namesakes - return _map(data, remapping) - - def _apply_transforms(self, inputs: dict) -> dict: - """Apply ``self.transforms``. - - Note that the special token ``IgnoreKey`` will be invisible to - ``self.transforms``, but not removed in this method. It will be - eventually removed in :func:``self._map_output``. - """ - results = inputs.copy() - inputs = {k: v for k, v in inputs.items() if v is not IgnoreKey} - outputs = self.transforms(inputs) - - if outputs is None: - raise ValueError(f"Transforms wrapped by {self.__class__.__name__} should not return None.") - - results.update(outputs) # type: ignore - return results - - def transform(self, results: dict) -> dict: - """Apply mapping, wrapped transforms and remapping.""" - - # Apply mapping - inputs = self._map_input(results, self.mapping) - # Apply wrapped transforms - outputs = self._apply_transforms(inputs) - # Apply remapping - outputs = self._map_output(outputs, self.remapping) - - results.update(outputs) # type: ignore - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(transforms = {self.transforms}" - repr_str += f", mapping = {self.mapping}" - repr_str += f", remapping = {self.remapping}" - repr_str += f", auto_remap = {self.auto_remap}" - repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys})" - return repr_str - - -@TRANSFORMS.register_module() -class TransformBroadcaster(KeyMapper): - """A transform wrapper to apply the wrapped transforms to multiple data - items. For example, apply Resize to multiple images. - - Args: - transforms (list[dict | callable]): Sequence of transform object or - config dict to be wrapped. - mapping (dict): A dict that defines the input key mapping. - Note that to apply the transforms to multiple data items, the - outer keys of the target items should be remapped as a list with - the standard inner key (The key required by the wrapped transform). - See the following example and the document of - ``mmcv.transforms.wrappers.KeyMapper`` for details. - remapping (dict): A dict that defines the output key mapping. - The keys and values have the same meanings and rules as in the - ``mapping``. Default: None. - auto_remap (bool, optional): If True, an inverse of the mapping will - be used as the remapping. If auto_remap is not given, it will be - automatically set True if 'remapping' is not given, and vice - versa. Default: None. - allow_nonexist_keys (bool): If False, the outer keys in the mapping - must exist in the input data, or an exception will be raised. - Default: False. - share_random_params (bool): If True, the random transform - (e.g., RandomFlip) will be conducted in a deterministic way and - have the same behavior on all data items. For example, to randomly - flip either both input image and ground-truth image, or none. - Default: False. - - .. note:: - To apply the transforms to each elements of a list or tuple, instead - of separating data items, you can map the outer key of the target - sequence to the standard inner key. See example 2. - example. - - Examples: - >>> # Example 1: Broadcast to enumerated keys, each contains a single - >>> # data element - >>> pipeline = [ - >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img - >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img - >>> # TransformBroadcaster maps multiple outer fields to standard - >>> # the inner field and process them with wrapped transforms - >>> # respectively - >>> dict(type='TransformBroadcaster', - >>> # case 1: from multiple outer fields - >>> mapping={'img': ['lq', 'gt']}, - >>> auto_remap=True, - >>> # share_random_param=True means using identical random - >>> # parameters in every processing - >>> share_random_param=True, - >>> transforms=[ - >>> dict(type='Crop', crop_size=(384, 384)), - >>> dict(type='Normalize'), - >>> ]) - >>> ] - - >>> # Example 2: Broadcast to keys that contains data sequences - >>> pipeline = [ - >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img - >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img - >>> # TransformBroadcaster maps multiple outer fields to standard - >>> # the inner field and process them with wrapped transforms - >>> # respectively - >>> dict(type='TransformBroadcaster', - >>> # case 2: from one outer field that contains multiple - >>> # data elements (e.g. a list) - >>> # mapping={'img': 'images'}, - >>> auto_remap=True, - >>> share_random_param=True, - >>> transforms=[ - >>> dict(type='Crop', crop_size=(384, 384)), - >>> dict(type='Normalize'), - >>> ]) - >>> ] - - >>> Example 3: Set ignored keys in broadcasting - >>> pipeline = [ - >>> dict(type='TransformBroadcaster', - >>> # Broadcast the wrapped transforms to multiple images - >>> # 'lq' and 'gt, but only update 'img_shape' once - >>> mapping={ - >>> 'img': ['lq', 'gt'], - >>> 'img_shape': ['img_shape', ...], - >>> }, - >>> auto_remap=True, - >>> share_random_params=True, - >>> transforms=[ - >>> # `RandomCrop` will modify the field "img", - >>> # and optionally update "img_shape" if it exists - >>> dict(type='RandomCrop'), - >>> ]) - >>> ] - """ - - def __init__( - self, - transforms: list[dict | Callable[[dict], dict]], - mapping: dict | None = None, - remapping: dict | None = None, - auto_remap: bool | None = None, - allow_nonexist_keys: bool = False, - share_random_params: bool = False, - ): - super().__init__(transforms, mapping, remapping, auto_remap, allow_nonexist_keys) - - self.share_random_params = share_random_params - - def scatter_sequence(self, data: dict) -> list[dict]: - """Scatter the broadcasting targets to a list of inputs of the wrapped - transforms.""" - - # infer split number from input - seq_len = 0 - key_rep = None - - if self.mapping: - keys = self.mapping.keys() - else: - keys = data.keys() - - for key in keys: - assert isinstance(data[key], Sequence) - if seq_len: - if len(data[key]) != seq_len: - raise ValueError( - f"Got inconsistent sequence length: {seq_len} ({key_rep}) vs. {len(data[key])} ({key})" - ) - else: - seq_len = len(data[key]) - key_rep = key - - assert seq_len > 0, "Fail to get the number of broadcasting targets" - - scatters = [] - for i in range(seq_len): # type: ignore - scatter = data.copy() - for key in keys: - scatter[key] = data[key][i] - scatters.append(scatter) - return scatters - - def transform(self, results: dict): - """Broadcast wrapped transforms to multiple targets.""" - - # Apply input remapping - inputs = self._map_input(results, self.mapping) - - # Scatter sequential inputs into a list - input_scatters = self.scatter_sequence(inputs) - - # Control random parameter sharing with a context manager - if self.share_random_params: - # The context manager :func`:cache_random_params` will let - # cacheable method of the transforms cache their outputs. Thus - # the random parameters will only generated once and shared - # by all data items. - ctx = cache_random_params # type: ignore - else: - ctx = nullcontext # type: ignore - - with ctx(self.transforms): - output_scatters = [self._apply_transforms(_input) for _input in input_scatters] - - # Collate output scatters (list of dict to dict of list) - outputs = {key: [_output[key] for _output in output_scatters] for key in output_scatters[0]} - - # Apply remapping - outputs = self._map_output(outputs, self.remapping) - - results.update(outputs) - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(transforms = {self.transforms}" - repr_str += f", mapping = {self.mapping}" - repr_str += f", remapping = {self.remapping}" - repr_str += f", auto_remap = {self.auto_remap}" - repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}" - repr_str += f", share_random_params = {self.share_random_params})" - return repr_str - - -@TRANSFORMS.register_module() -class RandomChoice(BaseTransform): - """Process data with a randomly chosen transform from given candidates. - - Args: - transforms (list[list]): A list of transform candidates, each is a - sequence of transforms. - prob (list[float], optional): The probabilities associated - with each pipeline. The length should be equal to the pipeline - number and the sum should be 1. If not given, a uniform - distribution will be assumed. - - Examples: - >>> # config - >>> pipeline = [ - >>> dict(type='RandomChoice', - >>> transforms=[ - >>> [dict(type='RandomHorizontalFlip')], # subpipeline 1 - >>> [dict(type='RandomRotate')], # subpipeline 2 - >>> ] - >>> ) - >>> ] - """ - - def __init__( - self, - transforms: list[Transform | list[Transform]], - prob: list[float] | None = None, - ): - super().__init__() - - if prob is not None: - assert is_seq_of(prob, float) - assert len(transforms) == len(prob), ( - f"``transforms`` and ``prob`` must have same lengths. Got {len(transforms)} vs {len(prob)}." - ) - assert sum(prob) == 1 - - self.prob = prob - self.transforms = [Compose(transforms) for transforms in transforms] - - def __iter__(self): - return iter(self.transforms) - - @cache_randomness - def random_pipeline_index(self) -> int: - """Return a random transform index.""" - indices = np.arange(len(self.transforms)) - return np.random.choice(indices, p=self.prob) - - def transform(self, results: dict) -> dict | None: - """Randomly choose a transform to apply.""" - idx = self.random_pipeline_index() - return self.transforms[idx](results) - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(transforms = {self.transforms}" - repr_str += f"prob = {self.prob})" - return repr_str - - -@TRANSFORMS.register_module() -class RandomApply(BaseTransform): - """Apply transforms randomly with a given probability. - - Args: - transforms (list[dict | callable]): The transform or transform list - to randomly apply. - prob (float): The probability to apply transforms. Default: 0.5 - - Examples: - >>> # config - >>> pipeline = [ - >>> dict(type='RandomApply', - >>> transforms=[dict(type='HorizontalFlip')], - >>> prob=0.3) - >>> ] - """ - - def __init__(self, transforms: Transform | list[Transform], prob: float = 0.5): - super().__init__() - self.prob = prob - self.transforms = Compose(transforms) - - def __iter__(self): - return iter(self.transforms) - - @cache_randomness - def random_apply(self) -> bool: - """Return a random bool value indicating whether apply the - transform.""" - return np.random.rand() < self.prob - - def transform(self, results: dict) -> dict | None: - """Randomly apply the transform.""" - if self.random_apply(): - return self.transforms(results) # type: ignore - else: - return results - - def __repr__(self) -> str: - repr_str = self.__class__.__name__ - repr_str += f"(transforms = {self.transforms}" - repr_str += f", prob = {self.prob})" - return repr_str diff --git a/libs/visengine/AGENTS.md b/libs/visengine/AGENTS.md deleted file mode 100644 index 507da40..0000000 --- a/libs/visengine/AGENTS.md +++ /dev/null @@ -1,41 +0,0 @@ -# visengine - -This is a simplified version of MMEngine, focused on supporting training and inference for Swin Mask R-CNN. - -## Key Principles - -1. **Simplified Runner**: Keep only the essential training loop functionality -2. **Basic Hooks**: Only maintain hooks actually used by Swin Mask R-CNN training -3. **No Complex Strategies**: Remove distributed training strategies we don't need - -## What to Keep - -- Basic Runner class for training/validation loops -- Essential hooks: - - CheckpointHook - - LoggerHook - - OptimizerHook - - IterTimerHook -- Config system (simplified) -- Registry system -- Basic data structures -- File I/O utilities - -## What to Remove - -- Complex distributed strategies (keep only single GPU and basic DDP) -- Unused hooks and runners -- Advanced features not needed for Swin Mask R-CNN -- Profiling and debugging tools we don't use - -## Dependencies - -Should only depend on: -- PyTorch -- Basic Python packages (numpy, etc.) -- viscv for image operations - ---- - -*For machine learning guidelines, see the machine_learning/AGENTS.md file.* -*For general repository guidelines, see the root AGENTS.md file.* diff --git a/libs/visengine/BUILD.pkl b/libs/visengine/BUILD.pkl deleted file mode 100644 index 26fa8bb..0000000 --- a/libs/visengine/BUILD.pkl +++ /dev/null @@ -1,39 +0,0 @@ -amends "@grog/package.pkl" - -local py_sources = List( - "visengine/**/*", - "pyproject.toml" -) - -targets { - new { - name = "visengine" - inputs { - ...py_sources - } - } - - new { - name = "test" - command = "uv run pytest" - inputs { - ...py_sources - "tests/**/*" - } - - dependencies { - "//tools:uv" - "//machine_learning/packages/ml_env_config" - "//machine_learning/packages/viscv" - } - - platform { - os { - "linux" - } - arch { - "amd64" - } - } - } -} diff --git a/libs/visengine/CLAUDE.md b/libs/visengine/CLAUDE.md deleted file mode 120000 index 47dc3e3..0000000 --- a/libs/visengine/CLAUDE.md +++ /dev/null @@ -1 +0,0 @@ -AGENTS.md \ No newline at end of file diff --git a/libs/visengine/pyproject.toml b/libs/visengine/pyproject.toml deleted file mode 100644 index 229e2e0..0000000 --- a/libs/visengine/pyproject.toml +++ /dev/null @@ -1,33 +0,0 @@ -[project] -name = "visengine" -version = "0.1.0" -description = "Training engine for Swin Mask R-CNN" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "numpy>=2.0.0", - "torch==2.5.1; platform_system == 'Linux' and platform_machine == 'x86_64'", - "pyyaml>=6.0.0", - "addict>=2.4.0", - "yapf>=0.30.0", - "termcolor>=2.0.0", - "rich>=13.0.0", - "opencv-python-headless>=4.8.0", - "matplotlib>=3.6.0", - "tqdm>=4.60.0", - "torchvision>=0.19.1", - "cloudpathlib>=0.18.1", - "google-cloud-storage>=2.18.2", - "bitsandbytes>=0.41.0; platform_system == 'Linux'", -] - -[build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -where = ["."] -include = ["visengine*"] - -[tool.setuptools.package-data] -'*' = ['*.yaml', '*.json'] diff --git a/libs/visengine/tests/test_placeholder.py b/libs/visengine/tests/test_placeholder.py deleted file mode 100644 index ef3a3c6..0000000 --- a/libs/visengine/tests/test_placeholder.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Placeholder test to ensure pytest doesn't fail when no tests are found. -This should be replaced with actual tests for the visengine package. -""" - - -def test_placeholder(): - """Placeholder test that always passes.""" - assert True diff --git a/libs/visengine/visengine/__init__.py b/libs/visengine/visengine/__init__.py deleted file mode 100644 index f552890..0000000 --- a/libs/visengine/visengine/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -# flake8: noqa - -# Import version first to ensure it's available -from .version import __version__, version_info - -# Import other modules -from .config import * -from .fileio import * -from .logging import * -from .registry import * -from .utils import * - -# Re-export version info explicitly at module level -globals()["__version__"] = __version__ -globals()["version_info"] = version_info diff --git a/libs/visengine/visengine/_strategy/__init__.py b/libs/visengine/visengine/_strategy/__init__.py deleted file mode 100644 index 47bfa90..0000000 --- a/libs/visengine/visengine/_strategy/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseStrategy -from .distributed import DDPStrategy -from .single_device import SingleDeviceStrategy - -__all__ = [ - "BaseStrategy", - "DDPStrategy", - "SingleDeviceStrategy", -] diff --git a/libs/visengine/visengine/_strategy/base.py b/libs/visengine/visengine/_strategy/base.py deleted file mode 100644 index 086b939..0000000 --- a/libs/visengine/visengine/_strategy/base.py +++ /dev/null @@ -1,985 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from __future__ import annotations - -import copy -import os.path as osp -import platform -import time -from abc import ABCMeta, abstractmethod -from collections import OrderedDict -from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, TypeVar, Union - -T = TypeVar("T") - -import torch -import torch.nn as nn -from torch.optim import Optimizer - -import visengine -from visengine.config import Config, ConfigDict -from visengine.dist import broadcast, get_dist_info, infer_launcher, is_distributed -from visengine.logging import MMLogger -from visengine.model.wrappers import is_model_wrapper -from visengine.registry import MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS -from visengine.utils import digit_version -from visengine.utils.dl_utils import TORCH_VERSION, collect_env, set_multi_processing - -if TYPE_CHECKING: - from visengine.optim import ( - BaseOptimWrapper, - OptimWrapperDict, - _ParamScheduler, - build_optim_wrapper, - ) - - ParamSchedulerType = Union[list[_ParamScheduler], dict[str, list[_ParamScheduler]]] -else: - # For runtime, we need to define ParamSchedulerType without the imports - ParamSchedulerType = Union[list, dict[str, list]] - - -class BaseStrategy(metaclass=ABCMeta): - """Base class for all strategies. - - In the process of supporting FSDP, DeepSpeed, and ColossalAI, the - scalability of the Runner faced challenges, which led to the redefinition - of the Runner's responsibilities. The Strategy abstraction was split out, - which is responsible for constructing, initializing, and saving/loading - the state of training components such as models, optimizers, and parameter - schedulers. - - Warning: - This is an experimental feature, and its interface is subject to - change. - - Keyword Args: - work_dir (str): The working directory to save checkpoints. The logs - will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dirs'. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as :attr:`experiment_name`. - Defaults to None. - env_kwargs (dict, optional): Environment config passed in - :meth:`setup_env`. Defaults to None. - log_kwargs (dict, optional): Logger config passed in - :meth:`build_logger`. Defaults to None. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - """ - - model: nn.Module - optim_wrapper: BaseOptimWrapper - param_schedulers: ParamSchedulerType - - def __init__( - self, - *, - work_dir: str = "work_dirs", - experiment_name: str | None = None, - env_kwargs: dict | None = None, - log_kwargs: dict | None = None, - auto_scale_lr: dict | None = None, - ): - self._work_dir = osp.abspath(work_dir) - mmengine.mkdir_or_exist(self._work_dir) - - self._env_kwargs = env_kwargs or {} - self._setup_env(**self._env_kwargs) - - if experiment_name is not None: - self._experiment_name = f"{experiment_name}_{self.timestamp}" - else: - self._experiment_name = self.timestamp - - self._log_dir = osp.join(self.work_dir, self.timestamp) - mmengine.mkdir_or_exist(self._log_dir) - - log_kwargs = log_kwargs or {} - self.logger = self.build_logger(**log_kwargs) - - self._auto_scale_lr = auto_scale_lr - - self.dispatch_kwargs: dict = {} - self._prepared = False - - @property - def work_dir(self): - return self._work_dir - - @property - def log_dir(self): - return self._log_dir - - @property - def experiment_name(self): - return self._experiment_name - - @property - def launcher(self): - return self._launcher - - @property - def distributed(self): - return self._distributed - - @property - def seed(self): - return self._seed - - @property - def rank(self): - return self._rank - - @property - def world_size(self): - return self._world_size - - @property - def timestamp(self): - return self._timestamp - - @property - def randomness(self): - return self._randomness - - @abstractmethod - def prepare( - self, - model: nn.Module | dict, - *, - optim_wrapper: "BaseOptimWrapper" | dict | None = None, - param_scheduler: "_ParamScheduler" | dict | list | None = None, - compile: dict | bool = False, - dispatch_kwargs: dict | None = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for building a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - """ - - def _setup_env( - self, - *, - launcher: str | None = None, - cudnn_benchmark: bool = False, - mp_cfg: dict | None = None, - dist_cfg: dict | None = None, - resource_limit: int = 4096, - randomness: dict | None = None, - ): - """Setup environment. - - This method will do the following things: - - 1. setup multi-processing - 2. setup distributed - 3. set random seed - - Keyword Args: - launcher (str, optional): Way to launcher multi-process. Supported - launchers are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' - is provided, non-distributed environment will be launched. - If launcher is None, the launcher will be inferred according - some specified environments. Defaults to None. - cudnn_benchmark (bool): Whether to enable cudnn benchmark. - Defaults to False. - mp_cfg (dict, optional): Multi-processing config. Defaults to None. - dist_cfg (dict, optional): Distributed config. Defaults to None. - resource_limit (int): Resource limit. Defaults to 4096. - randomness (dict): Some settings to make the experiment as - reproducible as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random - number will be generated and it will be broadcasted to all - other processes if in distributed environment. - If ``cudnn_benchmark`` is ``True`` in but ``deterministic`` is - ``True`` in ``randomness``, the value of - ``torch.backends.cudnn.benchmark`` will be ``False`` finally. - """ - if randomness is None: - randomness = {"seed": None} - if launcher is None: - launcher = infer_launcher() - - self._launcher = launcher - if self._launcher == "none": - self._distributed = False - else: - self._distributed = True - - if cudnn_benchmark: - torch.backends.cudnn.benchmark = True - - mp_cfg = mp_cfg if mp_cfg is not None else {} - set_multi_processing(**mp_cfg, distributed=self._distributed) - - # init distributed env first, since logger depends on the dist info. - if self._distributed and not is_distributed(): - dist_cfg = dist_cfg if dist_cfg is not None else {} - self._setup_distributed(launcher, **dist_cfg) - - self._rank, self._world_size = get_dist_info() - - timestamp = torch.tensor(time.time(), dtype=torch.float64) - # broadcast timestamp from 0 process to other processes - broadcast(timestamp) - self._timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime(timestamp.item())) - - # https://github.com/pytorch/pytorch/issues/973 - # set resource limit - if platform.system() != "Windows": - import resource - - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - base_soft_limit = rlimit[0] - hard_limit = rlimit[1] - soft_limit = min(max(resource_limit, base_soft_limit), hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) - - self._randomness = randomness - self._set_randomness(**randomness) - - def _setup_distributed(self, *args, **kwargs): - """Setup distributed environment.""" - pass - - def _set_randomness( - self, - seed: int | None = None, - diff_rank_seed: bool = False, - deterministic: bool = False, - ) -> None: - """Set random seed to guarantee reproducible results. - - Args: - seed (int, optional): A number to set random modules. - Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds according - to global rank. Defaults to False. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - See https://pytorch.org/docs/stable/notes/randomness.html for - more details. - """ - from visengine.runner import set_random_seed - - self._seed = set_random_seed(seed=seed, deterministic=deterministic, diff_rank_seed=diff_rank_seed) - - def build_model(self, model: nn.Module | dict) -> nn.Module: - """Build model. - - If ``model`` is a dict, it will be used to build a ``nn.Module`` - object. Otherwise, if ``model`` is a ``nn.Module`` object it will be - returned directly. - - An example of ``model``:: - - model = dict(type='ResNet') - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build ``nn.Module`` object. If ``model`` is a ``nn.Module`` - object, just returns itself. - - Note: - The returned model must implement ``train_step``, ``test_step`` - if ``runner.train`` or ``runner.test`` will be called. If - ``runner.val`` will be called or ``val_cfg`` is configured, - model must implement `val_step`. - - Returns: - nn.Module: Model build from ``model``. - """ - if isinstance(model, nn.Module): - return model - elif isinstance(model, dict): - model = MODELS.build(model) - return model # type: ignore - else: - raise TypeError(f"model should be a nn.Module object or dict, but got {model}") - - def compile_model( - self, - model: nn.Module, - compile: dict | bool = False, - ) -> nn.Module: - """Compile model. - - Args: - model (nn.Module): Model to compile. - - Returns: - nn.Module: Compiled model. - """ - if isinstance(compile, bool) and not compile: - return model - - assert digit_version(TORCH_VERSION) >= digit_version("2.0.0"), ( - "PyTorch >= 2.0.0 is required to enable torch.compile" - ) - - if isinstance(compile, bool): - compile = {} - - target = compile.pop("target", "forward") - func = getattr(model, target) - compiled_func = torch.compile(func, **compile) - setattr(model, target, compiled_func) - self.logger.info('Model has been "compiled". The first few iterations will be slow, please be patient.') - - return model - - def _init_model_weights(self, model: nn.Module) -> nn.Module: - """Initialize the model weights if the model has - :meth:`init_weights`""" - if hasattr(model, "init_weights") and self.dispatch_kwargs.get("init_weights_for_test_or_val", True): - model.init_weights() - # sync params and buffers - for _, params in model.state_dict().items(): - broadcast(params) - - return model - - def build_optim_wrapper( - self, - optim_wrapper: Optimizer | BaseOptimWrapper | dict, - model: nn.Module | None = None, - ) -> BaseOptimWrapper: - """Build optimizer wrapper. - - If ``optim_wrapper`` is a config dict for only one optimizer, - the keys must contain ``optimizer``, and ``type`` is optional. - It will build a :obj:`OptimWrapper` by default. - - If ``optim_wrapper`` is a config dict for multiple optimizers, i.e., - it has multiple keys and each key is for an optimizer wrapper. The - constructor must be specified since - :obj:`DefaultOptimizerConstructor` cannot handle the building of - training with multiple optimizers. - - If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e., - each value of ``optim_wrapper`` represents an ``OptimWrapper`` - instance. ``build_optim_wrapper`` will directly build the - :obj:`OptimWrapperDict` instance from ``optim_wrapper``. - - Args: - optim_wrapper (BaseOptimWrapper or dict): An OptimWrapper object or a - dict to build OptimWrapper objects. If ``optim_wrapper`` is an - OptimWrapper, just return an ``OptimizeWrapper`` instance. - - Note: - For single optimizer training, if `optim_wrapper` is a config - dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it - must contain `optimizer` to build the corresponding optimizer. - - Examples: - >>> # build an optimizer - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)) - >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> # is also valid. - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build optimizer without `type` - >>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - maximize: False - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build multiple optimizers - >>> optim_wrapper_cfg = dict( - ... generator=dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)), - ... discriminator=dict(type='OptimWrapper', optimizer=dict( - ... type='Adam', lr=0.001)) - ... # need to customize a multiple optimizer constructor - ... constructor='CustomMultiOptimizerConstructor', - ...) - >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - name: generator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.1 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - name: discriminator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - 'discriminator': Adam ( - Parameter Group 0 - dampening: 0 - lr: 0.02 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - - Important: - If you need to build multiple optimizers, you should implement a - MultiOptimWrapperConstructor which gets parameters passed to - corresponding optimizers and compose the ``OptimWrapperDict``. - More details about how to customize OptimizerConstructor can be - found at `optimizer-docs`_. - - Returns: - BaseOptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. - """ - if isinstance(optim_wrapper, BaseOptimWrapper): - return optim_wrapper - if isinstance(optim_wrapper, dict | ConfigDict | Config): - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get("optimizer", None) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - optim_wrapper.setdefault("type", "OptimWrapper") - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or "constructor" in optim_wrapper: - assert model is not None - from visengine.optim import build_optim_wrapper - - return build_optim_wrapper(model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, BaseOptimWrapper): - raise ValueError( - f'each item mush be an optimizer object when "type" and "constructor" are not in optimizer, but got {name}={optim}' - ) - optim_wrappers[name] = optim - from visengine.optim import OptimWrapperDict - - return OptimWrapperDict(**optim_wrappers) # type: ignore - else: - raise TypeError(f"optimizer wrapper should be an OptimWrapper object or dict, but got {optim_wrapper}") - - def _build_param_scheduler( - self, - scheduler: "_ParamScheduler" | dict | list, - optim_wrapper: "BaseOptimWrapper", - default_args: dict, - ) -> list["_ParamScheduler"]: - """Build parameter schedulers for a single optimizer. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - optim_wrapper (BaseOptimWrapper): An optimizer wrapper object is - passed to construct ParamScheduler object. - - Returns: - list[_ParamScheduler]: List of parameter schedulers build from - ``scheduler``. - - Note: - If the train loop is built, when building parameter schedulers, - it supports setting the max epochs/iters as the default ``end`` - of schedulers, and supports converting epoch-based schedulers - to iter-based according to the ``convert_to_iter_based`` key. - """ - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - max_epochs = default_args.pop("max_epochs", None) - max_iters = default_args.pop("max_iters", None) - - param_schedulers = [] - for scheduler in schedulers: - from visengine.optim import _ParamScheduler - - if isinstance(scheduler, _ParamScheduler): - param_schedulers.append(scheduler) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if _scheduler.get("by_epoch", True): - if max_epochs is None: - raise ValueError("max_epochs must be specified in default_args") - default_end = max_epochs - else: - if max_iters is None: - raise ValueError("max_iters must be specified in default_args") - default_end = max_iters - _scheduler.setdefault("end", default_end) - self.logger.debug( - f"The `end` of {_scheduler['type']} is not set. Use the max epochs/iters of train loop as default." - ) - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict(optimizer=optim_wrapper, **default_args), - ) - ) - else: - raise TypeError(f"scheduler should be a _ParamScheduler object or dict, but got {scheduler}") - return param_schedulers - - def build_param_scheduler( - self, - scheduler: "_ParamScheduler" | dict | list, - optim_wrapper: "BaseOptimWrapper", - default_args: dict | None = None, - ) -> ParamSchedulerType: - """Build parameter schedulers. - - ``build_param_scheduler`` should be called after - ``build_optim_wrapper`` because the building logic will change - according to the number of optimizers built by the runner. - The cases are as below: - - - Single optimizer: When only one optimizer is built and used in the - runner, ``build_param_scheduler`` will return a list of - parameter schedulers. - - Multiple optimizers: When two or more optimizers are built and used - in runner, ``build_param_scheduler`` will return a dict containing - the same keys with multiple optimizers and each value is a list of - parameter schedulers. Note that, if you want different optimizers to - use different parameter schedulers to update optimizer's - hyper-parameters, the input parameter ``scheduler`` also needs to be - a dict and its key are consistent with multiple optimizers. - Otherwise, the same parameter schedulers will be used to update - optimizer's hyper-parameters. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - - Examples: - >>> # build one scheduler - >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) - >>> runner.optim_wrapper = runner.build_optim_wrapper( - >>> optim_cfg) - >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [] # noqa: E501 - - >>> # build multiple schedulers - >>> scheduler_cfg = [ - ... dict(type='MultiStepLR', milestones=[1, 2]), - ... dict(type='StepLR', step_size=1) - ... ] - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [, # noqa: E501 - ] - - Above examples only provide the case of one optimizer and one scheduler - or multiple schedulers. If you want to know how to set parameter - scheduler when using multiple optimizers, you can find more examples - `optimizer-docs`_. - - Returns: - list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of - parameter schedulers or a dictionary contains list of parameter - schedulers build from ``scheduler``. - - .. _optimizer-docs: - https://mmengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html - """ - if default_args is None: - default_args = {} - if "epoch_length" in self.dispatch_kwargs: - default_args["epoch_length"] = self.dispatch_kwargs["epoch_length"] - if "max_epochs" in self.dispatch_kwargs: - default_args["max_epochs"] = self.dispatch_kwargs["max_epochs"] - if "max_iters" in self.dispatch_kwargs: - default_args["max_iters"] = self.dispatch_kwargs["max_iters"] - - param_schedulers: ParamSchedulerType - from visengine.optim import OptimWrapperDict - - if not isinstance(optim_wrapper, OptimWrapperDict): - # Since `OptimWrapperDict` inherits from `OptimWrapper`, - # `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell - # whether `self.optim_wrapper` is an `OptimizerWrapper` or - # `OptimWrapperDict` instance. Therefore, here we simply check - # self.optim_wrapper is not an `OptimWrapperDict` instance and - # then assert it is an OptimWrapper instance. - assert isinstance(optim_wrapper, BaseOptimWrapper), ( - "`build_optimizer` should be called before`build_param_scheduler` because the latter depends on the former" - ) - param_schedulers = self._build_param_scheduler(scheduler, optim_wrapper, default_args) # type: ignore - return param_schedulers - else: - param_schedulers = {} - for name, optimizer in optim_wrapper.items(): - if isinstance(scheduler, dict) and "type" not in scheduler: - # scheduler is a dict and each item is a ParamScheduler - # object or a config to build ParamScheduler objects - param_schedulers[name] = self._build_param_scheduler(scheduler[name], optimizer, default_args) - else: - param_schedulers[name] = self._build_param_scheduler(scheduler, optimizer, default_args) - - return param_schedulers - - def _scale_lr(self) -> None: - """Automatically scaling learning rate in training according to the - ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch - size. - - It scales the learning rate linearly according to the - `paper `_. - - Note: - ``scale_lr`` must be called after building optimizer wrappers - and before building parameter schedulers. - """ - if self._auto_scale_lr is None or not self._auto_scale_lr.get("enable", False): - return None - - assert "base_batch_size" in self._auto_scale_lr, "Lack of `base_batch_size` in `auto_scale_lr`." - - real_bs = self.world_size * self.dispatch_kwargs["train_micro_batch_size_per_gpu"] - base_bs = self._auto_scale_lr["base_batch_size"] - ratio = float(real_bs) / float(base_bs) - self.logger.info( - f"LR is set based on batch size of {base_bs} and the current batch size is {real_bs}. Scaling the original LR by {ratio}." - ) - - def _is_built(schedulers): - if isinstance(schedulers, dict): - return False if "type" in schedulers else any(_is_built(s) for s in schedulers.values()) - if isinstance(schedulers, list): - return any(_is_built(s) for s in schedulers) - from visengine.optim import _ParamScheduler - - return isinstance(schedulers, _ParamScheduler) - - if hasattr(self, "param_schedulers") and _is_built(self.param_schedulers): - raise RuntimeError( - "`scale_lr` should be called before building ParamScheduler because ParamScheduler will store initial lr from optimizer wrappers" - ) - - assert isinstance(self.optim_wrapper, BaseOptimWrapper), ( - "`scale_lr should be called after building OptimWrapper" - ) - - from visengine.optim import OptimWrapperDict - - if isinstance(self.optim_wrapper, OptimWrapperDict): - wrappers = list(self.optim_wrapper.values()) - else: - wrappers = [self.optim_wrapper] # type: ignore - - for wrapper in wrappers: - for group in wrapper.optimizer.param_groups: - group["lr"] = group["lr"] * ratio - - def build_logger( - self, - log_level: int | str = "INFO", - log_file: str | None = None, - **kwargs, - ) -> MMLogger: - """Build a global asscessable MMLogger. - - Args: - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - log_file (str, optional): Path of filename to save log. - Defaults to None. - **kwargs: Remaining parameters passed to ``MMLogger``. - - Returns: - MMLogger: A MMLogger object build from ``logger``. - """ - if log_file is None: - log_file = osp.join(self.log_dir, f"{self._timestamp}.log") - - log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs) - log_cfg.setdefault("name", self.experiment_name) - # `torch.compile` in PyTorch 2.0 could close all user defined handlers - # unexpectedly. Using file mode 'a' can help prevent abnormal - # termination of the FileHandler and ensure that the log file could - # be continuously updated during the lifespan of the runner. - log_cfg.setdefault("file_mode", "a") - - return MMLogger.get_instance(**log_cfg) # type: ignore - - def model_state_dict(self) -> dict[str, Any]: - """Get model state dict.""" - from visengine.runner import weights_to_cpu - - return weights_to_cpu(self.model.state_dict()) - - def optim_state_dict(self) -> dict[str, Any]: - """Get optimizer state dict.""" - if isinstance(self.optim_wrapper, BaseOptimWrapper): - return self.optim_wrapper.state_dict() - else: - raise TypeError(f"self.optim_wrapper should be a `BaseOptimWrapper` instance, but got {self.optim_wrapper}") - - def scheduler_state_dict(self) -> dict | list: - """Get parameter scheduler state dict.""" - if isinstance(self.param_schedulers, dict): - state_dict: dict = {} - for name, schedulers in self.param_schedulers.items(): - state_dict[name] = [] - for scheduler in schedulers: - state_dict[name].append(scheduler.state_dict()) - return state_dict - else: - state_list = [] - for scheduler in self.param_schedulers: # type: ignore - state_list.append(scheduler.state_dict()) - return state_list - - def load_model_state_dict( - self, - state_dict: dict, - *, - strict: bool = False, - revise_keys: list | None = None, - ) -> None: - """Load model state from dict.""" - from visengine.runner.checkpoint import _load_checkpoint_to_model - - if revise_keys is None: - revise_keys = [(r"^module.", "")] - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - _load_checkpoint_to_model(model, state_dict, strict=strict, revise_keys=revise_keys) - - def load_optim_state_dict(self, state_dict: dict) -> None: - """Load optimizer state from dict.""" - self.optim_wrapper.load_state_dict(state_dict) - - def load_scheduler_state_dict(self, state_dict: dict | list) -> None: - """Load scheduler state from dict.""" - if isinstance(self.param_schedulers, dict): - assert isinstance(state_dict, dict) - for name, schedulers in self.param_schedulers.items(): - for scheduler, ckpt_scheduler in zip(schedulers, state_dict[name], strict=False): - scheduler.load_state_dict(ckpt_scheduler) - else: - for scheduler, ckpt_scheduler in zip(self.param_schedulers, state_dict, strict=False): # type: ignore - scheduler.load_state_dict(ckpt_scheduler) - - def load_or_resume( - self, - *, - load_from: str | None = None, - resume: bool | str = False, - ) -> dict | None: - """Load checkpoint or resume from checkpoint. - - Args: - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool or str): Whether to resume training. Defaults to - False. If ``resume`` is True and ``load_from`` is None, - automatically to find latest checkpoint from ``work_dir``. - If not found, resuming does nothing. If ``resume`` is a string, - it will be treated as the checkpoint file to resume from. - """ - from visengine.runner import find_latest_checkpoint - - if not resume and load_from is None: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if isinstance(resume, str): - resume_from = resume - elif resume and load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self._work_dir) - self.logger.info(f"Auto resumed from the latest checkpoint {resume_from}.") - elif resume and load_from is not None: - # resume from the specified checkpoint - resume_from = load_from - - if resume_from is not None: - return self.resume(resume_from) - elif load_from is not None: - return self.load_checkpoint(load_from) - - return None - - @abstractmethod - def load_checkpoint( - self, - filename: str, - *, - map_location: str | Callable = "cpu", - strict: bool = False, - revise_keys: list | None = None, - callback: Callable | None = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - - @abstractmethod - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: str | Callable = "default", - callback: Callable | None = None, - ) -> dict: - """Resume training from given ``filename``. - - Four types of states will be resumed. - - - model state - - optimizer state - - scheduler state - - randomness state - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - - @abstractmethod - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: dict | None = None, - callback: Callable | None = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - - def collect_env(self) -> tuple[dict, dict]: - """Collect the information of the running environments.""" - system_env = collect_env() - runtime_env: OrderedDict = OrderedDict() - runtime_env.update(self._env_kwargs) - runtime_env.update(self.randomness) - runtime_env["Distributed launcher"] = self.launcher - runtime_env["Distributed training"] = self.distributed - runtime_env["GPU number"] = self.world_size - - return system_env, runtime_env - - def _prepared_components(self): - return_items = [self.model] - if hasattr(self, "optim_wrapper"): - return_items.append(self.optim_wrapper) - - if hasattr(self, "param_schedulers"): - return_items.append(self.param_schedulers) - - return return_items[0] if len(return_items) == 1 else return_items diff --git a/libs/visengine/visengine/_strategy/deepspeed.py b/libs/visengine/visengine/_strategy/deepspeed.py deleted file mode 100644 index 0f1fd3c..0000000 --- a/libs/visengine/visengine/_strategy/deepspeed.py +++ /dev/null @@ -1,575 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import json -import os.path as osp -import time -from collections.abc import Callable -from typing import Any - -import torch - -from visengine.logging import print_log - -try: - import deepspeed -except ImportError: - deepspeed = None - -import logging - -import torch.nn as nn - -import visengine -from visengine.dist import init_dist, is_main_process -from visengine.optim import BaseOptimWrapper, _ParamScheduler -from visengine.registry import MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS, STRATEGIES -from visengine.runner.checkpoint import save_checkpoint, weights_to_cpu -from visengine.utils import apply_to, digit_version, get_git_hash - -from .base import BaseStrategy - - -def register_deepspeed_optimizers() -> list[str]: - """Register optimizers in ``deepspeed`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - deepspeed_optimizers = [] - try: - import deepspeed # noqa: F401 - except ImportError: - pass - else: - from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam - from deepspeed.ops.lamb import FusedLamb - from deepspeed.runtime.fp16.onebit import OnebitAdam, OnebitLamb, ZeroOneAdam - - OPTIMIZERS.register_module(module=DeepSpeedCPUAdam) - deepspeed_optimizers.append("DeepSpeedCPUAdam") - OPTIMIZERS.register_module(module=FusedAdam) - deepspeed_optimizers.append("FusedAdam") - OPTIMIZERS.register_module(module=FusedLamb) - deepspeed_optimizers.append("FusedLamb") - OPTIMIZERS.register_module(module=OnebitAdam) - deepspeed_optimizers.append("OnebitAdam") - OPTIMIZERS.register_module(module=OnebitLamb) - deepspeed_optimizers.append("OnebitLamb") - OPTIMIZERS.register_module(module=ZeroOneAdam) - deepspeed_optimizers.append("ZeroOneAdam") - - return deepspeed_optimizers - - -@OPTIM_WRAPPERS.register_module(force=True) -class DeepSpeedOptimWrapper(BaseOptimWrapper): - def __init__(self, optimizer): - super().__init__(optimizer) - self._model = None - - @property - def model(self): - if self._model is None: - raise ValueError("model attribute should be set before accessing.") - return self._model - - @model.setter - def model(self, value): - self._model = value - - def update_params(self, loss) -> None: # type: ignore - """Update parameters in :attr:`optimizer`.""" - self.backward(loss) - self.step() - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """ "Perform gradient back propagation.""" - self.model.backward(loss) - - def zero_grad(self, **kwargs) -> None: - raise NotImplementedError("DeepSpeedOptimWrapper does not support zero_grad method currently.") - - def step(self, **kwargs): - self.model.step() - - def state_dict(self) -> dict: - state_dict = {} - if self.base_param_settings is not None: - state_dict["base_param_settings"] = self.base_param_settings - - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - base_param_settings = state_dict.pop("base_param_settings", None) - - if base_param_settings is not None: - self.base_param_settings = base_param_settings - - -@MODEL_WRAPPERS.register_module(force=True) -class MMDeepSpeedEngineWrapper: - def __init__( - self, - *, - model: Any, # Use Any instead of 'deepspeed.DeepSpeedEngine' - inputs_to_half: list[int | str] | None = None, - ): - self.model = model - self._inputs_to_half = inputs_to_half - - def __getattr__(self, name): - return getattr(self.model, name) - - def train_step( - self, - data: dict | tuple | list, - optim_wrapper: DeepSpeedOptimWrapper, - ) -> dict[str, torch.Tensor]: - data = self.model.module.data_preprocessor(data, training=True) - data = self._cast_inputs_half(data) - losses = self._run_forward(data, mode="loss") - parsed_loss, log_vars = self.model.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - - return log_vars - - def val_step(self, data: dict | tuple | list) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.model.module.data_preprocessor(data, False) - data = self._cast_inputs_half(data) - return self._run_forward(data, mode="predict") - - def test_step(self, data: dict | tuple | list) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.model.module.data_preprocessor(data, False) - data = self._cast_inputs_half(data) - return self._run_forward(data, mode="predict") - - def _run_forward(self, data: dict | tuple | list, mode: str) -> dict | list: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self.model(**data, mode=mode) - elif isinstance(data, list | tuple): - results = self.model(*data, mode=mode) - else: - raise TypeError(f"Output of `data_preprocessor` should be list, tuple or dict, but got {type(data)}") - return results - - def _cast_inputs_half(self, inputs: list | tuple | dict | None) -> list | tuple | dict: - """Cast inputs to half precision if needed. - - Args: - inputs (list or tuple or dict or None): Inputs to be casted. - - Returns: - list or tuple or dict: Casted inputs. - """ - if self._inputs_to_half is None or inputs is None: - # Return an empty dict instead of None to satisfy the type checker - return inputs if inputs is not None else {} - - dtype = next(self.model.parameters()).dtype - if isinstance(inputs, list | tuple): - new_inputs = [] - for i, v in enumerate(inputs): - if i in self._inputs_to_half: - new_inputs.append(apply_to(v, lambda x: hasattr(x, "to"), lambda x: x.to(dtype))) - else: - new_inputs.append(v) - return inputs.__class__(new_inputs) - elif isinstance(inputs, dict): - for k, v in inputs.items(): - if k in self._inputs_to_half: - inputs[k] = apply_to(v, lambda x: hasattr(x, "to"), lambda x: x.to(dtype)) - return inputs - else: - raise TypeError(f"inputs should be list, tuple or dict, but got {type(inputs)}") - - -@STRATEGIES.register_module(force=True) -class DeepSpeedStrategy(BaseStrategy): - """Support training models with DeepSpeed. - - Note: - The detailed usage of parameters can be found at - https://www.deepspeed.ai/docs/config-json/. - - Args: - config (str or dict, optional): If it is a string, it is a path to load - config for deepspeed. Defaults to None. - zero_optimization (dict, optional): Enabling and configuring ZeRO - memory optimizations. Defaults to None. - gradient_clipping (float, optional): Enable gradient clipping with - value. Defaults to None. - fp16 (dict, optional): Configuration for using mixed precision/FP16 - training that leverages NVIDIA's Apex package. Defaults to None. - inputs_to_half (list[int or str], optional): Which inputs are to - converted to half precision. Defaults to None. - If ``fp16`` is enabled, it also should be set. - bf16 (dict, optional): Configuration for using bfloat16 floating-point - format as an alternative to FP16. Defaults to None. - amp (dict, optional): Configuration for using automatic mixed - precision (AMP) training that leverages NVIDIA's Apex AMP package. - Defaults to None. - activation_checkpointing (dict, optional): Reduce memory usage by - clearing activations of certain layers and recomputing them - during a backward pass. - Defaults to None. - aio (dict, optional): Configuring the asynchronous I/O module for - offloading parameter and optimizer states to persistent (NVMe) - storage. This module uses Linux native asynchronous I/O (libaio). - Defaults to None. - train_micro_batch_size_per_gpu (int, optional): Batch size to be - processed by one GPU in one step (without gradient accumulation). - Defaults to None. - gradient_accumulation_steps (int, optional): Number of training steps - to accumulate gradients before averaging and applying them. - Defaults to None. - exclude_frozen_parameters (bool, optional): Exclude frozen parameters - from saved checkpoint. - """ - - def __init__( - self, - *, - # the following args are for deepspeed - config: str | dict | None = None, - zero_optimization: dict | None = None, - gradient_clipping: float | None = None, - fp16: dict | None = None, - inputs_to_half: list[int | str] | None = None, - bf16: dict | None = None, - amp: dict | None = None, - activation_checkpointing: dict | None = None, - aio: dict | None = None, - train_micro_batch_size_per_gpu: int | None = None, - gradient_accumulation_steps: int | None = None, - # disable the log printed by deepseed - steps_per_print: int = 10000000000000, - # the following args are for BaseStrategy - exclude_frozen_parameters: bool | None = None, - **kwargs, - ): - assert deepspeed is not None, ( - "DeepSpeed is not installed. Please check https://github.com/microsoft/DeepSpeed#installation." - ) - - super().__init__(**kwargs) - - self.config = self._parse_config(config) - if zero_optimization is not None: - self.config["zero_optimization"] = zero_optimization - if gradient_clipping is not None: - self.config["gradient_clipping"] = gradient_clipping - if fp16 is not None: - self.config["fp16"] = fp16 - if bf16 is not None: - self.config["bf16"] = bf16 - if amp is not None: - self.config["amp"] = amp - if activation_checkpointing is not None: - self.config["activation_checkpointing"] = activation_checkpointing - if aio is not None: - self.config["aio"] = aio - if train_micro_batch_size_per_gpu is not None: - self.config["train_micro_batch_size_per_gpu"] = train_micro_batch_size_per_gpu - if gradient_accumulation_steps is not None: - self.config["gradient_accumulation_steps"] = gradient_accumulation_steps - else: - self.config.setdefault("gradient_accumulation_steps", 1) - self.config["steps_per_print"] = steps_per_print - self._inputs_to_half = inputs_to_half - assert exclude_frozen_parameters is None or digit_version(deepspeed.__version__) >= digit_version("0.13.2"), ( - "DeepSpeed >= 0.13.2 is required to enable exclude_frozen_parameters" - ) - self.exclude_frozen_parameters = exclude_frozen_parameters - - register_deepspeed_optimizers() - - def _parse_config(self, config): - if config is None: - config = {} - elif isinstance(config, str): - with open(config) as f: - config = json.load(f) - return config - - def _setup_distributed( # type: ignore - self, - launcher: str | None = None, - backend: str = "nccl", - **kwargs, - ): - """Setup distributed environment. - - Args: - launcher (str, optional): Way to launch multi processes. - DeepSpeedStrategy does not support the launcher argument. - backend (str): Communication Backends. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: Other arguments for :func:`deepspeed.init_distributed`. - """ - init_dist(launcher, backend, init_backend="deepspeed", **kwargs) - - def prepare( - self, - model: nn.Module | dict, - *, - optim_wrapper: BaseOptimWrapper | dict | None = None, - param_scheduler: _ParamScheduler | dict | list | None = None, - compile: dict | bool = False, - dispatch_kwargs: dict | None = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for build a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - """ - if self._prepared: - return self._prepared_components() - assert dispatch_kwargs is not None - self.dispatch_kwargs.update(dispatch_kwargs) - - model = self.build_model(model) - model = self._init_model_weights(model) - - if optim_wrapper is not None: - self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - self.model = self._wrap_model(model) - - self.optim_wrapper.model = self.model # type: ignore - - else: - self.model = self._wrap_model(model) - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler(param_scheduler, self.optim_wrapper) - self._prepared = True - # Store the components but don't return anything to match BaseStrategy - self._prepared_components() - # Return None to match the base class return type - return None - - def _wrap_model(self, model: nn.Module) -> nn.Module: - if hasattr(self, "optim_wrapper"): - engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize( - model=model, optimizer=self.optim_wrapper.optimizer, config=self.config - ) - else: - engine, *_ = deepspeed.initialize(model=model, config=self.config) - - wrapper = MMDeepSpeedEngineWrapper(model=engine, inputs_to_half=self._inputs_to_half) - return wrapper - - def load_checkpoint( - self, - filename: str, - *, - map_location: str | Callable = "cpu", - strict: bool = False, - revise_keys: list | None = None, - callback: Callable | None = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Warning: - `map_localtion` and `callback` parameters are not supported yet. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - """ - if revise_keys is None: - revise_keys = [(r"^module.", "")] - self.logger.info(f"Load checkpoint from {filename}") - - dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version("0.13.2"): - _, extra_ckpt = self.model.load_checkpoint( - dirname, - tag=basename, - load_optimizer_states=False, - load_module_strict=not self.exclude_frozen_parameters, - ) - else: - _, extra_ckpt = self.model.load_checkpoint(dirname, tag=basename, load_optimizer_states=False) - - return extra_ckpt - - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: str | Callable = "default", - callback: Callable | None = None, - ) -> dict: - """Resume training from given ``filename``. - - Warning: - `map_location` and `callback` parameters are not supported yet. - - Args: - filename (str): Accept local filepath. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - """ - self.logger.info(f"Resume checkpoint from {filename}") - - dirname, basename = osp.split(filename) - if digit_version(deepspeed.__version__) >= digit_version("0.13.2"): - _, extra_ckpt = self.model.load_checkpoint( - dirname, - tag=basename, - load_optimizer_states=resume_optimizer, - load_module_strict=not self.exclude_frozen_parameters, - ) - else: - _, extra_ckpt = self.model.load_checkpoint(dirname, tag=basename, load_optimizer_states=resume_optimizer) - - if resume_optimizer: - self.load_optim_state_dict(extra_ckpt.pop("optim_wrapper")) - - if resume_param_scheduler and hasattr(self, "param_schedulers"): - param_schedulers = extra_ckpt.pop("param_schedulers") - self.load_scheduler_state_dict(param_schedulers) - - # resume random seed - resumed_seed = extra_ckpt["meta"].get("seed", None) - current_seed = self._randomness.get("seed") - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning( - f'The value of random seed in the checkpoint "{resumed_seed}" is different from the value in `randomness` config "{current_seed}"' - ) - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - - return extra_ckpt - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: dict | None = None, - callback: Callable | None = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Warning: - `callback` parameter is not supported yet. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - """ - if extra_ckpt is None: - extra_ckpt = {} - if "meta" not in extra_ckpt: - extra_ckpt["meta"] = {} - extra_ckpt["meta"].update( - seed=self.seed, - time=time.strftime("%Y%m%d_%H%M%S", time.localtime()), - mmengine=mmengine.__version__ + get_git_hash(), - ) - - if save_param_scheduler and hasattr(self, "param_schedulers"): - extra_ckpt["param_schedulers"] = self.scheduler_state_dict() - - if ( - not save_optimizer - and self.model.zero_optimization_partition_weights() - and not self.model.zero_gather_16bit_weights_on_model_save() - ): - print_log( - "Configured to `save_optimizer=False`, but currently using " - "DeepSpeed's ZeRO stage 3 with " - "`gather_16bit_weights_on_model_save=False`. In " - "this configuration, the model cannot be saved properly " - "and will be saved with the optimizer state. " - "To support `save_optimizer=False`, please set " - "`gather_16bit_weights_on_model_save=True` in your " - "DeepSpeed config.", - logger="current", - level=logging.WARNING, - ) - save_optimizer = True - - state_dict_kwargs = {} - if digit_version(deepspeed.__version__) >= digit_version("0.13.2"): - state_dict_kwargs["exclude_frozen_parameters"] = self.exclude_frozen_parameters - - if save_optimizer: - if hasattr(self, "optim_wrapper"): - # The key can not be 'optimizer', otherwise error will be - # thrown when loading or resuming checkpoint. - extra_ckpt["optim_wrapper"] = self.optim_state_dict() - - dirname, basename = osp.split(filename) - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - **state_dict_kwargs, - ) - else: - if self.model.zero_optimization_partition_weights(): - state_dict = self.model._zero3_consolidated_16bit_state_dict(**state_dict_kwargs) - else: - state_dict = self.model.module_state_dict(**state_dict_kwargs) - - if is_main_process(): - ckpt = {"state_dict": weights_to_cpu(state_dict), **extra_ckpt} - save_checkpoint(ckpt, filename) diff --git a/libs/visengine/visengine/_strategy/distributed.py b/libs/visengine/visengine/_strategy/distributed.py deleted file mode 100644 index 9d4fb7b..0000000 --- a/libs/visengine/visengine/_strategy/distributed.py +++ /dev/null @@ -1,126 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -from collections.abc import Callable - -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel - -from visengine.device import get_device -from visengine.dist import init_dist, is_distributed, master_only -from visengine.registry import MODEL_WRAPPERS, STRATEGIES -from visengine.model import convert_sync_batchnorm, is_model_wrapper -from .single_device import SingleDeviceStrategy - - -@STRATEGIES.register_module(force=True) -class DDPStrategy(SingleDeviceStrategy): - """Distribution strategy for distributed data parallel training. - - Args: - model_wrapper (dict): Dict for model wrapper. Defaults to None. - sync_bn (str): Type of sync batch norm. Defaults to None. - Options are 'torch' and 'mmcv'. - **kwargs: Other arguments for :class:`BaseStrategy`. - """ - - def __init__( - self, - *, - model_wrapper: dict | None = None, - sync_bn: str | None = None, - **kwargs, - ): - super().__init__(**kwargs) - self.model_wrapper = model_wrapper - self.sync_bn = sync_bn - - def _setup_distributed( # type: ignore - self, - launcher: str = "pytorch", - backend: str = "nccl", - **kwargs, - ): - """Setup distributed environment. - - Args: - launcher (str): Way to launcher multi processes. Supported - launchers are 'pytorch', 'mpi' and 'slurm'. - backend (str): Communication Backends. Supported backends are - 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. - **kwargs: Other arguments for :func:`init_dist`. - """ - if not is_distributed(): - init_dist(launcher, backend, **kwargs) - - def convert_model(self, model: nn.Module) -> nn.Module: - """Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` - (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. - - Args: - model (nn.Module): Model to be converted. - - Returns: - nn.Module: Converted model. - """ - if self.sync_bn is not None: - try: - model = convert_sync_batchnorm(model, self.sync_bn) - except ValueError as e: - self.logger.error(f'cfg.sync_bn should be "torch" or "mmcv", but got {self.sync_bn}') - raise e - - return model - - def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: - """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom - distributed data-parallel module wrappers. - - Args: - model (nn.Module): Model to be wrapped. - - Returns: - nn.Module or DistributedDataParallel: nn.Module or subclass of - ``DistributedDataParallel``. - """ - if is_model_wrapper(model): - return model - - model = model.to(get_device()) - - model = self.convert_model(model) - - if self.model_wrapper is None: - # set broadcast_buffers as False to keep compatibility with - # OpenMMLab repos - self.model_wrapper = { - "type": "MMDistributedDataParallel", - "broadcast_buffers": False, - } - - default_args = { - "type": "MMDistributedDataParallel", - "module": model, - "device_ids": [int(os.environ["LOCAL_RANK"])], - } - model = MODEL_WRAPPERS.build(self.model_wrapper, default_args=default_args) - return model - - @master_only - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: dict | None = None, - callback: Callable | None = None, - ) -> None: - super().save_checkpoint( - filename=filename, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=extra_ckpt, - callback=callback, - ) diff --git a/libs/visengine/visengine/_strategy/fsdp.py b/libs/visengine/visengine/_strategy/fsdp.py deleted file mode 100644 index c0443ed..0000000 --- a/libs/visengine/visengine/_strategy/fsdp.py +++ /dev/null @@ -1,646 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from __future__ import annotations - -import copy -import inspect -import os -import os.path as osp -import time -from collections import OrderedDict -from collections.abc import Callable, Sequence -from functools import partial -from typing import TYPE_CHECKING - -import torch.nn as nn -from torch.distributed.fsdp import ( - FullStateDictConfig, - FullyShardedDataParallel, - LocalStateDictConfig, - StateDictType, -) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullOptimStateDictConfig, - LocalOptimStateDictConfig, - OptimStateDictConfig, - StateDictConfig, -) -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler - -import visengine -from visengine.config import Config, ConfigDict -from visengine.device import get_device -from visengine.dist import get_rank, is_main_process -from visengine.model import is_model_wrapper - -if TYPE_CHECKING: - from visengine.model import BaseDataPreprocessor - from visengine.optim import ( - AmpOptimWrapper, - BaseOptimWrapper, - OptimWrapper, - OptimWrapperDict, - _ParamScheduler, - build_optim_wrapper, - ) -from visengine.registry import ( - FUNCTIONS, - MODEL_WRAPPERS, - OPTIM_WRAPPERS, - PARAM_SCHEDULERS, - STRATEGIES, - Registry, -) -from visengine.utils import get_git_hash, mkdir_or_exist - -from .distributed import DDPStrategy -from .utils import MetaTensorContext - -FSDP = FullyShardedDataParallel -FSDP_CONFIGS = Registry("fsdp configs") -FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig) -FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig) -FSDP_CONFIGS.register_module(module=FullStateDictConfig) -FSDP_CONFIGS.register_module(module=LocalStateDictConfig) - - -@STRATEGIES.register_module(force=True) -class FSDPStrategy(DDPStrategy): - """Support training model with FullyShardedDataParallel (FSDP). - - Keyword Args: - model_wrapper (dict, optional): Config dict for model wrapper. The - default configuration is: - - Examples: - >>> model_wrapper = dict( - >>> type='MMFullyShardedDataParallel', - >>> use_orig_params=True, - >>> ) - - See more configurable arguments in - :class:`MMFullyShardedDataParallel`. Defaults to None - skip_init_weights (bool, optional): Whether to skip initialization of - weights. Defaults to False. This is useful when the parameters of - the large model are loaded from a checkpoint, since skipping the - initialization of weights can save a lot of time. - state_dict_cfg (str or dict): Configuration for - how to save and load the state dict of the model, optimizer, and - scheduler. - - - "local": save and load the sharded state dict in all ranks. - - "full": save and load the full state dict in rank 0. - - `dict` object: save and load the state dict more flexibly. For - example, you can first offload the state dict to the 'cpu' and - then save it to the disk. This can help you to load the - checkpoint in a non-gpu environment: - - Examples: - >>> state_dict_cfg=dict( - >>> state_dict_type='FULL_STATE_DICT', - >>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), - >>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True), - - See more configurable arguments for ``state_dict_cfg``, - ``state_dict_config``, and ``optim_state_dict_config``in - `FSDP official api documents`_ - kwargs (dict): Additional arguments passed to :class:`DDPStrategy`: - - - work_dir (str): The working directory to save checkpoints. - The logs will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dirs'. - - experiment_name (str, optional): Name of current experiment. If - not specified, timestamp will be used as :attr:`experiment_name`. - Defaults to None. - - env_kwargs (dict, optional): Environment config passed in - :meth:`setup_env`. Defaults to None. - - log_kwargs (dict, optional): Logger config passed in - :meth:`build_logger`. Defaults to None. - activation_checkpointing (dict, optional): Config dict for gradient - checkpoint. - - Examples: - >>> activation_checkpointing = dict(check_fn='CustomCheckFn') - >>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1)) - - - ``check_fn`` field should behave consistently with - ``auto_wrap_policy`` defined in `model_wrapper`, and other - fields will be passed to ``apply_activation_checkpointing`` - - `New in version 0.9.0.` - - .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type - """ - - def __init__( - self, - *, - model_wrapper: dict | None = None, - skip_init_weights=False, - state_dict_cfg: str | dict = "local", - activation_checkpointing: dict | None = None, - **kwargs, - ): - super().__init__(model_wrapper=model_wrapper, **kwargs) - self._init_state_dict_cfg(state_dict_cfg) - if not isinstance(skip_init_weights, bool): - raise TypeError(f"skip_init_weights must be a boolean, but got {type(skip_init_weights)}") - self.skip_init_weights = skip_init_weights - self.activation_checkpointing = activation_checkpointing - - def _wrap_model(self, model: nn.Module) -> None: - """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other - custom fully sharded data parallel module wrappers. - - Args: - model (nn.Module): Model to be wrapped. - - Returns: - FullyShardedDataParallel: ``MMFullyShardedDataParallel`` - or subclass of ``FullyShardedDataParallel``. - """ - try: - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - apply_activation_checkpointing, - ) - except ImportError: - apply_activation_checkpointing = None - - for module in model.modules(): - if isinstance(module, BaseDataPreprocessor): - module.to(get_device()) - - if is_model_wrapper(model): - return - - if self.model_wrapper is None: - self.model_wrapper = {"type": "MMFullyShardedDataParallel"} - - default_args = { - "module": model, - "device_id": int(os.environ["LOCAL_RANK"]), - "type": "MMFullyShardedDataParallel", - } - model = MODEL_WRAPPERS.build(self.model_wrapper, default_args=default_args) - model.set_state_dict_type( - model, - self.state_dict_type, - self.state_dict_config, - self.optim_state_dict_config, - ) - - if self.activation_checkpointing is not None: - if apply_activation_checkpointing is None: - raise RuntimeError( - "activation_checkpointing maybe deprecated by current " - "PyTorch version, maybe you could switch to PyTorch 2.0 " - "or 2.1 to use `activation_checkpointing`." - ) - cfg = copy.deepcopy(self.activation_checkpointing) - with FUNCTIONS.switch_scope_and_registry(None): - check_fn = cfg.pop("check_fn") - if isinstance(check_fn, str): - check_fn = FUNCTIONS.get(check_fn) - elif isinstance(check_fn, dict): - fn_type = check_fn.pop("type") - if isinstance(fn_type, str): - fn_type = FUNCTIONS.get(fn_type) - check_fn = partial(fn_type, **cfg) - - if not callable(check_fn): - raise TypeError("`check_fn` must be a callable function") - apply_activation_checkpointing(model, check_fn=check_fn, **cfg) - return model - - def _is_full_state_dict(self): - """Whether to save and load the full state_dict in rank 0.""" - return self.state_dict_type == StateDictType.FULL_STATE_DICT - - def build_model(self, model: nn.Module | dict) -> nn.Module: - """Build model. - - If skip_init_weights is True, the model will be built with an empty - weights. It means that :meth:`load_checkpoint` must be called to fill - the weights before training. - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build ``nn.Module`` object. If ``model`` is a ``nn.Module`` - object, just returns itself. - - Returns: - nn.Module: Model build from ``model``. - """ - if self.skip_init_weights: - if isinstance(model, dict): - # Accelerate initialization by skipping init weights - with MetaTensorContext(): - model = super().build_model(model) - model.to_empty(device="cpu") - else: - model = super().build_model(model) - - # `id_to_name` will be used to convert the `optim_state_dict` of the - # raw optimizer to the `optim_state_dict` - # returned by `FSDP.optim_state_dict` in - # `StateDictType.FULL_STATE_DICT` mode. - self.id_to_name = {} - for name, param in model.named_parameters(): - self.id_to_name[id(param)] = name - return model - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: dict | None = None, - callback: Callable | None = None, - ) -> None: - """Save checkpoint to given ``filename``. - - If ``state_dict_type`` is `full`, the checkpoint will only be saved in - rank0. The structure of the saved checkpoint is the same as the one - saved by ``DDPStrategy`` - - If ``state_dict_type`` is `local`, each rank will save the sharded - state dict to a directory, which means the saved structure will look - like this: - - .. code-block:: bash - - ── epoch_0.pth - ├── rank0.pth - ├── rank1.pth - ├── ... - └── rank8.pth - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - from visengine.runner.checkpoint import save_checkpoint - - state_dict: dict = {} - state_dict["state_dict"] = self.model_state_dict() - - # save optimizer state dict - if save_optimizer and hasattr(self, "optim_wrapper"): - state_dict["optimizer"] = self.optim_state_dict() - - # save param scheduler state dict - if save_param_scheduler and hasattr(self, "param_schedulers"): - state_dict["param_schedulers"] = self.scheduler_state_dict() - - # save extra checkpoint passed by users - if extra_ckpt is None: - extra_ckpt = {} - if "meta" not in extra_ckpt: - extra_ckpt["meta"] = {} - - extra_ckpt["meta"].update( - seed=self.seed, - time=time.strftime("%Y%m%d_%H%M%S", time.localtime()), - mmengine=mmengine.__version__ + get_git_hash(), - ) - state_dict.update(extra_ckpt) - - # users can do some modification before saving checkpoint - if callback is not None: - callback(state_dict) - - # In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint - # of different ranks in different files. - if not self._is_full_state_dict(): - rank = get_rank() - mkdir_or_exist(filename) - ckpt_name = f"rank{rank}.pth" - filename = osp.join(filename, ckpt_name) - save_checkpoint(state_dict, filename) - - if is_main_process(): - save_checkpoint(state_dict, filename) - - def model_state_dict(self) -> dict: - """Get model state dict based on the ``state_dict_type``. - - If ``state_dict_type`` is `full`, the model state dict will be the - same as the one of original unsharded model. - - If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True`` - in ``model_wrapper``. The key of the state dict will be the same as - the one of original unsharded model, but its value will be the sharded - one - - If ``state_dict_type`` is `local`, and ```use_orig_params``` is - ``False`` in ``model_wrapper``, the flatten and sharded state dict will - be returned. - - See more details in the `official api documents`_ - - .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict - """ - # We've set state_dict by `FSDP.set_state_dict_type`, therefore we - # should get model state dict by `FSDP.state_dict` - return self.model.state_dict() - - def optim_state_dict(self) -> dict: - """Get model state dict based on the ``state_dict_type``. - - If ``state_dict_type`` is ``full``, the optimizer state dict can be - loaded by the original unsharded optimizer. - - Otherwise, the optimizer state dict could only be loaded by the - optimizer with sharded parameters. - - Note: - The optimizer state dict is not the same as the one of original - optimizer even if in ``full`` mode, although they can be loaded - correctly. - - See more details in the `official api documents`_ - - .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict - """ - return FSDP.optim_state_dict(self.model, self.optim_wrapper) - - def load_checkpoint(self, filename: str, **kwargs) -> dict: - """Load checkpoint from given ``filename``. - - Note: - If ``state_dict_type`` is `local`, the filename should be a - directory contains ``rank{i}.pth``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - if self._is_full_state_dict(): - return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) - else: - rank = get_rank() - filename = osp.join(filename, f"rank{rank}.pth") - return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) - - def load_model_state_dict( - self, - state_dict: dict, - *, - strict: bool = False, - revise_keys: list | None = None, - ) -> None: # type: ignore - """Load model state from dict. - - Warning: - `revise_keys` is not supported yet. - - Args: - state_dict (dict): Model state dict returned by - :meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type`` - is ``full``. ``state_dict`` could be the result of - ``model.state_dict()`` - strict (bool): Whether to load model state dict strictly. - Defaults to False. - """ - # We should load state dict by `FSDP.load_state_dict` - if revise_keys is None: - revise_keys = [(r"^module.", "")] - self.model.load_state_dict(state_dict, strict=strict) - - def load_optim_state_dict(self, state_dict: dict) -> None: - """Load optimizer state from dict. - - Args: - state_dict (dict): The optimizer state dict. If ``state_dict_type`` - is ``full``. ``state_dict`` could be the result of - ``optimizer.state_dict()`` - """ - optim_state_dict = FSDP.optim_state_dict_to_load(state_dict, self.model, self.optim_wrapper.optimizer) - self.optim_wrapper.load_state_dict(optim_state_dict) - - def _init_state_dict_cfg(self, state_dict_cfg: str | dict) -> None: - """Make ``state_dict_type`` and ``state_dict_config`` can be configured - with string.""" - if isinstance(state_dict_cfg, str): - if state_dict_cfg == "full": - self.state_dict_type = StateDictType.FULL_STATE_DICT - self.state_dict_config = FullStateDictConfig(rank0_only=True, offload_to_cpu=True) - self.optim_state_dict_config = FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True) - elif state_dict_cfg == "local": - self.state_dict_type = StateDictType.LOCAL_STATE_DICT - self.state_dict_config = LocalStateDictConfig() - self.optim_state_dict_config = LocalOptimStateDictConfig() - else: - raise ValueError(f"FSDP only supports `full` and `local` state_dict_type, but got {state_dict_cfg}") - elif isinstance(state_dict_cfg, dict): - if "state_dict_type" not in state_dict_cfg: - self.state_dict_type = StateDictType.LOCAL_STATE_DICT - else: - state_dict_type = state_dict_cfg["state_dict_type"] - if isinstance(state_dict_type, str): - self.state_dict_type = StateDictType[state_dict_cfg["state_dict_type"]] - else: - self.state_dict_type = state_dict_type - state_dict_config = state_dict_cfg.get("state_dict_config") - if state_dict_config is None: - self.state_dict_config = LocalStateDictConfig() - elif isinstance(state_dict_config, dict): - self.state_dict_config = FSDP_CONFIGS.build(state_dict_cfg["state_dict_config"]) - else: - self.state_dict_config = state_dict_config - - optim_state_dict_config = state_dict_cfg.get("optim_state_dict_config") - if optim_state_dict_config is None: - self.optim_state_dict_config = LocalOptimStateDictConfig() - elif isinstance(optim_state_dict_config, dict): - self.optim_state_dict_config = FSDP_CONFIGS.build(state_dict_cfg["optim_state_dict_config"]) - else: - self.optim_state_dict_config = optim_state_dict_config - else: - raise TypeError(f"state_dict_cfg should be a `str` or a `dict`, but got {type(state_dict_cfg)}") - - if not isinstance(self.state_dict_type, StateDictType): - raise TypeError(f"state_dict_type must be StateDictType, but got {type(self.state_dict_type)}") - if not isinstance(self.state_dict_config, StateDictConfig): - raise TypeError(f"state_dict_config must be StateDictConfig, but got {type(self.state_dict_config)}") - if not isinstance(self.optim_state_dict_config, OptimStateDictConfig): - raise TypeError( - f"optim_state_dict_config must be OptimStateDictConfig, but got {type(self.optim_state_dict_config)}" - ) - - def build_optim_wrapper( - self, - optim_wrapper: Optimizer | OptimWrapper | dict, - model: nn.Module | None = None, - ) -> BaseOptimWrapper: - """Support sharding the optimizer state dict given a built optimizer or - optim_wrapper. - - See specific usage in :meth:`BaseStrategy.build_optim_wrapper`. - """ - if isinstance(optim_wrapper, Optimizer): - optim_wrapper = OptimWrapper(optim_wrapper) - if isinstance(optim_wrapper, BaseOptimWrapper): - assert model is not None - # NOTE: The only difference is that FSDPStrategy will shard - # the the built OptimWrapper - optimizer = optim_wrapper.optimizer - param_groups = optimizer.param_groups - optim_state_dict = optimizer.state_dict() - assert not optim_state_dict["state"], ( - "Optimizer state_dict should be empty when giving an built optim_wrapper to FSDPStrategy" - ) - # Align the state_dict with state_dict generated by - # FSDP.full_optim_state_dict - new_param_groups = [] - for group in param_groups: - new_group = {key: value for key, value in group.items() if key != "param"} - new_group["params"] = [self.id_to_name[id(param)] for param in group["params"]] - new_param_groups.append(new_group) - optim_state_dict["param_groups"] = new_param_groups - defaults = {k: v for k, v in optimizer.defaults.items() if k != "differentiable"} - - params_dict = {} - for k, v in model.named_parameters(): - if "_fsdp_wrapped_module" in k: - k = k.replace("_fsdp_wrapped_module.", "") - params_dict[k] = v - - params = [] - for param_group in new_param_groups: - _params = [] - for param_name in param_group["params"]: - if param_name not in params_dict: - raise RuntimeError( - "Failed to reconstruct the sharded optimizer. You can try to set `use_orig_params=True` in `model_wrapper`" - ) - _params.append(params_dict[param_name]) - param_group = {k: v for k, v in param_group.items() if k != "param"} - param_group["params"] = _params - params.append(param_group) - - new_optimizer = optimizer.__class__(params, **defaults) - - # Force to load the converted optim_state_dict in full mode. - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): - optim_state_dict = FSDP.optim_state_dict_to_load(optim_state_dict, model, new_optimizer) - new_optimizer.load_state_dict(optim_state_dict) - optim_wrapper.optimizer = new_optimizer - return optim_wrapper - if isinstance(optim_wrapper, dict | ConfigDict | Config): - assert model is not None - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get("optimizer", None) - optim_wrapper.setdefault("type", "OptimWrapper") - if optim_wrapper.get("type", "AmpOptimWrapper") in ( - "AmpOptimWrapper", - AmpOptimWrapper, - ): - optim_wrapper.setdefault("use_fsdp", True) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or "constructor" in optim_wrapper: - return build_optim_wrapper(model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, OptimWrapper): - raise ValueError( - f'each item mush be an optimizer object when "type" and "constructor" are not in optimizer, but got {name}={optim}' - ) - optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) - else: - raise TypeError(f"optimizer wrapper should be an OptimWrapper object or dict, but got {optim_wrapper}") - - def _build_param_scheduler( - self, - scheduler: _ParamScheduler | dict | list, - optim_wrapper: BaseOptimWrapper, - default_args: dict, - ) -> list[_ParamScheduler]: - """Override this method to update the scheduler with the reconstructed - sharded optimzer.""" - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - max_epochs = default_args.pop("max_epochs", None) - max_iters = default_args.pop("max_iters", None) - - param_schedulers = [] - for scheduler in schedulers: - # Update the built scheduler with the sharded optimizer - if isinstance(scheduler, _ParamScheduler | LRScheduler): - parameter_keys = inspect.signature(scheduler.__class__).parameters.keys() - kwargs = {k: v for k, v in scheduler.state_dict().items() if k in parameter_keys} - scheduler = scheduler.__class__(optim_wrapper, **kwargs) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if _scheduler.get("by_epoch", True): - if max_epochs is None: - raise ValueError("max_epochs must be specified in default_args") - default_end = max_epochs - else: - if max_iters is None: - raise ValueError("max_iters must be specified in default_args") - default_end = max_iters - _scheduler.setdefault("end", default_end) - self.logger.debug( - f"The `end` of {_scheduler['type']} is not set. Use the max epochs/iters of train loop as default." - ) - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict(optimizer=optim_wrapper, **default_args), - ) - ) - else: - raise TypeError(f"scheduler should be a _ParamScheduler object or dict, but got {scheduler}") - return param_schedulers diff --git a/libs/visengine/visengine/_strategy/single_device.py b/libs/visengine/visengine/_strategy/single_device.py deleted file mode 100644 index 90894bb..0000000 --- a/libs/visengine/visengine/_strategy/single_device.py +++ /dev/null @@ -1,287 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import time -from collections.abc import Callable - -import torch.nn as nn - -import visengine -from visengine.device import get_device -from visengine.model import revert_sync_batchnorm -from visengine.optim import BaseOptimWrapper, _ParamScheduler -from visengine.registry import STRATEGIES -from visengine.utils import get_git_hash - -from .base import BaseStrategy - - -@STRATEGIES.register_module(force=True) -class SingleDeviceStrategy(BaseStrategy): - """Strategy for single device training.""" - - def prepare( - self, - model: nn.Module | dict, - *, - optim_wrapper: BaseOptimWrapper | dict | None = None, - param_scheduler: _ParamScheduler | dict | list | None = None, - compile: dict | bool = False, - dispatch_kwargs: dict | None = None, - ): - """Prepare model and some components. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It - can be a dict used for build a model. - - Keyword Args: - optim_wrapper (BaseOptimWrapper or dict, optional): Computing the - gradient of model parameters and updating them. - Defaults to None. - See :meth:`build_optim_wrapper` for examples. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optim_wrapper` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - compile (dict, optional): Config to compile model. - Defaults to False. Requires PyTorch>=2.0. - dispatch_kwargs (dict, optional): Kwargs to be passed to other - methods of Strategy. Defaults to None. - If ``accumulative_counts`` is set in ``optim_wrapper``, you - need to provide ``max_iters`` in ``dispatch_kwargs``. - """ - if self._prepared: - return self._prepared_components() - if dispatch_kwargs is not None: - self.dispatch_kwargs.update(dispatch_kwargs) - - model = self.build_model(model) - model = self._init_model_weights(model) - model = self._wrap_model(model) - model = self.compile_model(model, compile=compile) - - self.model = model - - if optim_wrapper is not None: - self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - self._scale_lr() - - accumulative_counts = getattr(self.optim_wrapper, "_accumulative_counts", 1) - if accumulative_counts > 1: - if "max_iters" not in self.dispatch_kwargs: - raise ValueError( - f'"max_iters" must be specified because "accumulative_counts" was set as {accumulative_counts} which is greater than 1.' - ) - - self.optim_wrapper.initialize_count_status( # type: ignore - self.model, 0, self.dispatch_kwargs["max_iters"] - ) - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler(param_scheduler, self.optim_wrapper) - - self._prepared = True - return self._prepared_components() - - def _wrap_model(self, model: nn.Module) -> nn.Module: - model = self.convert_model(model) - current_device = get_device() - return model.to(current_device) - - def convert_model(self, model: nn.Module) -> nn.Module: - """Convert layers of model. - - convert all ``SyncBatchNorm`` (SyncBN) and - ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers in the model to - ``BatchNormXd`` layers. - - Args: - model (nn.Module): Model to convert. - """ - self.logger.info( - "Distributed training is not used, all SyncBatchNorm (SyncBN) " - "layers in the model will be automatically reverted to " - "BatchNormXd layers if they are used." - ) - model = revert_sync_batchnorm(model) - return model - - def load_checkpoint( - self, - filename: str, - *, - map_location: str | Callable = "cpu", - strict: bool = False, - revise_keys: list | None = None, - callback: Callable | None = None, - ) -> dict: - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - callback (callable, callable): Callback function to modify the - checkpoint after loading the checkpoint. - Defaults to None. - """ - from visengine.runner.checkpoint import _load_checkpoint - - if revise_keys is None: - revise_keys = [(r"^module.", "")] - self.logger.info(f"Load checkpoint from {filename}") - - if map_location == "default": - device = get_device() - checkpoint = _load_checkpoint(filename, map_location=device) - else: - checkpoint = _load_checkpoint(filename, map_location=map_location) - - # users can do some modification after loading checkpoint - if callback is not None: - callback(checkpoint) - - state_dict = checkpoint.pop("state_dict") - self.load_model_state_dict(state_dict, strict=strict, revise_keys=revise_keys) - - return checkpoint - - def resume( - self, - filename: str, - *, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: str | Callable = "default", - callback: Callable | None = None, - ) -> dict: - """Resume training from given ``filename``. - - Four types of states will be resumed. - - - model state - - optimizer state - - scheduler state - - randomness state - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - - Keyword Args: - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - self.logger.info(f"Resume checkpoint from {filename}") - - checkpoint = self.load_checkpoint(filename, map_location=map_location, callback=callback) - - if resume_optimizer: - self.load_optim_state_dict(checkpoint.pop("optimizer")) - - if resume_param_scheduler and hasattr(self, "param_schedulers"): - self.load_scheduler_state_dict(checkpoint.pop("param_schedulers")) - - # resume random seed - resumed_seed = checkpoint["meta"].get("seed", None) - current_seed = self._randomness.get("seed") - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning( - f'The value of random seed in the checkpoint "{resumed_seed}" is different from the value in `randomness` config "{current_seed}"' - ) - self._randomness.update(seed=resumed_seed) - self._set_randomness(**self._randomness) - - # resume iter - cur_iter = checkpoint["meta"]["iter"] - - if hasattr(self, "optim_wrapper"): - accumulative_counts = getattr(self.optim_wrapper, "_accumulative_counts", 1) - if accumulative_counts > 1: - if "max_iters" not in self.dispatch_kwargs: - raise ValueError( - f'"max_iters" must be specified because "accumulative_counts" was set as {accumulative_counts} which is greater than 1.' - ) - # Initiate inner count of `optim_wrapper`. - self.optim_wrapper.initialize_count_status( # type: ignore - self.model, cur_iter, self.dispatch_kwargs["max_iters"] - ) - - return checkpoint - - def save_checkpoint( - self, - filename: str, - *, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - extra_ckpt: dict | None = None, - callback: Callable | None = None, - ) -> None: - """Save checkpoint to given ``filename``. - - Args: - filename (str): Filename to save checkpoint. - - Keyword Args: - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - extra_ckpt (dict, optional): Extra checkpoint to save. - Defaults to None. - callback (callable, callable): Callback function to modify the - checkpoint before saving the checkpoint. - Defaults to None. - """ - from visengine.runner.checkpoint import save_checkpoint - - state_dict: dict = {} - state_dict["state_dict"] = self.model_state_dict() - - # save optimizer state dict - if save_optimizer and hasattr(self, "optim_wrapper"): - state_dict["optimizer"] = self.optim_state_dict() - - if save_param_scheduler and hasattr(self, "param_schedulers"): - state_dict["param_schedulers"] = self.scheduler_state_dict() - - # save extra checkpoint passed by users - if extra_ckpt is None: - extra_ckpt = {} - if "meta" not in extra_ckpt: - extra_ckpt["meta"] = {} - extra_ckpt["meta"].update( - seed=self.seed, - time=time.strftime("%Y%m%d_%H%M%S", time.localtime()), - visengine=visengine.__version__ + get_git_hash(), - ) - - state_dict.update(extra_ckpt) - - # users can do some modification before saving checkpoint - if callback is not None: - callback(state_dict) - - save_checkpoint(state_dict, filename) diff --git a/libs/visengine/visengine/_strategy/utils.py b/libs/visengine/visengine/_strategy/utils.py deleted file mode 100644 index 0da0c43..0000000 --- a/libs/visengine/visengine/_strategy/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from torch._subclasses.fake_tensor import _is_tensor_constructor -from torch.utils._python_dispatch import TorchDispatchMode - - -class MetaTensorContext(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if _is_tensor_constructor(func): - device_idx = [arg.name for arg in func._schema.arguments].index("device") - if len(args) > device_idx: - args = list(args) - args[device_idx] = "meta" - else: - kwargs["device"] = "meta" - return func(*args, **kwargs) diff --git a/libs/visengine/visengine/config/__init__.py b/libs/visengine/visengine/config/__init__.py deleted file mode 100644 index 43f0264..0000000 --- a/libs/visengine/visengine/config/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .config import Config, ConfigDict, DictAction, read_base - -__all__ = ["Config", "ConfigDict", "DictAction", "read_base"] diff --git a/libs/visengine/visengine/config/config.py b/libs/visengine/visengine/config/config.py deleted file mode 100644 index 96f8a53..0000000 --- a/libs/visengine/visengine/config/config.py +++ /dev/null @@ -1,1792 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import ast -import copy -import difflib -import os -import os.path as osp -import platform -import shutil -import sys -import tempfile -import types -import uuid -import warnings -from argparse import Action, ArgumentParser, Namespace -from collections import OrderedDict, abc -from collections.abc import Sequence -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Union - -import yapf -from addict import Dict -from rich.console import Console -from rich.text import Text -from yapf.yapflib.yapf_api import FormatCode - -from visengine.fileio import dump, load -from visengine.logging import print_log -from visengine.utils import ( - check_file_exist, - digit_version, - get_installed_path, - import_modules_from_strings, - is_installed, -) - -from .lazy import LazyAttr, LazyObject -from .utils import ( - ConfigParsingError, - ImportTransformer, - RemoveAssignFromAST, - _gather_abs_import_lazyobj, - _get_external_cfg_base_path, - _get_external_cfg_path, - _get_package_and_cfg_path, - _is_builtin_module, -) - -BASE_KEY = "_base_" -DELETE_KEY = "_delete_" -DEPRECATION_KEY = "_deprecation_" -RESERVED_KEYS = ["filename", "text", "pretty_text", "env_variables"] - -if platform.system() == "Windows": - import regex as re -else: - import re # type: ignore - - -def _lazy2string(cfg_dict, dict_type=None): - if isinstance(cfg_dict, dict): - dict_type = dict_type or type(cfg_dict) - return dict_type({k: _lazy2string(v, dict_type) for k, v in dict.items(cfg_dict)}) - elif isinstance(cfg_dict, tuple | list): - return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) - elif isinstance(cfg_dict, LazyAttr | LazyObject): - return f"{cfg_dict.module}.{cfg_dict!s}" - else: - return cfg_dict - - -class ConfigDict(Dict): - """A dictionary for config which has the same interface as python's built- - in dictionary and can be used as a normal dictionary. - - The Config class would transform the nested fields (dictionary-like fields) - in config file into ``ConfigDict``. - - If the class attribute ``lazy`` is ``False``, users will get the - object built by ``LazyObject`` or ``LazyAttr``, otherwise users will get - the ``LazyObject`` or ``LazyAttr`` itself. - - The ``lazy`` should be set to ``True`` to avoid building the imported - object during configuration parsing, and it should be set to False outside - the Config to ensure that users do not experience the ``LazyObject``. - """ - - lazy = False - - def __init__(self, *args, **kwargs): - object.__setattr__(self, "__parent", kwargs.pop("__parent", None)) - object.__setattr__(self, "__key", kwargs.pop("__key", None)) - object.__setattr__(self, "__frozen", False) - for arg in args: - if not arg: - continue - # Since ConfigDict.items will convert LazyObject to real object - # automatically, we need to call super().items() to make sure - # the LazyObject will not be converted. - if isinstance(arg, ConfigDict): - for key, val in dict.items(arg): - self[key] = self._hook(val) - elif isinstance(arg, dict): - for key, val in arg.items(): - self[key] = self._hook(val) - elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): - self[arg[0]] = self._hook(arg[1]) - else: - for key, val in iter(arg): - self[key] = self._hook(val) - - for key, val in dict.items(kwargs): - self[key] = self._hook(val) - - def __missing__(self, name): - raise KeyError(name) - - def __getattr__(self, name): - try: - value = super().__getattr__(name) - if isinstance(value, LazyAttr | LazyObject) and not self.lazy: - value = value.build() - except KeyError: - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - except Exception as e: - raise e - else: - return value - - @classmethod - def _hook(cls, item): - # avoid to convert user defined dict to ConfigDict. - if type(item) in (dict, OrderedDict): - return cls(item) - elif isinstance(item, list | tuple): - return type(item)(cls._hook(elem) for elem in item) - return item - - def __setattr__(self, name, value): - value = self._hook(value) - return super().__setattr__(name, value) - - def __setitem__(self, name, value): - value = self._hook(value) - return super().__setitem__(name, value) - - def __getitem__(self, key): - return self.build_lazy(super().__getitem__(key)) - - def __deepcopy__(self, memo): - other = self.__class__() - memo[id(self)] = other - for key, value in super().items(): - other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) - return other - - def __copy__(self): - other = self.__class__() - for key, value in super().items(): - other[key] = value - return other - - copy = __copy__ - - def __iter__(self): - # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict` - # to get the built lazy object - return iter(self.keys()) - - def get(self, key: str, default: Any | None = None) -> Any: - """Get the value of the key. If class attribute ``lazy`` is True, the - LazyObject will be built and returned. - - Args: - key (str): The key. - default (any, optional): The default value. Defaults to None. - - Returns: - Any: The value of the key. - """ - return self.build_lazy(super().get(key, default)) - - def pop(self, key, default=None): - """Pop the value of the key. If class attribute ``lazy`` is True, the - LazyObject will be built and returned. - - Args: - key (str): The key. - default (any, optional): The default value. Defaults to None. - - Returns: - Any: The value of the key. - """ - return self.build_lazy(super().pop(key, default)) - - def update(self, *args, **kwargs) -> None: - """Override this method to make sure the LazyObject will not be built - during updating.""" - other = {} - if args: - if len(args) > 1: - raise TypeError("update only accept one positional argument") - # Avoid to used self.items to build LazyObject - for key, value in dict.items(args[0]): - other[key] = value - - for key, value in dict(kwargs).items(): - other[key] = value - for k, v in other.items(): - if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)): - self[k] = self._hook(v) - else: - self[k].update(v) - - def build_lazy(self, value: Any) -> Any: - """If class attribute ``lazy`` is False, the LazyObject will be built - and returned. - - Args: - value (Any): The value to be built. - - Returns: - Any: The built value. - """ - if isinstance(value, LazyAttr | LazyObject) and not self.lazy: - value = value.build() - return value - - def values(self): - """Yield the values of the dictionary. - - If class attribute ``lazy`` is False, the value of ``LazyObject`` or - ``LazyAttr`` will be built and returned. - """ - values = [] - for value in super().values(): - values.append(self.build_lazy(value)) - return values - - def items(self): - """Yield the keys and values of the dictionary. - - If class attribute ``lazy`` is False, the value of ``LazyObject`` or - ``LazyAttr`` will be built and returned. - """ - items = [] - for key, value in super().items(): - items.append((key, self.build_lazy(value))) - return items - - def merge(self, other: dict): - """Merge another dictionary into current dictionary. - - Args: - other (dict): Another dictionary. - """ - default = object() - - def _merge_a_into_b(a, b): - if isinstance(a, dict): - if not isinstance(b, dict): - a.pop(DELETE_KEY, None) - return a - if a.pop(DELETE_KEY, False): - b.clear() - all_keys = list(b.keys()) + list(a.keys()) - return { - key: _merge_a_into_b(a.get(key, default), b.get(key, default)) - for key in all_keys - if key != DELETE_KEY - } - else: - return a if a is not default else b - - merged = _merge_a_into_b(copy.deepcopy(other), copy.deepcopy(self)) - self.clear() - for key, value in merged.items(): - self[key] = value - - def __reduce_ex__(self, proto): - # Override __reduce_ex__ to avoid `self.items` will be - # called by CPython interpreter during pickling. See more details in - # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 - if digit_version(platform.python_version()) < digit_version("3.8"): - return (self.__class__, (dict(super().items()),), None, None, None) - else: - return (self.__class__, (dict(super().items()),), None, None, None, None) - - def __eq__(self, other): - if isinstance(other, ConfigDict): - return other.to_dict() == self.to_dict() - elif isinstance(other, dict): - return dict(self.items()) == other - else: - return False - - def _to_lazy_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and keep - the ``LazyObject`` or ``LazyAttr`` object not built.""" - - def _to_dict(data): - if isinstance(data, ConfigDict): - return {key: _to_dict(value) for key, value in Dict.items(data)} - elif isinstance(data, dict): - return {key: _to_dict(value) for key, value in data.items()} - elif isinstance(data, list | tuple): - return type(data)(_to_dict(item) for item in data) - else: - return data - - return _to_dict(self) - - def to_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and - convert the ``LazyObject`` or ``LazyAttr`` to string.""" - return _lazy2string(self, dict_type=dict) - - -def add_args(parser: ArgumentParser, cfg: dict, prefix: str = "") -> ArgumentParser: - """Add config fields into argument parser. - - Args: - parser (ArgumentParser): Argument parser. - cfg (dict): Config dictionary. - prefix (str, optional): Prefix of parser argument. - Defaults to ''. - - Returns: - ArgumentParser: Argument parser containing config fields. - """ - for k, v in cfg.items(): - if isinstance(v, str): - parser.add_argument("--" + prefix + k) - elif isinstance(v, bool): - parser.add_argument("--" + prefix + k, action="store_true") - elif isinstance(v, int): - parser.add_argument("--" + prefix + k, type=int) - elif isinstance(v, float): - parser.add_argument("--" + prefix + k, type=float) - elif isinstance(v, dict): - add_args(parser, v, prefix + k + ".") - elif isinstance(v, abc.Iterable): - parser.add_argument("--" + prefix + k, type=type(next(iter(v))), nargs="+") - else: - print_log(f"cannot parse key {prefix + k} of type {type(v)}", logger="current") - return parser - - -class Config: - """A facility for config and config files. - - It supports common file formats as configs: python/json/yaml. - ``Config.fromfile`` can parse a dictionary from a config file, then - build a ``Config`` instance with the dictionary. - The interface is the same as a dict object and also allows access config - values as attributes. - - Args: - cfg_dict (dict, optional): A config dictionary. Defaults to None. - cfg_text (str, optional): Text of config. Defaults to None. - filename (str or Path, optional): Name of config file. - Defaults to None. - format_python_code (bool): Whether to format Python code by yapf. - Defaults to True. - - Here is a simple example: - - Examples: - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> cfg.a - 1 - >>> cfg.b - {'b1': [0, 1]} - >>> cfg.b.b1 - [0, 1] - >>> cfg = Config.fromfile('tests/data/config/a.py') - >>> cfg.filename - "/home/username/projects/mmengine/tests/data/config/a.py" - >>> cfg.item4 - 'test' - >>> cfg - "Config [path: /home/username/projects/mmengine/tests/data/config/a.py] - :" - "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" - - You can find more advance usage in the `config tutorial`_. - - .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html - """ - - def __init__( - self, - cfg_dict: dict | None = None, - cfg_text: str | None = None, - filename: str | Path | None = None, - env_variables: dict | None = None, - format_python_code: bool = True, - ): - filename = str(filename) if isinstance(filename, Path) else filename - if cfg_dict is None: - cfg_dict = {} - elif not isinstance(cfg_dict, dict): - raise TypeError(f"cfg_dict must be a dict, but got {type(cfg_dict)}") - for key in cfg_dict: - if key in RESERVED_KEYS: - raise KeyError(f"{key} is reserved for config file") - - if not isinstance(cfg_dict, ConfigDict): - cfg_dict = ConfigDict(cfg_dict) - super().__setattr__("_cfg_dict", cfg_dict) - super().__setattr__("_filename", filename) - super().__setattr__("_format_python_code", format_python_code) - if not hasattr(self, "_imported_names"): - super().__setattr__("_imported_names", set()) - - if cfg_text: - text = cfg_text - elif filename: - with open(filename, encoding="utf-8") as f: - text = f.read() - else: - text = "" - super().__setattr__("_text", text) - if env_variables is None: - env_variables = {} - super().__setattr__("_env_variables", env_variables) - - @staticmethod - def fromfile( - filename: str | Path, - use_predefined_variables: bool = True, - import_custom_modules: bool = True, - use_environment_variables: bool = True, - lazy_import: bool | None = None, - format_python_code: bool = True, - ) -> "Config": - """Build a Config instance from config file. - - Args: - filename (str or Path): Name of config file. - use_predefined_variables (bool, optional): Whether to use - predefined variables. Defaults to True. - import_custom_modules (bool, optional): Whether to support - importing custom modules in config. Defaults to None. - use_environment_variables (bool, optional): Whether to use - environment variables. Defaults to True. - lazy_import (bool): Whether to load config in `lazy_import` mode. - If it is `None`, it will be deduced by the content of the - config file. Defaults to None. - format_python_code (bool): Whether to format Python code by yapf. - Defaults to True. - - Returns: - Config: Config instance built from config file. - """ - filename = str(filename) if isinstance(filename, Path) else filename - if lazy_import is False or (lazy_import is None and not Config._is_lazy_import(filename)): - cfg_dict, cfg_text, env_variables = Config._file2dict( - filename, - use_predefined_variables, - use_environment_variables, - lazy_import, - ) - if import_custom_modules and cfg_dict.get("custom_imports", None): - try: - import_modules_from_strings(**cfg_dict["custom_imports"]) - except ImportError as e: - err_msg = ( - f"Failed to import custom modules from {cfg_dict['custom_imports']}, the current sys.path is: " - ) - for p in sys.path: - err_msg += f"\n {p}" - err_msg += "\nYou should set `PYTHONPATH` to make `sys.path` include the directory which contains your custom module" - raise ImportError(err_msg) from e - return Config( - cfg_dict, - cfg_text=cfg_text, - filename=filename, - env_variables=env_variables, - ) - else: - # Enable lazy import when parsing the config. - # Using try-except to make sure ``ConfigDict.lazy`` will be reset - # to False. See more details about lazy in the docstring of - # ConfigDict - ConfigDict.lazy = True - try: - cfg_dict, imported_names = Config._parse_lazy_import(filename) - except Exception as e: - raise e - finally: - # disable lazy import to get the real type. See more details - # about lazy in the docstring of ConfigDict - ConfigDict.lazy = False - - cfg = Config(cfg_dict, filename=filename, format_python_code=format_python_code) - object.__setattr__(cfg, "_imported_names", imported_names) - return cfg - - @staticmethod - def fromstring(cfg_str: str, file_format: str) -> "Config": - """Build a Config instance from config text. - - Args: - cfg_str (str): Config text. - file_format (str): Config file format corresponding to the - config str. Only py/yml/yaml/json type are supported now! - - Returns: - Config: Config object generated from ``cfg_str``. - """ - if file_format not in [".py", ".json", ".yaml", ".yml"]: - raise OSError("Only py/yml/yaml/json type are supported now!") - if file_format != ".py" and "dict(" in cfg_str: - # check if users specify a wrong suffix for python - warnings.warn('Please check "file_format", the file format may be .py', stacklevel=2) - - # A temporary file can not be opened a second time on Windows. - # See https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile for more details. - # `temp_file` is opened first in `tempfile.NamedTemporaryFile` and - # second in `Config.from_file`. - # In addition, a named temporary file will be removed after closed. - # As a workaround we set `delete=False` and close the temporary file - # before opening again. - - with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=file_format, delete=False) as temp_file: - temp_file.write(cfg_str) - - cfg = Config.fromfile(temp_file.name) - os.remove(temp_file.name) # manually delete the temporary file - return cfg - - @staticmethod - def _get_base_modules(nodes: list) -> list: - """Get base module name from parsed code. - - Args: - nodes (list): Parsed code of the config file. - - Returns: - list: Name of base modules. - """ - - def _get_base_module_from_with(with_nodes: list) -> list: - """Get base module name from if statement in python file. - - Args: - with_nodes (list): List of if statement. - - Returns: - list: Name of base modules. - """ - base_modules = [] - for node in with_nodes: - assert isinstance(node, ast.ImportFrom), ( - "Illegal syntax in config file! Only `from ... import ...` could be implemented` in with read_base()`" - ) - assert node.module is not None, ( - "Illegal syntax in config file! Syntax like `from . import xxx` is not allowed in `with read_base()`" - ) - base_modules.append(node.level * "." + node.module) - return base_modules - - for idx, node in enumerate(nodes): - if ( - isinstance(node, ast.Assign) - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == BASE_KEY - ): - raise ConfigParsingError( - "The configuration file type in the inheritance chain " - "must match the current configuration file type, either " - '"lazy_import" or non-"lazy_import". You got this error ' - f'since you use the syntax like `_base_ = "{node.targets[0].id}"` ' - "in your config. You should use `with read_base(): ... to` " - "mark the inherited config file. See more information " - "in https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html" - ) - - if not isinstance(node, ast.With): - continue - - expr = node.items[0].context_expr - if ( - not isinstance(expr, ast.Call) - or not expr.func.id == "read_base" # type: ignore - or len(node.items) > 1 - ): - raise ConfigParsingError("Only `read_base` context manager can be used in the config") - - # The original code: - # ``` - # with read_base(): - # from .._base_.default_runtime import * - # ``` - # The processed code: - # ``` - # from .._base_.default_runtime import * - # ``` - # As you can see, the if statement is removed and the - # from ... import statement will be unindent - for nested_idx, nested_node in enumerate(node.body): - nodes.insert(idx + nested_idx + 1, nested_node) - nodes.pop(idx) - return _get_base_module_from_with(node.body) - return [] - - @staticmethod - def _validate_py_syntax(filename: str): - """Validate syntax of python config. - - Args: - filename (str): Filename of python config file. - """ - with open(filename, encoding="utf-8") as f: - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError(f"There are syntax errors in config file {filename}: {e}") - - @staticmethod - def _substitute_predefined_vars(filename: str, temp_config_name: str): - """Substitute predefined variables in config with actual values. - - Sometimes we want some variables in the config to be related to the - current path or file name, etc. - - Here is an example of a typical usage scenario. When training a model, - we define a working directory in the config that save the models and - logs. For different configs, we expect to define different working - directories. A common way for users is to use the config file name - directly as part of the working directory name, e.g. for the config - ``config_setting1.py``, the working directory is - ``. /work_dir/config_setting1``. - - This can be easily achieved using predefined variables, which can be - written in the config `config_setting1.py` as follows - - .. code-block:: python - - work_dir = '. /work_dir/{{ fileBasenameNoExtension }}' - - - Here `{{ fileBasenameNoExtension }}` indicates the file name of the - config (without the extension), and when the config class reads the - config file, it will automatically parse this double-bracketed string - to the corresponding actual value. - - .. code-block:: python - - cfg = Config.fromfile('. /config_setting1.py') - cfg.work_dir # ". /work_dir/config_setting1" - - - For details, Please refer to docs/zh_cn/advanced_tutorials/config.md . - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - """ - file_dirname = osp.dirname(filename) - file_basename = osp.basename(filename) - file_basename_no_extension = osp.splitext(file_basename)[0] - file_extname = osp.splitext(filename)[1] - support_templates = { - "fileDirname": file_dirname, - "fileBasename": file_basename, - "fileBasenameNoExtension": file_basename_no_extension, - "fileExtname": file_extname, - } - with open(filename, encoding="utf-8") as f: - config_file = f.read() - for key, value in support_templates.items(): - regexp = r"\{\{\s*" + str(key) + r"\s*\}\}" - value = value.replace("\\", "/") - config_file = re.sub(regexp, value, config_file) - with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: - tmp_config_file.write(config_file) - - @staticmethod - def _substitute_env_variables(filename: str, temp_config_name: str): - """Substitute environment variables in config with actual values. - - Sometimes, we want to change some items in the config with environment - variables. For examples, we expect to change dataset root by setting - ``DATASET_ROOT=/dataset/root/path`` in the command line. This can be - easily achieved by writing lines in the config as follows - - .. code-block:: python - - data_root = '{{$DATASET_ROOT:/default/dataset}}/images' - - - Here, ``{{$DATASET_ROOT:/default/dataset}}`` indicates using the - environment variable ``DATASET_ROOT`` to replace the part between - ``{{}}``. If the ``DATASET_ROOT`` is not set, the default value - ``/default/dataset`` will be used. - - Environment variables not only can replace items in the string, they - can also substitute other types of data in config. In this situation, - we can write the config as below - - .. code-block:: python - - model = dict( - bbox_head = dict(num_classes={{'$NUM_CLASSES:80'}})) - - - For details, Please refer to docs/zh_cn/tutorials/config.md . - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - """ - with open(filename, encoding="utf-8") as f: - config_file = f.read() - regexp = r"\{\{[\'\"]?\s*\$(\w+)\s*\:\s*(\S*?)\s*[\'\"]?\}\}" - keys = re.findall(regexp, config_file) - env_variables = {} - for var_name, value in keys: - regexp = r"\{\{[\'\"]?\s*\$" + var_name + r"\s*\:\s*" + value + r"\s*[\'\"]?\}\}" - if var_name in os.environ: - value = os.environ[var_name] - env_variables[var_name] = value - print_log( - f"Using env variable `{var_name}` with value of {value} to replace item in config.", - logger="current", - ) - if not value: - raise KeyError( - f"`{var_name}` cannot be found in `os.environ`. Please set `{var_name}` in environment or give a default value." - ) - config_file = re.sub(regexp, value, config_file) - - with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: - tmp_config_file.write(config_file) - return env_variables - - @staticmethod - def _pre_substitute_base_vars(filename: str, temp_config_name: str) -> dict: - """Preceding step for substituting variables in base config with actual - value. - - Args: - filename (str): Filename of config. - temp_config_name (str): Temporary filename to save substituted - config. - - Returns: - dict: A dictionary contains variables in base config. - """ - with open(filename, encoding="utf-8") as f: - config_file = f.read() - base_var_dict = {} - regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}" - base_vars = set(re.findall(regexp, config_file)) - for base_var in base_vars: - randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}" - base_var_dict[randstr] = base_var - regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}" - config_file = re.sub(regexp, f'"{randstr}"', config_file) - with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: - tmp_config_file.write(config_file) - return base_var_dict - - @staticmethod - def _substitute_base_vars(cfg: Any, base_var_dict: dict, base_cfg: dict) -> Any: - """Substitute base variables from strings to their actual values. - - Args: - Any : Config dictionary. - base_var_dict (dict): A dictionary contains variables in base - config. - base_cfg (dict): Base config dictionary. - - Returns: - Any : A dictionary with origin base variables - substituted with actual values. - """ - cfg = copy.deepcopy(cfg) - - if isinstance(cfg, dict): - for k, v in cfg.items(): - if isinstance(v, str) and v in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[v].split("."): - new_v = new_v[new_k] - cfg[k] = new_v - elif isinstance(v, list | tuple | dict): - cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg) - elif isinstance(cfg, tuple): - cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg) - elif isinstance(cfg, list): - cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg] - elif isinstance(cfg, str) and cfg in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[cfg].split("."): - new_v = new_v[new_k] - cfg = new_v - - return cfg - - @staticmethod - def _file2dict( - filename: str, - use_predefined_variables: bool = True, - use_environment_variables: bool = True, - lazy_import: bool | None = None, - ) -> tuple[dict, str, dict]: - """Transform file to variables dictionary. - - Args: - filename (str): Name of config file. - use_predefined_variables (bool, optional): Whether to use - predefined variables. Defaults to True. - use_environment_variables (bool, optional): Whether to use - environment variables. Defaults to True. - lazy_import (bool): Whether to load config in `lazy_import` mode. - If it is `None`, it will be deduced by the content of the - config file. Defaults to None. - - Returns: - Tuple[dict, str]: Variables dictionary and text of Config. - """ - # Auto-detect lazy_import if not specified - if lazy_import is None: - lazy_import = Config._is_lazy_import(filename) - - # If this is a lazy import file, use lazy import parsing - if lazy_import: - cfg_dict, imported_names = Config._parse_lazy_import(filename) - cfg_dict = Config._dict_to_config_dict_lazy(cfg_dict) - return cfg_dict, "" - - filename = osp.abspath(osp.expanduser(filename)) - check_file_exist(filename) - fileExtname = osp.splitext(filename)[1] - if fileExtname not in [".py", ".json", ".yaml", ".yml"]: - raise OSError("Only py/yml/yaml/json type are supported now!") - try: - with tempfile.TemporaryDirectory() as temp_config_dir: - temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname, delete=False) - if platform.system() == "Windows": - temp_config_file.close() - - # Substitute predefined variables - if use_predefined_variables: - Config._substitute_predefined_vars(filename, temp_config_file.name) - else: - shutil.copyfile(filename, temp_config_file.name) - # Substitute environment variables - env_variables = {} - if use_environment_variables: - env_variables = Config._substitute_env_variables(temp_config_file.name, temp_config_file.name) - # Substitute base variables from placeholders to strings - base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name) - - # Handle base files - base_cfg_dict = ConfigDict() - cfg_text_list = [] - for base_cfg_path in Config._get_base_files(temp_config_file.name): - base_cfg_path, scope = Config._get_cfg_path(base_cfg_path, filename) - _cfg_dict, _cfg_text, _env_variables = Config._file2dict( - filename=base_cfg_path, - use_predefined_variables=use_predefined_variables, - use_environment_variables=use_environment_variables, - lazy_import=lazy_import, - ) - cfg_text_list.append(_cfg_text) - env_variables.update(_env_variables) - duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() - if len(duplicate_keys) > 0: - raise KeyError(f"Duplicate key is not allowed among bases. Duplicate keys: {duplicate_keys}") - - # _dict_to_config_dict will do the following things: - # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. - # 2. Set `_scope_` for the outer dict variable for the base - # config. - # 3. Set `scope` attribute for each base variable. - # Different from `_scope_`, `scope` is not a key of base - # dict, `scope` attribute will be parsed to key `_scope_` - # by function `_parse_scope` only if the base variable is - # accessed by the current config. - _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) - base_cfg_dict.update(_cfg_dict) - - if filename.endswith(".py"): - with open(temp_config_file.name, encoding="utf-8") as f: - parsed_codes = ast.parse(f.read()) - parsed_codes = RemoveAssignFromAST(BASE_KEY).visit(parsed_codes) - codeobj = compile(parsed_codes, filename, mode="exec") - # Support load global variable in nested function of the - # config. - global_locals_var = {BASE_KEY: base_cfg_dict} - ori_keys = set(global_locals_var.keys()) - eval(codeobj, global_locals_var, global_locals_var) - cfg_dict = { - key: value - for key, value in global_locals_var.items() - if (key not in ori_keys and not key.startswith("__")) - } - elif filename.endswith((".yml", ".yaml", ".json")): - cfg_dict = load(temp_config_file.name) - # close temp file - for key, value in list(cfg_dict.items()): - if isinstance(value, types.FunctionType | types.ModuleType): - cfg_dict.pop(key) - temp_config_file.close() - - # If the current config accesses a base variable of base - # configs, The ``scope`` attribute of corresponding variable - # will be converted to the `_scope_`. - Config._parse_scope(cfg_dict) - except Exception as e: - if osp.exists(temp_config_dir): - shutil.rmtree(temp_config_dir) - raise e - - # check deprecation information - if DEPRECATION_KEY in cfg_dict: - deprecation_info = cfg_dict.pop(DEPRECATION_KEY) - warning_msg = f"The config file {filename} will be deprecated in the future." - if "expected" in deprecation_info: - warning_msg += f" Please use {deprecation_info['expected']} instead." - if "reference" in deprecation_info: - warning_msg += f" More information can be found at {deprecation_info['reference']}" - warnings.warn(warning_msg, DeprecationWarning, stacklevel=2) - - cfg_text = filename + "\n" - with open(filename, encoding="utf-8") as f: - # Setting encoding explicitly to resolve coding issue on windows - cfg_text += f.read() - - # Substitute base variables from strings to their actual values - cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict) - cfg_dict.pop(BASE_KEY, None) - - cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) - cfg_dict = {k: v for k, v in cfg_dict.items() if not k.startswith("__")} - - # merge cfg_text - cfg_text_list.append(cfg_text) - cfg_text = "\n".join(cfg_text_list) - - return cfg_dict, cfg_text, env_variables - - @staticmethod - def _parse_lazy_import(filename: str) -> tuple[ConfigDict, set]: - """Transform file to variables dictionary. - - Args: - filename (str): Name of config file. - - Returns: - Tuple[dict, dict]: ``cfg_dict`` and ``imported_names``. - - - cfg_dict (dict): Variables dictionary of parsed config. - - imported_names (set): Used to mark the names of - imported object. - """ - # In lazy import mode, users can use the Python syntax `import` to - # implement inheritance between configuration files, which is easier - # for users to understand the hierarchical relationships between - # different configuration files. - - # Besides, users can also using `import` syntax to import corresponding - # module which will be filled in the `type` field. It means users - # can directly navigate to the source of the module in the - # configuration file by clicking the `type` field. - - # To avoid really importing the third party package like `torch` - # during import `type` object, we use `_parse_lazy_import` to parse the - # configuration file, which will not actually trigger the import - # process, but simply parse the imported `type`s as LazyObject objects. - - # The overall pipeline of _parse_lazy_import is: - # 1. Parse the base module from the config file. - # || - # \/ - # base_module = ['mmdet.configs.default_runtime'] - # || - # \/ - # 2. recursively parse the base module and gather imported objects to - # a dict. - # || - # \/ - # The base_dict will be: - # { - # 'mmdet.configs.default_runtime': {...} - # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} - # ... - # }, each item in base_dict is a dict of `LazyObject` - # 3. parse the current config file filling the imported variable - # with the base_dict. - # - # 4. During the parsing process, all imported variable will be - # recorded in the `imported_names` set. These variables can be - # accessed, but will not be dumped by default. - - with open(filename, encoding="utf-8") as f: - global_dict = {"LazyObject": LazyObject, "__file__": filename} - base_dict = {} - - parsed_codes = ast.parse(f.read()) - # get the names of base modules, and remove the - # `with read_base():'` statement - base_modules = Config._get_base_modules(parsed_codes.body) - base_imported_names = set() - for base_module in base_modules: - # If base_module means a relative import, assuming the level is - # 2, which means the module is imported like - # "from ..a.b import c". we must ensure that c is an - # object `defined` in module b, and module b should not be a - # package including `__init__` file but a single python file. - level = len(re.match(r"\.*", base_module).group()) - if level > 0: - # Relative import - base_dir = osp.dirname(filename) - module_path = osp.join( - base_dir, - *([".."] * (level - 1)), - f"{base_module[level:].replace('.', '/')}.py", - ) - else: - # Absolute import - module_list = base_module.split(".") - if len(module_list) == 1: - raise ConfigParsingError( - "The imported configuration file should not be " - f"an independent package {module_list[0]}. Here " - "is an example: " - "`with read_base(): from mmdet.configs.retinanet_r50_fpn_1x_coco import *`" - ) - else: - package = module_list[0] - root_path = get_installed_path(package) - module_path = f"{osp.join(root_path, *module_list[1:])}.py" - if not osp.isfile(module_path): - raise ConfigParsingError( - f"{module_path} not found! It means that incorrect " - "module is defined in " - f"`with read_base(): = from {base_module} import ...`, please " - "make sure the base config module is valid " - "and is consistent with the prior import " - "logic" - ) - _base_cfg_dict, _base_imported_names = Config._parse_lazy_import(module_path) - base_imported_names |= _base_imported_names - # The base_dict will be: - # { - # 'mmdet.configs.default_runtime': {...} - # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} - # ... - # } - base_dict[base_module] = _base_cfg_dict - - # `base_dict` contains all the imported modules from `base_cfg`. - # In order to collect the specific imported module from `base_cfg` - # before parse the current file, we using AST Transform to - # transverse the imported module from base_cfg and merge then into - # the global dict. After the ast transformation, most of import - # syntax will be removed (except for the builtin import) and - # replaced with the `LazyObject` - transform = ImportTransformer(global_dict=global_dict, base_dict=base_dict, filename=filename) - modified_code = transform.visit(parsed_codes) - modified_code, abs_imported = _gather_abs_import_lazyobj(modified_code, filename=filename) - imported_names = transform.imported_obj | abs_imported - imported_names |= base_imported_names - modified_code = ast.fix_missing_locations(modified_code) - exec(compile(modified_code, filename, mode="exec"), global_dict, global_dict) - - ret: dict = {} - for key, value in global_dict.items(): - if key.startswith("__") or key in ["LazyObject"]: - continue - ret[key] = value - # convert dict to ConfigDict - cfg_dict = Config._dict_to_config_dict_lazy(ret) - - return cfg_dict, imported_names - - @staticmethod - def _dict_to_config_dict_lazy(cfg: dict): - """Recursively converts ``dict`` to :obj:`ConfigDict`. The only - difference between ``_dict_to_config_dict_lazy`` and - ``_dict_to_config_dict_lazy`` is that the former one does not consider - the scope, and will not trigger the building of ``LazyObject``. - - Args: - cfg (dict): Config dict. - - Returns: - ConfigDict: Converted dict. - """ - # Only the outer dict with key `type` should have the key `_scope_`. - if isinstance(cfg, dict): - cfg_dict = ConfigDict() - for key, value in cfg.items(): - cfg_dict[key] = Config._dict_to_config_dict_lazy(value) - return cfg_dict - if isinstance(cfg, tuple | list): - return type(cfg)(Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) - return cfg - - @staticmethod - def _dict_to_config_dict(cfg: dict, scope: str | None = None, has_scope=True): - """Recursively converts ``dict`` to :obj:`ConfigDict`. - - Args: - cfg (dict): Config dict. - scope (str, optional): Scope of instance. - has_scope (bool): Whether to add `_scope_` key to config dict. - - Returns: - ConfigDict: Converted dict. - """ - # Only the outer dict with key `type` should have the key `_scope_`. - if isinstance(cfg, dict): - if has_scope and "type" in cfg: - has_scope = False - if scope is not None and cfg.get("_scope_", None) is None: - cfg._scope_ = scope # type: ignore - cfg = ConfigDict(cfg) - dict.__setattr__(cfg, "scope", scope) - for key, value in cfg.items(): - cfg[key] = Config._dict_to_config_dict(value, scope=scope, has_scope=has_scope) - elif isinstance(cfg, tuple): - cfg = tuple(Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) for _cfg in cfg) - elif isinstance(cfg, list): - cfg = [Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) for _cfg in cfg] - return cfg - - @staticmethod - def _parse_scope(cfg: dict) -> None: - """Adds ``_scope_`` to :obj:`ConfigDict` instance, which means a base - variable. - - If the config dict already has the scope, scope will not be - overwritten. - - Args: - cfg (dict): Config needs to be parsed with scope. - """ - if isinstance(cfg, ConfigDict): - cfg._scope_ = cfg.scope - elif isinstance(cfg, tuple | list): - [Config._parse_scope(value) for value in cfg] - else: - return - - @staticmethod - def _get_base_files(filename: str) -> list: - """Get the base config file. - - Args: - filename (str): The config file. - - Raises: - TypeError: Name of config file. - - Returns: - list: A list of base config. - """ - file_format = osp.splitext(filename)[1] - if file_format == ".py": - Config._validate_py_syntax(filename) - with open(filename, encoding="utf-8") as f: - parsed_codes = ast.parse(f.read()).body - - def is_base_line(c): - return ( - isinstance(c, ast.Assign) and isinstance(c.targets[0], ast.Name) and c.targets[0].id == BASE_KEY - ) - - base_code = next((c for c in parsed_codes if is_base_line(c)), None) - if base_code is not None: - base_code = ast.Expression( # type: ignore - body=base_code.value - ) # type: ignore - base_files = eval(compile(base_code, "", mode="eval")) # type: ignore - else: - base_files = [] - elif file_format in (".yml", ".yaml", ".json"): - import visengine - - cfg_dict = visengine.load(filename) - base_files = cfg_dict.get(BASE_KEY, []) - else: - raise ConfigParsingError(f"The config type should be py, json, yaml or yml, but got {file_format}") - base_files = base_files if isinstance(base_files, list) else [base_files] - return base_files - - @staticmethod - def _get_cfg_path(cfg_path: str, filename: str) -> tuple[str, str | None]: - """Get the config path from the current or external package. - - Args: - cfg_path (str): Relative path of config. - filename (str): The config file being parsed. - - Returns: - Tuple[str, str or None]: Path and scope of config. If the config - is not an external config, the scope will be `None`. - """ - if "::" in cfg_path: - # `cfg_path` startswith '::' means an external config path. - # Get package name and relative config path. - scope = cfg_path.partition("::")[0] - package, cfg_path = _get_package_and_cfg_path(cfg_path) - - if not is_installed(package): - raise ModuleNotFoundError(f"{package} is not installed, please install {package} manually") - - # Get installed package path. - package_path = get_installed_path(package) - try: - # Get config path from meta file. - cfg_path = _get_external_cfg_path(package_path, cfg_path) - except ValueError: - # Since base config does not have a metafile, it should be - # concatenated with package path and relative config path. - cfg_path = _get_external_cfg_base_path(package_path, cfg_path) - except FileNotFoundError as e: - raise e - return cfg_path, scope - else: - # Get local config path. - cfg_dir = osp.dirname(filename) - cfg_path = osp.join(cfg_dir, cfg_path) - return cfg_path, None - - @staticmethod - def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict: - """Merge dict ``a`` into dict ``b`` (non-inplace). - - Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid - in-place modifications. - - Args: - a (dict): The source dict to be merged into ``b``. - b (dict): The origin dict to be fetch keys from ``a``. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in source ``a`` and will replace the element of the - corresponding index in b if b is a list. Defaults to False. - - Returns: - dict: The modified dict of ``b`` using ``a``. - - Examples: - # Normally merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # Delete b first and merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # b is a list - >>> Config._merge_a_into_b( - ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) - [{'a': 2}, {'b': 2}] - """ - b = b.copy() - for k, v in a.items(): - if allow_list_keys and k.isdigit() and isinstance(b, list): - k = int(k) - if len(b) <= k: - raise KeyError(f"Index {k} exceeds the length of list {b}") - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - elif isinstance(v, dict): - if k in b and not v.pop(DELETE_KEY, False): - allowed_types: tuple | type = (dict, list) if allow_list_keys else dict - if not isinstance(b[k], allowed_types): - raise TypeError( - f"{k}={v} in child config cannot inherit from " - f"base because {k} is a dict in the child config " - f"but is of type {type(b[k])} in base config. " - f"You may set `{DELETE_KEY}=True` to ignore the " - f"base config." - ) - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - else: - b[k] = ConfigDict(v) - else: - b[k] = v - return b - - @staticmethod - def auto_argparser(description=None): - """Generate argparser from config file automatically (experimental)""" - partial_parser = ArgumentParser(description=description) - partial_parser.add_argument("config", help="config file path") - cfg_file = partial_parser.parse_known_args()[0].config - cfg = Config.fromfile(cfg_file) - parser = ArgumentParser(description=description) - parser.add_argument("config", help="config file path") - add_args(parser, cfg) - return parser, cfg - - @property - def filename(self) -> str: - """Get file name of config.""" - return self._filename - - @property - def text(self) -> str: - """Get config text.""" - return self._text - - @property - def env_variables(self) -> dict: - """Get used environment variables.""" - return self._env_variables - - @property - def pretty_text(self) -> str: - """Get formatted python config text.""" - - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split("\n") - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = repr(v) - else: - v_str = str(v) - - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f"{k_str}: {v_str}" - else: - attr_str = f"{k!s}={v_str}" - attr_str = _indent(attr_str, indent) - - return attr_str - - def _format_list_tuple(k, v, use_mapping=False): - if isinstance(v, list): - left = "[" - right = "]" - else: - left = "(" - right = ")" - - v_str = f"{left}\n" - # check if all items in the list are dict - for item in v: - if isinstance(item, dict): - v_str += f"dict({_indent(_format_dict(item), indent)}),\n" - elif isinstance(item, tuple): - v_str += f"{_indent(_format_list_tuple(None, item), indent)},\n" - elif isinstance(item, list): - v_str += f"{_indent(_format_list_tuple(None, item), indent)},\n" - elif isinstance(item, str): - v_str += f"{_indent(repr(item), indent)},\n" - else: - v_str += str(item) + ",\n" - if k is None: - return _indent(v_str, indent) + right - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f"{k_str}: {v_str}" - else: - attr_str = f"{k!s}={v_str}" - attr_str = _indent(attr_str, indent) + right - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= not str(key_name).isidentifier() - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = "" - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += "{" - for idx, (k, v) in enumerate(sorted(input_dict.items(), key=lambda x: str(x[0]))): - is_last = idx >= len(input_dict) - 1 - end = "" if outest_level or is_last else "," - if isinstance(v, dict): - v_str = "\n" + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f"{k_str}: dict({v_str}" - else: - attr_str = f"{k!s}=dict({v_str}" - attr_str = _indent(attr_str, indent) + ")" + end - elif isinstance(v, list | tuple): - attr_str = _format_list_tuple(k, v, use_mapping) + end - else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += "\n".join(s) - if use_mapping: - r += "}" - return r - - cfg_dict = self.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - if self._format_python_code: - # copied from setup.cfg - yapf_style = { - "based_on_style": "pep8", - "blank_line_before_nested_class_or_def": True, - "split_before_expression_after_opening_paren": True, - } - try: - if digit_version(yapf.__version__) >= digit_version("0.40.2"): - text, _ = FormatCode(text, style_config=yapf_style) - else: - text, _ = FormatCode(text, style_config=yapf_style, verify=True) - except: # noqa: E722 - raise SyntaxError(f"Failed to format the config file, please check the syntax of: \n{text}") - return text - - def __repr__(self): - return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" - - def __len__(self): - return len(self._cfg_dict) - - def __getattr__(self, name: str) -> Any: - return getattr(self._cfg_dict, name) - - def __getitem__(self, name): - return self._cfg_dict.__getitem__(name) - - def __setattr__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setattr__(name, value) - - def __setitem__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setitem__(name, value) - - def __iter__(self): - return iter(self._cfg_dict) - - def __getstate__(self) -> tuple[dict, str | None, str | None, dict, bool, set]: - state = ( - self._cfg_dict, - self._filename, - self._text, - self._env_variables, - self._format_python_code, - self._imported_names, - ) - return state - - def __deepcopy__(self, memo): - cls = self.__class__ - other = cls.__new__(cls) - memo[id(self)] = other - - for key, value in self.__dict__.items(): - super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) - - return other - - def __copy__(self): - cls = self.__class__ - other = cls.__new__(cls) - other.__dict__.update(self.__dict__) - super(Config, other).__setattr__("_cfg_dict", self._cfg_dict.copy()) - - return other - - copy = __copy__ - - def __setstate__(self, state: tuple[dict, str | None, str | None, dict, bool, set]): - super().__setattr__("_cfg_dict", state[0]) - super().__setattr__("_filename", state[1]) - super().__setattr__("_text", state[2]) - super().__setattr__("_env_variables", state[3]) - super().__setattr__("_format_python_code", state[4]) - super().__setattr__("_imported_names", state[5]) - - def dump(self, file: str | Path | None = None): - """Dump config to file or return config text. - - Args: - file (str or Path, optional): If not specified, then the object - is dumped to a str, otherwise to a file specified by the filename. - Defaults to None. - - Returns: - str or None: Config text. - """ - file = str(file) if isinstance(file, Path) else file - cfg_dict = self.to_dict() - if file is None: - if self.filename is None or self.filename.endswith(".py"): - return self.pretty_text - else: - file_format = self.filename.split(".")[-1] - return dump(cfg_dict, file_format=file_format) - elif file.endswith(".py"): - with open(file, "w", encoding="utf-8") as f: - f.write(self.pretty_text) - else: - file_format = file.split(".")[-1] - return dump(cfg_dict, file=file, file_format=file_format) - - def merge_from_dict(self, options: dict, allow_list_keys: bool = True) -> None: - """Merge list into cfg_dict. - - Merge the dict parsed by MultipleKVAction into this cfg. - - Args: - options (dict): dict of configs to merge from. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in ``options`` and will replace the element of the - corresponding index in the config if the config is a list. - Defaults to True. - - Examples: - >>> from mmengine import Config - >>> # Merge dictionary element - >>> options = {'model.backbone.depth': 50, 'model.backbone.with_cp': True} - >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) - >>> cfg.merge_from_dict(options) - >>> cfg._cfg_dict - {'model': {'backbone': {'type': 'ResNet', 'depth': 50, 'with_cp': True}}} - >>> # Merge list element - >>> cfg = Config( - >>> dict(pipeline=[dict(type='LoadImage'), - >>> dict(type='LoadAnnotations')])) - >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) - >>> cfg.merge_from_dict(options, allow_list_keys=True) - >>> cfg._cfg_dict - {'pipeline': [{'type': 'SelfLoadImage'}, {'type': 'LoadAnnotations'}]} - """ - option_cfg_dict: dict = {} - for full_key, v in options.items(): - d = option_cfg_dict - key_list = full_key.split(".") - for subkey in key_list[:-1]: - d.setdefault(subkey, ConfigDict()) - d = d[subkey] - subkey = key_list[-1] - d[subkey] = v - - cfg_dict = super().__getattribute__("_cfg_dict") - super().__setattr__( - "_cfg_dict", - Config._merge_a_into_b(option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys), - ) - - @staticmethod - def diff(cfg1: Union[str, "Config"], cfg2: Union[str, "Config"]) -> str: - if isinstance(cfg1, str): - cfg1 = Config.fromfile(cfg1) - - if isinstance(cfg2, str): - cfg2 = Config.fromfile(cfg2) - - res = difflib.unified_diff(cfg1.pretty_text.split("\n"), cfg2.pretty_text.split("\n")) - - # Convert into rich format for better visualization - console = Console() - text = Text() - for line in res: - if line.startswith("+"): - color = "bright_green" - elif line.startswith("-"): - color = "bright_red" - else: - color = "bright_white" - _text = Text(line + "\n") - _text.stylize(color) - text.append(_text) - - with console.capture() as capture: - console.print(text) - - return capture.get() - - @staticmethod - def _is_lazy_import(filename: str) -> bool: - if not filename.endswith(".py"): - return False - with open(filename, encoding="utf-8") as f: - codes_str = f.read() - parsed_codes = ast.parse(codes_str) - for node in ast.walk(parsed_codes): - if ( - isinstance(node, ast.Assign) - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == BASE_KEY - ): - return False - - if isinstance(node, ast.With): - expr = node.items[0].context_expr - if not isinstance(expr, ast.Call) or not expr.func.id == "read_base": # type: ignore - raise ConfigParsingError("Only `read_base` context manager can be used in the config") - return True - if isinstance(node, ast.ImportFrom): - # relative import -> lazy_import - if node.level != 0: - return True - # Skip checking when using `visengine.config` in cfg file - if node.module == "visengine" and len(node.names) == 1 and node.names[0].name == "Config": - continue - if not isinstance(node.module, str): - continue - # non-builtin module -> lazy_import - if not _is_builtin_module(node.module): - return True - if isinstance(node, ast.Import): - for alias_node in node.names: - if not _is_builtin_module(alias_node.name): - return True - return False - - def _to_lazy_dict(self, keep_imported: bool = False) -> dict: - """Convert config object to dictionary with lazy object, and filter the - imported object.""" - res = self._cfg_dict._to_lazy_dict() - if hasattr(self, "_imported_names") and not keep_imported: - res = {key: value for key, value in res.items() if key not in self._imported_names} - return res - - def to_dict(self, keep_imported: bool = False): - """Convert all data in the config to a builtin ``dict``. - - Args: - keep_imported (bool): Whether to keep the imported field. - Defaults to False - - If you import third-party objects in the config file, all imported - objects will be converted to a string like ``torch.optim.SGD`` - """ - cfg_dict = self._cfg_dict.to_dict() - if hasattr(self, "_imported_names") and not keep_imported: - cfg_dict = {key: value for key, value in cfg_dict.items() if key not in self._imported_names} - return cfg_dict - - -class DictAction(Action): - """Argparse action to split an argument into KEY=VALUE form on the first = - and append to a dictionary. - - List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3', - or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested - brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' - """ - - @staticmethod - def _parse_int_float_bool(val: str) -> int | float | bool | Any: - """Parse int/float/bool value in the string.""" - try: - return int(val) - except ValueError: - pass - try: - return float(val) - except ValueError: - pass - if val.lower() in ["true", "false"]: - return True if val.lower() == "true" else False - if val == "None": - return None - return val - - @staticmethod - def _parse_iterable(val: str) -> list | tuple | Any: - """Parse iterable values in the string. - - All elements inside '()' or '[]' are treated as iterable values. - - Args: - val (str): Value string. - - Returns: - list | tuple | Any: The expanded list or tuple from the string, - or single value if no iterable values are found. - - Examples: - >>> DictAction._parse_iterable('1,2,3') - [1, 2, 3] - >>> DictAction._parse_iterable('[a, b, c]') - ['a', 'b', 'c'] - >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') - [(1, 2, 3), ['a', 'b'], 'c'] - """ - - def find_next_comma(string): - """Find the position of next comma in the string. - - If no ',' is found in the string, return the string length. All - chars inside '()' and '[]' are treated as one element and thus ',' - inside these brackets are ignored. - """ - assert (string.count("(") == string.count(")")) and (string.count("[") == string.count("]")), ( - f"Imbalanced brackets exist in {string}" - ) - end = len(string) - for idx, char in enumerate(string): - pre = string[:idx] - # The string before this ',' is balanced - if (char == ",") and (pre.count("(") == pre.count(")")) and (pre.count("[") == pre.count("]")): - end = idx - break - return end - - # Strip ' and " characters and replace whitespace. - val = val.strip("'\"").replace(" ", "") - is_tuple = False - if val.startswith("(") and val.endswith(")"): - is_tuple = True - val = val[1:-1] - elif val.startswith("[") and val.endswith("]"): - val = val[1:-1] - elif "," not in val: - # val is a single value - return DictAction._parse_int_float_bool(val) - - values = [] - while len(val) > 0: - comma_idx = find_next_comma(val) - element = DictAction._parse_iterable(val[:comma_idx]) - values.append(element) - val = val[comma_idx + 1 :] - - if is_tuple: - return tuple(values) - - return values - - def __call__( - self, - parser: ArgumentParser, - namespace: Namespace, - values: str | Sequence[Any] | None, - option_string: str | None = None, - ): # type: ignore - """Parse Variables in string and add them into argparser. - - Args: - parser (ArgumentParser): Argument parser. - namespace (Namespace): Argument namespace. - values (Union[str, Sequence[Any], None]): Argument string. - option_string (list[str], optional): Option string. - Defaults to None. - """ - # Copied behavior from `argparse._ExtendAction`. - options = copy.copy(getattr(namespace, self.dest, None) or {}) - if values is not None: - for kv in values: - key, val = kv.split("=", maxsplit=1) - options[key] = self._parse_iterable(val) - setattr(namespace, self.dest, options) - - -@contextmanager -def read_base(): - """Context manager to mark the base config. - - The pure Python-style configuration file allows you to use the import - syntax. However, it is important to note that you need to import the base - configuration file within the context of ``read_base``, and import other - dependencies outside of it. - - You can see more usage of Python-style configuration in the `tutorial`_ - - .. _tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta - """ - yield diff --git a/libs/visengine/visengine/config/lazy.py b/libs/visengine/visengine/config/lazy.py deleted file mode 100644 index 1f2b8f9..0000000 --- a/libs/visengine/visengine/config/lazy.py +++ /dev/null @@ -1,241 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import importlib -from typing import Any, Optional, Union - -from visengine.utils import is_seq_of - - -class LazyObject: - """LazyObject is used to lazily initialize the imported module during - parsing the configuration file. - - During parsing process, the syntax like: - - Examples: - >>> import torch.nn as nn - >>> from mmdet.models import RetinaNet - >>> import mmcls.models - >>> import mmcls.datasets - >>> import mmcls - - Will be parsed as: - - Examples: - >>> # import torch.nn as nn - >>> nn = lazyObject('torch.nn') - >>> # from mmdet.models import RetinaNet - >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') - >>> # import mmcls.models; import mmcls.datasets; import mmcls - >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) - - ``LazyObject`` records all module information and will be further - referenced by the configuration file. - - Args: - module (str or list or tuple): The module name to be imported. - imported (str, optional): The imported module name. Defaults to None. - location (str, optional): The filename and line number of the imported - module statement happened. - """ - - def __init__( - self, - module: Union[str, list, tuple], - imported: Optional[str] = None, - location: Optional[str] = None, - ): - if not isinstance(module, str) and not is_seq_of(module, str): - raise TypeError( - "module should be `str`, `list`, or `tuple`" - f"but got {type(module)}, this might be " - "a bug of MMEngine, please report it to " - "https://github.com/open-mmlab/mmengine/issues" - ) - self._module: Union[str, list, tuple] = module - - if not isinstance(imported, str) and imported is not None: - raise TypeError( - "imported should be `str` or None, but got " - f"{type(imported)}, this might be " - "a bug of MMEngine, please report it to " - "https://github.com/open-mmlab/mmengine/issues" - ) - self._imported = imported - self.location = location - - def build(self) -> Any: - """Return imported object. - - Returns: - Any: Imported object - """ - if isinstance(self._module, str): - try: - module = importlib.import_module(self._module) - except Exception as e: - raise type(e)(f"Failed to import {self._module} in {self.location} for {e}") - - if self._imported is not None: - if hasattr(module, self._imported): - module = getattr(module, self._imported) - else: - raise ImportError(f"Failed to import {self._imported} from {self._module} in {self.location}") - - return module - else: - # import xxx.xxx - # import xxx.yyy - # import xxx.zzz - # return imported xxx - try: - for module in self._module: - importlib.import_module(module) # type: ignore - module_name = self._module[0].split(".")[0] - return importlib.import_module(module_name) - except Exception as e: - raise type(e)(f"Failed to import {self.module} in {self.location} for {e}") - - @property - def module(self): - if isinstance(self._module, str): - return self._module - return self._module[0].split(".")[0] - - def __call__(self, *args, **kwargs): - raise RuntimeError() - - def __deepcopy__(self, memo): - return LazyObject(self._module, self._imported, self.location) - - def __getattr__(self, name): - # Cannot locate the line number of the getting attribute. - # Therefore only record the filename. - if self.location is not None: - location = self.location.split(", line")[0] - else: - location = self.location - return LazyAttr(name, self, location) - - def __str__(self) -> str: - if self._imported is not None: - return self._imported - return self.module - - __repr__ = __str__ - - # `pickle.dump` will try to get the `__getstate__` and `__setstate__` - # methods of the dumped object. If these two methods are not defined, - # LazyObject will return a `__getstate__` LazyObject` or `__setstate__` - # LazyObject. - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, state): - self.__dict__ = state - - -class LazyAttr: - """The attribute of the LazyObject. - - When parsing the configuration file, the imported syntax will be - parsed as the assignment ``LazyObject``. During the subsequent parsing - process, users may reference the attributes of the LazyObject. - To ensure that these attributes also contain information needed to - reconstruct the attribute itself, LazyAttr was introduced. - - Examples: - >>> models = LazyObject(['mmdet.models']) - >>> model = dict(type=models.RetinaNet) - >>> print(type(model['type'])) # - >>> print(model['type'].build()) # - """ # noqa: E501 - - def __init__(self, name: str, source: Union["LazyObject", "LazyAttr"], location=None): - self.name = name - self.source: Union[LazyAttr, LazyObject] = source - - if isinstance(self.source, LazyObject): - if isinstance(self.source._module, str): - if self.source._imported is None: - # source code: - # from xxx.yyy import zzz - # equivalent code: - # zzz = LazyObject('xxx.yyy', 'zzz') - # The source code of get attribute: - # eee = zzz.eee - # Then, `eee._module` should be "xxx.yyy.zzz" - self._module = self.source._module - else: - # source code: - # import xxx.yyy as zzz - # equivalent code: - # zzz = LazyObject('xxx.yyy') - # The source code of get attribute: - # eee = zzz.eee - # Then, `eee._module` should be "xxx.yyy" - self._module = f"{self.source._module}.{self.source}" - else: - # The source code of LazyObject should be - # 1. import xxx.yyy - # 2. import xxx.zzz - # Equivalent to - # xxx = LazyObject(['xxx.yyy', 'xxx.zzz']) - - # The source code of LazyAttr should be - # eee = xxx.eee - # Then, eee._module = xxx - self._module = str(self.source) - elif isinstance(self.source, LazyAttr): - # 1. import xxx - # 2. zzz = xxx.yyy.zzz - - # Equivalent to: - # xxx = LazyObject('xxx') - # zzz = xxx.yyy.zzz - # zzz._module = xxx.yyy._module + zzz.name - self._module = f"{self.source._module}.{self.source.name}" - self.location = location - - @property - def module(self): - return self._module - - def __call__(self, *args, **kwargs: Any) -> Any: - raise RuntimeError() - - def __getattr__(self, name: str) -> "LazyAttr": - return LazyAttr(name, self) - - def __deepcopy__(self, memo): - return LazyAttr(self.name, self.source) - - def build(self) -> Any: - """Return the attribute of the imported object. - - Returns: - Any: attribute of the imported object. - """ - obj = self.source.build() - try: - return getattr(obj, self.name) - except AttributeError: - raise ImportError(f"Failed to import {self.module}.{self.name} in {self.location}") - except ImportError as e: - raise e - - def __str__(self) -> str: - return self.name - - __repr__ = __str__ - - # `pickle.dump` will try to get the `__getstate__` and `__setstate__` - # methods of the dumped object. If these two methods are not defined, - # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__` - # LazyAttr. - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, state): - self.__dict__ = state diff --git a/libs/visengine/visengine/config/utils.py b/libs/visengine/visengine/config/utils.py deleted file mode 100644 index 37d5cab..0000000 --- a/libs/visengine/visengine/config/utils.py +++ /dev/null @@ -1,469 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import ast -import os.path as osp -import re -import sys -import warnings -from collections import defaultdict -from importlib.util import find_spec -from typing import List, Optional, Tuple, Union - -from visengine.fileio import load -from visengine.utils import check_file_exist - -PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable)) -SYSTEM_PYTHON_PREFIX = "/usr/lib/python" - -MODULE2PACKAGE = { - "mmcls": "mmcls", - "mmdet": "mmdet", - "visdet": "visdet", - "mmdet3d": "mmdet3d", - "mmseg": "mmsegmentation", - "mmaction": "mmaction2", - "mmtrack": "mmtrack", - "mmpose": "mmpose", - "mmedit": "mmedit", - "mmocr": "mmocr", - "mmgen": "mmgen", - "mmfewshot": "mmfewshot", - "mmrazor": "mmrazor", - "mmflow": "mmflow", - "mmhuman3d": "mmhuman3d", - "mmrotate": "mmrotate", - "mmselfsup": "mmselfsup", - "mmyolo": "mmyolo", - "mmpretrain": "mmpretrain", - "mmagic": "mmagic", -} - -# PKG2PROJECT is not a proper name to represent the mapping between module name -# (module import from) and package name (used by pip install). Therefore, -# PKG2PROJECT will be deprecated and this alias will only be kept until -# MMEngine v1.0.0 -PKG2PROJECT = MODULE2PACKAGE - - -class ConfigParsingError(RuntimeError): - """Raise error when failed to parse pure Python style config files.""" - - -def _get_cfg_metainfo(package_path: str, cfg_path: str) -> dict: - """Get target meta information from all 'metafile.yml' defined in `mode- - index.yml` of external package. - - Args: - package_path (str): Path of external package. - cfg_path (str): Name of experiment config. - - Returns: - dict: Meta information of target experiment. - """ - meta_index_path = osp.join(package_path, ".mim", "model-index.yml") - meta_index = load(meta_index_path) - cfg_dict = dict() - for meta_path in meta_index["Import"]: - meta_path = osp.join(package_path, meta_path) - cfg_meta = load(meta_path) - for model_cfg in cfg_meta["Models"]: - if "Config" not in model_cfg: - warnings.warn(f"There is not `Config` define in {model_cfg}") - continue - cfg_name = model_cfg["Config"].partition("/")[-1] - # Some config could have multiple weights, we only pick the - # first one. - if cfg_name in cfg_dict: - continue - cfg_dict[cfg_name] = model_cfg - if cfg_path not in cfg_dict: - raise ValueError(f"Expected configs: {cfg_dict.keys()}, but got {cfg_path}") - return cfg_dict[cfg_path] - - -def _get_external_cfg_path(package_path: str, cfg_file: str) -> str: - """Get config path of external package. - - Args: - package_path (str): Path of external package. - cfg_file (str): Name of experiment config. - - Returns: - str: Absolute config path from external package. - """ - cfg_file = cfg_file.split(".")[0] - model_cfg = _get_cfg_metainfo(package_path, cfg_file) - cfg_path = osp.join(package_path, model_cfg["Config"]) - check_file_exist(cfg_path) - return cfg_path - - -def _get_external_cfg_base_path(package_path: str, cfg_name: str) -> str: - """Get base config path of external package. - - Args: - package_path (str): Path of external package. - cfg_name (str): External relative config path with 'package::'. - - Returns: - str: Absolute config path from external package. - """ - cfg_path = osp.join(package_path, "configs", cfg_name) - check_file_exist(cfg_path) - return cfg_path - - -def _get_package_and_cfg_path(cfg_path: str) -> Tuple[str, str]: - """Get package name and relative config path. - - Args: - cfg_path (str): External relative config path with 'package::'. - - Returns: - Tuple[str, str]: Package name and config path. - """ - if re.match(r"\w*::\w*/\w*", cfg_path) is None: - raise ValueError( - "`_get_package_and_cfg_path` is used for get external package, " - "please specify the package name and relative config path, just " - "like `mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py`" - ) - package_cfg = cfg_path.split("::") - if len(package_cfg) > 2: - raise ValueError( - f"`::` should only be used to separate package and config name, but found multiple `::` in {cfg_path}" - ) - package, cfg_path = package_cfg - assert package in MODULE2PACKAGE, f"mmengine does not support to load {package} config." - package = MODULE2PACKAGE[package] - return package, cfg_path - - -class RemoveAssignFromAST(ast.NodeTransformer): - """Remove Assign node if the target's name match the key. - - Args: - key (str): The target name of the Assign node. - """ - - def __init__(self, key): - self.key = key - - def visit_Assign(self, node): - if isinstance(node.targets[0], ast.Name) and node.targets[0].id == self.key: - return None - else: - return node - - -def _is_builtin_module(module_name: str) -> bool: - """Check if a module is a built-in module. - - Arg: - module_name: name of module. - """ - if module_name.startswith("."): - return False - if module_name.startswith("visengine.config"): - return True - if module_name in sys.builtin_module_names: - return True - spec = find_spec(module_name.split(".")[0]) - # Module not found - if spec is None: - return False - origin_path = getattr(spec, "origin", None) - if origin_path is None: - return False - origin_path = osp.abspath(origin_path) - if ( - "site-package" in origin_path - or "dist-package" in origin_path - or not origin_path.startswith((PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX)) - ): - return False - else: - return True - - -class ImportTransformer(ast.NodeTransformer): - """Convert the import syntax to the assignment of - :class:`mmengine.config.LazyObject` and preload the base variable before - parsing the configuration file. - - Since you are already looking at this part of the code, I believe you must - be interested in the mechanism of the ``lazy_import`` feature of - :class:`Config`. In this docstring, we will dive deeper into its - principles. - - Most of OpenMMLab users maybe bothered with that: - - * In most of popular IDEs, they cannot navigate to the source code in - configuration file - * In most of popular IDEs, they cannot jump to the base file in current - configuration file, which is much painful when the inheritance - relationship is complex. - - In order to solve this problem, we introduce the ``lazy_import`` mode. - - A very intuitive idea for solving this problem is to import the module - corresponding to the "type" field using the ``import`` syntax. Similarly, - we can also ``import`` base file. - - However, this approach has a significant drawback. It requires triggering - the import logic to parse the configuration file, which can be - time-consuming. Additionally, it implies downloading numerous dependencies - solely for the purpose of parsing the configuration file. - However, it's possible that only a portion of the config will actually be - used. For instance, the package used in the ``train_pipeline`` may not - be necessary for an evaluation task. Forcing users to download these - unused packages is not a desirable solution. - - To avoid this problem, we introduce :class:`mmengine.config.LazyObject` and - :class:`mmengine.config.LazyAttr`. Before we proceed with further - explanations, you may refer to the documentation of these two modules to - gain an understanding of their functionalities. - - Actually, one of the functions of ``ImportTransformer`` is to hack the - ``import`` syntax. It will replace the import syntax - (exclude import the base files) with the assignment of ``LazyObject``. - - As for the import syntax of the base file, we cannot lazy import it since - we're eager to merge the fields of current file and base files. Therefore, - another function of the ``ImportTransformer`` is to collaborate with - ``Config._parse_lazy_import`` to parse the base files. - - Args: - global_dict (dict): The global dict of the current configuration file. - If we divide ordinary Python syntax into two parts, namely the - import section and the non-import section (assuming a simple case - with imports at the beginning and the rest of the code following), - the variables generated by the import statements are stored in - global variables for subsequent code use. In this context, - the ``global_dict`` represents the global variables required when - executing the non-import code. ``global_dict`` will be filled - during visiting the parsed code. - base_dict (dict): All variables defined in base files. - - Examples: - >>> from visengine.config import read_base - >>> - >>> - >>> with read_base(): - >>> from .._base_.default_runtime import * - >>> from .._base_.datasets.coco_detection import dataset - - In this case, the base_dict will be: - - Examples: - >>> base_dict = { - >>> '.._base_.default_runtime': ... - >>> '.._base_.datasets.coco_detection': dataset} - - and `global_dict` will be updated like this: - - Examples: - >>> global_dict.update(base_dict['.._base_.default_runtime']) # `import *` means update all data - >>> global_dict.update(dataset=base_dict['.._base_.datasets.coco_detection']['dataset']) # only update `dataset` - """ # noqa: E501 - - def __init__( - self, - global_dict: dict, - base_dict: Optional[dict] = None, - filename: Optional[str] = None, - ): - self.base_dict = base_dict if base_dict is not None else {} - self.global_dict = global_dict - # In Windows, the filename could be like this: - # "C:\\Users\\runneradmin\\AppData\\Local\\" - # Although it has been an raw string, ast.parse will firstly escape - # it as the executed code: - # "C:\Users\runneradmin\AppData\Local\\\" - # As you see, the `\U` will be treated as a part of - # the escape sequence during code parsing, leading to an - # parsing error - # Here we use `encode('unicode_escape').decode()` for double escaping - if isinstance(filename, str): - filename = filename.encode("unicode_escape").decode() - self.filename = filename - self.imported_obj: set = set() - super().__init__() - - def visit_ImportFrom(self, node: ast.ImportFrom) -> Optional[Union[List[ast.Assign], ast.ImportFrom]]: - """Hack the ``from ... import ...`` syntax and update the global_dict. - - Examples: - >>> from mmdet.models import RetinaNet - - Will be parsed as: - - Examples: - >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') - - ``global_dict`` will also be updated by ``base_dict`` as the - class docstring says. - - Args: - node (ast.AST): The node of the current import statement. - - Returns: - Optional[List[ast.Assign]]: There three cases: - - * If the node is a statement of importing base files. - None will be returned. - * If the node is a statement of importing a builtin module, - node will be directly returned - * Otherwise, it will return the assignment statements of - ``LazyObject``. - """ - # Built-in modules will not be parsed as LazyObject - module = f"{node.level * '.'}{node.module}" - if _is_builtin_module(module): - # Make sure builtin module will be added into `self.imported_obj` - for alias in node.names: - if alias.asname is not None: - self.imported_obj.add(alias.asname) - elif alias.name == "*": - raise ConfigParsingError("Cannot import * from non-base config") - else: - self.imported_obj.add(alias.name) - return node - - if module in self.base_dict: - for alias_node in node.names: - if alias_node.name == "*": - self.global_dict.update(self.base_dict[module]) - return None - if alias_node.asname is not None: - base_key = alias_node.asname - else: - base_key = alias_node.name - self.global_dict[base_key] = self.base_dict[module][alias_node.name] - return None - - nodes: List[ast.Assign] = [] - for alias_node in node.names: - # `ast.alias` has lineno attr after Python 3.10, - if hasattr(alias_node, "lineno"): - lineno = alias_node.lineno - else: - lineno = node.lineno - if alias_node.name == "*": - # TODO: If users import * from a non-config module, it should - # fallback to import the real module and raise a warning to - # remind users the real module will be imported which will slow - # down the parsing speed. - raise ConfigParsingError( - "Illegal syntax in config! `from xxx import *` is not allowed to appear outside the `if base:` statement" - ) - elif alias_node.asname is not None: - # case1: - # from visengine.dataset import BaseDataset as Dataset -> - # Dataset = LazyObject('visengine.dataset', 'BaseDataset') - code = f'{alias_node.asname} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 - self.imported_obj.add(alias_node.asname) - else: - # case2: - # from visengine.model import BaseModel - # BaseModel = LazyObject('visengine.model', 'BaseModel') - code = ( - f'{alias_node.name} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 - ) - self.imported_obj.add(alias_node.name) - try: - nodes.append(ast.parse(code).body[0]) # type: ignore - except Exception as e: - raise ConfigParsingError( - f"Cannot import {alias_node} from {module}" - "1. Cannot import * from 3rd party lib in the config " - "file\n" - "2. Please check if the module is a base config which " - "should be added to `_base_`\n" - ) from e - return nodes - - def visit_Import(self, node) -> Union[ast.Assign, ast.Import]: - """Work with ``_gather_abs_import_lazyobj`` to hack the ``import ...`` - syntax. - - Examples: - >>> import mmcls.models - >>> import mmcls.datasets - >>> import mmcls - - Will be parsed as: - - Examples: - >>> # import mmcls.models; import mmcls.datasets; import mmcls - >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) - - Args: - node (ast.AST): The node of the current import statement. - - Returns: - ast.Assign: If the import statement is ``import ... as ...``, - ast.Assign will be returned, otherwise node will be directly - returned. - """ - # For absolute import like: `import mmdet.configs as configs`. - # It will be parsed as: - # configs = LazyObject('mmdet.configs') - # For absolute import like: - # `import mmdet.configs` - # `import mmdet.configs.default_runtime` - # This will be parsed as - # mmdet = LazyObject(['mmdet.configs.default_runtime', 'mmdet.configs]) - # However, visit_Import cannot gather other import information, so - # `_gather_abs_import_LazyObject` will gather all import information - # from the same module and construct the LazyObject. - alias_list = node.names - assert len(alias_list) == 1, "Illegal syntax in config! import multiple modules in one line is not supported" - # TODO Support multiline import - alias = alias_list[0] - if alias.asname is not None: - self.imported_obj.add(alias.asname) - if _is_builtin_module(alias.name.split(".")[0]): - return node - return ast.parse( # type: ignore - f'{alias.asname} = LazyObject("{alias.name}",location="{self.filename}, line {node.lineno}")' - ).body[0] - return node - - -def _gather_abs_import_lazyobj(tree: ast.Module, filename: Optional[str] = None): - """Experimental implementation of gathering absolute import information.""" - if isinstance(filename, str): - filename = filename.encode("unicode_escape").decode() - imported = defaultdict(list) - abs_imported = set() - new_body: List[ast.stmt] = [] - # module2node is used to get lineno when Python < 3.10 - module2node: dict = dict() - for node in tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - # Skip converting built-in module to LazyObject - if _is_builtin_module(alias.name): - new_body.append(node) - continue - module = alias.name.split(".")[0] - module2node.setdefault(module, node) - imported[module].append(alias) - continue - new_body.append(node) - - for key, value in imported.items(): - names = [_value.name for _value in value] - if hasattr(value[0], "lineno"): - lineno = value[0].lineno - else: - lineno = module2node[key].lineno - lazy_module_assign = ast.parse( - f'{key} = LazyObject({names}, location="{filename}, line {lineno}")' # noqa: E501 - ) # noqa: E501 - abs_imported.add(key) - new_body.insert(0, lazy_module_assign.body[0]) - tree.body = new_body - return tree, abs_imported diff --git a/libs/visengine/visengine/dataset/__init__.py b/libs/visengine/visengine/dataset/__init__.py deleted file mode 100644 index 06b756a..0000000 --- a/libs/visengine/visengine/dataset/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base_dataset import BaseDataset, Compose, force_full_init -from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset -from .sampler import DefaultSampler, InfiniteSampler -from .utils import COLLATE_FUNCTIONS, default_collate, pseudo_collate, worker_init_fn - -__all__ = [ - "COLLATE_FUNCTIONS", - "BaseDataset", - "ClassBalancedDataset", - "Compose", - "ConcatDataset", - "DefaultSampler", - "InfiniteSampler", - "RepeatDataset", - "default_collate", - "force_full_init", - "pseudo_collate", - "worker_init_fn", -] diff --git a/libs/visengine/visengine/dataset/base_dataset.py b/libs/visengine/visengine/dataset/base_dataset.py deleted file mode 100644 index 72da4d6..0000000 --- a/libs/visengine/visengine/dataset/base_dataset.py +++ /dev/null @@ -1,812 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import functools -import gc -import logging -import pickle -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import numpy as np -from torch.utils.data import Dataset - -from visengine.config import Config -from visengine.fileio import join_path, list_from_file, load -from visengine.logging import print_log -from visengine.registry import TRANSFORMS -from visengine.utils import is_abs - - -class Compose: - """Compose multiple transforms sequentially. - - Args: - transforms (Sequence[dict, callable], optional): Sequence of transform - object or config dict to be composed. - """ - - def __init__(self, transforms: Sequence[dict | Callable] | None): - self.transforms: list[Callable] = [] - - if transforms is None: - transforms = [] - - for transform in transforms: - # `Compose` can be built with config dict with type and - # corresponding arguments. - if isinstance(transform, dict): - transform = TRANSFORMS.build(transform) - if not callable(transform): - raise TypeError(f"transform should be a callable object, but got {type(transform)}") - self.transforms.append(transform) - elif callable(transform): - self.transforms.append(transform) - else: - raise TypeError(f"transform must be a callable object or dict, but got {type(transform)}") - - def __call__(self, data: dict) -> dict | None: - """Call function to apply transforms sequentially. - - Args: - data (dict): A result dict contains the data to transform. - - Returns: - dict: Transformed data. - """ - for t in self.transforms: - data = t(data) - # The transform will return None when it failed to load images or - # cannot find suitable augmentation parameters to augment the data. - # Here we simply return None if the transform returns None and the - # dataset will handle it by randomly selecting another data sample. - if data is None: - return None - return data - - def __repr__(self): - """Print ``self.transforms`` in sequence. - - Returns: - str: Formatted string. - """ - format_string = self.__class__.__name__ + "(" - for t in self.transforms: - format_string += "\n" - format_string += f" {t}" - format_string += "\n)" - return format_string - - -def force_full_init(old_func: Callable) -> Any: - """Those methods decorated by ``force_full_init`` will be forced to call - ``full_init`` if the instance has not been fully initiated. - - Args: - old_func (Callable): Decorated function, make sure the first arg is an - instance with ``full_init`` method. - - Returns: - Any: Depends on old_func. - """ - - @functools.wraps(old_func) - def wrapper(obj: object, *args, **kwargs): - # The instance must have `full_init` method. - if not hasattr(obj, "full_init"): - raise AttributeError(f"{type(obj)} does not have full_init method.") - # If instance does not have `_fully_initialized` attribute or - # `_fully_initialized` is False, call `full_init` and set - # `_fully_initialized` to True - if not getattr(obj, "_fully_initialized", False): - print_log( - f"Attribute `_fully_initialized` is not defined in " - f"{type(obj)} or `type(obj)._fully_initialized is " - "False, `full_init` will be called and " - f"{type(obj)}._fully_initialized will be set to True", - logger="current", - level=logging.WARNING, - ) - obj.full_init() # type: ignore - obj._fully_initialized = True # type: ignore - - return old_func(obj, *args, **kwargs) - - return wrapper - - -class BaseDataset(Dataset): - r"""BaseDataset for open source projects in OpenMMLab. - - The annotation format is shown as follows. - - .. code-block:: none - - { - "metainfo": - { - "dataset_type": "test_dataset", - "task_name": "test_task" - }, - "data_list": - [ - { - "img_path": "test_img.jpg", - "height": 604, - "width": 640, - "instances": - [ - { - "bbox": [0, 0, 10, 20], - "bbox_label": 1, - "mask": [[0,0],[0,10],[10,20],[20,0]], - "extra_anns": [1,2,3] - }, - { - "bbox": [10, 10, 110, 120], - "bbox_label": 2, - "mask": [[10,10],[10,110],[110,120],[120,10]], - "extra_anns": [4,5,6] - } - ] - }, - ] - } - - Args: - ann_file (str, optional): Annotation file path. Defaults to ''. - metainfo (Mapping or Config, optional): Meta information for - dataset, such as class information. Defaults to None. - data_root (str, optional): The root directory for ``data_prefix`` and - ``ann_file``. Defaults to ''. - data_prefix (dict): Prefix for training data. Defaults to - dict(img_path=''). - filter_cfg (dict, optional): Config for filter data. Defaults to None. - indices (int or Sequence[int], optional): Support using first few - data in annotation file to facilitate training/testing on a smaller - serialize_data (bool, optional): Whether to hold memory using - serialized objects, when enabled, data loader workers can use - shared RAM from master process instead of making a copy. Defaults - to True. - pipeline (list, optional): Processing pipeline. Defaults to []. - test_mode (bool, optional): ``test_mode=True`` means in test phase. - Defaults to False. - lazy_init (bool, optional): Whether to load annotation during - instantiation. In some cases, such as visualization, only the meta - information of the dataset is needed, which is not necessary to - load annotation file. ``Basedataset`` can skip load annotations to - save time by set ``lazy_init=True``. Defaults to False. - max_refetch (int, optional): If ``Basedataset.prepare_data`` get a - None img. The maximum extra number of cycles to get a valid - image. Defaults to 1000. - - Note: - BaseDataset collects meta information from ``annotation file`` (the - lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo - parameter`` (highest) passed to constructors. The lower priority meta - information will be overwritten by higher one. - - Note: - Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. - should not inherit from ``BaseDataset`` since ``get_subset`` and - ``get_subset_`` could produce ambiguous meaning sub-dataset which - conflicts with original dataset. - - Examples: - >>> # Assume the annotation file is given above. - >>> class CustomDataset(BaseDataset): - >>> METAINFO: dict = dict(task_name='custom_task', - >>> dataset_type='custom_type') - >>> metainfo=dict(task_name='custom_task_name') - >>> custom_dataset = CustomDataset( - >>> 'path/to/ann_file', - >>> metainfo=metainfo) - >>> # meta information of annotation file will be overwritten by - >>> # `CustomDataset.METAINFO`. The merged meta information will - >>> # further be overwritten by argument `metainfo`. - >>> custom_dataset.metainfo - {'task_name': custom_task_name, dataset_type: custom_type} - """ - - METAINFO: dict = {} - _fully_initialized: bool = False - - def __init__( - self, - ann_file: str | None = "", - metainfo: Mapping | Config | None = None, - data_root: str | None = "", - data_prefix: dict | None = None, - filter_cfg: dict | None = None, - indices: int | Sequence[int] | None = None, - serialize_data: bool = True, - pipeline: list[dict | Callable] | None = None, - test_mode: bool = False, - lazy_init: bool = False, - max_refetch: int = 1000, - ): - if data_prefix is None: - data_prefix = {"img_path": ""} - if pipeline is None: - pipeline = [] - self.ann_file = ann_file - self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) - self.data_root = data_root - self.data_prefix = copy.copy(data_prefix) - self.filter_cfg = copy.deepcopy(filter_cfg) - self._indices = indices - self.serialize_data = serialize_data - self.test_mode = test_mode - self.max_refetch = max_refetch - self.data_list: list[dict] = [] - self.data_bytes: np.ndarray - - # Join paths. - self._join_prefix() - - # Build pipeline. - self.pipeline = Compose(pipeline) - # Full initialize the dataset. - if not lazy_init: - self.full_init() - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index and automatically call ``full_init`` if the - dataset has not been fully initialized. - - Args: - idx (int): The index of data. - - Returns: - dict: The idx-th annotation of the dataset. - """ - if self.serialize_data: - start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() - end_addr = self.data_address[idx].item() - bytes = memoryview(self.data_bytes[start_addr:end_addr]) # type: ignore - data_info = pickle.loads(bytes) # type: ignore - else: - data_info = copy.deepcopy(self.data_list[idx]) - # Some codebase needs `sample_idx` of data information. Here we convert - # the idx to a positive number and save it in data information. - if idx >= 0: - data_info["sample_idx"] = idx - else: - data_info["sample_idx"] = len(self) + idx - - return data_info - - def full_init(self): - """Load annotation file and set ``BaseDataset._fully_initialized`` to - True. - - If ``lazy_init=False``, ``full_init`` will be called during the - instantiation and ``self._fully_initialized`` will be set to True. If - ``obj._fully_initialized=False``, the class method decorated by - ``force_full_init`` will call ``full_init`` automatically. - - Several steps to initialize annotation: - - - load_data_list: Load annotations from annotation file. - - filter data information: Filter annotations according to - filter_cfg. - - slice_data: Slice dataset according to ``self._indices`` - - serialize_data: Serialize ``self.data_list`` if - ``self.serialize_data`` is True. - """ - if self._fully_initialized: - return - # load data information - self.data_list = self.load_data_list() - # filter illegal data, such as data that has no annotations. - self.data_list = self.filter_data() - # Get subset data according to indices. - if self._indices is not None: - self.data_list = self._get_unserialized_subset(self._indices) - - # serialize data_list - if self.serialize_data: - self.data_bytes, self.data_address = self._serialize_data() - - self._fully_initialized = True - - @property - def metainfo(self) -> dict: - """Get meta information of dataset. - - Returns: - dict: meta information collected from ``BaseDataset.METAINFO``, - annotation file and metainfo argument during instantiation. - """ - return copy.deepcopy(self._metainfo) - - def parse_data_info(self, raw_data_info: dict) -> dict | list[dict]: - """Parse raw annotation to target format. - - This method should return dict or list of dict. Each dict or list - contains the data information of a training sample. If the protocol of - the sample annotations is changed, this function can be overridden to - update the parsing logic while keeping compatibility. - - Args: - raw_data_info (dict): Raw data information load from ``ann_file`` - - Returns: - list or list[dict]: Parsed annotation. - """ - for prefix_key, prefix in self.data_prefix.items(): - assert prefix_key in raw_data_info, ( - f"raw_data_info: {raw_data_info} dose not contain prefix key{prefix_key}, please check your data_prefix." - ) - raw_data_info[prefix_key] = join_path(prefix, raw_data_info[prefix_key]) - return raw_data_info - - def filter_data(self) -> list[dict]: - """Filter annotations according to filter_cfg. Defaults return all - ``data_list``. - - If some ``data_list`` could be filtered according to specific logic, - the subclass should override this method. - - Returns: - list[int]: Filtered results. - """ - return self.data_list - - def get_cat_ids(self, idx: int) -> list[int]: - """Get category ids by index. Dataset wrapped by ClassBalancedDataset - must implement this method. - - The ``ClassBalancedDataset`` requires a subclass which implements this - method. - - Args: - idx (int): The index of data. - - Returns: - list[int]: All categories in the image of specified index. - """ - raise NotImplementedError(f"{type(self)} must implement `get_cat_ids` method") - - def __getitem__(self, idx: int) -> dict: - """Get the idx-th image and data information of dataset after - ``self.pipeline``, and ``full_init`` will be called if the dataset has - not been fully initialized. - - During training phase, if ``self.pipeline`` get ``None``, - ``self._rand_another`` will be called until a valid image is fetched or - the maximum limit of refetech is reached. - - Args: - idx (int): The index of self.data_list. - - Returns: - dict: The idx-th image and data information of dataset after - ``self.pipeline``. - """ - # Performing full initialization by calling `__getitem__` will consume - # extra memory. If a dataset is not fully initialized by setting - # `lazy_init=True` and then fed into the dataloader. Different workers - # will simultaneously read and parse the annotation. It will cost more - # time and memory, although this may work. Therefore, it is recommended - # to manually call `full_init` before dataset fed into dataloader to - # ensure all workers use shared RAM from master process. - if not self._fully_initialized: - print_log( - "Please call `full_init()` method manually to accelerate the speed.", - logger="current", - level=logging.WARNING, - ) - self.full_init() - - if self.test_mode: - data = self.prepare_data(idx) - if data is None: - raise Exception("Test time pipline should not get `None` data_sample") - return data - - for _ in range(self.max_refetch + 1): - data = self.prepare_data(idx) - # Broken images or random augmentations may cause the returned data - # to be None - if data is None: - idx = self._rand_another() - continue - return data - - raise Exception(f"Cannot find valid image after {self.max_refetch}! Please check your image path and pipeline") - - def load_data_list(self) -> list[dict]: - """Load annotations from an annotation file named as ``self.ann_file`` - - If the annotation file does not follow `OpenMMLab 2.0 format dataset - `_ . - The subclass must override this method for load annotations. The meta - information of annotation file will be overwritten :attr:`METAINFO` - and ``metainfo`` argument of constructor. - - Returns: - list[dict]: A list of annotation. - """ - # `self.ann_file` denotes the absolute annotation file path if - # `self.root=None` or relative path if `self.root=/path/to/data/`. - annotations = load(self.ann_file) - if not isinstance(annotations, dict): - raise TypeError( - f"The annotations loaded from annotation file should be a dict, but got {type(annotations)}!" - ) - if "data_list" not in annotations or "metainfo" not in annotations: - raise ValueError("Annotation must have data_list and metainfo keys") - metainfo = annotations["metainfo"] - raw_data_list = annotations["data_list"] - - # Meta information load from annotation file will not influence the - # existed meta information load from `BaseDataset.METAINFO` and - # `metainfo` arguments defined in constructor. - for k, v in metainfo.items(): - self._metainfo.setdefault(k, v) - - # load and parse data_infos. - data_list = [] - for raw_data_info in raw_data_list: - # parse raw data information to target format - data_info = self.parse_data_info(raw_data_info) - if isinstance(data_info, dict): - # For image tasks, `data_info` should information if single - # image, such as dict(img_path='xxx', width=360, ...) - data_list.append(data_info) - elif isinstance(data_info, list): - # For video tasks, `data_info` could contain image - # information of multiple frames, such as - # [dict(video_path='xxx', timestamps=...), - # dict(video_path='xxx', timestamps=...)] - for item in data_info: - if not isinstance(item, dict): - raise TypeError(f"data_info must be list of dict, but got {type(item)}") - data_list.extend(data_info) - else: - raise TypeError(f"data_info should be a dict or list of dict, but got {type(data_info)}") - - return data_list - - @classmethod - def _load_metainfo(cls, metainfo: Mapping | Config | None = None) -> dict: - """Collect meta information from the dictionary of meta. - - Args: - metainfo (Mapping or Config, optional): Meta information dict. - If ``metainfo`` contains existed filename, it will be - parsed by ``list_from_file``. - - Returns: - dict: Parsed meta information. - """ - # avoid `cls.METAINFO` being overwritten by `metainfo` - cls_metainfo = copy.deepcopy(cls.METAINFO) - if metainfo is None: - return cls_metainfo - if not isinstance(metainfo, Mapping | Config): - raise TypeError(f"metainfo should be a Mapping or Config, but got {type(metainfo)}") - - for k, v in metainfo.items(): - if isinstance(v, str): - # If type of value is string, and can be loaded from - # corresponding backend. it means the file name of meta file. - try: - cls_metainfo[k] = list_from_file(v) - except (TypeError, FileNotFoundError): - print_log( - f"{v} is not a meta file, simply parsed as meta information", - logger="current", - level=logging.WARNING, - ) - cls_metainfo[k] = v - else: - cls_metainfo[k] = v - return cls_metainfo - - def _join_prefix(self): - """Join ``self.data_root`` with ``self.data_prefix`` and - ``self.ann_file``. - - Examples: - >>> # self.data_prefix contains relative paths - >>> self.data_root = 'a/b/c' - >>> self.data_prefix = dict(img='d/e/') - >>> self.ann_file = 'f' - >>> self._join_prefix() - >>> self.data_prefix - dict(img='a/b/c/d/e') - >>> self.ann_file - 'a/b/c/f' - >>> # self.data_prefix contains absolute paths - >>> self.data_root = 'a/b/c' - >>> self.data_prefix = dict(img='/d/e/') - >>> self.ann_file = 'f' - >>> self._join_prefix() - >>> self.data_prefix - dict(img='/d/e') - >>> self.ann_file - 'a/b/c/f' - """ - # Automatically join annotation file path with `self.root` if - # `self.ann_file` is not an absolute path. - if self.ann_file and not is_abs(self.ann_file) and self.data_root: - self.ann_file = join_path(self.data_root, self.ann_file) - # Automatically join data directory with `self.root` if path value in - # `self.data_prefix` is not an absolute path. - for data_key, prefix in self.data_prefix.items(): - if not isinstance(prefix, str): - raise TypeError(f"prefix should be a string, but got {type(prefix)}") - if not is_abs(prefix) and self.data_root: - self.data_prefix[data_key] = join_path(self.data_root, prefix) - else: - self.data_prefix[data_key] = prefix - - @force_full_init - def get_subset_(self, indices: Sequence[int] | int) -> None: - """The in-place version of ``get_subset`` to convert dataset to a - subset of original dataset. - - This method will convert the original dataset to a subset of dataset. - If type of indices is int, ``get_subset_`` will return a subdataset - which contains the first or last few data information according to - indices is positive or negative. If type of indices is a sequence of - int, the subdataset will extract the data information according to - the index given in indices. - - Examples: - >>> dataset = BaseDataset('path/to/ann_file') - >>> len(dataset) - 100 - >>> dataset.get_subset_(90) - >>> len(dataset) - 90 - >>> # if type of indices is sequence, extract the corresponding - >>> # index data information - >>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - >>> len(dataset) - 10 - >>> dataset.get_subset_(-3) - >>> len(dataset) # Get the latest few data information. - 3 - - Args: - indices (int or Sequence[int]): If type of indices is int, indices - represents the first or last few data of dataset according to - indices is positive or negative. If type of indices is - Sequence, indices represents the target data information - index of dataset. - """ - # Get subset of data from serialized data or data information sequence - # according to `self.serialize_data`. - if self.serialize_data: - self.data_bytes, self.data_address = self._get_serialized_subset(indices) - else: - self.data_list = self._get_unserialized_subset(indices) - - @force_full_init - def get_subset(self, indices: Sequence[int] | int) -> "BaseDataset": - """Return a subset of dataset. - - This method will return a subset of original dataset. If type of - indices is int, ``get_subset_`` will return a subdataset which - contains the first or last few data information according to - indices is positive or negative. If type of indices is a sequence of - int, the subdataset will extract the information according to the index - given in indices. - - Examples: - >>> dataset = BaseDataset('path/to/ann_file') - >>> len(dataset) - 100 - >>> subdataset = dataset.get_subset(90) - >>> len(sub_dataset) - 90 - >>> # if type of indices is list, extract the corresponding - >>> # index data information - >>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, - >>> 8, 9]) - >>> len(sub_dataset) - 10 - >>> subdataset = dataset.get_subset(-3) - >>> len(subdataset) # Get the latest few data information. - 3 - - Args: - indices (int or Sequence[int]): If type of indices is int, indices - represents the first or last few data of dataset according to - indices is positive or negative. If type of indices is - Sequence, indices represents the target data information - index of dataset. - - Returns: - BaseDataset: A subset of dataset. - """ - # Get subset of data from serialized data or data information list - # according to `self.serialize_data`. Since `_get_serialized_subset` - # will recalculate the subset data information, - # `_copy_without_annotation` will copy all attributes except data - # information. - sub_dataset = self._copy_without_annotation() - # Get subset of dataset with serialize and unserialized data. - if self.serialize_data: - data_bytes, data_address = self._get_serialized_subset(indices) - sub_dataset.data_bytes = data_bytes.copy() - sub_dataset.data_address = data_address.copy() - else: - data_list = self._get_unserialized_subset(indices) - sub_dataset.data_list = copy.deepcopy(data_list) - return sub_dataset - - def _get_serialized_subset(self, indices: Sequence[int] | int) -> tuple[np.ndarray, np.ndarray]: - """Get subset of serialized data information list. - - Args: - indices (int or Sequence[int]): If type of indices is int, - indices represents the first or last few data of serialized - data information list. If type of indices is Sequence, indices - represents the target data information index which consist of - subset data information. - - Returns: - Tuple[np.ndarray, np.ndarray]: subset of serialized data - information. - """ - sub_data_bytes: list | np.ndarray - sub_data_address: list | np.ndarray - if isinstance(indices, int): - if indices >= 0: - assert indices < len(self.data_address), f"{indices} is out of dataset length({len(self)}" - # Return the first few data information. - end_addr = self.data_address[indices - 1].item() if indices > 0 else 0 - # Slicing operation of `np.ndarray` does not trigger a memory - # copy. - sub_data_bytes = self.data_bytes[:end_addr] - # Since the buffer size of first few data information is not - # changed, - sub_data_address = self.data_address[:indices] - else: - assert -indices <= len(self.data_address), f"{indices} is out of dataset length({len(self)}" - # Return the last few data information. - ignored_bytes_size = self.data_address[indices - 1] - start_addr = self.data_address[indices - 1].item() - sub_data_bytes = self.data_bytes[start_addr:] - sub_data_address = self.data_address[indices:] - sub_data_address = sub_data_address - ignored_bytes_size - elif isinstance(indices, Sequence): - sub_data_bytes = [] - sub_data_address = [] - for idx in indices: - assert len(self) > idx >= -len(self) - start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() - end_addr = self.data_address[idx].item() - # Get data information by address. - sub_data_bytes.append(self.data_bytes[start_addr:end_addr]) - # Get data information size. - sub_data_address.append(end_addr - start_addr) - # Handle indices is an empty list. - if sub_data_bytes: - sub_data_bytes = np.concatenate(sub_data_bytes) - sub_data_address = np.cumsum(sub_data_address) - else: - sub_data_bytes = np.array([]) - sub_data_address = np.array([]) - else: - raise TypeError(f"indices should be a int or sequence of int, but got {type(indices)}") - return sub_data_bytes, sub_data_address # type: ignore - - def _get_unserialized_subset(self, indices: Sequence[int] | int) -> list: - """Get subset of data information list. - - Args: - indices (int or Sequence[int]): If type of indices is int, - indices represents the first or last few data of data - information. If type of indices is Sequence, indices represents - the target data information index which consist of subset data - information. - - Returns: - Tuple[np.ndarray, np.ndarray]: subset of data information. - """ - if isinstance(indices, int): - if indices >= 0: - # Return the first few data information. - sub_data_list = self.data_list[:indices] - else: - # Return the last few data information. - sub_data_list = self.data_list[indices:] - elif isinstance(indices, Sequence): - # Return the data information according to given indices. - sub_data_list = [] - for idx in indices: - sub_data_list.append(self.data_list[idx]) - else: - raise TypeError(f"indices should be a int or sequence of int, but got {type(indices)}") - return sub_data_list - - def _serialize_data(self) -> tuple[np.ndarray, np.ndarray]: - """Serialize ``self.data_list`` to save memory when launching multiple - workers in data loading. This function will be called in ``full_init``. - - Hold memory using serialized objects, and data loader workers can use - shared RAM from master process instead of making a copy. - - Returns: - Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding - address. - """ - - def _serialize(data): - buffer = pickle.dumps(data, protocol=4) - return np.frombuffer(buffer, dtype=np.uint8) - - # Serialize data information list avoid making multiple copies of - # `self.data_list` when iterate `import torch.utils.data.dataloader` - # with multiple workers. - data_list = [_serialize(x) for x in self.data_list] - address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) - data_address: np.ndarray = np.cumsum(address_list) - # TODO Check if np.concatenate is necessary - data_bytes = np.concatenate(data_list) - # Empty cache for preventing making multiple copies of - # `self.data_info` when loading data multi-processes. - self.data_list.clear() - gc.collect() - return data_bytes, data_address - - def _rand_another(self) -> int: - """Get random index. - - Returns: - int: Random index from 0 to ``len(self)-1`` - """ - return np.random.randint(0, len(self)) - - def prepare_data(self, idx) -> Any: - """Get data processed by ``self.pipeline``. - - Args: - idx (int): The index of ``data_info``. - - Returns: - Any: Depends on ``self.pipeline``. - """ - data_info = self.get_data_info(idx) - return self.pipeline(data_info) - - @force_full_init - def __len__(self) -> int: - """Get the length of filtered dataset and automatically call - ``full_init`` if the dataset has not been fully init. - - Returns: - int: The length of filtered dataset. - """ - if self.serialize_data: - return len(self.data_address) - else: - return len(self.data_list) - - def _copy_without_annotation(self, memo=None) -> "BaseDataset": - """Deepcopy for all attributes other than ``data_list``, - ``data_address`` and ``data_bytes``. - - Args: - memo: Memory dict which used to reconstruct complex object - correctly. - """ - if memo is None: - memo = {} - cls = self.__class__ - other = cls.__new__(cls) - memo[id(self)] = other - - for key, value in self.__dict__.items(): - if key in ["data_list", "data_address", "data_bytes"]: - continue - super(BaseDataset, other).__setattr__(key, copy.deepcopy(value, memo)) - - return other diff --git a/libs/visengine/visengine/dataset/dataset_wrapper.py b/libs/visengine/visengine/dataset/dataset_wrapper.py deleted file mode 100644 index ef5d504..0000000 --- a/libs/visengine/visengine/dataset/dataset_wrapper.py +++ /dev/null @@ -1,527 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import bisect -import copy -import logging -import math -from collections import defaultdict -from collections.abc import Sequence - -import numpy as np -from torch.utils.data.dataset import ConcatDataset as _ConcatDataset - -from visengine.logging import print_log -from visengine.registry import DATASETS - -from .base_dataset import BaseDataset, force_full_init - - -@DATASETS.register_module(force=True) -class ConcatDataset(_ConcatDataset): - """A wrapper of concatenated dataset. - - Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. - - Note: - ``ConcatDataset`` should not inherit from ``BaseDataset`` since - ``get_subset`` and ``get_subset_`` could produce ambiguous meaning - sub-dataset which conflicts with original dataset. If you want to use - a sub-dataset of ``ConcatDataset``, you should set ``indices`` - arguments for wrapped dataset which inherit from ``BaseDataset``. - - Args: - datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets - which will be concatenated. - lazy_init (bool, optional): Whether to load annotation during - instantiation. Defaults to False. - ignore_keys (List[str] or str): Ignore the keys that can be - unequal in `dataset.metainfo`. Defaults to None. - `New in version 0.3.0.` - """ - - def __init__( - self, - datasets: Sequence[BaseDataset | dict], - lazy_init: bool = False, - ignore_keys: str | list[str] | None = None, - ): - self.datasets: list[BaseDataset] = [] - for i, dataset in enumerate(datasets): - if isinstance(dataset, dict): - self.datasets.append(DATASETS.build(dataset)) - elif isinstance(dataset, BaseDataset): - self.datasets.append(dataset) - else: - raise TypeError( - f"elements in datasets sequence should be config or `BaseDataset` instance, but got {type(dataset)}" - ) - if ignore_keys is None: - self.ignore_keys = [] - elif isinstance(ignore_keys, str): - self.ignore_keys = [ignore_keys] - elif isinstance(ignore_keys, list): - self.ignore_keys = ignore_keys - else: - raise TypeError(f"ignore_keys should be a list or str, but got {type(ignore_keys)}") - - meta_keys: set = set() - for dataset in self.datasets: - meta_keys |= dataset.metainfo.keys() - # Only use metainfo of first dataset. - self._metainfo = self.datasets[0].metainfo - for i, dataset in enumerate(self.datasets, 1): - for key in meta_keys: - if key in self.ignore_keys: - continue - if key not in dataset.metainfo: - raise ValueError(f"{key} does not in the meta information of the {i}-th dataset") - first_type = type(self._metainfo[key]) - cur_type = type(dataset.metainfo[key]) - if first_type is not cur_type: # type: ignore - raise TypeError( - f"The type {cur_type} of {key} in the {i}-th dataset should be the same with the first dataset {first_type}" - ) - if ( - isinstance(self._metainfo[key], np.ndarray) - and not np.array_equal(self._metainfo[key], dataset.metainfo[key]) - ) or (not isinstance(self._metainfo[key], np.ndarray) and self._metainfo[key] != dataset.metainfo[key]): - raise ValueError( - f"The meta information of the {i}-th dataset does not match meta information of the first dataset" - ) - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the first dataset in ``self.datasets``. - - Returns: - dict: Meta information of first dataset. - """ - # Prevent `self._metainfo` from being modified by outside. - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - for d in self.datasets: - d.full_init() - # Get the cumulative sizes of `self.datasets`. For example, the length - # of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9] - super().__init__(self.datasets) - self._fully_initialized = True - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> tuple[int, int]: - """Convert global idx to local index. - - Args: - idx (int): Global index of ``RepeatDataset``. - - Returns: - Tuple[int, int]: The index of ``self.datasets`` and the local - index of data. - """ - if idx < 0: - if -idx > len(self): - raise ValueError(f"absolute value of index({idx}) should not exceed datasetlength({len(self)}).") - idx = len(self) + idx - # Get `dataset_idx` to tell idx belongs to which dataset. - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - # Get the inner index of single dataset. - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - - return dataset_idx, sample_idx - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the datasets. - """ - dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) - return self.datasets[dataset_idx].get_data_info(sample_idx) - - @force_full_init - def __len__(self): - return super().__len__() - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - "Please call `full_init` method manually to accelerate the speed.", - logger="current", - level=logging.WARNING, - ) - self.full_init() - dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) - return self.datasets[dataset_idx][sample_idx] - - def get_subset_(self, indices: list[int] | int) -> None: - """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - "`ConcatDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `ConcatDataset`." - ) - - def get_subset(self, indices: list[int] | int) -> "BaseDataset": - """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - "`ConcatDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `ConcatDataset`." - ) - - -@DATASETS.register_module(force=True) -class RepeatDataset: - """A wrapper of repeated dataset. - - The length of repeated dataset will be `times` larger than the original - dataset. This is useful when the data loading time is long but the dataset - is small. Using RepeatDataset can reduce the data loading time between - epochs. - - Note: - ``RepeatDataset`` should not inherit from ``BaseDataset`` since - ``get_subset`` and ``get_subset_`` could produce ambiguous meaning - sub-dataset which conflicts with original dataset. If you want to use - a sub-dataset of ``RepeatDataset``, you should set ``indices`` - arguments for wrapped dataset which inherit from ``BaseDataset``. - - Args: - dataset (BaseDataset or dict): The dataset to be repeated. - times (int): Repeat times. - lazy_init (bool): Whether to load annotation during - instantiation. Defaults to False. - """ - - def __init__(self, dataset: BaseDataset | dict, times: int, lazy_init: bool = False): - self.dataset: BaseDataset - if isinstance(dataset, dict): - self.dataset = DATASETS.build(dataset) - elif isinstance(dataset, BaseDataset): - self.dataset = dataset - else: - raise TypeError( - f"elements in datasets sequence should be config or `BaseDataset` instance, but got {type(dataset)}" - ) - self.times = times - self._metainfo = self.dataset.metainfo - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the repeated dataset. - - Returns: - dict: The meta information of repeated dataset. - """ - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - - self.dataset.full_init() - self._ori_len = len(self.dataset) - self._fully_initialized = True - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> int: - """Convert global index to local index. - - Args: - idx: Global index of ``RepeatDataset``. - - Returns: - idx (int): Local index of data. - """ - return idx % self._ori_len - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the datasets. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_data_info(sample_idx) - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - "Please call `full_init` method manually to accelerate the speed.", - logger="current", - level=logging.WARNING, - ) - self.full_init() - - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset[sample_idx] - - @force_full_init - def __len__(self): - return self.times * self._ori_len - - def get_subset_(self, indices: list[int] | int) -> None: - """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - "`RepeatDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `RepeatDataset`." - ) - - def get_subset(self, indices: list[int] | int) -> "BaseDataset": - """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- - dataset.""" - raise NotImplementedError( - "`RepeatDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `RepeatDataset`." - ) - - -@DATASETS.register_module(force=True) -class ClassBalancedDataset: - """A wrapper of class balanced dataset. - - Suitable for training on class imbalanced datasets like LVIS. Following - the sampling strategy in the `paper `_, - in each epoch, an image may appear multiple times based on its - "repeat factor". - The repeat factor for an image is a function of the frequency the rarest - category labeled in that image. The "frequency of category c" in [0, 1] - is defined by the fraction of images in the training set (without repeats) - in which category c appears. - The dataset needs to instantiate :meth:`get_cat_ids` to support - ClassBalancedDataset. - - The repeat factor is computed as followed. - - 1. For each category c, compute the fraction # of images - that contain it: :math:`f(c)` - 2. For each category c, compute the category-level repeat factor: - :math:`r(c) = max(1, sqrt(t/f(c)))` - 3. For each image I, compute the image-level repeat factor: - :math:`r(I) = max_{c in I} r(c)` - - Note: - ``ClassBalancedDataset`` should not inherit from ``BaseDataset`` - since ``get_subset`` and ``get_subset_`` could produce ambiguous - meaning sub-dataset which conflicts with original dataset. If you - want to use a sub-dataset of ``ClassBalancedDataset``, you should set - ``indices`` arguments for wrapped dataset which inherit from - ``BaseDataset``. - - Args: - dataset (BaseDataset or dict): The dataset to be repeated. - oversample_thr (float): frequency threshold below which data is - repeated. For categories with ``f_c >= oversample_thr``, there is - no oversampling. For categories with ``f_c < oversample_thr``, the - degree of oversampling following the square-root inverse frequency - heuristic above. - lazy_init (bool, optional): whether to load annotation during - instantiation. Defaults to False - """ - - def __init__( - self, - dataset: BaseDataset | dict, - oversample_thr: float, - lazy_init: bool = False, - ): - if isinstance(dataset, dict): - self.dataset = DATASETS.build(dataset) - elif isinstance(dataset, BaseDataset): - self.dataset = dataset - else: - raise TypeError( - f"elements in datasets sequence should be config or `BaseDataset` instance, but got {type(dataset)}" - ) - self.oversample_thr = oversample_thr - self._metainfo = self.dataset.metainfo - - self._fully_initialized = False - if not lazy_init: - self.full_init() - - @property - def metainfo(self) -> dict: - """Get the meta information of the repeated dataset. - - Returns: - dict: The meta information of repeated dataset. - """ - return copy.deepcopy(self._metainfo) - - def full_init(self): - """Loop to ``full_init`` each dataset.""" - if self._fully_initialized: - return - - self.dataset.full_init() - # Get repeat factors for each image. - repeat_factors = self._get_repeat_factors(self.dataset, self.oversample_thr) - # Repeat dataset's indices according to repeat_factors. For example, - # if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`, - # the repeated indices will be [1, 2, 2, 3, 3, 3]. - repeat_indices = [] - for dataset_index, repeat_factor in enumerate(repeat_factors): - repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) - self.repeat_indices = repeat_indices - - self._fully_initialized = True - - def _get_repeat_factors(self, dataset: BaseDataset, repeat_thr: float) -> list[float]: - """Get repeat factor for each images in the dataset. - - Args: - dataset (BaseDataset): The dataset. - repeat_thr (float): The threshold of frequency. If an image - contains the categories whose frequency below the threshold, - it would be repeated. - - Returns: - List[float]: The repeat factors for each images in the dataset. - """ - # 1. For each category c, compute the fraction # of images - # that contain it: f(c) - category_freq: defaultdict = defaultdict(float) - num_images = len(dataset) - for idx in range(num_images): - cat_ids = set(self.dataset.get_cat_ids(idx)) - for cat_id in cat_ids: - category_freq[cat_id] += 1 - for k, v in category_freq.items(): - assert v > 0, f"caterogy {k} does not contain any images" - category_freq[k] = v / num_images - - # 2. For each category c, compute the category-level repeat factor: - # r(c) = max(1, sqrt(t/f(c))) - category_repeat = { - cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) for cat_id, cat_freq in category_freq.items() - } - - # 3. For each image I and its labels L(I), compute the image-level - # repeat factor: - # r(I) = max_{c in L(I)} r(c) - repeat_factors = [] - for idx in range(num_images): - # the length of `repeat_factors` need equal to the length of - # dataset. Hence, if the `cat_ids` is empty, - # the repeat_factor should be 1. - repeat_factor: float = 1.0 - cat_ids = set(self.dataset.get_cat_ids(idx)) - if len(cat_ids) != 0: - repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids}) - repeat_factors.append(repeat_factor) - - return repeat_factors - - @force_full_init - def _get_ori_dataset_idx(self, idx: int) -> int: - """Convert global index to local index. - - Args: - idx (int): Global index of ``RepeatDataset``. - - Returns: - int: Local index of data. - """ - return self.repeat_indices[idx] - - @force_full_init - def get_cat_ids(self, idx: int) -> list[int]: - """Get category ids of class balanced dataset by index. - - Args: - idx (int): Index of data. - - Returns: - List[int]: All categories in the image of specified index. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_cat_ids(sample_idx) - - @force_full_init - def get_data_info(self, idx: int) -> dict: - """Get annotation by index. - - Args: - idx (int): Global index of ``ConcatDataset``. - - Returns: - dict: The idx-th annotation of the dataset. - """ - sample_idx = self._get_ori_dataset_idx(idx) - return self.dataset.get_data_info(sample_idx) - - def __getitem__(self, idx): - if not self._fully_initialized: - print_log( - "Please call `full_init` method manually to accelerate the speed.", - logger="current", - level=logging.WARNING, - ) - self.full_init() - - ori_index = self._get_ori_dataset_idx(idx) - return self.dataset[ori_index] - - @force_full_init - def __len__(self): - return len(self.repeat_indices) - - def get_subset_(self, indices: list[int] | int) -> None: - """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning - of sub-dataset.""" - raise NotImplementedError( - "`ClassBalancedDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `ClassBalancedDataset`." - ) - - def get_subset(self, indices: list[int] | int) -> "BaseDataset": - """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning - of sub-dataset.""" - raise NotImplementedError( - "`ClassBalancedDataset` dose not support `get_subset` and " - "`get_subset_` interfaces because this will lead to ambiguous " - "implementation of some methods. If you want to use `get_subset` " - "or `get_subset_` interfaces, please use them in the wrapped " - "dataset first and then use `ClassBalancedDataset`." - ) diff --git a/libs/visengine/visengine/dataset/sampler.py b/libs/visengine/visengine/dataset/sampler.py deleted file mode 100644 index 84fa0de..0000000 --- a/libs/visengine/visengine/dataset/sampler.py +++ /dev/null @@ -1,162 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import itertools -import math -from collections.abc import Iterator, Sized - -import torch -from torch.utils.data import Sampler - -from visengine.dist import get_dist_info, sync_random_seed -from visengine.registry import DATA_SAMPLERS - - -@DATA_SAMPLERS.register_module(force=True) -class DefaultSampler(Sampler): - """The default data sampler for both distributed and non-distributed - environment. - - It has several differences from the PyTorch ``DistributedSampler`` as - below: - - 1. This sampler supports non-distributed environment. - - 2. The round up behaviors are a little different. - - - If ``round_up=True``, this sampler will add extra samples to make the - number of samples is evenly divisible by the world size. And - this behavior is the same as the ``DistributedSampler`` with - ``drop_last=False``. - - If ``round_up=False``, this sampler won't remove or add any samples - while the ``DistributedSampler`` with ``drop_last=True`` will remove - tail samples. - - Args: - dataset (Sized): The dataset. - shuffle (bool): Whether shuffle the dataset or not. Defaults to True. - seed (int, optional): Random seed used to shuffle the sampler if - :attr:`shuffle=True`. This number should be identical across all - processes in the distributed group. Defaults to None. - round_up (bool): Whether to add extra samples to make the number of - samples evenly divisible by the world size. Defaults to True. - """ - - def __init__( - self, - dataset: Sized, - shuffle: bool = True, - seed: int | None = None, - round_up: bool = True, - ) -> None: - rank, world_size = get_dist_info() - self.rank = rank - self.world_size = world_size - - self.dataset = dataset - self.shuffle = shuffle - if seed is None: - seed = sync_random_seed() - self.seed = seed - self.epoch = 0 - self.round_up = round_up - - if self.round_up: - self.num_samples = math.ceil(len(self.dataset) / world_size) - self.total_size = self.num_samples * self.world_size - else: - self.num_samples = math.ceil((len(self.dataset) - rank) / world_size) - self.total_size = len(self.dataset) - - def __iter__(self) -> Iterator[int]: - """Iterate the indices.""" - # deterministically shuffle based on epoch and seed - if self.shuffle: - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = torch.arange(len(self.dataset)).tolist() - - # add extra samples to make it evenly divisible - if self.round_up: - indices = (indices * int(self.total_size / len(indices) + 1))[: self.total_size] - - # subsample - indices = indices[self.rank : self.total_size : self.world_size] - - return iter(indices) - - def __len__(self) -> int: - """The number of samples in this rank.""" - return self.num_samples - - def set_epoch(self, epoch: int) -> None: - """Sets the epoch for this sampler. - - When :attr:`shuffle=True`, this ensures all replicas use a different - random ordering for each epoch. Otherwise, the next iteration of this - sampler will yield the same ordering. - - Args: - epoch (int): Epoch number. - """ - self.epoch = epoch - - -@DATA_SAMPLERS.register_module(force=True) -class InfiniteSampler(Sampler): - """It's designed for iteration-based runner and yields a mini-batch indices - each time. - - The implementation logic is referred to - https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py - - Args: - dataset (Sized): The dataset. - shuffle (bool): Whether shuffle the dataset or not. Defaults to True. - seed (int, optional): Random seed. If None, set a random seed. - Defaults to None. - """ - - def __init__(self, dataset: Sized, shuffle: bool = True, seed: int | None = None) -> None: - rank, world_size = get_dist_info() - self.rank = rank - self.world_size = world_size - - self.dataset = dataset - self.world_size = world_size - self.rank = rank - self.shuffle = shuffle - if seed is None: - seed = sync_random_seed() - self.seed = seed - self.size = len(dataset) - self.indices = self._indices_of_rank() - - def _infinite_indices(self) -> Iterator[int]: - """Infinitely yield a sequence of indices.""" - g = torch.Generator() - g.manual_seed(self.seed) - while True: - if self.shuffle: - yield from torch.randperm(self.size, generator=g).tolist() - - else: - yield from torch.arange(self.size).tolist() - - def _indices_of_rank(self) -> Iterator[int]: - """Slice the infinite indices by rank.""" - yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) - - def __iter__(self) -> Iterator[int]: - """Iterate the indices.""" - yield from self.indices - - def __len__(self) -> int: - """Length of base dataset.""" - return self.size - - def set_epoch(self, epoch: int) -> None: - """Not supported in iteration-based runner.""" - pass diff --git a/libs/visengine/visengine/dataset/utils.py b/libs/visengine/visengine/dataset/utils.py deleted file mode 100644 index fcf271b..0000000 --- a/libs/visengine/visengine/dataset/utils.py +++ /dev/null @@ -1,155 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import random -import warnings -from collections.abc import Mapping, Sequence -from typing import Any - -import numpy as np -import torch -from torch.utils.data._utils.collate import default_collate as torch_default_collate - -from visengine.registry import FUNCTIONS -from visengine.structures import BaseDataElement - -# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep -# the compatibility. -COLLATE_FUNCTIONS = FUNCTIONS - - -def worker_init_fn( - worker_id: int, - num_workers: int, - rank: int, - seed: int, - disable_subprocess_warning: bool = False, -) -> None: - """This function will be called on each worker subprocess after seeding and - before data loading. - - Args: - worker_id (int): Worker id in [0, num_workers - 1]. - num_workers (int): How many subprocesses to use for data loading. - rank (int): Rank of process in distributed environment. If in - non-distributed environment, it is a constant number `0`. - seed (int): Random seed. - """ - # The seed of each worker equals to - # num_worker * rank + worker_id + user_seed - worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) - torch.manual_seed(worker_seed) - if disable_subprocess_warning and worker_id != 0: - warnings.simplefilter("ignore") - - -@FUNCTIONS.register_module(force=True) -def pseudo_collate(data_batch: Sequence) -> Any: - """Convert list of data sampled from dataset into a batch of data, of which - type consistent with the type of each data_itement in ``data_batch``. - - The default behavior of dataloader is to merge a list of samples to form - a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` - will not stack tensors to batch tensors, and convert int, float, ndarray to - tensors. - - This code is referenced from: - `Pytorch default_collate `_. - - Args: - data_batch (Sequence): Batch of data from dataloader. - - Returns: - Any: Transversed Data in the same format as the data_itement of - ``data_batch``. - """ - data_item = data_batch[0] - data_item_type = type(data_item) - if isinstance(data_item, str | bytes): - return data_batch - elif isinstance(data_item, tuple) and hasattr(data_item, "_fields"): - # named tuple - return data_item_type(*(pseudo_collate(samples) for samples in zip(*data_batch, strict=False))) - elif isinstance(data_item, Sequence): - # check to make sure that the data_itements in batch have - # consistent size - it = iter(data_batch) - data_item_size = len(next(it)) - if not all(len(data_item) == data_item_size for data_item in it): - raise RuntimeError("each data_itement in list of batch should be of equal size") - transposed = list(zip(*data_batch, strict=False)) - - if isinstance(data_item, tuple): - return [pseudo_collate(samples) for samples in transposed] # Compat with Pytorch. - else: - try: - return data_item_type([pseudo_collate(samples) for samples in transposed]) - except TypeError: - # The sequence type may not support `__init__(iterable)` - # (e.g., `range`). - return [pseudo_collate(samples) for samples in transposed] - elif isinstance(data_item, Mapping): - return data_item_type({key: pseudo_collate([d[key] for d in data_batch]) for key in data_item}) - else: - return data_batch - - -@FUNCTIONS.register_module(force=True) -def default_collate(data_batch: Sequence) -> Any: - """Convert list of data sampled from dataset into a batch of data, of which - type consistent with the type of each data_itement in ``data_batch``. - - Different from :func:`pseudo_collate`, ``default_collate`` will stack - tensor contained in ``data_batch`` into a batched tensor with the - first dimension batch size, and then move input tensor to the target - device. - - Different from ``default_collate`` in pytorch, ``default_collate`` will - not process ``BaseDataElement``. - - This code is referenced from: - `Pytorch default_collate `_. - - Note: - ``default_collate`` only accept input tensor with the same shape. - - Args: - data_batch (Sequence): Data sampled from dataset. - - Returns: - Any: Data in the same format as the data_itement of ``data_batch``, of which - tensors have been stacked, and ndarray, int, float have been - converted to tensors. - """ - data_item = data_batch[0] - data_item_type = type(data_item) - - if isinstance(data_item, BaseDataElement | str | bytes): - return data_batch - elif isinstance(data_item, tuple) and hasattr(data_item, "_fields"): - # named_tuple - return data_item_type(*(default_collate(samples) for samples in zip(*data_batch, strict=False))) - elif isinstance(data_item, Sequence): - # check to make sure that the data_itements in batch have - # consistent size - it = iter(data_batch) - data_item_size = len(next(it)) - if not all(len(data_item) == data_item_size for data_item in it): - raise RuntimeError("each data_itement in list of batch should be of equal size") - transposed = list(zip(*data_batch, strict=False)) - - if isinstance(data_item, tuple): - return [default_collate(samples) for samples in transposed] # Compat with Pytorch. - else: - try: - return data_item_type([default_collate(samples) for samples in transposed]) - except TypeError: - # The sequence type may not support `__init__(iterable)` - # (e.g., `range`). - return [default_collate(samples) for samples in transposed] - elif isinstance(data_item, Mapping): - return data_item_type({key: default_collate([d[key] for d in data_batch]) for key in data_item}) - else: - return torch_default_collate(data_batch) diff --git a/libs/visengine/visengine/device/__init__.py b/libs/visengine/visengine/device/__init__.py deleted file mode 100644 index 22c6b53..0000000 --- a/libs/visengine/visengine/device/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .utils import ( - get_device, - get_max_cuda_memory, - is_cuda_available, -) - -# Add lambda function that always returns False -# Required by mmcv -is_mlu_available = lambda: False -is_npu_available = lambda: False -is_musa_available = lambda: False -is_mps_available = lambda: False - -__all__ = [ - "get_device", - "get_max_cuda_memory", - "is_cuda_available", - "is_mlu_available", - "is_npu_available", - "is_musa_available", - "is_mps_available", -] diff --git a/libs/visengine/visengine/device/utils.py b/libs/visengine/visengine/device/utils.py deleted file mode 100644 index 1332daa..0000000 --- a/libs/visengine/visengine/device/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import torch - - -def get_max_cuda_memory(device: torch.device | None = None) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for - a given device. By default, this returns the peak allocated memory since - the beginning of this program. - - Args: - device (torch.device, optional): selected device. Returns - statistic for the current device, given by - :func:`~torch.cuda.current_device`, if ``device`` is None. - Defaults to None. - - Returns: - int: The maximum GPU memory occupied by tensors in megabytes - for a given device. - """ - mem = torch.cuda.max_memory_allocated(device=device) - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int, device=device) - torch.cuda.reset_peak_memory_stats() - return int(mem_mb.item()) - - -def is_cuda_available() -> bool: - """Returns True if cuda devices exist.""" - return torch.cuda.is_available() - - -DEVICE = "cpu" -if is_cuda_available(): - DEVICE = "cuda" - - -def get_device() -> str: - """Returns the currently existing device type. - - Returns: - str: cuda | cpu. - """ - return DEVICE diff --git a/libs/visengine/visengine/evaluator/__init__.py b/libs/visengine/visengine/evaluator/__init__.py deleted file mode 100644 index 4e634f2..0000000 --- a/libs/visengine/visengine/evaluator/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .evaluator import Evaluator -from .metric import BaseMetric, DumpResults -from .utils import get_metric_value - -__all__ = ["BaseMetric", "DumpResults", "Evaluator", "get_metric_value"] diff --git a/libs/visengine/visengine/evaluator/evaluator.py b/libs/visengine/visengine/evaluator/evaluator.py deleted file mode 100644 index 2bdd578..0000000 --- a/libs/visengine/visengine/evaluator/evaluator.py +++ /dev/null @@ -1,134 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Iterator, Sequence -from typing import Any - -from visengine.dataset import pseudo_collate -from visengine.registry import EVALUATOR, METRICS -from visengine.structures import BaseDataElement - -from .metric import BaseMetric - - -@EVALUATOR.register_module(force=True) -class Evaluator: - """Wrapper class to compose multiple :class:`BaseMetric` instances. - - Args: - metrics (dict or BaseMetric or Sequence): The config of metrics. - """ - - def __init__(self, metrics: dict | BaseMetric | Sequence): - self._dataset_meta: dict | None = None - if not isinstance(metrics, Sequence): - metrics = [metrics] - self.metrics: list[BaseMetric] = [] - for metric in metrics: - if isinstance(metric, dict): - self.metrics.append(METRICS.build(metric)) - else: - self.metrics.append(metric) - - @property - def dataset_meta(self) -> dict | None: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the evaluator and it's metrics.""" - self._dataset_meta = dataset_meta - for metric in self.metrics: - metric.dataset_meta = dataset_meta - - def process(self, data_samples: Sequence[BaseDataElement], data_batch: Any | None = None): - """Convert ``BaseDataSample`` to dict and invoke process method of each - metric. - - Args: - data_samples (Sequence[BaseDataElement]): predictions of the model, - and the ground truth of the validation set. - data_batch (Any, optional): A batch of data from the dataloader. - """ - _data_samples = [] - for data_sample in data_samples: - if isinstance(data_sample, BaseDataElement): - _data_samples.append(data_sample.to_dict()) - else: - _data_samples.append(data_sample) - - for metric in self.metrics: - metric.process(data_batch, _data_samples) - - def evaluate(self, size: int) -> dict: - """Invoke ``evaluate`` method of each metric and collect the metrics - dictionary. - - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data based on - this size. - - Returns: - dict: Evaluation results of all metrics. The keys are the names - of the metrics, and the values are corresponding results. - """ - metrics = {} - for metric in self.metrics: - _results = metric.evaluate(size) - - # Check metric name conflicts - for name in _results.keys(): - if name in metrics: - raise ValueError( - "There are multiple evaluation results with the same " - f"metric name {name}. Please make sure all metrics " - "have different prefixes." - ) - - metrics.update(_results) - return metrics - - def offline_evaluate(self, data_samples: Sequence, data: Sequence | None = None, chunk_size: int = 1): - """Offline evaluate the dumped predictions on the given data . - - Args: - data_samples (Sequence): All predictions and ground truth of the - model and the validation set. - data (Sequence, optional): All data of the validation set. - chunk_size (int): The number of data samples and predictions to be - processed in a batch. - """ - - # support chunking iterable objects - def get_chunks(seq: Iterator, chunk_size=1): - stop = False - while not stop: - chunk = [] - for _ in range(chunk_size): - try: - chunk.append(next(seq)) - except StopIteration: - stop = True - break - if chunk: - yield chunk - - if data is not None: - assert len(data_samples) == len(data), ( - f"data_samples and data should have the same length, but got data_samples length: {len(data_samples)} data length: {len(data)}" - ) - data = get_chunks(iter(data), chunk_size) - - size = 0 - for output_chunk in get_chunks(iter(data_samples), chunk_size): - if data is not None: - data_chunk = pseudo_collate(next(data)) # type: ignore - else: - data_chunk = None - size += len(output_chunk) - self.process(output_chunk, data_chunk) - return self.evaluate(size) diff --git a/libs/visengine/visengine/evaluator/metric.py b/libs/visengine/visengine/evaluator/metric.py deleted file mode 100644 index 8557336..0000000 --- a/libs/visengine/visengine/evaluator/metric.py +++ /dev/null @@ -1,197 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from abc import ABCMeta, abstractmethod -from collections.abc import Sequence -from typing import Any - -from torch import Tensor - -from visengine.dist import broadcast_object_list, collect_results, is_main_process -from visengine.fileio import dump -from visengine.logging import print_log -from visengine.registry import METRICS -from visengine.structures import BaseDataElement - - -class BaseMetric(metaclass=ABCMeta): - """Base class for a metric. - - The metric first processes each batch of data_samples and predictions, - and appends the processed results to the results list. Then it - collects all results together from all ranks if distributed training - is used. Finally, it computes the metrics of the entire dataset. - - A subclass of class:`BaseMetric` should assign a meaningful value to the - class attribute `default_prefix`. See the argument `prefix` for details. - - Args: - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - collect_dir: (str, optional): Synchronize directory for collecting data - from different ranks. This argument should only be configured when - ``collect_device`` is 'cpu'. Defaults to None. - `New in version 0.7.3.` - """ - - default_prefix: str | None = None - - def __init__( - self, - collect_device: str = "cpu", - prefix: str | None = None, - collect_dir: str | None = None, - ) -> None: - if collect_dir is not None and collect_device != "cpu": - raise ValueError("`collec_dir` could only be configured when `collect_device='cpu'`") - - self._dataset_meta: None | dict = None - self.collect_device = collect_device - self.results: list[Any] = [] - self.prefix = prefix or self.default_prefix - self.collect_dir = collect_dir - - if self.prefix is None: - print_log( - f"The prefix is not set in metric class {self.__class__.__name__}.", - logger="current", - level=logging.WARNING, - ) - - @property - def dataset_meta(self) -> dict | None: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the metric.""" - self._dataset_meta = dataset_meta - - @abstractmethod - def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: - """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to - compute the metrics when all batches have been processed. - - Args: - data_batch (Any): A batch of data from the dataloader. - data_samples (Sequence[dict]): A batch of outputs from - the model. - """ - - @abstractmethod - def compute_metrics(self, results: list) -> dict: - """Compute the metrics from processed results. - - Args: - results (list): The processed results of each batch. - - Returns: - dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - def evaluate(self, size: int) -> dict: - """Evaluate the model performance of the whole dataset after processing - all batches. - - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data based on - this size. - - Returns: - dict: Evaluation metrics dict on the val dataset. The keys are the - names of the metrics, and the values are corresponding results. - """ - if len(self.results) == 0: - print_log( - f"{self.__class__.__name__} got empty `self.results`. Please " - "ensure that the processed results are properly added into " - "`self.results` in `process` method.", - logger="current", - level=logging.WARNING, - ) - - if self.collect_device == "cpu": - results = collect_results(self.results, size, self.collect_device, tmpdir=self.collect_dir) - else: - results = collect_results(self.results, size, self.collect_device) - - if is_main_process(): - # cast all tensors in results list to cpu - results = _to_cpu(results) - _metrics = self.compute_metrics(results) # type: ignore - # Add prefix to metric names - if self.prefix: - _metrics = {"/".join((self.prefix, k)): v for k, v in _metrics.items()} - metrics = [_metrics] - else: - metrics = [None] # type: ignore - - broadcast_object_list(metrics) - - # reset the results list - self.results.clear() - return metrics[0] - - -@METRICS.register_module(force=True) -class DumpResults(BaseMetric): - """Dump model predictions to a pickle file for offline evaluation. - - Args: - out_file_path (str): Path of the dumped file. Must end with '.pkl' - or '.pickle'. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - collect_dir: (str, optional): Synchronize directory for collecting data - from different ranks. This argument should only be configured when - ``collect_device`` is 'cpu'. Defaults to None. - `New in version 0.7.3.` - """ - - def __init__( - self, - out_file_path: str, - collect_device: str = "cpu", - collect_dir: str | None = None, - ) -> None: - super().__init__(collect_device=collect_device, collect_dir=collect_dir) - if not out_file_path.endswith((".pkl", ".pickle")): - raise ValueError("The output file must be a pkl file.") - self.out_file_path = out_file_path - - def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: - """Transfer tensors in predictions to CPU.""" - self.results.extend(_to_cpu(predictions)) - - def compute_metrics(self, results: list) -> dict: - """Dump the prediction results to a pickle file.""" - dump(results, self.out_file_path) - print_log(f"Results has been saved to {self.out_file_path}.", logger="current") - return {} - - -def _to_cpu(data: Any) -> Any: - """Transfer all tensors and BaseDataElement to cpu.""" - if isinstance(data, Tensor | BaseDataElement): - return data.to("cpu") - elif isinstance(data, list): - return [_to_cpu(d) for d in data] - elif isinstance(data, tuple): - return tuple(_to_cpu(d) for d in data) - elif isinstance(data, dict): - return {k: _to_cpu(v) for k, v in data.items()} - else: - return data diff --git a/libs/visengine/visengine/evaluator/utils.py b/libs/visengine/visengine/evaluator/utils.py deleted file mode 100644 index c7bb23c..0000000 --- a/libs/visengine/visengine/evaluator/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any - - -def get_metric_value(indicator: str, metrics: dict) -> Any: - """Get the metric value specified by an indicator, which can be either a - metric name or a full name with evaluator prefix. - - Args: - indicator (str): The metric indicator, which can be the metric name - (e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP') - metrics (dict): The evaluation results output by the evaluator - - Returns: - Any: The specified metric value - """ - - if "/" in indicator: - # The indicator is a full name - if indicator in metrics: - return metrics[indicator] - else: - raise ValueError(f'The indicator "{indicator}" can not match any metric in {list(metrics.keys())}') - else: - # The indicator is metric name without prefix - matched = [k for k in metrics.keys() if k.split("/")[-1] == indicator] - - if not matched: - raise ValueError(f"The indicator {indicator} can not match any metric in {list(metrics.keys())}") - elif len(matched) > 1: - raise ValueError(f'The indicator "{indicator}" matches multiple metrics {matched}') - else: - return metrics[matched[0]] diff --git a/libs/visengine/visengine/fileio/__init__.py b/libs/visengine/visengine/fileio/__init__.py deleted file mode 100644 index dde2837..0000000 --- a/libs/visengine/visengine/fileio/__init__.py +++ /dev/null @@ -1,88 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .backends import ( - BaseStorageBackend, - HTTPBackend, - LmdbBackend, - LocalBackend, - MemcachedBackend, - PetrelBackend, - register_backend, -) -from .file_client import FileClient, HardDiskBackend -from .handlers import ( - BaseFileHandler, - JsonHandler, - PickleHandler, - YamlHandler, - register_handler, -) -from .io import ( - copy_if_symlink_fails, - copyfile, - copyfile_from_local, - copyfile_to_local, - copytree, - copytree_from_local, - copytree_to_local, - dump, - exists, - generate_presigned_url, - get, - get_file_backend, - get_local_path, - get_text, - isdir, - isfile, - join_path, - list_dir_or_file, - load, - put, - put_text, - remove, - rmtree, -) -from .parse import dict_from_file, list_from_file - -__all__ = [ - "BaseFileHandler", - "BaseStorageBackend", - "FileClient", - "HTTPBackend", - "HardDiskBackend", - "JsonHandler", - "LmdbBackend", - "LocalBackend", - "MemcachedBackend", - "PetrelBackend", - "PickleHandler", - "YamlHandler", - "copy_if_symlink_fails", - "copyfile", - "copyfile_from_local", - "copyfile_to_local", - "copytree", - "copytree_from_local", - "copytree_to_local", - "dict_from_file", - "dump", - "exists", - "generate_presigned_url", - "get", - "get_file_backend", - "get_local_path", - "get_text", - "isdir", - "isfile", - "join_path", - "list_dir_or_file", - "list_from_file", - "load", - "put", - "put_text", - "register_backend", - "register_handler", - "remove", - "rmtree", -] diff --git a/libs/visengine/visengine/fileio/backends/__init__.py b/libs/visengine/visengine/fileio/backends/__init__.py deleted file mode 100644 index 88b8fbf..0000000 --- a/libs/visengine/visengine/fileio/backends/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseStorageBackend -from .http_backend import HTTPBackend -from .lmdb_backend import LmdbBackend -from .local_backend import LocalBackend -from .memcached_backend import MemcachedBackend -from .petrel_backend import PetrelBackend -from .registry_utils import backends, prefix_to_backends, register_backend - -__all__ = [ - "BaseStorageBackend", - "HTTPBackend", - "LmdbBackend", - "LocalBackend", - "MemcachedBackend", - "PetrelBackend", - "backends", - "prefix_to_backends", - "register_backend", -] diff --git a/libs/visengine/visengine/fileio/backends/base.py b/libs/visengine/visengine/fileio/backends/base.py deleted file mode 100644 index 37b4a1f..0000000 --- a/libs/visengine/visengine/fileio/backends/base.py +++ /dev/null @@ -1,43 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from abc import ABCMeta, abstractmethod - -from visengine.logging import print_log - - -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: :meth:`get()` and - :meth:`get_text()`. - - - :meth:`get()` reads the file as a byte stream. - - :meth:`get_text()` reads the file as texts. - """ - - # a flag to indicate whether the backend can create a symlink for a file - # This attribute will be deprecated in future. - _allow_symlink = False - - @property - def allow_symlink(self): - print_log( - "allow_symlink will be deprecated in future", - logger="current", - level=logging.WARNING, - ) - return self._allow_symlink - - @property - def name(self): - return self.__class__.__name__ - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass diff --git a/libs/visengine/visengine/fileio/backends/http_backend.py b/libs/visengine/visengine/fileio/backends/http_backend.py deleted file mode 100644 index 334dfbb..0000000 --- a/libs/visengine/visengine/fileio/backends/http_backend.py +++ /dev/null @@ -1,79 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import tempfile -from collections.abc import Generator -from contextlib import contextmanager -from pathlib import Path -from urllib.request import urlopen - -from .base import BaseStorageBackend - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath: str) -> bytes: - """Read bytes from a given ``filepath``. - - Args: - filepath (str): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get('http://path/of/file') - b'hello world' - """ - return urlopen(filepath).read() - - def get_text(self, filepath, encoding="utf-8") -> str: - """Read text from a given ``filepath``. - - Args: - filepath (str): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get_text('http://path/of/file') - 'hello world' - """ - return urlopen(filepath).read().decode(encoding) - - @contextmanager - def get_local_path(self, filepath: str) -> Generator[str | Path, None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with backend.get_local_path('http://path/of/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) diff --git a/libs/visengine/visengine/fileio/backends/lmdb_backend.py b/libs/visengine/visengine/fileio/backends/lmdb_backend.py deleted file mode 100644 index 5d308aa..0000000 --- a/libs/visengine/visengine/fileio/backends/lmdb_backend.py +++ /dev/null @@ -1,78 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from pathlib import Path - -from .base import BaseStorageBackend - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_path (str): Lmdb database path. - readonly (bool): Lmdb environment parameter. If True, disallow any - write operations. Defaults to True. - lock (bool): Lmdb environment parameter. If False, when concurrent - access occurs, do not lock the database. Defaults to False. - readahead (bool): Lmdb environment parameter. If False, disable the OS - filesystem readahead mechanism, which may improve random read - performance when a database is larger than RAM. Defaults to False. - **kwargs: Keyword arguments passed to `lmdb.open`. - - Attributes: - db_path (str): Lmdb database path. - """ - - def __init__(self, db_path, readonly=True, lock=False, readahead=False, **kwargs): - try: - import lmdb # noqa: F401 - except ImportError: - raise ImportError('Please run "pip install lmdb" to enable LmdbBackend.') - - self.db_path = str(db_path) - self.readonly = readonly - self.lock = lock - self.readahead = readahead - self.kwargs = kwargs - self._client = None - - def get(self, filepath: str | Path) -> bytes: - """Get values according to the filepath. - - Args: - filepath (str or Path): Here, filepath is the lmdb key. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LmdbBackend('path/to/lmdb') - >>> backend.get('key') - b'hello world' - """ - if self._client is None: - self._client = self._get_client() - - filepath = str(filepath) - with self._client.begin(write=False) as txn: - value_buf = txn.get(filepath.encode("ascii")) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError - - def _get_client(self): - import lmdb - - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs, - ) - - def __del__(self): - if self._client is not None: - self._client.close() diff --git a/libs/visengine/visengine/fileio/backends/local_backend.py b/libs/visengine/visengine/fileio/backends/local_backend.py deleted file mode 100644 index 288e34f..0000000 --- a/libs/visengine/visengine/fileio/backends/local_backend.py +++ /dev/null @@ -1,548 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import shutil -from collections.abc import Generator, Iterator -from contextlib import contextmanager -from pathlib import Path -import io -from PIL import Image, ImageOps - -from visengine.utils.path import mkdir_or_exist - -from .base import BaseStorageBackend - - -class LocalBackend(BaseStorageBackend): - """Raw local storage backend.""" - - _allow_symlink = True - use_exif = False - - def get(self, filepath: str | Path) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - if self.use_exif: - with open(filepath, "rb") as f: - image = Image.open(f) - image = ImageOps.exif_transpose(image) - buffer = io.BytesIO() - image.save(buffer, format=image.format or "JPEG") - value = buffer.getvalue() - else: - with open(filepath, "rb") as f: - value = f.read() - return value - - def get_text(self, filepath: str | Path, encoding: str = "utf-8") -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - with open(filepath, encoding=encoding) as f: - text = f.read() - return text - - def put(self, obj: bytes, filepath: str | Path) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put(b'hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, "wb") as f: - f.write(obj) - - def put_text(self, obj: str, filepath: str | Path, encoding: str = "utf-8") -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put_text('hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, "w", encoding=encoding) as f: - f.write(obj) - - def exists(self, filepath: str | Path) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.exists(filepath) - True - """ - return osp.exists(filepath) - - def isdir(self, filepath: str | Path) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/dir' - >>> backend.isdir(filepath) - True - """ - return osp.isdir(filepath) - - def isfile(self, filepath: str | Path) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.isfile(filepath) - True - """ - return osp.isfile(filepath) - - def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - - Examples: - >>> backend = LocalBackend() - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> backend.join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - # TODO, if filepath or filepaths are Path, should return Path - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: str | Path, - ) -> Generator[str | Path, None, None]: - """Only for unified API and do nothing. - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> backend = LocalBackend() - >>> with backend.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - yield filepath - - def copyfile( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> backend.copyfile(src, dst) - '/path1/of/dir/file' - """ - return shutil.copy(src, dst) - - def copytree( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - TODO: Whether to support dirs_exist_ok parameter. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree(src, dst) - '/path/of/dir2' - """ - return shutil.copytree(src, dst) - - def copyfile_from_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Copy a local file src to dst and return the destination file. Same - as :meth:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_from_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_from_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_from_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. Same as - :meth:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def copyfile_to_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Copy the file src to local dst and return the destination file. Same - as :meth:`copyfile`. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_to_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_to_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_to_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def remove(self, filepath: str | Path) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.remove(filepath) - """ - if not self.exists(filepath): - raise FileNotFoundError(f"filepath {filepath} does not exist") - - if self.isdir(filepath): - raise IsADirectoryError("filepath should be a file") - - os.remove(filepath) - - def rmtree(self, dir_path: str | Path) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> dir_path = '/path/of/dir' - >>> backend.rmtree(dir_path) - """ - shutil.rmtree(dir_path) - - def copy_if_symlink_fails( - self, - src: str | Path, - dst: str | Path, - ) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directly copy src - to dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - - Returns: - bool: Return True if successfully create a symbolic link pointing - to src. Otherwise, return False. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> backend.copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> backend.copy_if_symlink_fails(src, dst) - True - """ - try: - os.symlink(src, dst) - return True - except Exception: - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file( - self, - dir_path: str | Path, - list_dir: bool = True, - list_file: bool = True, - suffix: str | tuple[str] | None = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = LocalBackend() - >>> dir_path = '/path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - if list_dir and suffix is not None: - raise TypeError("`suffix` should be None when `list_dir` is True") - - if (suffix is not None) and not isinstance(suffix, str | tuple): - raise TypeError("`suffix` must be a string or tuple of strings") - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith(".") and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/libs/visengine/visengine/fileio/backends/memcached_backend.py b/libs/visengine/visengine/fileio/backends/memcached_backend.py deleted file mode 100644 index 8ecf664..0000000 --- a/libs/visengine/visengine/fileio/backends/memcached_backend.py +++ /dev/null @@ -1,59 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from pathlib import Path - -from .base import BaseStorageBackend - - -class MemcachedBackend(BaseStorageBackend): - """Memcached storage backend. - - Attributes: - server_list_cfg (str): Config file for memcached server list. - client_cfg (str): Config file for memcached client. - sys_path (str, optional): Additional path to be appended to `sys.path`. - Defaults to None. - """ - - def __init__(self, server_list_cfg, client_cfg, sys_path=None): - if sys_path is not None: - import sys - - sys.path.append(sys_path) - try: - import mc - except ImportError: - raise ImportError("Please install memcached to enable MemcachedBackend.") - - self.server_list_cfg = server_list_cfg - self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) - # mc.pyvector servers as a point which points to a memory cache - self._mc_buffer = mc.pyvector() - - def get(self, filepath: str | Path): - """Get values according to the filepath. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> server_list_cfg = '/path/of/server_list.conf' - >>> client_cfg = '/path/of/mc.conf' - >>> backend = MemcachedBackend(server_list_cfg, client_cfg) - >>> backend.get('/path/of/file') - b'hello world' - """ - filepath = str(filepath) - import mc - - self._client.Get(filepath, self._mc_buffer) - value_buf = mc.ConvertBuffer(self._mc_buffer) - return value_buf - - def get_text(self, filepath, encoding=None): - raise NotImplementedError diff --git a/libs/visengine/visengine/fileio/backends/petrel_backend.py b/libs/visengine/visengine/fileio/backends/petrel_backend.py deleted file mode 100644 index adebdb6..0000000 --- a/libs/visengine/visengine/fileio/backends/petrel_backend.py +++ /dev/null @@ -1,762 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import re -import tempfile -from collections.abc import Generator, Iterator -from contextlib import contextmanager -from pathlib import Path -from shutil import SameFileError - -import visengine -from visengine.utils import has_method - -from .base import BaseStorageBackend - - -class PetrelBackend(BaseStorageBackend): - """Petrel storage backend (for internal usage). - - PetrelBackend supports reading and writing data to multiple clusters. - If the file path contains the cluster name, PetrelBackend will read data - from specified cluster or write data to it. Otherwise, PetrelBackend will - access the default cluster. - - Args: - path_mapping (dict, optional): Path mapping dict from local path to - Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in - ``filepath`` will be replaced by ``dst``. Defaults to None. - enable_mc (bool, optional): Whether to enable memcached support. - Defaults to True. - conf_path (str, optional): Config path of Petrel client. Default: None. - `New in version 0.3.3`. - - Examples: - >>> backend = PetrelBackend() - >>> filepath1 = 'petrel://path/of/file' - >>> filepath2 = 'cluster-name:petrel://path/of/file' - >>> backend.get(filepath1) # get data from default cluster - >>> client.get(filepath2) # get data from 'cluster-name' cluster - """ - - def __init__( - self, - path_mapping: dict | None = None, - enable_mc: bool = True, - conf_path: str | None = None, - ): - try: - from petrel_client import client - except ImportError: - raise ImportError("Please install petrel_client to enable PetrelBackend.") - - self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc) - assert isinstance(path_mapping, dict) or path_mapping is None - self.path_mapping = path_mapping - - def _map_path(self, filepath: str | Path) -> str: - """Map ``filepath`` to a string path whose prefix will be replaced by - :attr:`self.path_mapping`. - - Args: - filepath (str or Path): Path to be mapped. - """ - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v, 1) - return filepath - - def _format_path(self, filepath: str) -> str: - """Convert a ``filepath`` to standard format of petrel oss. - - If the ``filepath`` is concatenated by ``os.path.join``, in a Windows - environment, the ``filepath`` will be the format of - 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the - above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. - - Args: - filepath (str): Path to be formatted. - """ - return re.sub(r"\\+", "/", filepath) - - def _replace_prefix(self, filepath: str | Path) -> str: - filepath = str(filepath) - return filepath.replace("petrel://", "s3://") - - def get(self, filepath: str | Path) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Return bytes read from filepath. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - value = self._client.Get(filepath) - return value - - def get_text( - self, - filepath: str | Path, - encoding: str = "utf-8", - ) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - return str(self.get(filepath), encoding=encoding) - - def put(self, obj: bytes, filepath: str | Path) -> None: - """Write bytes to a given ``filepath``. - - Args: - obj (bytes): Data to be saved. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.put(b'hello world', filepath) - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.put(filepath, obj) - - def put_text( - self, - obj: str, - filepath: str | Path, - encoding: str = "utf-8", - ) -> None: - """Write text to a given ``filepath``. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to encode the ``obj``. - Defaults to 'utf-8'. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.put_text('hello world', filepath) - """ - self.put(bytes(obj, encoding=encoding), filepath) - - def exists(self, filepath: str | Path) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.exists(filepath) - True - """ - if not (has_method(self._client, "contains") and has_method(self._client, "isdir")): - raise NotImplementedError( - "Current version of Petrel Python SDK has not supported " - "the `contains` and `isdir` methods, please use a higher" - "version or dev branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) or self._client.isdir(filepath) - - def isdir(self, filepath: str | Path) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/dir' - >>> backend.isdir(filepath) - True - """ - if not has_method(self._client, "isdir"): - raise NotImplementedError( - "Current version of Petrel Python SDK has not supported the `isdir` method, please use a higher version or dev branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.isdir(filepath) - - def isfile(self, filepath: str | Path) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.isfile(filepath) - True - """ - if not has_method(self._client, "contains"): - raise NotImplementedError( - "Current version of Petrel Python SDK has not supported the `contains` method, please use a higher version or dev branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) - - def join_path( - self, - filepath: str | Path, - *filepaths: str | Path, - ) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result after concatenation. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.join_path(filepath, 'another/path') - 'petrel://path/of/file/another/path' - >>> backend.join_path(filepath, '/another/path') - 'petrel://path/of/file/another/path' - """ - filepath = self._format_path(self._map_path(filepath)) - if filepath.endswith("/"): - filepath = filepath[:-1] - formatted_paths = [filepath] - for path in filepaths: - formatted_path = self._format_path(self._map_path(path)) - formatted_paths.append(formatted_path.lstrip("/")) - - return "/".join(formatted_paths) - - @contextmanager - def get_local_path( - self, - filepath: str | Path, - ) -> Generator[str | Path, None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str or Path): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = PetrelBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> filepath = 'petrel://path/of/file' - >>> with backend.get_local_path(filepath) as path: - ... # do something here - """ - assert self.isfile(filepath) - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - - def copyfile( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'petrel://path/of/file' - >>> dst = 'petrel://path/of/file1' - >>> backend.copyfile(src, dst) - 'petrel://path/of/file1' - - >>> # dst is a directory - >>> dst = 'petrel://path/of/dir' - >>> backend.copyfile(src, dst) - 'petrel://path/of/dir/file' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, src.split("/")[-1]) - - if src == dst: - raise SameFileError("src and dst should not be same") - - self.put(self.get(src), dst) - return dst - - def copytree( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/dir' - >>> dst = 'petrel://path/of/dir1' - >>> backend.copytree(src, dst) - 'petrel://path/of/dir1' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - - if self.exists(dst): - raise FileExistsError("dst should not exist") - - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - src_path = self.join_path(src, path) - dst_path = self.join_path(dst, path) - self.put(self.get(src_path), dst_path) - - return dst - - def copyfile_from_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Upload a local file src to dst and return the destination file. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'path/of/your/file' - >>> dst = 'petrel://path/of/file1' - >>> backend.copyfile_from_local(src, dst) - 'petrel://path/of/file1' - - >>> # dst is a directory - >>> dst = 'petrel://path/of/dir' - >>> backend.copyfile_from_local(src, dst) - 'petrel://path/of/dir/file' - """ - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, osp.basename(src)) - - with open(src, "rb") as f: - self.put(f.read(), dst) - - return dst - - def copytree_from_local( - self, - src: str | Path, - dst: str | Path, - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'path/of/your/dir' - >>> dst = 'petrel://path/of/dir1' - >>> backend.copytree_from_local(src, dst) - 'petrel://path/of/dir1' - """ - dst = self._format_path(self._map_path(dst)) - if self.exists(dst): - raise FileExistsError("dst should not exist") - - src = str(src) - - for cur_dir, _, files in os.walk(src): - for f in files: - src_path = osp.join(cur_dir, f) - dst_path = self.join_path(dst, src_path.replace(src, "")) - self.copyfile_from_local(src_path, dst_path) - - return dst - - def copyfile_to_local( - self, - src: str | Path, - dst: str | Path, - ) -> str | Path: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = PetrelBackend() - >>> # dst is a file - >>> src = 'petrel://path/of/file' - >>> dst = 'path/of/your/file' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/file' - - >>> # dst is a directory - >>> dst = 'path/of/your/dir' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/dir/file' - """ - if osp.isdir(dst): - basename = osp.basename(src) - if isinstance(dst, str): - dst = osp.join(dst, basename) - else: - assert isinstance(dst, Path) - dst = dst / basename - - with open(dst, "wb") as f: - f.write(self.get(src)) - - return dst - - def copytree_to_local( - self, - src: str | Path, - dst: str | Path, - ) -> str | Path: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/dir' - >>> dst = 'path/of/your/dir' - >>> backend.copytree_to_local(src, dst) - 'path/of/your/dir' - """ - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - dst_path = osp.join(dst, path) - mmengine.mkdir_or_exist(osp.dirname(dst_path)) - with open(dst_path, "wb") as f: - f.write(self.get(self.join_path(src, path))) - - return dst - - def remove(self, filepath: str | Path) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> backend = PetrelBackend() - >>> filepath = 'petrel://path/of/file' - >>> backend.remove(filepath) - """ - if not has_method(self._client, "delete"): - raise NotImplementedError( - "Current version of Petrel Python SDK has not supported the `delete` method, please use a higher version or dev branch instead." - ) - - if not self.exists(filepath): - raise FileNotFoundError(f"filepath {filepath} does not exist") - - if self.isdir(filepath): - raise IsADirectoryError("filepath should be a file") - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.delete(filepath) - - def rmtree(self, dir_path: str | Path) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> backend = PetrelBackend() - >>> dir_path = 'petrel://path/of/dir' - >>> backend.rmtree(dir_path) - """ - for path in self.list_dir_or_file(dir_path, list_dir=False, recursive=True): - filepath = self.join_path(dir_path, path) - self.remove(filepath) - - def copy_if_symlink_fails( - self, - src: str | Path, - dst: str | Path, - ) -> bool: - """Create a symbolic link pointing to src named dst. - - Directly copy src to dst because PetrelBacekend does not support create - a symbolic link. - - Args: - src (str or Path): A file or directory to be copied. - dst (str or Path): Copy a file or directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - bool: Return False because PetrelBackend does not support create - a symbolic link. - - Examples: - >>> backend = PetrelBackend() - >>> src = 'petrel://path/of/file' - >>> dst = 'petrel://path/of/your/file' - >>> backend.copy_if_symlink_fails(src, dst) - False - >>> src = 'petrel://path/of/dir' - >>> dst = 'petrel://path/of/your/dir' - >>> backend.copy_if_symlink_fails(src, dst) - False - """ - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file( - self, - dir_path: str | Path, - list_dir: bool = True, - list_file: bool = True, - suffix: str | tuple[str] | None = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - Petrel has no concept of directories but it simulates the directory - hierarchy in the filesystem through public prefixes. In addition, - if the returned path ends with '/', it means the path is a public - prefix which is a logical directory. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - In addition, the returned path of directory will not contains the - suffix '/' which is consistent with other backends. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = PetrelBackend() - >>> dir_path = 'petrel://path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - if not has_method(self._client, "list"): - raise NotImplementedError( - "Current version of Petrel Python SDK has not supported the `list` method, please use a higher version or dev branch instead." - ) - - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - dir_path = self._replace_prefix(dir_path) - if list_dir and suffix is not None: - raise TypeError("`list_dir` should be False when `suffix` is not None") - - if (suffix is not None) and not isinstance(suffix, str | tuple): - raise TypeError("`suffix` must be a string or tuple of strings") - - # Petrel's simulated directory hierarchy assumes that directory paths - # should end with `/` - if not dir_path.endswith("/"): - dir_path += "/" - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): - for path in self._client.list(dir_path): - # the `self.isdir` is not used here to determine whether path - # is a directory, because `self.isdir` relies on - # `self._client.list` - if path.endswith("/"): # a directory path - next_dir_path = self.join_path(dir_path, path) - if list_dir: - # get the relative path and exclude the last - # character '/' - rel_dir = next_dir_path[len(root) : -1] - yield rel_dir - if recursive: - yield from _list_dir_or_file(next_dir_path, list_dir, list_file, suffix, recursive) - else: # a file path - absolute_path = self.join_path(dir_path, path) - rel_path = absolute_path[len(root) :] - if (suffix is None or rel_path.endswith(suffix)) and list_file: - yield rel_path - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) - - def generate_presigned_url(self, url: str, client_method: str = "get_object", expires_in: int = 3600) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on Petrel backend. - - Note: - Now only work on Petrel backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Default: 'get_object'. - expires_in (int): expires, in seconds. Default: 3600. - - Returns: - str: Generated presigned url. - """ - return self._client.generate_presigned_url(url, client_method, expires_in) diff --git a/libs/visengine/visengine/fileio/backends/registry_utils.py b/libs/visengine/visengine/fileio/backends/registry_utils.py deleted file mode 100644 index 7447a8b..0000000 --- a/libs/visengine/visengine/fileio/backends/registry_utils.py +++ /dev/null @@ -1,121 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect - -from .base import BaseStorageBackend -from .http_backend import HTTPBackend -from .lmdb_backend import LmdbBackend -from .local_backend import LocalBackend -from .memcached_backend import MemcachedBackend -from .petrel_backend import PetrelBackend - -backends: dict = {} -prefix_to_backends: dict = {} - - -def _register_backend( - name: str, - backend: type[BaseStorageBackend], - force: bool = False, - prefixes: str | list | tuple | None = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (BaseStorageBackend): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - """ - global backends, prefix_to_backends - - if not isinstance(name, str): - raise TypeError(f"the backend name should be a string, but got {type(name)}") - - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class, but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - - if name in backends and not force: - raise ValueError( - f'{name} is already registered as a storage backend, add "force=True" if you want to override it' - ) - backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, list | tuple) - - for prefix in prefixes: - if prefix in prefix_to_backends and not force: - raise ValueError( - f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it' - ) - - prefix_to_backends[prefix] = backend - - -def register_backend( - name: str, - backend: type[BaseStorageBackend] | None = None, - force: bool = False, - prefixes: str | list | tuple | None = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - - This method can be used as a normal method or a decorator. - - Examples: - - >>> class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - >>> register_backend('new', NewBackend) - - >>> @register_backend('new') - ... class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - """ - if backend is not None: - _register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - _register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - -register_backend("local", LocalBackend, prefixes="") -register_backend("memcached", MemcachedBackend) -register_backend("lmdb", LmdbBackend) -# To avoid breaking backward Compatibility, 's3' is also used as a -# prefix for PetrelBackend -register_backend("petrel", PetrelBackend, prefixes=["petrel", "s3"]) -register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/libs/visengine/visengine/fileio/file_client.py b/libs/visengine/visengine/fileio/file_client.py deleted file mode 100644 index 8906522..0000000 --- a/libs/visengine/visengine/fileio/file_client.py +++ /dev/null @@ -1,460 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -from collections.abc import Generator, Iterator -from contextlib import contextmanager -from pathlib import Path -from typing import Any - -from visengine.logging import print_log -from visengine.utils import is_filepath - -from .backends import ( - BaseStorageBackend, - HTTPBackend, - LmdbBackend, - LocalBackend, - MemcachedBackend, - PetrelBackend, -) - - -class HardDiskBackend(LocalBackend): - """Raw hard disks storage backend.""" - - def __init__(self, use_exif: bool = False) -> None: - print_log( - '"HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.', - logger="current", - level=logging.WARNING, - ) - self.use_exif = use_exif - - @property - def name(self): - return self.__class__.__name__ - - -class FileClient: - """A general file client to access files in different backends. - - The client loads a file or text in a specified backend from its path - and returns it as a binary or text file. There are two ways to choose a - backend, the name of backend and the prefix of path. Although both of them - can be used to choose a storage backend, ``backend`` has a higher priority - that is if they are all set, the storage backend will be chosen by the - backend argument. If they are all `None`, the disk backend will be chosen. - Note that It can also register other backend accessor with a given name, - prefixes, and backend class. In addition, We use the singleton pattern to - avoid repeated object creation. If the arguments are the same, the same - object will be returned. - - Warning: - `FileClient` will be deprecated in future. Please use io functions - in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io - - Args: - backend (str, optional): The storage backend type. Options are "disk", - "memcached", "lmdb", "http" and "petrel". Defaults to None. - prefix (str, optional): The prefix of the registered storage backend. - Options are "s3", "http", "https". Defaults to None. - - Examples: - >>> # only set backend - >>> file_client = FileClient(backend='petrel') - >>> # only set prefix - >>> file_client = FileClient(prefix='s3') - >>> # set both backend and prefix but use backend to choose client - >>> file_client = FileClient(backend='petrel', prefix='s3') - >>> # if the arguments are the same, the same object is returned - >>> file_client1 = FileClient(backend='petrel') - >>> file_client1 is file_client - True - - Attributes: - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - "disk": HardDiskBackend, - "memcached": MemcachedBackend, - "lmdb": LmdbBackend, - "petrel": PetrelBackend, - "http": HTTPBackend, - } - - _prefix_to_backends: dict = { - "s3": PetrelBackend, - "petrel": PetrelBackend, - "http": HTTPBackend, - "https": HTTPBackend, - } - - _instances: dict = {} - - client: Any - - def __new__(cls, backend=None, prefix=None, **kwargs): - print_log( - '"FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', - logger="current", - level=logging.WARNING, - ) - if backend is None and prefix is None: - backend = "disk" - if backend is not None and backend not in cls._backends: - raise ValueError( - f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}" - ) - if prefix is not None and prefix not in cls._prefix_to_backends: - raise ValueError( - f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}" - ) - - # concatenate the arguments to a unique key for determining whether - # objects with the same arguments were created - arg_key = f"{backend}:{prefix}" - for key, value in kwargs.items(): - arg_key += f":{key}:{value}" - - # if a backend was overridden, it will create a new object - if arg_key in cls._instances: - _instance = cls._instances[arg_key] - else: - # create a new object and put it to _instance - _instance = super().__new__(cls) - if backend is not None: - _instance.client = cls._backends[backend](**kwargs) - else: - _instance.client = cls._prefix_to_backends[prefix](**kwargs) - - cls._instances[arg_key] = _instance - - return _instance - - @property - def name(self): - return self.client.name - - @property - def allow_symlink(self): - return self.client.allow_symlink - - @staticmethod - def parse_uri_prefix(uri: str | Path) -> str | None: - """Parse the prefix of a uri. - - Args: - uri (str | Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> FileClient.parse_uri_prefix('s3://path/of/your/file') - 's3' - - Returns: - str | None: Return the prefix of uri if the uri contains '://' else - ``None``. - """ - assert is_filepath(uri) - uri = str(uri) - if "://" not in uri: - return None - else: - prefix, _ = uri.split("://") - # In the case of PetrelBackend, the prefix may contains the cluster - # name like clusterName:s3 - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - @classmethod - def infer_client(cls, file_client_args: dict | None = None, uri: str | Path | None = None) -> "FileClient": - """Infer a suitable file client based on the URI and arguments. - - Args: - file_client_args (dict, optional): Arguments to instantiate a - FileClient. Defaults to None. - uri (str | Path, optional): Uri to be parsed that contains the file - prefix. Defaults to None. - - Examples: - >>> uri = 's3://path/of/your/file' - >>> file_client = FileClient.infer_client(uri=uri) - >>> file_client_args = {'backend': 'petrel'} - >>> file_client = FileClient.infer_client(file_client_args) - - Returns: - FileClient: Instantiated FileClient object. - """ - assert file_client_args is not None or uri is not None - if file_client_args is None: - file_prefix = cls.parse_uri_prefix(uri) # type: ignore - return cls(prefix=file_prefix) - else: - return cls(**file_client_args) - - @classmethod - def _register_backend(cls, name, backend, force=False, prefixes=None): - if not isinstance(name, str): - raise TypeError(f"the backend name should be a string, but got {type(name)}") - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - if not force and name in cls._backends: - raise KeyError( - f'{name} is already registered as a storage backend, add "force=True" if you want to override it' - ) - - if name in cls._backends and force: - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, cls._backends[name]): - cls._instances.pop(arg_key) - cls._backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, list | tuple) - for prefix in prefixes: - if prefix not in cls._prefix_to_backends: - cls._prefix_to_backends[prefix] = backend - elif (prefix in cls._prefix_to_backends) and force: - overridden_backend = cls._prefix_to_backends[prefix] - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, overridden_backend): - cls._instances.pop(arg_key) - else: - raise KeyError( - f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it' - ) - - @classmethod - def register_backend(cls, name, backend=None, force=False, prefixes=None): - """Register a backend to FileClient. - - This method can be used as a normal class method or a decorator. - - .. code-block:: python - - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - FileClient.register_backend('new', NewBackend) - - or - - .. code-block:: python - - @FileClient.register_backend('new') - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool, optional): Whether to override the backend if the name - has already been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefixes - of the registered storage backend. Defaults to None. - `New in version 1.3.15.` - """ - if backend is not None: - cls._register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - def get(self, filepath: str | Path) -> bytes | memoryview: - """Read data from a given ``filepath`` with 'rb' mode. - - Note: - There are two types of return values for ``get``, one is ``bytes`` - and the other is ``memoryview``. The advantage of using memoryview - is that you can avoid copying, and if you want to convert it to - ``bytes``, you can use ``.tobytes()``. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes | memoryview: Expected bytes object or a memory view of the - bytes object. - """ - return self.client.get(filepath) - - def get_text(self, filepath: str | Path, encoding="utf-8") -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return self.client.get_text(filepath, encoding) - - def put(self, obj: bytes, filepath: str | Path) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - self.client.put(obj, filepath) - - def put_text(self, obj: str, filepath: str | Path) -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - `filepath`. Defaults to 'utf-8'. - """ - self.client.put_text(obj, filepath) - - def remove(self, filepath: str | Path) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - """ - self.client.remove(filepath) - - def exists(self, filepath: str | Path) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return self.client.exists(filepath) - - def isdir(self, filepath: str | Path) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return self.client.isdir(filepath) - - def isfile(self, filepath: str | Path) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return self.client.isfile(filepath) - - def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return self.client.join_path(filepath, *filepaths) - - @contextmanager - def get_local_path(self, filepath: str | Path) -> Generator[str | Path, None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself. - - .. warning:: - ``get_local_path`` is an experimental interface that may change in - the future. - - Args: - filepath (str or Path): Path to be read data. - - Examples: - >>> file_client = FileClient(prefix='s3') - >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one path. - """ - with self.client.get_local_path(str(filepath)) as local_path: - yield local_path - - def list_dir_or_file( - self, - dir_path: str | Path, - list_dir: bool = True, - list_file: bool = True, - suffix: str | tuple[str] | None = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/libs/visengine/visengine/fileio/handlers/__init__.py b/libs/visengine/visengine/fileio/handlers/__init__.py deleted file mode 100644 index dfed722..0000000 --- a/libs/visengine/visengine/fileio/handlers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseFileHandler -from .json_handler import JsonHandler -from .pickle_handler import PickleHandler -from .registry_utils import file_handlers, register_handler -from .yaml_handler import YamlHandler - -__all__ = [ - "BaseFileHandler", - "JsonHandler", - "PickleHandler", - "YamlHandler", - "file_handlers", - "register_handler", -] diff --git a/libs/visengine/visengine/fileio/handlers/base.py b/libs/visengine/visengine/fileio/handlers/base.py deleted file mode 100644 index 93685ec..0000000 --- a/libs/visengine/visengine/fileio/handlers/base.py +++ /dev/null @@ -1,32 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - - -class BaseFileHandler(metaclass=ABCMeta): - # `str_like` is a flag to indicate whether the type of file object is - # str-like object or bytes-like object. Pickle only processes bytes-like - # objects but json only processes str-like object. If it is str-like - # object, `StringIO` will be used to process the buffer. - str_like = True - - @abstractmethod - def load_from_fileobj(self, file, **kwargs): - pass - - @abstractmethod - def dump_to_fileobj(self, obj, file, **kwargs): - pass - - @abstractmethod - def dump_to_str(self, obj, **kwargs): - pass - - def load_from_path(self, filepath, mode="r", **kwargs): - with open(filepath, mode) as f: - return self.load_from_fileobj(f, **kwargs) - - def dump_to_path(self, obj, filepath, mode="w", **kwargs): - with open(filepath, mode) as f: - self.dump_to_fileobj(obj, f, **kwargs) diff --git a/libs/visengine/visengine/fileio/handlers/json_handler.py b/libs/visengine/visengine/fileio/handlers/json_handler.py deleted file mode 100644 index 70b230e..0000000 --- a/libs/visengine/visengine/fileio/handlers/json_handler.py +++ /dev/null @@ -1,37 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import json - -import numpy as np - -from .base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, set | range): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f"{type(obj)} is unsupported for json dump") - - -class JsonHandler(BaseFileHandler): - def load_from_fileobj(self, file): - return json.load(file) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("default", set_default) - json.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("default", set_default) - return json.dumps(obj, **kwargs) diff --git a/libs/visengine/visengine/fileio/handlers/pickle_handler.py b/libs/visengine/visengine/fileio/handlers/pickle_handler.py deleted file mode 100644 index 7aecd46..0000000 --- a/libs/visengine/visengine/fileio/handlers/pickle_handler.py +++ /dev/null @@ -1,27 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import pickle - -from .base import BaseFileHandler - - -class PickleHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return pickle.load(file, **kwargs) - - def load_from_path(self, filepath, **kwargs): - return super().load_from_path(filepath, mode="rb", **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("protocol", 2) - return pickle.dumps(obj, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("protocol", 2) - pickle.dump(obj, file, **kwargs) - - def dump_to_path(self, obj, filepath, **kwargs): - super().dump_to_path(obj, filepath, mode="wb", **kwargs) diff --git a/libs/visengine/visengine/fileio/handlers/registry_utils.py b/libs/visengine/visengine/fileio/handlers/registry_utils.py deleted file mode 100644 index 16c4eb5..0000000 --- a/libs/visengine/visengine/fileio/handlers/registry_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.utils import is_list_of - -from .base import BaseFileHandler -from .json_handler import JsonHandler -from .pickle_handler import PickleHandler -from .yaml_handler import YamlHandler - -file_handlers = { - "json": JsonHandler(), - "yaml": YamlHandler(), - "yml": YamlHandler(), - "pickle": PickleHandler(), - "pkl": PickleHandler(), -} - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") - if isinstance(file_formats, str): - file_formats = [file_formats] - if not is_list_of(file_formats, str): - raise TypeError("file_formats must be a str or a list of str") - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/libs/visengine/visengine/fileio/handlers/yaml_handler.py b/libs/visengine/visengine/fileio/handlers/yaml_handler.py deleted file mode 100644 index d969e09..0000000 --- a/libs/visengine/visengine/fileio/handlers/yaml_handler.py +++ /dev/null @@ -1,26 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import yaml - -try: - from yaml import CDumper as Dumper # type: ignore - from yaml import CLoader as Loader # type: ignore -except ImportError: - from yaml import Dumper, Loader # type: ignore - -from .base import BaseFileHandler # isort:skip - - -class YamlHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - kwargs.setdefault("Loader", Loader) - return yaml.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("Dumper", Dumper) - yaml.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("Dumper", Dumper) - return yaml.dump(obj, **kwargs) diff --git a/libs/visengine/visengine/fileio/io.py b/libs/visengine/visengine/fileio/io.py deleted file mode 100644 index 60c5781..0000000 --- a/libs/visengine/visengine/fileio/io.py +++ /dev/null @@ -1,912 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -"""This module provides unified file I/O related functions, which support -operating I/O with different file backends based on the specified filepath or -backend_args. - -MMEngine currently supports five file backends: - -- LocalBackend -- PetrelBackend -- HTTPBackend -- LmdbBackend -- MemcacheBackend - -Note that this module provide a union of all of the above file backends so -NotImplementedError will be raised if the interface in the file backend is not -implemented. - -There are two ways to call a method of a file backend: - -- Initialize a file backend with ``get_file_backend`` and call its methods. -- Directory call unified I/O functions, which will call ``get_file_backend`` - first and then call the corresponding backend method. - -Examples: - >>> # Initialize a file backend and call its methods - >>> import visengine.fileio as fileio - >>> backend = fileio.get_file_backend(backend_args={'backend': 'petrel'}) - >>> backend.get('s3://path/of/your/file') - - >>> # Directory call unified I/O functions - >>> fileio.get('s3://path/of/your/file') -""" - -import json -import warnings -from collections.abc import Generator, Iterator -from contextlib import contextmanager -from io import BytesIO, StringIO -from pathlib import Path - -from visengine.utils import is_filepath, is_str - -from .backends import backends, prefix_to_backends -from .file_client import FileClient - -# file_handlers and register_handler had been moved to -# mmengine/fileio/handlers/registry_utis. Import them -# in this file to keep backward compatibility. -from .handlers import file_handlers, register_handler # noqa: F401 - -backend_instances: dict = {} - - -def _parse_uri_prefix(uri: str | Path) -> str: - """Parse the prefix of uri. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> _parse_uri_prefix('/home/path/of/your/file') - '' - >>> _parse_uri_prefix('s3://path/of/your/file') - 's3' - >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') - 's3' - - Returns: - str: Return the prefix of uri if the uri contains '://'. Otherwise, - return ''. - """ - assert is_filepath(uri) - uri = str(uri) - # if uri does not contains '://', the uri will be handled by - # LocalBackend by default - if "://" not in uri: - return "" - else: - prefix, _ = uri.split("://") - # In the case of PetrelBackend, the prefix may contain the cluster - # name like clusterName:s3://path/of/your/file - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - -def _get_file_backend(prefix: str, backend_args: dict): - """Return a file backend based on the prefix or backend_args. - - Args: - prefix (str): Prefix of uri. - backend_args (dict): Arguments to instantiate the corresponding - backend. - """ - # backend name has a higher priority - if "backend" in backend_args: - # backend_args should not be modified - backend_args_bak = backend_args.copy() - backend_name = backend_args_bak.pop("backend") - backend = backends[backend_name](**backend_args_bak) - else: - backend = prefix_to_backends[prefix](**backend_args) - return backend - - -def get_file_backend( - uri: str | Path | None = None, - *, - backend_args: dict | None = None, - enable_singleton: bool = False, -): - """Return a file backend based on the prefix of uri or backend_args. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - enable_singleton (bool): Whether to enable the singleton pattern. - If it is True, the backend created will be reused if the - signature is same with the previous one. Defaults to False. - - Returns: - BaseStorageBackend: Instantiated Backend object. - - Examples: - >>> # get file backend based on the prefix of uri - >>> uri = 's3://path/of/your/file' - >>> backend = get_file_backend(uri) - >>> # get file backend based on the backend_args - >>> backend = get_file_backend(backend_args={'backend': 'petrel'}) - >>> # backend name has a higher priority if 'backend' in backend_args - >>> backend = get_file_backend(uri, backend_args={'backend': 'petrel'}) - """ - global backend_instances - - if backend_args is None: - backend_args = {} - - if uri is None and "backend" not in backend_args: - raise ValueError('uri should not be None when "backend" does not exist in backend_args') - - if uri is not None: - prefix = _parse_uri_prefix(uri) - else: - prefix = "" - - if enable_singleton: - # TODO: whether to pass sort_key to json.dumps - unique_key = f"{prefix}:{json.dumps(backend_args)}" - if unique_key in backend_instances: - return backend_instances[unique_key] - - backend = _get_file_backend(prefix, backend_args) - backend_instances[unique_key] = backend - return backend - else: - backend = _get_file_backend(prefix, backend_args) - return backend - - -def get( - filepath: str | Path, - backend_args: dict | None = None, -) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> filepath = '/path/of/file' - >>> get(filepath) - b'hello world' - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.get(filepath) - - -def get_text( - filepath: str | Path, - encoding="utf-8", - backend_args: dict | None = None, -) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> filepath = '/path/of/file' - >>> get_text(filepath) - 'hello world' - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.get_text(filepath, encoding) - - -def put( - obj: bytes, - filepath: str | Path, - backend_args: dict | None = None, -) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> filepath = '/path/of/file' - >>> put(b'hello world', filepath) - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - backend.put(obj, filepath) - - -def put_text( - obj: str, - filepath: str | Path, - backend_args: dict | None = None, -) -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> filepath = '/path/of/file' - >>> put_text('hello world', filepath) - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - backend.put_text(obj, filepath) - - -def exists( - filepath: str | Path, - backend_args: dict | None = None, -) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> exists(filepath) - True - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.exists(filepath) - - -def isdir( - filepath: str | Path, - backend_args: dict | None = None, -) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/dir' - >>> isdir(filepath) - True - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.isdir(filepath) - - -def isfile( - filepath: str | Path, - backend_args: dict | None = None, -) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> isfile(filepath) - True - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.isfile(filepath) - - -def join_path( - filepath: str | Path, - *filepaths: str | Path, - backend_args: dict | None = None, -) -> str | Path: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - *filepaths (str or Path): Other paths to be concatenated. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The result of concatenation. - - Examples: - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - return backend.join_path(filepath, *filepaths) - - -@contextmanager -def get_local_path( - filepath: str | Path, - backend_args: dict | None = None, -) -> Generator[str | Path, None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself and it will - not be released (removed). - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: Only yield one path. - - Examples: - >>> with get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - with backend.get_local_path(str(filepath)) as local_path: - yield local_path - - -def copyfile( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError will - be raised. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> copyfile(src, dst) - '/path1/of/dir/file' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True) - return backend.copyfile(src, dst) - - -def copytree( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will be - raised. - - Examples: - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> copytree(src, dst) - '/path/of/dir2' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True) - return backend.copytree(src, dst) - - -def copyfile_from_local( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Copy a local file src to dst and return the destination file. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = 's3://openmmlab/mmengine/file1' - >>> # src will be copied to 's3://openmmlab/mmengine/file1' - >>> copyfile_from_local(src, dst) - s3://openmmlab/mmengine/file1 - - >>> # dst is a directory - >>> dst = 's3://openmmlab/mmengine' - >>> # src will be copied to 's3://openmmlab/mmengine/file'' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/file' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True) - return backend.copyfile_from_local(src, dst) - - -def copytree_from_local( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = '/path/of/dir' - >>> dst = 's3://openmmlab/mmengine/dir' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/dir' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True) - return backend.copytree_from_local(src, dst) - - -def copyfile_to_local( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = 's3://openmmlab/mmengine/file' - >>> dst = '/path/of/file' - >>> # src will be copied to '/path/of/file' - >>> copyfile_to_local(src, dst) - '/path/of/file' - - >>> # dst is a directory - >>> dst = '/path/of/dir' - >>> # src will be copied to '/path/of/dir/file' - >>> copyfile_to_local(src, dst) - '/path/of/dir/file' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True) - return backend.copyfile_to_local(src, dst) - - -def copytree_to_local( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> str | Path: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = 's3://openmmlab/mmengine/dir' - >>> dst = '/path/of/dir' - >>> copytree_to_local(src, dst) - '/path/of/dir' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True) - return backend.copytree_to_local(src, dst) - - -def remove( - filepath: str | Path, - backend_args: dict | None = None, -) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> filepath = '/path/of/file' - >>> remove(filepath) - """ - backend = get_file_backend(filepath, backend_args=backend_args, enable_singleton=True) - backend.remove(filepath) - - -def rmtree( - dir_path: str | Path, - backend_args: dict | None = None, -) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> dir_path = '/path/of/dir' - >>> rmtree(dir_path) - """ - backend = get_file_backend(dir_path, backend_args=backend_args, enable_singleton=True) - backend.rmtree(dir_path) - - -def copy_if_symlink_fails( - src: str | Path, - dst: str | Path, - backend_args: dict | None = None, -) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directory copy src to - dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return True if successfully create a symbolic link pointing to - src. Otherwise, return False. - - Examples: - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> copy_if_symlink_fails(src, dst) - True - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True) - return backend.copy_if_symlink_fails(src, dst) - - -def list_dir_or_file( - dir_path: str | Path, - list_dir: bool = True, - list_file: bool = True, - suffix: str | tuple[str] | None = None, - recursive: bool = False, - backend_args: dict | None = None, -) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # list those files and directories in current directory - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - backend = get_file_backend(dir_path, backend_args=backend_args, enable_singleton=True) - yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) - - -def generate_presigned_url( - url: str, - client_method: str = "get_object", - expires_in: int = 3600, - backend_args: dict | None = None, -) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on Petrel backend. - - Note: - Now only work on Petrel backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Defaults to 'get_object'. - expires_in (int): expires, in seconds. Defaults to 3600. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Generated presigned url. - """ - backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True) - return backend.generate_presigned_url(url, client_method, expires_in) - - -def load(file, file_format=None, file_client_args=None, backend_args=None, **kwargs): - """Load data from json/yaml/pickle files. - - This method provides a unified api for loading data from serialized files. - - ``load`` supports loading data from serialized files those can be storaged - in different backends. - - Args: - file (str or :obj:`Path` or file-like object): Filename or a file-like - object. - file_format (str, optional): If not specified, the file format will be - inferred from the file extension, otherwise use the specified one. - Currently supported formats include "json", "yaml/yml" and - "pickle/pkl". - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> load('/path/of/your/file') # file is storaged in disk - >>> load('https://path/of/your/file') # file is storaged in Internet - >>> load('s3://path/of/your/file') # file is storaged in petrel - - Returns: - The content from the file. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None and is_str(file): - file_format = file.split(".")[-1] - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.') - - handler = file_handlers[file_format] - if is_str(file): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend(file, backend_args=backend_args) - - if handler.str_like: - with StringIO(file_backend.get_text(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - elif hasattr(file, "read"): - obj = handler.load_from_fileobj(file, **kwargs) - else: - raise TypeError('"file" must be a filepath str or a file-object') - return obj - - -def dump(obj, file=None, file_format=None, file_client_args=None, backend_args=None, **kwargs): - """Dump data to json/yaml/pickle strings or files. - - This method provides a unified api for dumping data as strings or to files, - and also supports custom arguments for each file format. - - ``dump`` supports dumping data as strings or to files which is saved to - different backends. - - Args: - obj (any): The python object to be dumped. - file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dumped to a str, otherwise to a file - specified by the filename or file-like object. - file_format (str, optional): Same as :func:`load`. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> dump('hello world', '/path/of/your/file') # disk - >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel - - Returns: - bool: True for success, False otherwise. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None: - if is_str(file): - file_format = file.split(".")[-1] - elif file is None: - raise ValueError("file_format must be specified since file is None") - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - handler = file_handlers[file_format] - if file is None: - return handler.dump_to_str(obj, **kwargs) - elif is_str(file): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend(file, backend_args=backend_args) - - if handler.str_like: - with StringIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put_text(f.getvalue(), file) - else: - with BytesIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put(f.getvalue(), file) - elif hasattr(file, "write"): - handler.dump_to_fileobj(obj, file, **kwargs) - else: - raise TypeError('"file" must be a filename str or a file-object') diff --git a/libs/visengine/visengine/fileio/parse.py b/libs/visengine/visengine/fileio/parse.py deleted file mode 100644 index 500c941..0000000 --- a/libs/visengine/visengine/fileio/parse.py +++ /dev/null @@ -1,133 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from io import StringIO - -from .file_client import FileClient -from .io import get_text - - -def list_from_file( - filename, - prefix="", - offset=0, - max_num=0, - encoding="utf-8", - file_client_args=None, - backend_args=None, -): - """Load a text file and parse the content as a list of strings. - - ``list_from_file`` supports loading a text file which can be storaged in - different backends and parsing the content as a list for strings. - - Args: - filename (str): Filename. - prefix (str): The prefix to be inserted to the beginning of each item. - offset (int): The offset of lines. - max_num (int): The maximum number of lines to be read, - zeros and negatives mean no limitation. - encoding (str): Encoding used to open the file. Defaults to utf-8. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> list_from_file('/path/of/your/file') # disk - ['hello', 'world'] - >>> list_from_file('s3://path/of/your/file') # ceph or petrel - ['hello', 'world'] - - Returns: - list[str]: A list of strings. - """ - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - cnt = 0 - item_list = [] - - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, filename) - text = file_client.get_text(filename, encoding) - else: - text = get_text(filename, encoding, backend_args=backend_args) - - with StringIO(text) as f: - for _ in range(offset): - f.readline() - for line in f: - if 0 < max_num <= cnt: - break - item_list.append(prefix + line.rstrip("\n\r")) - cnt += 1 - return item_list - - -def dict_from_file(filename, key_type=str, encoding="utf-8", file_client_args=None, backend_args=None): - """Load a text file and parse the content as a dict. - - Each line of the text file will be two or more columns split by - whitespaces or tabs. The first column will be parsed as dict keys, and - the following columns will be parsed as dict values. - - ``dict_from_file`` supports loading a text file which can be storaged in - different backends and parsing the content as a dict. - - Args: - filename(str): Filename. - key_type(type): Type of the dict keys. str is user by default and - type conversion will be performed if specified. - encoding (str): Encoding used to open the file. Defaults to utf-8. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> dict_from_file('/path/of/your/file') # disk - {'key1': 'value1', 'key2': 'value2'} - >>> dict_from_file('s3://path/of/your/file') # ceph or petrel - {'key1': 'value1', 'key2': 'value2'} - - Returns: - dict: The parsed contents. - """ - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - mapping = {} - - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, filename) - text = file_client.get_text(filename, encoding) - else: - text = get_text(filename, encoding, backend_args=backend_args) - - with StringIO(text) as f: - for line in f: - items = line.rstrip("\n").split() - assert len(items) >= 2 - key = key_type(items[0]) - val = items[1:] if len(items) > 2 else items[1] - mapping[key] = val - return mapping diff --git a/libs/visengine/visengine/hooks/__init__.py b/libs/visengine/visengine/hooks/__init__.py deleted file mode 100644 index 118f03e..0000000 --- a/libs/visengine/visengine/hooks/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .checkpoint_hook import CheckpointHook -from .early_stopping_hook import EarlyStoppingHook -from .ema_hook import EMAHook -from .empty_cache_hook import EmptyCacheHook -from .hook import Hook -from .iter_timer_hook import IterTimerHook -from .logger_hook import LoggerHook -from .naive_visualization_hook import NaiveVisualizationHook -from .param_scheduler_hook import ParamSchedulerHook -from .profiler_hook import NPUProfilerHook, ProfilerHook -from .runtime_info_hook import RuntimeInfoHook -from .sampler_seed_hook import DistSamplerSeedHook -from .sync_buffer_hook import SyncBuffersHook - -__all__ = [ - "CheckpointHook", - "DistSamplerSeedHook", - "EMAHook", - "EarlyStoppingHook", - "EmptyCacheHook", - "Hook", - "IterTimerHook", - "LoggerHook", - "NPUProfilerHook", - "NaiveVisualizationHook", - "ParamSchedulerHook", - "ProfilerHook", - "RuntimeInfoHook", - "SyncBuffersHook", -] diff --git a/libs/visengine/visengine/hooks/checkpoint_hook.py b/libs/visengine/visengine/hooks/checkpoint_hook.py deleted file mode 100644 index 4bccff4..0000000 --- a/libs/visengine/visengine/hooks/checkpoint_hook.py +++ /dev/null @@ -1,655 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import hashlib -import logging -import os.path as osp -import pickle -from collections import deque -from collections.abc import Callable, Sequence -from math import inf -from pathlib import Path -from typing import Optional - -from visengine.dist import is_main_process, master_only -from visengine.fileio import FileClient, get_file_backend -from visengine.logging import print_log -from visengine.registry import HOOKS -from visengine.utils import is_list_of, is_seq_of - -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -@HOOKS.register_module(force=True) -class CheckpointHook(Hook): - """Save checkpoints periodically. - - Args: - interval (int): The saving period. If ``by_epoch=True``, interval - indicates epochs, otherwise it indicates iterations. - Defaults to -1, which means "never". - by_epoch (bool): Saving checkpoints by epoch or by iteration. - Defaults to True. - save_optimizer (bool): Whether to save optimizer state_dict in the - checkpoint. It is usually used for resuming experiments. - Defaults to True. - save_param_scheduler (bool): Whether to save param_scheduler state_dict - in the checkpoint. It is usually used for resuming experiments. - Defaults to True. - out_dir (str, Path, Optional): The root directory to save checkpoints. - If not specified, ``runner.work_dir`` will be used by default. If - specified, the ``out_dir`` will be the concatenation of ``out_dir`` - and the last level directory of ``runner.work_dir``. For example, - if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is - ``./work_dir/cur_exp``, then the ckpt will be saved in - ``./tmp/cur_exp``. Defaults to None. - max_keep_ckpts (int): The maximum checkpoints to keep. - In some cases we want only the latest few checkpoints and would - like to delete old ones to save the disk space. - Defaults to -1, which means unlimited. - save_last (bool): Whether to force the last checkpoint to be - saved regardless of interval. Defaults to True. - save_best (str, List[str], optional): If a metric is specified, it - would measure the best checkpoint during evaluation. If a list of - metrics is passed, it would measure a group of best checkpoints - corresponding to the passed metrics. The information about best - checkpoint(s) would be saved in ``runner.message_hub`` to keep - best score value and best checkpoint path, which will be also - loaded when resuming checkpoint. Options are the evaluation metrics - on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox - detection and instance segmentation. ``AR@100`` for proposal - recall. If ``save_best`` is ``auto``, the first key of the returned - ``OrderedDict`` result will be used. Defaults to None. - rule (str, List[str], optional): Comparison rule for best score. If - set to None, it will infer a reasonable rule. Keys such as 'acc', - 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' - will be inferred by 'less' rule. If ``save_best`` is a list of - metrics and ``rule`` is a str, all metrics in ``save_best`` will - share the comparison rule. If ``save_best`` and ``rule`` are both - lists, their length must be the same, and metrics in ``save_best`` - will use the corresponding comparison rule in ``rule``. Options - are 'greater', 'less', None and list which contains 'greater' and - 'less'. Defaults to None. - greater_keys (List[str], optional): Metric keys that will be - inferred by 'greater' comparison rule. If ``None``, - _default_greater_keys will be used. Defaults to None. - less_keys (List[str], optional): Metric keys that will be - inferred by 'less' comparison rule. If ``None``, _default_less_keys - will be used. Defaults to None. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - filename_tmpl (str, optional): String template to indicate checkpoint - name. If specified, must contain one and only one "{}", which will - be replaced with ``epoch + 1`` if ``by_epoch=True`` else - ``iteration + 1``. - Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" - accordingly. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - `New in version 0.2.0.` - published_keys (str, List[str], optional): If ``save_last`` is ``True`` - or ``save_best`` is not ``None``, it will automatically - publish model with keys in the list after training. - Defaults to None. - `New in version 0.7.1.` - save_begin (int): Control the epoch number or iteration number - at which checkpoint saving begins. Defaults to 0, which means - saving at the beginning. - `New in version 0.8.3.` - - Examples: - >>> # Save best based on single metric - >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', - >>> rule='less') - >>> # Save best based on multi metrics with the same comparison rule - >>> CheckpointHook(interval=2, by_epoch=True, - >>> save_best=['acc', 'mIoU'], rule='greater') - >>> # Save best based on multi metrics with different comparison rule - >>> CheckpointHook(interval=2, by_epoch=True, - >>> save_best=['FID', 'IS'], rule=['less', 'greater']) - >>> # Save best based on single metric and publish model after training - >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', - >>> rule='less', published_keys=['meta', 'state_dict']) - """ - - out_dir: str - - priority = "VERY_LOW" - - # logic to save best checkpoints - # Since the key for determining greater or less is related to the - # downstream tasks, downstream repositories may need to overwrite - # the following inner variables accordingly. - - rule_map = {"greater": lambda x, y: x > y, "less": lambda x, y: x < y} - init_value_map = {"greater": -inf, "less": inf} - _default_greater_keys = [ - "acc", - "top", - "AR@", - "auc", - "precision", - "mAP", - "mDice", - "mIoU", - "mAcc", - "aAcc", - ] - _default_less_keys = ["loss"] - - def __init__( - self, - interval: int = -1, - by_epoch: bool = True, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - out_dir: str | Path | None = None, - max_keep_ckpts: int = -1, - save_last: bool = True, - save_best: str | list[str] | None = None, - rule: str | list[str] | None = None, - greater_keys: Sequence[str] | None = None, - less_keys: Sequence[str] | None = None, - file_client_args: dict | None = None, - filename_tmpl: str | None = None, - backend_args: dict | None = None, - published_keys: str | list[str] | None = None, - save_begin: int = 0, - **kwargs, - ) -> None: - self.interval = interval - self.by_epoch = by_epoch - self.save_optimizer = save_optimizer - self.save_param_scheduler = save_param_scheduler - self.out_dir = out_dir # type: ignore - self.max_keep_ckpts = max_keep_ckpts - self.save_last = save_last - self.args = kwargs - - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - logger="current", - level=logging.WARNING, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - self.file_client_args = file_client_args - self.backend_args = backend_args - - if filename_tmpl is None: - if self.by_epoch: - self.filename_tmpl = "epoch_{}.pth" - else: - self.filename_tmpl = "iter_{}.pth" - else: - self.filename_tmpl = filename_tmpl - - # save best logic - assert isinstance(save_best, str) or is_list_of(save_best, str) or (save_best is None), ( - f'"save_best" should be a str or list of str or None, but got {type(save_best)}' - ) - - if isinstance(save_best, list): - if "auto" in save_best: - assert len(save_best) == 1, 'Only support one "auto" in "save_best" list.' - assert len(save_best) == len(set(save_best)), 'Find duplicate element in "save_best".' - else: - # convert str to list[str] - if save_best is not None: - save_best = [save_best] # type: ignore - self.save_best = save_best - - # rule logic - assert isinstance(rule, str) or is_list_of(rule, str) or (rule is None), ( - f'"rule" should be a str or list of str or None, but got {type(rule)}' - ) - if isinstance(rule, list): - # check the length of rule list - assert len(rule) in [1, len(self.save_best)], ( # type: ignore - f'Number of "rule" must be 1 or the same as number of "save_best", but got {len(rule)}.' - ) - else: - # convert str/None to list - rule = [rule] # type: ignore - - if greater_keys is None: - self.greater_keys = self._default_greater_keys - else: - if not isinstance(greater_keys, list | tuple): - greater_keys = (greater_keys,) # type: ignore - assert is_seq_of(greater_keys, str) - self.greater_keys = greater_keys # type: ignore - - if less_keys is None: - self.less_keys = self._default_less_keys - else: - if not isinstance(less_keys, list | tuple): - less_keys = (less_keys,) # type: ignore - assert is_seq_of(less_keys, str) - self.less_keys = less_keys # type: ignore - - if self.save_best is not None: - self.is_better_than: dict[str, Callable] = {} - self._init_rule(rule, self.save_best) - if len(self.key_indicators) == 1: - self.best_ckpt_path: str | None = None - else: - self.best_ckpt_path_dict: dict = {} - - # published keys - if not (isinstance(published_keys, str) or is_seq_of(published_keys, str) or published_keys is None): - raise TypeError( - f'"published_keys" should be a str or a sequence of str or None, but got {type(published_keys)}' - ) - - if isinstance(published_keys, str): - published_keys = [published_keys] - elif isinstance(published_keys, list | tuple): - assert len(published_keys) == len(set(published_keys)), 'Find duplicate elements in "published_keys".' - self.published_keys = published_keys - - self.last_ckpt = None - if save_begin < 0: - raise ValueError("save_begin should not be less than 0, but got {save_begin}") - self.save_begin = save_begin - - def before_train(self, runner) -> None: - """Finish all operations, related to checkpoint. - - This function will get the appropriate file client, and the directory - to save these checkpoints of the model. - - Args: - runner (Runner): The runner of the training process. - """ - if self.out_dir is None: - self.out_dir = runner.work_dir - - # If self.file_client_args is None, self.file_client will not - # used in CheckpointHook. To avoid breaking backward compatibility, - # it will not be removed util the release of MMEngine1.0 - self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) - - if self.file_client_args is None: - self.file_backend = get_file_backend(self.out_dir, backend_args=self.backend_args) - else: - self.file_backend = self.file_client - - # if `self.out_dir` is not equal to `runner.work_dir`, it means that - # `self.out_dir` is set so the final `self.out_dir` is the - # concatenation of `self.out_dir` and the last level directory of - # `runner.work_dir` - if self.out_dir != runner.work_dir: - basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_backend.join_path(self.out_dir, basename) # type: ignore - - runner.logger.info(f"Checkpoints will be saved to {self.out_dir}.") - - if self.save_best is not None: - if len(self.key_indicators) == 1: - if "best_ckpt" not in runner.message_hub.runtime_info: - self.best_ckpt_path = None - else: - self.best_ckpt_path = runner.message_hub.get_info("best_ckpt") - else: - for key_indicator in self.key_indicators: - best_ckpt_name = f"best_ckpt_{key_indicator}" - if best_ckpt_name not in runner.message_hub.runtime_info: - self.best_ckpt_path_dict[key_indicator] = None - else: - self.best_ckpt_path_dict[key_indicator] = runner.message_hub.get_info(best_ckpt_name) - - if self.max_keep_ckpts > 0: - keep_ckpt_ids = [] - if "keep_ckpt_ids" in runner.message_hub.runtime_info: - keep_ckpt_ids = runner.message_hub.get_info("keep_ckpt_ids") - - while len(keep_ckpt_ids) > self.max_keep_ckpts: - step = keep_ckpt_ids.pop(0) - if is_main_process(): - path = self.file_backend.join_path(self.out_dir, self.filename_tmpl.format(step)) - if self.file_backend.isfile(path): - self.file_backend.remove(path) - elif self.file_backend.isdir(path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(path) - - self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, self.max_keep_ckpts) - - def after_train_epoch(self, runner) -> None: - """Save the checkpoint and synchronize buffers after each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - if not self.by_epoch: - return - - # save checkpoint for following cases: - # 1. every ``self.interval`` epochs which start at ``self.save_begin`` - # 2. reach the last epoch of training - if self.every_n_epochs(runner, self.interval, self.save_begin) or ( - self.save_last and self.is_last_train_epoch(runner) - ): - runner.logger.info(f"Saving checkpoint at {runner.epoch + 1} epochs") - self._save_checkpoint(runner) - - def after_val_epoch(self, runner, metrics): - """Save the checkpoint and synchronize buffers after each evaluation - epoch. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics - """ - if len(metrics) == 0: - runner.logger.warning( - "Since `metrics` is an empty dict, the behavior to save the best checkpoint will be skipped in this evaluation." - ) - return - - self._save_best_checkpoint(runner, metrics) - - def after_train(self, runner) -> None: - """Publish the checkpoint after training. - - Args: - runner (Runner): The runner of the training process. - """ - if self.published_keys is None: - return - - if self.save_last and self.last_ckpt is not None: - self._publish_model(runner, self.last_ckpt) - - if getattr(self, "best_ckpt_path", None) is not None: - self._publish_model(runner, str(self.best_ckpt_path)) - if getattr(self, "best_ckpt_path_dict", None) is not None: - for best_ckpt in self.best_ckpt_path_dict.values(): - self._publish_model(runner, best_ckpt) - - @master_only - def _publish_model(self, runner, ckpt_path: str) -> None: - """Remove unnecessary keys from ckpt_path and save the new checkpoint. - - Args: - runner (Runner): The runner of the training process. - ckpt_path (str): The checkpoint path that ought to be published. - """ - from visengine.runner import save_checkpoint - from visengine.runner.checkpoint import _load_checkpoint - - checkpoint = _load_checkpoint(ckpt_path) - assert self.published_keys is not None - removed_keys = [] - for key in list(checkpoint.keys()): - if key not in self.published_keys: - removed_keys.append(key) - checkpoint.pop(key) - if removed_keys: - print_log( - f"Key {removed_keys} will be removed because they are not " - "found in published_keys. If you want to keep them, " - f"please set `{removed_keys}` in published_keys", - logger="current", - ) - checkpoint_data = pickle.dumps(checkpoint) - sha = hashlib.sha256(checkpoint_data).hexdigest() - final_path = osp.splitext(ckpt_path)[0] + f"-{sha[:8]}.pth" - save_checkpoint(checkpoint, final_path) - print_log( - f"The checkpoint ({ckpt_path}) is published to {final_path}.", - logger="current", - ) - - def _save_checkpoint_with_step(self, runner, step, meta): - # remove other checkpoints before save checkpoint to make the - # self.keep_ckpt_ids are saved as expected - if self.max_keep_ckpts > 0: - # _save_checkpoint and _save_best_checkpoint may call this - # _save_checkpoint_with_step in one epoch - if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: - pass - else: - if len(self.keep_ckpt_ids) == self.max_keep_ckpts: - _step = self.keep_ckpt_ids.popleft() - if is_main_process(): - ckpt_path = self.file_backend.join_path(self.out_dir, self.filename_tmpl.format(_step)) - - if self.file_backend.isfile(ckpt_path): - self.file_backend.remove(ckpt_path) - elif self.file_backend.isdir(ckpt_path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(ckpt_path) - - self.keep_ckpt_ids.append(step) - runner.message_hub.update_info("keep_ckpt_ids", list(self.keep_ckpt_ids)) - - ckpt_filename = self.filename_tmpl.format(step) - self.last_ckpt = self.file_backend.join_path(self.out_dir, ckpt_filename) - runner.message_hub.update_info("last_ckpt", self.last_ckpt) - - runner.save_checkpoint( - self.out_dir, - ckpt_filename, - self.file_client_args, - save_optimizer=self.save_optimizer, - save_param_scheduler=self.save_param_scheduler, - meta=meta, - by_epoch=self.by_epoch, - backend_args=self.backend_args, - **self.args, - ) - - # Model parallel-like training should involve pulling sharded states - # from all ranks, but skip the following procedure. - if not is_main_process(): - return - - save_file = osp.join(runner.work_dir, "last_checkpoint") - with open(save_file, "w") as f: - f.write(self.last_ckpt) # type: ignore - - def _save_checkpoint(self, runner) -> None: - """Save the current checkpoint and delete outdated checkpoint. - - Args: - runner (Runner): The runner of the training process. - """ - if self.by_epoch: - step = runner.epoch + 1 - meta = {"epoch": step, "iter": runner.iter} - else: - step = runner.iter + 1 - meta = {"epoch": runner.epoch, "iter": step} - - self._save_checkpoint_with_step(runner, step, meta=meta) - - def _save_best_checkpoint(self, runner, metrics) -> None: - """Save the current checkpoint and delete outdated checkpoint. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics. - """ - if not self.save_best: - return - - if self.by_epoch: - ckpt_filename = self.filename_tmpl.format(runner.epoch) - cur_type, cur_time = "epoch", runner.epoch - else: - ckpt_filename = self.filename_tmpl.format(runner.iter) - cur_type, cur_time = "iter", runner.iter - - meta = {"epoch": runner.epoch, "iter": runner.iter} - - # handle auto in self.key_indicators and self.rules before the loop - if "auto" in self.key_indicators: - self._init_rule(self.rules, [next(iter(metrics.keys()))]) - - best_ckpt_updated = False - # save best logic - # get score from messagehub - for key_indicator, rule in zip(self.key_indicators, self.rules, strict=False): - key_score = metrics[key_indicator] - - if len(self.key_indicators) == 1: - best_score_key = "best_score" - runtime_best_ckpt_key = "best_ckpt" - best_ckpt_path = self.best_ckpt_path - else: - best_score_key = f"best_score_{key_indicator}" - runtime_best_ckpt_key = f"best_ckpt_{key_indicator}" - best_ckpt_path = self.best_ckpt_path_dict[key_indicator] - - if best_score_key not in runner.message_hub.runtime_info: - best_score = self.init_value_map[rule] - else: - best_score = runner.message_hub.get_info(best_score_key) - - if key_score is None or not self.is_better_than[key_indicator](key_score, best_score): - continue - - best_ckpt_updated = True - - best_score = key_score - runner.message_hub.update_info(best_score_key, best_score) - - if best_ckpt_path and is_main_process(): - is_removed = False - if self.file_backend.isfile(best_ckpt_path): - self.file_backend.remove(best_ckpt_path) - is_removed = True - elif self.file_backend.isdir(best_ckpt_path): - # checkpoints saved by deepspeed are directories - self.file_backend.rmtree(best_ckpt_path) - is_removed = True - - if is_removed: - runner.logger.info(f"The previous best checkpoint {best_ckpt_path} is removed") - - best_ckpt_name = f"best_{key_indicator}_{ckpt_filename}" - # Replace illegal characters for filename with `_` - best_ckpt_name = best_ckpt_name.replace("/", "_") - if len(self.key_indicators) == 1: - self.best_ckpt_path = self.file_backend.join_path( # type: ignore - self.out_dir, best_ckpt_name - ) - runner.message_hub.update_info(runtime_best_ckpt_key, self.best_ckpt_path) - else: - self.best_ckpt_path_dict[key_indicator] = self.file_backend.join_path( # type: ignore - self.out_dir, best_ckpt_name - ) - runner.message_hub.update_info(runtime_best_ckpt_key, self.best_ckpt_path_dict[key_indicator]) - runner.save_checkpoint( - self.out_dir, - filename=best_ckpt_name, - file_client_args=self.file_client_args, - save_optimizer=False, - save_param_scheduler=False, - meta=meta, - by_epoch=False, - backend_args=self.backend_args, - ) - runner.logger.info( - f"The best checkpoint with {best_score:0.4f} {key_indicator} at {cur_time} {cur_type} is saved to {best_ckpt_name}." - ) - - # save checkpoint again to update the best_score and best_ckpt stored - # in message_hub because the checkpoint saved in `after_train_epoch` - # or `after_train_iter` stage only keep the previous best checkpoint - # not the current best checkpoint which causes the current best - # checkpoint can not be removed when resuming training. - if best_ckpt_updated and self.last_ckpt is not None: - self._save_checkpoint_with_step(runner, cur_time, meta) - - def _init_rule(self, rules, key_indicators) -> None: - """Initialize rule, key_indicator, comparison_func, and best score. If - key_indicator is a list of string and rule is a string, all metric in - the key_indicator will share the same rule. - - Here is the rule to determine which rule is used for key indicator when - the rule is not specific (note that the key indicator matching is case- - insensitive): - - 1. If the key indicator is in ``self.greater_keys``, the rule - will be specified as 'greater'. - 2. Or if the key indicator is in ``self.less_keys``, the rule - will be specified as 'less'. - 3. Or if any one item in ``self.greater_keys`` is a substring of - key_indicator, the rule will be specified as 'greater'. - 4. Or if any one item in ``self.less_keys`` is a substring of - key_indicator, the rule will be specified as 'less'. - - Args: - rule (List[Optional[str]]): Comparison rule for best score. - key_indicator (List[str]): Key indicator to determine - the comparison rule. - """ - if len(rules) == 1: - rules = rules * len(key_indicators) - - self.rules = [] - for rule, key_indicator in zip(rules, key_indicators, strict=False): - if rule not in self.rule_map and rule is not None: - raise KeyError(f"rule must be greater, less or None, but got {rule}.") - - if rule is None and key_indicator != "auto": - # `_lc` here means we use the lower case of keys for - # case-insensitive matching - key_indicator_lc = key_indicator.lower() - greater_keys = {key.lower() for key in self.greater_keys} - less_keys = {key.lower() for key in self.less_keys} - - if key_indicator_lc in greater_keys: - rule = "greater" - elif key_indicator_lc in less_keys: - rule = "less" - elif any(key in key_indicator_lc for key in greater_keys): - rule = "greater" - elif any(key in key_indicator_lc for key in less_keys): - rule = "less" - else: - raise ValueError( - f"Cannot infer the rule for key {key_indicator}, thus a specific rule must be specified." - ) - if rule is not None: - self.is_better_than[key_indicator] = self.rule_map[rule] - self.rules.append(rule) - - self.key_indicators = key_indicators - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs=Optional[dict], - ) -> None: - """Save the checkpoint and synchronize buffers after each iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - if self.by_epoch: - return - - # save checkpoint for following cases: - # 1. every ``self.interval`` iterations - # which start at ``self.save_begin`` - # 2. reach the last iteration of training - if self.every_n_train_iters(runner, self.interval, self.save_begin) or ( - self.save_last and self.is_last_train_iter(runner) - ): - runner.logger.info(f"Saving checkpoint at {runner.iter + 1} iterations") - self._save_checkpoint(runner) diff --git a/libs/visengine/visengine/hooks/early_stopping_hook.py b/libs/visengine/visengine/hooks/early_stopping_hook.py deleted file mode 100644 index 8c1a44c..0000000 --- a/libs/visengine/visengine/hooks/early_stopping_hook.py +++ /dev/null @@ -1,158 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from math import inf, isfinite -from typing import Optional - -from visengine.registry import HOOKS - -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -@HOOKS.register_module(force=True) -class EarlyStoppingHook(Hook): - """Early stop the training when the monitored metric reached a plateau. - - Args: - monitor (str): The monitored metric key to decide early stopping. - rule (str, optional): Comparison rule. Options are 'greater', - 'less'. Defaults to None. - min_delta (float, optional): Minimum difference to continue the - training. Defaults to 0.01. - strict (bool, optional): Whether to crash the training when `monitor` - is not found in the `metrics`. Defaults to False. - check_finite: Whether to stop training when the monitor becomes NaN or - infinite. Defaults to True. - patience (int, optional): The times of validation with no improvement - after which training will be stopped. Defaults to 5. - stopping_threshold (float, optional): Stop training immediately once - the monitored quantity reaches this threshold. Defaults to None. - - Note: - `New in version 0.7.0.` - """ - - priority = "LOWEST" - - rule_map = {"greater": lambda x, y: x > y, "less": lambda x, y: x < y} - _default_greater_keys = [ - "acc", - "top", - "AR@", - "auc", - "precision", - "mAP", - "mDice", - "mIoU", - "mAcc", - "aAcc", - ] - _default_less_keys = ["loss"] - - def __init__( - self, - monitor: str, - rule: str | None = None, - min_delta: float = 0.1, - strict: bool = False, - check_finite: bool = True, - patience: int = 5, - stopping_threshold: float | None = None, - ): - self.monitor = monitor - if rule is not None: - if rule not in ["greater", "less"]: - raise ValueError(f'`rule` should be either "greater" or "less", but got {rule}') - else: - rule = self._init_rule(monitor) - self.rule = rule - self.min_delta = min_delta if rule == "greater" else -1 * min_delta - self.strict = strict - self.check_finite = check_finite - self.patience = patience - self.stopping_threshold = stopping_threshold - - self.wait_count = 0 - self.best_score = -inf if rule == "greater" else inf - - def _init_rule(self, monitor: str) -> str: - greater_keys = {key.lower() for key in self._default_greater_keys} - less_keys = {key.lower() for key in self._default_less_keys} - monitor_lc = monitor.lower() - if monitor_lc in greater_keys: - rule = "greater" - elif monitor_lc in less_keys: - rule = "less" - elif any(key in monitor_lc for key in greater_keys): - rule = "greater" - elif any(key in monitor_lc for key in less_keys): - rule = "less" - else: - raise ValueError(f"Cannot infer the rule for {monitor}, thus rule must be specified.") - return rule - - def _check_stop_condition(self, current_score: float) -> tuple[bool, str]: - compare = self.rule_map[self.rule] - stop_training = False - reason_message = "" - - if self.check_finite and not isfinite(current_score): - stop_training = True - reason_message = f"Monitored metric {self.monitor} = {current_score} is infinite. Previous best value was {self.best_score:.3f}." - - elif self.stopping_threshold is not None and compare(current_score, self.stopping_threshold): - stop_training = True - self.best_score = current_score - reason_message = f"Stopping threshold reached: `{self.monitor}` = {current_score} is {self.rule} than {self.stopping_threshold}." - elif compare(self.best_score + self.min_delta, current_score): - self.wait_count += 1 - - if self.wait_count >= self.patience: - reason_message = f"the monitored metric did not improve in the last {self.wait_count} records. best score: {self.best_score:.3f}. " - stop_training = True - else: - self.best_score = current_score - self.wait_count = 0 - - return stop_training, reason_message - - def before_run(self, runner) -> None: - """Check `stop_training` variable in `runner.train_loop`. - - Args: - runner (Runner): The runner of the training process. - """ - - assert hasattr(runner.train_loop, "stop_training"), "`train_loop` should contain `stop_training` variable." - - def after_val_epoch(self, runner, metrics): - """Decide whether to stop the training process. - - Args: - runner (Runner): The runner of the training process. - metrics (dict): Evaluation results of all metrics - """ - - if self.monitor not in metrics: - if self.strict: - raise RuntimeError( - "Early stopping conditioned on metric " - f"`{self.monitor} is not available. Please check available" - f" metrics {metrics}, or set `strict=False` in " - "`EarlyStoppingHook`." - ) - warnings.warn( - f"Skip early stopping process since the evaluation results ({metrics.keys()}) do not include `monitor` ({self.monitor}).", - stacklevel=2, - ) - return - - current_score = metrics[self.monitor] - - stop_training, message = self._check_stop_condition(current_score) - if stop_training: - runner.train_loop.stop_training = True - runner.logger.info(message) diff --git a/libs/visengine/visengine/hooks/ema_hook.py b/libs/visengine/visengine/hooks/ema_hook.py deleted file mode 100644 index 411b916..0000000 --- a/libs/visengine/visengine/hooks/ema_hook.py +++ /dev/null @@ -1,238 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import itertools -import logging - -from visengine.logging import print_log -from visengine.model import is_model_wrapper -from visengine.registry import HOOKS, MODELS - -from .hook import DATA_BATCH, Hook - - -@HOOKS.register_module(force=True) -class EMAHook(Hook): - """A Hook to apply Exponential Moving Average (EMA) on the model during - training. - - Note: - - EMAHook takes priority over CheckpointHook. - - The original model parameters are actually saved in ema field after - train. - - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. - - Args: - ema_type (str): The type of EMA strategy to use. You can find the - supported strategies in :mod:`mmengine.model.averaged_model`. - Defaults to 'ExponentialMovingAverage'. - strict_load (bool): Whether to strictly enforce that the keys of - ``state_dict`` in checkpoint match the keys returned by - ``self.module.state_dict``. Defaults to False. - Changed in v0.3.0. - begin_iter (int): The number of iteration to enable ``EMAHook``. - Defaults to 0. - begin_epoch (int): The number of epoch to enable ``EMAHook``. - Defaults to 0. - **kwargs: Keyword arguments passed to subclasses of - :obj:`BaseAveragedModel` - """ - - priority = "NORMAL" - - def __init__( - self, - ema_type: str = "ExponentialMovingAverage", - strict_load: bool = False, - begin_iter: int = 0, - begin_epoch: int = 0, - **kwargs, - ): - self.strict_load = strict_load - self.ema_cfg = dict(type=ema_type, **kwargs) - assert not (begin_iter != 0 and begin_epoch != 0), "`begin_iter` and `begin_epoch` should not be both set." - assert begin_iter >= 0, f"`begin_iter` must larger than or equal to 0, but got begin_iter: {begin_iter}" - assert begin_epoch >= 0, f"`begin_epoch` must larger than or equal to 0, but got begin_epoch: {begin_epoch}" - self.begin_iter = begin_iter - self.begin_epoch = begin_epoch - # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be - # enabled at 0 iteration. - self.enabled_by_epoch = self.begin_epoch > 0 - - def before_run(self, runner) -> None: - """Create an ema copy of the model. - - Args: - runner (Runner): The runner of the training process. - """ - model = runner.model - if is_model_wrapper(model): - model = model.module - self.src_model = model - self.ema_model = MODELS.build(self.ema_cfg, default_args={"model": self.src_model}) - - def before_train(self, runner) -> None: - """Check the begin_epoch/iter is smaller than max_epochs/iters. - - Args: - runner (Runner): The runner of the training process. - """ - if self.enabled_by_epoch: - assert self.begin_epoch <= runner.max_epochs, ( - f"self.begin_epoch should be smaller than or equal to runner.max_epochs: {runner.max_epochs}, but got begin_epoch: {self.begin_epoch}" - ) - else: - assert self.begin_iter <= runner.max_iters, ( - f"self.begin_iter should be smaller than or equal to runner.max_iters: {runner.max_iters}, but got begin_iter: {self.begin_iter}" - ) - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | None = None, - ) -> None: - """Update ema parameter. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. - """ - if self._ema_started(runner): - self.ema_model.update_parameters(self.src_model) - else: - ema_params = self.ema_model.module.state_dict() - src_params = self.src_model.state_dict() - for k, p in ema_params.items(): - p.data.copy_(src_params[k].data) - - def before_val_epoch(self, runner) -> None: - """We load parameter values from ema model to source model before - validation. - - Args: - runner (Runner): The runner of the training process. - """ - self._swap_ema_parameters() - - def after_val_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """We recover source model's parameter from ema model after validation. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._swap_ema_parameters() - - def before_test_epoch(self, runner) -> None: - """We load parameter values from ema model to source model before test. - - Args: - runner (Runner): The runner of the training process. - """ - self._swap_ema_parameters() - - def after_test_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """We recover source model's parameter from ema model after test. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._swap_ema_parameters() - - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: - """Save ema parameters to checkpoint. - - Args: - runner (Runner): The runner of the testing process. - """ - checkpoint["ema_state_dict"] = self.ema_model.state_dict() - # Save ema parameters to the source model's state dict so that we - # can directly load the averaged model weights for deployment. - # Swapping the state_dict key-values instead of swapping model - # parameters because the state_dict is a shallow copy of model - # parameters. - self._swap_ema_state_dict(checkpoint) - - def after_load_checkpoint(self, runner, checkpoint: dict) -> None: - """Resume ema parameters from checkpoint. - - Args: - runner (Runner): The runner of the testing process. - """ - from visengine.runner.checkpoint import load_state_dict - - if "ema_state_dict" in checkpoint and runner._resume: - # The original model parameters are actually saved in ema - # field swap the weights back to resume ema state. - self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict(checkpoint["ema_state_dict"], strict=self.strict_load) - - # Support load checkpoint without ema state dict. - else: - if runner._resume: - print_log( - "There is no `ema_state_dict` in checkpoint. `EMAHook` will make a copy of `state_dict` as the initial `ema_state_dict`", - "current", - logging.WARNING, - ) - load_state_dict( - self.ema_model.module, - copy.deepcopy(checkpoint["state_dict"]), - strict=self.strict_load, - ) - - def _swap_ema_parameters(self) -> None: - """Swap the parameter of model with ema_model.""" - avg_param = ( - itertools.chain(self.ema_model.module.parameters(), self.ema_model.module.buffers()) - if self.ema_model.update_buffers - else self.ema_model.module.parameters() - ) - src_param = ( - itertools.chain(self.src_model.parameters(), self.src_model.buffers()) - if self.ema_model.update_buffers - else self.src_model.parameters() - ) - for p_avg, p_src in zip(avg_param, src_param, strict=False): - tmp = p_avg.data.clone() - p_avg.data.copy_(p_src.data) - p_src.data.copy_(tmp) - - def _swap_ema_state_dict(self, checkpoint): - """Swap the state dict values of model with ema_model.""" - model_state = checkpoint["state_dict"] - ema_state = checkpoint["ema_state_dict"] - for k in ema_state: - if k[:7] == "module.": - tmp = ema_state[k] - ema_state[k] = model_state[k[7:]] - model_state[k[7:]] = tmp - - def _ema_started(self, runner) -> bool: - """Whether ``EMAHook`` has been initialized at current iteration or - epoch. - - :attr:`ema_model` will be initialized when ``runner.iter`` or - ``runner.epoch`` is greater than ``self.begin`` for the first time. - - Args: - runner (Runner): Runner of the training, validation process. - - Returns: - bool: Whether ``EMAHook`` has been initialized. - """ - if self.enabled_by_epoch: - return runner.epoch + 1 >= self.begin_epoch - else: - return runner.iter + 1 >= self.begin_iter diff --git a/libs/visengine/visengine/hooks/empty_cache_hook.py b/libs/visengine/visengine/hooks/empty_cache_hook.py deleted file mode 100644 index 82820a6..0000000 --- a/libs/visengine/visengine/hooks/empty_cache_hook.py +++ /dev/null @@ -1,81 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Sequence -from typing import Optional - -import torch - -from visengine.registry import HOOKS - -from ..device import is_cuda_available -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -@HOOKS.register_module(force=True) -class EmptyCacheHook(Hook): - """Releases all unoccupied cached GPU memory during the process of - training. - - Args: - before_epoch (bool): Whether to release cache before an epoch. Defaults - to False. - after_epoch (bool): Whether to release cache after an epoch. Defaults - to True. - after_iter (bool): Whether to release cache after an iteration. - Defaults to False. - """ - - priority = "NORMAL" - - def __init__( - self, - before_epoch: bool = False, - after_epoch: bool = True, - after_iter: bool = False, - ) -> None: - self._do_before_epoch = before_epoch - self._do_after_epoch = after_epoch - self._do_after_iter = after_iter - - def _after_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | Sequence | None = None, - mode: str = "train", - ) -> None: - """Empty cache after an iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_after_iter and is_cuda_available(): - torch.cuda.empty_cache() - - def _before_epoch(self, runner, mode: str = "train") -> None: - """Empty cache before an epoch. - - Args: - runner (Runner): The runner of the training process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_before_epoch and is_cuda_available(): - torch.cuda.empty_cache() - - def _after_epoch(self, runner, mode: str = "train") -> None: - """Empty cache after an epoch. - - Args: - runner (Runner): The runner of the training process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - if self._do_after_epoch and is_cuda_available(): - torch.cuda.empty_cache() diff --git a/libs/visengine/visengine/hooks/hook.py b/libs/visengine/visengine/hooks/hook.py deleted file mode 100644 index 5106aae..0000000 --- a/libs/visengine/visengine/hooks/hook.py +++ /dev/null @@ -1,468 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Sequence -from typing import Optional - -from visengine.utils import is_method_overridden - -DATA_BATCH = Optional[dict | tuple | list] - - -class Hook: - """Base hook class. - - All hooks should inherit from this class. - """ - - priority = "NORMAL" - stages = ( - "before_run", - "after_load_checkpoint", - "before_train", - "before_train_epoch", - "before_train_iter", - "after_train_iter", - "after_train_epoch", - "before_val", - "before_val_epoch", - "before_val_iter", - "after_val_iter", - "after_val_epoch", - "after_val", - "before_save_checkpoint", - "after_train", - "before_test", - "before_test_epoch", - "before_test_iter", - "after_test_iter", - "after_test_epoch", - "after_test", - "after_run", - ) - - def before_run(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before the training validation or testing process. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - """ - - def after_run(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before the training validation or testing process. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - """ - - def before_train(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before train. - - Args: - runner (Runner): The runner of the training process. - """ - - def after_train(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after train. - - Args: - runner (Runner): The runner of the training process. - """ - - def before_val(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before validation. - - Args: - runner (Runner): The runner of the validation process. - """ - - def after_val(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after validation. - - Args: - runner (Runner): The runner of the validation process. - """ - - def before_test(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before testing. - - Args: - runner (Runner): The runner of the testing process. - """ - - def after_test(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after testing. - - Args: - runner (Runner): The runner of the testing process. - """ - - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: - """All subclasses should override this method, if they need any - operations before saving the checkpoint. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - checkpoint (dict): Model's checkpoint. - """ - - def after_load_checkpoint(self, runner, checkpoint: dict) -> None: - """All subclasses should override this method, if they need any - operations after loading the checkpoint. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - checkpoint (dict): Model's checkpoint. - """ - - def before_train_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - self._before_epoch(runner, mode="train") - - def before_val_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - """ - self._before_epoch(runner, mode="val") - - def before_test_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each test epoch. - - Args: - runner (Runner): The runner of the testing process. - """ - self._before_epoch(runner, mode="test") - - def after_train_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - self._after_epoch(runner, mode="train") - - def after_val_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._after_epoch(runner, mode="val") - - def after_test_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - self._after_epoch(runner, mode="test") - - def before_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - """ - self._before_iter(runner, batch_idx=batch_idx, data_batch=data_batch, mode="train") - - def before_val_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the val loop. - data_batch (dict, optional): Data from dataloader. - Defaults to None. - """ - self._before_iter(runner, batch_idx=batch_idx, data_batch=data_batch, mode="val") - - def before_test_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each test iteration. - - Args: - runner (Runner): The runner of the testing process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - Defaults to None. - """ - self._before_iter(runner, batch_idx=batch_idx, data_batch=data_batch, mode="test") - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | None = None, - ) -> None: - """All subclasses should override this method, if they need any - operations after each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode="train", - ) - - def after_val_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | None = None, - ) -> None: - """All subclasses should override this method, if they need any - operations after each validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the val loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode="val", - ) - - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | None = None, - ) -> None: - """All subclasses should override this method, if they need any - operations after each test iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode="test", - ) - - def _before_epoch(self, runner, mode: str = "train") -> None: - """All subclasses should override this method, if they need any - operations before each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _after_epoch(self, runner, mode: str = "train") -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _before_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, mode: str = "train") -> None: - """All subclasses should override this method, if they need any - operations before each iter. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def _after_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | dict | None = None, - mode: str = "train", - ) -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or Sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - - def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: - """Test whether current epoch can be evenly divided by n. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - n (int): Whether current epoch can be evenly divided by n. - start (int): Starting from `start` to check the logic for - every n epochs. Defaults to 0. - - Returns: - bool: Whether current epoch can be evenly divided by n. - """ - dividend = runner.epoch + 1 - start - return dividend % n == 0 if dividend >= 0 and n > 0 else False - - def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: - """Test whether current inner iteration can be evenly divided by n. - - Args: - batch_idx (int): Current batch index of the training, validation - or testing loop. - n (int): Whether current inner iteration can be evenly - divided by n. - - Returns: - bool: Whether current inner iteration can be evenly - divided by n. - """ - return (batch_idx + 1) % n == 0 if n > 0 else False - - def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: - """Test whether current training iteration can be evenly divided by n. - - Args: - runner (Runner): The runner of the training, validation or testing - process. - n (int): Whether current iteration can be evenly divided by n. - start (int): Starting from `start` to check the logic for - every n iterations. Defaults to 0. - - Returns: - bool: Return True if the current iteration can be evenly divided - by n, otherwise False. - """ - dividend = runner.iter + 1 - start - return dividend % n == 0 if dividend >= 0 and n > 0 else False - - def end_of_epoch(self, dataloader, batch_idx: int) -> bool: - """Check whether the current iteration reaches the last iteration of - the dataloader. - - Args: - dataloader (Dataloader): The dataloader of the training, - validation or testing process. - batch_idx (int): The index of the current batch in the loop. - Returns: - bool: Whether reaches the end of current epoch or not. - """ - return batch_idx + 1 == len(dataloader) - - def is_last_train_epoch(self, runner) -> bool: - """Test whether current epoch is the last train epoch. - - Args: - runner (Runner): The runner of the training process. - - Returns: - bool: Whether reaches the end of training epoch. - """ - return runner.epoch + 1 == runner.max_epochs - - def is_last_train_iter(self, runner) -> bool: - """Test whether current iteration is the last train iteration. - - Args: - runner (Runner): The runner of the training process. - - Returns: - bool: Whether current iteration is the last train iteration. - """ - return runner.iter + 1 == runner.max_iters - - def get_triggered_stages(self) -> list: - """Get all triggered stages with method name of the hook. - - Returns: - list: List of triggered stages. - """ - trigger_stages = set() - for stage in Hook.stages: - if is_method_overridden(stage, Hook, self): - trigger_stages.add(stage) - - # some methods will be triggered in multi stages - # use this dict to map method to stages. - method_stages_map = { - "_before_epoch": [ - "before_train_epoch", - "before_val_epoch", - "before_test_epoch", - ], - "_after_epoch": [ - "after_train_epoch", - "after_val_epoch", - "after_test_epoch", - ], - "_before_iter": [ - "before_train_iter", - "before_val_iter", - "before_test_iter", - ], - "_after_iter": ["after_train_iter", "after_val_iter", "after_test_iter"], - } - - for method, map_stages in method_stages_map.items(): - if is_method_overridden(method, Hook, self): - trigger_stages.update(map_stages) - - return list(trigger_stages) diff --git a/libs/visengine/visengine/hooks/iter_timer_hook.py b/libs/visengine/visengine/hooks/iter_timer_hook.py deleted file mode 100644 index fe19f55..0000000 --- a/libs/visengine/visengine/hooks/iter_timer_hook.py +++ /dev/null @@ -1,107 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import time -from collections.abc import Sequence -from typing import Optional - -from visengine.registry import HOOKS - -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -@HOOKS.register_module(force=True) -class IterTimerHook(Hook): - """A hook that logs the time spent during iteration. - - E.g. ``data_time`` for loading data and ``time`` for a model train step. - """ - - priority = "NORMAL" - - def __init__(self): - self.time_sec_tot = 0 - self.time_sec_test_val = 0 - self.start_iter = 0 - - def before_train(self, runner) -> None: - """Synchronize the number of iterations with the runner after resuming - from checkpoints. - - Args: - runner: The runner of the training, validation or testing - process. - """ - self.start_iter = runner.iter - - def _before_epoch(self, runner, mode: str = "train") -> None: - """Record timestamp before start an epoch. - - Args: - runner (Runner): The runner of the training validation and - testing process. - mode (str): Current mode of runner. Defaults to 'train'. - """ - self.t = time.time() - - def _after_epoch(self, runner, mode: str = "train") -> None: - self.time_sec_test_val = 0 - - def _before_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, mode: str = "train") -> None: - """Calculating time for loading data and updating "data_time" - ``HistoryBuffer`` of ``runner.message_hub``. - - Args: - runner (Runner): The runner of the training, validation and - testing process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from - dataloader. - mode (str): Current mode of runner. Defaults to 'train'. - """ - # Update data loading time in `runner.message_hub`. - runner.message_hub.update_scalar(f"{mode}/data_time", time.time() - self.t) - - def _after_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | Sequence | None = None, - mode: str = "train", - ) -> None: - """Calculating time for an iteration and updating "time" - ``HistoryBuffer`` of ``runner.message_hub``. - - Args: - runner (Runner): The runner of the training validation and - testing process. - batch_idx (int): The index of the current batch in the loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (dict or sequence, optional): Outputs from model. - mode (str): Current mode of runner. Defaults to 'train'. - """ - # Update iteration time in `runner.message_hub`. - message_hub = runner.message_hub - message_hub.update_scalar(f"{mode}/time", time.time() - self.t) - self.t = time.time() - iter_time = message_hub.get_scalar(f"{mode}/time") - if mode == "train": - self.time_sec_tot += iter_time.current() - # Calculate average iterative time. - time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) - # Calculate eta. - eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) - runner.message_hub.update_info("eta", eta_sec) - else: - if mode == "val": - cur_dataloader = runner.val_dataloader - else: - cur_dataloader = runner.test_dataloader - - self.time_sec_test_val += iter_time.current() - time_sec_avg = self.time_sec_test_val / (batch_idx + 1) - eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1) - runner.message_hub.update_info("eta", eta_sec) diff --git a/libs/visengine/visengine/hooks/logger_hook.py b/libs/visengine/visengine/hooks/logger_hook.py deleted file mode 100644 index f284c88..0000000 --- a/libs/visengine/visengine/hooks/logger_hook.py +++ /dev/null @@ -1,340 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import os -import os.path as osp -from collections import OrderedDict -from collections.abc import Sequence -from pathlib import Path -from typing import Optional, Union - -import numpy as np -import torch - -from visengine.fileio import FileClient, dump -from visengine.fileio.io import get_file_backend -from visengine.hooks import Hook -from visengine.logging import print_log -from visengine.registry import HOOKS -from visengine.utils import is_seq_of, scandir - -DATA_BATCH = Optional[dict | tuple | list] -SUFFIX_TYPE = Union[Sequence[str], str] - - -@HOOKS.register_module(force=True) -class LoggerHook(Hook): - """Collect logs from different components of ``Runner`` and write them to - terminal, JSON file, tensorboard and wandb .etc. - - ``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during - training/validation/testing phase. It is used to control following - behaviors: - - - The frequency of logs update in terminal, local, tensorboad wandb.etc. - - The frequency of show experiment information in terminal. - - The work directory to save logs. - - Args: - interval (int): Logging interval (every k iterations). - Defaults to 10. - ignore_last (bool): Ignore the log of last iterations in each epoch if - the number of remaining iterations is less than :attr:`interval`. - Defaults to True. - interval_exp_name (int): Logging interval for experiment name. This - feature is to help users conveniently get the experiment - information from screen or log file. Defaults to 1000. - out_dir (str or Path, optional): The root directory to save - checkpoints. If not specified, ``runner.work_dir`` will be used - by default. If specified, the ``out_dir`` will be the concatenation - of ``out_dir`` and the last level directory of ``runner.work_dir``. - For example, if the input ``out_dir`` is ``./tmp`` and - ``runner.work_dir`` is ``./work_dir/cur_exp``, then the log will be - saved in ``./tmp/cur_exp``. Defaults to None. - out_suffix (Tuple[str] or str): Those files in ``runner._log_dir`` - ending with ``out_suffix`` will be copied to ``out_dir``. Defaults - to ('json', '.log', '.py'). - keep_local (bool): Whether to keep local logs in the local machine - when :attr:`out_dir` is specified. If False, the local log will be - removed. Defaults to True. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - `backend_args` instead. - log_metric_by_epoch (bool): Whether to output metric in validation step - by epoch. It can be true when running in epoch based runner. - If set to True, `after_val_epoch` will set `step` to self.epoch in - `runner.visualizer.add_scalars`. Otherwise `step` will be - self.iter. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> # The simplest LoggerHook config. - >>> logger_hook_cfg = dict(interval=20) - """ - - priority = "BELOW_NORMAL" - - def __init__( - self, - interval: int = 10, - ignore_last: bool = True, - interval_exp_name: int = 1000, - out_dir: str | Path | None = None, - out_suffix: SUFFIX_TYPE = (".json", ".log", ".py", "yaml"), - keep_local: bool = True, - file_client_args: dict | None = None, - log_metric_by_epoch: bool = True, - backend_args: dict | None = None, - ): - if not isinstance(interval, int): - raise TypeError("interval must be an integer") - if interval <= 0: - raise ValueError("interval must be greater than 0") - - if not isinstance(ignore_last, bool): - raise TypeError("ignore_last must be a boolean") - - if not isinstance(interval_exp_name, int): - raise TypeError("interval_exp_name must be an integer") - if interval_exp_name <= 0: - raise ValueError("interval_exp_name must be greater than 0") - - if out_dir is not None and not isinstance(out_dir, str | Path): - raise TypeError("out_dir must be a str or Path object") - - if not isinstance(keep_local, bool): - raise TypeError("keep_local must be a boolean") - - if out_dir is None and file_client_args is not None: - raise ValueError('file_client_args should be "None" when `out_dir` is notspecified.') - - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - logger="current", - level=logging.WARNING, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - if not (isinstance(out_suffix, str) or is_seq_of(out_suffix, str)): - raise TypeError(f"out_suffix should be a string or a sequence of string, but got {type(out_suffix)}") - - self.out_suffix = out_suffix - self.out_dir = out_dir - self.interval = interval - self.ignore_last = ignore_last - self.interval_exp_name = interval_exp_name - self.keep_local = keep_local - self.file_client_args = file_client_args - self.json_log_path: str | None = None - - if self.out_dir is not None: - self.file_client = FileClient.infer_client(file_client_args, self.out_dir) - if file_client_args is None: - self.file_backend = get_file_backend(self.out_dir, backend_args=backend_args) - else: - self.file_backend = self.file_client - - self.log_metric_by_epoch = log_metric_by_epoch - - def before_run(self, runner) -> None: - """Infer ``self.file_client`` from ``self.out_dir``. Initialize the - ``self.start_iter`` and record the meta information. - - Args: - runner (Runner): The runner of the training process. - """ - if self.out_dir is not None: - # The final `self.out_dir` is the concatenation of `self.out_dir` - # and the last level directory of `runner.work_dir` - basename = osp.basename(runner.work_dir.rstrip(osp.sep)) - self.out_dir = self.file_backend.join_path(self.out_dir, basename) - runner.logger.info(f"Text logs will be saved to {self.out_dir} after the training process.") - - self.json_log_path = f"{runner.timestamp}.json" - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | None = None, - ) -> None: - """Record logs after training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict tuple or list, optional): Data from dataloader. - outputs (dict, optional): Outputs from model. - """ - # Print experiment name every n iterations. - if self.every_n_train_iters(runner, self.interval_exp_name) or ( - self.end_of_epoch(runner.train_dataloader, batch_idx) - ): - exp_info = f"Exp name: {runner.experiment_name}" - runner.logger.info(exp_info) - if self.every_n_inner_iters(batch_idx, self.interval): - tag, log_str = runner.log_processor.get_log_after_iter(runner, batch_idx, "train") - elif self.end_of_epoch(runner.train_dataloader, batch_idx) and ( - not self.ignore_last or len(runner.train_dataloader) <= self.interval - ): - # `runner.max_iters` may not be divisible by `self.interval`. if - # `self.ignore_last==True`, the log of remaining iterations will - # be recorded (Epoch [4][1000/1007], the logs of 998-1007 - # iterations will be recorded). - tag, log_str = runner.log_processor.get_log_after_iter(runner, batch_idx, "train") - else: - return - runner.logger.info(log_str) - runner.visualizer.add_scalars(tag, step=runner.iter + 1, file_path=self.json_log_path) - - def after_val_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | None = None, - ) -> None: - """Record logs after validation iteration. - - Args: - runner (Runner): The runner of the validation process. - batch_idx (int): The index of the current batch in the validation - loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - Defaults to None. - outputs (sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self.interval): - _, log_str = runner.log_processor.get_log_after_iter(runner, batch_idx, "val") - runner.logger.info(log_str) - - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | None = None, - ) -> None: - """Record logs after testing iteration. - - Args: - runner (Runner): The runner of the testing process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self.interval): - _, log_str = runner.log_processor.get_log_after_iter(runner, batch_idx, "test") - runner.logger.info(log_str) - - def after_val_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - tag, log_str = runner.log_processor.get_log_after_epoch(runner, len(runner.val_dataloader), "val") - runner.logger.info(log_str) - if self.log_metric_by_epoch: - # Accessing the epoch attribute of the runner will trigger - # the construction of the train_loop. Therefore, to avoid - # triggering the construction of the train_loop during - # validation, check before accessing the epoch. - if isinstance(runner._train_loop, dict) or runner._train_loop is None: - epoch = 0 - else: - epoch = runner.epoch - runner.visualizer.add_scalars(tag, step=epoch, file_path=self.json_log_path) - else: - if isinstance(runner._train_loop, dict) or runner._train_loop is None: - iter = 0 - else: - iter = runner.iter - runner.visualizer.add_scalars(tag, step=iter, file_path=self.json_log_path) - - def after_test_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - tag, log_str = runner.log_processor.get_log_after_epoch( - runner, len(runner.test_dataloader), "test", with_non_scalar=True - ) - runner.logger.info(log_str) - dump(self._process_tags(tag), osp.join(runner.log_dir, self.json_log_path)) # type: ignore - - @staticmethod - def _process_tags(tags: dict): - """Convert tag values to json-friendly type.""" - - def process_val(value): - if isinstance(value, list | tuple): - # Array type of json - return [process_val(item) for item in value] - elif isinstance(value, dict): - # Object type of json - return {k: process_val(v) for k, v in value.items()} - elif isinstance(value, str | int | float | bool) or value is None: - # Other supported type of json - return value - elif isinstance(value, torch.Tensor | np.ndarray): - return value.tolist() - # Drop unsupported values. - - processed_tags = OrderedDict(process_val(tags)) - - return processed_tags - - def after_run(self, runner) -> None: - """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` - - Args: - runner (Runner): The runner of the training/testing/validation - process. - """ - # close the visualizer - runner.visualizer.close() - - # copy or upload logs to self.out_dir - if self.out_dir is None: - return - - removed_files = [] - for filename in scandir(runner._log_dir, self.out_suffix, True): - local_filepath = osp.join(runner._log_dir, filename) - removed_files.append(local_filepath) - out_filepath = self.file_backend.join_path(self.out_dir, filename) - with open(local_filepath) as f: - self.file_backend.put_text(f.read(), out_filepath) - - runner.logger.info(f"The file {local_filepath} has been uploaded to {out_filepath}.") - - if not self.keep_local: - runner.logger.info( - f"{local_filepath} was removed due to the `self.keep_local=False`. You can check the running logs in {out_filepath}" - ) - - if not self.keep_local: - # Close file handler to avoid PermissionError on Windows. - for handler in runner.logger.handlers: - if isinstance(handler, logging.FileHandler): - handler.close() - - for file in removed_files: - os.remove(file) diff --git a/libs/visengine/visengine/hooks/naive_visualization_hook.py b/libs/visengine/visengine/hooks/naive_visualization_hook.py deleted file mode 100644 index afc3c28..0000000 --- a/libs/visengine/visengine/hooks/naive_visualization_hook.py +++ /dev/null @@ -1,94 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -from collections.abc import Sequence -from typing import Optional - -import cv2 -import numpy as np - -from visengine.hooks import Hook -from visengine.registry import HOOKS -from visengine.utils.dl_utils import tensor2imgs - -DATA_BATCH = Optional[dict | tuple | list] - - -# TODO: Due to interface changes, the current class -# functions incorrectly -@HOOKS.register_module(force=True) -class NaiveVisualizationHook(Hook): - """Show or Write the predicted results during the process of testing. - - Args: - interval (int): Visualization interval. Defaults to 1. - draw_gt (bool): Whether to draw the ground truth. Defaults to True. - draw_pred (bool): Whether to draw the predicted result. - Defaults to True. - """ - - priority = "NORMAL" - - def __init__(self, interval: int = 1, draw_gt: bool = True, draw_pred: bool = True): - self.draw_gt = draw_gt - self.draw_pred = draw_pred - self._interval = interval - - def _unpad(self, input: np.ndarray, unpad_shape: tuple[int, int]) -> np.ndarray: - """Unpad the input image. - - Args: - input (np.ndarray): The image to unpad. - unpad_shape (tuple): The shape of image before padding. - - Returns: - np.ndarray: The image before padding. - """ - unpad_width, unpad_height = unpad_shape - unpad_image = input[:unpad_height, :unpad_width] - return unpad_image - - def before_train(self, runner) -> None: - """Call add_graph method of visualizer. - - Args: - runner (Runner): The runner of the training process. - """ - runner.visualizer.add_graph(runner.model, None) - - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Sequence | None = None, - ) -> None: - """Show or Write the predicted results. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the test loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - outputs (Sequence, optional): Outputs from model. - """ - if self.every_n_inner_iters(batch_idx, self._interval): - for data, output in zip(data_batch, outputs, strict=False): # type: ignore - input = data["inputs"] - data_sample = data["data_sample"] - input = tensor2imgs(input, **data_sample.get("img_norm_cfg", {}))[0] - # TODO We will implement a function to revert the augmentation - # in the future. - ori_shape = (data_sample.ori_width, data_sample.ori_height) - if "pad_shape" in data_sample: - input = self._unpad(input, data_sample.get("scale", ori_shape)) - origin_image = cv2.resize(input, ori_shape) - name = osp.basename(data_sample.img_path) - runner.visualizer.add_datasample( - name, - origin_image, - data_sample, - output, - self.draw_gt, - self.draw_pred, - ) diff --git a/libs/visengine/visengine/hooks/param_scheduler_hook.py b/libs/visengine/visengine/hooks/param_scheduler_hook.py deleted file mode 100644 index cd8a31d..0000000 --- a/libs/visengine/visengine/hooks/param_scheduler_hook.py +++ /dev/null @@ -1,133 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional - -from visengine.optim import _ParamScheduler -from visengine.registry import HOOKS -from visengine.utils import is_list_of - -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -@HOOKS.register_module(force=True) -class ParamSchedulerHook(Hook): - """A hook to update some hyper-parameters in optimizer, e.g., learning rate - and momentum.""" - - priority = "LOW" - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | None = None, - ) -> None: - """Call step function for each scheduler after each training iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (dict or tuple or list, optional): Data from dataloader. - In order to keep this interface consistent with other hooks, - we keep ``data_batch`` here. - outputs (dict, optional): Outputs from model. - In order to keep this interface consistent with other hooks, we - keep ``data_batch`` here. - """ - - if runner.param_schedulers is None: - return - - def step(param_schedulers): - assert isinstance(param_schedulers, list) - for scheduler in param_schedulers: - if not scheduler.by_epoch: - scheduler.step() - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - "runner.param_schedulers should be list of ParamScheduler or " - "a dict containing list of ParamScheduler, " - f"but got {runner.param_schedulers}" - ) - - def after_train_epoch(self, runner) -> None: - """Call step function for each scheduler after each training epoch. - - Args: - runner (Runner): The runner of the training process. - """ - - if runner.param_schedulers is None: - return - - def step(param_schedulers): - assert isinstance(param_schedulers, list) - for scheduler in param_schedulers: - if scheduler.by_epoch: - scheduler.step() - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - "runner.param_schedulers should be list of ParamScheduler or " - "a dict containing list of ParamScheduler, " - f"but got {runner.param_schedulers}" - ) - - def after_val_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """Call step function for each scheduler which has attribute - ``need_val_args`` after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - - Note: - if ``runner.param_schedulers`` is not built before, - the hook ``after_val_epoch`` will be skipped. - """ - - if runner.param_schedulers is None: - return - - # avoid counting scheduler._global_step - # it has counted in after_train_* hook - if metrics is None: - return - - def step(param_schedulers): - # check param_schedulers is list and built - if not is_list_of(param_schedulers, _ParamScheduler): - return - - for scheduler in param_schedulers: - if scheduler.by_epoch and getattr(scheduler, "need_val_args", False): - scheduler.step(metrics) - - if isinstance(runner.param_schedulers, list): - step(runner.param_schedulers) - elif isinstance(runner.param_schedulers, dict): - for param_schedulers in runner.param_schedulers.values(): - step(param_schedulers) - else: - raise TypeError( - "runner.param_schedulers should be list of ParamScheduler or " - "a dict containing list of ParamScheduler, " - f"but got {runner.param_schedulers}" - ) diff --git a/libs/visengine/visengine/hooks/profiler_hook.py b/libs/visengine/visengine/hooks/profiler_hook.py deleted file mode 100644 index 18ea623..0000000 --- a/libs/visengine/visengine/hooks/profiler_hook.py +++ /dev/null @@ -1,333 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import os -import os.path as osp -import sys -from collections.abc import Callable - -import torch - -from visengine.dist import master_only -from visengine.hooks import Hook -from visengine.logging import print_log -from visengine.registry import HOOKS - - -def check_kineto() -> bool: - kineto_exist = False - try: - if torch.autograd.kineto_available(): - kineto_exist = True - except AttributeError: - print_log("NO KINETO", logger="current", level=logging.WARNING) - return kineto_exist - - -@HOOKS.register_module(force=True) -class ProfilerHook(Hook): - """A hook to analyze performance during training and inference. - - PyTorch Profiler is a tool that allows the collection of the performance - metrics during the training. More details on Profiler can be found at - `official docs `_ - - Args: - by_epoch (bool): Profile performance by epoch or by iteration. - Defaults to True. - profile_times (int): The period (epoch/iter) recorded by the profiler. - Defaults to 1. For example, profile_iters=10 and by_epoch=False, - indicate that 0-10 iterations are recorded. - activity_with_cpu (bool): Activities to be used in the analysis (CPU) - activity_with_cuda (bool): Activities to be used in the analysis (CUDA) - schedule (dict, optional): Key-word arguments passed to - `torch.profile.schedule `_. - Defaults to None, which means profiling without a schedule - on_trace_ready (callable, dict, optional): Either a handler or a dict - of generating handler. Defaults to None, which means profiling - without an on_trace_ready.The Callable type needs to construct its - own function that can handle 'torch.autograd.profiler.profile'. - Two officially recommended ways are provided: - - - ``schedule=dict(type='log_trace')``: Print the profiling result - in the terminal. See more details in the `PyTorch official tutorial`_. - The configurable arguments are the same as - ``prof.key_averages().table`` - - ``scheduler=dict(type='tb_trace')``: Profile the performance - with tensorboard. See more details in the tutorial - `profile with tensorboard`_. - - record_shapes (bool): Save information about operator's input shapes. - Defaults to False. - profile_memory (bool): Track tensor memory allocation/deallocation. - Defaults to False. - with_stack (bool): Record source information (file and line number) - for the ops. Defaults to False. - with_flops (bool): Use formula to estimate the FLOPS of specific - operators (matrix multiplication and 2D convolution). - Defaults to False. - json_trace_path (str, optional): Exports the collected trace in Chrome - JSON format. Chrome use 'chrome://tracing' view json file. - Defaults to None, which means profiling does not store json files. - - Warnings: - The profiler will be closed after ``profile_times`` iterations - automatically. Please make sure the configuration of your scheduler - will not close the profiler before the iteration reach the value of - ``profile_times`` - - Examples: - >>> # tensorboard trace - >>> trace_config = dict(type='tb_trace') - >>> profiler_hook_cfg = dict(on_trace_ready=trace_config) - - .. _PyTorch official tutorial: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-execution-time - .. _profile with tensorboard: https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#pytorch-profiler-with-tensorboard - """ - - priority = "VERY_LOW" - - def __init__( - self, - *, - by_epoch: bool = True, - profile_times: int = 1, - activity_with_cpu: bool = True, - activity_with_cuda: bool = False, - schedule: dict | None = None, - on_trace_ready: Callable | dict | None = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - json_trace_path: str | None = None, - ) -> None: - try: - from torch import profiler - except ImportError: - raise ImportError("please upgrade torch above 1.8.1") - if not check_kineto(): - raise ImportError( - "Due to Kineto support issues, please upgrade pytorch above 1.8.1(windows users above 1.9.1)" - ) - - assert isinstance(by_epoch, bool), "``by_epoch`` should be a boolean." - self.by_epoch = by_epoch - - if profile_times < 1: - raise ValueError(f"profile_iters should be greater than 0, but got {profile_times}") - if by_epoch and profile_times > 1: - raise ValueError( - f"Profiler will profile 0-{profile_times} epochs.\n" - "Since profiler will slow down the training, it is recommended" - " to train 1 epoch with ProfilerHook and adjust your setting " - "according to the profiler summary.\n" - "During normal training(epoch > 1), " - "you may disable the ProfilerHook." - ) - self.profile_times = profile_times - - assert isinstance(activity_with_cpu, bool), "``activity_with_cpu`` should be a boolean." - assert isinstance(activity_with_cuda, bool), "``activity_with_cuda`` should be a boolean." - self.activities = [] - if activity_with_cpu: - self.activities.append(profiler.ProfilerActivity.CPU) - if activity_with_cuda: - self.activities.append(profiler.ProfilerActivity.CUDA) - - if schedule is not None: - assert isinstance(schedule, dict), "``schedule`` should be a dict." - self.schedule = profiler.schedule(**schedule) - else: - self.schedule = None - - self.on_trace_ready = on_trace_ready - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.with_stack = with_stack - self.with_flops = with_flops - - self.json_trace_path = json_trace_path - self._closed = False - - def before_run(self, runner): - """Initialize the profiler. - - Through the runner parameter, the validity of the parameter is further - determined. - """ - max_times = runner.max_epochs if self.by_epoch else runner.max_iters - if max_times < self.profile_times: - raise ValueError(f"``profile_times`` should not be greater than {max_times}") - - on_trace_ready = self._parse_trace_config(runner) - - self.profiler = torch.profiler.profile( - activities=self.activities, - schedule=self.schedule, - on_trace_ready=on_trace_ready, - record_shapes=self.record_shapes, - profile_memory=self.profile_memory, - with_stack=self.with_stack, - with_flops=self.with_flops, - ) - - self.profiler.__enter__() - runner.logger.info("profiler is profiling...") - - def _parse_trace_config(self, runner): - """Used to parse the parameter 'on_trace_ready'.""" - if self.on_trace_ready is None: - _on_trace_ready = None - elif callable(self.on_trace_ready): - _on_trace_ready = self.on_trace_ready - elif isinstance(self.on_trace_ready, dict): - trace_cfg = self.on_trace_ready.copy() - trace_type = trace_cfg.pop("type") - - # Build a log printing handle - if trace_type == "log_trace": - - def _log_handler(_profile): - print(_profile.key_averages().table(**trace_cfg)) - - _on_trace_ready = _log_handler - - elif trace_type == "tb_trace": # tensorboard_trace handler - try: - import torch_tb_profiler # noqa: F401 - except ImportError: - raise ImportError("please run ``pip install torch-tb-profiler``") - - if "dir_name" not in trace_cfg: - trace_cfg["dir_name"] = osp.join(runner.log_dir, "tf_tracing_logs") - elif not osp.isabs(trace_cfg["dir_name"]): - trace_cfg["dir_name"] = osp.join(runner.log_dir, trace_cfg["dir_name"]) - runner.logger.info(f"trace_files of ProfilerHook will be saved to {trace_cfg['dir_name']}.") - - if self.json_trace_path is not None: - runner.logger.warn( - "When using tensorboard_trace, it is recommended to " - "save json files by setting ``worker_name`` instead of" - " setting ``json_trace_path``" - ) - _on_trace_ready = torch.profiler.tensorboard_trace_handler(**trace_cfg) - else: - raise ValueError(f'trace_type should be "log_trace" or "tb_trace", but got {trace_type}') - else: - raise ValueError(f"``on_trace_ready`` should be a handler, or dict, or None, but got {self.on_trace_ready}") - return _on_trace_ready - - def after_train_epoch(self, runner): - """Determine if the content is exported.""" - # `after_train_epoch` will also be called in IterBasedTrainLoop. - # Here we check `self._closed` to avoid exiting twice. - if not self._closed: - self._export_chrome_trace(runner) - - def after_train_iter(self, runner, batch_idx, data_batch, outputs): - """Profiler will call `step` method if it is not closed.""" - if not self._closed: - self.profiler.step() - if runner.iter == self.profile_times - 1 and not self.by_epoch: - self._export_chrome_trace(runner) - - def _export_chrome_trace(self, runner): - """Exporting content.""" - self._closed = True - runner.logger.info("profiler may take a few minutes...") - self.profiler.__exit__(None, None, None) - if self.json_trace_path is not None: - self.profiler.export_chrome_trace(self.json_trace_path) - - -@HOOKS.register_module(force=True) -class NPUProfilerHook(Hook): - """NPUProfiler to analyze performance during training. - - NPU Profiling is used to count the device execution time of all operators. - The torch_npu.npu.profile interface is used to complete the profiling data - collection at each stage of the project, and the data is analyzed by the - msprof tool and the data can be dumped to further manually analyze the - key performance bottlenecks. For more details on the torch_npu.npu.profile - interface, please visit - https://gitee.com/ascend/pytorch/blob/master/torch_npu/npu/profiler.py#profile - - Args: - begin (int): Number of start iterations for profiling. Defaults to 0. - end (int): Number of end iterations for profiling. Defaults to 1. - result_path (str): The path to save the profiling results file. - Defaults to 'cann_profiling'. - exit_after_profiling (bool): Whether to exit the program after - profiling. Defaults to True. - use_e2e_profiler (bool): Turn on E2E profiling, E2E profiling combines - performance data at the Pytorch level and the NPU level to analyze - the bottlenecks of model performance end-to-end, and cannot show - detailed content, and only as an auxiliary analysis. - Defaults to False. - ge_profiling_to_std_out (bool): Turn on GE profiling, GE uses to - collect the profiling data of the host side scheduling of the - Assend device. Defaults to False. - - Examples: - >>> cfg = ... - >>> profiler_config = dict(type='NPUProfilerHook', end=2) - >>> cfg.merge_from_dict({'custom_hooks': custom_hooks}) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - """ - - priority = "VERY_LOW" - - def __init__( - self, - *, - begin: int = 0, - end: int = 1, - result_path: str = "cann_profiling", - exit_after_profiling: bool = True, - use_e2e_profiler: bool = False, - ge_profiling_to_std_out: bool = False, - ): - try: - import torch_npu - except ImportError: - raise ImportError("Failed to import torch_npu module") - - if begin >= end: - raise ValueError("The iteration to start profiling should not be greaterthan or equal to profile end") - - self.begin = begin - self.end = end - self.result_path = result_path - self.exit_after_profiling = exit_after_profiling - - if ge_profiling_to_std_out: - os.environ["GE_PROFILING_TO_STD_OUT"] = "1" - - if not osp.exists(self.result_path): - os.makedirs(self.result_path, exist_ok=True) - - self.profiler = torch_npu.npu.profile(self.result_path, use_e2e_profiler=use_e2e_profiler) - - @master_only - def before_run(self, runner): - if self.end > runner.max_iters: - raise ValueError("The profiling end iteration should not be greaterthan the max iteration") - - @master_only - def before_train_iter(self, runner, batch_idx, data_batch=None): - if runner.iter == self.begin: - self.profiler.__enter__() - runner.logger.info("NPUProfiler starts profiling...") - - @master_only - def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): - if runner.iter == self.end - 1: - runner.logger.info("profiler may take a few minutes to save the profiling result.") - self.profiler.__exit__(None, None, None) - if self.exit_after_profiling: - sys.exit() diff --git a/libs/visengine/visengine/hooks/runtime_info_hook.py b/libs/visengine/visengine/hooks/runtime_info_hook.py deleted file mode 100644 index 73f91f2..0000000 --- a/libs/visengine/visengine/hooks/runtime_info_hook.py +++ /dev/null @@ -1,183 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional - -import numpy as np -import torch - -from visengine.registry import HOOKS -from visengine.utils import get_git_hash -from visengine.version import __version__ - -from .hook import Hook - -DATA_BATCH = Optional[dict | tuple | list] - - -def _is_scalar(value: Any) -> bool: - """Determine the value is a scalar type value. - - Args: - value (Any): value of log. - - Returns: - bool: whether the value is a scalar type value. - """ - if isinstance(value, np.ndarray): - return value.size == 1 - elif isinstance(value, int | float | np.number): - return True - elif isinstance(value, torch.Tensor): - return value.numel() == 1 - return False - - -@HOOKS.register_module(force=True) -class RuntimeInfoHook(Hook): - """A hook that updates runtime information into message hub. - - E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the - training state. Components that cannot access the runner can get runtime - information through the message hub. - """ - - priority = "VERY_HIGH" - - def before_run(self, runner) -> None: - """Update metainfo. - - Args: - runner (Runner): The runner of the training process. - """ - metainfo = { - "cfg": runner.cfg.pretty_text, - "seed": runner.seed, - "experiment_name": runner.experiment_name, - "mmengine_version": __version__ + get_git_hash(), - } - runner.message_hub.update_info_dict(metainfo) - - self.last_loop_stage = None - - def before_train(self, runner) -> None: - """Update resumed training state. - - Args: - runner (Runner): The runner of the training process. - """ - runner.message_hub.update_info("loop_stage", "train") - runner.message_hub.update_info("epoch", runner.epoch) - runner.message_hub.update_info("iter", runner.iter) - runner.message_hub.update_info("max_epochs", runner.max_epochs) - runner.message_hub.update_info("max_iters", runner.max_iters) - if hasattr(runner.train_dataloader.dataset, "metainfo"): - runner.message_hub.update_info("dataset_meta", runner.train_dataloader.dataset.metainfo) - - def after_train(self, runner) -> None: - runner.message_hub.pop_info("loop_stage") - - def before_train_epoch(self, runner) -> None: - """Update current epoch information before every epoch. - - Args: - runner (Runner): The runner of the training process. - """ - runner.message_hub.update_info("epoch", runner.epoch) - - def before_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None) -> None: - """Update current iter and learning rate information before every - iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - """ - runner.message_hub.update_info("iter", runner.iter) - lr_dict = runner.optim_wrapper.get_lr() - assert isinstance(lr_dict, dict), ( - "`runner.optim_wrapper.get_lr()` should return a dict " - "of learning rate when training with OptimWrapper(single " - "optimizer) or OptimWrapperDict(multiple optimizer), " - f"but got {type(lr_dict)} please check your optimizer " - "constructor return an `OptimWrapper` or `OptimWrapperDict` " - "instance" - ) - for name, lr in lr_dict.items(): - runner.message_hub.update_scalar(f"train/{name}", lr[0]) - - def after_train_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: dict | None = None, - ) -> None: - """Update ``log_vars`` in model outputs every iteration. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. - """ - if outputs is not None: - for key, value in outputs.items(): - runner.message_hub.update_scalar(f"train/{key}", value) - - def before_val(self, runner) -> None: - self.last_loop_stage = runner.message_hub.get_info("loop_stage") - runner.message_hub.update_info("loop_stage", "val") - - def after_val_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each validation epoch. - - Args: - runner (Runner): The runner of the validation process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - if metrics is not None: - for key, value in metrics.items(): - if _is_scalar(value): - runner.message_hub.update_scalar(f"val/{key}", value) - else: - runner.message_hub.update_info(f"val/{key}", value) - - def after_val(self, runner) -> None: - # ValLoop may be called within the TrainLoop, so we need to reset - # the loop_stage - # workflow: before_train -> before_val -> after_val -> after_train - if self.last_loop_stage == "train": - runner.message_hub.update_info("loop_stage", self.last_loop_stage) - self.last_loop_stage = None - else: - runner.message_hub.pop_info("loop_stage") - - def before_test(self, runner) -> None: - runner.message_hub.update_info("loop_stage", "test") - - def after_test(self, runner) -> None: - runner.message_hub.pop_info("loop_stage") - - def after_test_epoch(self, runner, metrics: dict[str, float] | None = None) -> None: - """All subclasses should override this method, if they need any - operations after each test epoch. - - Args: - runner (Runner): The runner of the testing process. - metrics (Dict[str, float], optional): Evaluation results of all - metrics on test dataset. The keys are the names of the - metrics, and the values are corresponding results. - """ - if metrics is not None: - for key, value in metrics.items(): - if _is_scalar(value): - runner.message_hub.update_scalar(f"test/{key}", value) - else: - runner.message_hub.update_info(f"test/{key}", value) diff --git a/libs/visengine/visengine/hooks/sampler_seed_hook.py b/libs/visengine/visengine/hooks/sampler_seed_hook.py deleted file mode 100644 index ffab38d..0000000 --- a/libs/visengine/visengine/hooks/sampler_seed_hook.py +++ /dev/null @@ -1,38 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.registry import HOOKS - -from .hook import Hook - - -@HOOKS.register_module(force=True) -class DistSamplerSeedHook(Hook): - """Data-loading sampler for distributed training. - - When distributed training, it is only useful in conjunction with - :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same - purpose with :obj:`IterLoader`. - """ - - priority = "NORMAL" - - def before_train_epoch(self, runner) -> None: - """Set the seed for sampler and batch_sampler. - - Args: - runner (Runner): The runner of the training process. - """ - if hasattr(runner.train_loop.dataloader, "sampler") and hasattr( - runner.train_loop.dataloader.sampler, "set_epoch" - ): - # In case the` _SingleProcessDataLoaderIter` has no sampler, - # or data loader uses `SequentialSampler` in Pytorch. - runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) - - elif hasattr(runner.train_loop.dataloader, "batch_sampler") and hasattr( - runner.train_loop.dataloader.batch_sampler.sampler, "set_epoch" - ): - # In case the` _SingleProcessDataLoaderIter` has no batch sampler. - # batch sampler in pytorch warps the sampler as its attributes. - runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/libs/visengine/visengine/hooks/sync_buffer_hook.py b/libs/visengine/visengine/hooks/sync_buffer_hook.py deleted file mode 100644 index 1a31a26..0000000 --- a/libs/visengine/visengine/hooks/sync_buffer_hook.py +++ /dev/null @@ -1,46 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.dist import all_reduce_params, is_distributed -from visengine.registry import HOOKS - -from .hook import Hook - - -@HOOKS.register_module(force=True) -class SyncBuffersHook(Hook): - """Synchronize model buffers such as running_mean and running_var in BN at - the end of each epoch.""" - - priority = "NORMAL" - - def __init__(self) -> None: - self.distributed = is_distributed() - # A flag to mark whether synchronization has been done in - # after_train_epoch - self.called_in_train = False - - def before_val_epoch(self, runner) -> None: - """All-reduce model buffers before each validation epoch. - - Synchronize the buffers before each validation if they have not been - synchronized at the end of the previous training epoch. This method - will be called when using IterBasedTrainLoop. - - Args: - runner (Runner): The runner of the training process. - """ - if self.distributed: - if not self.called_in_train: - all_reduce_params(runner.model.buffers(), op="mean") - self.called_in_train = False - - def after_train_epoch(self, runner) -> None: - """All-reduce model buffers at the end of each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - if self.distributed: - all_reduce_params(runner.model.buffers(), op="mean") - self.called_in_train = True diff --git a/libs/visengine/visengine/hub/__init__.py b/libs/visengine/visengine/hub/__init__.py deleted file mode 100644 index aab9115..0000000 --- a/libs/visengine/visengine/hub/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .hub import get_config, get_model - -__all__ = ["get_config", "get_model"] diff --git a/libs/visengine/visengine/hub/deprecated.json b/libs/visengine/visengine/hub/deprecated.json deleted file mode 100644 index 25cf6f2..0000000 --- a/libs/visengine/visengine/hub/deprecated.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "resnet50_caffe": "detectron/resnet50_caffe", - "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", - "resnet101_caffe": "detectron/resnet101_caffe", - "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" -} diff --git a/libs/visengine/visengine/hub/hub.py b/libs/visengine/visengine/hub/hub.py deleted file mode 100644 index 1423861..0000000 --- a/libs/visengine/visengine/hub/hub.py +++ /dev/null @@ -1,97 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import importlib -import os.path as osp - -from visengine.config import Config -from visengine.config.utils import ( - _get_cfg_metainfo, - _get_external_cfg_base_path, - _get_package_and_cfg_path, -) -from visengine.registry import MODELS, DefaultScope -from visengine.runner import load_checkpoint -from visengine.utils import get_installed_path, install_package -from ml_env_config.env import env -from vision.tools.logger import logger -from pathlib import Path - - -def get_config(cfg_path: str, pretrained: bool = False) -> Config: - """Get config from external package. - - Args: - cfg_path (str): External relative config path. - pretrained (bool): Whether to save pretrained model path. If - ``pretrained==True``, the url of pretrained model can be accessed - by ``cfg.model_path``. Defaults to False. - - Examples: - >>> cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) - >>> # Equivalent to - >>> # cfg = Config.fromfile('/path/to/faster-rcnn_r50_fpn_1x_coco.py') - >>> cfg.model_path - https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth - - Returns: - Config: A `Config` parsed from external package. - """ - # Get package name and relative config path. - package, cfg_path = _get_package_and_cfg_path(cfg_path) - # Handle the renaming from mmdet to visdet - if package == "mmdet": - package = "visdet" - package_path = osp.join(osp.dirname(osp.abspath(importlib.import_module(package).__file__))) - try: - # Use `cfg_path` to search target config file. - cfg_meta = _get_cfg_metainfo(package_path, cfg_path) - cfg_basepath = Path(package_path).parent - cfg_path = osp.join(cfg_basepath, cfg_meta["Config"]) - logger.info(f"Config path --> {cfg_path}") - cfg = Config.fromfile(cfg_path) - if pretrained: - assert "Weights" in cfg_meta, "Cannot find `Weights` in cfg_file.metafile.yml, please check themetafile" - cfg.model_path = cfg_meta["Weights"] - except ValueError: - # Since the base config does not contain a metafile, the absolute - # config is `osp.join(package_path, cfg_path_prefix, cfg_name)` - cfg_path = _get_external_cfg_base_path(package_path, cfg_path) - cfg = Config.fromfile(cfg_path) - except Exception as e: - raise e - return cfg - - -def get_model(cfg_path: str, pretrained: bool = False, **kwargs): - """Get built model from external package. - - Args: - cfg_path (str): External relative config path with prefix - 'package::' and without suffix. - pretrained (bool): Whether to load pretrained model. Defaults to False. - kwargs (dict): Default arguments to build model. - - Examples: - >>> model = get_model('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) - >>> type(model) - - - Returns: - nn.Module: Built model. - """ - package = cfg_path.split("::")[0] - with DefaultScope.overwrite_default_scope(package): # type: ignore - cfg = get_config(cfg_path, pretrained) - if "data_preprocessor" in cfg: - cfg.model.data_preprocessor = cfg.data_preprocessor - models_module = importlib.import_module(f"{package}.utils") - models_module.register_all_modules() # type: ignore - model = MODELS.build(cfg.model, default_args=kwargs) - if pretrained: - load_checkpoint(model, cfg.model_path) - # Hack to use pretrained weights. - # If we do not set _is_init here, Runner will call - # `model.init_weights()` to overwrite the pretrained model. - model._is_init = True - return model diff --git a/libs/visengine/visengine/hub/mmcls.json b/libs/visengine/visengine/hub/mmcls.json deleted file mode 100644 index c073a41..0000000 --- a/libs/visengine/visengine/hub/mmcls.json +++ /dev/null @@ -1,59 +0,0 @@ -{ - "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth", - "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth", - "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth", - "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth", - "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth", - "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth", - "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth", - "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth", - "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth", - "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth", - "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth", - "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth", - "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth", - "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth", - "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth", - "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth", - "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth", - "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth", - "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth", - "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth", - "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth", - "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth", - "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth", - "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth", - "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth", - "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth", - "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth", - "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth", - "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth", - "mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth", - "mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth", - "repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth", - "repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth", - "repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth", - "repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth", - "repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth", - "repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth", - "repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth", - "repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth", - "repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth", - "repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth", - "repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth", - "repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth", - "res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth", - "res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth", - "res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth", - "swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth", - "swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth", - "swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth", - "swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth", - "t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth", - "t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth", - "t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth", - "tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth", - "vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth", - "vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth", - "vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth" -} diff --git a/libs/visengine/visengine/hub/openmmlab.json b/libs/visengine/visengine/hub/openmmlab.json deleted file mode 100644 index 8311db4..0000000 --- a/libs/visengine/visengine/hub/openmmlab.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth", - "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth", - "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth", - "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth", - "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth", - "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth", - "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", - "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth", - "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", - "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", - "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth", - "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth", - "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", - "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", - "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", - "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", - "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", - "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", - "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth", - "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth", - "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", - "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", - "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth", - "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", - "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", - "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", - "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth", - "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth", - "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth", - "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth", - "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth", - "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth", - "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth", - "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth", - "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth", - "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth", - "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth", - "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth", - "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth", - "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth", - "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth", - "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth", - "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth", - "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth", - "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth", - "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth", - "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth", - "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth" -} diff --git a/libs/visengine/visengine/hub/torchvision_0.12.json b/libs/visengine/visengine/hub/torchvision_0.12.json deleted file mode 100644 index 9f457da..0000000 --- a/libs/visengine/visengine/hub/torchvision_0.12.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", - "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - "shufflenetv2_x1.5": null, - "shufflenetv2_x2.0": null, - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" -} diff --git a/libs/visengine/visengine/infer/__init__.py b/libs/visengine/visengine/infer/__init__.py deleted file mode 100644 index 761ec99..0000000 --- a/libs/visengine/visengine/infer/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .infer import BaseInferencer - -__all__ = ["BaseInferencer"] diff --git a/libs/visengine/visengine/infer/infer.py b/libs/visengine/visengine/infer/infer.py deleted file mode 100644 index 44143af..0000000 --- a/libs/visengine/visengine/infer/infer.py +++ /dev/null @@ -1,667 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import importlib -import os.path as osp -import re -import warnings -from abc import ABCMeta, abstractmethod -from collections.abc import Callable, Iterable, Sequence -from datetime import datetime -from typing import Any, Union - -import numpy as np -import torch -import torch.nn as nn -from rich.progress import track - -from visengine.config import Config, ConfigDict -from visengine.config.utils import MODULE2PACKAGE -from visengine.dataset import pseudo_collate -from visengine.device import get_device -from visengine.fileio import get_file_backend, isdir, join_path, list_dir_or_file, load -from visengine.logging import print_log -from visengine.registry import FUNCTIONS, MODELS, VISUALIZERS, DefaultScope -from visengine.runner.checkpoint import _load_checkpoint, _load_checkpoint_to_model -from visengine.structures import InstanceData -from visengine.visualization import Visualizer - -InstanceList = list[InstanceData] -InputType = Union[str, np.ndarray, torch.Tensor] -InputsType = Union[InputType, Sequence[InputType]] -ImgType = Union[np.ndarray, Sequence[np.ndarray]] -ResType = Union[dict, list[dict]] -ConfigType = Union[Config, ConfigDict] -ModelType = Union[dict, ConfigType, str] - - -class InferencerMeta(ABCMeta): - """Check the legality of the inferencer. - - All Inferencers should not define duplicated keys for - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` and - ``postprocess_kwargs``. - """ - - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - assert isinstance(cls.preprocess_kwargs, set) - assert isinstance(cls.forward_kwargs, set) - assert isinstance(cls.visualize_kwargs, set) - assert isinstance(cls.postprocess_kwargs, set) - - all_kwargs = cls.preprocess_kwargs | cls.forward_kwargs | cls.visualize_kwargs | cls.postprocess_kwargs - - assert len(all_kwargs) == ( - len(cls.preprocess_kwargs) - + len(cls.forward_kwargs) - + len(cls.visualize_kwargs) - + len(cls.postprocess_kwargs) - ), ( - f"Class define error! {cls.__name__} should not " - "define duplicated keys for `preprocess_kwargs`, " - "`forward_kwargs`, `visualize_kwargs` and " - "`postprocess_kwargs` are not allowed." - ) - - -class BaseInferencer(metaclass=InferencerMeta): - """Base inferencer for downstream tasks. - - The BaseInferencer provides the standard workflow for inference as follows: - - 1. Preprocess the input data by :meth:`preprocess`. - 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` - assumes the model inherits from :class:`mmengine.models.BaseModel` and - will call `model.test_step` in :meth:`forward` by default. - 3. Visualize the results by :meth:`visualize`. - 4. Postprocess and return the results by :meth:`postprocess`. - - When we call the subclasses inherited from BaseInferencer (not overriding - ``__call__``), the workflow will be executed in order. - - All subclasses of BaseInferencer could define the following class - attributes for customization: - - - ``preprocess_kwargs``: The keys of the kwargs that will be passed to - :meth:`preprocess`. - - ``forward_kwargs``: The keys of the kwargs that will be passed to - :meth:`forward` - - ``visualize_kwargs``: The keys of the kwargs that will be passed to - :meth:`visualize` - - ``postprocess_kwargs``: The keys of the kwargs that will be passed to - :meth:`postprocess` - - All attributes mentioned above should be a ``set`` of keys (strings), - and each key should not be duplicated. Actually, :meth:`__call__` will - dispatch all the arguments to the corresponding methods according to the - ``xxx_kwargs`` mentioned above, therefore, the key in sets should - be unique to avoid ambiguous dispatching. - - Warning: - If subclasses defined the class attributes mentioned above with - duplicated keys, an ``AssertionError`` will be raised during import - process. - - Subclasses inherited from ``BaseInferencer`` should implement - :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: - - - _init_pipeline: Return a callable object to preprocess the input data. - - visualize: Visualize the results returned by :meth:`forward`. - - postprocess: Postprocess the results returned by :meth:`forward` and - :meth:`visualize`. - - Args: - model (str, optional): Path to the config file or the model name - defined in metafile. Take the `mmdet metafile `_ - as an example, the `model` could be `retinanet_r18_fpn_1x_coco` or - its alias. If model is not specified, user must provide the - `weights` saved by MMEngine which contains the config string. - Defaults to None. - weights (str, optional): Path to the checkpoint. If it is not specified - and model is a model name of metafile, the weights will be loaded - from metafile. Defaults to None. - device (str, optional): Device to run inference. If None, the available - device will be automatically used. Defaults to None. - scope (str, optional): The scope of the model. Defaults to None. - show_progress (bool): Control whether to display the progress bar during - the inference process. Defaults to True. - `New in version 0.7.4.` - - Note: - Since ``Inferencer`` could be used to infer batch data, - `collate_fn` should be defined. If `collate_fn` is not defined in config - file, the `collate_fn` will be `pseudo_collate` by default. - """ - - preprocess_kwargs: set = set() - forward_kwargs: set = set() - visualize_kwargs: set = set() - postprocess_kwargs: set = set() - - def __init__( - self, - model: ModelType | str | None = None, - weights: str | None = None, - device: str | None = None, - scope: str | None = None, - show_progress: bool = True, - ) -> None: - if scope is None: - default_scope = DefaultScope.get_current_instance() - if default_scope is not None: - scope = default_scope.scope_name - self.scope = scope - # Load config to cfg - cfg: ConfigType - if isinstance(model, str): - if osp.isfile(model): - cfg = Config.fromfile(model) - else: - # Load config and weights from metafile. If `weights` is - # assigned, the weights defined in metafile will be ignored. - cfg, _weights = self._load_model_from_metafile(model) - if weights is None: - weights = _weights - elif isinstance(model, Config | ConfigDict): - cfg = copy.deepcopy(model) - elif isinstance(model, dict): - cfg = copy.deepcopy(ConfigDict(model)) - elif model is None: - if weights is None: - raise ValueError( - "If model is None, the weights must be specified since the config needs to be loaded from the weights" - ) - cfg = ConfigDict() - else: - raise TypeError(f"model must be a filepath or any ConfigTypeobject, but got {type(model)}") - - if device is None: - device = get_device() - - self.model = self._init_model(cfg, weights, device) # type: ignore - self.pipeline = self._init_pipeline(cfg) - self.collate_fn = self._init_collate(cfg) - self.visualizer = self._init_visualizer(cfg) - self.cfg = cfg - self.show_progress = show_progress - - def __call__( - self, - inputs: InputsType, - return_datasamples: bool = False, - batch_size: int = 1, - **kwargs, - ) -> dict: - """Call the inferencer. - - Args: - inputs (InputsType): Inputs for the inferencer. - return_datasamples (bool): Whether to return results as - :obj:`BaseDataElement`. Defaults to False. - batch_size (int): Batch size. Defaults to 1. - **kwargs: Key words arguments passed to :meth:`preprocess`, - :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. - Each key in kwargs should be in the corresponding set of - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` - and ``postprocess_kwargs``. - - Returns: - dict: Inference and visualization results. - """ - ( - preprocess_kwargs, - forward_kwargs, - visualize_kwargs, - postprocess_kwargs, - ) = self._dispatch_kwargs(**kwargs) - - ori_inputs = self._inputs_to_list(inputs) - inputs = self.preprocess(ori_inputs, batch_size=batch_size, **preprocess_kwargs) - preds = [] - for data in track(inputs, description="Inference") if self.show_progress else inputs: - preds.extend(self.forward(data, **forward_kwargs)) - visualization = self.visualize(ori_inputs, preds, **visualize_kwargs) # type: ignore - results = self.postprocess(preds, visualization, return_datasamples, **postprocess_kwargs) - return results - - def _inputs_to_list(self, inputs: InputsType) -> list: - """Preprocess the inputs to a list. - - Preprocess inputs to a list according to its type: - - - list or tuple: return inputs - - str: - - Directory path: return all files in the directory - - other cases: return a list containing the string. The string - could be a path to file, a url or other types of string according - to the task. - - Args: - inputs (InputsType): Inputs for the inferencer. - - Returns: - list: List of input for the :meth:`preprocess`. - """ - if isinstance(inputs, str): - backend = get_file_backend(inputs) - if hasattr(backend, "isdir") and isdir(inputs): - # Backends like HttpsBackend do not implement `isdir`, so only - # those backends that implement `isdir` could accept the inputs - # as a directory - filename_list = list_dir_or_file(inputs, list_dir=False) - inputs = [join_path(inputs, filename) for filename in filename_list] - - if not isinstance(inputs, list | tuple): - inputs = [inputs] - - return list(inputs) - - def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): - """Process the inputs into a model-feedable format. - - Customize your preprocess by overriding this method. Preprocess should - return an iterable object, of which each item will be used as the - input of ``model.test_step``. - - ``BaseInferencer.preprocess`` will return an iterable chunked data, - which will be used in __call__ like this: - - .. code-block:: python - - def __call__(self, inputs, batch_size=1, **kwargs): - chunked_data = self.preprocess(inputs, batch_size, **kwargs) - for batch in chunked_data: - preds = self.forward(batch, **kwargs) - - Args: - inputs (InputsType): Inputs given by user. - batch_size (int): batch size. Defaults to 1. - - Yields: - Any: Data processed by the ``pipeline`` and ``collate_fn``. - """ - chunked_data = self._get_chunk_data(map(self.pipeline, inputs), batch_size) - yield from map(self.collate_fn, chunked_data) - - @torch.no_grad() - def forward(self, inputs: dict | tuple, **kwargs) -> Any: - """Feed the inputs to the model.""" - return self.model.test_step(inputs) - - @abstractmethod - def visualize(self, inputs: list, preds: Any, show: bool = False, **kwargs) -> list[np.ndarray]: - """Visualize predictions. - - Customize your visualization by overriding this method. visualize - should return visualization results, which could be np.ndarray or any - other objects. - - Args: - inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. - preds (Any): Predictions of the model. - show (bool): Whether to display the image in a popup window. - Defaults to False. - - Returns: - List[np.ndarray]: Visualization results. - """ - - @abstractmethod - def postprocess( - self, - preds: Any, - visualization: list[np.ndarray], - return_datasample=False, - **kwargs, - ) -> dict: - """Process the predictions and visualization results from ``forward`` - and ``visualize``. - - This method should be responsible for the following tasks: - - 1. Convert datasamples into a json-serializable dict if needed. - 2. Pack the predictions and visualization results and return them. - 3. Dump or log the predictions. - - Customize your postprocess by overriding this method. Make sure - ``postprocess`` will return a dict with visualization results and - inference results. - - Args: - preds (List[Dict]): Predictions of the model. - visualization (np.ndarray): Visualized predictions. - return_datasample (bool): Whether to return results as datasamples. - Defaults to False. - - Returns: - dict: Inference and visualization results with key ``predictions`` - and ``visualization`` - - - ``visualization (Any)``: Returned by :meth:`visualize` - - ``predictions`` (dict or DataSample): Returned by - :meth:`forward` and processed in :meth:`postprocess`. - If ``return_datasample=False``, it usually should be a - json-serializable dict containing only basic data elements such - as strings and numbers. - """ - - def _load_model_from_metafile(self, model: str) -> tuple[Config, str]: - """Load config and weights from metafile. - - Args: - model (str): model name defined in metafile. - - Returns: - Tuple[Config, str]: Loaded Config and weights path defined in - metafile. - """ - model = model.lower() - - assert self.scope is not None, "scope should be initialized if you want to load config from metafile." - assert self.scope in MODULE2PACKAGE, f"{self.scope} not in {MODULE2PACKAGE}!,please pass a valid scope." - - repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(self.scope) - for model_cfg in BaseInferencer._get_models_from_metafile(repo_or_mim_dir): - model_name = model_cfg["Name"].lower() - model_aliases = model_cfg.get("Alias", []) - if isinstance(model_aliases, str): - model_aliases = [model_aliases.lower()] - else: - model_aliases = [alias.lower() for alias in model_aliases] - if model_name == model or model in model_aliases: - cfg = Config.fromfile(osp.join(repo_or_mim_dir, model_cfg["Config"])) - weights = model_cfg["Weights"] - weights = weights[0] if isinstance(weights, list) else weights - return cfg, weights - raise ValueError(f"Cannot find model: {model} in {self.scope}") - - @staticmethod - def _get_repo_or_mim_dir(scope): - """Get the directory where the ``Configs`` located when the package is - installed or ``PYTHONPATH`` is set. - - Args: - scope (str): The scope of repository. - - Returns: - str: The directory where the ``Configs`` is located. - """ - try: - module = importlib.import_module(scope) - except ImportError: - if scope not in MODULE2PACKAGE: - raise KeyError(f"{scope} is not a valid scope. The available scopes are {MODULE2PACKAGE.keys()}") - else: - project = MODULE2PACKAGE[scope] - raise ImportError( - f'Cannot import {scope} correctly, please try to install the {project} by "pip install {project}"' - ) - # Since none of OpenMMLab series packages are namespace packages - # (https://docs.python.org/3/glossary.html#term-namespace-package), - # The first element of module.__path__ means package installation path. - package_path = module.__path__[0] - - if osp.exists(osp.join(osp.dirname(package_path), "configs")): - repo_dir = osp.dirname(package_path) - return repo_dir - else: - mim_dir = osp.join(package_path, ".mim") - if not osp.exists(osp.join(mim_dir, "configs")): - raise FileNotFoundError( - f"Cannot find `configs` directory in {package_path}!, please check the completeness of the {scope}." - ) - return mim_dir - - def _init_model( - self, - cfg: ConfigType, - weights: str | None, - device: str = "cpu", - ) -> nn.Module: - """Initialize the model with the given config and checkpoint on the - specific device. - - Args: - cfg (ConfigType): Config containing the model information. - weights (str, optional): Path to the checkpoint. - device (str, optional): Device to run inference. Defaults to 'cpu'. - - Returns: - nn.Module: Model loaded with checkpoint. - """ - checkpoint: dict | None = None - if weights is not None: - checkpoint = _load_checkpoint(weights, map_location="cpu") - - if not cfg: - assert checkpoint is not None - try: - # Prefer to get config from `message_hub` since `message_hub` - # is a more stable module to store all runtime information. - # However, the early version of MMEngine will not save config - # in `message_hub`, so we will try to load config from `meta`. - cfg_string = checkpoint["message_hub"]["runtime_info"]["cfg"] - except KeyError: - assert "meta" in checkpoint, ( - "If model(config) is not provided, the checkpoint must" - "contain the config string in `meta` or `message_hub`, " - "but both `meta` and `message_hub` are not found in the " - "checkpoint." - ) - meta = checkpoint["meta"] - if "cfg" in meta: - cfg_string = meta["cfg"] - else: - raise ValueError("Cannot find the config in the checkpoint.") - cfg.update(Config.fromstring(cfg_string, file_format=".py")._cfg_dict) - - # Delete the `pretrained` field to prevent model from loading the - # the pretrained weights unnecessarily. - if cfg.model.get("pretrained") is not None: - del cfg.model.pretrained - - model = MODELS.build(cfg.model) - model.cfg = cfg - self._load_weights_to_model(model, checkpoint, cfg) - model.to(device) - model.eval() - return model - - def _load_weights_to_model(self, model: nn.Module, checkpoint: dict | None, cfg: ConfigType | None) -> None: - """Loading model weights and meta information from cfg and checkpoint. - - Subclasses could override this method to load extra meta information - from ``checkpoint`` and ``cfg`` to model. - - Args: - model (nn.Module): Model to load weights and meta information. - checkpoint (dict, optional): The loaded checkpoint. - cfg (Config or ConfigDict, optional): The loaded config. - """ - if checkpoint is not None: - _load_checkpoint_to_model(model, checkpoint) - else: - warnings.warn( - "Checkpoint is not loaded, and the inference result is calculated by the randomly initialized model!", - stacklevel=2, - ) - - def _init_collate(self, cfg: ConfigType) -> Callable: - """Initialize the ``collate_fn`` with the given config. - - The returned ``collate_fn`` will be used to collate the batch data. - If will be used in :meth:`preprocess` like this - - .. code-block:: python - def preprocess(self, inputs, batch_size, **kwargs): - ... - dataloader = map(self.collate_fn, dataloader) - yield from dataloader - - Args: - cfg (ConfigType): Config which could contained the `collate_fn` - information. If `collate_fn` is not defined in config, it will - be :func:`pseudo_collate`. - - Returns: - Callable: Collate function. - """ - try: - with FUNCTIONS.switch_scope_and_registry(self.scope) as registry: - collate_fn = registry.get(cfg.test_dataloader.collate_fn) - except AttributeError: - collate_fn = pseudo_collate - return collate_fn # type: ignore - - @abstractmethod - def _init_pipeline(self, cfg: ConfigType) -> Callable: - """Initialize the test pipeline. - - Return a pipeline to handle various input data, such as ``str``, - ``np.ndarray``. It is an abstract method in BaseInferencer, and should - be implemented in subclasses. - - The returned pipeline will be used to process a single data. - It will be used in :meth:`preprocess` like this: - - .. code-block:: python - def preprocess(self, inputs, batch_size, **kwargs): - ... - dataset = map(self.pipeline, dataset) - ... - """ - - def _init_visualizer(self, cfg: ConfigType) -> Visualizer | None: - """Initialize visualizers. - - Args: - cfg (ConfigType): Config containing the visualizer information. - - Returns: - Visualizer or None: Visualizer initialized with config. - """ - if "visualizer" not in cfg: - return None - timestamp = str(datetime.timestamp(datetime.now())) - name = cfg.visualizer.get("name", timestamp) - if Visualizer.check_instance_created(name): - name = f"{name}-{timestamp}" - cfg.visualizer.name = name - return VISUALIZERS.build(cfg.visualizer) - - def _get_chunk_data(self, inputs: Iterable, chunk_size: int): - """Get batch data from dataset. - - Args: - inputs (Iterable): An iterable dataset. - chunk_size (int): Equivalent to batch size. - - Yields: - list: batch data. - """ - inputs_iter = iter(inputs) - while True: - try: - chunk_data = [] - for _ in range(chunk_size): - processed_data = next(inputs_iter) - chunk_data.append(processed_data) - yield chunk_data - except StopIteration: - if chunk_data: - yield chunk_data - break - - def _dispatch_kwargs(self, **kwargs) -> tuple[dict, dict, dict, dict]: - """Dispatch kwargs to preprocess(), forward(), visualize() and - postprocess() according to the actual demands. - - Returns: - Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, - forward, visualize and postprocess respectively. - """ - # Ensure each argument only matches one function - method_kwargs = self.preprocess_kwargs | self.forward_kwargs | self.visualize_kwargs | self.postprocess_kwargs - - union_kwargs = method_kwargs | set(kwargs.keys()) - if union_kwargs != method_kwargs: - unknown_kwargs = union_kwargs - method_kwargs - raise ValueError( - f"unknown argument {unknown_kwargs} for `preprocess`, `forward`, `visualize` and `postprocess`" - ) - - preprocess_kwargs = {} - forward_kwargs = {} - visualize_kwargs = {} - postprocess_kwargs = {} - - for key, value in kwargs.items(): - if key in self.preprocess_kwargs: - preprocess_kwargs[key] = value - elif key in self.forward_kwargs: - forward_kwargs[key] = value - elif key in self.visualize_kwargs: - visualize_kwargs[key] = value - else: - postprocess_kwargs[key] = value - - return ( - preprocess_kwargs, - forward_kwargs, - visualize_kwargs, - postprocess_kwargs, - ) - - @staticmethod - def _get_models_from_metafile(dir: str): - """Load model config defined in metafile from package path. - - Args: - dir (str): Path to the directory of Config. It requires the - directory ``Config``, file ``model-index.yml`` exists in the - ``dir``. - - Yields: - dict: Model config defined in metafile. - """ - meta_indexes = load(osp.join(dir, "model-index.yml")) - for meta_path in meta_indexes["Import"]: - # meta_path example: mmcls/.mim/configs/conformer/metafile.yml - meta_path = osp.join(dir, meta_path) - metainfo = load(meta_path) - yield from metainfo["Models"] - - @staticmethod - def list_models(scope: str | None = None, patterns: str = r".*"): - """List models defined in metafile of corresponding packages. - - Args: - scope (str, optional): The scope to which the model belongs. - Defaults to None. - patterns (str, optional): Regular expressions for the searched - models. Once matched with ``Alias`` or ``Name`` filed in - metafile, corresponding model will be added to the return list. - Defaults to '.*'. - - Returns: - dict: Model dict with model name and its alias. - """ - matched_models = [] - if scope is None: - default_scope = DefaultScope.get_current_instance() - assert default_scope is not None, "scope should be initialized if you want to load config from metafile." - assert scope in MODULE2PACKAGE, f"{scope} not in {MODULE2PACKAGE}!, please make pass a valid scope." - root_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope) - for model_cfg in BaseInferencer._get_models_from_metafile(root_or_mim_dir): - model_name = [model_cfg["Name"]] - model_name.extend(model_cfg.get("Alias", [])) - for name in model_name: - if re.match(patterns, name) is not None: - matched_models.append(name) - output_str = "" - for name in matched_models: - output_str += f"model_name: {name}\n" - print_log(output_str, logger="current") - return matched_models diff --git a/libs/visengine/visengine/logging/__init__.py b/libs/visengine/visengine/logging/__init__.py deleted file mode 100644 index f5466de..0000000 --- a/libs/visengine/visengine/logging/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .history_buffer import HistoryBuffer -from .logger import MMLogger, print_log -from .message_hub import MessageHub - -__all__ = ["HistoryBuffer", "MMLogger", "MessageHub", "print_log"] diff --git a/libs/visengine/visengine/logging/history_buffer.py b/libs/visengine/visengine/logging/history_buffer.py deleted file mode 100644 index 85fb7bb..0000000 --- a/libs/visengine/visengine/logging/history_buffer.py +++ /dev/null @@ -1,224 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from collections.abc import Callable, Sequence -from typing import Any - -import numpy as np - - -class HistoryBuffer: - """Unified storage format for different log types. - - ``HistoryBuffer`` records the history of log for further statistics. - - Examples: - >>> history_buffer = HistoryBuffer() - >>> # Update history_buffer. - >>> history_buffer.update(1) - >>> history_buffer.update(2) - >>> history_buffer.min() # minimum of (1, 2) - 1 - >>> history_buffer.max() # maximum of (1, 2) - 2 - >>> history_buffer.mean() # mean of (1, 2) - 1.5 - >>> history_buffer.statistics('mean') # access method by string. - 1.5 - - Args: - log_history (Sequence): History logs. Defaults to []. - count_history (Sequence): Counts of history logs. Defaults to []. - max_length (int): The max length of history logs. Defaults to 1000000. - """ - - _statistics_methods: dict = {} - - def __init__( - self, - log_history: Sequence = [], - count_history: Sequence = [], - max_length: int = 1000000, - ): - self.max_length = max_length - self._set_default_statistics() - assert len(log_history) == len(count_history), "The lengths of log_history and count_histroy should be equal" - if len(log_history) > max_length: - warnings.warn( - f"The length of history buffer({len(log_history)}) exceeds the max_length({max_length}), the first few elements will be ignored.", - stacklevel=2, - ) - self._log_history = np.array(log_history[-max_length:]) - self._count_history = np.array(count_history[-max_length:]) - else: - self._log_history = np.array(log_history) - self._count_history = np.array(count_history) - - def _set_default_statistics(self) -> None: - """Register default statistic methods: min, max, current and mean.""" - self._statistics_methods.setdefault("min", HistoryBuffer.min) - self._statistics_methods.setdefault("max", HistoryBuffer.max) - self._statistics_methods.setdefault("current", HistoryBuffer.current) - self._statistics_methods.setdefault("mean", HistoryBuffer.mean) - - def update(self, log_val: int | float, count: int = 1) -> None: - """Update the log history. - - If the length of the buffer exceeds ``self._max_length``, the oldest - element will be removed from the buffer. - - Args: - log_val (int or float): The value of log. - count (int): The accumulation times of log, defaults to 1. - ``count`` will be used in smooth statistics. - """ - if not isinstance(log_val, int | float) or not isinstance(count, int | float): - raise TypeError( - f"log_val must be int or float but got {type(log_val)}, count must be int but got {type(count)}" - ) - self._log_history = np.append(self._log_history, log_val) - self._count_history = np.append(self._count_history, count) - if len(self._log_history) > self.max_length: - self._log_history = self._log_history[-self.max_length :] - self._count_history = self._count_history[-self.max_length :] - - @property - def data(self) -> tuple[np.ndarray, np.ndarray]: - """Get the ``_log_history`` and ``_count_history``. - - Returns: - Tuple[np.ndarray, np.ndarray]: History logs and the counts of - the history logs. - """ - return self._log_history, self._count_history - - @classmethod - def register_statistics(cls, method: Callable) -> Callable: - """Register custom statistics method to ``_statistics_methods``. - - The registered method can be called by ``history_buffer.statistics`` - with corresponding method name and arguments. - - Examples: - >>> @HistoryBuffer.register_statistics - >>> def weighted_mean(self, window_size, weight): - >>> assert len(weight) == window_size - >>> return (self._log_history[-window_size:] * - >>> np.array(weight)).sum() / \ - >>> self._count_history[-window_size:] - - >>> log_buffer = HistoryBuffer([1, 2], [1, 1]) - >>> log_buffer.statistics('weighted_mean', 2, [2, 1]) - 2 - - Args: - method (Callable): Custom statistics method. - Returns: - Callable: Original custom statistics method. - """ - method_name = method.__name__ - assert method_name not in cls._statistics_methods, "method_name cannot be registered twice!" - cls._statistics_methods[method_name] = method - return method - - def statistics(self, method_name: str, *arg, **kwargs) -> Any: - """Access statistics method by name. - - Args: - method_name (str): Name of method. - - Returns: - Any: Depends on corresponding method. - """ - if method_name not in self._statistics_methods: - raise KeyError(f"{method_name} has not been registered in HistoryBuffer._statistics_methods") - method = self._statistics_methods[method_name] - # Provide self arguments for registered functions. - return method(self, *arg, **kwargs) - - def mean(self, window_size: int | None = None) -> np.ndarray: - """Return the mean of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global mean value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: Mean value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), f"The type of window size should be int, but got {type(window_size)}" - else: - window_size = len(self._log_history) - logs_sum = self._log_history[-window_size:].sum() - counts_sum = self._count_history[-window_size:].sum() - return logs_sum / counts_sum - - def max(self, window_size: int | None = None) -> np.ndarray: - """Return the maximum value of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global maximum value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: The maximum value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), f"The type of window size should be int, but got {type(window_size)}" - else: - window_size = len(self._log_history) - return self._log_history[-window_size:].max() - - def min(self, window_size: int | None = None) -> np.ndarray: - """Return the minimum value of the latest ``window_size`` values in log - histories. - - If ``window_size is None`` or ``window_size > len(self._log_history)``, - return the global minimum value of history logs. - - Args: - window_size (int, optional): Size of statistics window. - Returns: - np.ndarray: The minimum value within the window. - """ - if window_size is not None: - assert isinstance(window_size, int), f"The type of window size should be int, but got {type(window_size)}" - else: - window_size = len(self._log_history) - return self._log_history[-window_size:].min() - - def current(self) -> np.ndarray: - """Return the recently updated values in log histories. - - Returns: - np.ndarray: Recently updated values in log histories. - """ - if len(self._log_history) == 0: - raise ValueError("HistoryBuffer._log_history is an empty array! please call update first") - return self._log_history[-1] - - def __getstate__(self) -> dict: - """Make ``_statistics_methods`` can be resumed. - - Returns: - dict: State dict including statistics_methods. - """ - self.__dict__.update(statistics_methods=self._statistics_methods) - return self.__dict__ - - def __setstate__(self, state): - """Try to load ``_statistics_methods`` from state. - - Args: - state (dict): State dict. - """ - statistics_methods = state.pop("statistics_methods", {}) - self._set_default_statistics() - self._statistics_methods.update(statistics_methods) - self.__dict__.update(state) diff --git a/libs/visengine/visengine/logging/logger.py b/libs/visengine/visengine/logging/logger.py deleted file mode 100644 index a8be1cc..0000000 --- a/libs/visengine/visengine/logging/logger.py +++ /dev/null @@ -1,442 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -import os -import os.path as osp -import sys -import warnings -from getpass import getuser -from logging import Logger, LogRecord, handlers -from socket import gethostname - -from termcolor import colored - -from visengine.utils import ManagerMixin -from visengine.utils.manager import _accquire_lock, _release_lock - - -class FilterDuplicateWarning(logging.Filter): - """Filter the repeated warning message. - - Args: - name (str): name of the filter. - """ - - def __init__(self, name: str = "mmengine"): - super().__init__(name) - self.seen: set = set() - - def filter(self, record: LogRecord) -> bool: - """Filter the repeated warning message. - - Args: - record (LogRecord): The log record. - - Returns: - bool: Whether to output the log record. - """ - if record.levelno != logging.WARNING: - return True - - if record.msg not in self.seen: - self.seen.add(record.msg) - return True - return False - - -class MMFormatter(logging.Formatter): - """Colorful format for MMLogger. If the log level is error, the logger will - additionally output the location of the code. - - Args: - color (bool): Whether to use colorful format. filehandler is not - allowed to use color format, otherwise it will be garbled. - blink (bool): Whether to blink the ``INFO`` and ``DEBUG`` logging - level. - **kwargs: Keyword arguments passed to - :meth:`logging.Formatter.__init__`. - """ - - _color_mapping: dict = { - "ERROR": "red", - "WARNING": "yellow", - "INFO": "white", - "DEBUG": "green", - } - - def __init__(self, color: bool = True, blink: bool = False, **kwargs): - super().__init__(**kwargs) - assert not (not color and blink), "blink should only be available when color is True" - # Get prefix format according to color. - error_prefix = self._get_prefix("ERROR", color, blink=True) - warn_prefix = self._get_prefix("WARNING", color, blink=True) - info_prefix = self._get_prefix("INFO", color, blink) - debug_prefix = self._get_prefix("DEBUG", color, blink) - - # Config output format. - self.err_format = ( - f"%(asctime)s - %(name)s - {error_prefix} - %(pathname)s - %(funcName)s - %(lineno)d - %(message)s" - ) - self.warn_format = f"%(asctime)s - %(name)s - {warn_prefix} - %(message)s" - self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(message)s" - self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(message)s" - - def _get_prefix(self, level: str, color: bool, blink=False) -> str: - """Get the prefix of the target log level. - - Args: - level (str): log level. - color (bool): Whether to get colorful prefix. - blink (bool): Whether the prefix will blink. - - Returns: - str: The plain or colorful prefix. - """ - if color: - attrs = ["underline"] - if blink: - attrs.append("blink") - prefix = colored(level, self._color_mapping[level], attrs=attrs) - else: - prefix = level - return prefix - - def format(self, record: LogRecord) -> str: - """Override the `logging.Formatter.format`` method `. Output the - message according to the specified log level. - - Args: - record (LogRecord): A LogRecord instance represents an event being - logged. - - Returns: - str: Formatted result. - """ - if record.levelno == logging.ERROR: - self._style._fmt = self.err_format - elif record.levelno == logging.WARNING: - self._style._fmt = self.warn_format - elif record.levelno == logging.INFO: - self._style._fmt = self.info_format - elif record.levelno == logging.DEBUG: - self._style._fmt = self.debug_format - - result = logging.Formatter.format(self, record) - return result - - -class MMLogger(Logger, ManagerMixin): - """Formatted logger used to record messages. - - ``MMLogger`` can create formatted logger to log message with different - log levels and get instance in the same way as ``ManagerMixin``. - ``MMLogger`` has the following features: - - - Distributed log storage, ``MMLogger`` can choose whether to save log of - different ranks according to `log_file`. - - Message with different log levels will have different colors and format - when displayed on terminal. - - Note: - - The `name` of logger and the ``instance_name`` of ``MMLogger`` could - be different. We can only get ``MMLogger`` instance by - ``MMLogger.get_instance`` but not ``logging.getLogger``. This feature - ensures ``MMLogger`` will not be incluenced by third-party logging - config. - - Different from ``logging.Logger``, ``MMLogger`` will not log warning - or error message without ``Handler``. - - Examples: - >>> logger = MMLogger.get_instance(name='MMLogger', - >>> logger_name='Logger') - >>> # Although logger has name attribute just like `logging.Logger` - >>> # We cannot get logger instance by `logging.getLogger`. - >>> assert logger.name == 'Logger' - >>> assert logger.instance_name = 'MMLogger' - >>> assert id(logger) != id(logging.getLogger('Logger')) - >>> # Get logger that do not store logs. - >>> logger1 = MMLogger.get_instance('logger1') - >>> # Get logger only save rank0 logs. - >>> logger2 = MMLogger.get_instance('logger2', log_file='out.log') - >>> # Get logger only save multiple ranks logs. - >>> logger3 = MMLogger.get_instance('logger3', log_file='out.log', - >>> distributed=True) - - Args: - name (str): Global instance name. - logger_name (str): ``name`` attribute of ``Logging.Logger`` instance. - If `logger_name` is not defined, defaults to 'visengine'. - log_file (str, optional): The log filename. If specified, a - ``FileHandler`` will be added to the logger. Defaults to None. - log_level (str): The log level of the handler. Defaults to - 'INFO'. If log level is 'DEBUG', distributed logs will be saved - during distributed training. - file_mode (str): The file mode used to open log file. Defaults to 'w'. - distributed (bool): Whether to save distributed logs, Defaults to - false. - file_handler_cfg (dict, optional): Configuration of file handler. - Defaults to None. If ``file_handler_cfg`` is not specified, - ``logging.FileHandler`` will be used by default. If it is - specified, the ``type`` key should be set. It can be - ``RotatingFileHandler``, ``TimedRotatingFileHandler``, - ``WatchedFileHandler`` or other file handlers, and the remaining - fields will be used to build the handler. - - Examples: - >>> file_handler_cfg = dict( - >>> type='TimedRotatingFileHandler', - >>> when='MIDNIGHT', - >>> interval=1, - >>> backupCount=365) - - `New in version 0.9.0.` - """ - - def __init__( - self, - name: str, - logger_name="visengine", - log_file: str | None = None, - log_level: int | str = "INFO", - file_mode: str = "w", - distributed=False, - file_handler_cfg: dict | None = None, - ): - Logger.__init__(self, logger_name) - ManagerMixin.__init__(self, name) - # Get rank in DDP mode. - if isinstance(log_level, str): - log_level = logging._nameToLevel[log_level] - global_rank = _get_rank() - device_id = _get_device_id() - - # Config stream_handler. If `rank != 0`. stream_handler can only - # export ERROR logs. - stream_handler = logging.StreamHandler(stream=sys.stdout) - # `StreamHandler` record month, day, hour, minute, and second - # timestamp. - stream_handler.setFormatter(MMFormatter(color=True, datefmt="%m/%d %H:%M:%S")) - # Only rank0 `StreamHandler` will log messages below error level. - if global_rank == 0: - stream_handler.setLevel(log_level) - else: - stream_handler.setLevel(logging.ERROR) - stream_handler.addFilter(FilterDuplicateWarning(logger_name)) - self.handlers.append(stream_handler) - - if log_file is not None: - world_size = _get_world_size() - is_distributed = (log_level <= logging.DEBUG or distributed) and world_size > 1 - if is_distributed: - filename, suffix = osp.splitext(osp.basename(log_file)) - hostname = _get_host_info() - if hostname: - filename = f"{filename}_{hostname}_device{device_id}_rank{global_rank}{suffix}" - else: - # Omit hostname if it is empty - filename = f"{filename}_device{device_id}_rank{global_rank}{suffix}" - log_file = osp.join(osp.dirname(log_file), filename) - # Save multi-ranks logs if distributed is True. The logs of rank0 - # will always be saved. - if global_rank == 0 or is_distributed: - if file_handler_cfg is not None: - assert "type" in file_handler_cfg - file_handler_type = file_handler_cfg.pop("type") - file_handlers_map = _get_logging_file_handlers() - if file_handler_type in file_handlers_map: - file_handler_cls = file_handlers_map[file_handler_type] - file_handler_cfg.setdefault("filename", log_file) - file_handler = file_handler_cls(**file_handler_cfg) - else: - raise ValueError(f"`logging.handlers` does not contain {file_handler_type}") - else: - # Here, the default behavior of the official - # logger is 'a'. Thus, we provide an interface to - # change the file mode to the default behavior. - # `FileHandler` is not supported to have colors, - # otherwise it will appear garbled. - file_handler = logging.FileHandler(log_file, file_mode) - - # `StreamHandler` record year, month, day hour, minute, - # and second timestamp. file_handler will only record logs - # without color to avoid garbled code saved in files. - file_handler.setFormatter(MMFormatter(color=False, datefmt="%Y/%m/%d %H:%M:%S")) - file_handler.setLevel(log_level) - file_handler.addFilter(FilterDuplicateWarning(logger_name)) - self.handlers.append(file_handler) - self._log_file = log_file - - @property - def log_file(self): - return self._log_file - - @classmethod - def get_current_instance(cls) -> "MMLogger": - """Get latest created ``MMLogger`` instance. - - :obj:`MMLogger` can call :meth:`get_current_instance` before any - instance has been created, and return a logger with the instance name - "mmengine". - - Returns: - MMLogger: Configured logger instance. - """ - if not cls._instance_dict: - cls.get_instance("visengine") - return super().get_current_instance() - - def callHandlers(self, record: LogRecord) -> None: - """Pass a record to all relevant handlers. - - Override ``callHandlers`` method in ``logging.Logger`` to avoid - multiple warning messages in DDP mode. Loop through all handlers of - the logger instance and its parents in the logger hierarchy. If no - handler was found, the record will not be output. - - Args: - record (LogRecord): A ``LogRecord`` instance contains logged - message. - """ - for handler in self.handlers: - if record.levelno >= handler.level: - handler.handle(record) - - def setLevel(self, level): - """Set the logging level of this logger. - - If ``logging.Logger.selLevel`` is called, all ``logging.Logger`` - instances managed by ``logging.Manager`` will clear the cache. Since - ``MMLogger`` is not managed by ``logging.Manager`` anymore, - ``MMLogger`` should override this method to clear caches of all - ``MMLogger`` instance which is managed by :obj:`ManagerMixin`. - - level must be an int or a str. - """ - self.level = logging._checkLevel(level) - _accquire_lock() - # The same logic as `logging.Manager._clear_cache`. - for logger in MMLogger._instance_dict.values(): - logger._cache.clear() - _release_lock() - - -def print_log(msg, logger: Logger | str | None = None, level=logging.INFO) -> None: - """Print a log message. - - Args: - msg (str): The message to be logged. - logger (Logger or str, optional): If the type of logger is - ``logging.Logger``, we directly use logger to log messages. - Some special loggers are: - - - "silent": No message will be printed. - - "current": Use latest created logger to log message. - - other str: Instance name of logger. The corresponding logger - will log message if it has been created, otherwise ``print_log`` - will raise a `ValueError`. - - None: The `print()` method will be used to print log messages. - level (int): Logging level. Only available when `logger` is a Logger - object, "current", or a created logger instance name. - """ - if logger is None: - print(msg) - elif isinstance(logger, logging.Logger): - logger.log(level, msg) - elif logger == "silent": - pass - elif logger == "current": - logger_instance = MMLogger.get_current_instance() - logger_instance.log(level, msg) - elif isinstance(logger, str): - # If the type of `logger` is `str`, but not with value of `current` or - # `silent`, we assume it indicates the name of the logger. If the - # corresponding logger has not been created, `print_log` will raise - # a `ValueError`. - if MMLogger.check_instance_created(logger): - logger_instance = MMLogger.get_instance(logger) - logger_instance.log(level, msg) - else: - raise ValueError(f"MMLogger: {logger} has not been created!") - else: - raise TypeError( - f'`logger` should be either a logging.Logger object, str, "silent", "current" or None, but got {type(logger)}' - ) - - -def _get_world_size(): - """Support using logging module without torch.""" - try: - # requires torch - from visengine.dist import get_world_size - except ImportError: - return 1 - else: - return get_world_size() - - -def _get_rank(): - """Support using logging module without torch.""" - try: - # requires torch - from visengine.dist import get_rank - except ImportError: - return 0 - else: - return get_rank() - - -def _get_device_id(): - """Get device id of current machine.""" - try: - import torch - except ImportError: - return 0 - else: - local_rank = int(os.getenv("LOCAL_RANK", "0")) - # TODO: return device id of npu and mlu. - if not torch.cuda.is_available(): - return local_rank - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", None) - if cuda_visible_devices is None: - num_device = torch.cuda.device_count() - cuda_visible_devices = list(range(num_device)) - else: - cuda_visible_devices = cuda_visible_devices.split(",") - try: - return int(cuda_visible_devices[local_rank]) - except ValueError: - # handle case for Multi-Instance GPUs - # see #1148 for details - return cuda_visible_devices[local_rank] - - -def _get_host_info() -> str: - """Get hostname and username. - - Return empty string if exception raised, e.g. ``getpass.getuser()`` will - lead to error in docker container - """ - host = "" - try: - host = f"{getuser()}@{gethostname()}" - except Exception as e: - warnings.warn(f"Host or user not found: {e!s}", stacklevel=2) - return host - - -def _get_logging_file_handlers() -> dict: - """Get additional file_handlers in ``logging.handlers``. - - Returns: - Dict: A map of file_handlers. - """ - file_handlers_map = {} - for module_name in dir(handlers): - if module_name.startswith("__"): - continue - _fh = getattr(handlers, module_name) - if inspect.isclass(_fh) and issubclass(_fh, logging.FileHandler): - file_handlers_map[module_name] = _fh - return file_handlers_map diff --git a/libs/visengine/visengine/logging/message_hub.py b/libs/visengine/visengine/logging/message_hub.py deleted file mode 100644 index 76f123a..0000000 --- a/libs/visengine/visengine/logging/message_hub.py +++ /dev/null @@ -1,470 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Union - -import numpy as np - -from visengine.utils import ManagerMixin - -from .history_buffer import HistoryBuffer -from .logger import print_log - -if TYPE_CHECKING: - import torch - - -class MessageHub(ManagerMixin): - """Message hub for component interaction. MessageHub is created and - accessed in the same way as ManagerMixin. - - ``MessageHub`` will record log information and runtime information. The - log information refers to the learning rate, loss, etc. of the model - during training phase, which will be stored as ``HistoryBuffer``. The - runtime information refers to the iter times, meta information of - runner etc., which will be overwritten by next update. - - Args: - name (str): Name of message hub used to get corresponding instance - globally. - log_scalars (dict, optional): Each key-value pair in the - dictionary is the name of the log information such as "loss", "lr", - "metric" and their corresponding values. The type of value must be - HistoryBuffer. Defaults to None. - runtime_info (dict, optional): Each key-value pair in the - dictionary is the name of the runtime information and their - corresponding values. Defaults to None. - resumed_keys (dict, optional): Each key-value pair in the - dictionary decides whether the key in :attr:`_log_scalars` and - :attr:`_runtime_info` will be serialized. - - Note: - Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or - :attr:`_runtime_info`. The corresponding value cannot be set - repeatedly. - - Examples: - >>> # create empty `MessageHub`. - >>> message_hub1 = MessageHub('name') - >>> log_scalars = dict(loss=HistoryBuffer()) - >>> runtime_info = dict(task='task') - >>> resumed_keys = dict(loss=True) - >>> # create `MessageHub` from data. - >>> message_hub2 = MessageHub( - >>> name='name', - >>> log_scalars=log_scalars, - >>> runtime_info=runtime_info, - >>> resumed_keys=resumed_keys) - """ - - def __init__( - self, - name: str, - log_scalars: dict | None = None, - runtime_info: dict | None = None, - resumed_keys: dict | None = None, - ): - super().__init__(name) - self._log_scalars = self._parse_input("log_scalars", log_scalars) - self._runtime_info = self._parse_input("runtime_info", runtime_info) - self._resumed_keys = self._parse_input("resumed_keys", resumed_keys) - - for value in self._log_scalars.values(): - assert isinstance(value, HistoryBuffer), ( - f"The type of log_scalars'value must be HistoryBuffer, but got {type(value)}" - ) - - for key in self._resumed_keys.keys(): - assert key in self._log_scalars or key in self._runtime_info, ( - f"Key in `resumed_keys` must contained in `log_scalars` or `runtime_info`, but got {key}" - ) - - @classmethod - def get_current_instance(cls) -> "MessageHub": - """Get latest created ``MessageHub`` instance. - - :obj:`MessageHub` can call :meth:`get_current_instance` before any - instance has been created, and return a message hub with the instance - name "mmengine". - - Returns: - MessageHub: Empty ``MessageHub`` instance. - """ - if not cls._instance_dict: - cls.get_instance("mmengine") - return super().get_current_instance() - - def update_scalar( - self, - key: str, - value: Union[int, float, np.ndarray, "torch.Tensor"], - count: int = 1, - resumed: bool = True, - ) -> None: - """Update :attr:_log_scalars. - - Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key - ``HistoryBuffer`` has been created, ``value`` and ``count`` is the - argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar`` - will create an ``HistoryBuffer`` with value and count via the - constructor of ``HistoryBuffer``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> # create loss `HistoryBuffer` with value=1, count=1 - >>> message_hub.update_scalar('loss', 1) - >>> # update loss `HistoryBuffer` with value - >>> message_hub.update_scalar('loss', 3) - >>> message_hub.update_scalar('loss', 3, resumed=False) - AssertionError: loss used to be true, but got false now. resumed - keys cannot be modified repeatedly' - - Note: - The ``resumed`` argument needs to be consistent for the same - ``key``. - - Args: - key (str): Key of ``HistoryBuffer``. - value (torch.Tensor or np.ndarray or int or float): Value of log. - count (torch.Tensor or np.ndarray or int or float): Accumulation - times of log, defaults to 1. `count` will be used in smooth - statistics. - resumed (str): Whether the corresponding ``HistoryBuffer`` - could be resumed. Defaults to True. - """ - self._set_resumed_keys(key, resumed) - checked_value = self._get_valid_value(value) - assert isinstance(count, int), f"The type of count must be int. but got {type(count): {count}}" - if key in self._log_scalars: - self._log_scalars[key].update(checked_value, count) - else: - self._log_scalars[key] = HistoryBuffer([checked_value], [count]) - - def update_scalars(self, log_dict: dict, resumed: bool = True) -> None: - """Update :attr:`_log_scalars` with a dict. - - ``update_scalars`` iterates through each pair of log_dict key-value, - and calls ``update_scalar``. If type of value is dict, the value should - be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in - ``log_dict`` has the same resume option. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``log_dict``. - - Args: - log_dict (str): Used for batch updating :attr:`_log_scalars`. - resumed (bool): Whether all ``HistoryBuffer`` referred in - log_dict should be resumed. Defaults to True. - - Examples: - >>> message_hub = MessageHub.get_instance('mmengine') - >>> log_dict = dict(a=1, b=2, c=3) - >>> message_hub.update_scalars(log_dict) - >>> # The default count of `a`, `b` and `c` is 1. - >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) - >>> message_hub.update_scalars(log_dict) - >>> # The count of `c` is 2. - """ - assert isinstance(log_dict, dict), f"`log_dict` must be a dict!, but got {type(log_dict)}" - for log_name, log_val in log_dict.items(): - if isinstance(log_val, dict): - assert "value" in log_val, f"value must be defined in {log_val}" - count = self._get_valid_value(log_val.get("count", 1)) - value = log_val["value"] - else: - count = 1 - value = log_val - assert isinstance(count, int), f"The type of count must be int. but got {type(count): {count}}" - self.update_scalar(log_name, value, count, resumed) - - def update_info(self, key: str, value: Any, resumed: bool = True) -> None: - """Update runtime information. - - The key corresponding runtime information will be overwritten each - time calling ``update_info``. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``key``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> message_hub.update_info('iter', 100) - - Args: - key (str): Key of runtime information. - value (Any): Value of runtime information. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - self._set_resumed_keys(key, resumed) - self._runtime_info[key] = value - - def pop_info(self, key: str, default: Any | None = None) -> Any: - """Remove runtime information by key. If the key does not exist, this - method will return the default value. - - Args: - key (str): Key of runtime information. - default (Any, optional): The default returned value for the - given key. - - Returns: - Any: The runtime information if the key exists. - """ - return self._runtime_info.pop(key, default) - - def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None: - """Update runtime information with dictionary. - - The key corresponding runtime information will be overwritten each - time calling ``update_info``. - - Note: - The ``resumed`` argument needs to be consistent for the same - ``info_dict``. - - Examples: - >>> message_hub = MessageHub(name='name') - >>> message_hub.update_info({'iter': 100}) - - Args: - info_dict (str): Runtime information dictionary. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - assert isinstance(info_dict, dict), f"`log_dict` must be a dict!, but got {type(info_dict)}" - for key, value in info_dict.items(): - self.update_info(key, value, resumed=resumed) - - def _set_resumed_keys(self, key: str, resumed: bool) -> None: - """Set corresponding resumed keys. - - This method is called by ``update_scalar``, ``update_scalars`` and - ``update_info`` to set the corresponding key is true or false in - :attr:`_resumed_keys`. - - Args: - key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`. - resumed (bool): Whether the corresponding ``HistoryBuffer`` - could be resumed. - """ - if key not in self._resumed_keys: - self._resumed_keys[key] = resumed - else: - assert self._resumed_keys[key] == resumed, ( - f"{key} used to be {self._resumed_keys[key]}, but got {{resumed}} now. resumed keys cannot be modified repeatedly." - ) - - @property - def log_scalars(self) -> OrderedDict: - """Get all ``HistoryBuffer`` instances. - - Note: - Considering the large memory footprint of history buffers in the - post-training, :meth:`get_scalar` will return a reference of - history buffer rather than a copy. - - Returns: - OrderedDict: All ``HistoryBuffer`` instances. - """ - return self._log_scalars - - @property - def runtime_info(self) -> OrderedDict: - """Get all runtime information. - - Returns: - OrderedDict: A copy of all runtime information. - """ - return self._runtime_info - - def get_scalar(self, key: str) -> HistoryBuffer: - """Get ``HistoryBuffer`` instance by key. - - Note: - Considering the large memory footprint of history buffers in the - post-training, :meth:`get_scalar` will not return a reference of - history buffer rather than a copy. - - Args: - key (str): Key of ``HistoryBuffer``. - - Returns: - HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the - key exists. - """ - if key not in self.log_scalars: - raise KeyError( - f"{key} is not found in Messagehub.log_buffers: instance name is: {MessageHub.instance_name}" - ) - return self.log_scalars[key] - - def get_info(self, key: str, default: Any | None = None) -> Any: - """Get runtime information by key. If the key does not exist, this - method will return default information. - - Args: - key (str): Key of runtime information. - default (Any, optional): The default returned value for the - given key. - - Returns: - Any: A copy of corresponding runtime information if the key exists. - """ - if key not in self.runtime_info: - return default - else: - # TODO: There are restrictions on objects that can be saved - # return copy.deepcopy(self._runtime_info[key]) - return self._runtime_info[key] - - def _get_valid_value( - self, - value: Union["torch.Tensor", np.ndarray, np.number, int, float], - ) -> int | float: - """Convert value to python built-in type. - - Args: - value (torch.Tensor or np.ndarray or np.number or int or float): - value of log. - - Returns: - float or int: python built-in type value. - """ - if isinstance(value, np.ndarray | np.number): - assert value.size == 1 - value = value.item() - elif isinstance(value, int | float): - value = value - else: - # check whether value is torch.Tensor but don't want - # to import torch in this file - assert hasattr(value, "numel") and value.numel() == 1 - value = value.item() - return value # type: ignore - - def state_dict(self) -> dict: - """Returns a dictionary containing log scalars, runtime information and - resumed keys, which should be resumed. - - The returned ``state_dict`` can be loaded by :meth:`load_state_dict`. - - Returns: - dict: A dictionary contains ``log_scalars``, ``runtime_info`` and - ``resumed_keys``. - """ - saved_scalars = OrderedDict() - saved_info = OrderedDict() - - for key, value in self._log_scalars.items(): - if self._resumed_keys.get(key, False): - saved_scalars[key] = copy.deepcopy(value) - - for key, value in self._runtime_info.items(): - if self._resumed_keys.get(key, False): - try: - saved_info[key] = copy.deepcopy(value) - except: # noqa: E722 - print_log( - f"{key} in message_hub cannot be copied, just return its reference. ", - logger="current", - level=logging.WARNING, - ) - saved_info[key] = value - return { - "log_scalars": saved_scalars, - "runtime_info": saved_info, - "resumed_keys": self._resumed_keys, - } - - def load_state_dict(self, state_dict: Union["MessageHub", dict]) -> None: - """Loads log scalars, runtime information and resumed keys from - ``state_dict`` or ``message_hub``. - - If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it - will only make copies of data which should be resumed from the source - ``message_hub``. - - If ``state_dict`` is a ``message_hub`` instance, it will make copies of - all data from the source message_hub. We suggest to load data from - ``dict`` rather than a ``MessageHub`` instance. - - Args: - state_dict (dict or MessageHub): A dictionary contains key - ``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a - MessageHub instance. - """ - if isinstance(state_dict, dict): - for key in ("log_scalars", "runtime_info", "resumed_keys"): - assert key in state_dict, f"The loaded `state_dict` of `MessageHub` must contain key: `{key}`" - # The old `MessageHub` could save non-HistoryBuffer `log_scalars`, - # therefore the loaded `log_scalars` needs to be filtered. - for key, value in state_dict["log_scalars"].items(): - if not isinstance(value, HistoryBuffer): - print_log( - f"{key} in message_hub is not HistoryBuffer, just skip resuming it.", - logger="current", - level=logging.WARNING, - ) - continue - self.log_scalars[key] = value - - for key, value in state_dict["runtime_info"].items(): - try: - self._runtime_info[key] = copy.deepcopy(value) - except: # noqa: E722 - print_log( - f"{key} in message_hub cannot be copied, just return its reference.", - logger="current", - level=logging.WARNING, - ) - self._runtime_info[key] = value - - for key, value in state_dict["resumed_keys"].items(): - if key not in set(self.log_scalars.keys()) | set(self._runtime_info.keys()): - print_log( - f"resumed key: {key} is not defined in message_hub, just skip resuming this key.", - logger="current", - level=logging.WARNING, - ) - continue - elif not value: - print_log( - f"Although resumed key: {key} is False, {key} " - "will still be loaded this time. This key will " - "not be saved by the next calling of " - "`MessageHub.state_dict()`", - logger="current", - level=logging.WARNING, - ) - self._resumed_keys[key] = value - - # Since some checkpoints saved serialized `message_hub` instance, - # `load_state_dict` support loading `message_hub` instance for - # compatibility - else: - self._log_scalars = copy.deepcopy(state_dict._log_scalars) - self._runtime_info = copy.deepcopy(state_dict._runtime_info) - self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) - - def _parse_input(self, name: str, value: Any) -> OrderedDict: - """Parse input value. - - Args: - name (str): name of input value. - value (Any): Input value. - - Returns: - dict: Parsed input value. - """ - if value is None: - return OrderedDict() - elif isinstance(value, dict): - return OrderedDict(value) - else: - raise TypeError(f"{name} should be a dict or `None`, but got {type(name)}") diff --git a/libs/visengine/visengine/model/__init__.py b/libs/visengine/visengine/model/__init__.py deleted file mode 100644 index c9b203f..0000000 --- a/libs/visengine/visengine/model/__init__.py +++ /dev/null @@ -1,83 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.utils.dl_utils import TORCH_VERSION -from visengine.utils.version_utils import digit_version -from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from .base_module import BaseModule, ModuleDict, ModuleList, Sequential -from .utils import ( - convert_sync_batchnorm, - detect_anomalous_params, - merge_dict, - revert_sync_batchnorm, - stack_batch, -) -from .weight_init import ( - BaseInit, - Caffe2XavierInit, - ConstantInit, - KaimingInit, - NormalInit, - PretrainedInit, - TruncNormalInit, - UniformInit, - XavierInit, - bias_init_with_prob, - caffe2_xavier_init, - constant_init, - initialize, - kaiming_init, - normal_init, - trunc_normal_init, - uniform_init, - update_init_info, - xavier_init, -) -from .wrappers import ( - MMDistributedDataParallel, - MMSeparateDistributedDataParallel, - is_model_wrapper, -) - -__all__ = [ - "BaseDataPreprocessor", - "BaseInit", - "BaseModel", - "BaseModule", - "Caffe2XavierInit", - "ConstantInit", - # "ExponentialMovingAverage", # Not imported - EMA is in hooks module - "ImgDataPreprocessor", - "KaimingInit", - "MMDistributedDataParallel", - "MMSeparateDistributedDataParallel", - "ModuleDict", - "ModuleList", - # "MomentumAnnealingEMA", # Not imported - EMA is in hooks module - "NormalInit", - "PretrainedInit", - "Sequential", - "TruncNormalInit", - "UniformInit", - "XavierInit", - "bias_init_with_prob", - "caffe2_xavier_init", - "constant_init", - "convert_sync_batchnorm", - "detect_anomalous_params", - "initialize", - "is_model_wrapper", - "kaiming_init", - "merge_dict", - "normal_init", - "revert_sync_batchnorm", - "stack_batch", - "trunc_normal_init", - "uniform_init", - "update_init_info", - "xavier_init", -] - -from .wrappers import MMFullyShardedDataParallel - -__all__.append("MMFullyShardedDataParallel") diff --git a/libs/visengine/visengine/model/base_model/__init__.py b/libs/visengine/visengine/model/base_model/__init__.py deleted file mode 100644 index cb93622..0000000 --- a/libs/visengine/visengine/model/base_model/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base_model import BaseModel -from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor - -__all__ = ["BaseDataPreprocessor", "BaseModel", "ImgDataPreprocessor"] diff --git a/libs/visengine/visengine/model/base_model/base_model.py b/libs/visengine/visengine/model/base_model/base_model.py deleted file mode 100644 index 789b00b..0000000 --- a/libs/visengine/visengine/model/base_model/base_model.py +++ /dev/null @@ -1,411 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from collections import OrderedDict - -import torch -import torch.nn as nn - -from visengine.logging import MMLogger -from visengine.registry import MODELS -from visengine.utils import is_list_of - -from ..base_module import BaseModule -from .data_preprocessor import BaseDataPreprocessor - - -class BaseModel(BaseModule): - """Base class for all algorithmic models. - - BaseModel implements the basic functions of the algorithmic model, such as - weights initialize, batch inputs preprocess(see more information in - :class:`BaseDataPreprocessor`), parse losses, and update model parameters. - - Subclasses inherit from BaseModel only need to implement the forward - method, which implements the logic to calculate loss and predictions, - then can be trained in the runner. - - Examples: - >>> @MODELS.register_module(force=True) - >>> class ToyModel(BaseModel): - >>> - >>> def __init__(self): - >>> super().__init__() - >>> self.backbone = nn.Sequential() - >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) - >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) - >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) - >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) - >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) - >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) - >>> - >>> self.criterion = nn.CrossEntropyLoss() - >>> - >>> def forward(self, batch_inputs, data_samples, mode='tensor'): - >>> data_samples = torch.stack(data_samples) - >>> if mode == 'tensor': - >>> return self.backbone(batch_inputs) - >>> elif mode == 'predict': - >>> feats = self.backbone(batch_inputs) - >>> predictions = torch.argmax(feats, 1) - >>> return predictions - >>> elif mode == 'loss': - >>> feats = self.backbone(batch_inputs) - >>> loss = self.criterion(feats, data_samples) - >>> return dict(loss=loss) - - Args: - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. - init_cfg (dict, optional): The weight initialized config for - :class:`BaseModule`. - - Attributes: - data_preprocessor (:obj:`BaseDataPreprocessor`): Used for - pre-processing data sampled by dataloader to the format accepted by - :meth:`forward`. - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__( - self, - data_preprocessor: dict | nn.Module | None = None, - init_cfg: dict | None = None, - ): - super().__init__(init_cfg) - if data_preprocessor is None: - data_preprocessor = {"type": "BaseDataPreprocessor"} - if isinstance(data_preprocessor, nn.Module): - self.data_preprocessor = data_preprocessor - elif isinstance(data_preprocessor, dict): - self.data_preprocessor = MODELS.build(data_preprocessor) - else: - raise TypeError( - f"data_preprocessor should be a `dict` or `nn.Module` instance, but got {type(data_preprocessor)}" - ) - - def train_step(self, data: dict | tuple | list, optim_wrapper) -> dict[str, torch.Tensor]: - """Implements the default model training process including - preprocessing, model forward propagation, loss calculation, - optimization, and back-propagation. - - During non-distributed training. If subclasses do not override the - :meth:`train_step`, :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` will call this method to update model - parameters. The default parameter update process is as follows: - - 1. Calls ``self.data_processor(data, training=False)`` to collect - batch_inputs and corresponding data_samples(labels). - 2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw - loss - 3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to - backward and dict of loss tensor used to log messages. - 4. Calls ``optim_wrapper.update_params(loss)`` to update model. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapper): OptimWrapper instance - used to update model parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - logger = MMLogger.get_current_instance() - logger.debug("[BaseModel] Starting train_step") - - # Enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.data_preprocessor(data, True) - logger.debug("[BaseModel] Data preprocessed, running forward pass in loss mode") - losses = self._run_forward(data, mode="loss") # type: ignore - - # Debug log: Check all loss components - logger.debug(f"[BaseModel] Raw losses from forward: {list(losses.keys())}") - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - logger.debug(f"[BaseModel] {loss_name}: {loss_value.item()}") - if torch.isnan(loss_value): - logger.error(f"[BaseModel] NaN detected in {loss_name}!") - if torch.isinf(loss_value): - logger.error(f"[BaseModel] Inf detected in {loss_name}!") - - parsed_losses, log_vars = self.parse_losses(losses) # type: ignore - logger.debug(f"[BaseModel] Parsed total loss: {parsed_losses.item()}") - - optim_wrapper.update_params(parsed_losses) - return log_vars - - def val_step(self, data: tuple | dict | list) -> list: - """Gets the predictions of given data. - - Calls ``self.data_preprocessor(data, False)`` and - ``self(inputs, data_sample, mode='predict')`` in order. Return the - predictions which will be passed to evaluator. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.data_preprocessor(data, False) - return self._run_forward(data, mode="predict") # type: ignore - - def test_step(self, data: dict | tuple | list) -> list: - """``BaseModel`` implements ``test_step`` the same as ``val_step``. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - data = self.data_preprocessor(data, False) - return self._run_forward(data, mode="predict") # type: ignore - - def parse_losses(self, losses: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: - """Parses the raw outputs (losses) of the network. - - Args: - losses (dict): Raw output of the network, which usually contain - losses and other necessary information. - - Returns: - tuple[Tensor, dict]: There are two elements. The first is the - loss tensor passed to optim_wrapper which may be a weighted sum - of all losses, and the second is log_vars which will be sent to - the logger. - """ - logger = MMLogger.get_current_instance() - logger.debug(f"[BaseModel] Parsing losses: {list(losses.keys())}") - - log_vars = [] - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - loss_mean = loss_value.mean() - log_vars.append([loss_name, loss_mean]) - logger.debug(f"[BaseModel] {loss_name} mean: {loss_mean.item()}") - - if torch.isnan(loss_mean): - logger.error(f"[BaseModel] NaN detected in mean of {loss_name}!") - if torch.isinf(loss_mean): - logger.error(f"[BaseModel] Inf detected in mean of {loss_name}!") - - elif is_list_of(loss_value, torch.Tensor): - loss_sum = sum(_loss.mean() for _loss in loss_value) - log_vars.append([loss_name, loss_sum]) - logger.debug(f"[BaseModel] {loss_name} sum of means: {loss_sum.item()}") - - if torch.isnan(loss_sum): - logger.error(f"[BaseModel] NaN detected in sum of {loss_name}!") - if torch.isinf(loss_sum): - logger.error(f"[BaseModel] Inf detected in sum of {loss_name}!") - else: - raise TypeError(f"{loss_name} is not a tensor or list of tensors") - - loss = sum(value for key, value in log_vars if "loss" in key) - logger.debug(f"[BaseModel] Total loss (sum of all loss components): {loss.item()}") - - if torch.isnan(loss): - logger.error(f"[BaseModel] NaN detected in total loss!") - if torch.isinf(loss): - logger.error(f"[BaseModel] Inf detected in total loss!") - - log_vars.insert(0, ["loss", loss]) - log_vars = OrderedDict(log_vars) # type: ignore - - return loss, log_vars # type: ignore - - def to(self, *args, **kwargs) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.to` - additionally. - - Returns: - nn.Module: The model itself. - """ - - # Since Torch has not officially merged - # the npu-related fields, using the _parse_to function - # directly will cause the NPU to not be found. - # Here, the input parameters are processed to avoid errors. - if args and isinstance(args[0], str) and "npu" in args[0]: - import torch_npu - - args = ( - next(iter(args)).replace( - "npu", - (torch_npu.npu.native_device if hasattr(torch_npu.npu, "native_device") else "privateuseone"), - ), - ) - if kwargs and "npu" in str(kwargs.get("device", "")): - import torch_npu - - kwargs["device"] = kwargs["device"].replace( - "npu", - (torch_npu.npu.native_device if hasattr(torch_npu.npu, "native_device") else "privateuseone"), - ) - - device = torch._C._nn._parse_to(*args, **kwargs)[0] - if device is not None: - self._set_device(torch.device(device)) - return super().to(*args, **kwargs) - - def cuda( - self, - device: int | str | torch.device | None = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.cuda` - additionally. - - Returns: - nn.Module: The model itself. - """ - if device is None or isinstance(device, int): - device = torch.device("cuda", index=device) - self._set_device(torch.device(device)) - return super().cuda(device) - - def musa( - self, - device: int | str | torch.device | None = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.musa` - additionally. - - Returns: - nn.Module: The model itself. - """ - if device is None or isinstance(device, int): - device = torch.device("musa", index=device) - self._set_device(torch.device(device)) - return super().musa(device) - - def mlu( - self, - device: int | str | torch.device | None = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.mlu` - additionally. - - Returns: - nn.Module: The model itself. - """ - device = torch.device("mlu", torch.mlu.current_device()) - self._set_device(device) - return super().mlu() - - def npu( - self, - device: int | str | torch.device | None = None, - ) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.npu` - additionally. - - Returns: - nn.Module: The model itself. - - Note: - This generation of NPU(Ascend910) does not support - the use of multiple cards in a single process, - so the index here needs to be consistent with the default device - """ - device = torch.npu.current_device() - self._set_device(device) - return super().npu() - - def cpu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to call :meth:`BaseDataPreprocessor.cpu` - additionally. - - Returns: - nn.Module: The model itself. - """ - self._set_device(torch.device("cpu")) - return super().cpu() - - def _set_device(self, device: torch.device) -> None: - """Recursively set device for `BaseDataPreprocessor` instance. - - Args: - device (torch.device): the desired device of the parameters and - buffers in this module. - """ - - def apply_fn(module): - if not isinstance(module, BaseDataPreprocessor): - return - if device is not None: - module._device = device - - self.apply(apply_fn) - - @abstractmethod - def forward( - self, - inputs: torch.Tensor, - data_samples: list | None = None, - mode: str = "tensor", - ) -> dict[str, torch.Tensor] | list: - """Returns losses or predictions of training, validation, testing, and - simple inference process. - - ``forward`` method of BaseModel is an abstract method, its subclasses - must implement this method. - - Accepts ``batch_inputs`` and ``data_sample`` processed by - :attr:`data_preprocessor`, and returns results according to mode - arguments. - - During non-distributed training, validation, and testing process, - ``forward`` will be called by ``BaseModel.train_step``, - ``BaseModel.val_step`` and ``BaseModel.test_step`` directly. - - During distributed data parallel training process, - ``MMSeparateDistributedDataParallel.train_step`` will first call - ``DistributedDataParallel.forward`` to enable automatic - gradient synchronization, and then call ``forward`` to get training - loss. - - Args: - inputs (torch.Tensor): batch input tensor collated by - :attr:`data_preprocessor`. - data_samples (list, optional): - data samples collated by :attr:`data_preprocessor`. - mode (str): mode should be one of ``loss``, ``predict`` and - ``tensor`` - - - ``loss``: Called by ``train_step`` and return loss ``dict`` - used for logging - - ``predict``: Called by ``val_step`` and ``test_step`` - and return list of results used for computing metric. - - ``tensor``: Called by custom use to get ``Tensor`` type - results. - - Returns: - dict or list: - - If ``mode == loss``, return a ``dict`` of loss tensor used - for backward and logging. - - If ``mode == predict``, return a ``list`` of inference - results. - - If ``mode == tensor``, return a tensor or ``tuple`` of tensor - or ``dict`` of tensor for custom use. - """ - - def _run_forward(self, data: dict | tuple | list, mode: str) -> dict[str, torch.Tensor] | list: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, list | tuple): - results = self(*data, mode=mode) - else: - raise TypeError(f"Output of `data_preprocessor` should be list, tuple or dict, but got {type(data)}") - return results diff --git a/libs/visengine/visengine/model/base_model/data_preprocessor.py b/libs/visengine/visengine/model/base_model/data_preprocessor.py deleted file mode 100644 index c2ae832..0000000 --- a/libs/visengine/visengine/model/base_model/data_preprocessor.py +++ /dev/null @@ -1,303 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import math -from collections.abc import Mapping, Sequence -from typing import Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from visengine.registry import MODELS -from visengine.structures import BaseDataElement -from visengine.utils import is_seq_of - -from ..utils import stack_batch - -CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, None] - - -@MODELS.register_module(force=True) -class BaseDataPreprocessor(nn.Module): - """Base data pre-processor used for copying data to the target device. - - Subclasses inherit from ``BaseDataPreprocessor`` could override the - forward method to implement custom data pre-processing, such as - batch-resize, MixUp, or CutMix. - - Args: - non_blocking (bool): Whether block current process - when transferring data to device. - New in version 0.3.0. - - Note: - Data dictionary returned by dataloader must be a dict and at least - contain the ``inputs`` key. - """ - - def __init__(self, non_blocking: bool | None = False): - super().__init__() - self._non_blocking = non_blocking - self._device = torch.device("cpu") - - def cast_data(self, data: CastData) -> CastData: - """Copying data to the target device. - - Args: - data (dict): Data returned by ``DataLoader``. - - Returns: - CollatedResult: Inputs and data sample at target device. - """ - if isinstance(data, Mapping): - return {key: self.cast_data(data[key]) for key in data} - elif isinstance(data, str | bytes) or data is None: - return data - elif isinstance(data, tuple) and hasattr(data, "_fields"): - # namedtuple - return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # yapf:disable - elif isinstance(data, Sequence): - return type(data)(self.cast_data(sample) for sample in data) # type: ignore # yapf:disable - elif isinstance(data, torch.Tensor | BaseDataElement): - return data.to(self.device, non_blocking=self._non_blocking) - else: - return data - - def forward(self, data: dict, training: bool = False) -> dict | list: - """Preprocesses the data into the model input format. - - After the data pre-processing of :meth:`cast_data`, ``forward`` - will stack the input tensor list to a batch tensor at the first - dimension. - - Args: - data (dict): Data returned by dataloader - training (bool): Whether to enable training time augmentation. - - Returns: - dict or list: Data in the same format as the model input. - """ - return self.cast_data(data) # type: ignore - - @property - def device(self): - return self._device - - def to(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - - # Since Torch has not officially merged - # the npu-related fields, using the _parse_to function - # directly will cause the NPU to not be found. - # Here, the input parameters are processed to avoid errors. - if args and isinstance(args[0], str) and "npu" in args[0]: - args = (next(iter(args)).replace("npu", torch.npu.native_device),) - if kwargs and "npu" in str(kwargs.get("device", "")): - kwargs["device"] = kwargs["device"].replace("npu", torch.npu.native_device) - - device = torch._C._nn._parse_to(*args, **kwargs)[0] - if device is not None: - self._device = torch.device(device) - return super().to(*args, **kwargs) - - def cuda(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.cuda.current_device()) - return super().cuda() - - def musa(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.musa.current_device()) - return super().musa() - - def npu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.npu.current_device()) - return super().npu() - - def mlu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device(torch.mlu.current_device()) - return super().mlu() - - def cpu(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the :attr:`device` - - Returns: - nn.Module: The model itself. - """ - self._device = torch.device("cpu") - return super().cpu() - - -@MODELS.register_module(force=True) -class ImgDataPreprocessor(BaseDataPreprocessor): - """Image pre-processor for normalization and bgr to rgb conversion. - - Accepts the data sampled by the dataloader, and preprocesses it into the - format of the model input. ``ImgDataPreprocessor`` provides the - basic data pre-processing as follows - - - Collates and moves data to the target device. - - Converts inputs from bgr to rgb if the shape of input is (3, H, W). - - Normalizes image with defined std and mean. - - Pads inputs to the maximum size of current batch with defined - ``pad_value``. The padding size can be divisible by a defined - ``pad_size_divisor`` - - Stack inputs to batch_inputs. - - For ``ImgDataPreprocessor``, the dimension of the single inputs must be - (3, H, W). - - Note: - ``ImgDataPreprocessor`` and its subclass is built in the - constructor of :class:`BaseDataset`. - - Args: - mean (Sequence[float or int], optional): The pixel mean of image - channels. If ``bgr_to_rgb=True`` it means the mean value of R, - G, B channels. If the length of `mean` is 1, it means all - channels have the same mean value, or the input is a gray image. - If it is not specified, images will not be normalized. Defaults - None. - std (Sequence[float or int], optional): The pixel standard deviation of - image channels. If ``bgr_to_rgb=True`` it means the standard - deviation of R, G, B channels. If the length of `std` is 1, - it means all channels have the same standard deviation, or the - input is a gray image. If it is not specified, images will - not be normalized. Defaults None. - pad_size_divisor (int): The size of padded image should be - divisible by ``pad_size_divisor``. Defaults to 1. - pad_value (float or int): The padded pixel value. Defaults to 0. - bgr_to_rgb (bool): whether to convert image from BGR to RGB. - Defaults to False. - rgb_to_bgr (bool): whether to convert image from RGB to RGB. - Defaults to False. - non_blocking (bool): Whether block current process - when transferring data to device. - New in version v0.3.0. - - Note: - if images do not need to be normalized, `std` and `mean` should be - both set to None, otherwise both of them should be set to a tuple of - corresponding values. - """ - - def __init__( - self, - mean: Sequence[float | int] | None = None, - std: Sequence[float | int] | None = None, - pad_size_divisor: int = 1, - pad_value: float | int = 0, - bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - non_blocking: bool | None = False, - ): - super().__init__(non_blocking) - assert not (bgr_to_rgb and rgb_to_bgr), "`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time" - assert (mean is None) == (std is None), "mean and std should be both None or tuple" - if mean is not None: - assert len(mean) == 3 or len(mean) == 1, ( - f"`mean` should have 1 or 3 values, to be compatible with RGB or gray image, but got {len(mean)} values" - ) - assert len(std) == 3 or len(std) == 1, ( # type: ignore - "`std` should have 1 or 3 values, to be compatible with RGB " # type: ignore - f"or gray image, but got {len(std)} values" - ) # type: ignore - self._enable_normalize = True - self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False) - self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False) - else: - self._enable_normalize = False - self._channel_conversion = rgb_to_bgr or bgr_to_rgb - self.pad_size_divisor = pad_size_divisor - self.pad_value = pad_value - - def forward(self, data: dict, training: bool = False) -> dict | list: - """Performs normalization, padding and bgr2rgb conversion based on - ``BaseDataPreprocessor``. - - Args: - data (dict): Data sampled from dataset. If the collate - function of DataLoader is :obj:`pseudo_collate`, data will be a - list of dict. If collate function is :obj:`default_collate`, - data will be a tuple with batch input tensor and list of data - samples. - training (bool): Whether to enable training time augmentation. If - subclasses override this method, they can perform different - preprocessing strategies for training and testing based on the - value of ``training``. - - Returns: - dict or list: Data in the same format as the model input. - """ - data = self.cast_data(data) # type: ignore - _batch_inputs = data["inputs"] # type: ignore - # Process data with `pseudo_collate`. - if is_seq_of(_batch_inputs, torch.Tensor): - batch_inputs = [] - for _batch_input in _batch_inputs: - # channel transform - if self._channel_conversion: - _batch_input = _batch_input[[2, 1, 0], ...] # type: ignore - # Convert to float after channel conversion to ensure - # efficiency - _batch_input = _batch_input.float() # type: ignore - # Normalization. - if self._enable_normalize: - if self.mean.shape[0] == 3: - assert _batch_input.dim() == 3 and _batch_input.shape[0] == 3, ( - "If the mean has 3 values, the input tensor " - "should in shape of (3, H, W), but got the tensor " - f"with shape {_batch_input.shape}" - ) - _batch_input = (_batch_input - self.mean) / self.std - batch_inputs.append(_batch_input) - # Pad and stack Tensor. - batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor, self.pad_value) - # Process data with `default_collate`. - elif isinstance(_batch_inputs, torch.Tensor): - assert _batch_inputs.dim() == 4, ( - f"The input of `ImgDataPreprocessor` should be a NCHW tensor or a list of tensor, but got a tensor with shape: {_batch_inputs.shape}" - ) - if self._channel_conversion: - _batch_inputs = _batch_inputs[:, [2, 1, 0], ...] - # Convert to float after channel conversion to ensure - # efficiency - _batch_inputs = _batch_inputs.float() - if self._enable_normalize: - _batch_inputs = (_batch_inputs - self.mean) / self.std - h, w = _batch_inputs.shape[2:] - target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor - target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor - pad_h = target_h - h - pad_w = target_w - w - batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h), "constant", self.pad_value) - else: - raise TypeError( - f"Output of `cast_data` should be a dict of list/tuple with inputs and data_samples, but got {type(data)}: {data}" - ) # type: ignore - data["inputs"] = batch_inputs # type: ignore - data.setdefault("data_samples", None) # type: ignore - return data # type: ignore diff --git a/libs/visengine/visengine/model/base_module.py b/libs/visengine/visengine/model/base_module.py deleted file mode 100644 index ede52ca..0000000 --- a/libs/visengine/visengine/model/base_module.py +++ /dev/null @@ -1,229 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -from abc import ABCMeta -from collections import defaultdict -from collections.abc import Iterable -from logging import FileHandler - -import torch.nn as nn - -from visengine.dist import master_only -from visengine.logging import MMLogger, print_log - -from .weight_init import PretrainedInit, initialize, update_init_info -from .wrappers.utils import is_model_wrapper - - -class BaseModule(nn.Module, metaclass=ABCMeta): - """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of - ``torch.nn.Module`` with additional functionality of parameter - initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly - adds three attributes. - - - ``init_cfg``: the config to control the initialization. - - ``init_weights``: The function of parameter initialization and recording - initialization information. - - ``_params_init_info``: Used to track the parameter initialization - information. This attribute only exists during executing the - ``init_weights``. - - Note: - :obj:`PretrainedInit` has a higher priority than any other - initializer. The loaded pretrained weights will overwrite - the previous initialized weights. - - Args: - init_cfg (dict or List[dict], optional): Initialization config dict. - """ - - def __init__(self, init_cfg: dict | list[dict] | None = None): - """Initialize BaseModule, inherited from `torch.nn.Module`""" - - # NOTE init_cfg can be defined in different levels, but init_cfg - # in low levels has a higher priority. - - super().__init__() - # define default value of init_cfg instead of hard code - # in init_weights() function - self._is_init = False - - self.init_cfg = copy.deepcopy(init_cfg) - - # Backward compatibility in derived classes - # if pretrained is not None: - # warnings.warn('DeprecationWarning: pretrained is a deprecated \ - # key, please consider using init_cfg') - # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) - - @property - def is_init(self): - return self._is_init - - @is_init.setter - def is_init(self, value): - self._is_init = value - - def init_weights(self): - """Initialize the weights.""" - - is_top_level_module = False - # check if it is top-level module - if not hasattr(self, "_params_init_info"): - # The `_params_init_info` is used to record the initialization - # information of the parameters - # the key should be the obj:`nn.Parameter` of model and the value - # should be a dict containing - # - init_info (str): The string that describes the initialization. - # - tmp_mean_value (FloatTensor): The mean of the parameter, - # which indicates whether the parameter has been modified. - # this attribute would be deleted after all parameters - # is initialized. - self._params_init_info = defaultdict(dict) - is_top_level_module = True - - # Initialize the `_params_init_info`, - # When detecting the `tmp_mean_value` of - # the corresponding parameter is changed, update related - # initialization information - for _name, param in self.named_parameters(): - self._params_init_info[param]["init_info"] = ( - f"The value is the same before and after calling `init_weights` of {self.__class__.__name__} " - ) - self._params_init_info[param]["tmp_mean_value"] = param.data.mean().cpu() - - # pass `params_init_info` to all submodules - # All submodules share the same `params_init_info`, - # so it will be updated when parameters are - # modified at any level of the model. - for sub_module in self.modules(): - sub_module._params_init_info = self._params_init_info - - module_name = self.__class__.__name__ - if not self._is_init: - if self.init_cfg: - print_log( - f"initialize {module_name} with init_cfg {self.init_cfg}", - logger="current", - level=logging.DEBUG, - ) - - init_cfgs = self.init_cfg - if isinstance(self.init_cfg, dict): - init_cfgs = [self.init_cfg] - - # PretrainedInit has higher priority than any other init_cfg. - # Therefore we initialize `pretrained_cfg` last to overwrite - # the previous initialized weights. - # See details in https://github.com/open-mmlab/mmengine/issues/691 - other_cfgs = [] - pretrained_cfg = [] - for init_cfg in init_cfgs: - assert isinstance(init_cfg, dict) - if init_cfg["type"] == "Pretrained" or init_cfg["type"] is PretrainedInit: - pretrained_cfg.append(init_cfg) - else: - other_cfgs.append(init_cfg) - - initialize(self, other_cfgs) - - for m in self.children(): - if is_model_wrapper(m) and not hasattr(m, "init_weights"): - m = m.module - if hasattr(m, "init_weights") and not getattr(m, "is_init", False): - m.init_weights() - # users may overload the `init_weights` - update_init_info( - m, - init_info=f"Initialized by user-defined `init_weights` in {m.__class__.__name__} ", - ) - if self.init_cfg and pretrained_cfg: - initialize(self, pretrained_cfg) - self._is_init = True - else: - print_log( - f"init_weights of {self.__class__.__name__} has been called more than once.", - logger="current", - level=logging.WARNING, - ) - - if is_top_level_module: - self._dump_init_info() - - for sub_module in self.modules(): - del sub_module._params_init_info - - @master_only - def _dump_init_info(self): - """Dump the initialization information to a file named - `initialization.log.json` in workdir.""" - - logger = MMLogger.get_current_instance() - with_file_handler = False - # dump the information to the logger file if there is a `FileHandler` - for handler in logger.handlers: - if isinstance(handler, FileHandler): - handler.stream.write("Name of parameter - Initialization information\n") - for name, param in self.named_parameters(): - handler.stream.write(f"\n{name} - {param.shape}: \n{self._params_init_info[param]['init_info']} \n") - handler.stream.flush() - with_file_handler = True - if not with_file_handler: - for name, param in self.named_parameters(): - logger.info(f"\n{name} - {param.shape}: \n{self._params_init_info[param]['init_info']} \n ") - - def __repr__(self): - s = super().__repr__() - if self.init_cfg: - s += f"\ninit_cfg={self.init_cfg}" - return s - - -class Sequential(BaseModule, nn.Sequential): - """Sequential module in openmmlab. - - Ensures that all modules in ``Sequential`` have a different initialization - strategy than the outer model - - Args: - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, *args, init_cfg: dict | None = None): - BaseModule.__init__(self, init_cfg) - nn.Sequential.__init__(self, *args) - - -class ModuleList(BaseModule, nn.ModuleList): - """ModuleList in openmmlab. - - Ensures that all modules in ``ModuleList`` have a different initialization - strategy than the outer model - - Args: - modules (iterable, optional): An iterable of modules to add. - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, modules: Iterable | None = None, init_cfg: dict | None = None): - BaseModule.__init__(self, init_cfg) - nn.ModuleList.__init__(self, modules) - - -class ModuleDict(BaseModule, nn.ModuleDict): - """ModuleDict in openmmlab. - - Ensures that all modules in ``ModuleDict`` have a different initialization - strategy than the outer model - - Args: - modules (dict, optional): A mapping (dictionary) of (string: module) - or an iterable of key-value pairs of type (string, module). - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, modules: dict | None = None, init_cfg: dict | None = None): - BaseModule.__init__(self, init_cfg) - nn.ModuleDict.__init__(self, modules) diff --git a/libs/visengine/visengine/model/efficient_conv_bn_eval.py b/libs/visengine/visengine/model/efficient_conv_bn_eval.py deleted file mode 100644 index 821ef29..0000000 --- a/libs/visengine/visengine/model/efficient_conv_bn_eval.py +++ /dev/null @@ -1,149 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from operator import attrgetter - -import torch -import torch.nn as nn - - -def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor): - """Code borrowed from mmcv 2.0.1, so that this feature can be used for old - mmcv versions. - - Implementation based on https://arxiv.org/abs/2305.11624 - "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" - It leverages the associative law between convolution and affine transform, - i.e., normalize (weight conv feature) = (normalize weight) conv feature. - It works for Eval mode of ConvBN blocks during validation, and can be used - for training as well. It reduces memory and computation cost. - Args: - bn (_BatchNorm): a BatchNorm module. - conv (nn._ConvNd): a conv module - x (torch.Tensor): Input feature map. - """ - # These lines of code are designed to deal with various cases - # like bn without affine transform, and conv without bias - weight_on_the_fly = conv.weight - if conv.bias is not None: - bias_on_the_fly = conv.bias - else: - bias_on_the_fly = torch.zeros_like(bn.running_var) - - if bn.weight is not None: - bn_weight = bn.weight - else: - bn_weight = torch.ones_like(bn.running_var) - - if bn.bias is not None: - bn_bias = bn.bias - else: - bn_bias = torch.zeros_like(bn.running_var) - - # shape of [C_out, 1, 1, 1] in Conv2d - weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape([-1] + [1] * (len(conv.weight.shape) - 1)) - # shape of [C_out, 1, 1, 1] in Conv2d - coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff - - # shape of [C_out, C_in, k, k] in Conv2d - weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly - # shape of [C_out] in Conv2d - bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * (bias_on_the_fly - bn.running_mean) - - return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) - - -def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor): - """This function controls whether to use `efficient_conv_bn_eval_forward`. - - If the following `bn` is in `eval` mode, then we turn on the special - `efficient_conv_bn_eval_forward`. - """ - if not bn.training: - # bn in eval mode - output = efficient_conv_bn_eval_forward(bn, conv, x) - return output - else: - conv_out = conv._conv_forward(x, conv.weight, conv.bias) - return bn(conv_out) - - -def efficient_conv_bn_eval_graph_transform(fx_model): - """Find consecutive conv+bn calls in the graph, inplace modify the graph - with the fused operation.""" - modules = dict(fx_model.named_modules()) - - patterns = [(torch.nn.modules.conv._ConvNd, torch.nn.modules.batchnorm._BatchNorm)] - - pairs = [] - # Iterate through nodes in the graph to find ConvBN blocks - for node in fx_model.graph.nodes: - # If our current node isn't calling a Module then we can ignore it. - if node.op != "call_module": - continue - target_module = modules[node.target] - found_pair = False - for conv_class, bn_class in patterns: - if isinstance(target_module, bn_class): - source_module = modules[node.args[0].target] - if isinstance(source_module, conv_class): - found_pair = True - # Not a conv-BN pattern or output of conv is used by other nodes - if not found_pair or len(node.args[0].users) > 1: - continue - - # Find a pair of conv and bn computation nodes to optimize - conv_node = node.args[0] - bn_node = node - pairs.append([conv_node, bn_node]) - - for conv_node, bn_node in pairs: - # set insertion point - fx_model.graph.inserting_before(conv_node) - # create `get_attr` node to access modules - # note that we directly call `create_node` to fill the `name` - # argument. `fx_model.graph.get_attr` and - # `fx_model.graph.call_function` does not allow the `name` argument. - conv_get_node = fx_model.graph.create_node(op="get_attr", target=conv_node.target, name="get_conv") - bn_get_node = fx_model.graph.create_node(op="get_attr", target=bn_node.target, name="get_bn") - # prepare args for the fused function - args = (bn_get_node, conv_get_node, conv_node.args[0]) - # create a new node - new_node = fx_model.graph.create_node( - op="call_function", - target=efficient_conv_bn_eval_control, - args=args, - name="efficient_conv_bn_eval", - ) - # this node replaces the original conv + bn, and therefore - # should replace the uses of bn_node - bn_node.replace_all_uses_with(new_node) - # take care of the deletion order: - # delete bn_node first, and then conv_node - fx_model.graph.erase_node(bn_node) - fx_model.graph.erase_node(conv_node) - - # regenerate the code - fx_model.graph.lint() - fx_model.recompile() - - -def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): - import torch.fx as fx - - # currently we use `fx.symbolic_trace` to trace models. - # in the future, we might turn to pytorch 2.0 compile infrastructure to - # get the `fx.GraphModule` IR. Nonetheless, the graph transform function - # can remain unchanged. We just need to change the way - # we get `fx.GraphModule`. - fx_model: fx.GraphModule = fx.symbolic_trace(model) - efficient_conv_bn_eval_graph_transform(fx_model) - model.forward = fx_model.forward - - -def turn_on_efficient_conv_bn_eval(model: torch.nn.Module, modules: list[str] | str): - if isinstance(modules, str): - modules = [modules] - for module_name in modules: - module = attrgetter(module_name)(model) - turn_on_efficient_conv_bn_eval_for_single_model(module) diff --git a/libs/visengine/visengine/model/utils.py b/libs/visengine/visengine/model/utils.py deleted file mode 100644 index 8757b6a..0000000 --- a/libs/visengine/visengine/model/utils.py +++ /dev/null @@ -1,257 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from visengine.logging import print_log - -# Use torch's SyncBatchNorm directly -SyncBatchNorm = nn.SyncBatchNorm - - -def stack_batch( - tensor_list: list[torch.Tensor], - pad_size_divisor: int = 1, - pad_value: int | float = 0, -) -> torch.Tensor: - """Stack multiple tensors to form a batch and pad the tensor to the max - shape use the right bottom padding mode in these images. If - ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is - divisible by ``pad_size_divisor``. - - Args: - tensor_list (List[Tensor]): A list of tensors with the same dim. - pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding - to ensure the shape of each dim is divisible by - ``pad_size_divisor``. This depends on the model, and many - models need to be divisible by 32. Defaults to 1 - pad_value (int, float): The padding value. Defaults to 0. - - Returns: - Tensor: The n dim tensor. - """ - assert isinstance(tensor_list, list), f"Expected input type to be list, but got {type(tensor_list)}" - assert tensor_list, "`tensor_list` could not be an empty list" - assert len({tensor.ndim for tensor in tensor_list}) == 1, ( - f"Expected the dimensions of all tensors must be the same, but got {[tensor.ndim for tensor in tensor_list]}" - ) - - dim = tensor_list[0].dim() - num_img = len(tensor_list) - all_sizes: torch.Tensor = torch.Tensor([tensor.shape for tensor in tensor_list]) - max_sizes = torch.ceil(torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor - padded_sizes = max_sizes - all_sizes - # The first dim normally means channel, which should not be padded. - padded_sizes[:, 0] = 0 - if padded_sizes.sum() == 0: - return torch.stack(tensor_list) - # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4), - # it means that padding the last dim with 1(left) 2(right), padding the - # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of - # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed, - # and only odd index of pad should be assigned to keep padding "right" and - # "bottom". - pad = torch.zeros(num_img, 2 * dim, dtype=torch.int) - pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)] - batch_tensor = [] - for idx, tensor in enumerate(tensor_list): - batch_tensor.append(F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value)) - return torch.stack(batch_tensor) - - -def detect_anomalous_params(loss: torch.Tensor, model) -> None: - parameters_in_graph = set() - visited = set() - - def traverse(grad_fn): - if grad_fn is None: - return - if grad_fn not in visited: - visited.add(grad_fn) - if hasattr(grad_fn, "variable"): - parameters_in_graph.add(grad_fn.variable) - parents = grad_fn.next_functions - if parents is not None: - for parent in parents: - grad_fn = parent[0] - traverse(grad_fn) - - traverse(loss.grad_fn) - for n, p in model.named_parameters(): - if p not in parameters_in_graph and p.requires_grad: - print_log( - f"{n} with shape {p.size()} is not in the computational graph \n", - logger="current", - level=logging.ERROR, - ) - - -def merge_dict(*args): - """Merge all dictionaries into one dictionary. - - If pytorch version >= 1.8, ``merge_dict`` will be wrapped - by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip - trace ``merge_dict``. - - Note: - If a function needs to be traced by ``torch.fx.symbolic_trace``, - but inevitably needs to use ``update`` method of ``dict``(``update`` - is not traceable). It should use ``merge_dict`` to replace - ``xxx.update``. - - Args: - *args: dictionary needs to be merged. - - Returns: - dict: Merged dict from args - """ - output = {} - for item in args: - assert isinstance(item, dict), f"all arguments of merge_dict should be a dict, but got {type(item)}" - output.update(item) - return output - - -# torch.fx is only available when pytorch version >= 1.8. -# If the subclass of `BaseModel` has multiple submodules, and each module -# will return a loss dict during training process, i.e., `TwoStageDetector` -# in mmdet. It should use `merge_dict` to get the total loss, rather than -# `loss.update` to keep model traceable. -try: - import torch.fx - - # make torch.fx skip trace `merge_dict`. - merge_dict = torch.fx.wrap(merge_dict) - -except ImportError: - warnings.warn( - "Cannot import torch.fx, `merge_dict` is a simple function to merge multiple dicts", - stacklevel=2, - ) - - -class _BatchNormXd(nn.modules.batchnorm._BatchNorm): - """A general BatchNorm layer without input dimension check. - - Reproduced from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc - is `_check_input_dim` that is designed for tensor sanity checks. - The check has been bypassed in this class for the convenience of converting - SyncBatchNorm. - """ - - def _check_input_dim(self, input: torch.Tensor): - return - - -def revert_sync_batchnorm(module: nn.Module) -> nn.Module: - """Helper function to convert all `SyncBatchNorm` (SyncBN) and - `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to - `BatchNormXd` layers. - - Adapted from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - - Args: - module (nn.Module): The module containing `SyncBatchNorm` layers. - - Returns: - module_output: The converted module with `BatchNormXd` layers. - """ - module_output = module - module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] - module_checklist.append(SyncBatchNorm) - - if isinstance(module, tuple(module_checklist)): - module_output = _BatchNormXd( - module.num_features, - module.eps, - module.momentum, - module.affine, - module.track_running_stats, - ) - if module.affine: - # no_grad() may not be needed here but - # just to be consistent with `convert_sync_batchnorm()` - with torch.no_grad(): - module_output.weight = module.weight - module_output.bias = module.bias - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - module_output.training = module.training - # qconfig exists in quantized models - if hasattr(module, "qconfig"): - module_output.qconfig = module.qconfig - for name, child in module.named_children(): - # Some custom modules or 3rd party implemented modules may raise an - # error when calling `add_module`. Therefore, try to catch the error - # and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 - # for more details. - try: - module_output.add_module(name, revert_sync_batchnorm(child)) - except Exception: - print_log( - f"Failed to convert {child} from SyncBN to BN!", - logger="current", - level=logging.WARNING, - ) - del module - return module_output - - -def convert_sync_batchnorm(module: nn.Module, implementation="torch") -> nn.Module: - """Helper function to convert all `BatchNorm` layers in the model to - `SyncBatchNorm` (SyncBN) or `mmcv.ops.sync_bn.SyncBatchNorm` (MMSyncBN) - layers. Adapted from `PyTorch convert sync batchnorm`_. - - Args: - module (nn.Module): The module containing `SyncBatchNorm` layers. - implementation (str): The type of `SyncBatchNorm` to convert to. - - - 'torch': convert to `torch.nn.modules.batchnorm.SyncBatchNorm`. - - 'mmcv': convert to `mmcv.ops.sync_bn.SyncBatchNorm`. - - Returns: - nn.Module: The converted module with `SyncBatchNorm` layers. - - .. _PyTorch convert sync batchnorm: - https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm - """ - module_output = module - - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - if implementation == "torch": - SyncBatchNorm = torch.nn.modules.batchnorm.SyncBatchNorm - elif implementation == "mmcv": - from mmcv.ops import SyncBatchNorm # type: ignore - else: - raise ValueError(f'sync_bn should be "torch" or "mmcv", but got {implementation}') - - module_output = SyncBatchNorm( - module.num_features, - module.eps, - module.momentum, - module.affine, - module.track_running_stats, - ) - - if module.affine: - with torch.no_grad(): - module_output.weight = module.weight - module_output.bias = module.bias - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - if hasattr(module, "qconfig"): - module_output.qconfig = module.qconfig - for name, child in module.named_children(): - module_output.add_module(name, convert_sync_batchnorm(child, implementation)) - del module - return module_output diff --git a/libs/visengine/visengine/model/weight_init.py b/libs/visengine/visengine/model/weight_init.py deleted file mode 100644 index 87219f3..0000000 --- a/libs/visengine/visengine/model/weight_init.py +++ /dev/null @@ -1,670 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import math -import warnings - -import numpy as np -import torch -import torch.nn as nn -from torch import Tensor - -from visengine.logging import print_log -from visengine.registry import WEIGHT_INITIALIZERS, build_from_cfg - - -def update_init_info(module, init_info): - """Update the `_params_init_info` in the module if the value of parameters - are changed. - - Args: - module (obj:`nn.Module`): The module of PyTorch with a user-defined - attribute `_params_init_info` which records the initialization - information. - init_info (str): The string that describes the initialization. - """ - assert hasattr(module, "_params_init_info"), f"Can not find `_params_init_info` in {module}" - for name, param in module.named_parameters(): - assert param in module._params_init_info, ( - f"Find a new :obj:`Parameter` " - f"named `{name}` during executing the " - f"`init_weights` of " - f"`{module.__class__.__name__}`. " - f"Please do not add or " - f"replace parameters during executing " - f"the `init_weights`. " - ) - - # The parameter has been changed during executing the - # `init_weights` of module - mean_value = param.data.mean().cpu() - if module._params_init_info[param]["tmp_mean_value"] != mean_value: - module._params_init_info[param]["init_info"] = init_info - module._params_init_info[param]["tmp_mean_value"] = mean_value - - -def constant_init(module, val, bias=0): - if hasattr(module, "weight") and module.weight is not None: - nn.init.constant_(module.weight, val) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def xavier_init(module, gain=1, bias=0, distribution="normal"): - assert distribution in ["uniform", "normal"] - if hasattr(module, "weight") and module.weight is not None: - if distribution == "uniform": - nn.init.xavier_uniform_(module.weight, gain=gain) - else: - nn.init.xavier_normal_(module.weight, gain=gain) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def normal_init(module, mean=0, std=1, bias=0): - if hasattr(module, "weight") and module.weight is not None: - nn.init.normal_(module.weight, mean, std) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def trunc_normal_init( - module: nn.Module, - mean: float = 0, - std: float = 1, - a: float = -2, - b: float = 2, - bias: float = 0, -) -> None: - if hasattr(module, "weight") and module.weight is not None: - trunc_normal_(module.weight, mean, std, a, b) # type: ignore - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) # type: ignore - - -def uniform_init(module, a=0, b=1, bias=0): - if hasattr(module, "weight") and module.weight is not None: - nn.init.uniform_(module.weight, a, b) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def kaiming_init(module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"): - assert distribution in ["uniform", "normal"] - if hasattr(module, "weight") and module.weight is not None: - if distribution == "uniform": - nn.init.kaiming_uniform_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - else: - nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def caffe2_xavier_init(module, bias=0): - # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch - # Acknowledgment to FAIR's internal code - kaiming_init( - module, - a=1, - mode="fan_in", - nonlinearity="leaky_relu", - bias=bias, - distribution="uniform", - ) - - -def bias_init_with_prob(prior_prob): - """Initialize conv/fc bias value according to a given probability value.""" - bias_init = float(-np.log((1 - prior_prob) / prior_prob)) - return bias_init - - -def _get_bases_name(m): - return [b.__name__ for b in m.__class__.__bases__] - - -class BaseInit: - def __init__(self, *, bias=0, bias_prob=None, layer=None): - self.wholemodule = False - if not isinstance(bias, int | float): - raise TypeError(f"bias must be a number, but got a {type(bias)}") - - if bias_prob is not None: - if not isinstance(bias_prob, float): - raise TypeError( - f"bias_prob type must be float, \ - but got {type(bias_prob)}" - ) - - if layer is not None: - if not isinstance(layer, str | list): - raise TypeError( - f"layer must be a str or a list of str, \ - but got a {type(layer)}" - ) - else: - layer = [] - - if bias_prob is not None: - self.bias = bias_init_with_prob(bias_prob) - else: - self.bias = bias - self.layer = [layer] if isinstance(layer, str) else layer - - def _get_init_info(self): - info = f"{self.__class__.__name__}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Constant") -class ConstantInit(BaseInit): - """Initialize module parameters with constant values. - - Args: - val (int | float): the value to fill the weights in the module with - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, val, **kwargs): - super().__init__(**kwargs) - self.val = val - - def __call__(self, module): - def init(m): - if self.wholemodule: - constant_init(m, self.val, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - constant_init(m, self.val, self.bias) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: val={self.val}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Xavier") -class XavierInit(BaseInit): - r"""Initialize module parameters with values according to the method - described in the paper below. - - `Understanding the difficulty of training deep feedforward - neural networks - Glorot, X. & Bengio, Y. (2010). - `_ - - Args: - gain (int | float): an optional scaling factor. Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - distribution (str): distribution either be ``'normal'`` - or ``'uniform'``. Defaults to ``'normal'``. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, gain=1, distribution="normal", **kwargs): - super().__init__(**kwargs) - self.gain = gain - self.distribution = distribution - - def __call__(self, module): - def init(m): - if self.wholemodule: - xavier_init(m, self.gain, self.bias, self.distribution) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - xavier_init(m, self.gain, self.bias, self.distribution) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: gain={self.gain}, distribution={self.distribution}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Normal") -class NormalInit(BaseInit): - r"""Initialize module parameters with the values drawn from the normal - distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. - - Args: - mean (int | float):the mean of the normal distribution. Defaults to 0. - std (int | float): the standard deviation of the normal distribution. - Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, mean=0, std=1, **kwargs): - super().__init__(**kwargs) - self.mean = mean - self.std = std - - def __call__(self, module): - def init(m): - if self.wholemodule: - normal_init(m, self.mean, self.std, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - normal_init(m, self.mean, self.std, self.bias) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: mean={self.mean}, std={self.std}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="TruncNormal") -class TruncNormalInit(BaseInit): - r"""Initialize module parameters with the values drawn from the normal - distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values - outside :math:`[a, b]`. - - Args: - mean (float): the mean of the normal distribution. Defaults to 0. - std (float): the standard deviation of the normal distribution. - Defaults to 1. - a (float): The minimum cutoff value. - b ( float): The maximum cutoff value. - bias (float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, mean: float = 0, std: float = 1, a: float = -2, b: float = 2, **kwargs) -> None: - super().__init__(**kwargs) - self.mean = mean - self.std = std - self.a = a - self.b = b - - def __call__(self, module: nn.Module) -> None: - def init(m): - if self.wholemodule: - trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: a={self.a}, b={self.b}, mean={self.mean}, std={self.std}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Uniform") -class UniformInit(BaseInit): - r"""Initialize module parameters with values drawn from the uniform - distribution :math:`\mathcal{U}(a, b)`. - - Args: - a (int | float): the lower bound of the uniform distribution. - Defaults to 0. - b (int | float): the upper bound of the uniform distribution. - Defaults to 1. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, a=0, b=1, **kwargs): - super().__init__(**kwargs) - self.a = a - self.b = b - - def __call__(self, module): - def init(m): - if self.wholemodule: - uniform_init(m, self.a, self.b, self.bias) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - uniform_init(m, self.a, self.b, self.bias) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: a={self.a}, b={self.b}, bias={self.bias}" - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Kaiming") -class KaimingInit(BaseInit): - r"""Initialize module parameters with the values according to the method - described in the paper below. - - `Delving deep into rectifiers: Surpassing human-level - performance on ImageNet classification - He, K. et al. (2015). - `_ - - Args: - a (int | float): the negative slope of the rectifier used after this - layer (only used with ``'leaky_relu'``). Defaults to 0. - mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing - ``'fan_in'`` preserves the magnitude of the variance of the weights - in the forward pass. Choosing ``'fan_out'`` preserves the - magnitudes in the backwards pass. Defaults to ``'fan_out'``. - nonlinearity (str): the non-linear function (`nn.functional` name), - recommended to use only with ``'relu'`` or ``'leaky_relu'`` . - Defaults to 'relu'. - bias (int | float): the value to fill the bias. Defaults to 0. - bias_prob (float, optional): the probability for bias initialization. - Defaults to None. - distribution (str): distribution either be ``'normal'`` or - ``'uniform'``. Defaults to ``'normal'``. - layer (str | list[str], optional): the layer will be initialized. - Defaults to None. - """ - - def __init__(self, a=0, mode="fan_out", nonlinearity="relu", distribution="normal", **kwargs): - super().__init__(**kwargs) - self.a = a - self.mode = mode - self.nonlinearity = nonlinearity - self.distribution = distribution - - def __call__(self, module): - def init(m): - if self.wholemodule: - kaiming_init( - m, - self.a, - self.mode, - self.nonlinearity, - self.bias, - self.distribution, - ) - else: - layername = m.__class__.__name__ - basesname = _get_bases_name(m) - if len(set(self.layer) & {layername, *basesname}): - kaiming_init( - m, - self.a, - self.mode, - self.nonlinearity, - self.bias, - self.distribution, - ) - - module.apply(init) - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = ( - f"{self.__class__.__name__}: a={self.a}, mode={self.mode}, " - f"nonlinearity={self.nonlinearity}, " - f"distribution ={self.distribution}, bias={self.bias}" - ) - return info - - -@WEIGHT_INITIALIZERS.register_module(name="Caffe2Xavier") -class Caffe2XavierInit(KaimingInit): - # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch - # Acknowledgment to FAIR's internal code - def __init__(self, **kwargs): - super().__init__( - a=1, - mode="fan_in", - nonlinearity="leaky_relu", - distribution="uniform", - **kwargs, - ) - - def __call__(self, module): - super().__call__(module) - - -@WEIGHT_INITIALIZERS.register_module(name="Pretrained") -class PretrainedInit: - """Initialize module by loading a pretrained model. - - Args: - checkpoint (str): the checkpoint file of the pretrained model should - be load. - prefix (str, optional): the prefix of a sub-module in the pretrained - model. it is for loading a part of the pretrained model to - initialize. For example, if we would like to only load the - backbone of a detector model, we can set ``prefix='backbone.'``. - Defaults to None. - map_location (str): map tensors into proper locations. Defaults to cpu. - """ - - def __init__(self, checkpoint, prefix=None, map_location="cpu"): - self.checkpoint = checkpoint - self.prefix = prefix - self.map_location = map_location - - def __call__(self, module): - from visengine.runner.checkpoint import ( - _load_checkpoint_with_prefix, - load_checkpoint, - load_state_dict, - ) - - if self.prefix is None: - print_log(f"load model from: {self.checkpoint}", logger="current") - load_checkpoint( - module, - self.checkpoint, - map_location=self.map_location, - strict=False, - logger="current", - ) - else: - print_log(f"load {self.prefix} in model from: {self.checkpoint}", logger="current") - state_dict = _load_checkpoint_with_prefix(self.prefix, self.checkpoint, map_location=self.map_location) - load_state_dict(module, state_dict, strict=False, logger="current") - - if hasattr(module, "_params_init_info"): - update_init_info(module, init_info=self._get_init_info()) - - def _get_init_info(self): - info = f"{self.__class__.__name__}: load from {self.checkpoint}" - return info - - -def _initialize(module, cfg, wholemodule=False): - func = build_from_cfg(cfg, WEIGHT_INITIALIZERS) - # wholemodule flag is for override mode, there is no layer key in override - # and initializer will give init values for the whole module with the name - # in override. - func.wholemodule = wholemodule - func(module) - - -def _initialize_override(module, override, cfg): - if not isinstance(override, dict | list): - raise TypeError( - f"override must be a dict or a list of dict, \ - but got {type(override)}" - ) - - override = [override] if isinstance(override, dict) else override - - for override_ in override: - cp_override = copy.deepcopy(override_) - name = cp_override.pop("name", None) - if name is None: - raise ValueError(f'`override` must contain the key "name",but got {cp_override}') - # if override only has name key, it means use args in init_cfg - if not cp_override: - cp_override.update(cfg) - # if override has name key and other args except type key, it will - # raise error - elif "type" not in cp_override.keys(): - raise ValueError(f'`override` need "type" key, but got {cp_override}') - - if hasattr(module, name): - _initialize(getattr(module, name), cp_override, wholemodule=True) - else: - raise RuntimeError(f"module did not have attribute {name}, but init_cfg is {cp_override}.") - - -def initialize(module, init_cfg): - r"""Initialize a module. - - Args: - module (``torch.nn.Module``): the module will be initialized. - init_cfg (dict | list[dict]): initialization configuration dict to - define initializer. OpenMMLab has implemented 6 initializers - including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, - ``Kaiming``, and ``Pretrained``. - - Example: - >>> module = nn.Linear(2, 3, bias=True) - >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) - >>> initialize(module, init_cfg) - >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2)) - >>> # define key ``'layer'`` for initializing layer with different - >>> # configuration - >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1), - dict(type='Constant', layer='Linear', val=2)] - >>> initialize(module, init_cfg) - >>> # define key``'override'`` to initialize some specific part in - >>> # module - >>> class FooNet(nn.Module): - >>> def __init__(self): - >>> super().__init__() - >>> self.feat = nn.Conv2d(3, 16, 3) - >>> self.reg = nn.Conv2d(16, 10, 3) - >>> self.cls = nn.Conv2d(16, 5, 3) - >>> model = FooNet() - >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d', - >>> override=dict(type='Constant', name='reg', val=3, bias=4)) - >>> initialize(model, init_cfg) - >>> model = ResNet(depth=50) - >>> # Initialize weights with the pretrained model. - >>> init_cfg = dict(type='Pretrained', - checkpoint='torchvision://resnet50') - >>> initialize(model, init_cfg) - >>> # Initialize weights of a sub-module with the specific part of - >>> # a pretrained model by using "prefix". - >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\ - >>> 'retinanet_r50_fpn_1x_coco/'\ - >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' - >>> init_cfg = dict(type='Pretrained', - checkpoint=url, prefix='backbone.') - """ - if not isinstance(init_cfg, dict | list): - raise TypeError( - f"init_cfg must be a dict or a list of dict, \ - but got {type(init_cfg)}" - ) - - if isinstance(init_cfg, dict): - init_cfg = [init_cfg] - - for cfg in init_cfg: - # should deeply copy the original config because cfg may be used by - # other modules, e.g., one init_cfg shared by multiple bottleneck - # blocks, the expected cfg will be changed after pop and will change - # the initialization behavior of other modules - cp_cfg = copy.deepcopy(cfg) - override = cp_cfg.pop("override", None) - _initialize(module, cp_cfg) - - if override is not None: - cp_cfg.pop("layer", None) - _initialize_override(module, override, cp_cfg) - else: - # All attributes in module have same initialization. - pass - - -def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, b: float) -> Tensor: - # Method based on - # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - # Modified from - # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.", - stacklevel=2, - ) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - lower = norm_cdf((a - mean) / std) - upper = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [lower, upper], then translate - # to [2lower-1, 2upper-1]. - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0) -> Tensor: - r"""Fills the input Tensor with values drawn from a truncated normal - distribution. The values are effectively drawn from the normal distribution - :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside - :math:`[a, b]` redrawn until they are within the bounds. The method used - for generating the random values works best when :math:`a \leq \text{mean} - \leq b`. - - Modified from - https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py - - Args: - tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`. - mean (float): the mean of the normal distribution. - std (float): the standard deviation of the normal distribution. - a (float): the minimum cutoff value. - b (float): the maximum cutoff value. - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/libs/visengine/visengine/model/wrappers/__init__.py b/libs/visengine/visengine/model/wrappers/__init__.py deleted file mode 100644 index 2c59803..0000000 --- a/libs/visengine/visengine/model/wrappers/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.utils.dl_utils import TORCH_VERSION -from visengine.utils.version_utils import digit_version -from .distributed import MMDistributedDataParallel -from .seperate_distributed import MMSeparateDistributedDataParallel -from .utils import is_model_wrapper - -__all__ = [ - "MMDistributedDataParallel", - "MMSeparateDistributedDataParallel", - "is_model_wrapper", -] - -from .fully_sharded_distributed import MMFullyShardedDataParallel - -__all__.append("MMFullyShardedDataParallel") diff --git a/libs/visengine/visengine/model/wrappers/distributed.py b/libs/visengine/visengine/model/wrappers/distributed.py deleted file mode 100644 index 49f40e4..0000000 --- a/libs/visengine/visengine/model/wrappers/distributed.py +++ /dev/null @@ -1,169 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import torch -from torch.nn.parallel import DataParallel, DistributedDataParallel - -from visengine.registry import MODEL_WRAPPERS - -from ..utils import detect_anomalous_params - -if TYPE_CHECKING: - from visengine.optim import OptimWrapper - -MODEL_WRAPPERS.register_module(module=DistributedDataParallel) -MODEL_WRAPPERS.register_module(module=DataParallel) - - -@MODEL_WRAPPERS.register_module(force=True) -class MMDistributedDataParallel(DistributedDataParallel): - """A distributed model wrapper used for training,testing and validation in - loop. - - Different from DistributedDataParallel, MMDistributedDataParallel - implements three methods :meth:`train_step`, :meth:`val_step` and - :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` - and ``test_loop``. - - - ``train_step``: Called by ``runner.train_loop``, and implement - default model forward, gradient back propagation, parameter updating - logic. To take advantage of DistributedDataParallel's automatic gradient - synchronization, ``train_step`` calls ``DistributedDataParallel.forward`` - to calculate the losses, and call other methods of :class:`BaseModel` to - pre-process data and parse losses. Finally, update model parameters by - :class:`OptimWrapper` and return the loss dictionary used - for logging. - - - ``val_step``: Called by ``runner.val_loop`` and get the inference - results. Since there is no gradient synchronization requirement, - this procedure is equivalent to ``BaseModel.val_step`` - - - ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``. - - Args: - detect_anomalous_params (bool): This option is only used for - debugging which will slow down the training speed. - Detect anomalous parameters that are not included in - the computational graph with `loss` as the root. - There are two cases - - - Parameters were not used during forward pass. - - Parameters were not used to produce loss. - - Defaults to False. - - **kwargs: keyword arguments passed to ``DistributedDataParallel``. - - - device_ids (List[int] or torch.device, optional): CUDA devices - for module. - - output_device (int or torch.device, optional): Device location of - output for single-device CUDA modules. - - dim (int): Defaults to 0. - - broadcast_buffers (bool): Flag that enables syncing ( - broadcasting) buffers of the module at beginning of the - ``forward`` function. Defaults to True - - find_unused_parameters (bool): Whether to find parameters of - module, which are not in the forward graph. Defaults to False. - - process_group (ProcessGroup, optional): The process group to be - used for distributed data all-reduction. - - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults - to 25. - - check_reduction (bool): This argument is deprecated. Defaults - to False. - - gradient_as_bucket_view (bool): Defaults to False. - - static_graph (bool): Defaults to False. - - See more information about arguments in - :class:`torch.nn.parallel.DistributedDataParallel`. - - Note: - If model has multiple submodules and each module has - separate optimization strategies, - :class:`MMSeparateDistributedDataParallel` should be used to wrap - the model. - - Note: - If model itself has custom optimization strategy, rather than - simply forward model and update model. A custom model wrapper - inherit from ``MMDistributedDataParallel`` should be defined and - override the ``train_step`` method. - """ - - def __init__(self, module, detect_anomalous_params: bool = False, **kwargs): - super().__init__(module=module, **kwargs) - self.detect_anomalous_params = detect_anomalous_params - - def train_step(self, data: dict | tuple | list, optim_wrapper) -> dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - :meth:`train_step` will perform the following steps in order: - - - If :attr:`module` defines the preprocess method, - call ``module.preprocess`` to pre-processing data. - - Call ``module.forward(**data)`` and get losses. - - Parse losses. - - Call ``optim_wrapper.optimizer_step`` to update parameters. - - Return log messages of losses. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapper): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - # Enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.module.data_preprocessor(data, training=True) - losses = self._run_forward(data, mode="loss") - parsed_loss, log_vars = self.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - if self.detect_anomalous_params: - detect_anomalous_params(parsed_loss, model=self) - return log_vars - - def val_step(self, data: dict | tuple | list) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.val_step(data) - - def test_step(self, data: dict | tuple | list) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.test_step(data) - - def _run_forward(self, data: dict | tuple | list, mode: str) -> Any: - """Unpacks data for :meth:`forward` - - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, list | tuple): - results = self(*data, mode=mode) - else: - raise TypeError(f"Output of `data_preprocessor` should be list, tuple or dict, but got {type(data)}") - return results diff --git a/libs/visengine/visengine/model/wrappers/fully_sharded_distributed.py b/libs/visengine/visengine/model/wrappers/fully_sharded_distributed.py deleted file mode 100644 index c61d258..0000000 --- a/libs/visengine/visengine/model/wrappers/fully_sharded_distributed.py +++ /dev/null @@ -1,436 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from __future__ import annotations - -from collections.abc import Callable, Iterable -from functools import partial -from typing import TYPE_CHECKING, Any - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup - -# yapf: disable -from torch.distributed.fsdp.api import ( - FullStateDictConfig, - LocalOptimStateDictConfig, - LocalStateDictConfig, - OptimStateDictConfig, - ShardedOptimStateDictConfig, - ShardedStateDictConfig, - ShardingStrategy, - StateDictConfig, - StateDictSettings, - StateDictType, -) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - CPUOffload, - FullOptimStateDictConfig, - FullyShardedDataParallel, - MixedPrecision, -) - -# yapf: enable -from visengine.registry import FUNCTIONS, MODEL_WRAPPERS -from visengine.structures import BaseDataElement -from visengine.utils import digit_version, is_seq_of - -if TYPE_CHECKING: - from visengine.optim import OptimWrapper - - -@MODEL_WRAPPERS.register_module(force=True) -class MMFullyShardedDataParallel(FullyShardedDataParallel): - """A wrapper for sharding Module parameters across data parallel workers. - - Different from FullyShardedDataParallel, MMFullyShardedDataParallel - implements three methods :meth:`train_step`, :meth:`val_step` and - :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` - and ``test_loop``. - - - ``train_step``: Called by ``runner.train_loop``, and implement - default model forward, gradient back propagation, parameter updating - logic. - - - ``val_step``: Called by ``runner.val_loop`` and get the inference - results. Specially, since MMFullyShardedDataParallel will wrap model - recursively, it may cause some problem if one just use - ``BaseModel.val_step`` to implement ``val_step`` here. To avoid that, - ``val_step`` will call methods of :obj:`BaseModel` to pre-process - data first, and use ``FullyShardedDataParallel.forward`` to get result. - - - ``test_step``: Called by ``runner.test_loop`` and get the inference - results. Its logic is equivalent to ``val_loop``. - - Args: - module (nn.Module): module to be wrapped with FSDP. - process_group (ProcessGroup, optional): process group for sharding. - cpu_offload (bool, CPUOffload, optional): - CPU offloading config. - Different from FullyShardedDataParallel,Since it can be set by - users' pre-defined config in MMEngine,its type is expected to be - `None`, `bool` or `CPUOffload`. - - Currently, only parameter and gradient CPU offload is supported. - It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order - for params and grads to be on same device to work with optimizer. - This API is subject to change. Default is ``None`` in which case - there will be no offloading. - auto_wrap_policy (str or Callable, optional): - Specifying a policy to recursively wrap layers with FSDP. - Different from FullyShardedDataParallel, Since it can be set by - users' pre-defined config in MMEngine, its type is expected to be - `None`, `str` or `Callable`. If it's `str`, then - MMFullyShardedDataParallel will try to get specified method in - ``FSDP_WRAP_POLICIES`` registry,and this method will be passed to - FullyShardedDataParallel to finally initialize model. - - Note that this policy currently will only apply to child modules of - the passed in module. The remainder modules are always wrapped in - the returned FSDP root instance. - ``default_auto_wrap_policy`` written in - ``torch.distributed.fsdp.wrap`` is an example of - ``auto_wrap_policy`` callable, this policy wraps layers with - parameter sizes larger than 100M. Users can supply the customized - ``auto_wrap_policy`` callable that should accept following - arguments: ``module: nn.Module``, ``recurse: bool``, - ``unwrapped_params: int``, extra customized arguments could be - added to the customized ``auto_wrap_policy`` callable as well. - - Example:: - - >>> def custom_auto_wrap_policy( - >>> module: nn.Module, - >>> recurse: bool, - >>> unwrapped_params: int, - >>> # These are customizable for this policy function. - >>> min_num_params: int = int(1e8), - >>> ) -> bool: - >>> return unwrapped_params >= min_num_params - - backward_prefetch (str or BackwardPrefetch, optional): - Different from FullyShardedDataParallel, this argument could be a - string or a BackwardPrefetch instance. If it's a string, then - it should be ``BACKWARD_PRE`` or ``BACKWARD_POST`` - mixed_precision (dict or MixedPrecision, optional): - This configures native mixed precision for FSDP. If this is set to - ``None``. Different from the native FSDP, this argument can a dict - like this: - - Examples: - >>> mixed_precision=dict(param_dtype='float16', - >>> buffer_dtype='float32', - >>> reduce_dtype='float32') - - Defaults to None. - use_orig_params (bool): Different from native - ``FullyShardedDataParallel``, it defaults to True. - **kwargs: Keyword arguments passed to - :class:`FullyShardedDataParallel`. - """ - - def __init__( - self, - module: nn.Module, - process_group: dict | ProcessGroup | None = None, - sharding_strategy: str | ShardingStrategy = None, - cpu_offload: bool | CPUOffload | None = None, - auto_wrap_policy: str | Callable | None = None, - backward_prefetch: str | BackwardPrefetch | None = None, - mixed_precision: dict | MixedPrecision | None = None, - param_init_fn: str | Callable[[nn.Module], None] | None = None, # type: ignore - use_orig_params: bool = True, - **kwargs, - ): - if isinstance(sharding_strategy, str): - sharding_strategy = ShardingStrategy[sharding_strategy] - if not (isinstance(sharding_strategy, ShardingStrategy) or sharding_strategy is None): - raise TypeError( - f"sharding_strategy must be str or enum of `ShardingStrategy` , but got {sharding_strategy}" - ) - - if isinstance(cpu_offload, bool): - cpu_offload = CPUOffload(offload_params=cpu_offload) - if not (isinstance(cpu_offload, CPUOffload) or cpu_offload is None): - raise TypeError(f"`cpu_offload` should be `None`, `bool`or `CPUOffload`, but has type {type(cpu_offload)}") - - with FUNCTIONS.switch_scope_and_registry(None): - if isinstance(auto_wrap_policy, str): - auto_wrap_policy = FUNCTIONS.get(auto_wrap_policy) # type: ignore - if auto_wrap_policy is None: - raise ValueError("`auto_wrap_policy` is not registered!") - elif isinstance(auto_wrap_policy, dict): - policy = auto_wrap_policy.pop("type") - if isinstance(policy, str): - policy = FUNCTIONS.get(policy) # type: ignore - if policy is None: - raise ValueError("`auto_wrap_policy` is not registered!") - auto_wrap_policy = partial(policy, **auto_wrap_policy) - - if not (auto_wrap_policy is None or callable(auto_wrap_policy)): # type: ignore - raise TypeError( - f"`auto_wrap_policy` should be a str, a callable, a dict or None, but has type {type(auto_wrap_policy)}" - ) - - if isinstance(backward_prefetch, str): - backward_prefetch = BackwardPrefetch[backward_prefetch] - if not (isinstance(backward_prefetch, BackwardPrefetch) or backward_prefetch is None): - raise TypeError( - "`backward_prefetch` should be `None`, string of " - '"BACKWARD_PRE" and "BACKWARD_POST", or ' - f"`BackwardPrefetch`, but has type {type(backward_prefetch)}" - ) - - if isinstance(param_init_fn, str): - param_init_fn = FUNCTIONS.get(param_init_fn) # type: ignore - if param_init_fn is None: - raise ValueError("`param_init_fn` is not registered!") - elif isinstance(param_init_fn, dict): - init_fn = param_init_fn.pop("type") - if isinstance(param_init_fn, str): - init_fn = FUNCTIONS.get(init_fn) # type: ignore - if init_fn is None: - raise ValueError("`param_init_fn` is not registered!") - param_init_fn = partial(init_fn, **param_init_fn) - - if not (callable(param_init_fn) or param_init_fn is None): - raise TypeError( - f"`param_init_fn` should be a str, a callable, a dict or None, but has type {type(param_init_fn)}" - ) - - def parse_dtype(dtype): - if dtype is None: - return None - elif isinstance(dtype, str): - return getattr(torch, dtype) - elif isinstance(dtype, torch.dtype): - return dtype - else: - raise TypeError(f"`dtype` should be `None`, `str` or `torch.dtype`, but has type {type(dtype)}") - - if isinstance(mixed_precision, dict): - mixed_precision["param_dtype"] = parse_dtype(mixed_precision.get("param_dtype", None)) - mixed_precision["reduce_dtype"] = parse_dtype(mixed_precision.get("reduce_dtype", None)) - mixed_precision["buffer_dtype"] = parse_dtype(mixed_precision.get("buffer_dtype", None)) - mixed_precision = MixedPrecision(**mixed_precision) - elif isinstance(mixed_precision, MixedPrecision): - mixed_precision = mixed_precision - elif mixed_precision is not None: - raise TypeError( - f"`mixed_precision` should be `None`, `dict` or `MixedPrecision`, but has type {type(mixed_precision)}" - ) - - # ignored_parameters and ignored_modules will be deprecated by PyTorch. - # Therefore we hide them in **kwargs. - # TODO: Update when PyTorch 2.1.0 released - if "ignored_parameters" in kwargs: - kwargs["ignored_parameters"] = self._get_ignored_params(module, kwargs["ignored_parameters"]) - - if "ignored_modules" in kwargs: - kwargs["ignored_modules"] = self._get_ignored_modules(module, kwargs["ignored_modules"]) - - super().__init__( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - param_init_fn=param_init_fn, - use_orig_params=use_orig_params, - **kwargs, - ) - - def train_step(self, data: dict, optim_wrapper: "OptimWrapper") -> dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - :meth:`train_step` will perform the following steps in order: - - - If :attr:`module` defines the preprocess method, - call ``module.preprocess`` to pre-processing data. - - Call ``module.forward(**data)`` and get losses. - - Parse losses. - - Call ``optim_wrapper.optimizer_step`` to update parameters. - - Return log messages of losses. - - Args: - data (dict): Data sampled by dataloader. - optim_wrapper (OptimWrapper): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. - """ - # enable automatic mixed precision training context. - with optim_wrapper.optim_context(self): - data = self.module.data_preprocessor(data, training=True) - if isinstance(data, dict): - losses = self(**data, mode="loss") - elif isinstance(data, list | tuple): - losses = self(*data, mode="loss") - else: - raise TypeError(f"Output of `data_preprocessor` should be list tuple or dict, but got {type(data)}") - parsed_loss, log_vars = self.module.parse_losses(losses) - optim_wrapper.update_params(parsed_loss) - return log_vars - - def val_step(self, data: dict) -> list[BaseDataElement]: - """Gets the prediction of module during validation process. - - Args: - data (dict): Data sampled by dataloader. - - Returns: - List[BaseDataElement] or dict: The predictions of given data. - """ - data = self.module.data_preprocessor(data, False) - return self._run_forward(data, mode="predict") # type: ignore - - def test_step(self, data: dict) -> list[BaseDataElement]: - """Gets the predictions of module during testing process. - - Args: - data (dict): Data sampled by dataloader. - - Returns: - List[BaseDataElement]: The predictions of given data. - """ - data = self.module.data_preprocessor(data, False) - return self._run_forward(data, mode="predict") # type: ignore - - def _run_forward(self, data: dict | tuple | list, mode: str) -> dict[str, torch.Tensor] | list: - """Unpacks data for :meth:`forward` - Args: - data (dict or tuple or list): Data sampled from dataset. - mode (str): Mode of forward. - Returns: - dict or list: Results of training or testing mode. - """ - if isinstance(data, dict): - results = self(**data, mode=mode) - elif isinstance(data, list | tuple): - results = self(*data, mode=mode) - else: - raise TypeError(f"Output of `data_preprocessor` should be list, tuple or dict, but got {type(data)}") - return results - - def _get_ignored_params(self, module: nn.Module, ignored_parameters: Iterable[str] | Iterable[nn.Module]): - """Get params from string.""" - params_dict = dict(module.named_parameters()) - if is_seq_of(ignored_parameters, str): - ignored_parameters = [params_dict[name] for name in ignored_parameters] - if not is_seq_of(ignored_parameters, nn.Parameter) and ignored_parameters is not None: - raise TypeError( - f"`ignored_modules` should be `None`, `Iterable[str]` or `Iterable[nn.Parameters]`, but has type {type(ignored_parameters)}" - ) - return ignored_parameters - - def _get_ignored_modules(self, module: nn.Module, ignored_modules: Iterable[str] | Iterable[nn.Module]): - """Get modules from string.""" - modules_dict = dict(module.named_modules()) - if is_seq_of(ignored_modules, str): - ignored_modules = [modules_dict[name] for name in ignored_modules] - if not is_seq_of(ignored_modules, nn.Module) and ignored_modules is not None: - raise TypeError( - f"`ignored_modules` should be `None`, `Iterable[str]` or `Iterable[nn.Module]`, but has type {type(ignored_modules)}" - ) - return ignored_modules - - if digit_version(torch.__version__) < digit_version("2.0.1"): - - @staticmethod - def optim_state_dict( - model: torch.nn.Module, - optim: torch.optim.Optimizer, - group: dist.ProcessGroup | None = None, - ) -> dict[str, Any]: - """Copied from pytorch 2.0.1 which has fixed some bugs.""" - state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) - return FullyShardedDataParallel._optim_state_dict_impl( - model=model, - optim=optim, - optim_state_dict=optim.state_dict(), - optim_input=None, - rank0_only=getattr(state_dict_settings.optim_state_dict_config, "rank0_only", False), - full_state_dict=state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT, - group=group, - ) - - @staticmethod - def set_state_dict_type( - module: nn.Module, - state_dict_type: StateDictType, - state_dict_config: StateDictConfig | None = None, - optim_state_dict_config: OptimStateDictConfig | None = None, - ) -> StateDictSettings: - """Copied from pytorch 2.0.1 which has fixed some bugs.""" - import torch.distributed.fsdp._traversal_utils as traversal_utils - - _state_dict_type_to_config = { - StateDictType.FULL_STATE_DICT: FullStateDictConfig, - StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, - StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, - } - _optim_state_dict_type_to_config = { - StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, - StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, - StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, - } - - # Use the default config if a state_dict config is not set. - state_dict_config_type = _state_dict_type_to_config[state_dict_type] - optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] - if state_dict_config is None: - state_dict_config = state_dict_config_type() - if optim_state_dict_config is None: - optim_state_dict_config = optim_state_dict_config_type() - if state_dict_config_type != type(state_dict_config): - raise RuntimeError( - f"Expected state_dict_config of type {state_dict_config_type} but got {type(state_dict_config)}" - ) - if optim_state_dict_config_type != type(optim_state_dict_config): - raise RuntimeError( - f"Expected optim_state_dict_config of type {optim_state_dict_config_type} but got {type(optim_state_dict_config)}" - ) - - # Set the state_dict type and configurations. - prev_state_dict_type = None - prev_state_dict_config = None - prev_optim_state_dict_config = None - for submodule in traversal_utils._get_fsdp_states(module): - if prev_state_dict_type is None: - prev_state_dict_type = submodule._state_dict_type - else: - assert prev_state_dict_type == submodule._state_dict_type, ( - "All FSDP modules should have the same state_dict_type." - ) - if prev_state_dict_config is None: - prev_state_dict_config = submodule._state_dict_config - else: - assert isinstance(submodule._state_dict_config, type(prev_state_dict_config)), ( - "All FSDP modules must have the same type of state_dict_config." - ) - if prev_optim_state_dict_config is None: - prev_optim_state_dict_config = submodule._optim_state_dict_config - else: - assert isinstance( - submodule._optim_state_dict_config, - type(prev_optim_state_dict_config), - ), "All FSDP modules must have the same type of optim_state_dict_config." - - submodule._state_dict_type = state_dict_type - submodule._state_dict_config = state_dict_config - submodule._optim_state_dict_config = optim_state_dict_config - - return StateDictSettings( - prev_state_dict_type, - prev_state_dict_config, - prev_optim_state_dict_config, - ) diff --git a/libs/visengine/visengine/model/wrappers/seperate_distributed.py b/libs/visengine/visengine/model/wrappers/seperate_distributed.py deleted file mode 100644 index 03f661b..0000000 --- a/libs/visengine/visengine/model/wrappers/seperate_distributed.py +++ /dev/null @@ -1,164 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from __future__ import annotations - -from contextlib import ExitStack, contextmanager -from typing import TYPE_CHECKING - -import torch -import torch.nn as nn -from torch.nn.parallel.distributed import DistributedDataParallel - -from visengine.device import get_device -from visengine.registry import MODEL_WRAPPERS - -from .distributed import MMDistributedDataParallel - -if TYPE_CHECKING: - from visengine.optim import OptimWrapperDict - - -@MODEL_WRAPPERS.register_module(force=True) -class MMSeparateDistributedDataParallel(DistributedDataParallel): - """A DistributedDataParallel wrapper for models in MMGeneration. - - In MMedting and MMGeneration there is a need to wrap different modules in - the models with separate DistributedDataParallel. Otherwise, it will cause - errors for GAN training. For example, the GAN model, usually has two - submodules: generator and discriminator. If we wrap both of them in one - standard DistributedDataParallel, it will cause errors during training, - because when we update the parameters of the generator (or discriminator), - the parameters of the discriminator (or generator) is not updated, which is - not allowed for DistributedDataParallel. So we design this wrapper to - separately wrap DistributedDataParallel for generator and discriminator. - In this wrapper, we perform two operations: - - 1. Wraps each module in the models with separate MMDistributedDataParallel. - Note that only modules with parameters will be wrapped. - 2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to - get losses and predictions. - - Args: - module (nn.Module): model contain multiple submodules which have - separately updating strategy. - broadcast_buffers (bool): Same as that in - ``torch.nn.parallel.distributed.DistributedDataParallel``. - Defaults to False. - find_unused_parameters (bool): Same as that in - ``torch.nn.parallel.distributed.DistributedDataParallel``. - Traverse the autograd graph of all tensors contained in returned - value of the wrapped module's forward function. Defaults to False. - **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``. - - - device_ids (List[int] or torch.device, optional): CUDA devices - for module. - - output_device (int or torch.device, optional): Device location of - output for single-device CUDA modules. - - dim (int): Defaults to 0. - - process_group (ProcessGroup, optional): The process group to be - used for distributed data all-reduction. - - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults - to 25. - - check_reduction (bool): This argument is deprecated. Defaults - to False. - - gradient_as_bucket_view (bool): Defaults to False. - - static_graph (bool): Defaults to False. - - See more information about arguments in - :class:`torch.nn.parallel.DistributedDataParallel`. - """ - - def __init__( - self, - module: nn.Module, - broadcast_buffers: bool = False, - find_unused_parameters: bool = False, - **kwargs, - ): - super(DistributedDataParallel, self).__init__() - self.module = module - device = get_device() - # Wrap the submodule with parameters of `self.module` to - # `MMDistributedDataParallel` - for name, sub_module in module._modules.items(): - # module without parameters. - if next(sub_module.parameters(), None) is None: - sub_module = sub_module.to(device) - elif all(not p.requires_grad for p in sub_module.parameters()): - sub_module = sub_module.to(device) - else: - sub_module = MMDistributedDataParallel( - module=sub_module.to(device), - broadcast_buffers=broadcast_buffers, - find_unused_parameters=find_unused_parameters, - **kwargs, - ) - module._modules[name] = sub_module - - def train_step(self, data: dict | tuple | list, optim_wrapper: "OptimWrapperDict") -> dict[str, torch.Tensor]: - """Interface for model forward, backward and parameters updating during - training process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - optim_wrapper (OptimWrapperDict): A wrapper of optimizer to - update parameters. - - Returns: - Dict[str, torch.Tensor]: A dict of tensor for logging. - """ - return self.module.train_step(data, optim_wrapper) - - def val_step(self, data: dict | tuple | list) -> list: - """Gets the prediction of module during validation process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.val_step(data) - - def test_step(self, data: dict | tuple | list) -> list: - """Gets the predictions of module during testing process. - - Args: - data (dict or tuple or list): Data sampled from dataset. - - Returns: - list: The predictions of given data. - """ - return self.module.test_step(data) - - @contextmanager - def no_sync(self): - """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel`` - modules.""" - with ExitStack() as stack: - for sub_ddp_model in self.module._modules.values(): - stack.enter_context(sub_ddp_model.no_sync()) - yield - - def train(self, mode: bool = True) -> "MMSeparateDistributedDataParallel": - """Sets the module in training mode. - - In order to make the ddp wrapper inheritance hierarchy more uniform, - ``MMSeparateDistributedDataParallel`` inherits from - ``DistributedDataParallel``, but will not call its constructor. - Since the attributes of ``DistributedDataParallel`` have not been - initialized, call the ``train`` method of ``DistributedDataParallel`` - will raise an error if pytorch version <= 1.9. Therefore, override - this method to call the ``train`` method of submodules. - - Args: - mode (bool): whether to set training mode (``True``) or evaluation - mode (``False``). Defaults to ``True``. - - Returns: - Module: self. - """ - self.training = mode - self.module.train(mode) - return self diff --git a/libs/visengine/visengine/model/wrappers/utils.py b/libs/visengine/visengine/model/wrappers/utils.py deleted file mode 100644 index 4f9c456..0000000 --- a/libs/visengine/visengine/model/wrappers/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from visengine.registry import MODEL_WRAPPERS, Registry - - -def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): - """Check if a module is a model wrapper. - - The following 4 model in MMEngine (and their subclasses) are regarded as - model wrappers: DataParallel, DistributedDataParallel, - MMDataParallel, MMDistributedDataParallel. You may add you own - model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``. - - Args: - model (nn.Module): The model to be checked. - registry (Registry): The parent registry to search for model wrappers. - - Returns: - bool: True if the input model is a model wrapper. - """ - module_wrappers = tuple(registry.module_dict.values()) - if isinstance(model, module_wrappers): - return True - - if not registry.children: - return False - - return any(is_model_wrapper(model, child) for child in registry.children.values()) diff --git a/libs/visengine/visengine/optim/__init__.py b/libs/visengine/visengine/optim/__init__.py deleted file mode 100644 index 6599c25..0000000 --- a/libs/visengine/visengine/optim/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .optimizer import ( - OPTIM_WRAPPER_CONSTRUCTORS, - OPTIMIZERS, - AmpOptimWrapper, - BaseOptimWrapper, - DefaultOptimWrapperConstructor, - OptimWrapper, - OptimWrapperDict, - build_optim_wrapper, -) - -# yapf: disable -from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, - CosineAnnealingLR, CosineAnnealingMomentum, - CosineAnnealingParamScheduler, ExponentialLR, - ExponentialMomentum, ExponentialParamScheduler, - LinearLR, LinearMomentum, LinearParamScheduler, - MultiStepLR, MultiStepMomentum, - MultiStepParamScheduler, OneCycleLR, - OneCycleParamScheduler, PolyLR, PolyMomentum, - PolyParamScheduler, ReduceOnPlateauLR, - ReduceOnPlateauMomentum, ReduceOnPlateauParamScheduler, - StepLR, StepMomentum, StepParamScheduler, - _ParamScheduler) - -# yapf: enable -__all__ = [ - "OPTIMIZERS", - "OPTIM_WRAPPER_CONSTRUCTORS", - "AmpOptimWrapper", - "BaseOptimWrapper", - "ConstantLR", - "ConstantMomentum", - "ConstantParamScheduler", - "CosineAnnealingLR", - "CosineAnnealingMomentum", - "CosineAnnealingParamScheduler", - "DefaultOptimWrapperConstructor", - "ExponentialLR", - "ExponentialMomentum", - "ExponentialParamScheduler", - "LinearLR", - "LinearMomentum", - "LinearParamScheduler", - "MultiStepLR", - "MultiStepMomentum", - "MultiStepParamScheduler", - "OneCycleLR", - "OneCycleParamScheduler", - "OptimWrapper", - "OptimWrapperDict", - "PolyLR", - "PolyMomentum", - "PolyParamScheduler", - "ReduceOnPlateauLR", - "ReduceOnPlateauMomentum", - "ReduceOnPlateauParamScheduler", - "StepLR", - "StepMomentum", - "StepParamScheduler", - "_ParamScheduler", - "build_optim_wrapper", -] diff --git a/libs/visengine/visengine/optim/optimizer/__init__.py b/libs/visengine/visengine/optim/optimizer/__init__.py deleted file mode 100644 index a4d7403..0000000 --- a/libs/visengine/visengine/optim/optimizer/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .amp_optimizer_wrapper import AmpOptimWrapper -from .base import BaseOptimWrapper -from .builder import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, build_optim_wrapper -from .default_constructor import DefaultOptimWrapperConstructor -from .optimizer_wrapper import OptimWrapper -from .optimizer_wrapper_dict import OptimWrapperDict - -__all__ = [ - "OPTIMIZERS", - "OPTIM_WRAPPER_CONSTRUCTORS", - "AmpOptimWrapper", - "BaseOptimWrapper", - "DefaultOptimWrapperConstructor", - "OptimWrapper", - "OptimWrapperDict", - "build_optim_wrapper", -] diff --git a/libs/visengine/visengine/optim/optimizer/amp_optimizer_wrapper.py b/libs/visengine/visengine/optim/optimizer/amp_optimizer_wrapper.py deleted file mode 100644 index 5f24188..0000000 --- a/libs/visengine/visengine/optim/optimizer/amp_optimizer_wrapper.py +++ /dev/null @@ -1,179 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager - -import torch -import torch.nn as nn - -from visengine.registry import OPTIM_WRAPPERS - -from .optimizer_wrapper import OptimWrapper - -# updated from torch.cuda.amp -> torch.amp -# due to a deprecation warning -from torch.amp import GradScaler -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from visengine.runner.amp import autocast - - -# There's also an APEX implementation that came before this, -# but the torch implementation (here) is recommended. -# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 -# https://github.com/open-mmlab/mmengine/tree/main/mmengine/optim/optimizer -@OPTIM_WRAPPERS.register_module(force=True) -class AmpOptimWrapper(OptimWrapper): - """A subclass of :class:`OptimWrapper` that supports automatic mixed - precision training based on torch.cuda.amp. - - ``AmpOptimWrapper`` provides a unified interface with - ``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way - as ``OptimWrapper``. - - Warnings: - ``AmpOptimWrapper`` requires PyTorch >= 1.6. - - Args: - loss_scale (float or str or dict): The initial configuration of - `torch.cuda.amp.GradScaler`. See more specific arguments - introduction at `PyTorch AMP `_ # noqa: E501 - Defaults to ``dynamic``. - - - "dynamic": Initialize GradScale without any arguments. - - float: Initialize GradScaler with ``init_scale``. - - dict: Initialize GradScaler with more detail configuration. - - dtype (str or torch.dtype, optional): The data type to autocast in amp. - If a ``str`` is given, it will be converted to ``torch.dtype``. - Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and - `'float64'`. If set to ``None``, the default data type will be used. - Defaults to None. - `New in version 0.6.1.` - use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should - be enabled when using ``FullyShardedDataParallel``. - Defaults to False. - `New in version 0.8.0.` - **kwargs: Keyword arguments passed to OptimWrapper. - - Warnings: - ``dtype`` argument is only available with PyTorch version >= 1.10.0. If - you use PyTorch of an older version, it will be ignored. - - Note: - If you use ``IterBasedRunner`` and enable gradient accumulation, - the original `max_iters` should be multiplied by - ``accumulative_counts``. - """ - - valid_dtypes = ("float16", "bfloat16", "float32", "float64") - - def __init__( - self, - loss_scale: str = "dynamic", - dtype: str | torch.dtype = None, - use_fsdp: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - self._scale_update_param = None - - if use_fsdp: - scaler_type = ShardedGradScaler - else: - scaler_type = GradScaler - - if loss_scale == "dynamic": - # If loss_scale is a string, it must be 'dynamic', then dynamic - # loss scaling will be used. - self.loss_scaler = scaler_type() - elif isinstance(loss_scale, float): - # Static loss scaling - self._scale_update_param = loss_scale - self.loss_scaler = scaler_type(init_scale=loss_scale) - elif isinstance(loss_scale, dict): - # More specific configuration. - self.loss_scaler = scaler_type(**loss_scale) - else: - raise TypeError(f'loss_scale must be of type float, dict, or "dynamic", but got {loss_scale}') - - # convert string value to torch.dtype - if isinstance(dtype, str): - assert dtype in self.valid_dtypes, f"dtype should be any of {self.valid_dtypes}, got {dtype}" - dtype = getattr(torch, dtype) - - assert dtype is None or isinstance(dtype, torch.dtype), ( - f"dtype should be None or instance of torch.dtype, got {dtype}" - ) - self.cast_dtype = dtype - - def backward(self, loss: torch.Tensor, **kwargs): - """Perform gradient back propagation with :attr:`loss_scaler`. - - Args: - loss (torch.Tensor): The loss of current iteration. - kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` - """ - self.loss_scaler.scale(loss).backward(**kwargs) - self._inner_count += 1 - - def step(self, **kwargs): - """Update parameters with :attr:`loss_scaler`. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.step`. - """ - if self.clip_grad_kwargs: - self.loss_scaler.unscale_(self.optimizer) - self._clip_grad() - self.loss_scaler.step(self.optimizer, **kwargs) - self.loss_scaler.update(self._scale_update_param) - - def state_dict(self) -> dict: - """Get the state dictionary of :attr:`optimizer` and - :attr:`loss_scaler`. - - Based on the state dictionary of the optimizer, the returned state - dictionary will add a key named "loss_scaler". - - Returns: - dict: The merged state dict of :attr:`loss_scaler` and - :attr:`optimizer`. - """ - # save state_dict of loss_scaler - state_dict = super().state_dict() - state_dict["loss_scaler"] = self.loss_scaler.state_dict() - return state_dict - - def load_state_dict(self, state_dict: dict): - """Load and parse the state dictionary of :attr:`optimizer` and - :attr:`loss_scaler`. - - If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will - load the corresponding keys. Otherwise, only the :attr:`optimizer` - will load the state dictionary. - - Args: - state_dict (dict): The state dict of :attr:`optimizer` and - :attr:`loss_scaler` - """ - if "loss_scaler" in state_dict: - self.loss_scaler.load_state_dict(state_dict.pop("loss_scaler")) - - if "base_param_settings" in state_dict: - self.base_param_settings = state_dict.pop("base_param_settings") - - # load state_dict of optimizer - self.optimizer.load_state_dict(state_dict) - - @contextmanager - def optim_context(self, model: nn.Module): - """Enables the context for mixed precision training, and enables the - context for disabling gradient synchronization during gradient - accumulation context. - - Args: - model (nn.Module): The training model. - """ - with super().optim_context(model), autocast(dtype=self.cast_dtype): - yield diff --git a/libs/visengine/visengine/optim/optimizer/base.py b/libs/visengine/visengine/optim/optimizer/base.py deleted file mode 100644 index 0a2234b..0000000 --- a/libs/visengine/visengine/optim/optimizer/base.py +++ /dev/null @@ -1,128 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - -import torch - - -class BaseOptimWrapper(metaclass=ABCMeta): - def __init__(self, optimizer): - self.optimizer = optimizer - - # The Following code is used to initialize `base_param_settings`. - # `base_param_settings` is used to store the parameters that are not - # updated by the optimizer. - # The `base_param_settings` used for tracking the base learning in the - # optimizer. If the optimizer has multiple parameter groups, this - # params will not be scaled by the loss factor. - if len(optimizer.param_groups) > 1: - self.base_param_settings = {"params": torch.tensor([0.0], dtype=torch.float)} - self.base_param_settings.update(**self.optimizer.defaults) - else: - self.base_param_settings = None # type: ignore - - @abstractmethod - def update_params(self, *args, **kwargs): - """Update parameters in :attr:`optimizer`.""" - - @abstractmethod - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Perform gradient back propagation.""" - - @abstractmethod - def zero_grad(self, **kwargs) -> None: - """A wrapper of ``Optimizer.zero_grad``.""" - - @abstractmethod - def step(self, **kwargs): - """Call the step method of optimizer.""" - - def state_dict(self) -> dict: - """A wrapper of ``Optimizer.state_dict``.""" - state_dict = self.optimizer.state_dict() - if self.base_param_settings is not None: - state_dict["base_param_settings"] = self.base_param_settings - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - """A wrapper of ``Optimizer.load_state_dict``. load the state dict of - :attr:`optimizer`. - - Provide unified ``load_state_dict`` interface compatible with automatic - mixed precision training. Subclass can overload this method to - implement the required logic. For example, the state dictionary of - GradScaler should be loaded when training with ``torch.cuda.amp``. - - Args: - state_dict (dict): The state dictionary of :attr:`optimizer`. - """ - base_param_settings = state_dict.pop("base_param_settings", None) - - if base_param_settings is not None: - self.base_param_settings = base_param_settings - - # load state_dict of optimizer - self.optimizer.load_state_dict(state_dict) - - @property - def param_groups(self) -> list[dict]: - """A wrapper of ``Optimizer.param_groups``. - - Make OptimizeWrapper compatible with :class:`_ParamScheduler`. - - Returns: - dict: the ``param_groups`` of :attr:`optimizer`. - """ - if self.base_param_settings is not None: - return [*self.optimizer.param_groups, self.base_param_settings] - else: - return self.optimizer.param_groups - - @property - def defaults(self) -> dict: - """A wrapper of ``Optimizer.defaults``. - - Make OptimizeWrapper compatible with :class:`_ParamScheduler`. - - Returns: - dict: the ``param_groups`` of :attr:`optimizer`. - """ - return self.optimizer.defaults - - def get_lr(self): - """Get the learning rate of the optimizer. - - Provide unified interface to get learning rate of optimizer. - - Returns: - Dict[str, List[float]]: - param_groups learning rate of the optimizer. - """ - res = {} - if self.base_param_settings is not None: - res["base_lr"] = [self.base_param_settings["lr"]] - - res["lr"] = [group["lr"] for group in self.optimizer.param_groups] - - return res - - def get_momentum(self) -> dict[str, list[float]]: - """Get the momentum of the optimizer. - - Provide unified interface to get momentum of optimizer. - - Returns: - Dict[str, List[float]]: Momentum of the optimizer. - """ - momentum = [] - for group in self.optimizer.param_groups: - # Get momentum of SGD. - if "momentum" in group.keys(): - momentum.append(group["momentum"]) - # Get momentum of Adam. - elif "betas" in group.keys(): - momentum.append(group["betas"][0]) - else: - momentum.append(0) - return {"momentum": momentum} diff --git a/libs/visengine/visengine/optim/optimizer/builder.py b/libs/visengine/visengine/optim/optimizer/builder.py deleted file mode 100644 index 9470981..0000000 --- a/libs/visengine/visengine/optim/optimizer/builder.py +++ /dev/null @@ -1,187 +0,0 @@ -import copy -import inspect - -import torch -import torch.nn as nn - -from visengine.config import Config, ConfigDict -from visengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS - -from .optimizer_wrapper import OptimWrapper - - -def register_torch_optimizers() -> list[str]: - """Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - torch_optimizers = [] - for module_name in dir(torch.optim): - if module_name.startswith("__"): - continue - _optim = getattr(torch.optim, module_name) - if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): - if module_name == "Adafactor": - OPTIMIZERS.register_module(name="TorchAdafactor", module=_optim) - else: - OPTIMIZERS.register_module(module=_optim) - torch_optimizers.append(module_name) - return torch_optimizers - - -TORCH_OPTIMIZERS = register_torch_optimizers() - - -def register_torch_npu_optimizers() -> list[str]: - """Register optimizers in ``torch npu`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - return [] - - -NPU_OPTIMIZERS = [] - - -def register_dadaptation_optimizers() -> list[str]: - """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - dadaptation_optimizers = [] - try: - import dadaptation - except ImportError: - pass - else: - for module_name in ["DAdaptAdaGrad", "DAdaptAdam", "DAdaptSGD"]: - _optim = getattr(dadaptation, module_name) - if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) - dadaptation_optimizers.append(module_name) - return dadaptation_optimizers - - -DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() - - -def register_lion_optimizers() -> list[str]: - """Register Lion optimizer to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - optimizers = [] - try: - from lion_pytorch import Lion - except ImportError: - pass - else: - OPTIMIZERS.register_module(module=Lion) - optimizers.append("Lion") - return optimizers - - -LION_OPTIMIZERS = register_lion_optimizers() - - -def register_sophia_optimizers() -> list[str]: - """Register Sophia optimizer to the ``OPTIMIZERS`` registry. - - Returns: - List[str]: A list of registered optimizers' name. - """ - optimizers = [] - try: - import Sophia - except ImportError: - pass - else: - for module_name in dir(Sophia): - _optim = getattr(Sophia, module_name) - if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): - OPTIMIZERS.register_module(module=_optim) - optimizers.append(module_name) - return optimizers - - -SOPHIA_OPTIMIZERS = register_sophia_optimizers() - - -def register_bitsandbytes_optimizers() -> list[str]: - """Register optimizers in ``bitsandbytes`` to the ``OPTIMIZERS`` registry. - - In the `bitsandbytes` library, optimizers that have the same name as the - default optimizers in PyTorch are prefixed with ``bnb_``. For example, - ``bnb_Adagrad``. - - Returns: - List[str]: A list of registered optimizers' name. - """ - dadaptation_optimizers = [] - import bitsandbytes as bnb - - optim_classes = inspect.getmembers( - bnb.optim, - lambda _optim: (inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer)), - ) - for name, optim_cls in optim_classes: - if name in OPTIMIZERS: - name = f"bnb_{name}" - OPTIMIZERS.register_module(module=optim_cls, name=name) - dadaptation_optimizers.append(name) - return dadaptation_optimizers - - -BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() - - -def register_transformers_optimizers(): - transformer_optimizers = [] - try: - from transformers import Adafactor - except ImportError: - pass - else: - OPTIMIZERS.register_module(name="Adafactor", module=Adafactor) - transformer_optimizers.append("Adafactor") - return transformer_optimizers - - -TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers() - - -def build_optim_wrapper(model: nn.Module, cfg: dict | Config | ConfigDict) -> OptimWrapper: - """Build function of OptimWrapper. - - If ``constructor`` is set in the ``cfg``, this method will build an - optimizer wrapper constructor, and use optimizer wrapper constructor to - build the optimizer wrapper. If ``constructor`` is not set, the - ``DefaultOptimWrapperConstructor`` will be used by default. - - Args: - model (nn.Module): Model to be optimized. - cfg (dict): Config of optimizer wrapper, optimizer constructor and - optimizer. - - Returns: - OptimWrapper: The built optimizer wrapper. - """ - optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop("constructor", "DefaultOptimWrapperConstructor") - paramwise_cfg = optim_wrapper_cfg.pop("paramwise_cfg", None) - - # NPU is no longer supported - - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - { - "type": constructor_type, - "optim_wrapper_cfg": optim_wrapper_cfg, - "paramwise_cfg": paramwise_cfg, - } - ) - optim_wrapper = optim_wrapper_constructor(model) - return optim_wrapper diff --git a/libs/visengine/visengine/optim/optimizer/default_constructor.py b/libs/visengine/visengine/optim/optimizer/default_constructor.py deleted file mode 100644 index 652c0a1..0000000 --- a/libs/visengine/visengine/optim/optimizer/default_constructor.py +++ /dev/null @@ -1,298 +0,0 @@ -import inspect -import logging - -import torch -import torch.nn as nn -from torch.nn import GroupNorm, LayerNorm -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.instancenorm import _InstanceNorm - -from visengine.logging import print_log -from visengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS -from visengine.utils import is_list_of - - -@OPTIM_WRAPPER_CONSTRUCTORS.register_module(force=True) -class DefaultOptimWrapperConstructor: - """Default constructor for optimizers. - - By default, each parameter share the same optimizer settings, and we - provide an argument ``paramwise_cfg`` to specify parameter-wise settings. - It is a dict and may contain the following fields: - - - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If - one of the keys in ``custom_keys`` is a substring of the name of one - parameter, then the setting of the parameter will be specified by - ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will - be ignored. It should be noted that the aforementioned ``key`` is the - longest key that is a substring of the name of the parameter. If there - are multiple matched keys with the same length, then the key with lower - alphabet order will be chosen. - ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` - and ``decay_mult``. See Example 2 below. - - ``bias_lr_mult`` (float): It will be multiplied to the learning - rate for all bias parameters (except for those in normalization - layers and offset layers of DCN). - - ``bias_decay_mult`` (float): It will be multiplied to the weight - decay for all bias parameters (except for those in - normalization layers, depthwise conv layers, offset layers of DCN). - - ``norm_decay_mult`` (float): It will be multiplied to the weight - decay for all weight and bias parameters of normalization - layers. - - ``flat_decay_mult`` (float): It will be multiplied to the weight - decay for all one-dimensional parameters - - ``dwconv_decay_mult`` (float): It will be multiplied to the weight - decay for all weight and bias parameters of depthwise conv - layers. - - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning - rate for parameters of offset layer in the deformable convs - of a model. - - ``bypass_duplicate`` (bool): If true, the duplicate parameters - would not be added into optimizer. Defaults to False. - - Note: - - 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will - override the effect of ``bias_lr_mult`` in the bias of offset layer. - So be careful when using both ``bias_lr_mult`` and - ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset - layer in deformable convs, set ``dcn_offset_lr_mult`` to the original - ``dcn_offset_lr_mult`` * ``bias_lr_mult``. - - 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will - apply it to all the DCN layers in the model. So be careful when the - model contains multiple DCN layers in places other than backbone. - - Args: - optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. - - Required fields of ``optim_wrapper_cfg`` are - - - ``type``: class name of the OptimizerWrapper - - ``optimizer``: The configuration of optimizer. - - Optional fields of ``optim_wrapper_cfg`` are - - - any arguments of the corresponding optimizer wrapper type, - e.g., accumulative_counts, clip_grad, etc. - - Required fields of ``optimizer`` are - - - `type`: class name of the optimizer. - - Optional fields of ``optimizer`` are - - - any arguments of the corresponding optimizer type, e.g., - lr, weight_decay, momentum, etc. - - paramwise_cfg (dict, optional): Parameter-wise options. - - Example 1: - >>> model = torch.nn.modules.Conv1d(1, 1, 1) - >>> optim_wrapper_cfg = dict( - >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, - >>> momentum=0.9, weight_decay=0.0001)) - >>> paramwise_cfg = dict(norm_decay_mult=0.) - >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( - >>> optim_wrapper_cfg, paramwise_cfg) - >>> optim_wrapper = optim_wrapper_builder(model) - - Example 2: - >>> # assume model have attribute model.backbone and model.cls_head - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - >>> type='SGD', lr=0.01, weight_decay=0.95)) - >>> paramwise_cfg = dict(custom_keys={ - >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) - >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( - >>> optim_wrapper_cfg, paramwise_cfg) - >>> optim_wrapper = optim_wrapper_builder(model) - >>> # Then the `lr` and `weight_decay` for model.backbone is - >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for - >>> # model.cls_head is (0.01, 0.95). - """ - - def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg: dict | None = None): - if not isinstance(optim_wrapper_cfg, dict): - raise TypeError("optimizer_cfg should be a dict", f"but got {type(optim_wrapper_cfg)}") - assert "optimizer" in optim_wrapper_cfg, '`optim_wrapper_cfg` must contain "optimizer" config' - self.optim_wrapper_cfg = optim_wrapper_cfg.copy() - self.optimizer_cfg = self.optim_wrapper_cfg.pop("optimizer") - self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg - self.base_lr = self.optimizer_cfg.get("lr", None) - self.base_wd = self.optimizer_cfg.get("weight_decay", None) - self._validate_cfg() - - def _validate_cfg(self) -> None: - """Verify the correctness of the config.""" - if not isinstance(self.paramwise_cfg, dict): - raise TypeError(f"paramwise_cfg should be None or a dict, but got {type(self.paramwise_cfg)}") - - if "custom_keys" in self.paramwise_cfg: - if not isinstance(self.paramwise_cfg["custom_keys"], dict): - raise TypeError( - f"If specified, custom_keys must be a dict, but got {type(self.paramwise_cfg['custom_keys'])}" - ) - if self.base_wd is None: - for key in self.paramwise_cfg["custom_keys"]: - if "decay_mult" in self.paramwise_cfg["custom_keys"][key]: - raise ValueError("base_wd should not be None") - - # get base lr and weight decay - # weight_decay must be explicitly specified if mult is specified - if ( - "bias_decay_mult" in self.paramwise_cfg - or "norm_decay_mult" in self.paramwise_cfg - or "dwconv_decay_mult" in self.paramwise_cfg - ): - if self.base_wd is None: - raise ValueError("base_wd should not be None") - - def _is_in(self, param_group: dict, param_group_list: list) -> bool: - """Check whether the `param_group` is in the`param_group_list`""" - assert is_list_of(param_group_list, dict) - param = set(param_group["params"]) - param_set = set() - for group in param_group_list: - param_set.update(set(group["params"])) - - return not param.isdisjoint(param_set) - - def add_params( - self, - params: list[dict], - module: nn.Module, - prefix: str = "", - is_dcn_module: int | float | None = None, - ) -> None: - """Add all parameters of module to the params list. - - The parameters of the given module will be added to the list of param - groups, with specific rules defined by paramwise_cfg. - - Args: - params (list[dict]): A list of param groups, it will be modified - in place. - module (nn.Module): The module to be added. - prefix (str): The prefix of the module - is_dcn_module (int|float|None): If the current module is a - submodule of DCN, `is_dcn_module` will be passed to - control conv_offset layer's learning rate. Defaults to None. - """ - # get param-wise options - custom_keys = self.paramwise_cfg.get("custom_keys", {}) - # first sort with alphabet order and then sort with reversed len of str - sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) - - bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", None) - bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", None) - norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", None) - dwconv_decay_mult = self.paramwise_cfg.get("dwconv_decay_mult", None) - flat_decay_mult = self.paramwise_cfg.get("flat_decay_mult", None) - bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) - dcn_offset_lr_mult = self.paramwise_cfg.get("dcn_offset_lr_mult", None) - - # special rules for norm layers and depth-wise conv layers - is_norm = isinstance(module, GroupNorm | LayerNorm | _BatchNorm | _InstanceNorm) - is_dwconv = isinstance(module, torch.nn.Conv2d) and module.in_channels == module.groups - - for name, param in module.named_parameters(recurse=False): - param_group = {"params": [param]} - if bypass_duplicate and self._is_in(param_group, params): - print_log( - f"{prefix} is duplicate. It is skipped since bypass_duplicate={bypass_duplicate}", - logger="current", - level=logging.WARNING, - ) - continue - if not param.requires_grad: - print_log( - (f"{prefix}.{name} is skipped since its requires_grad={param.requires_grad}"), - logger="current", - level=logging.WARNING, - ) - continue - - # if the parameter match one of the custom keys, ignore other rules - is_custom = False - for key in sorted_keys: - if key in f"{prefix}.{name}": - is_custom = True - lr_mult = custom_keys[key].get("lr_mult", 1.0) - param_group["lr"] = self.base_lr * lr_mult - if self.base_wd is not None: - decay_mult = custom_keys[key].get("decay_mult", 1.0) - param_group["weight_decay"] = self.base_wd * decay_mult - # add custom settings to param_group - for k, v in custom_keys[key].items(): - param_group[k] = v - break - - if not is_custom: - # bias_lr_mult affects all bias parameters - # except for norm.bias dcn.conv_offset.bias - if name == "bias" and not (is_norm or is_dcn_module) and bias_lr_mult is not None: - param_group["lr"] = self.base_lr * bias_lr_mult - - if ( - prefix.find("conv_offset") != -1 - and is_dcn_module - and dcn_offset_lr_mult is not None - and isinstance(module, torch.nn.Conv2d) - ): - # deal with both dcn_offset's bias & weight - param_group["lr"] = self.base_lr * dcn_offset_lr_mult - - # apply weight decay policies - if self.base_wd is not None: - # norm decay - if is_norm and norm_decay_mult is not None: - param_group["weight_decay"] = self.base_wd * norm_decay_mult - # bias lr and decay - elif name == "bias" and not is_dcn_module and bias_decay_mult is not None: - param_group["weight_decay"] = self.base_wd * bias_decay_mult - # depth-wise conv - elif is_dwconv and dwconv_decay_mult is not None: - param_group["weight_decay"] = self.base_wd * dwconv_decay_mult - # flatten parameters except dcn offset - elif param.ndim == 1 and not is_dcn_module and flat_decay_mult is not None: - param_group["weight_decay"] = self.base_wd * flat_decay_mult - params.append(param_group) - for key, value in param_group.items(): - if key == "params": - continue - full_name = f"{prefix}.{name}" if prefix else name - print_log(f"paramwise_options -- {full_name}:{key}={value}", logger="current") - - # Removed the deformable convolutions because they're not used by the Swin MaskRCNN - is_dcn_module = False - for child_name, child_mod in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self.add_params(params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module) - - def __call__(self, model: nn.Module): # -> OptimWrapper: - if hasattr(model, "module"): - model = model.module - - optim_wrapper_cfg = self.optim_wrapper_cfg.copy() - optim_wrapper_cfg.setdefault("type", "OptimWrapper") - optimizer_cfg = self.optimizer_cfg.copy() - optimizer_cls = self.optimizer_cfg["type"] - # Optimizer like HybridAdam in colossalai requires the argument name - # `model_params` rather than `params`. Here we get the first argument - # name and fill it with the model parameters. - if isinstance(optimizer_cls, str): - with OPTIMIZERS.switch_scope_and_registry(None) as registry: - optimizer_cls = registry.get(self.optimizer_cfg["type"]) - fisrt_arg_name = next(iter(inspect.signature(optimizer_cls).parameters)) - # if no paramwise option is specified, just use the global setting - if not self.paramwise_cfg: - optimizer_cfg[fisrt_arg_name] = model.parameters() - optimizer = OPTIMIZERS.build(optimizer_cfg) - else: - # set param-wise lr and weight decay recursively - params: list = [] - self.add_params(params, model) - optimizer_cfg[fisrt_arg_name] = params - optimizer = OPTIMIZERS.build(optimizer_cfg) - optim_wrapper = OPTIM_WRAPPERS.build(optim_wrapper_cfg, default_args={"optimizer": optimizer}) - return optim_wrapper diff --git a/libs/visengine/visengine/optim/optimizer/optimizer_wrapper.py b/libs/visengine/visengine/optim/optimizer/optimizer_wrapper.py deleted file mode 100644 index 1bde7d6..0000000 --- a/libs/visengine/visengine/optim/optimizer/optimizer_wrapper.py +++ /dev/null @@ -1,532 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from contextlib import contextmanager - -import torch -import torch.nn as nn -from torch.optim import Optimizer - -from visengine.logging import MessageHub, print_log, MMLogger -from visengine.registry import OPTIM_WRAPPERS -from visengine.utils.dl_utils import has_batch_norm - -from .base import BaseOptimWrapper - - -@OPTIM_WRAPPERS.register_module(force=True) -class OptimWrapper(BaseOptimWrapper): - """Optimizer wrapper provides a common interface for updating parameters. - - Optimizer wrapper provides a unified interface for single precision - training and automatic mixed precision training with different hardware. - OptimWrapper encapsulates optimizer to provide simplified interfaces - for commonly used training techniques such as gradient accumulative and - grad clips. ``OptimWrapper`` implements the basic logic of gradient - accumulation and gradient clipping based on ``torch.optim.Optimizer``. - The subclasses only need to override some methods to implement the mixed - precision training. See more information in :class:`AmpOptimWrapper`. - - Args: - optimizer (Optimizer): Optimizer used to update model parameters. - accumulative_counts (int): The number of iterations to accumulate - gradients. The parameters will be updated per - ``accumulative_counts``. - clip_grad (dict, optional): If ``clip_grad`` is not None, it will be - the arguments of :func:`torch.nn.utils.clip_grad_norm_` or - :func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a - dict, and the keys could be set as follows: - - If the key ``type`` is not set, or ``type`` is "norm", - the accepted keys are as follows: - - - max_norm (float or int): Max norm of the gradients. - - norm_type (float or int): Type of the used p-norm. Can be - ``'inf'`` for infinity norm. - - error_if_nonfinite (bool): If True, an error is thrown if - the total norm of the gradients from :attr:`parameters` is - ``nan``, ``inf``, or ``-inf``. Defaults to False (will switch - to True in the future) - - If the key ``type`` is set to "value", the accepted keys are as - follows: - - - clip_value (float or int): maximum allowed value of the - gradients. The gradients are clipped in the range - ``(-clip_value, +clip_value)``. - - Note: - If ``accumulative_counts`` is larger than 1, perform - :meth:`update_params` under the context of ``optim_context`` - could avoid unnecessary gradient synchronization. - - Note: - If you use ``IterBasedRunner`` and enable gradient accumulation, - the original `max_iters` should be multiplied by - ``accumulative_counts``. - - Note: - The subclass should ensure that once :meth:`update_params` is called, - ``_inner_count += 1`` is automatically performed. - - Examples: - >>> # Config sample of OptimWrapper and enable clipping gradient by - >>> # norm. - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=1, - >>> clip_grad=dict(max_norm=0.2)) - >>> # Config sample of OptimWrapper and enable clipping gradient by - >>> # value. - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=1, - >>> clip_grad=dict(type='value', clip_value=0.2)) - >>> # Use OptimWrapper to update model. - >>> import torch.nn as nn - >>> import torch - >>> from torch.optim import SGD - >>> from torch.utils.data import DataLoader - >>> from visengine.optim import OptimWrapper - >>> - >>> model = nn.Linear(1, 1) - >>> dataset = torch.randn(10, 1, 1) - >>> dataloader = DataLoader(dataset) - >>> optimizer = SGD(model.parameters(), lr=0.1) - >>> optim_wrapper = OptimWrapper(optimizer) - >>> - >>> for data in dataloader: - >>> loss = model(data) - >>> optim_wrapper.update_params(loss) - >>> # Enable gradient accumulation - >>> optim_wrapper_cfg = dict( - >>> type='OptimWrapper', - >>> _accumulative_counts=3, - >>> clip_grad=dict(max_norm=0.2)) - >>> ddp_model = DistributedDataParallel(model) - >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) - >>> optim_wrapper = OptimWrapper(optimizer) - >>> optim_wrapper.initialize_count_status(0, len(dataloader)) - >>> # If model is a subclass instance of DistributedDataParallel, - >>> # `optim_context` context manager can avoid unnecessary gradient - >>> # synchronize. - >>> for iter, data in enumerate(dataloader): - >>> with optim_wrapper.optim_context(ddp_model): - >>> loss = model(data) - >>> optim_wrapper.update_params(loss) - """ - - def __init__( - self, - optimizer: Optimizer, - accumulative_counts: int = 1, - clip_grad: dict | None = None, - ): - assert accumulative_counts > 0, "_accumulative_counts at least greater than or equal to 1" - self._accumulative_counts = accumulative_counts - self.optimizer = optimizer - - if clip_grad is not None: - # clip_grad_kwargs should not be non-empty dict. - assert isinstance(clip_grad, dict) and clip_grad, ( - "If `clip_grad` is not None, it should be a `dict` which is the arguments of `torch.nn.utils.clip_grad_norm_` or clip_grad_value_`." - ) - clip_type = clip_grad.pop("type", "norm") - if clip_type == "norm": - self.clip_func = torch.nn.utils.clip_grad_norm_ - self.grad_name = "grad_norm" - elif clip_type == "value": - self.clip_func = torch.nn.utils.clip_grad_value_ - self.grad_name = "grad_value" - else: - raise ValueError(f'type of clip_grad should be "norm" or "value" but got {clip_type}') - assert clip_grad, ( - "`clip_grad` should contain other arguments " - "besides `type`. The arguments should match " - "with the `torch.nn.utils.clip_grad_norm_` or " - "clip_grad_value_`" - ) - self.clip_grad_kwargs = clip_grad - # Used to update `grad_norm` log message. - self.message_hub = MessageHub.get_current_instance() - self._inner_count = 0 - # `_max_counts` means the total number of parameter updates. It - # ensures that the gradient of the last few iterations will not be - # lost when the `_max_counts` is not divisible by - # `accumulative_counts`. - self._max_counts = -1 - # The `_remainder_iter` is used for calculating loss factor at the - # last few iterations. If `_max_counts` has not been initialized, - # the loss factor will always be the same as `_accumulative_counts`. - self._remainder_counts = -1 - - # The Following code is used to initialize `base_param_settings`. - # `base_param_settings` is used to store the parameters that are not - # updated by the optimizer. - # The `base_param_settings` used for tracking the base learning in the - # optimizer. If the optimizer has multiple parameter groups, this - # params will not be scaled by the loss factor. - if len(optimizer.param_groups) > 1: - self.base_param_settings = {"params": torch.tensor([0.0], dtype=torch.float)} - self.base_param_settings.update(**self.optimizer.defaults) - else: - self.base_param_settings = None # type: ignore - - def update_params( # type: ignore - self, - loss: torch.Tensor, - step_kwargs: dict | None = None, - zero_kwargs: dict | None = None, - ) -> None: - """Update parameters in :attr:`optimizer`. - - Args: - loss (torch.Tensor): A tensor for back propagation. - step_kwargs (dict): Arguments for optimizer.step. - Defaults to None. - New in version v0.4.0. - zero_kwargs (dict): Arguments for optimizer.zero_grad. - Defaults to None. - New in version v0.4.0. - """ - logger = MMLogger.get_current_instance() - - # Debug log: Check loss before scaling - logger.debug(f"[OptimWrapper] update_params called with loss: {loss.item()}") - if torch.isnan(loss): - logger.error(f"[OptimWrapper] NaN loss detected before scaling! Loss: {loss}") - if torch.isinf(loss): - logger.error(f"[OptimWrapper] Inf loss detected before scaling! Loss: {loss}") - - if step_kwargs is None: - step_kwargs = {} - if zero_kwargs is None: - zero_kwargs = {} - - loss = self.scale_loss(loss) - - # Debug log: Check loss after scaling - logger.debug(f"[OptimWrapper] Loss after scaling: {loss.item()}") - if torch.isnan(loss): - logger.error(f"[OptimWrapper] NaN loss detected after scaling! Scaled loss: {loss}") - if torch.isinf(loss): - logger.error(f"[OptimWrapper] Inf loss detected after scaling! Scaled loss: {loss}") - - self.backward(loss) - # Update parameters only if `self._inner_count` is divisible by - # `self._accumulative_counts` or `self._inner_count` equals to - # `self._max_counts` - if self.should_update(): - logger.debug(f"[OptimWrapper] Updating parameters at inner_count: {self._inner_count}") - self.step(**step_kwargs) - self.zero_grad(**zero_kwargs) - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Perform gradient back propagation. - - Provide unified ``backward`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. For example, ``torch.cuda.amp`` require some extra - operation on GradScaler during backward process. - - Note: - If subclasses inherit from ``OptimWrapper`` override - ``backward``, ``_inner_count +=1`` must be implemented. - - Args: - loss (torch.Tensor): The loss of current iteration. - kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. - """ - logger = MMLogger.get_current_instance() - logger.debug(f"[OptimWrapper] Starting backward pass with loss: {loss.item()}") - - # Check for NaN/Inf before backward - if torch.isnan(loss): - logger.error(f"[OptimWrapper] NaN loss detected before backward! Cannot compute gradients.") - if torch.isinf(loss): - logger.error(f"[OptimWrapper] Inf loss detected before backward! Cannot compute gradients.") - - loss.backward(**kwargs) - - # Check gradients after backward - logger.debug(f"[OptimWrapper] Backward pass completed, checking gradients...") - self._check_gradients() - - self._inner_count += 1 - - def zero_grad(self, **kwargs) -> None: - """A wrapper of ``Optimizer.zero_grad``. - - Provide unified ``zero_grad`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.zero_grad`. - """ - self.optimizer.zero_grad(**kwargs) - - def step(self, **kwargs) -> None: - """A wrapper of ``Optimizer.step``. - - Provide unified ``step`` interface compatible with automatic mixed - precision training. Subclass can overload this method to implement the - required logic. For example, ``torch.cuda.amp`` require some extra - operation on ``GradScaler`` during step process. - - Clip grad if :attr:`clip_grad_kwargs` is not None, and then update - parameters. - - Args: - kwargs: Keyword arguments passed to - :meth:`torch.optim.Optimizer.step`. - """ - logger = MMLogger.get_current_instance() - logger.debug(f"[OptimWrapper] Starting optimizer step...") - - # Check gradients before clipping - self._check_gradients("before clipping") - - if self.clip_grad_kwargs: - self._clip_grad() - # Check gradients after clipping - self._check_gradients("after clipping") - - # Check parameters before step - self._check_parameters("before step") - - self.optimizer.step(**kwargs) - - # Check parameters after step - self._check_parameters("after step") - - logger.debug(f"[OptimWrapper] Optimizer step completed") - - @contextmanager - def optim_context(self, model: nn.Module): - """A Context for gradient accumulation and automatic mix precision - training. - - If subclasses need to enable the context for mix precision training, - e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be - enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 - training, ``optim_context`` will only enable the context for - blocking the unnecessary gradient synchronization during gradient - accumulation - - If model is an instance with ``no_sync`` method (which means - blocking the gradient synchronization) and - ``self._accumulative_counts != 1``. The model will not automatically - synchronize gradients if ``cur_iter`` is divisible by - ``self._accumulative_counts``. Otherwise, this method will enable an - empty context. - - Args: - model (nn.Module): The training model. - """ - # During gradient accumulation process, the gradient synchronize - # should only happen before updating parameters. - if not self.should_sync() and hasattr(model, "no_sync"): - with model.no_sync(): - yield - else: - yield - - def _clip_grad(self) -> None: - """Clip the gradients of parameters.""" - params: list[torch.Tensor] = [] - for param_group in self.optimizer.param_groups: - params.extend(param_group["params"]) - - params = list(filter(lambda p: p.requires_grad and p.grad is not None, params)) - if len(params) > 0: - grad = self.clip_func(params, **self.clip_grad_kwargs) - # `torch.nn.utils.clip_grad_value_` will return None. - if grad is not None: - self.message_hub.update_scalar(f"train/{self.grad_name}", float(grad)) - - def initialize_count_status(self, model: nn.Module, init_counts: int, max_counts: int) -> None: - """Initialize gradient accumulation related attributes. - - ``OptimWrapper`` can be used without calling - ``initialize_iter_status``. However, Consider the case of ``len( - dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is - not divisible by 3, the last iteration will not trigger - ``optimizer.step()``, resulting in one less parameter updating. - - Args: - model (nn.Module): Training model - init_counts (int): The initial value of the inner count. - max_counts (int): The maximum value of the inner count. - """ - self._inner_count = init_counts - self._max_counts = max_counts - if self._inner_count % self._accumulative_counts != 0: - print_log( - "Resumed iteration number is not divisible by " - "`_accumulative_counts` in `GradientCumulativeOptimizerHook`, " - "which means the gradient of some iterations is lost and the " - "result may be influenced slightly.", - logger="current", - level=logging.WARNING, - ) - - if has_batch_norm(model) and self._accumulative_counts > 1: - print_log( - "Gradient accumulative may slightly decrease performance because the model has BatchNorm layers.", - logger="current", - level=logging.WARNING, - ) - # Remainder of `_max_counts` divided by `_accumulative_counts` - self._remainder_counts = self._max_counts % self._accumulative_counts - - def should_update(self) -> bool: - """Decide whether the parameters should be updated at the current - iteration. - - Called by :meth:`update_params` and check whether the optimizer - wrapper should update parameters at current iteration. - - Returns: - bool: Whether to update parameters. - """ - return self._inner_count % self._accumulative_counts == 0 or self._inner_count == self._max_counts - - def should_sync(self) -> bool: - """Decide whether the automatic gradient synchronization should be - allowed at the current iteration. - - It takes effect when gradient accumulation is used to skip - synchronization at the iterations where the parameter is not updated. - - Since ``should_sync`` is called by :meth:`optim_context`, and it is - called before :meth:`backward` which means ``self._inner_count += 1`` - has not happened yet. Therefore, ``self._inner_count += 1`` should be - performed manually here. - - Returns: - bool: Whether to block the automatic gradient synchronization. - """ - return (self._inner_count + 1) % self._accumulative_counts == 0 or (self._inner_count + 1) == self._max_counts - - def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: - """Get scaled loss according to ``_accumulative_counts``, - ``_inner_count`` and max_counts. - - Args: - loss (torch.Tensor): Original loss calculated by model. - - Returns: - loss (torch.Tensor): Scaled loss. - """ - if self._accumulative_counts == 1: - # update parameters without gradient accumulation. The gradient - # should not be rescaled and `loss_factor=1`. - loss_factor = 1 - elif self._max_counts == -1: - loss_factor = self._accumulative_counts - else: - # if `self._accumulative_counts > 1`, the gradient needs to be - # rescaled and accumulated. In most cases, `loss_factor` equals to - # `self._accumulative_counts`. However, `self._max_counts` may not - # be divisible by `self._accumulative_counts`, so the - # `loss_scale` for the last few iterations needs to be - # recalculated. - if self._inner_count < self._max_counts - self._remainder_counts: - loss_factor = self._accumulative_counts - else: - loss_factor = self._remainder_counts - assert loss_factor > 0, ( - "loss_factor should be larger than zero! This error could " - "happened when initialize_iter_status called with an " - "error `init_counts` or `max_counts`" - ) - - loss = loss / loss_factor - return loss - - @property - def inner_count(self): - """Get the number of updating parameters of optimizer wrapper.""" - return self._inner_count - - def __repr__(self): - wrapper_info = f"Type: {type(self).__name__}\n_accumulative_counts: {self._accumulative_counts}\noptimizer: \n" - optimizer_str = repr(self.optimizer) + "\n" - return wrapper_info + optimizer_str - - def _check_gradients(self, stage: str = "") -> None: - """Check gradients for NaN/Inf values and log statistics.""" - logger = MMLogger.get_current_instance() - - nan_grad_params = [] - inf_grad_params = [] - zero_grad_params = [] - normal_grad_params = 0 - total_grad_norm = 0.0 - - for i, param_group in enumerate(self.optimizer.param_groups): - for j, param in enumerate(param_group["params"]): - if param.grad is not None: - grad_norm = param.grad.norm().item() - total_grad_norm += grad_norm**2 - - if torch.isnan(param.grad).any(): - nan_grad_params.append(f"group_{i}_param_{j}") - logger.error(f"[OptimWrapper] NaN gradient detected {stage} in param group {i}, param {j}") - elif torch.isinf(param.grad).any(): - inf_grad_params.append(f"group_{i}_param_{j}") - logger.error(f"[OptimWrapper] Inf gradient detected {stage} in param group {i}, param {j}") - elif grad_norm == 0: - zero_grad_params.append(f"group_{i}_param_{j}") - else: - normal_grad_params += 1 - - total_grad_norm = total_grad_norm**0.5 - - if nan_grad_params or inf_grad_params: - logger.error( - f"[OptimWrapper] Gradient check {stage}: " - f"NaN params: {len(nan_grad_params)}, " - f"Inf params: {len(inf_grad_params)}, " - f"Zero grad params: {len(zero_grad_params)}, " - f"Normal params: {normal_grad_params}" - ) - else: - logger.debug( - f"[OptimWrapper] Gradient check {stage}: " - f"Total grad norm: {total_grad_norm:.6f}, " - f"Zero grad params: {len(zero_grad_params)}, " - f"Normal params: {normal_grad_params}" - ) - - def _check_parameters(self, stage: str = "") -> None: - """Check parameters for NaN/Inf values and log statistics.""" - logger = MMLogger.get_current_instance() - - nan_params = [] - inf_params = [] - normal_params = 0 - - for i, param_group in enumerate(self.optimizer.param_groups): - for j, param in enumerate(param_group["params"]): - if torch.isnan(param).any(): - nan_params.append(f"group_{i}_param_{j}") - logger.error(f"[OptimWrapper] NaN parameter detected {stage} in param group {i}, param {j}") - elif torch.isinf(param).any(): - inf_params.append(f"group_{i}_param_{j}") - logger.error(f"[OptimWrapper] Inf parameter detected {stage} in param group {i}, param {j}") - else: - normal_params += 1 - - if nan_params or inf_params: - logger.error( - f"[OptimWrapper] Parameter check {stage}: " - f"NaN params: {len(nan_params)}, " - f"Inf params: {len(inf_params)}, " - f"Normal params: {normal_params}" - ) - else: - logger.debug(f"[OptimWrapper] Parameter check {stage}: All {normal_params} parameters are normal") diff --git a/libs/visengine/visengine/optim/optimizer/optimizer_wrapper_dict.py b/libs/visengine/visengine/optim/optimizer/optimizer_wrapper_dict.py deleted file mode 100644 index d462f8f..0000000 --- a/libs/visengine/visengine/optim/optimizer/optimizer_wrapper_dict.py +++ /dev/null @@ -1,187 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Iterator -from contextlib import contextmanager - -import torch -import torch.nn as nn - -from .optimizer_wrapper import OptimWrapper - - -class OptimWrapperDict(OptimWrapper): - """A dictionary container of :obj:`OptimWrapper`. - - If runner is training with multiple optimizers, all optimizer wrappers - should be managed by :obj:`OptimWrapperDict` which is built by - ``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save - the state dictionary of all optimizer wrappers. - - Consider the semantic ambiguity of calling :meth:``update_params``, - :meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not - implement these methods. - - Examples: - >>> import torch.nn as nn - >>> from torch.optim import SGD - >>> from visengine.optim import OptimWrapperDict, OptimWrapper - >>> model1 = nn.Linear(1, 1) - >>> model2 = nn.Linear(1, 1) - >>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1)) - >>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1)) - >>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1, - >>> model2=optim_wrapper2) - - Note: - The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed - in the same way as `dict`. - - Args: - **optim_wrappers: A dictionary of ``OptimWrapper`` instance. - """ - - def __init__(self, **optim_wrapper_dict: OptimWrapper): - for key, value in optim_wrapper_dict.items(): - assert isinstance(value, OptimWrapper), ( - f"`OptimWrapperDict` only accept OptimWrapper instance, but got {key}: {type(value)}" - ) - self.optim_wrappers = optim_wrapper_dict - - def update_params( # type: ignore - self, - loss: torch.Tensor, - step_kwargs: dict | None = None, - zero_kwargs: dict | None = None, - ) -> None: - """Update all optimizer wrappers would lead to a duplicate backward - errors, and OptimWrapperDict does not know which optimizer wrapper - should be updated. - - Therefore, this method is not implemented. The optimizer wrapper of - OptimWrapperDict should be accessed and call its `update_params`. - """ - raise NotImplementedError("`update_params` should be called by each optimizer separately`") - - def backward(self, loss: torch.Tensor, **kwargs) -> None: - """Since OptimWrapperDict doesn't know which optimizer wrapper's - backward method should be called (``loss_scaler`` maybe different in - different :obj:AmpOptimWrapper), this method is not implemented. - - The optimizer wrapper of OptimWrapperDict should be accessed and call - its `backward`. - """ - raise NotImplementedError("`backward` should be called by each optimizer separately`") - - def step(self, **kwargs) -> None: - """Since the backward method is not implemented, the step should not be - implemented either.""" - raise NotImplementedError("`step` should be called by each optimizer separately`") - - def zero_grad(self, **kwargs) -> None: - """Set the gradients of all optimizer wrappers to zero.""" - for optim_wrapper in self.optim_wrappers.values(): - optim_wrapper.zero_grad() - - @contextmanager - def optim_context(self, model: nn.Module): - """``optim_context`` should be called by each optimizer separately.""" - raise NotImplementedError("`optim_context` should be called by each optimizer separately") - - def initialize_count_status(self, model: nn.Module, cur_iter, max_iters) -> None: - """Do nothing but provide unified interface for :obj:`OptimWrapper` - - Since ``OptimWrapperDict`` does not know the correspondence between - model and optimizer wrapper. ``initialize_iter_status`` will do nothing - and each optimizer wrapper should call ``initialize_iter_status`` - separately. - """ - return - - @property - def param_groups(self): - """Returns the parameter groups of each OptimWrapper.""" - param_groups = {} - for key, value in self.optim_wrappers.items(): - param_groups[key] = value.param_groups - return param_groups - - def get_lr(self) -> dict[str, list[float]]: - """Get the learning rate of all optimizers. - - Returns: - Dict[str, List[float]]: Learning rate of all optimizers. - """ - lr_dict = {} - for name, optim_wrapper in self.optim_wrappers.items(): - inner_lr_dict = optim_wrapper.get_lr() - if "base_lr" in inner_lr_dict: - lr_dict[f"{name}.base_lr"] = inner_lr_dict["base_lr"] - lr_dict[f"{name}.lr"] = inner_lr_dict["lr"] - return lr_dict - - def get_momentum(self) -> dict[str, list[float]]: - """Get the momentum of all optimizers. - - Returns: - Dict[str, List[float]]: momentum of all optimizers. - """ - momentum_dict = {} - for name, optim_wrapper in self.optim_wrappers.items(): - momentum_dict[f"{name}.momentum"] = optim_wrapper.get_momentum()["momentum"] - return momentum_dict - - def state_dict(self) -> dict: - """Get the state dictionary of all optimizer wrappers. - - Returns: - dict: Each key-value pair in the dictionary represents the name - and state dictionary of corresponding :obj:`OptimWrapper`. - """ - state_dict = {} - for name, optim_wrapper in self.optim_wrappers.items(): - state_dict[name] = optim_wrapper.state_dict() - return state_dict - - def load_state_dict(self, state_dict: dict) -> None: - """Load the state dictionary from the ``state_dict``. - - Args: - state_dict (dict): Each key-value pair in `state_dict` represents - the name and the state dictionary of corresponding - :obj:`OptimWrapper`. - """ - for name, _state_dict in state_dict.items(): - assert name in self.optim_wrappers, f"Mismatched `state_dict`! cannot found {name} in OptimWrapperDict" - self.optim_wrappers[name].load_state_dict(_state_dict) - - def items(self) -> Iterator[tuple[str, OptimWrapper]]: - """A generator to get the name and corresponding :obj:`OptimWrapper`""" - yield from self.optim_wrappers.items() - - def values(self) -> Iterator[OptimWrapper]: - """A generator to get :obj:`OptimWrapper`""" - yield from self.optim_wrappers.values() - - def keys(self) -> Iterator[str]: - """A generator to get the name of :obj:`OptimWrapper`""" - yield from self.optim_wrappers.keys() - - def __getitem__(self, key: str) -> OptimWrapper: - assert key in self.optim_wrappers, ( - f"Cannot find {key} in OptimWrapperDict, please check your optimizer constructor." - ) - return self.optim_wrappers[key] - - def __contains__(self, key: str) -> bool: - return key in self.optim_wrappers - - def __len__(self) -> int: - return len(self.optim_wrappers) - - def __repr__(self) -> str: - desc = "" - for name, optim_wrapper in self.optim_wrappers.items(): - desc += f"name: {name}\n" - desc += repr(optim_wrapper) - return desc diff --git a/libs/visengine/visengine/optim/scheduler/__init__.py b/libs/visengine/visengine/optim/scheduler/__init__.py deleted file mode 100644 index fc9dbb3..0000000 --- a/libs/visengine/visengine/optim/scheduler/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -# yapf: disable -from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR, - ExponentialLR, LinearLR, MultiStepLR, OneCycleLR, - PolyLR, ReduceOnPlateauLR, StepLR) -from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, - CosineRestartMomentum, ExponentialMomentum, - LinearMomentum, MultiStepMomentum, - PolyMomentum, ReduceOnPlateauMomentum, - StepMomentum) -from .param_scheduler import (ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler, _ParamScheduler) - -# yapf: enable -__all__ = [ - "ConstantLR", - "ConstantMomentum", - "ConstantParamScheduler", - "CosineAnnealingLR", - "CosineAnnealingMomentum", - "CosineAnnealingParamScheduler", - "CosineRestartLR", - "CosineRestartMomentum", - "CosineRestartParamScheduler", - "ExponentialLR", - "ExponentialMomentum", - "ExponentialParamScheduler", - "LinearLR", - "LinearMomentum", - "LinearParamScheduler", - "MultiStepLR", - "MultiStepMomentum", - "MultiStepParamScheduler", - "OneCycleLR", - "OneCycleParamScheduler", - "PolyLR", - "PolyMomentum", - "PolyParamScheduler", - "ReduceOnPlateauLR", - "ReduceOnPlateauMomentum", - "ReduceOnPlateauParamScheduler", - "StepLR", - "StepMomentum", - "StepParamScheduler", - "_ParamScheduler", -] diff --git a/libs/visengine/visengine/optim/scheduler/lr_scheduler.py b/libs/visengine/visengine/optim/scheduler/lr_scheduler.py deleted file mode 100644 index 2e5d894..0000000 --- a/libs/visengine/visengine/optim/scheduler/lr_scheduler.py +++ /dev/null @@ -1,386 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.registry import PARAM_SCHEDULERS - -# yapf: disable -from .param_scheduler import ( - ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, - LinearParamScheduler, - MultiStepParamScheduler, - OneCycleParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler, -) - -# yapf: enable - - -class LRSchedulerMixin: - """A mixin class for learning rate schedulers.""" - - def __init__(self, optimizer, *args, **kwargs): - super().__init__(optimizer, "lr", *args, **kwargs) - - -@PARAM_SCHEDULERS.register_module(force=True) -class ConstantLR(LRSchedulerMixin, ConstantParamScheduler): - """Decays the learning rate value of each parameter group by a small - constant factor until the number of epoch reaches a pre-defined milestone: - ``end``. Notice that such decay can happen simultaneously with other - changes to the learning rate value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - factor (float): The number we multiply learning rate until the - milestone. Defaults to 1./3. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without state - dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the learning rate can be simultaneously modified - outside this scheduler by other operators. If the learning rate is set - solely by this scheduler, the learning rate at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum learning rate. Defaults to None. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler): - """Decays the learning rate of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - gamma (float): Multiplicative factor of learning rate decay. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class LinearLR(LRSchedulerMixin, LinearParamScheduler): - """Decays the learning rate of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - learning rate from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - start_factor (float): The number we multiply learning rate in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply learning rate at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler): - """Decays the specified learning rate in each parameter group by gamma once - the number of epoch reaches one of the milestones. Notice that such decay - can happen simultaneously with other changes to the learning rate from - outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of learning rate decay. - Defaults to 0.1. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class StepLR(LRSchedulerMixin, StepParamScheduler): - """Decays the learning rate of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the learning rate from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - step_size (int): Period of learning rate decay. - gamma (float): Multiplicative factor of learning rate decay. - Defaults to 0.1. - begin (int): Step at which to start updating the learning rate. - Defaults to 0. - end (int): Step at which to stop updating the learning rate. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled learning rate is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the learning rate for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class PolyLR(LRSchedulerMixin, PolyParamScheduler): - """Decays the learning rate of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): Wrapped optimizer. - eta_min (float): Minimum learning rate at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): - r"""Sets the learning rate of each parameter group according to the 1cycle - learning rate policy. The 1cycle policy anneals the learning rate from an - initial learning rate to some maximum learning rate and then from that - maximum learning rate to some minimum learning rate much lower than the - initial learning rate. This policy was initially described in the paper - `Super-Convergence: Very Fast Training of Neural Networks Using Large - Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every - batch. `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in - one of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. A number of epochs (epochs) and a number of steps per epoch - (steps_per_epoch) are provided. - In this case, the number of total steps is inferred by - total_steps = epochs * steps_per_epoch - - You must either provide a value for total_steps or provide a value for both - epochs and steps_per_epoch. - - The default behaviour of this scheduler follows the fastai implementation - of 1cycle, which claims that "unpublished work has shown even better - results by using only two phases". To mimic the behaviour of the original - paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - eta_max (float or list): Upper parameter value boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it must be inferred by - providing a value for epochs and steps_per_epoch. - Defaults to None. - pct_start (float): The percentage of the cycle (in number of steps) - spent increasing the learning rate. - Defaults to 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, - "linear" for linear annealing. - Defaults to 'cos' - div_factor (float): Determines the initial learning rate via - initial_param = eta_max/div_factor - Defaults to 25 - final_div_factor (float): Determines the minimum learning rate via - eta_min = initial_param/final_div_factor - Defaults to 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to - annihilate the learning rate according to 'final_div_factor' - instead of modifying the second phase (the first two phases will be - symmetrical about the step indicated by 'pct_start'). - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler): - """Sets the learning rate of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - learning rate from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `min_lr` or `min_lr_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class ReduceOnPlateauLR(LRSchedulerMixin, ReduceOnPlateauParamScheduler): - """Reduce the learning rate of each parameter group when a metric has - stopped improving. Models often benefit from reducing the learning rate by - a factor of 2-10 once learning stagnates. This scheduler reads a metrics - quantity and if no improvement is seen for a ``patience`` number of epochs, - the learning rate is reduced. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - monitor (str): Key name of the value to monitor in metrics dict. - rule (str): One of `less`, `greater`. In `less` rule, learning rate - will be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the learning rate will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which learning rate will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the learning rate after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after learning rate has been reduced. - Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the learning rate of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to learning rate. If the difference - between new and old learning rate is smaller than eps, the update - is ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ diff --git a/libs/visengine/visengine/optim/scheduler/momentum_scheduler.py b/libs/visengine/visengine/optim/scheduler/momentum_scheduler.py deleted file mode 100644 index eeed998..0000000 --- a/libs/visengine/visengine/optim/scheduler/momentum_scheduler.py +++ /dev/null @@ -1,364 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from visengine.registry import PARAM_SCHEDULERS - -# yapf: disable -from .param_scheduler import ( - ConstantParamScheduler, - CosineAnnealingParamScheduler, - CosineRestartParamScheduler, - ExponentialParamScheduler, - LinearParamScheduler, - MultiStepParamScheduler, - PolyParamScheduler, - ReduceOnPlateauParamScheduler, - StepParamScheduler, -) - -# yapf: enable - - -class MomentumSchedulerMixin: - """A mixin class for momentum schedulers. - - It can schedule the momentum in SGD and the beta_0 in Adam series. - """ - - def __init__(self, optimizer, *args, **kwargs): - self.use_betas = False - if "momentum" in optimizer.defaults: - param_name = "momentum" - elif "betas" in optimizer.defaults: - # for Adam series optimizer, the momentum is beta_0 - self.use_betas = True - param_name = "momentum" - for group in optimizer.param_groups: - # set a reference momentum in the param groups for scheduling - group[param_name] = group["betas"][0] - else: - raise ValueError("optimizer must support momentum when using momentum scheduler") - super().__init__(optimizer, param_name, *args, **kwargs) - - def step(self): - """Adjusts the momentum of each parameter group based on the specified - schedule.""" - super().step() - if self.use_betas: - for group in self.optimizer.param_groups: - _, beta_1 = group["betas"] - # update the betas with the calculated value - group["betas"] = (group["momentum"], beta_1) - - -@PARAM_SCHEDULERS.register_module(force=True) -class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler): - """Decays the momentum value of each parameter group by a small constant - factor until the number of epoch reaches a pre-defined milestone: ``end``. - Notice that such decay can happen simultaneously with other changes to the - momentum value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - factor (float): The number we multiply momentum until the milestone. - Defaults to 1./3. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without state - dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by epochs. - Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineAnnealingMomentum(MomentumSchedulerMixin, CosineAnnealingParamScheduler): - r"""Set the momentum of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the momentum can be simultaneously modified - outside this scheduler by other operators. If the momentum is set - solely by this scheduler, the momentum at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - T_max (int): Maximum number of iterations. - eta_min (float): Minimum momentum value. Defaults to None. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler): - """Decays the momentum of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - gamma (float): Multiplicative factor of momentum value decay. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler): - """Decays the momentum of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - momentum from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - start_factor (float): The number we multiply momentum in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply momentum at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler): - """Decays the specified momentum in each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the momentum from outside this - scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of momentum value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class StepMomentum(MomentumSchedulerMixin, StepParamScheduler): - """Decays the momentum of each parameter group by gamma every step_size - epochs. Notice that such decay can happen simultaneously with other changes - to the momentum from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - step_size (int): Period of momentum value decay. - gamma (float): Multiplicative factor of momentum value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the momentum. - Defaults to 0. - end (int): Step at which to stop updating the momentum. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - verbose (bool): Whether to print the momentum for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler): - """Decays the momentum of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - eta_min (float): Minimum momentum at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineRestartMomentum(MomentumSchedulerMixin, CosineRestartParamScheduler): - """Sets the momentum of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - momentum from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `min_lr` or `min_lr_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - -@PARAM_SCHEDULERS.register_module(force=True) -class ReduceOnPlateauMomentum(MomentumSchedulerMixin, ReduceOnPlateauParamScheduler): - """Reduce the momentum of each parameter group when a metric has stopped - improving. Models often benefit from reducing the momentum by a factor of - 2-10 once learning stagnates. This scheduler reads a metrics quantity and - if no improvement is seen for a ``patience`` number of epochs, the momentum - is reduced. - - Args: - optimizer (Optimizer or OptimWrapper): optimizer or Wrapped - optimizer. - monitor (str): Key name of the value to monitor in metrics dict. - rule (str): One of `less`, `greater`. In `less` rule, momentum will - be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the momentum will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which momentum will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the momentum after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after momentum has been reduced. Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the momentum of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to momentum. If the difference - between new and old momentum is smaller than eps, the update is - ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def step(self, metrics=None): - """Adjusts the momentum of each parameter group based on the specified - schedule. - - Args: - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - Defaults to None. - """ - super(MomentumSchedulerMixin, self).step(metrics) - if self.use_betas: - for group in self.optimizer.param_groups: - _, beta_1 = group["betas"] - # update the betas with the calculated value - group["betas"] = (group["momentum"], beta_1) diff --git a/libs/visengine/visengine/optim/scheduler/param_scheduler.py b/libs/visengine/visengine/optim/scheduler/param_scheduler.py deleted file mode 100644 index 84d7106..0000000 --- a/libs/visengine/visengine/optim/scheduler/param_scheduler.py +++ /dev/null @@ -1,1515 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -# ------------------------------------------------------------------------ -# Modified from https://github.com/pytorch/pytorch -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# ------------------------------------------------------------------------ - -import math -import warnings -import weakref -from collections import Counter -from collections.abc import Callable, Sequence -from functools import wraps -from typing import Union - -from torch.optim import Optimizer - -from visengine.logging import print_log -from visengine.optim import BaseOptimWrapper -from visengine.registry import PARAM_SCHEDULERS - -INF = int(1e9) - -OptimizerType = Union[BaseOptimWrapper, Optimizer] - - -class _ParamScheduler: - """Base class for parameter schedulers. - - It should be inherited by all schedulers that schedule parameters in the - optimizer's ``param_groups``. All subclasses should overwrite the - ``_get_value()`` according to their own schedule strategy. - The implementation is motivated by - https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resuming without - state dict. Default value ``-1`` means the ``step`` function is - never be called before. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - # Attach optimizer - if not isinstance(optimizer, Optimizer | BaseOptimWrapper): - raise TypeError(f"``optimizer`` should be an Optimizer,but got {type(optimizer).__name__}") - self.optimizer = optimizer - self.param_name = param_name - - if end <= begin: - raise ValueError(f"end should be larger than begin, but got begin={begin}, end={end}") - self.begin = begin - self.end = end - - self.by_epoch = by_epoch - - assert isinstance(last_step, int) and last_step >= -1 - # Initialize valid step count and base values - if last_step == -1: - for group in optimizer.param_groups: - # If the param is never be scheduled, record the current value - # as the initial value. - group.setdefault(f"initial_{param_name}", group[param_name]) - else: - for i, group in enumerate(optimizer.param_groups): - if f"initial_{param_name}" not in group: - raise KeyError( - f"param 'initial_{param_name}' is not specified in param_groups[{{}}] when resuming an optimizer".format( - i - ) - ) - self.base_values = [group[f"initial_{param_name}"] for group in optimizer.param_groups] - self.last_step = last_step - - # Following https://github.com/pytorch/pytorch/issues/20124 - # We would like to ensure that `scheduler.step()` is called after - # `optimizer.step()` - def with_counter(method: Callable): - if getattr(method, "_with_counter", False): - # `optimizer.step()` has already been replaced, return. - return method - - # Keep a weak reference to the optimizer instance to prevent - # cyclic references. - instance_ref = weakref.ref(method.__self__) # type: ignore - # Get the unbound method for the same purpose. - func = method.__func__ # type: ignore - cls = instance_ref().__class__ # type: ignore - del method - - @wraps(func) - def wrapper(*args, **kwargs): - instance = instance_ref() - instance._global_step += 1 - wrapped = func.__get__(instance, cls) - return wrapped(*args, **kwargs) - - # Note that the returned function here is no longer a bound method, - # so attributes like `__func__` and `__self__` no longer exist. - wrapper._with_counter = True # type: ignore - return wrapper - - # add counter to optimizer - self.optimizer.step = with_counter(self.optimizer.step) # type: ignore - self.optimizer._global_step = -1 # type: ignore - - self._global_step = -1 - self.verbose = verbose - - self.step() - - def state_dict(self) -> dict: - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which is not - the optimizer. - - Returns: - dict: scheduler state. - """ - return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - - def load_state_dict(self, state_dict: dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_value(self): - """Return the last computed value by current scheduler. - - Returns: - list: A list of the last computed value of the optimizer's - ``param_group``. - """ - return self._last_value - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - raise NotImplementedError - - def print_value(self, is_verbose: bool, group: int, value: float): - """Display the current parameter value. - - Args: - is_verbose (bool): Whether to print the value. - group (int): The index of the current ``param_group``. - value (float): The parameter value. - """ - if is_verbose: - print_log( - f"Adjusting parameter value of group {group} to {value:.4e}.", - logger="current", - ) - - def step(self): - """Adjusts the parameter value of each parameter group based on the - specified schedule.""" - # Raise a warning if old pattern is detected - # https://github.com/pytorch/pytorch/issues/20124 - if self._global_step == 0: - if not hasattr(self.optimizer.step, "_with_counter"): - warnings.warn( - "Seems like `optimizer.step()` has been overridden after " - "parameter value scheduler initialization. Please, make " - "sure to call `optimizer.step()` before " - "`scheduler.step()`. See more details at " - "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", - UserWarning, - stacklevel=2, - ) - - # Just check if there were two first scheduler.step() calls - # before optimizer.step() - elif self.optimizer._global_step < 0: - warnings.warn( - "Detected call of `scheduler.step()` before " - "`optimizer.step()`. In PyTorch 1.1.0 and later, you " - "should call them in the opposite order: " - "`optimizer.step()` before `scheduler.step()`. " - "Failure to do this will result in PyTorch skipping " - "the first value of the parameter value schedule. " - "See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", - UserWarning, - stacklevel=2, - ) - self._global_step += 1 - - # Compute parameter value per param group in the effective range - if self.begin <= self._global_step < self.end: - self.last_step += 1 - values = self._get_value() - - for i, data in enumerate(zip(self.optimizer.param_groups, values, strict=False)): - param_group, value = data - param_group[self.param_name] = value - self.print_value(self.verbose, i, value) - - self._last_value = [group[self.param_name] for group in self.optimizer.param_groups] - - -@PARAM_SCHEDULERS.register_module(force=True) -class StepParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the parameter value from outside this scheduler. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - step_size (int): Period of parameter value decay. - gamma (float): Multiplicative factor of parameter value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - step_size: int, - gamma: float = 0.1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - self.step_size = step_size - self.gamma = gamma - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch( - cls, - *args, - step_size, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs, - ): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - step_size = step_size * epoch_length - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - step_size=step_size, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs, - ) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if (self.last_step == 0) or (self.last_step % self.step_size != 0): - return [group[self.param_name] for group in self.optimizer.param_groups] - return [group[self.param_name] * self.gamma for group in self.optimizer.param_groups] - - -@PARAM_SCHEDULERS.register_module(force=True) -class MultiStepParamScheduler(_ParamScheduler): - """Decays the specified parameter in each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the parameter from outside this - scheduler. - - Args: - optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - milestones (list): List of epoch indices. Must be increasing. - gamma (float): Multiplicative factor of parameter value decay. - Defaults to 0.1. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - milestones: list[int], - gamma: float = 0.1, - last_step: int = -1, - begin: int = 0, - end: int = INF, - by_epoch: bool = True, - verbose: bool = False, - ): - self.milestones = Counter(milestones) - self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch( - cls, - *args, - milestones, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs, - ): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - milestones = [i * epoch_length for i in milestones] - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls( - *args, - milestones=milestones, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs, - ) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step not in self.milestones: - return [group[self.param_name] for group in self.optimizer.param_groups] - return [ - group[self.param_name] * self.gamma ** self.milestones[self.last_step] - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module(force=True) -class ConstantParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by a small constant - factor until the number of epoch reaches a pre-defined milestone: ``end``. - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - factor (float): The number we multiply parameter value until the - milestone. Defaults to 1./3. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - factor: float = 1.0 / 3, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - if factor > 1.0 or factor < 0: - raise ValueError("Constant multiplicative factor should between 0 and 1.") - - self.factor = factor - self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch(cls, *args, begin=0, end=INF, by_epoch=True, epoch_length=None, **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [group[self.param_name] * self.factor for group in self.optimizer.param_groups] - - if self.last_step > self.total_iters or (self.last_step != self.total_iters): - return [group[self.param_name] for group in self.optimizer.param_groups] - - if self.last_step == self.total_iters: - return [group[self.param_name] * (1.0 / self.factor) for group in self.optimizer.param_groups] - - -@PARAM_SCHEDULERS.register_module(force=True) -class ExponentialParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by gamma every epoch. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - gamma (float): Multiplicative factor of parameter value decay. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - gamma: float, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch(cls, *args, begin=0, end=INF, by_epoch=True, epoch_length=None, **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [group[self.param_name] for group in self.optimizer.param_groups] - return [group[self.param_name] * self.gamma for group in self.optimizer.param_groups] - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineAnnealingParamScheduler(_ParamScheduler): - r"""Set the parameter value of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial value and - :math:`T_{cur}` is the number of epochs since the last restart in SGDR: - - .. math:: - \begin{aligned} - \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), - & T_{cur} \neq (2k+1)T_{max}; \\ - \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) - \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), - & T_{cur} = (2k+1)T_{max}. - \end{aligned} - - Notice that because the schedule - is defined recursively, the parameter value can be simultaneously modified - outside this scheduler by other operators. If the parameter value is set - solely by this scheduler, the parameter value at each step becomes: - - .. math:: - \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + - \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) - - It has been proposed in - `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this - only implements the cosine annealing part of SGDR, and not the restarts. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - T_max (int, optional): Maximum number of iterations. If not specified, - use ``end - begin``. Defaults to None. - eta_min (float, optional): Minimum parameter value. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - eta_min_ratio (float, optional): The ratio of the minimum parameter - value to the base parameter value. Either `eta_min` or - `eta_min_ratio` should be specified. Defaults to None. - New in version 0.3.2. - - .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: - https://arxiv.org/abs/1608.03983 - """ - - def __init__( - self, - optimizer: Optimizer | BaseOptimWrapper, - param_name: str, - T_max: int | None = None, - eta_min: float | None = None, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - eta_min_ratio: float | None = None, - ): - # To preserve backwards compatibility - if eta_min is None and eta_min_ratio is None: - eta_min = 0.0 - assert (eta_min is None) ^ (eta_min_ratio is None), "Either `eta_min` or `eta_min_ratio should be specified" - self.T_max = T_max or (end - begin) - self.eta_min = eta_min - self.eta_min_ratio = eta_min_ratio - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch( - cls, - *args, - T_max=None, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs, - ): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - if T_max is not None: - T_max = T_max * epoch_length - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, T_max=T_max, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self) -> list: - """Compute value using chainable form of the scheduler.""" - - def _get_eta_min(base_value): - if self.eta_min_ratio is None: - return self.eta_min - return base_value * self.eta_min_ratio - - if self.last_step == 0: - return [group[self.param_name] for group in self.optimizer.param_groups] - elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: - return [ - group[self.param_name] - + (base_value - _get_eta_min(base_value)) * (1 - math.cos(math.pi / self.T_max)) / 2 - for base_value, group in zip(self.base_values, self.optimizer.param_groups, strict=False) - ] - return [ - (1 + math.cos(math.pi * self.last_step / self.T_max)) - / (1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) - * (group[self.param_name] - _get_eta_min(base_value)) - + _get_eta_min(base_value) - for base_value, group in zip(self.base_values, self.optimizer.param_groups, strict=False) - ] - - -@PARAM_SCHEDULERS.register_module(force=True) -class LinearParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group by linearly changing - small multiplicative factor until the number of epoch reaches a pre-defined - milestone: ``end``. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - start_factor (float): The number we multiply parameter value in the - first epoch. The multiplication factor changes towards end_factor - in the following epochs. Defaults to 1./3. - end_factor (float): The number we multiply parameter value at the end - of linear changing process. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: Optimizer | BaseOptimWrapper, - param_name: str, - start_factor: float = 1.0 / 3, - end_factor: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - if start_factor > 1.0 or start_factor < 0: - raise ValueError("Starting multiplicative factor should between 0 and 1.") - - if end_factor > 1.0 or end_factor < 0: - raise ValueError("Ending multiplicative factor should between 0 and 1.") - - self.start_factor = start_factor - self.end_factor = end_factor - self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch(cls, *args, begin=0, end=INF, by_epoch=True, epoch_length=None, **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [group[self.param_name] * self.start_factor for group in self.optimizer.param_groups] - - return [ - group[self.param_name] - * ( - 1.0 - + (self.end_factor - self.start_factor) - / (self.total_iters * self.start_factor + (self.last_step - 1) * (self.end_factor - self.start_factor)) - ) - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module(force=True) -class PolyParamScheduler(_ParamScheduler): - """Decays the parameter value of each parameter group in a polynomial decay - scheme. - - Notice that such decay can happen simultaneously with other changes to the - parameter value from outside this scheduler. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to 0. - power (float): The power of the polynomial. Defaults to 1.0. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: Optimizer | BaseOptimWrapper, - param_name: str, - eta_min: float = 0, - power: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - self.eta_min = eta_min - self.power = power - self.total_iters = end - begin - 1 - - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch(cls, *args, begin=0, end=INF, by_epoch=True, epoch_length=None, **kwargs): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - if self.last_step == 0: - return [group[self.param_name] for group in self.optimizer.param_groups] - - return [ - (group[self.param_name] - self.eta_min) * (1 - 1 / (self.total_iters - self.last_step + 1)) ** self.power - + self.eta_min - for group in self.optimizer.param_groups - ] - - -@PARAM_SCHEDULERS.register_module(force=True) -class OneCycleParamScheduler(_ParamScheduler): - r"""Sets the parameters of each parameter group according to the 1cycle - learning rate policy. The 1cycle policy anneals the learning rate from an - initial learning rate to some maximum learning rate and then from that - maximum learning rate to some minimum learning rate much lower than the - initial learning rate. This policy was initially described in the paper - `Super-Convergence: Very Fast Training of Neural Networks Using Large - Learning Rates`_. - - The 1cycle learning rate policy changes the learning rate after every - batch. `step` should be called after a batch has been used for training. - - This scheduler is not chainable. - - Note also that the total number of steps in the cycle can be determined in - one of two ways (listed in order of precedence): - - #. A value for total_steps is explicitly provided. - #. If total_steps is not defined, begin and end of the ParamSchedul will - works for it. In this case, the number of total steps is inferred by - total_steps = end - begin - - The default behaviour of this scheduler follows the fastai implementation - of 1cycle, which claims that "unpublished work has shown even better - results by using only two phases". To mimic the behaviour of the original - paper instead, set ``three_phase=True``. - - Args: - optimizer (Optimizer): Wrapped optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - eta_max (float or list): Upper parameter value boundaries in the cycle - for each parameter group. - total_steps (int): The total number of steps in the cycle. Note that - if a value is not provided here, then it will be equal to - ``end - begin``. Defaults to None - pct_start (float): The percentage of the cycle (in number of steps) - spent increasing the learning rate. - Defaults to 0.3 - anneal_strategy (str): {'cos', 'linear'} - Specifies the annealing strategy: "cos" for cosine annealing, - "linear" for linear annealing. - Defaults to 'cos' - div_factor (float): Determines the initial learning rate via - initial_param = eta_max/div_factor - Defaults to 25 - final_div_factor (float): Determines the minimum learning rate via - eta_min = initial_param/final_div_factor - Defaults to 1e4 - three_phase (bool): If ``True``, use a third phase of the schedule to - annihilate the learning rate according to 'final_div_factor' - instead of modifying the second phase (the first two phases will be - symmetrical about the step indicated by 'pct_start'). - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: - https://arxiv.org/abs/1708.07120 - """ - - def __init__( - self, - optimizer: Optimizer | BaseOptimWrapper, - param_name: str, - eta_max: float = 0, - total_steps: int | None = None, - pct_start: float = 0.3, - anneal_strategy: str = "cos", - div_factor: float = 25.0, - final_div_factor: float = 1e4, - three_phase: bool = False, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - assert param_name == "lr", f"OneCycle only works for learning rate updating, but got patam_name as {param_name}" - - self.eta_max = eta_max - self.div_factor = div_factor - self.final_div_factor = final_div_factor - - # Validate total_steps - if total_steps is not None: - if total_steps <= 0 or not isinstance(total_steps, int): - raise ValueError(f"Expected positive integer total_steps, but got {total_steps}") - self.total_steps = total_steps - else: - self.total_steps = end - begin - - # Validate pct_start - if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): - raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}") - - # Validate anneal_strategy - if anneal_strategy not in ["cos", "linear"]: - raise ValueError(f'anneal_strategy must by one of "cos" or "linear", instead got {anneal_strategy}') - elif anneal_strategy == "cos": - self.anneal_func = self._annealing_cos - elif anneal_strategy == "linear": - self.anneal_func = self._annealing_linear - - if three_phase: - self._schedule_phases = [ - { - "end_step": float(pct_start * self.total_steps) - 1, - f"start_{param_name}": f"initial_{param_name}", - f"end_{param_name}": f"max_{param_name}", - }, - { - "end_step": float(2 * pct_start * self.total_steps) - 2, - f"start_{param_name}": f"max_{param_name}", - f"end_{param_name}": f"initial_{param_name}", - }, - { - "end_step": self.total_steps - 1, - f"start_{param_name}": f"initial_{param_name}", - f"end_{param_name}": f"min_{param_name}", - }, - ] - else: - self._schedule_phases = [ - { - "end_step": float(pct_start * self.total_steps) - 1, - f"start_{param_name}": f"initial_{param_name}", - f"end_{param_name}": f"max_{param_name}", - }, - { - "end_step": self.total_steps - 1, - f"start_{param_name}": f"max_{param_name}", - f"end_{param_name}": f"min_{param_name}", - }, - ] - - # Initialize parameters - max_values = self._format_param(f"max_{param_name}", optimizer, eta_max) - if last_step == -1: - for idx, group in enumerate(optimizer.param_groups): - group[f"initial_{param_name}"] = max_values[idx] / div_factor - group[f"max_{param_name}"] = max_values[idx] - group[f"min_{param_name}"] = group[f"initial_{param_name}"] / final_div_factor - - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, list | tuple): - if len(param) != len(optimizer.param_groups): - raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}") - return param - else: - return [param] * len(optimizer.param_groups) - - @staticmethod - def _annealing_cos(start, end, pct): - """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" - - cos_out = math.cos(math.pi * pct) + 1 - return end + (start - end) / 2.0 * cos_out - - @staticmethod - def _annealing_linear(start, end, pct): - """Linearly anneal from `start` to `end` as pct goes from 0.0 to - 1.0.""" - return (end - start) * pct + start - - @classmethod - def build_iter_from_epoch( - cls, - *args, - begin=0, - end=INF, - total_steps=None, - by_epoch=True, - epoch_length=None, - **kwargs, - ): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - if total_steps is not None: - total_steps = total_steps * epoch_length - return cls( - *args, - begin=begin, - end=end, - total_steps=total_steps, - by_epoch=by_epoch, - **kwargs, - ) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - - params = [] - step_num = self.last_step - - if step_num > self.total_steps: - raise ValueError( - f"Tried to step {step_num + 1} times. The specified number of total steps is {self.total_steps}" - ) - - for group in self.optimizer.param_groups: - start_step = 0 - for i, phase in enumerate(self._schedule_phases): - end_step = phase["end_step"] - if step_num <= end_step or i == len(self._schedule_phases) - 1: - pct = (step_num - start_step) / (end_step - start_step) - computed_param = self.anneal_func( - group[phase["start_" + self.param_name]], - group[phase["end_" + self.param_name]], - pct, - ) - break - start_step = phase["end_step"] - - params.append(computed_param) - - return params - - -@PARAM_SCHEDULERS.register_module(force=True) -class CosineRestartParamScheduler(_ParamScheduler): - """Sets the parameters of each parameter group according to the cosine - annealing with restarts scheme. The cosine restart policy anneals the - parameter from the initial value to `eta_min` with a cosine annealing - schedule and then restarts another period from the maximum value multiplied - with `restart_weight`. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - periods (list[int]): Periods for each cosine anneling cycle. - restart_weights (list[float]): Restart weights at each - restart iteration. Defaults to [1]. - eta_min (float, optional): Minimum parameter value at the end of - scheduling. Defaults to None. - eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `eta_min` or `eta_min_ratio` - should be specified. Defaults to None. - begin (int): Step at which to start updating the parameters. - Defaults to 0. - end (int): Step at which to stop updating the parameters. - Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - """ - - def __init__( - self, - optimizer: Optimizer | BaseOptimWrapper, - param_name: str, - periods: list[int], - restart_weights: Sequence[float] = (1,), - eta_min: float | None = None, - eta_min_ratio: float | None = None, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - assert (eta_min is None) ^ (eta_min_ratio is None) - self.periods = periods - self.eta_min = eta_min - self.eta_min_ratio = eta_min_ratio - self.restart_weights = restart_weights - assert len(self.periods) == len(self.restart_weights), ( - "periods and restart_weights should have the same length." - ) - self.cumulative_periods = [sum(self.periods[0 : i + 1]) for i in range(0, len(self.periods))] - - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose, - ) - - @classmethod - def build_iter_from_epoch( - cls, - *args, - periods, - begin=0, - end=INF, - by_epoch=True, - epoch_length=None, - **kwargs, - ): - """Build an iter-based instance of this scheduler from an epoch-based - config.""" - assert by_epoch, "Only epoch-based kwargs whose `by_epoch=True` can be converted to iter-based." - assert epoch_length is not None and epoch_length > 0, ( - f"`epoch_length` must be a positive integer, but got {epoch_length}." - ) - periods = [p * epoch_length for p in periods] - by_epoch = False - begin = int(begin * epoch_length) - if end != INF: - end = int(end * epoch_length) - return cls(*args, periods=periods, begin=begin, end=end, by_epoch=by_epoch, **kwargs) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - idx = self.get_position_from_periods(self.last_step, self.cumulative_periods) - # if current step is not in the periods, return origin parameters - if idx is None: - return [group[self.param_name] for group in self.optimizer.param_groups] - current_weight = self.restart_weights[idx] - nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] - current_periods = self.periods[idx] - step = self.last_step - nearest_restart - values = [] - for base_value, group in zip(self.base_values, self.optimizer.param_groups, strict=False): - eta_max = base_value * current_weight - if self.eta_min_ratio is None: - eta_min = self.eta_min - else: - eta_min = base_value * self.eta_min_ratio - if step == 0: - values.append(eta_max) - else: - values.append( - (1 + math.cos(math.pi * step / current_periods)) - / (1 + math.cos(math.pi * (step - 1) / current_periods)) - * (group[self.param_name] - eta_min) - + eta_min - ) - - return values - - @staticmethod - def get_position_from_periods(iteration: int, cumulative_periods: list[int]) -> int | None: - """Get the position from a period list. - - It will return the index of the right-closest number in the period - list. - For example, the cumulative_periods = [100, 200, 300, 400], - if iteration == 50, return 0; - if iteration == 210, return 2; - if iteration == 300, return 3. - - Args: - iteration (int): Current iteration. - cumulative_periods (list[int]): Cumulative period list. - - Returns: - Optional[int]: The position of the right-closest number in the - period list. If not in the period, return None. - """ - for i, period in enumerate(cumulative_periods): - if iteration < period: - return i - return None - - -@PARAM_SCHEDULERS.register_module(force=True) -class ReduceOnPlateauParamScheduler(_ParamScheduler): - """Reduce the parameters of each parameter group when a metric has stopped - improving. Models often benefit from reducing the parameters by a factor of - 2-10 once learning stagnates. This scheduler reads a metrics quantity and - if no improvement is seen for a ``patience`` number of epochs, the - parameters are reduced. - - The implementation is motivated by `PyTorch ReduceLROnPlateau`_. - - Args: - optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped - optimizer. - param_name (str): Name of the parameter to be adjusted, such as - ``lr``, ``momentum``. - monitor (str): The name of the metric to measure whether - the performance of the model is improved. - rule (str): One of `less`, `greater`. In `less` rule, parameters will - be reduced when the quantity monitored has stopped - decreasing; in `greater` rule it will be reduced when the - quantity monitored has stopped increasing. Defaults to 'less'. - The ``rule`` is the renaming of ``mode`` in pytorch. - factor (float): Factor by which the parameters will be - reduced. new_param = param * factor. Defaults to 0.1. - patience (int): Number of epochs with no improvement after - which parameters will be reduced. For example, if - ``patience = 2``, then we will ignore the first 2 epochs - with no improvement, and will only decrease the parameters after - the 3rd epoch if the monitor value still hasn't improved then. - Defaults to 10. - threshold (float): Threshold for measuring the new optimum, - to only focus on significant changes. Defaults to 1e-4. - threshold_rule (str): One of `rel`, `abs`. In `rel` rule, - dynamic_threshold = best * ( 1 + threshold ) in 'greater' - rule or best * ( 1 - threshold ) in `less` rule. - In `abs` rule, dynamic_threshold = best + threshold in - `greater` rule or best - threshold in `less` rule. - Defaults to 'rel'. - cooldown (int): Number of epochs to wait before resuming - normal operation after parameters have been reduced. Defaults to 0. - min_value (float or list[float]): A scalar or a sequence of scalars. - A lower bound on the parameters of each parameter group - respectively. Defaults to 0. . - eps (float): Minimal decay applied to parameters. If the difference - between new and old parameters are smaller than eps, the update is - ignored. Defaults to 1e-8. - begin (int): Step at which to start triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to 0. - end (int): Step at which to stop triggering the scheduler - to monitor in val within the interval calculated - according to epoch of training. Defaults to INF. - last_step (int): The index of last step. Used for resume without - state dict. Defaults to -1. - by_epoch (bool): Whether the scheduled parameters are updated by - epochs. Defaults to True. - verbose (bool): Whether to print the value for each update. - Defaults to False. - - .. _PyTorch ReduceLROnPlateau: - https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py - """ - - need_val_args = True - - def __init__( - self, - optimizer: OptimizerType, - param_name: str, - monitor: str = "loss", - rule: str = "less", - factor: float = 0.1, - patience: int = 10, - threshold: float = 1e-4, - threshold_rule: str = "rel", - cooldown: int = 0, - min_value: float | Sequence[float] = 0.0, - eps: float = 1e-8, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False, - ): - # Attach optimizer - if not isinstance(optimizer, Optimizer | BaseOptimWrapper): - raise TypeError(f"``optimizer`` should be an Optimizer,but got {type(optimizer).__name__}") - self.optimizer = optimizer - self.param_name = param_name - - if end <= begin: - raise ValueError(f"end should be larger than begin, but got begin={begin}, end={end}") - self.begin = begin - self.end = end - - assert by_epoch, f"Now {type(self).__name__} only support by_epoch=True" - self.by_epoch = by_epoch - - assert isinstance(last_step, int) and last_step >= -1 - # Initialize valid step count and base values - if last_step == -1: - for group in optimizer.param_groups: - # If the param is never be scheduled, record the current value - # as the initial value. - group.setdefault(f"initial_{param_name}", group[param_name]) - else: - for i, group in enumerate(optimizer.param_groups): - if f"initial_{param_name}" not in group: - raise KeyError( - f"param 'initial_{param_name}' is not specified in param_groups[{{}}] when resuming an optimizer".format( - i - ) - ) - - self.last_step = last_step - - self._global_step = 0 - self.verbose = verbose - - if factor >= 1.0: - raise ValueError("Factor should be < 1.0.") - self.factor = factor - - # This code snippet handles compatibility with the optimizer wrapper. - # The optimizer wrapper includes an additional parameter to record the - # base learning rate (lr) which is not affected by the paramwise_cfg. - # By retrieving the base lr, we can obtain the actual base lr that - # reflects the learning progress. - if isinstance(optimizer, BaseOptimWrapper): - raw_optimizer = optimizer.optimizer - else: - raw_optimizer = optimizer - - if isinstance(min_value, list | tuple): - if len(min_value) != len(raw_optimizer.param_groups): - raise ValueError(f"expected {len(raw_optimizer.param_groups)} min_lrs, got {len(min_value)}") - self.min_values = list(min_value) - # Consider the `min_value` of the last param_groups - # as the base setting. And we only add this value when - # the optimizer is OptimWrapper. - if isinstance(optimizer, BaseOptimWrapper) and optimizer.base_param_settings is not None: # type: ignore - self.min_values.append(self.min_values[-1]) - - else: - self.min_values = [min_value] * len(optimizer.param_groups) # type: ignore - - self.patience = patience - self.cooldown = cooldown - self.cooldown_counter = 0 - self.rule_worse = None # the worse value for the chosen mode - self.best = None - self.num_bad_epochs = 0 - self.eps = eps - - self.monitor = monitor - self._init_is_better(rule=rule, threshold=threshold, threshold_rule=threshold_rule) - self._reset() - - # remove call self.step() and init self._global_step = 0 - self._last_value = [group[self.param_name] for group in self.optimizer.param_groups] - - def step(self, metrics=None): - """Adjusts the parameter value of each parameter group based on the - specified schedule. - - Args: - metrics (Dict[str, float], optional): Evaluation results of all - metrics on validation dataset. The keys are the names of the - metrics, and the values are corresponding results. - Defaults to None. - """ - if metrics is None: - # only to count self._global_step - self._global_step += 1 - return - - if not isinstance(metrics, dict): - raise TypeError(f"metrics type should be dict, but got type {type(metrics)}") - - # Compute parameter value per param group in the effective range - if self.begin <= self._global_step < self.end: - self.last_step += 1 - - # convert `metric` to float, in case it's a zero-dim Tensor - metric = metrics.get(self.monitor, None) - if metric is not None: - if self._is_better(metric, self.best): - self.best = metric - self.num_bad_epochs = 0 - else: - self.num_bad_epochs += 1 - - if self._in_cooldown(): - self.cooldown_counter -= 1 - self.num_bad_epochs = 0 # ignore bad epochs in cooldown - - if self.num_bad_epochs > self.patience: - values = self._get_value() - - for i, data in enumerate(zip(self.optimizer.param_groups, values, strict=False)): - param_group, value = data - if param_group[self.param_name] - value > self.eps: - param_group[self.param_name] = value - self.print_value(self.verbose, i, value) - self.cooldown_counter = self.cooldown - self.num_bad_epochs = 0 - - else: - raise KeyError(f"Excepted key in {list(metrics.keys())}, but got key {self.monitor} is not in dict") - - self._last_value = [group[self.param_name] for group in self.optimizer.param_groups] - - def print_value(self, is_verbose: bool, group: int, value: float) -> None: - """Display the current parameter value. - - Args: - is_verbose (bool): Whether to print the value. - group (int): The index of the current ``param_group``. - value (float): The parameter value. - """ - if is_verbose: - step_name = "epoch" if self.by_epoch else "iter" - print_log( - f"Adjusting parameter value of group {group} to {value:.4e} in {step_name} {self.last_step}.", - logger="current", - ) - - def _get_value(self): - """Compute value using chainable form of the scheduler.""" - values = [float(group[self.param_name]) * self.factor for group in self.optimizer.param_groups] - return [max(v, min_v) for v, min_v in zip(values, self.min_values, strict=False)] - - def _in_cooldown(self): - """Judge whether it is in cooldown.""" - return self.cooldown_counter > 0 - - def _is_better(self, a, best): - """Judge whether the monitor value is better.""" - if self.rule == "less" and self.threshold_rule == "rel": - rel_epsilon = 1.0 - self.threshold - return a < best * rel_epsilon - - elif self.rule == "less" and self.threshold_rule == "abs": - return a < best - self.threshold - - elif self.rule == "greater" and self.threshold_rule == "rel": - rel_epsilon = self.threshold + 1.0 - return a > best * rel_epsilon - - else: # rule == 'greater' and epsilon_mode == 'abs': - return a > best + self.threshold - - def _init_is_better(self, rule, threshold, threshold_rule): - """Initialize rule and its associated values.""" - if threshold < 0: - raise ValueError(f"threshold {threshold} should be >= 0.") - if rule not in {"less", "greater"}: - raise ValueError(f"mode {rule} is unknown!") - if threshold_rule not in {"rel", "abs"}: - raise ValueError(f"threshold mode {threshold_rule} is unknown!") - - if rule == "less": - self.rule_worse = INF - else: # rule == 'greater': - self.rule_worse = -INF - - self.rule = rule - self.threshold = threshold - self.threshold_rule = threshold_rule - - def _reset(self): - """Resets num_bad_epochs counter and cooldown counter.""" - self.best = self.rule_worse - self.cooldown_counter = 0 - self.num_bad_epochs = 0 diff --git a/libs/visengine/visengine/registry/__init__.py b/libs/visengine/visengine/registry/__init__.py deleted file mode 100644 index 867c1f9..0000000 --- a/libs/visengine/visengine/registry/__init__.py +++ /dev/null @@ -1,72 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .build_functions import ( - build_from_cfg, - build_model_from_cfg, - build_runner_from_cfg, - build_scheduler_from_cfg, -) -from .default_scope import DefaultScope -from .registry import Registry -from .root import ( - DATA_SAMPLERS, - DATASETS, - EVALUATOR, - FUNCTIONS, - HOOKS, - INFERENCERS, - LOG_PROCESSORS, - LOOPS, - METRICS, - MODEL_WRAPPERS, - MODELS, - OPTIM_WRAPPER_CONSTRUCTORS, - OPTIM_WRAPPERS, - OPTIMIZERS, - PARAM_SCHEDULERS, - RUNNER_CONSTRUCTORS, - RUNNERS, - STRATEGIES, - TASK_UTILS, - TRANSFORMS, - VISBACKENDS, - VISUALIZERS, - WEIGHT_INITIALIZERS, -) -from .utils import count_registered_modules, init_default_scope, traverse_registry_tree - -__all__ = [ - "DATASETS", - "DATA_SAMPLERS", - "EVALUATOR", - "FUNCTIONS", - "HOOKS", - "INFERENCERS", - "LOG_PROCESSORS", - "LOOPS", - "METRICS", - "MODELS", - "MODEL_WRAPPERS", - "OPTIMIZERS", - "OPTIM_WRAPPERS", - "OPTIM_WRAPPER_CONSTRUCTORS", - "PARAM_SCHEDULERS", - "RUNNERS", - "RUNNER_CONSTRUCTORS", - "STRATEGIES", - "TASK_UTILS", - "TRANSFORMS", - "VISBACKENDS", - "VISUALIZERS", - "WEIGHT_INITIALIZERS", - "DefaultScope", - "Registry", - "build_from_cfg", - "build_model_from_cfg", - "build_runner_from_cfg", - "build_scheduler_from_cfg", - "count_registered_modules", - "init_default_scope", - "traverse_registry_tree", -] diff --git a/libs/visengine/visengine/registry/build_functions.py b/libs/visengine/visengine/registry/build_functions.py deleted file mode 100644 index b6da052..0000000 --- a/libs/visengine/visengine/registry/build_functions.py +++ /dev/null @@ -1,301 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -from typing import Any, Union - -from visengine.config import Config, ConfigDict -from visengine.utils import ManagerMixin, digit_version - -from .registry import Registry - -from typing import TYPE_CHECKING, Any, Union - -import torch.nn as nn - -# Neither of these imports are actually neeed, they're just for type checking -# from mmengine.optim.scheduler import _ParamScheduler -# from mmengine.runner import Runner - -import torch - - -def build_from_cfg( - cfg: dict | ConfigDict | Config, - registry: Registry, - default_args: dict | ConfigDict | Config | None = None, -) -> Any: - """Build a module from config dict when it is a class configuration, or - call a function from config dict when it is a function configuration. - - If the global variable default scope (:obj:`DefaultScope`) exists, - :meth:`build` will firstly get the responding registry and then call - its own :meth:`build`. - - At least one of the ``cfg`` and ``default_args`` contains the key "type", - which should be either str or class. If they all contain it, the key - in ``cfg`` will be used because ``cfg`` has a high priority than - ``default_args`` that means if a key exists in both of them, the value of - the key will be ``cfg[key]``. They will be merged first and the key "type" - will be popped up and the remaining keys will be used as initialization - arguments. - - Examples: - >>> from mmengine import Registry, build_from_cfg - >>> MODELS = Registry('models') - >>> @MODELS.register_module(force=True) - >>> class ResNet: - >>> def __init__(self, depth, stages=4): - >>> self.depth = depth - >>> self.stages = stages - >>> cfg = dict(type='ResNet', depth=50) - >>> model = build_from_cfg(cfg, MODELS) - >>> # Returns an instantiated object - >>> @MODELS.register_module(force=True) - >>> def resnet50(): - >>> pass - >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) - >>> # Return a result of the calling function - - Args: - cfg (dict or ConfigDict or Config): Config dict. It should at least - contain the key "type". - registry (:obj:`Registry`): The registry to search the type from. - default_args (dict or ConfigDict or Config, optional): Default - initialization arguments. Defaults to None. - - Returns: - object: The constructed object. - """ - # Avoid circular import - from ..logging import print_log - - if not isinstance(cfg, dict | ConfigDict | Config): - raise TypeError(f"cfg should be a dict, ConfigDict or Config, but got {type(cfg)}") - - if "type" not in cfg: - if default_args is None or "type" not in default_args: - raise KeyError(f'`cfg` or `default_args` must contain the key "type", but got {cfg}\n{default_args}') - - if not isinstance(registry, Registry): - raise TypeError(f"registry must be a mmengine.Registry object, but got {type(registry)}") - - if not (isinstance(default_args, dict | ConfigDict | Config) or default_args is None): - raise TypeError(f"default_args should be a dict, ConfigDict, Config or None, but got {type(default_args)}") - - args = cfg.copy() - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - - # Instance should be built under target scope, if `_scope_` is defined - # in cfg, current default scope should switch to specified scope - # temporarily. - scope = args.pop("_scope_", None) - with registry.switch_scope_and_registry(scope) as registry: - obj_type = args.pop("type") - if isinstance(obj_type, str): - obj_cls = registry.get(obj_type) - if obj_cls is None: - raise KeyError( - f"{obj_type} is not in the {registry.scope}::{registry.name} registry. " - f"Please check whether the value of `{obj_type}` is " - "correct or it was registered as expected. More details " - "can be found at " - "https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module" - ) - # this will include classes, functions, partial functions and more - elif callable(obj_type): - obj_cls = obj_type - else: - raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") - - # If `obj_cls` inherits from `ManagerMixin`, it should be - # instantiated by `ManagerMixin.get_instance` to ensure that it - # can be accessed globally. - if inspect.isclass(obj_cls) and issubclass(obj_cls, ManagerMixin): # type: ignore - obj = obj_cls.get_instance(**args) # type: ignore - else: - obj = obj_cls(**args) # type: ignore - - if inspect.isclass(obj_cls) or inspect.isfunction(obj_cls) or inspect.ismethod(obj_cls): - print_log( - f"An `{obj_cls.__name__}` instance is built from " # type: ignore - "registry, and its implementation can be found in " - f"{obj_cls.__module__}", # type: ignore - logger="current", - level=logging.DEBUG, - ) - else: - print_log( - f"An instance is built from registry, and its constructor is {obj_cls}", - logger="current", - level=logging.DEBUG, - ) - return obj - - -def build_runner_from_cfg(cfg: dict | ConfigDict | Config, registry: Registry): # -> Runner: - """Build a Runner object. - - Examples: - >>> from visengine.registry import Registry, build_runner_from_cfg - >>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg) - >>> @RUNNERS.register_module(force=True) - >>> class CustomRunner(Runner): - >>> def setup_env(env_cfg): - >>> pass - >>> cfg = dict(runner_type='CustomRunner', ...) - >>> custom_runner = RUNNERS.build(cfg) - - Args: - cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key - exists, it will be used to build a custom runner. Otherwise, it - will be used to build a default runner. - registry (:obj:`Registry`): The registry to search the type from. - - Returns: - object: The constructed runner object. - """ - - assert isinstance(cfg, dict | ConfigDict | Config), ( - f"cfg should be a dict, ConfigDict or Config, but got {type(cfg)}" - ) - assert isinstance(registry, Registry), ( - "registry should be a mmengine.Registry object", - f"but got {type(registry)}", - ) - - args = cfg.copy() - # Runner should be built under target scope, if `_scope_` is defined - # in cfg, current default scope should switch to specified scope - # temporarily. - scope = args.pop("_scope_", None) - with registry.switch_scope_and_registry(scope) as registry: - obj_type = args.get("runner_type", "Runner") - if isinstance(obj_type, str): - runner_cls = registry.get(obj_type) - if runner_cls is None: - raise KeyError( - f"{obj_type} is not in the {registry.name} registry. " - f"Please check whether the value of `{obj_type}` is " - "correct or it was registered as expected. More details " - "can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module" - ) - elif inspect.isclass(obj_type): - runner_cls = obj_type - else: - raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") - - runner = runner_cls.from_cfg(args) # type: ignore - print_log( - f"An `{runner_cls.__name__}` instance is built from " # type: ignore - "registry, its implementation can be found in" - f"{runner_cls.__module__}", # type: ignore - logger="current", - level=logging.DEBUG, - ) - return runner - - -def build_model_from_cfg( - cfg: dict | ConfigDict | Config, - registry: Registry, - default_args: Union[dict, "ConfigDict", "Config"] | None = None, -) -> "nn.Module": - """Build a PyTorch model from config dict(s). Different from - ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. - - Args: - cfg (dict, list[dict]): The config of modules, which is either a config - dict or a list of config dicts. If cfg is a list, the built - modules will be wrapped with ``nn.Sequential``. - registry (:obj:`Registry`): A registry the module belongs to. - default_args (dict, optional): Default arguments to build the module. - Defaults to None. - - Returns: - nn.Module: A built nn.Module. - """ - from ..model import Sequential - - if isinstance(cfg, list): - modules = [build_from_cfg(_cfg, registry, default_args) for _cfg in cfg] - return Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) - - -def build_optimizer_from_cfg( - cfg: dict | ConfigDict | Config, - registry: Registry, - default_args: dict | ConfigDict | Config | None = None, -) -> Any: - if "type" in cfg and "Adafactor" == cfg["type"] and digit_version(torch.__version__) >= digit_version("2.5.0"): - print_log("the torch version of Adafactor is registered as TorchAdafactor") - return build_from_cfg(cfg, registry, default_args) - - -def build_scheduler_from_cfg( - cfg: dict | ConfigDict | Config, - registry: Registry, - default_args: dict | ConfigDict | Config | None = None, -): # -> "_ParamScheduler": - """Builds a ``ParamScheduler`` instance from config. - - ``ParamScheduler`` supports building instance by its constructor or - method ``build_iter_from_epoch``. Therefore, its registry needs a build - function to handle both cases. - - Args: - cfg (dict or ConfigDict or Config): Config dictionary. If it contains - the key ``convert_to_iter_based``, instance will be built by method - ``convert_to_iter_based``, otherwise instance will be built by its - constructor. - registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry. - default_args (dict or ConfigDict or Config, optional): Default - initialization arguments. It must contain key ``optimizer``. If - ``convert_to_iter_based`` is defined in ``cfg``, it must - additionally contain key ``epoch_length``. Defaults to None. - - Returns: - object: The constructed ``ParamScheduler``. - """ - assert isinstance(cfg, dict | ConfigDict | Config), ( - f"cfg should be a dict, ConfigDict or Config, but got {type(cfg)}" - ) - assert isinstance(registry, Registry), ( - "registry should be a mmengine.Registry object", - f"but got {type(registry)}", - ) - - args = cfg.copy() - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - scope = args.pop("_scope_", None) - with registry.switch_scope_and_registry(scope) as registry: - convert_to_iter = args.pop("convert_to_iter_based", False) - if convert_to_iter: - scheduler_type = args.pop("type") - assert "epoch_length" in args and args.get("by_epoch", True), ( - "Only epoch-based parameter scheduler can be converted to iter-based, and `epoch_length` should be set" - ) - if isinstance(scheduler_type, str): - scheduler_cls = registry.get(scheduler_type) - if scheduler_cls is None: - raise KeyError( - f"{scheduler_type} is not in the {registry.name} " - "registry. Please check whether the value of " - f"`{scheduler_type}` is correct or it was registered " - "as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module" - ) - elif inspect.isclass(scheduler_type): - scheduler_cls = scheduler_type - else: - raise TypeError(f"type must be a str or valid type, but got {type(scheduler_type)}") - return scheduler_cls.build_iter_from_epoch(**args) # type: ignore - else: - args.pop("epoch_length", None) - return build_from_cfg(args, registry) diff --git a/libs/visengine/visengine/registry/default_scope.py b/libs/visengine/visengine/registry/default_scope.py deleted file mode 100644 index 396f3c5..0000000 --- a/libs/visengine/visengine/registry/default_scope.py +++ /dev/null @@ -1,96 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import Optional - -from visengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock - - -class DefaultScope(ManagerMixin): - """Scope of current task used to reset the current registry, which can be - accessed globally. - - Consider the case of resetting the current ``Registry`` by - ``default_scope`` in the internal module which cannot access runner - directly, it is difficult to get the ``default_scope`` defined in - ``Runner``. However, if ``Runner`` created ``DefaultScope`` instance - by given ``default_scope``, the internal module can get - ``default_scope`` by ``DefaultScope.get_current_instance`` everywhere. - - Args: - name (str): Name of default scope for global access. - scope_name (str): Scope of current task. - - Examples: - >>> from visengine.model import MODELS - >>> # Define default scope in runner. - >>> DefaultScope.get_instance('task', scope_name='mmdet') - >>> # Get default scope globally. - >>> scope_name = DefaultScope.get_instance('task').scope_name - """ - - def __init__(self, name: str, scope_name: str): - super().__init__(name) - assert isinstance(scope_name, str), f"scope_name should be a string, but got {scope_name}" - self._scope_name = scope_name - - @property - def scope_name(self) -> str: - """ - Returns: - str: Get current scope. - """ - return self._scope_name - - @classmethod - def get_current_instance(cls) -> Optional["DefaultScope"]: - """Get latest created default scope. - - Since default_scope is an optional argument for ``Registry.build``. - ``get_current_instance`` should return ``None`` if there is no - ``DefaultScope`` created. - - Examples: - >>> default_scope = DefaultScope.get_current_instance() - >>> # There is no `DefaultScope` created yet, - >>> # `get_current_instance` return `None`. - >>> default_scope = DefaultScope.get_instance( - >>> 'instance_name', scope_name='mmengine') - >>> default_scope.scope_name - mmengine - >>> default_scope = DefaultScope.get_current_instance() - >>> default_scope.scope_name - mmengine - - Returns: - Optional[DefaultScope]: Return None If there has not been - ``DefaultScope`` instance created yet, otherwise return the - latest created DefaultScope instance. - """ - _accquire_lock() - if cls._instance_dict: - instance = super().get_current_instance() - else: - instance = None - _release_lock() - return instance - - @classmethod - @contextmanager - def overwrite_default_scope(cls, scope_name: str | None) -> Generator: - """Overwrite the current default scope with `scope_name`""" - if scope_name is None: - yield - else: - tmp = copy.deepcopy(cls._instance_dict) - # To avoid create an instance with the same name. - time.sleep(1e-6) - cls.get_instance(f"overwrite-{time.time()}", scope_name=scope_name) - try: - yield - finally: - cls._instance_dict = tmp diff --git a/libs/visengine/visengine/registry/registry.py b/libs/visengine/visengine/registry/registry.py deleted file mode 100644 index bffe7a2..0000000 --- a/libs/visengine/visengine/registry/registry.py +++ /dev/null @@ -1,691 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import logging -import sys -from collections.abc import Callable, Generator -from contextlib import contextmanager -from importlib import import_module -from typing import Any, Optional - -from rich.console import Console -from rich.table import Table - -from visengine.config.utils import MODULE2PACKAGE -from visengine.utils import get_object_from_string, is_seq_of - -from .default_scope import DefaultScope - - -class Registry: - """A registry to map strings to classes or functions. - - Registered object could be built from registry. Meanwhile, registered - functions could be called from registry. - - Args: - name (str): Registry name. - build_func (callable, optional): A function to construct instance - from Registry. :func:`build_from_cfg` is used if neither ``parent`` - or ``build_func`` is specified. If ``parent`` is specified and - ``build_func`` is not given, ``build_func`` will be inherited - from ``parent``. Defaults to None. - parent (:obj:`Registry`, optional): Parent registry. The class - registered in children registry could be built from parent. - Defaults to None. - scope (str, optional): The scope of registry. It is the key to search - for children registry. If not specified, scope will be the name of - the package where class is defined, e.g. mmdet, mmcls, mmseg. - Defaults to None. - locations (list): The locations to import the modules registered - in this registry. Defaults to []. - New in version 0.4.0. - - Examples: - >>> # define a registry - >>> MODELS = Registry('models') - >>> # registry the `ResNet` to `MODELS` - >>> @MODELS.register_module(force=True) - >>> class ResNet: - >>> pass - >>> # build model from `MODELS` - >>> resnet = MODELS.build(dict(type='ResNet')) - >>> @MODELS.register_module(force=True) - >>> def resnet50(): - >>> pass - >>> resnet = MODELS.build(dict(type='resnet50')) - - >>> # hierarchical registry - >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') - >>> @DETECTORS.register_module(force=True) - >>> class FasterRCNN: - >>> pass - >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) - - >>> # add locations to enable auto import - >>> DETECTORS = Registry('detectors', parent=MODELS, - >>> scope='det', locations=['det.models.detectors']) - >>> # define this class in 'det.models.detectors' - >>> @DETECTORS.register_module(force=True) - >>> class MaskRCNN: - >>> pass - >>> # The registry will auto import det.models.detectors.MaskRCNN - >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) - - More advanced usages can be found at - https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. - """ - - def __init__( - self, - name: str, - build_func: Callable | None = None, - parent: Optional["Registry"] = None, - scope: str | None = None, - locations: list | None = None, - ): - from .build_functions import build_from_cfg - - if locations is None: - locations = [] - self._name = name - self._module_dict: dict[str, type] = {} - self._children: dict[str, Registry] = {} - self._locations = locations - self._imported = False - - if scope is not None: - assert isinstance(scope, str) - self._scope = scope - else: - self._scope = self.infer_scope() - - # See https://mypy.readthedocs.io/en/stable/common_issues.html# - # variables-vs-type-aliases for the use - self.parent: Registry | None - if parent is not None: - assert isinstance(parent, Registry) - parent._add_child(self) - self.parent = parent - else: - self.parent = None - - # self.build_func will be set with the following priority: - # 1. build_func - # 2. parent.build_func - # 3. build_from_cfg - self.build_func: Callable - if build_func is None: - if self.parent is not None: - self.build_func = self.parent.build_func - else: - self.build_func = build_from_cfg - else: - self.build_func = build_func - - def __len__(self): - return len(self._module_dict) - - def __contains__(self, key): - return self.get(key) is not None - - def __repr__(self): - table = Table(title=f"Registry of {self._name}") - table.add_column("Names", justify="left", style="cyan") - table.add_column("Objects", justify="left", style="green") - - for name, obj in sorted(self._module_dict.items()): - table.add_row(name, str(obj)) - - console = Console() - with console.capture() as capture: - console.print(table, end="") - - return capture.get() - - @staticmethod - def infer_scope() -> str: - """Infer the scope of registry. - - The name of the package where registry is defined will be returned. - - Returns: - str: The inferred scope name. - - Examples: - >>> # in mmdet/models/backbone/resnet.py - >>> MODELS = Registry('models') - >>> @MODELS.register_module(force=True) - >>> class ResNet: - >>> pass - >>> # The scope of ``ResNet`` will be ``mmdet``. - """ - from ..logging import print_log - - # `sys._getframe` returns the frame object that many calls below the - # top of the stack. The call stack for `infer_scope` can be listed as - # follow: - # frame-0: `infer_scope` itself - # frame-1: `__init__` of `Registry` which calls the `infer_scope` - # frame-2: Where the `Registry(...)` is called - module = inspect.getmodule(sys._getframe(2)) - if module is not None: - filename = module.__name__ - split_filename = filename.split(".") - scope = split_filename[0] - else: - # use "mmengine" to handle some cases which can not infer the scope - # like initializing Registry in interactive mode - scope = "mmengine" - print_log( - 'set scope as "mmengine" when scope can not be inferred. You ' - 'can silence this warning by passing a "scope" argument to ' - 'Registry like `Registry(name, scope="toy")`', - logger="current", - level=logging.WARNING, - ) - - return scope - - @staticmethod - def split_scope_key(key: str) -> tuple[str | None, str]: - """Split scope and key. - - The first scope will be split from key. - - Return: - tuple[str | None, str]: The former element is the first scope of - the key, which can be ``None``. The latter is the remaining key. - - Examples: - >>> Registry.split_scope_key('mmdet.ResNet') - 'mmdet', 'ResNet' - >>> Registry.split_scope_key('ResNet') - None, 'ResNet' - """ - split_index = key.find(".") - if split_index != -1: - return key[:split_index], key[split_index + 1 :] - else: - return None, key - - @property - def name(self): - return self._name - - @property - def scope(self): - return self._scope - - @property - def module_dict(self): - return self._module_dict - - @property - def children(self): - return self._children - - @property - def root(self): - return self._get_root_registry() - - @contextmanager - def switch_scope_and_registry(self, scope: str | None) -> Generator: - """Temporarily switch default scope to the target scope, and get the - corresponding registry. - - If the registry of the corresponding scope exists, yield the - registry, otherwise yield the current itself. - - Args: - scope (str, optional): The target scope. - - Examples: - >>> from visengine.registry import Registry, DefaultScope, MODELS - >>> import time - >>> # External Registry - >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', - >>> parent=MODELS) - >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', - >>> parent=MODELS) - >>> # Local Registry - >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', - >>> parent=MODELS) - >>> - >>> # Initiate DefaultScope - >>> DefaultScope.get_instance(f'scope_{time.time()}', - >>> scope_name='custom') - >>> # Check default scope - >>> DefaultScope.get_current_instance().scope_name - custom - >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> registry.scope - mmcls - >>> # Nested switch scope - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: - >>> DefaultScope.get_current_instance().scope_name - mmdet - >>> mmdet_registry.scope - mmdet - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> mmcls_registry.scope - mmcls - >>> - >>> # Check switch back to original scope. - >>> DefaultScope.get_current_instance().scope_name - custom - """ - from ..logging import print_log - - # Switch to the given scope temporarily. If the corresponding registry - # can be found in root registry, return the registry under the scope, - # otherwise return the registry itself. - with DefaultScope.overwrite_default_scope(scope): - # Get the global default scope - default_scope = DefaultScope.get_current_instance() - # Get registry by scope - if default_scope is not None: - scope_name = default_scope.scope_name - try: - import_module(f"{scope_name}.registry") - except (ImportError, AttributeError, ModuleNotFoundError): - if scope in MODULE2PACKAGE: - print_log( - f"{scope} is not installed and its " - "modules will not be registered. If you " - "want to use modules defined in " - f"{scope}, Please install {scope} by " - f"`pip install {MODULE2PACKAGE[scope]}.", - logger="current", - level=logging.WARNING, - ) - else: - print_log( - f"Failed to import `{scope}.registry` make sure the registry.py exists in `{scope}` package.", - logger="current", - level=logging.WARNING, - ) - root = self._get_root_registry() - registry = root._search_child(scope_name) - if registry is None: - # if `default_scope` can not be found, fallback to argument - # `registry` - print_log( - f'Failed to search registry with scope "{scope_name}" ' - f'in the "{root.name}" registry tree. ' - f'As a workaround, the current "{self.name}" registry ' - f'in "{self.scope}" is used to build instance. This ' - "may cause unexpected failure when running the built " - f'modules. Please check whether "{scope_name}" is a ' - "correct scope, or whether the registry is " - "initialized.", - logger="current", - level=logging.WARNING, - ) - registry = self - # If there is no built default scope, just return current registry. - else: - registry = self - yield registry - - def _get_root_registry(self) -> "Registry": - """Return the root registry.""" - root = self - while root.parent is not None: - root = root.parent - return root - - def import_from_location(self) -> None: - """Import modules from the pre-defined locations in self._location.""" - if not self._imported: - # Avoid circular import - from ..logging import print_log - - # avoid BC breaking - if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: - print_log( - f'The "{self.name}" registry in {self.scope} did not ' - "set import location. Fallback to call " - f"`{self.scope}.utils.register_all_modules` " - "instead.", - logger="current", - level=logging.DEBUG, - ) - try: - module = import_module(f"{self.scope}.utils") - except (ImportError, AttributeError, ModuleNotFoundError): - if self.scope in MODULE2PACKAGE: - print_log( - f"{self.scope} is not installed and its " - "modules will not be registered. If you " - "want to use modules defined in " - f"{self.scope}, Please install {self.scope} by " - f"`pip install {MODULE2PACKAGE[self.scope]}.", - logger="current", - level=logging.WARNING, - ) - else: - print_log( - f"Failed to import {self.scope} and register its modules, please make sure you have registered the module manually.", - logger="current", - level=logging.WARNING, - ) - else: - # The import errors triggered during the registration - # may be more complex, here just throwing - # the error to avoid causing more implicit registry errors - # like `xxx`` not found in `yyy` registry. - module.register_all_modules(False) # type: ignore - - for loc in self._locations: - import_module(loc) - print_log( - f"Modules of {self.scope}'s {self.name} registry have been automatically imported from {loc}", - logger="current", - level=logging.DEBUG, - ) - self._imported = True - - def get(self, key: str) -> type | None: - """Get the registry record. - - If `key`` represents the whole object name with its module - information, for example, `mmengine.model.BaseModel`, ``get`` - will directly return the class object :class:`BaseModel`. - - Otherwise, it will first parse ``key`` and check whether it - contains a scope name. The logic to search for ``key``: - - - ``key`` does not contain a scope name, i.e., it is purely a module - name like "ResNet": :meth:`get` will search for ``ResNet`` from the - current registry to its parent or ancestors until finding it. - - - ``key`` contains a scope name and it is equal to the scope of the - current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` - will only search for ``ResNet`` in the current registry. - - - ``key`` contains a scope name and it is not equal to the scope of - the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the - scope exists in its children, :meth:`get` will get "FCNet" from - them. If not, :meth:`get` will first get the root registry and root - registry call its own :meth:`get` method. - - Args: - key (str): Name of the registered item, e.g., the class name in - string format. - - Returns: - Type or None: Return the corresponding class if ``key`` exists, - otherwise return None. - - Examples: - >>> # define a registry - >>> MODELS = Registry('models') - >>> # register `ResNet` to `MODELS` - >>> @MODELS.register_module(force=True) - >>> class ResNet: - >>> pass - >>> resnet_cls = MODELS.get('ResNet') - - >>> # hierarchical registry - >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') - >>> # `ResNet` does not exist in `DETECTORS` but `get` method - >>> # will try to search from its parents or ancestors - >>> resnet_cls = DETECTORS.get('ResNet') - >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') - >>> @CLASSIFIER.register_module(force=True) - >>> class MobileNet: - >>> pass - >>> # `get` from its sibling registries - >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') - """ - # Avoid circular import - from ..logging import print_log - - if not isinstance(key, str): - raise TypeError(f"The key argument of `Registry.get` must be a str, got {type(key)}") - - scope, real_key = self.split_scope_key(key) - obj_cls = None - registry_name = self.name - scope_name = self.scope - - # lazy import the modules to register them into the registry - self.import_from_location() - - if scope is None or scope == self._scope: - # get from self - if real_key in self._module_dict: - obj_cls = self._module_dict[real_key] - elif scope is None: - # try to get the target from its parent or ancestors - parent = self.parent - while parent is not None: - if real_key in parent._module_dict: - obj_cls = parent._module_dict[real_key] - registry_name = parent.name - scope_name = parent.scope - break - parent = parent.parent - else: - # import the registry to add the nodes into the registry tree - try: - import_module(f"{scope}.registry") - print_log( - f"Registry node of {scope} has been automatically imported.", - logger="current", - level=logging.DEBUG, - ) - except (ImportError, AttributeError, ModuleNotFoundError): - print_log( - f"Cannot auto import {scope}.registry, please check " - f'whether the package "{scope}" is installed correctly ' - "or import the registry manually.", - logger="current", - level=logging.DEBUG, - ) - # get from self._children - if scope in self._children: - obj_cls = self._children[scope].get(real_key) - registry_name = self._children[scope].name - scope_name = scope - else: - root = self._get_root_registry() - - if scope != root._scope and scope not in root._children: - # If not skip directly, `root.get(key)` will recursively - # call itself until RecursionError is thrown. - pass - else: - obj_cls = root.get(key) - - if obj_cls is None: - # Actually, it's strange to implement this `try ... except` to - # get the object by its name in `Registry.get`. However, If we - # want to build the model using a configuration like - # `dict(type='visengine.model.BaseModel')`, which can - # be dumped by lazy import config, we need this code snippet - # for `Registry.get` to work. - try: - obj_cls = get_object_from_string(key) - except Exception: - raise RuntimeError(f"Failed to get {key}") - - if obj_cls is not None: - # For some rare cases (e.g. obj_cls is a partial function), obj_cls - # doesn't have `__name__`. Use default value to prevent error - cls_name = getattr(obj_cls, "__name__", str(obj_cls)) - print_log( - f'Get class `{cls_name}` from "{registry_name}" registry in "{scope_name}"', - logger="current", - level=logging.DEBUG, - ) - - return obj_cls - - def _search_child(self, scope: str) -> Optional["Registry"]: - """Depth-first search for the corresponding registry in its children. - - Note that the method only search for the corresponding registry from - the current registry. Therefore, if we want to search from the root - registry, :meth:`_get_root_registry` should be called to get the - root registry first. - - Args: - scope (str): The scope name used for searching for its - corresponding registry. - - Returns: - Registry or None: Return the corresponding registry if ``scope`` - exists, otherwise return None. - """ - if self._scope == scope: - return self - - for child in self._children.values(): - registry = child._search_child(scope) - if registry is not None: - return registry - - return None - - def build(self, cfg: dict, *args, **kwargs) -> Any: - """Build an instance. - - Build an instance by calling :attr:`build_func`. - - Args: - cfg (dict): Config dict needs to be built. - - Returns: - Any: The constructed object. - - Examples: - >>> from mmengine import Registry - >>> MODELS = Registry('models') - >>> @MODELS.register_module(force=True) - >>> class ResNet: - >>> def __init__(self, depth, stages=4): - >>> self.depth = depth - >>> self.stages = stages - >>> cfg = dict(type='ResNet', depth=50) - >>> model = MODELS.build(cfg) - """ - return self.build_func(cfg, *args, **kwargs, registry=self) - - def _add_child(self, registry: "Registry") -> None: - """Add a child for a registry. - - Args: - registry (:obj:`Registry`): The ``registry`` will be added as a - child of the ``self``. - """ - - assert isinstance(registry, Registry) - assert registry.scope is not None - # Allow re-registration of the same registry (idempotent) - if registry.scope in self.children: - if self.children[registry.scope] is registry: - # Same registry instance, skip - return - else: - raise AssertionError(f"scope {registry.scope} exists in {self.name} registry") - self.children[registry.scope] = registry - - def _register_module( - self, - module: type, - module_name: str | list[str] | None = None, - force: bool = False, - ) -> None: - """Register a module. - - Args: - module (type): Module to be registered. Typically a class or a - function, but generally all ``Callable`` are acceptable. - module_name (str or list of str, optional): The module name to be - registered. If not specified, the class name will be used. - Defaults to None. - force (bool): Whether to override an existing class with the same - name. Defaults to False. - """ - from ..logging import print_log - - if not callable(module): - raise TypeError(f"module must be Callable, but got {type(module)}") - - if module_name is None: - module_name = module.__name__ - if isinstance(module_name, str): - module_name = [module_name] - for name in module_name: - if not force and name in self._module_dict: - existed_module = self.module_dict[name] - print_log( - f"Warning: {name} is already registered in {self.name} at {existed_module.__module__}, " - f"it will be overwritten by the module at {module.__module__}", - logger="current", - level=logging.WARNING, - ) - self._module_dict[name] = module - - def register_module( - self, - name: str | list[str] | None = None, - force: bool = False, - module: type | None = None, - ) -> type | Callable: - """Register a module. - - A record will be added to ``self._module_dict``, whose key is the class - name or the specified name, and value is the class itself. - It can be used as a decorator or a normal function. - - Args: - name (str or list of str, optional): The module name to be - registered. If not specified, the class name will be used. - force (bool): Whether to override an existing class with the same - name. Defaults to False. - module (type, optional): Module class or function to be registered. - Defaults to None. - - Examples: - >>> backbones = Registry('backbone') - >>> # as a decorator - >>> @backbones.register_module(force=True) - >>> class ResNet: - >>> pass - >>> backbones = Registry('backbone') - >>> @backbones.register_module(name='mnet') - >>> class MobileNet: - >>> pass - - >>> # as a normal function - >>> class ResNet: - >>> pass - >>> backbones.register_module(module=ResNet) - """ - if not isinstance(force, bool): - raise TypeError(f"force must be a boolean, but got {type(force)}") - - # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_seq_of(name, str)): - raise TypeError(f"name must be None, an instance of str, or a sequence of str, but got {type(name)}") - - # use it as a normal method: x.register_module(module=SomeClass) - if module is not None: - self._register_module(module=module, module_name=name, force=force) - return module - - # use it as a decorator: @x.register_module(force=True) - def _register(module): - self._register_module(module=module, module_name=name, force=force) - return module - - return _register diff --git a/libs/visengine/visengine/registry/root.py b/libs/visengine/visengine/registry/root.py deleted file mode 100644 index 2978b8e..0000000 --- a/libs/visengine/visengine/registry/root.py +++ /dev/null @@ -1,72 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -"""MMEngine provides 20 root registries to support using modules across -projects. - -More datails can be found at -https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. -""" - -from .build_functions import ( - build_model_from_cfg, - build_optimizer_from_cfg, - build_runner_from_cfg, - build_scheduler_from_cfg, -) -from .registry import Registry - -# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` -RUNNERS = Registry("runner", build_func=build_runner_from_cfg) -# manage runner constructors that define how to initialize runners -RUNNER_CONSTRUCTORS = Registry("runner constructor") -# manage all kinds of loops like `EpochBasedTrainLoop` -LOOPS = Registry("loop") -# manage all kinds of hooks like `CheckpointHook` -HOOKS = Registry("hook") - -# manage all kinds of strategies like `NativeStrategy` and `DDPStrategy` -STRATEGIES = Registry("strategy") - -# manage data-related modules -DATASETS = Registry("dataset") -DATA_SAMPLERS = Registry("data sampler") -TRANSFORMS = Registry("transform", locations=["viscv.transforms"]) - -# mangage all kinds of modules inheriting `nn.Module` -MODELS = Registry("model", build_model_from_cfg) -# mangage all kinds of model wrappers like 'MMDistributedDataParallel' -MODEL_WRAPPERS = Registry("model_wrapper") -# mangage all kinds of weight initialization modules like `Uniform` -WEIGHT_INITIALIZERS = Registry("weight initializer") - -# mangage all kinds of optimizers like `SGD` and `Adam` -OPTIMIZERS = Registry("optimizer", build_func=build_optimizer_from_cfg) -# manage optimizer wrapper -OPTIM_WRAPPERS = Registry("optim_wrapper") -# manage constructors that customize the optimization hyperparameters. -OPTIM_WRAPPER_CONSTRUCTORS = Registry("optimizer wrapper constructor") -# mangage all kinds of parameter schedulers like `MultiStepLR` -PARAM_SCHEDULERS = Registry("parameter scheduler", build_func=build_scheduler_from_cfg) - -# manage all kinds of metrics -METRICS = Registry("metric") -# manage evaluator -EVALUATOR = Registry("evaluator") - -# manage task-specific modules like anchor generators and box coders -TASK_UTILS = Registry("task util") - -# manage visualizer -VISUALIZERS = Registry("visualizer") -# manage visualizer backend -VISBACKENDS = Registry("vis_backend") - -# manage logprocessor -LOG_PROCESSORS = Registry("log_processor") - -# manage inferencer -INFERENCERS = Registry("inferencer") - -# manage function -FUNCTIONS = Registry("function") diff --git a/libs/visengine/visengine/registry/utils.py b/libs/visengine/visengine/registry/utils.py deleted file mode 100644 index 87a714f..0000000 --- a/libs/visengine/visengine/registry/utils.py +++ /dev/null @@ -1,119 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import datetime -import logging -import os.path as osp - -from visengine.fileio import dump -from visengine.logging import print_log - -from . import root -from .default_scope import DefaultScope -from .registry import Registry - - -def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list: - """Traverse the whole registry tree from any given node, and collect - information of all registered modules in this registry tree. - - Args: - registry (Registry): a registry node in the registry tree. - verbose (bool): Whether to print log. Defaults to True - - Returns: - list: Statistic results of all modules in each node of the registry - tree. - """ - root_registry = registry.root - modules_info = [] - - def _dfs_registry(_registry): - if isinstance(_registry, Registry): - num_modules = len(_registry.module_dict) - scope = _registry.scope - registry_info = {"num_modules": num_modules, "scope": scope} - for name, registered_class in _registry.module_dict.items(): - folder = "/".join(registered_class.__module__.split(".")[:-1]) - if folder in registry_info: - registry_info[folder].append(name) - else: - registry_info[folder] = [name] - if verbose: - print_log( - f"Find {num_modules} modules in {scope}'s '{_registry.name}' registry ", - logger="current", - ) - modules_info.append(registry_info) - else: - return - for _, child in _registry.children.items(): - _dfs_registry(child) - - _dfs_registry(root_registry) - return modules_info - - -def count_registered_modules(save_path: str | None = None, verbose: bool = True) -> dict: - """Scan all modules in MMEngine's root and child registries and dump to - json. - - Args: - save_path (str, optional): Path to save the json file. - verbose (bool): Whether to print log. Defaults to True. - - Returns: - dict: Statistic results of all registered modules. - """ - # import modules to trigger registering - import visengine.dataset - import visengine.evaluator - import visengine.hooks - import visengine.model - import visengine.optim - import visengine.runner - import visengine.visualization # noqa: F401 - - registries_info = {} - # traverse all registries in MMEngine - for item in dir(root): - if not item.startswith("__"): - registry = getattr(root, item) - if isinstance(registry, Registry): - registries_info[item] = traverse_registry_tree(registry, verbose) - scan_data = { - "scan_date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "registries": registries_info, - } - if verbose: - print_log(f"Finish registry analysis, got: {scan_data}", logger="current") - if save_path is not None: - json_path = osp.join(save_path, "modules_statistic_results.json") - dump(scan_data, json_path, indent=2) - print_log(f"Result has been saved to {json_path}", logger="current") - return scan_data - - -def init_default_scope(scope: str) -> None: - """Initialize the given default scope. - - Args: - scope (str): The name of the default scope. - """ - never_created = DefaultScope.get_current_instance() is None or not DefaultScope.check_instance_created(scope) - if never_created: - DefaultScope.get_instance(scope, scope_name=scope) - return - current_scope = DefaultScope.get_current_instance() # type: ignore - if current_scope.scope_name != scope: # type: ignore - print_log( - "The current default scope " # type: ignore - f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore - "`init_default_scope` will force set the current" - f'default scope to "{scope}".', - logger="current", - level=logging.WARNING, - ) - # avoid name conflict - new_instance_name = f"{scope}-{datetime.datetime.now()}" - DefaultScope.get_instance(new_instance_name, scope_name=scope) diff --git a/libs/visengine/visengine/runner/__init__.py b/libs/visengine/visengine/runner/__init__.py deleted file mode 100644 index 9c12264..0000000 --- a/libs/visengine/visengine/runner/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from ._flexible_runner import FlexibleRunner -from .activation_checkpointing import turn_on_activation_checkpointing -from .amp import autocast -from .base_loop import BaseLoop -from .checkpoint import ( - CheckpointLoader, - find_latest_checkpoint, - get_deprecated_model_names, - get_external_models, - get_mmcls_models, - get_state_dict, - get_torchvision_models, - load_checkpoint, - load_state_dict, - save_checkpoint, - weights_to_cpu, -) -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .runner import Runner -from .utils import set_random_seed - -__all__ = [ - "BaseLoop", - "CheckpointLoader", - "EpochBasedTrainLoop", - "FlexibleRunner", - "IterBasedTrainLoop", - "LogProcessor", - "Priority", - "Runner", - "TestLoop", - "ValLoop", - "autocast", - "find_latest_checkpoint", - "get_deprecated_model_names", - "get_external_models", - "get_mmcls_models", - "get_priority", - "get_state_dict", - "get_torchvision_models", - "load_checkpoint", - "load_state_dict", - "save_checkpoint", - "set_random_seed", - "turn_on_activation_checkpointing", - "weights_to_cpu", -] diff --git a/libs/visengine/visengine/runner/_flexible_runner.py b/libs/visengine/visengine/runner/_flexible_runner.py deleted file mode 100644 index f8a971a..0000000 --- a/libs/visengine/visengine/runner/_flexible_runner.py +++ /dev/null @@ -1,1652 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -import os.path as osp -import pickle -import warnings -from collections.abc import Callable -from functools import partial -from typing import Union - -import torch.nn as nn -from torch.utils.data import DataLoader - -import visengine -from visengine._strategy import BaseStrategy -from visengine.config import Config, ConfigDict -from visengine.dataset import worker_init_fn as default_worker_init_fn -from visengine.dist import get_rank, infer_launcher, master_only -from visengine.evaluator import Evaluator -from visengine.fileio import FileClient, join_path -from visengine.hooks import Hook -from visengine.logging import MessageHub, print_log -from visengine.optim import OptimWrapper, OptimWrapperDict, _ParamScheduler -from visengine.registry import ( - DATA_SAMPLERS, - DATASETS, - EVALUATOR, - FUNCTIONS, - HOOKS, - LOG_PROCESSORS, - LOOPS, - RUNNERS, - STRATEGIES, - VISUALIZERS, - DefaultScope, -) -from visengine.utils import digit_version -from visengine.utils.dl_utils import TORCH_VERSION -from visengine.visualization import Visualizer - -from .base_loop import BaseLoop -from .checkpoint import find_latest_checkpoint -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .utils import _get_batch_size - -ConfigType = Union[dict, Config, ConfigDict] -ParamSchedulerType = Union[list[_ParamScheduler], dict[str, list[_ParamScheduler]]] -OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] - - -@RUNNERS.register_module(force=True) -class FlexibleRunner: - """A training helper for PyTorch. - - Runner object can be built from config by ``runner = Runner.from_cfg(cfg)`` - where the ``cfg`` usually contains training, validation, and test-related - configurations to build corresponding components. We usually use the - same config to launch training, testing, and validation tasks. However, - only some of these components are necessary at the same time, e.g., - testing a model does not need training or validation-related components. - - To avoid repeatedly modifying config, the construction of ``Runner`` adopts - lazy initialization to only initialize components when they are going to be - used. Therefore, the model is always initialized at the beginning, and - training, validation, and, testing related components are only initialized - when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``, - respectively. - - Warning: - This is an experimental feature, and its interface is subject to - change. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It can be - a dict used for build a model. - - Kwargs: - work_dir (str, optional): The working directory to save checkpoints. - The logs will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. Defaults to 'work_dir'. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as ``experiment_name``. - Defaults to None. - train_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping training steps. Defaults to None. - See :meth:`build_dataloader` for more details. - optim_wrapper (OptimWrapper or dict, optional): - Computing gradient of model parameters. If specified, - :attr:`train_dataloader` should also be specified. If automatic - mixed precision or gradient accmulation - training is required. The type of ``optim_wrapper`` should be - AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for - examples. Defaults to None. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optimizer` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - train_cfg (dict, optional): A dict to build a training loop. If it does - not provide "type" key, it should contain "by_epoch" to decide - which type of training loop :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` should be used. If ``train_cfg`` - specified, :attr:`train_dataloader` should also be specified. - Defaults to None. See :meth:`build_train_loop` for more details. - val_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping validation steps. Defaults to None. - See :meth:`build_dataloader` for more details. - val_evaluator (Evaluator or dict or list, optional): A evaluator object - used for computing metrics for validation. It can be a dict or a - list of dict to build a evaluator. If specified, - :attr:`val_dataloader` should also be specified. Defaults to None. - val_cfg (dict, optional): A dict to build a validation loop. If it does - not provide "type" key, :class:`ValLoop` will be used by default. - If ``val_cfg`` specified, :attr:`val_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - test_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping test steps. Defaults to None. - See :meth:`build_dataloader` for more details. - Defaults to None. See :meth:`build_val_loop` for more details. - test_evaluator (Evaluator or dict or list, optional): A evaluator - object used for computing metrics for test steps. It can be a dict - or a list of dict to build a evaluator. If specified, - :attr:`test_dataloader` should also be specified. Defaults to None. - test_cfg (dict, optional): A dict to build a test loop. If it does - not provide "type" key, :class:`TestLoop` will be used by default. - If ``test_cfg`` specified, :attr:`test_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_test_loop` for more details. - strategy (BaseStrategy or dict, optional): A strategy object or a dict - to build a strategy. Defaults to None. If not specified, the - strategy will be inferred automatically. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to - execute default actions like updating model parameters and saving - checkpoints. Default hooks are ``OptimizerHook``, - ``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and - ``CheckpointHook``. Defaults to None. - See :meth:`register_default_hooks` for more details. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict - and doesn't contain the key ``data_preprocessor``, set the argument - as the ``data_preprocessor`` of the ``model`` dict. - Defaults to None. - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool): Whether to resume training. Defaults to False. If - ``resume`` is True and ``load_from`` is None, automatically to - find latest checkpoint from ``work_dir``. If not found, resuming - does nothing. - launcher (str, optional): Way to launcher multi-process. Supported - launchers are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is - provided, non-distributed environment will be launched. - If launcher is None, the launcher will be inferred according some - specified environments. Defaults to None. - env_cfg (dict): A dict used for setting environment. Defaults to - dict(dist_cfg=dict(backend='nccl')). - log_processor (dict, optional): A processor to format logs. Defaults to - None. - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - visualizer (Visualizer or dict, optional): A Visualizer object or a - dict build Visualizer object. Defaults to None. If not - specified, default config will be used. - default_scope (str): Used to reset registries location. - Defaults to "mmengine". - randomness (dict): Some settings to make the experiment as reproducible - as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random number - will be generated and it will be broadcasted to all other processes - if in distributed environment. If ``cudnn_benchmark`` is - ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in - ``randomness``, the value of ``torch.backends.cudnn.benchmark`` - will be ``False`` finally. - compile (bool or dict, optional): Whether to enable ``torch.compile``. - Defaults to False. - cfg (dict or Configdict or :obj:`Config`, optional): Full config. - Defaults to None. - - Note: - Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in - `compile = True`. If you want to control compile options, you - can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``. - Refer to `PyTorch API Documentation `_ for more valid - options. - - Examples: - >>> from visengine.runner import Runner - >>> cfg = dict( - >>> model=dict(type='ToyModel'), - >>> work_dir='path/of/work_dir', - >>> train_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=True), - >>> batch_size=1, - >>> num_workers=0), - >>> val_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> test_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> auto_scale_lr=dict(base_batch_size=16, enable=False), - >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( - >>> type='SGD', lr=0.01)), - >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - >>> val_evaluator=dict(type='ToyEvaluator'), - >>> test_evaluator=dict(type='ToyEvaluator'), - >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), - >>> val_cfg=dict(), - >>> test_cfg=dict(), - >>> custom_hooks=[], - >>> default_hooks=dict( - >>> timer=dict(type='IterTimerHook'), - >>> checkpoint=dict(type='CheckpointHook', interval=1), - >>> logger=dict(type='LoggerHook'), - >>> optimizer=dict(type='OptimizerHook', grad_clip=False), - >>> param_scheduler=dict(type='ParamSchedulerHook')), - >>> launcher='none', - >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), - >>> log_processor=dict(window_size=20), - >>> visualizer=dict(type='Visualizer', - >>> vis_backends=[dict(type='LocalVisBackend', - >>> save_dir='temp_dir')]) - >>> ) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - >>> runner.test() - """ - - cfg: Config - _train_loop: BaseLoop | dict | None - _val_loop: BaseLoop | dict | None - _test_loop: BaseLoop | dict | None - - def __init__( - self, - model: nn.Module | dict, - *, - work_dir: str = "work_dirs", - experiment_name: str | None = None, - train_dataloader: DataLoader | dict | None = None, - optim_wrapper: OptimWrapper | dict | None = None, - param_scheduler: _ParamScheduler | dict | list | None = None, - train_cfg: dict | None = None, - val_dataloader: DataLoader | dict | None = None, - val_evaluator: Evaluator | dict | list | None = None, - val_cfg: dict | None = None, - test_dataloader: DataLoader | dict | None = None, - test_evaluator: Evaluator | dict | list | None = None, - test_cfg: dict | None = None, - strategy: BaseStrategy | dict | None = None, - auto_scale_lr: dict | None = None, - default_hooks: dict[str, Hook | dict] | None = None, - custom_hooks: list[Hook | dict] | None = None, - data_preprocessor: nn.Module | dict | None = None, - load_from: str | None = None, - resume: str | bool = False, - launcher: str | None = None, - env_cfg: dict | None = None, - log_processor: dict | None = None, - log_level: str = "INFO", - visualizer: Visualizer | dict | None = None, - default_scope: str | None = "mmengine", - randomness: dict | None = None, - compile: bool | dict = False, - cfg: ConfigType | None = None, - ): - if randomness is None: - randomness = {"seed": None} - if env_cfg is None: - env_cfg = {"dist_cfg": {"backend": "nccl"}} - if isinstance(model, dict) and data_preprocessor is not None: - # Merge the data_preprocessor to model config. - model.setdefault("data_preprocessor", data_preprocessor) - self.model = model - - self._work_dir = osp.abspath(work_dir) - mmengine.mkdir_or_exist(self._work_dir) - - # recursively copy the `cfg` because `self.cfg` will be modified - # everywhere. - if cfg is not None: - if isinstance(cfg, Config): - self.cfg = copy.deepcopy(cfg) - elif isinstance(cfg, dict): - self.cfg = Config(cfg) - else: - self.cfg = Config({}) - - # lazy initialization - training_related = [train_dataloader, train_cfg, optim_wrapper] - if not (all(item is None for item in training_related) or all(item is not None for item in training_related)): - raise ValueError( - "train_dataloader, train_cfg, and optim_wrapper should be " - "either all None or not None, but got " - f"train_dataloader={train_dataloader}, " - f"train_cfg={train_cfg}, " - f"optim_wrapper={optim_wrapper}." - ) - self._train_dataloader = train_dataloader - self._train_loop = train_cfg - - self.optim_wrapper: OptimWrapper | dict | None - self.optim_wrapper = optim_wrapper - - self._auto_scale_lr = auto_scale_lr - - # If there is no need to adjust learning rate, momentum or other - # parameters of optimizer, param_scheduler can be None - if param_scheduler is not None and self.optim_wrapper is None: - raise ValueError(f"param_scheduler should be None when optim_wrapper is None, but got {param_scheduler}") - - self.param_schedulers = param_scheduler - - val_related = [val_dataloader, val_cfg, val_evaluator] - if not (all(item is None for item in val_related) or all(item is not None for item in val_related)): - raise ValueError( - "val_dataloader, val_cfg, and val_evaluator should be either " - "all None or not None, but got " - f"val_dataloader={val_dataloader}, val_cfg={val_cfg}, " - f"val_evaluator={val_evaluator}" - ) - self._val_dataloader = val_dataloader - self._val_loop = val_cfg - self._val_evaluator = val_evaluator - - test_related = [test_dataloader, test_cfg, test_evaluator] - if not (all(item is None for item in test_related) or all(item is not None for item in test_related)): - raise ValueError( - "test_dataloader, test_cfg, and test_evaluator should be " - "either all None or not None, but got " - f"test_dataloader={test_dataloader}, test_cfg={test_cfg}, " - f"test_evaluator={test_evaluator}" - ) - self._test_dataloader = test_dataloader - self._test_loop = test_cfg - self._test_evaluator = test_evaluator - - if not isinstance(compile, bool) and not isinstance(compile, dict): - raise TypeError(f"compile should be a bool or dict, but got {type(compile)}") - self._compile = compile - - if isinstance(resume, str) and load_from is not None: - raise ValueError("If resume is a str, load_from should be None.") - self._load_from = load_from - self._resume = resume - # flag to mark whether checkpoint has been loaded or resumed - self._has_loaded = False - - if launcher is None: - launcher = infer_launcher() - - if experiment_name is None and self.cfg.filename is not None: - experiment_name = osp.splitext(osp.basename(self.cfg.filename))[0] - - self._randomness_cfg = randomness - self.strategy = self.build_strategy( - strategy, - launcher=launcher, - randomness=randomness, - env_cfg=env_cfg, - experiment_name=experiment_name, - log_level=log_level, - ) - - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - if default_scope is not None: - default_scope = DefaultScope.get_instance( # type: ignore - self.experiment_name, scope_name=default_scope - ) - self.default_scope = default_scope - # Build log processor to format message. - log_processor = {} if log_processor is None else log_processor - self.log_processor = self.build_log_processor(log_processor) - - # Collect and log environment information. - self._log_env() - - # Build `message_hub` for communication among components. - # `message_hub` can store log scalars (loss, learning rate) and - # runtime information (iter and epoch). Those components that do not - # have access to the runner can get iteration or epoch information - # from `message_hub`. For example, models can get the latest created - # `message_hub` by - # `self.message_hub=MessageHub.get_current_instance()` and then get - # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. - # See `MessageHub` and `ManagerMixin` for more details. - self.message_hub = self.build_message_hub() - # visualizer used for writing log or visualizing all kinds of data - self.visualizer = self.build_visualizer(visualizer) - if self.cfg: - self.visualizer.add_config(self.cfg) - - self._hooks: list[Hook] = [] - # register hooks to `self._hooks` - self.register_hooks(default_hooks, custom_hooks) - # log hooks information - self.logger.info(f"Hooks will be executed in the following order:\n{self.get_hooks_info()}") - - # dump `cfg` to `work_dir` - self.dump_config() - - @classmethod - def from_cfg(cls, cfg: ConfigType) -> "FlexibleRunner": - """Build a runner from config. - - Args: - cfg (ConfigType): A config used for building runner. Keys of - ``cfg`` can see :meth:`__init__`. - - Returns: - Runner: A runner build from ``cfg``. - """ - cfg = copy.deepcopy(cfg) - runner = cls( - model=cfg["model"], - work_dir=cfg.get("work_dir", "work_dirs"), - experiment_name=cfg.get("experiment_name"), - train_dataloader=cfg.get("train_dataloader"), - optim_wrapper=cfg.get("optim_wrapper"), - param_scheduler=cfg.get("param_scheduler"), - train_cfg=cfg.get("train_cfg"), - val_dataloader=cfg.get("val_dataloader"), - val_evaluator=cfg.get("val_evaluator"), - val_cfg=cfg.get("val_cfg"), - test_dataloader=cfg.get("test_dataloader"), - test_evaluator=cfg.get("test_evaluator"), - test_cfg=cfg.get("test_cfg"), - strategy=cfg.get("strategy"), - auto_scale_lr=cfg.get("auto_scale_lr"), - default_hooks=cfg.get("default_hooks"), - custom_hooks=cfg.get("custom_hooks"), - data_preprocessor=cfg.get("data_preprocessor"), - load_from=cfg.get("load_from"), - resume=cfg.get("resume", False), - launcher=cfg.get("launcher"), - env_cfg=cfg.get("env_cfg"), # type: ignore - log_processor=cfg.get("log_processor"), - log_level=cfg.get("log_level", "INFO"), - visualizer=cfg.get("visualizer"), - default_scope=cfg.get("default_scope", "mmengine"), - randomness=cfg.get("randomness", {"seed": None}), - cfg=cfg, - ) - - return runner - - @property - def experiment_name(self): - """str: Name of experiment.""" - return self.strategy.experiment_name - - @property - def model_name(self): - """str: Name of the model, usually the module class name.""" - return self._model_name - - @property - def work_dir(self): - """str: The working directory to save checkpoints and logs.""" - return self._work_dir - - @property - def log_dir(self): - return self.strategy.log_dir - - @property - def logger(self): - return self.strategy.logger - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_epochs - else: - return 0 - - @property - def max_iters(self): - """int: Total iterations to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_iters - else: - return 0 - - @property - def epoch(self): - """int: Current epoch.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.epoch - else: - return 0 - - @property - def iter(self): - """int: Current iteration.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.iter - else: - return 0 - - @property - def distributed(self): - """bool: Whether current environment is distributed.""" - return self.strategy.distributed - - @property - def rank(self): - """int: Rank of current process.""" - return self.strategy.rank - - @property - def world_size(self): - """int: Number of processes participating in the job.""" - return self.strategy.world_size - - @property - def deterministic(self): - """int: Whether cudnn to select deterministic algorithms.""" - return self._deterministic - - @property - def seed(self): - """int: A number to set random modules.""" - return self.strategy.seed - - @property - def timestamp(self): - """str: Timestamp when creating experiment.""" - return self.strategy.timestamp - - @property - def hooks(self): - """List[:obj:`Hook`]: A list of registered hooks.""" - return self._hooks - - @property - def train_loop(self): - """:obj:`BaseLoop`: A loop to run training.""" - if isinstance(self._train_loop, BaseLoop) or self._train_loop is None: - return self._train_loop - else: - self._train_loop = self.build_train_loop(self._train_loop) - return self._train_loop - - @property - def val_loop(self): - """:obj:`BaseLoop`: A loop to run validation.""" - if isinstance(self._val_loop, BaseLoop) or self._val_loop is None: - return self._val_loop - else: - self._val_loop = self.build_val_loop(self._val_loop) - return self._val_loop - - @property - def test_loop(self): - """:obj:`BaseLoop`: A loop to run testing.""" - if isinstance(self._test_loop, BaseLoop) or self._test_loop is None: - return self._test_loop - else: - self._test_loop = self.build_test_loop(self._test_loop) - return self._test_loop - - @property - def train_dataloader(self): - """The data loader for training.""" - return self.train_loop.dataloader - - @property - def val_dataloader(self): - """The data loader for validation.""" - return self.val_loop.dataloader - - @property - def test_dataloader(self): - """The data loader for testing.""" - return self.test_loop.dataloader - - @property - def val_evaluator(self): - """:obj:`Evaluator`: An evaluator for validation.""" - return self.val_loop.evaluator - - @property - def test_evaluator(self): - """:obj:`Evaluator`: An evaluator for testing.""" - return self.test_loop.evaluator - - @property - def val_interval(self): - """int: Interval to run validation during training.""" - return self.train_loop.val_interval - - @property - def val_begin(self): - """int: The epoch/iteration to start running validation during - training.""" - return self.train_loop.val_begin - - def build_strategy( - self, - strategy: BaseStrategy | dict | None = None, - launcher: str = "none", - randomness: dict | None = None, - env_cfg: dict | None = None, - experiment_name: str | None = None, - log_level: str | None = None, - ) -> BaseStrategy: - """Build a strategy. - - Args: - strategy (BaseStrategy, optional): A strategy object or dict to - build the strategy. Defaults to None. - - Returns: - BaseStrategy: A strategy object. - """ - if env_cfg is None: - env_cfg = {"dist_cfg": {"backend": "nccl"}} - if isinstance(strategy, BaseStrategy): - strategy_obj = strategy - else: - if launcher == "none": - if strategy is None: - strategy = {"type": "SingleDeviceStrategy"} - else: - if strategy is None: - strategy = {"type": "DDPStrategy"} - - assert isinstance(strategy, dict) - - # train_micro_batch_size_per_gpu is required by DeepSpeed - if isinstance(strategy["type"], str): - strategy_name = strategy["type"] - else: - strategy_name = strategy["type"].__name__ - if strategy_name == "DeepSpeedStrategy": - if self._train_dataloader is None: - strategy["train_micro_batch_size_per_gpu"] = 1 - else: - strategy["train_micro_batch_size_per_gpu"] = _get_batch_size(self._train_dataloader) - - strategy.setdefault("work_dir", self._work_dir) - strategy.setdefault("experiment_name", experiment_name) - strategy.setdefault("auto_scale_lr", self._auto_scale_lr) - - env_kwargs = dict( - launcher=launcher, - randomness=randomness, - **env_cfg, - ) - strategy.setdefault("env_kwargs", env_kwargs) - - log_kwargs = {"log_level": log_level} - strategy.setdefault("log_kwargs", log_kwargs) - - strategy_obj = STRATEGIES.build(strategy) - - return strategy_obj - - def build_message_hub( - self, - message_hub: dict | None = None, - ) -> MessageHub: - """Build a global asscessable MessageHub. - - Args: - message_hub (dict, optional): A dict to build MessageHub object. - If not specified, default config will be used to build - MessageHub object. Defaults to None. - - Returns: - MessageHub: A MessageHub object build from ``message_hub``. - """ - if message_hub is None: - message_hub = {"name": self.experiment_name} - elif isinstance(message_hub, dict): - # ensure message_hub containing name key - message_hub.setdefault("name", self.experiment_name) - else: - raise TypeError(f"message_hub should be dict or None, but got {message_hub}") - - return MessageHub.get_instance(**message_hub) - - def build_visualizer( - self, - visualizer: Visualizer | dict | None = None, - ) -> Visualizer: - """Build a global asscessable Visualizer. - - Args: - visualizer (Visualizer or dict, optional): A Visualizer object - or a dict to build Visualizer object. If ``visualizer`` is a - Visualizer object, just returns itself. If not specified, - default config will be used to build Visualizer object. - Defaults to None. - - Returns: - Visualizer: A Visualizer object build from ``visualizer``. - """ - if visualizer is None: - visualizer = { - "name": self.experiment_name, - "vis_backends": [{"type": "LocalVisBackend"}], - "save_dir": self.log_dir, - } - return Visualizer.get_instance(**visualizer) - - if isinstance(visualizer, Visualizer): - return visualizer - - if isinstance(visualizer, dict): - # ensure visualizer containing name key - visualizer.setdefault("name", self.experiment_name) - visualizer.setdefault("save_dir", self.log_dir) - return VISUALIZERS.build(visualizer) - else: - raise TypeError(f"visualizer should be Visualizer object, a dict or None, but got {visualizer}") - - def build_evaluator( - self, - evaluator: dict | list | Evaluator, - ) -> Evaluator: - """Build evaluator. - - Examples of ``evaluator``:: - - # evaluator could be a built Evaluator instance - evaluator = Evaluator(metrics=[ToyMetric()]) - - # evaluator can also be a list of dict - evaluator = [ - dict(type='ToyMetric1'), - dict(type='ToyEvaluator2') - ] - - # evaluator can also be a list of built metric - evaluator = [ToyMetric1(), ToyMetric2()] - - # evaluator can also be a dict with key metrics - evaluator = dict(metrics=ToyMetric()) - # metric is a list - evaluator = dict(metrics=[ToyMetric()]) - - Args: - evaluator (Evaluator or dict or list): An Evaluator object or a - config dict or list of config dict used to build an Evaluator. - - Returns: - Evaluator: Evaluator build from ``evaluator``. - """ - if isinstance(evaluator, Evaluator): - return evaluator - elif isinstance(evaluator, dict): - # if `metrics` in dict keys, it means to build customized evalutor - if "metrics" in evaluator: - evaluator.setdefault("type", "Evaluator") - return EVALUATOR.build(evaluator) - # otherwise, default evalutor will be built - else: - return Evaluator(evaluator) # type: ignore - elif isinstance(evaluator, list): - # use the default `Evaluator` - return Evaluator(evaluator) # type: ignore - else: - raise TypeError(f"evaluator should be one of dict, list of dict, and Evaluator, but got {evaluator}") - - @staticmethod - def build_dataloader( - dataloader: DataLoader | dict, - seed: int | None = None, - diff_rank_seed: bool = False, - ) -> DataLoader: - """Build dataloader. - - The method builds three components: - - - Dataset - - Sampler - - Dataloader - - An example of ``dataloader``:: - - dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=9 - ) - - Args: - dataloader (DataLoader or dict): A Dataloader object or a dict to - build Dataloader object. If ``dataloader`` is a Dataloader - object, just returns itself. - seed (int, optional): Random seed. Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds to - different ranks. If True, the seed passed to sampler is set - to None, in order to synchronize the seeds used in samplers - across different ranks. Defaults to False. - - Returns: - Dataloader: DataLoader build from ``dataloader_cfg``. - """ - if isinstance(dataloader, DataLoader): - return dataloader - - dataloader_cfg = copy.deepcopy(dataloader) - - # build dataset - dataset_cfg = dataloader_cfg.pop("dataset") - if isinstance(dataset_cfg, dict): - dataset = DATASETS.build(dataset_cfg) - if hasattr(dataset, "full_init"): - dataset.full_init() - else: - # fallback to raise error in dataloader - # if `dataset_cfg` is not a valid type - dataset = dataset_cfg - - # build sampler - sampler_cfg = dataloader_cfg.pop("sampler") - if isinstance(sampler_cfg, dict): - sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build(sampler_cfg, default_args={"dataset": dataset, "seed": sampler_seed}) - else: - # fallback to raise error in dataloader - # if `sampler_cfg` is not a valid type - sampler = sampler_cfg - - # build batch sampler - batch_sampler_cfg = dataloader_cfg.pop("batch_sampler", None) - if batch_sampler_cfg is None: - batch_sampler = None - elif isinstance(batch_sampler_cfg, dict): - batch_sampler = DATA_SAMPLERS.build( - batch_sampler_cfg, - default_args={ - "sampler": sampler, - "batch_size": dataloader_cfg.pop("batch_size"), - }, - ) - else: - # fallback to raise error in dataloader - # if `batch_sampler_cfg` is not a valid type - batch_sampler = batch_sampler_cfg - - # build dataloader - init_fn: partial | None - if "worker_init_fn" in dataloader_cfg: - worker_init_fn_cfg = dataloader_cfg.pop("worker_init_fn") - worker_init_fn_type = worker_init_fn_cfg.pop("type") - worker_init_fn = FUNCTIONS.get(worker_init_fn_type) - assert callable(worker_init_fn) - init_fn = partial(worker_init_fn, **worker_init_fn_cfg) # type: ignore - else: - if seed is not None: - disable_subprocess_warning = dataloader_cfg.pop("disable_subprocess_warning", False) - assert isinstance(disable_subprocess_warning, bool), ( - f"disable_subprocess_warning should be a bool, but got {type(disable_subprocess_warning)}" - ) - init_fn = partial( - default_worker_init_fn, - num_workers=dataloader_cfg.get("num_workers"), - rank=get_rank(), - seed=seed, - disable_subprocess_warning=disable_subprocess_warning, - ) - else: - init_fn = None - - # `persistent_workers` requires pytorch version >= 1.7 - if "persistent_workers" in dataloader_cfg and digit_version(TORCH_VERSION) < digit_version("1.7.0"): - print_log( - "`persistent_workers` is only available when pytorch version >= 1.7", - logger="current", - level=logging.WARNING, - ) - dataloader_cfg.pop("persistent_workers") - - # The default behavior of `collat_fn` in dataloader is to - # merge a list of samples to form a mini-batch of Tensor(s). - # However, in mmengine, if `collate_fn` is not defined in - # dataloader_cfg, `pseudo_collate` will only convert the list of - # samples into a dict without stacking the batch tensor. - collate_fn_cfg = dataloader_cfg.pop("collate_fn", {"type": "pseudo_collate"}) - if isinstance(collate_fn_cfg, dict): - collate_fn_type = collate_fn_cfg.pop("type") - if isinstance(collate_fn_type, str): - collate_fn = FUNCTIONS.get(collate_fn_type) - else: - collate_fn = collate_fn_type - collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore - elif callable(collate_fn_cfg): - collate_fn = collate_fn_cfg - else: - raise TypeError(f"collate_fn should be a dict or callable object, but got {collate_fn_cfg}") - data_loader = DataLoader( - dataset=dataset, - sampler=sampler if batch_sampler is None else None, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - worker_init_fn=init_fn, - **dataloader_cfg, - ) - return data_loader - - def build_train_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build training loop. - - Examples of ``loop``:: - - # `EpochBasedTrainLoop` will be used - loop = dict(by_epoch=True, max_epochs=3) - - # `IterBasedTrainLoop` will be used - loop = dict(by_epoch=False, max_epochs=3) - - # custom training loop - loop = dict(type='CustomTrainLoop', max_epochs=3) - - Args: - loop (BaseLoop or dict): A training loop or a dict to build - training loop. If ``loop`` is a training loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Training loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) - - if "type" in loop_cfg and "by_epoch" in loop_cfg: - raise RuntimeError("Only one of `type` or `by_epoch` can exist in `loop_cfg`.") - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={"runner": self, "dataloader": self._train_dataloader}, - ) - else: - by_epoch = loop_cfg.pop("by_epoch") - if by_epoch: - loop = EpochBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader) - else: - loop = IterBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader) - return loop # type: ignore - - def build_val_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build validation loop. - - Examples of ``loop``: - - # `ValLoop` will be used - loop = dict() - - # custom validation loop - loop = dict(type='CustomValLoop') - - Args: - loop (BaseLoop or dict): A validation loop or a dict to build - validation loop. If ``loop`` is a validation loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Validation loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"train_loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={ - "runner": self, - "dataloader": self._val_dataloader, - "evaluator": self._val_evaluator, - }, - ) - else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator, - ) # type: ignore - - return loop # type: ignore - - def build_test_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build test loop. - - Examples of ``loop``:: - - # `TestLoop` will be used - loop = dict() - - # custom test loop - loop = dict(type='CustomTestLoop') - - Args: - loop (BaseLoop or dict): A test loop or a dict to build test loop. - If ``loop`` is a test loop object, just returns itself. - - Returns: - :obj:`BaseLoop`: Test loop object build from ``loop_cfg``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"train_loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) # type: ignore - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={ - "runner": self, - "dataloader": self._test_dataloader, - "evaluator": self._test_evaluator, - }, - ) - else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator, - ) # type: ignore - - return loop # type: ignore - - def build_log_processor( - self, - log_processor: LogProcessor | dict, - ) -> LogProcessor: - """Build test log_processor. - - Examples of ``log_processor``: - - # `LogProcessor` will be used - log_processor = dict() - - # custom log_processor - log_processor = dict(type='CustomLogProcessor') - - Args: - log_processor (LogProcessor or dict): A log processor or a dict - to build log processor. If ``log_processor`` is a log processor - object, just returns itself. - - Returns: - :obj:`LogProcessor`: Log processor object build from - ``log_processor_cfg``. - """ - if isinstance(log_processor, LogProcessor): - return log_processor - elif not isinstance(log_processor, dict): - raise TypeError(f"log processor should be a LogProcessor object or dict, butgot {log_processor}") - - log_processor_cfg = copy.deepcopy(log_processor) # type: ignore - - if "type" in log_processor_cfg: - log_processor = LOG_PROCESSORS.build(log_processor_cfg) - else: - log_processor = LogProcessor(**log_processor_cfg) # type: ignore - - return log_processor # type: ignore - - def get_hooks_info(self) -> str: - # Get hooks info in each stage - stage_hook_map: dict[str, list] = {stage: [] for stage in Hook.stages} - for hook in self.hooks: - try: - priority = Priority(hook.priority).name # type: ignore - except ValueError: - priority = hook.priority # type: ignore - classname = hook.__class__.__name__ - hook_info = f"({priority:<12}) {classname:<35}" - for trigger_stage in hook.get_triggered_stages(): - stage_hook_map[trigger_stage].append(hook_info) - - stage_hook_infos = [] - for stage in Hook.stages: - hook_infos = stage_hook_map[stage] - if len(hook_infos) > 0: - info = f"{stage}:\n" - info += "\n".join(hook_infos) - info += "\n -------------------- " - stage_hook_infos.append(info) - return "\n".join(stage_hook_infos) - - def load_or_resume(self): - """Load or resume checkpoint.""" - if self._has_loaded: - return None - - if not self._resume and self._load_from is None: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if isinstance(self._resume, str): - resume_from = self._resume - elif self._resume and self._load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self.work_dir) - self.logger.info(f"Auto resumed from the latest checkpoint {resume_from}.") - elif self._resume and self._load_from is not None: - # resume from the specified checkpoint - resume_from = self._load_from - - if resume_from is not None: - self.resume(resume_from) - self._has_loaded = True - elif self._load_from is not None: - self.load_checkpoint(self._load_from) - self._has_loaded = True - - def train(self) -> nn.Module: - """Launch training. - - Returns: - nn.Module: The model after training. - """ - if self._train_loop is None: - raise RuntimeError( - "`self._train_loop` should not be None when calling train " - "method. Please provide `train_dataloader`, `train_cfg`, " - "`optimizer` and `param_scheduler` arguments when " - "initializing runner." - ) - - self._train_loop = self.build_train_loop(self._train_loop) # type: ignore - - if self._val_loop is not None: - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - - compile: dict | bool = False - if isinstance(self._compile, bool): - if self._compile: - compile = {"target": "train_step"} - else: - compile = copy.copy(self._compile) - compile.setdefault("target", "train_step") - - dispatch_kwargs = { - "epoch_length": len(self.train_dataloader), - "max_epochs": self.max_epochs, - "max_iters": self.max_iters, - "train_micro_batch_size_per_gpu": _get_batch_size(self.train_dataloader), - } # type: ignore - - self.strategy.prepare( - self.model, - optim_wrapper=self.optim_wrapper, - param_scheduler=self.param_schedulers, - compile=compile, - dispatch_kwargs=dispatch_kwargs, - ) - - self.model = self.strategy.model - self.optim_wrapper = self.strategy.optim_wrapper # type: ignore - if self.param_schedulers is not None: - self.param_schedulers = self.strategy.param_schedulers - - self.load_or_resume() - - # TODO: add a contextmanager to avoid calling `before_run` many times - self.call_hook("before_run") - - model = self.train_loop.run() # type: ignore - self.call_hook("after_run") - return model - - def val(self) -> dict: - """Launch validation. - - Returns: - dict: A dict of metrics on validation set. - """ - if self._val_loop is None: - raise RuntimeError( - "`self._val_loop` should not be None when calling val method." - "Please provide `val_dataloader`, `val_cfg` and " - "`val_evaluator` arguments when initializing runner." - ) - - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - - dispatch_kwargs = {"init_weights_for_test_or_val": self.cfg.get("init_weights_for_test_or_val", True)} - self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) - self.model = self.strategy.model - - self.load_or_resume() - - self.call_hook("before_run") - metrics = self.val_loop.run() # type: ignore - self.call_hook("after_run") - - return metrics - - def test(self) -> dict: - """Launch test. - - Returns: - dict: A dict of metrics on testing set. - """ - if self._test_loop is None: - raise RuntimeError( - "`self._test_loop` should not be None when calling test " - "method. Please provide `test_dataloader`, `test_cfg` and " - "`test_evaluator` arguments when initializing runner." - ) - - self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - dispatch_kwargs = {"init_weights_for_test_or_val": self.cfg.get("init_weights_for_test_or_val", True)} - self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) - self.model = self.strategy.model - - self.load_or_resume() - - self.call_hook("before_run") - metrics = self.test_loop.run() # type: ignore - self.call_hook("after_run") - - return metrics - - def call_hook(self, fn_name: str, **kwargs) -> None: - """Call all hooks. - - Args: - fn_name (str): The function name in each hook to be called, such as - "before_train_epoch". - **kwargs: Keyword arguments passed to hook. - """ - for hook in self._hooks: - # support adding additional custom hook methods - if hasattr(hook, fn_name): - try: - getattr(hook, fn_name)(self, **kwargs) - except TypeError as e: - raise TypeError(f"{e} in {hook}") from e - - def register_hook( - self, - hook: Hook | dict, - priority: str | int | Priority | None = None, - ) -> None: - """Register a hook into the hook list. - - The hook will be inserted into a priority queue, with the specified - priority (See :class:`Priority` for details of priorities). - For hooks with the same priority, they will be triggered in the same - order as they are registered. - - Priority of hook will be decided with the following priority: - - - ``priority`` argument. If ``priority`` is given, it will be priority - of hook. - - If ``hook`` argument is a dict and ``priority`` in it, the priority - will be the value of ``hook['priority']``. - - If ``hook`` argument is a dict but ``priority`` not in it or ``hook`` - is an instance of ``hook``, the priority will be ``hook.priority``. - - Args: - hook (:obj:`Hook` or dict): The hook to be registered. - priority (int or str or :obj:`Priority`, optional): Hook priority. - Lower value means higher priority. - """ - if not isinstance(hook, Hook | dict): - raise TypeError(f"hook should be an instance of Hook or dict, but got {hook}") - - _priority = None - if isinstance(hook, dict): - if "priority" in hook: - _priority = hook.pop("priority") - - hook_obj = HOOKS.build(hook) - else: - hook_obj = hook - - if priority is not None: - hook_obj.priority = priority - elif _priority is not None: - hook_obj.priority = _priority - - inserted = False - for i in range(len(self._hooks) - 1, -1, -1): - if get_priority(hook_obj.priority) >= get_priority(self._hooks[i].priority): - self._hooks.insert(i + 1, hook_obj) - inserted = True - break - if not inserted: - self._hooks.insert(0, hook_obj) - - def register_default_hooks( - self, - hooks: dict[str, Hook | dict] | None = None, - ) -> None: - """Register default hooks into hook list. - - ``hooks`` will be registered into runner to execute some default - actions like updating model parameters or saving checkpoints. - - Default hooks and their priorities: - - +----------------------+-------------------------+ - | Hooks | Priority | - +======================+=========================+ - | RuntimeInfoHook | VERY_HIGH (10) | - +----------------------+-------------------------+ - | IterTimerHook | NORMAL (50) | - +----------------------+-------------------------+ - | DistSamplerSeedHook | NORMAL (50) | - +----------------------+-------------------------+ - | LoggerHook | BELOW_NORMAL (60) | - +----------------------+-------------------------+ - | ParamSchedulerHook | LOW (70) | - +----------------------+-------------------------+ - | CheckpointHook | VERY_LOW (90) | - +----------------------+-------------------------+ - - If ``hooks`` is None, above hooks will be registered by - default:: - - default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - - If not None, ``hooks`` will be merged into ``default_hooks``. - If there are None value in default_hooks, the corresponding item will - be popped from ``default_hooks``:: - - hooks = dict(timer=None) - - The final registered default hooks will be :obj:`RuntimeInfoHook`, - :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`, - :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`. - - Args: - hooks (dict[str, Hook or dict], optional): Default hooks or configs - to be registered. - """ - default_hooks: dict = { - "runtime_info": {"type": "RuntimeInfoHook"}, - "timer": {"type": "IterTimerHook"}, - "sampler_seed": {"type": "DistSamplerSeedHook"}, - "logger": {"type": "LoggerHook"}, - "param_scheduler": {"type": "ParamSchedulerHook"}, - "checkpoint": {"type": "CheckpointHook", "interval": 1}, - } - if hooks is not None: - for name, hook in hooks.items(): - if name in default_hooks and hook is None: - # remove hook from _default_hooks - default_hooks.pop(name) - else: - assert hook is not None - default_hooks[name] = hook - - for hook in default_hooks.values(): - self.register_hook(hook) - - def register_custom_hooks(self, hooks: list[Hook | dict]) -> None: - """Register custom hooks into hook list. - - Args: - hooks (list[Hook | dict]): List of hooks or configs to be - registered. - """ - for hook in hooks: - self.register_hook(hook) - - def register_hooks( - self, - default_hooks: dict[str, Hook | dict] | None = None, - custom_hooks: list[Hook | dict] | None = None, - ) -> None: - """Register default hooks and custom hooks into hook list. - - Args: - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks - to execute default actions like updating model parameters and - saving checkpoints. Defaults to None. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - """ - self.register_default_hooks(default_hooks) - - if custom_hooks is not None: - self.register_custom_hooks(custom_hooks) - - def resume( - self, - filename: str, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: str | Callable = "default", - ) -> None: - """Resume model from checkpoint. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - """ - - def callback(checkpoint): - self.call_hook("after_load_checkpoint", checkpoint=checkpoint) - - checkpoint = self.strategy.resume( - filename, - resume_optimizer=resume_optimizer, - resume_param_scheduler=resume_param_scheduler, - map_location=map_location, - callback=callback, - ) - - self.train_loop._epoch = checkpoint["meta"]["epoch"] - self.train_loop._iter = checkpoint["meta"]["iter"] - - # check whether the number of GPU used for current experiment - # is consistent with resuming from checkpoint - if "config" in checkpoint["meta"]: - config = mmengine.Config.fromstring(checkpoint["meta"]["config"], file_format=".py") - previous_gpu_ids = config.get("gpu_ids", None) - if previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self.world_size: - # TODO, should we modify the iteration? - self.logger.info( - "Number of GPU used for current experiment is not consistent with resuming from checkpoint" - ) - if self._auto_scale_lr is None or not self._auto_scale_lr.get("enable", False): - raise RuntimeError( - "Cannot automatically rescale lr in resuming. Please " - "make sure the number of GPU is consistent with the " - "previous training state resuming from the checkpoint " - "or set `enable` in `auto_scale_lr to False." - ) - - resumed_dataset_meta = checkpoint["meta"].get("dataset_meta", None) - dataset_meta = getattr(self.train_dataloader.dataset, "metainfo", None) - - # `resumed_dataset_meta` and `dataset_meta` could be object like - # np.ndarray, which cannot be directly judged as equal or not, - # therefore we just compared their dumped results. - if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta): - self.logger.warning( - "The dataset metainfo from the resumed checkpoint is " - "different from the current training dataset, please " - "check the correctness of the checkpoint or the training " - "dataset." - ) - - self.message_hub.load_state_dict(checkpoint["message_hub"]) - - self.logger.info(f"resumed epoch: {self.epoch}, iter: {self.iter}") - - def load_checkpoint( - self, - filename: str, - map_location: str | Callable = "cpu", - strict: bool = False, - revise_keys: list | None = None, - ): - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - """ - - if revise_keys is None: - revise_keys = [(r"^module.", "")] - - def callback(checkpoint): - self.call_hook("after_load_checkpoint", checkpoint=checkpoint) - - self.strategy.load_checkpoint( - filename, - map_location=map_location, - strict=strict, - revise_keys=revise_keys, - callback=callback, - ) - - def save_checkpoint( - self, - out_dir: str, - filename: str, - file_client_args: dict | None = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: dict | None = None, - by_epoch: bool = True, - backend_args: dict | None = None, - ): - """Save checkpoints. - - ``CheckpointHook`` invokes this method to save checkpoints - periodically. - - Args: - out_dir (str): The directory that checkpoints are saved. - filename (str): The checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for - details. Defaults to None. It will be deprecated in future. - Please use `backend_args` instead. - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - meta (dict, optional): The meta information to be saved in the - checkpoint. Defaults to None. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - """ - if meta is None: - meta = {} - elif not isinstance(meta, dict): - raise TypeError(f"meta should be a dict or None, but got {type(meta)}") - - if by_epoch: - # self.epoch increments 1 after - # `self.call_hook('after_train_epoch)` but `save_checkpoint` is - # called by `after_train_epoch`` method of `CheckpointHook` so - # `epoch` should be `self.epoch + 1` - meta.update(epoch=self.epoch + 1, iter=self.iter) - else: - meta.update(epoch=self.epoch, iter=self.iter + 1) - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) - else: - filepath = join_path(out_dir, filename, backend_args=backend_args) # type: ignore - - meta.update(cfg=self.cfg.pretty_text, experiment_name=self.experiment_name) - - if hasattr(self.train_dataloader.dataset, "metainfo"): - meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) - - checkpoint = {"meta": meta, "message_hub": self.message_hub.state_dict()} - - def callback(checkpoint): - self.call_hook("before_save_checkpoint", checkpoint=checkpoint) - - self.strategy.save_checkpoint( - filename=filepath, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=checkpoint, - callback=callback, - ) - - @master_only - def dump_config(self) -> None: - """Dump config to `work_dir`.""" - if self.cfg.filename is not None: - filename = osp.basename(self.cfg.filename) - else: - filename = f"{self.timestamp}.py" - self.cfg.dump(osp.join(self.work_dir, filename)) - - def _log_env(self) -> None: - """Logging environment information of the current task. - - Args: - env_cfg (dict): The environment config of the runner. - """ - # Collect and log environment information. - system_env, runtime_env = self.strategy.collect_env() - - env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in system_env.items()) - runtime_env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in runtime_env.items()) - dash_line = "-" * 60 - self.logger.info( - "\n" - + dash_line - + "\nSystem environment:" - + env_info - + "\n\nRuntime environment:" - + runtime_env_info - + "\n" - + dash_line - + "\n" - ) - - if self.cfg._cfg_dict: - self.logger.info(f"Config:\n{self.cfg.pretty_text}") diff --git a/libs/visengine/visengine/runner/activation_checkpointing.py b/libs/visengine/visengine/runner/activation_checkpointing.py deleted file mode 100644 index 3fd8c3b..0000000 --- a/libs/visengine/visengine/runner/activation_checkpointing.py +++ /dev/null @@ -1,24 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from functools import wraps -from operator import attrgetter - -import torch -from torch.utils.checkpoint import checkpoint - - -def wrap_forward(forward): - @wraps(forward) - def wrapper(*args): - return checkpoint(forward, *args) - - return wrapper - - -def turn_on_activation_checkpointing(model: torch.nn.Module, modules: list[str] | str): - if isinstance(modules, str): - modules = [modules] - for module_name in modules: - module = attrgetter(module_name)(model) - module.forward = wrap_forward(module.forward) diff --git a/libs/visengine/visengine/runner/amp.py b/libs/visengine/visengine/runner/amp.py deleted file mode 100644 index 9ce1d63..0000000 --- a/libs/visengine/visengine/runner/amp.py +++ /dev/null @@ -1,69 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from contextlib import contextmanager - -import torch - -from visengine.device import get_device, is_cuda_available -from visengine.logging import print_log -from visengine.utils import digit_version - - -@contextmanager -def autocast( - device_type: str | None = None, - dtype: torch.dtype | None = None, - enabled: bool = True, - cache_enabled: bool | None = None, -): - """A wrapper of ``torch.autocast``. - - Provides a unified interface for PyTorch autocast functionality. - Only supports PyTorch 2.0.0 and above. - - Args: - device_type (str, required): Whether to use 'cuda' or 'cpu' device. - enabled(bool): Whether autocasting should be enabled in the region. - Defaults to True - dtype (torch_dtype, optional): Whether to use ``torch.float16`` or - ``torch.bfloat16``. - cache_enabled(bool, optional): Whether the weight cache inside - autocast should be enabled. - """ - # Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py - # This code should update with the `torch.autocast`. - if cache_enabled is None: - cache_enabled = torch.is_autocast_cache_enabled() - device = get_device() - device_type = device if device_type is None else device_type - - if device_type == "cuda": - if dtype is None: - dtype = torch.get_autocast_gpu_dtype() - - if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): - raise RuntimeError("Current CUDA Device does not support bfloat16. Please switch dtype to float16.") - - elif device_type == "cpu": - if dtype is None: - dtype = torch.bfloat16 - assert dtype == torch.bfloat16, "In CPU autocast, only support `torch.bfloat16` dtype" - else: - # Device like MPS does not support fp16 training or testing. - # If an inappropriate device is set and fp16 is enabled, an error - # will be thrown. - if enabled is False: - yield - return - else: - raise ValueError(f"User specified autocast device_type must be cuda or cpu, but got {device_type}") - - with torch.autocast( - device_type=device_type, - enabled=enabled, - dtype=dtype, - cache_enabled=cache_enabled, - ): - yield diff --git a/libs/visengine/visengine/runner/base_loop.py b/libs/visengine/visengine/runner/base_loop.py deleted file mode 100644 index 68de6e1..0000000 --- a/libs/visengine/visengine/runner/base_loop.py +++ /dev/null @@ -1,37 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod -from typing import Any - -from torch.utils.data import DataLoader - - -class BaseLoop(metaclass=ABCMeta): - """Base loop class. - - All subclasses inherited from ``BaseLoop`` should overwrite the - :meth:`run` method. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - """ - - def __init__(self, runner, dataloader: DataLoader | dict) -> None: - self._runner = runner - if isinstance(dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get("diff_rank_seed", False) - self.dataloader = runner.build_dataloader(dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) - else: - self.dataloader = dataloader - - @property - def runner(self): - return self._runner - - @abstractmethod - def run(self) -> Any: - """Execute loop.""" diff --git a/libs/visengine/visengine/runner/checkpoint.py b/libs/visengine/visengine/runner/checkpoint.py deleted file mode 100644 index 3741eee..0000000 --- a/libs/visengine/visengine/runner/checkpoint.py +++ /dev/null @@ -1,867 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import io -import logging -import os -import os.path as osp -import pkgutil -import re -from collections import OrderedDict, namedtuple -from collections.abc import Callable -from importlib import import_module -from tempfile import TemporaryDirectory - -import torch -from cloudpathlib import GSClient, GSPath -from google.cloud import storage - -import visengine -from visengine.dist import get_dist_info -from visengine.fileio import FileClient, get_file_backend -from visengine.fileio import load as load_file -from visengine.logging import print_log -from visengine.model import is_model_wrapper -from visengine.utils import apply_to, deprecated_function, digit_version, mkdir_or_exist -from visengine.utils.dl_utils import load_url - -import sys - -# `MMENGINE_HOME` is the highest priority directory to save checkpoints -# downloaded from Internet. If it is not set, as a workaround, using -# `XDG_CACHE_HOME`` or `~/.cache` instead. -# Note that `XDG_CACHE_HOME` defines the base directory relative to which -# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME` -# is either not set or empty, a default equal to `~/.cache` should be used. -ENV_MMENGINE_HOME = "MMENGINE_HOME" -ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" -DEFAULT_CACHE_DIR = "~/.cache" - - -# Create a fake mmengine module that redirects to visengine - -# tmp fix for loading old checkpoints -sys.modules["mmengine"] = visengine -sys.modules["mmengine.logging"] = visengine.logging -sys.modules["mmengine.logging.history_buffer"] = visengine.logging.history_buffer - - -class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])): - def __repr__(self): - if not self.missing_keys and not self.unexpected_keys: - return "" - return super().__repr__() - - __str__ = __repr__ - - -def _get_mmengine_home(): - mmengine_home = os.path.expanduser( - os.getenv( - ENV_MMENGINE_HOME, - os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmengine"), - ) - ) - - mkdir_or_exist(mmengine_home) - return mmengine_home - - -def load_state_dict(module, state_dict, strict=False, logger=None): - """Load state_dict to a module. - - This method is modified from :meth:`torch.nn.Module.load_state_dict`. - Default value for ``strict`` is set to ``False`` and the message for - param mismatch will be shown even if strict is False. - - Args: - module (Module): Module that receives the state_dict. - state_dict (OrderedDict): Weights. - strict (bool): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Defaults to False. - logger (:obj:`logging.Logger`, optional): Logger to log the error - message. If not specified, print function will be used. - """ - unexpected_keys = [] - missing_keys = [] - err_msg = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - # use _load_from_state_dict to enable checkpoint version control - def load(module, local_state_dict, prefix=""): - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module): - module = module.module - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - local_state_dict, - prefix, - local_metadata, - True, - missing_keys, - unexpected_keys, - err_msg, - ) - for name, child in module._modules.items(): - if child is not None: - child_prefix = prefix + name + "." - child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} - load(child, child_state_dict, child_prefix) - - # Note that the hook can modify missing_keys and unexpected_keys. - incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) - if hasattr(module, "_load_state_dict_post_hooks"): - for hook in module._load_state_dict_post_hooks.values(): - out = hook(module, incompatible_keys) - assert out is None, ( - "Hooks registered with " - "``register_load_state_dict_post_hook`` are not expected " - "to return new values, if incompatible_keys need to be " - "modified, it should be done inplace." - ) - - load(module, state_dict) - load = None # break load->load reference cycle - - # ignore "num_batches_tracked" of BN layers - missing_keys = [key for key in missing_keys if "num_batches_tracked" not in key] - - if unexpected_keys: - err_msg.append(f"unexpected key in source state_dict: {', '.join(unexpected_keys)}\n") - if missing_keys: - err_msg.append(f"missing keys in source state_dict: {', '.join(missing_keys)}\n") - - rank, _ = get_dist_info() - if len(err_msg) > 0 and rank == 0: - err_msg.insert(0, "The model and loaded state dict do not match exactly\n") - err_msg = "\n".join(err_msg) - if strict: - raise RuntimeError(err_msg) - else: - print_log(err_msg, logger=logger, level=logging.WARNING) - - -def get_torchvision_models(): - import torchvision - - if digit_version(torchvision.__version__) < digit_version("0.13.0a0"): - model_urls = {} - # When the version of torchvision is lower than 0.13, the model url is - # not declared in `torchvision.model.__init__.py`, so we need to - # iterate through `torchvision.models.__path__` to get the url for each - # model. - for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f"torchvision.models.{name}") - if hasattr(_zoo, "model_urls"): - _urls = _zoo.model_urls - model_urls.update(_urls) - else: - # Since torchvision bumps to v0.13, the weight loading logic, - # model keys and model urls have been changed. Here the URLs of old - # version is loaded to avoid breaking back compatibility. If the - # torchvision version>=0.13.0, new URLs will be added. Users can get - # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', - # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. - json_path = osp.join(mmengine.__path__[0], "hub/torchvision_0.12.json") - model_urls = mmengine.load(json_path) - if digit_version(torchvision.__version__) < digit_version("0.14.0a0"): - weights_list = [ - cls for cls_name, cls in torchvision.models.__dict__.items() if cls_name.endswith("_Weights") - ] - else: - weights_list = [ - torchvision.models.get_model_weights(model) - for model in torchvision.models.list_models(torchvision.models) - ] - - for cls in weights_list: - # The name of torchvision model weights classes ends with - # `_Weights` such as `ResNet18_Weights`. However, some model weight - # classes, such as `MNASNet0_75_Weights` does not have any urls in - # torchvision 0.13.0 and cannot be iterated. Here we simply check - # `DEFAULT` attribute to ensure the class is not empty. - if not hasattr(cls, "DEFAULT"): - continue - # Since `cls.DEFAULT` can not be accessed by iterating cls, we set - # default urls explicitly. - cls_name = cls.__name__ - cls_key = cls_name.replace("_Weights", "").lower() - model_urls[f"{cls_key}.default"] = cls.DEFAULT.url - for weight_enum in cls: - cls_key = cls_name.replace("_Weights", "").lower() - cls_key = f"{cls_key}.{weight_enum.name.lower()}" - model_urls[cls_key] = weight_enum.url - - return model_urls - - -def get_external_models(): - mmengine_home = _get_mmengine_home() - default_json_path = osp.join(mmengine.__path__[0], "hub/openmmlab.json") - default_urls = load_file(default_json_path) - assert isinstance(default_urls, dict) - external_json_path = osp.join(mmengine_home, "open_mmlab.json") - if osp.exists(external_json_path): - external_urls = load_file(external_json_path) - assert isinstance(external_urls, dict) - default_urls.update(external_urls) - - return default_urls - - -def get_mmcls_models(): - mmcls_json_path = osp.join(mmengine.__path__[0], "hub/mmcls.json") - mmcls_urls = load_file(mmcls_json_path) - - return mmcls_urls - - -def get_deprecated_model_names(): - deprecate_json_path = osp.join(mmengine.__path__[0], "hub/deprecated.json") - deprecate_urls = load_file(deprecate_json_path) - assert isinstance(deprecate_urls, dict) - - return deprecate_urls - - -def _process_mmcls_checkpoint(checkpoint): - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - # Some checkpoints converted from 3rd-party repo don't - # have the "state_dict" key. - state_dict = checkpoint - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - if k.startswith("backbone."): - new_state_dict[k[9:]] = v - new_checkpoint = {"state_dict": new_state_dict} - - return new_checkpoint - - -class CheckpointLoader: - """A general checkpoint loader to manage all schemes.""" - - _schemes: dict[str, Callable] = {} - - @classmethod - def _register_scheme(cls, prefixes, loader, force=False): - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, list | tuple) - for prefix in prefixes: - if (prefix not in cls._schemes) or force: - cls._schemes[prefix] = loader - else: - raise KeyError( - f'{prefix} is already registered as a loader backend, add "force=True" if you want to override it' - ) - # sort, longer prefixes take priority - cls._schemes = OrderedDict(sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) - - @classmethod - def register_scheme(cls, prefixes, loader=None, force=False): - """Register a loader to CheckpointLoader. - - This method can be used as a normal class method or a decorator. - - Args: - prefixes (str or list[str] or tuple[str]): - The prefix of the registered loader. - loader (function, optional): The loader function to be registered. - When this method is used as a decorator, loader is None. - Defaults to None. - force (bool, optional): Whether to override the loader - if the prefix has already been registered. Defaults to False. - """ - - if loader is not None: - cls._register_scheme(prefixes, loader, force=force) - return - - def _register(loader_cls): - cls._register_scheme(prefixes, loader_cls, force=force) - return loader_cls - - return _register - - @classmethod - def _get_checkpoint_loader(cls, path): - """Finds a loader that supports the given path. Falls back to the local - loader if no other loader is found. - - Args: - path (str): checkpoint path - - Returns: - callable: checkpoint loader - """ - for p in cls._schemes: - # use regular match to handle some cases that where the prefix of - # loader has a prefix. For example, both 's3://path' and - # 'open-mmlab:s3://path' should return `load_from_ceph` - if re.match(p, path) is not None: - return cls._schemes[p] - - @classmethod - def load_checkpoint(cls, filename, map_location=None, logger="current"): - """Load checkpoint through URL scheme path. - - Args: - filename (str): checkpoint file name with given prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - logger (str): The logger for message. Defaults to 'current'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint_loader = cls._get_checkpoint_loader(filename) - class_name = checkpoint_loader.__name__ - print_log( - f"Loads checkpoint by {class_name[10:]} backend from path: {filename}", - logger=logger, - ) - return checkpoint_loader(filename, map_location) - - -@CheckpointLoader.register_scheme(prefixes="") -def load_from_local(filename, map_location): - """Load checkpoint by local file path. - - Args: - filename (str): local checkpoint file path - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - - ## GEORGE EDIT HERE - """ - import torch.serialization - - from visengine.logging.history_buffer import HistoryBuffer - - filename = osp.expanduser(filename) - if not osp.isfile(filename): - raise FileNotFoundError(f"{filename} can not be found.") - - # Try loading with safe_globals first - try: - with torch.serialization.safe_globals([HistoryBuffer]): - # GEORGE EDIT HERE - # weights_only=True to prevent warning on newer torch versions - checkpoint = torch.load(filename, map_location=map_location, weights_only=True) - return checkpoint - except Exception as e1: - # If that fails, try with weights_only=False - try: - checkpoint = torch.load(filename, map_location=map_location, weights_only=False) - return checkpoint - except Exception as e2: - # If both attempts fail, try with weights_only=True as last resort - try: - checkpoint = torch.load(filename, map_location=map_location, weights_only=True) - return checkpoint - except Exception as e3: - raise RuntimeError( - f"Failed to load checkpoint with all methods. Original file: {filename}\n" - f"1. safe_globals error: {e1!s}\n" - f"2. weights_only=False error: {e2!s}\n" - f"3. weights_only=True error: {e3!s}" - ) - - -@CheckpointLoader.register_scheme(prefixes=("http://", "https://")) -def load_from_http(filename, map_location=None, model_dir=None, progress=os.isatty(0)): - """Load checkpoint through HTTP or HTTPS scheme path. In distributed - setting, this function only download checkpoint at local rank 0. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - model_dir (string, optional): directory in which to save the object, - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - rank, world_size = get_dist_info() - if rank == 0: - checkpoint = load_url(filename, model_dir=model_dir, map_location=map_location, progress=progress) - if world_size > 1: - torch.distributed.barrier() - if rank > 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress, - ) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes="gs://") -def load_from_gcs(filename, map_location=None): - """Load checkpoint from Google Cloud Storage. - - Args: - filename (str): GCS path starting with gs:// - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - rank, world_size = get_dist_info() - - # Create a temporary directory to download the checkpoint - checkpoint_dir = osp.join(_get_mmengine_home(), "checkpoints") - mkdir_or_exist(checkpoint_dir) - - # Extract bucket and blob path from gs:// URL - # Example: gs://binit-machine-learning/mmdet/model.pth - # -> bucket: binit-machine-learning, blob: mmdet/model.pth - filename_parts = filename.replace("gs://", "").split("/", 1) - if len(filename_parts) != 2: - raise ValueError(f"Invalid GCS path: {filename}") - - bucket_name = filename_parts[0] - blob_path = filename_parts[1] - - # Create a local filename based on the blob path - local_filename = osp.join(checkpoint_dir, blob_path.replace("/", "_")) - - # Download only on rank 0 in distributed setting - if rank == 0: - if not osp.exists(local_filename): - print_log(f"Downloading checkpoint from GCS: {filename}") - try: - # Use GSPath from cloudpathlib for downloading - gs_path = GSPath(filename, client=GSClient(storage_client=storage.Client())) - gs_path.download_to(local_filename) - print_log(f"Downloaded checkpoint to: {local_filename}") - except Exception as e: - raise RuntimeError(f"Failed to download checkpoint from GCS: {e}") - - # Synchronize in distributed setting - if world_size > 1: - torch.distributed.barrier() - - # Load the checkpoint - checkpoint = load_from_local(local_filename, map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes="pavi://") -def load_from_pavi(filename, map_location=None): - """Load checkpoint through the file path prefixed with pavi. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with pavi prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - assert filename.startswith("pavi://"), f"Expected filename startswith `pavi://`, but get {filename}" - model_path = filename[7:] - - try: - from pavi import modelcloud - except ImportError: - raise ImportError("Please install pavi to load checkpoint from modelcloud.") - - model = modelcloud.get(model_path) - with TemporaryDirectory() as tmp_dir: - downloaded_file = osp.join(tmp_dir, model.name) - model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=[r"(\S+\:)?s3://", r"(\S+\:)?petrel://"]) -def load_from_ceph(filename, map_location=None, backend="petrel"): - """Load checkpoint through the file path prefixed with s3. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with s3 prefix - map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. - Defaults to 'petrel'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - file_backend = get_file_backend(filename, backend_args={"backend": backend}) - with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=("modelzoo://", "torchvision://")) -def load_from_torchvision(filename, map_location=None): - """Load checkpoint through the file path prefixed with modelzoo or - torchvision. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - model_urls = get_torchvision_models() - if filename.startswith("modelzoo://"): - print_log( - 'The URL scheme of "modelzoo://" is deprecated, please use "torchvision://" instead', - logger="current", - level=logging.WARNING, - ) - model_name = filename[11:] - else: - model_name = filename[14:] - return load_from_http(model_urls[model_name], map_location=map_location) - - -@CheckpointLoader.register_scheme(prefixes=("open-mmlab://", "openmmlab://")) -def load_from_openmmlab(filename, map_location=None): - """Load checkpoint through the file path prefixed with open-mmlab or - openmmlab. - - Args: - filename (str): checkpoint file path with open-mmlab or - openmmlab prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_external_models() - prefix_str = "open-mmlab://" - if filename.startswith(prefix_str): - model_name = filename[13:] - else: - model_name = filename[12:] - prefix_str = "openmmlab://" - - deprecated_urls = get_deprecated_model_names() - if model_name in deprecated_urls: - print_log( - f"{prefix_str}{model_name} is deprecated in favor of {prefix_str}{deprecated_urls[model_name]}", - logger="current", - level=logging.WARNING, - ) - model_name = deprecated_urls[model_name] - model_url = model_urls[model_name] - # check if is url - if model_url.startswith(("http://", "https://")): - checkpoint = load_from_http(model_url, map_location=map_location) - else: - filename = osp.join(_get_mmengine_home(), model_url) - if not osp.isfile(filename): - raise FileNotFoundError(f"{filename} can not be found.") - checkpoint = torch.load(filename, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes="mmcls://") -def load_from_mmcls(filename, map_location=None): - """Load checkpoint through the file path prefixed with mmcls. - - Args: - filename (str): checkpoint file path with mmcls prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_mmcls_models() - model_name = filename[8:] - checkpoint = load_from_http(model_urls[model_name], map_location=map_location) - checkpoint = _process_mmcls_checkpoint(checkpoint) - return checkpoint - - -def _load_checkpoint(filename, map_location=None, logger=None): - """Load checkpoint from somewhere (modelzoo, file, url). - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None. - logger (:mod:`logging.Logger`, optional): The logger for error message. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. It can be either an - OrderedDict storing model weights or a dict containing other - information, which depends on the checkpoint. - """ - return CheckpointLoader.load_checkpoint(filename, map_location, logger) - - -def _load_checkpoint_with_prefix(prefix, filename, map_location=None): - """Load partial pretrained model with specific prefix. - - Args: - prefix (str): The prefix of sub-module. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str | None): Same as :func:`torch.load`. - Defaults to None. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint = _load_checkpoint(filename, map_location=map_location) - - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - if not prefix.endswith("."): - prefix += "." - prefix_len = len(prefix) - - state_dict = {k[prefix_len:]: v for k, v in state_dict.items() if k.startswith(prefix)} - - assert state_dict, f"{prefix} is not in the pretrained model" - return state_dict - - -def _load_checkpoint_to_model(model, checkpoint, strict=False, logger=None, revise_keys=None): - # get state_dict from checkpoint - if revise_keys is None: - revise_keys = [(r"^module\.", "")] - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - - # strip prefix of state_dict - metadata = getattr(state_dict, "_metadata", OrderedDict()) - for p, r in revise_keys: - state_dict = OrderedDict({re.sub(p, r, k): v for k, v in state_dict.items()}) - # Keep metadata in state_dict - state_dict._metadata = metadata - - # load state_dict - load_state_dict(model, state_dict, strict, logger) - return checkpoint - - -def load_checkpoint(model, filename, map_location=None, strict=False, logger=None, revise_keys=None): - """Load checkpoint from a file or URI. - - Args: - model (Module): Module to load checkpoint. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str): Same as :func:`torch.load`. - strict (bool): Whether to allow different params for the model and - checkpoint. - logger (:mod:`logging.Logger` or None): The logger for error message. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - if revise_keys is None: - revise_keys = [(r"^module\.", "")] - checkpoint = _load_checkpoint(filename, map_location, logger) - # OrderedDict is a subclass of dict - if not isinstance(checkpoint, dict): - raise RuntimeError(f"No state_dict found in checkpoint file {filename}") - - return _load_checkpoint_to_model(model, checkpoint, strict, logger, revise_keys) - - -def weights_to_cpu(state_dict): - """Copy a model state_dict to cpu. - - Args: - state_dict (OrderedDict): Model weights on GPU. - - Returns: - OrderedDict: Model weights on GPU. - """ - # stash metadata to put in state_dict later - metadata = getattr(state_dict, "_metadata", OrderedDict()) - state_dict = apply_to(state_dict, lambda x: hasattr(x, "cpu"), lambda x: x.cpu()) - state_dict._metadata = metadata - return state_dict - - -@deprecated_function( - since="0.3.0", - removed_in="0.5.0", - instructions="`_save_to_state_dict` will be deprecated in the future, please use `nn.Module._save_to_state_dict` directly.", -) -def _save_to_state_dict(module, destination, prefix, keep_vars): - """Saves module state to `destination` dictionary. - - This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. - - Args: - module (nn.Module): The module to generate state_dict. - destination (dict): A dict where state will be stored. - prefix (str): The prefix for parameters and buffers used in this - module. - keep_vars (bool): Whether to keep the variable property of the - parameters. - """ - for name, param in module._parameters.items(): - if param is not None: - destination[prefix + name] = param if keep_vars else param.detach() - for name, buf in module._buffers.items(): - if buf is not None and name not in module._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - - -def get_state_dict(module, destination=None, prefix="", keep_vars=False): - """Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - This method is modified from :meth:`torch.nn.Module.state_dict` to - recursively check parallel module in case that the model has a complicated - structure, e.g., nn.Module(nn.Module(DDP)). - - Args: - module (nn.Module): The module to generate state_dict. - destination (OrderedDict): Returned dict for the state of the - module. - prefix (str): Prefix of the key. - keep_vars (bool): Whether to keep the variable property of the - parameters. Defaults to False. - - Returns: - dict: A dictionary containing a whole state of the module. - """ - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module): - module = module.module - - # below is the same as torch.nn.Module.state_dict() - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = {"version": module._version} - module._save_to_state_dict(destination, prefix, keep_vars) - for name, child in module._modules.items(): - if child is not None: - get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) - for hook in module._state_dict_hooks.values(): - hook_result = hook(module, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - -def save_checkpoint(checkpoint, filename, file_client_args=None, backend_args=None): - """Save checkpoint to file. - - Args: - checkpoint (dict): Module whose params are to be saved. - filename (str): Checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - `backend_args` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - """ - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - logger="current", - level=logging.WARNING, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - if filename.startswith("pavi://"): - if file_client_args is not None or backend_args is not None: - raise ValueError('"file_client_args" or "backend_args" should be "None" if filename starts with "pavi://"') - try: - from pavi import exception, modelcloud - except ImportError: - raise ImportError("Please install pavi to load checkpoint from modelcloud.") - model_path = filename[7:] - root = modelcloud.Folder() - model_dir, model_name = osp.split(model_path) - try: - model = modelcloud.get(model_dir) - except exception.NodeNotFoundError: - model = root.create_training_model(model_dir) - with TemporaryDirectory() as tmp_dir: - checkpoint_file = osp.join(tmp_dir, model_name) - with open(checkpoint_file, "wb") as f: - torch.save(checkpoint, f) - f.flush() - model.create_file(checkpoint_file, name=model_name) - else: - file_client = FileClient.infer_client(file_client_args, filename) - if file_client_args is None: - file_backend = get_file_backend(filename, backend_args=backend_args) - else: - file_backend = file_client - - with io.BytesIO() as f: - torch.save(checkpoint, f) - file_backend.put(f.getvalue(), filename) - - -def find_latest_checkpoint(path: str) -> str | None: - """Find the latest checkpoint from the given path. - - Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501 - - Args: - path(str): The path to find checkpoints. - - Returns: - str or None: File path of the latest checkpoint. - """ - save_file = osp.join(path, "last_checkpoint") - last_saved: str | None - if os.path.exists(save_file): - with open(save_file) as f: - last_saved = f.read().strip() - else: - print_log("Did not find last_checkpoint to be resumed.") - last_saved = None - return last_saved diff --git a/libs/visengine/visengine/runner/log_processor.py b/libs/visengine/visengine/runner/log_processor.py deleted file mode 100644 index 82d36b4..0000000 --- a/libs/visengine/visengine/runner/log_processor.py +++ /dev/null @@ -1,558 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import datetime -import re -from collections import OrderedDict -from itertools import chain - -import numpy as np -import torch - -from visengine.device import ( - get_max_cuda_memory, - is_cuda_available, -) -from visengine.registry import LOG_PROCESSORS - -from .utils import _get_batch_size - - -@LOG_PROCESSORS.register_module(force=True) -class LogProcessor: - """A log processor used to format log information collected from - ``runner.message_hub.log_scalars``. - - ``LogProcessor`` instance is built by runner and will format - ``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can - directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument - ``custom_cfg`` of constructor can control the statistics method of logs. - - Args: - window_size (int): default smooth interval. Defaults to 10. - by_epoch (bool): Whether to format logs with epoch stype. Defaults to - True. - custom_cfg (list[dict], optional): Contains multiple log config dict, - in which key means the data source name of log and value means the - statistic method and corresponding arguments used to count the - data source. Defaults to None. - - - If custom_cfg is None, all logs will be formatted via default - methods, such as smoothing loss by default window_size. If - custom_cfg is defined as a list of config dict, for example: - [dict(data_src='loss', method='mean', log_name='global_loss', - window_size='global')]. It means the log item ``loss`` will be - counted as global mean and additionally logged as ``global_loss`` - (defined by ``log_name``). If ``log_name`` is not defined in - config dict, the original logged key will be overwritten. - - - The original log item cannot be overwritten twice. Here is - an error example: - [dict(data_src='loss', method='mean', window_size='global'), - dict(data_src='loss', method='mean', window_size='epoch')]. - Both log config dict in custom_cfg do not have ``log_name`` key, - which means the loss item will be overwritten twice. - - - For those statistic methods with the ``window_size`` argument, - if ``by_epoch`` is set to False, ``windows_size`` should not be - `epoch` to statistics log value by epoch. - num_digits (int): The number of significant digit shown in the - logging message. Defaults to 4. - log_with_hierarchy (bool): Whether to log with hierarchy. If it is - True, the information is written to visualizer backend such as - :obj:`LocalVisBackend` and :obj:`TensorboardBackend` - with hierarchy. For example, ``loss`` will be saved as - ``train/loss``, and accuracy will be saved as ``val/accuracy``. - Defaults to False. - `New in version 0.7.0.` - mean_pattern (str): This is a regular expression used to match the log - that need to be included in the smoothing statistics. - `New in version 0.7.3.` - - Examples: - >>> # `log_name` is defined, `loss_large_window` will be an additional - >>> # record. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> log_name='loss_large_window', - >>> method_name='mean', - >>> window_size=100)]) - >>> # `log_name` is not defined. `loss` will be overwritten. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100)]) - >>> # Record loss with different statistics methods. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> log_name='loss_large_window', - >>> method_name='mean', - >>> window_size=100), - >>> dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100)]) - >>> # Overwrite loss item twice will raise an error. - >>> log_processor = dict( - >>> window_size=10, - >>> by_epoch=True, - >>> custom_cfg=[dict(data_src='loss', - >>> method_name='mean', - >>> window_size=100), - >>> dict(data_src='loss', - >>> method_name='max', - >>> window_size=100)]) - AssertionError - """ - - def __init__( - self, - window_size=10, - by_epoch=True, - custom_cfg: list[dict] | None = None, - num_digits: int = 4, - log_with_hierarchy: bool = False, - mean_pattern=r".*(loss|time|data_time|grad_norm).*", - ): - self.window_size = window_size - self.by_epoch = by_epoch - self.custom_cfg = custom_cfg if custom_cfg else [] - self.num_digits = num_digits - self.log_with_hierarchy = log_with_hierarchy - self.mean_pattern = re.compile(mean_pattern) - self._check_custom_cfg() - - def get_log_after_iter(self, runner, batch_idx: int, mode: str) -> tuple[dict, str]: - """Format log string after training, validation or testing iteration. - - Args: - runner (Runner): The runner of training phase. - batch_idx (int): The index of the current batch in the current - loop. - mode (str): Current mode of runner, train, test or val. - - Return: - Tuple[dict, str]: Formatted log dict/string which will be - recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. - """ - assert mode in ["train", "test", "val"] - # Overwrite ``window_size`` defined in ``custom_cfg`` to int value. - parsed_cfg = self._parse_windows_size(runner, batch_idx, self.custom_cfg) - # log_tag is used to write log information to terminal - log_tag = self._collect_scalars(parsed_cfg, runner, mode) - - # If `self.log_with_hierarchy` is False, the tag is the same as - # log_tag. Otherwise, each key in tag starts with prefix `train`, - # `test` or `val` - if not self.log_with_hierarchy: - tag = copy.deepcopy(log_tag) - else: - tag = self._collect_scalars(parsed_cfg, runner, mode, True) - - # Record learning rate. - lr_str_list = [] - for key, value in tag.items(): - if key.endswith("lr"): - key = self._remove_prefix(key, f"{mode}/") - log_tag.pop(key) - lr_str_list.append(f"{key}: {value:.{self.num_digits}e}") - lr_str = " ".join(lr_str_list) - # Format log header. - # by_epoch == True - # train/val: Epoch [5][5/10] ... - # test: Epoch [5/10] - # by_epoch == False - # train: Epoch [5/10000] ... (divided by `max_iter`) - # val/test: Epoch [5/2000] ... (divided by length of dataloader) - if self.by_epoch: - # Align the iteration log: - # Epoch(train) [ 9][010/270] - # ... ||| ||| - # Epoch(train) [ 10][100/270] - dataloader_len = self._get_dataloader_size(runner, mode) - cur_iter = self._get_iter(runner, batch_idx) - cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len))) - if mode in ["train", "val"]: - cur_epoch = self._get_epoch(runner, mode) - if not (isinstance(runner._train_loop, dict) or runner._train_loop is None): - # Right Align the epoch log: - # Epoch(train) [9][100/270] - # ... || - # Epoch(train) [100][100/270] - max_epochs = runner.max_epochs - # 3 means the three characters: "[", "]", and " " occupied - # in " [{max_epochs}]" - cur_epoch_str = f"[{cur_epoch}]".rjust(len(str(max_epochs)) + 3, " ") - else: - cur_epoch_str = f"[{cur_epoch}]" - tag["epoch"] = cur_epoch - log_str = f"Epoch({mode}){cur_epoch_str}[{cur_iter_str}/{dataloader_len}] " - else: - log_str = f"Epoch({mode}) [{cur_iter_str}/{dataloader_len}] " - else: - if mode == "train": - cur_iter = self._get_iter(runner, batch_idx) - cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters))) - log_str = f"Iter({mode}) [{cur_iter_str}/{runner.max_iters}] " - else: - dataloader_len = self._get_dataloader_size(runner, mode) - cur_iter_str = str(batch_idx + 1).rjust(len(str(dataloader_len))) - log_str = f"Iter({mode}) [{cur_iter_str}/{dataloader_len}] " - # Add global iter. - if isinstance(runner._train_loop, dict) or runner._train_loop is None: - tag["iter"] = 0 - else: - tag["iter"] = runner.iter + 1 - # Concatenate lr, momentum string with log header. - log_str += f"{lr_str} " - # If IterTimerHook used in runner, eta, time, and data_time should be - # recorded. - if all(item in log_tag for item in ["time", "data_time"]) and "eta" in runner.message_hub.runtime_info: - eta = runner.message_hub.get_info("eta") - eta_str = str(datetime.timedelta(seconds=int(eta))) - log_str += f"eta: {eta_str} " - log_str += ( - f"time: {log_tag['time']:.{self.num_digits}f} data_time: {log_tag['data_time']:.{self.num_digits}f} " - ) - - # Calculate and add images/second for both train and val - if mode == "train": - batch_size = _get_batch_size(runner._train_dataloader) - elif mode == "val": - batch_size = _get_batch_size(runner._val_dataloader) - else: - batch_size = None - - if batch_size: - images_per_second = batch_size / log_tag["time"] - log_str += f"img/s: {images_per_second:.{self.num_digits - 2}f} " - tag["images_per_second"] = images_per_second - - # Pop recorded keys - log_tag.pop("time") - log_tag.pop("data_time") - - # If cuda/musa is available, - # the max memory occupied should be calculated. - if is_cuda_available(): - max_memory = self._get_max_memory(runner) - log_str += f"memory: {max_memory} " - tag["memory"] = max_memory - - # Loop left keys to fill `log_str`. - if mode in ("train", "val"): - log_items = [] - for name, val in log_tag.items(): - if mode == "val" and not name.startswith("val/loss"): - continue - if isinstance(val, float): - val = f"{val:.{self.num_digits}f}" - log_items.append(f"{name}: {val}") - log_str += " ".join(log_items) - return tag, log_str - - def get_log_after_epoch(self, runner, batch_idx: int, mode: str, with_non_scalar: bool = False) -> tuple[dict, str]: - """Format log string after validation or testing epoch. - - Args: - runner (Runner): The runner of validation/testing phase. - batch_idx (int): The index of the current batch in the current - loop. - mode (str): Current mode of runner. - with_non_scalar (bool): Whether to include non-scalar infos in the - returned tag. Defaults to False. - - Return: - Tuple[dict, str]: Formatted log dict/string which will be - recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. - """ - assert mode in ["test", "val"], f"`_get_metric_log_str` only accept val or test mode, but got {mode}" - dataloader_len = self._get_dataloader_size(runner, mode) - - # By epoch: - # Epoch(val) [10][1000/1000] ... - # Epoch(test) [1000/1000] ... - # By iteration: - # Iteration(val) [1000/1000] ... - # Iteration(test) [1000/1000] ... - if self.by_epoch: - if mode == "val": - cur_epoch = self._get_epoch(runner, mode) - log_str = f"Epoch({mode}) [{cur_epoch}][{dataloader_len}/{dataloader_len}] " - else: - log_str = f"Epoch({mode}) [{dataloader_len}/{dataloader_len}] " - - else: - log_str = f"Iter({mode}) [{dataloader_len}/{dataloader_len}] " - - custom_cfg_copy = copy.deepcopy(self.custom_cfg) - # remove prefix - custom_keys = [self._remove_prefix(cfg["data_src"], f"{mode}/") for cfg in custom_cfg_copy] - # Count the averaged time and data_time by epoch - if "time" not in custom_keys: - custom_cfg_copy.append({"data_src": "time", "window_size": "epoch", "method_name": "mean"}) - if "data_time" not in custom_keys: - custom_cfg_copy.append({"data_src": "data_time", "window_size": "epoch", "method_name": "mean"}) - parsed_cfg = self._parse_windows_size(runner, batch_idx, custom_cfg_copy) - # tag is used to write log information to different backends. - ori_tag = self._collect_scalars(parsed_cfg, runner, mode, self.log_with_hierarchy) - non_scalar_tag = self._collect_non_scalars(runner, mode) - # move `time` or `data_time` to the end of the log - tag = OrderedDict() - time_tag = OrderedDict() - for key, value in ori_tag.items(): - if key in (f"{mode}/time", f"{mode}/data_time", "time", "data_time"): - time_tag[key] = value - else: - tag[key] = value - # Log other messages. - log_items = [] - log_str += " " - for name, val in chain(tag.items(), non_scalar_tag.items(), time_tag.items()): - if isinstance(val, float): - val = f"{val:.{self.num_digits}f}" - if isinstance(val, torch.Tensor | np.ndarray): - # newline to display tensor and array. - val = f"\n{val}\n" - log_items.append(f"{name}: {val}") - log_str += " ".join(log_items) - - if with_non_scalar: - tag.update(non_scalar_tag) - tag.update(time_tag) - return tag, log_str - - def _collect_scalars(self, custom_cfg: list[dict], runner, mode: str, reserve_prefix: bool = False) -> dict: - """Collect log information to compose a dict according to mode. - - Args: - custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int - ``window_size``. - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - reserve_prefix (bool): Whether to reserve the prefix of the key. - - Returns: - dict: Statistical values of logs. - """ - custom_cfg = copy.deepcopy(custom_cfg) - tag = OrderedDict() - # history_scalars of train/val/test phase. - history_scalars = runner.message_hub.log_scalars - # corresponding mode history_scalars - mode_history_scalars = OrderedDict() - # extract log scalars and remove prefix to `mode_history_scalars` - # according to mode. - for prefix_key, log_buffer in history_scalars.items(): - if prefix_key.startswith(mode): - if not reserve_prefix: - key = self._remove_prefix(prefix_key, f"{mode}/") - else: - key = prefix_key - mode_history_scalars[key] = log_buffer - for key in mode_history_scalars: - # Update the latest learning rate and smoothed time logs. - if re.search(self.mean_pattern, key) is not None: - tag[key] = mode_history_scalars[key].mean(self.window_size) - else: - # Default statistic method is current. - tag[key] = mode_history_scalars[key].current() - # Update custom keys. - for log_cfg in custom_cfg: - data_src = log_cfg.pop("data_src") - log_name = log_cfg.pop("log_name", data_src) - if reserve_prefix: - data_src = f"{mode}/{data_src}" - log_name = f"{mode}/{log_name}" - # log item in custom_cfg could only exist in train or val - # mode. - if data_src in mode_history_scalars: - tag[log_name] = mode_history_scalars[data_src].statistics(**log_cfg) - return tag - - def _collect_non_scalars(self, runner, mode: str) -> dict: - """Collect log information to compose a dict according to mode. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - - Returns: - dict: non-scalar infos of the specified mode. - """ - # infos of train/val/test phase. - infos = runner.message_hub.runtime_info - # corresponding mode infos - mode_infos = OrderedDict() - # extract log info and remove prefix to `mode_infos` according to mode. - for prefix_key, value in infos.items(): - if prefix_key.startswith(mode): - if self.log_with_hierarchy: - key = prefix_key - else: - key = self._remove_prefix(prefix_key, f"{mode}/") - mode_infos[key] = value - return mode_infos - - def _remove_prefix(self, string: str, prefix: str): - """Remove the prefix ``train``, ``val`` and ``test`` of the key.""" - if string.startswith(prefix): - return string[len(prefix) :] - else: - return string - - def _check_custom_cfg(self) -> None: - """Check the legality of ``self.custom_cfg``.""" - - def _check_window_size(): - for log_cfg in self.custom_cfg: - if not self.by_epoch: - assert log_cfg["window_size"] != "epoch", ( - "window_size cannot be epoch if LoggerHook.by_epoch is False." - ) - - def _check_repeated_log_name(): - # The `log_name` of the same data_src should not be repeated. - # If `log_name` is not specified, `data_src` will be overwritten. - # But only allowed to be overwritten once. - check_set = set() - for log_cfg in self.custom_cfg: - assert "data_src" in log_cfg - data_src = log_cfg["data_src"] - log_name = log_cfg.get("log_name", data_src) - assert log_name not in check_set, ( - f"Found duplicate {log_name} for {data_src}. Please check" - "your `custom_cfg` for `log_processor`. You should " - f"neither define duplicate `{log_name}` for {data_src} " - f"nor do not define any {log_name} for multiple " - f"{data_src}, See more information in the docstring of " - "LogProcessor" - ) - - check_set.add(log_name) - - _check_repeated_log_name() - _check_window_size() - - def _parse_windows_size(self, runner, batch_idx: int, custom_cfg: list | None = None) -> list: - """Parse window_size defined in custom_cfg to int value. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - batch_idx (int): The iteration index of current dataloader. - custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None - to keep backward compatibility. - """ - if custom_cfg is None: - custom_cfg = copy.deepcopy(self.custom_cfg) - else: - custom_cfg = copy.deepcopy(custom_cfg) - for log_cfg in custom_cfg: - window_size = log_cfg.get("window_size", None) - if window_size is None or isinstance(window_size, int): - continue - elif window_size == "epoch": - log_cfg["window_size"] = batch_idx + 1 - elif window_size == "global": - log_cfg["window_size"] = runner.iter + 1 - else: - raise TypeError(f"window_size should be int, epoch or global, but got invalid {window_size}") - return custom_cfg - - def _get_max_memory(self, runner) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) - for a given device. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - - Returns: - The maximum GPU memory occupied by tensors in megabytes for a given - device. - """ - - device = getattr(runner.model, "output_device", None) - return get_max_cuda_memory(device) - - def _get_iter(self, runner, batch_idx: int) -> int: - """Get current iteration index. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - batch_idx (int): The iteration index of current - dataloader. Defaults to None. - - Returns: - int: The current global iter or inner iter. - """ - if self.by_epoch: - current_iter = batch_idx + 1 - else: - current_iter = runner.iter + 1 - return current_iter - - def _get_epoch(self, runner, mode: str) -> int: - """Get current epoch according to mode. - - Args: - runner (Runner): The runner of the training/testing/validation - process. - mode (str): Current mode of runner. - - Returns: - int: The current epoch. - """ - if mode == "train": - epoch = runner.epoch + 1 - elif mode == "val": - if isinstance(runner._train_loop, dict) or runner._train_loop is None: - epoch = 0 - else: - # normal val mode - # runner.epoch += 1 has been done before validation - epoch = runner.epoch - else: - raise ValueError(f"runner mode should be 'train' or 'val', but got {mode}") - return epoch - - def _get_cur_loop(self, runner, mode: str): - """Get current loop according to mode. - - Args: - runner (Runner): The runner of the training/validation/testing - process. - mode (str): Current mode of runner. - - Returns: - BaseLoop: Current loop of runner. - """ - # returns type hint will occur circular import - if mode == "train": - return runner.train_loop - elif mode == "val": - return runner.val_loop - else: - return runner.test_loop - - def _get_dataloader_size(self, runner, mode) -> int: - """Get dataloader size of current loop. - - Args: - runner (Runner): The runner of the training/validation/testing - mode (str): Current mode of runner. - - Returns: - int: The dataloader size of current loop. - """ - return len(self._get_cur_loop(runner=runner, mode=mode).dataloader) diff --git a/libs/visengine/visengine/runner/loops.py b/libs/visengine/visengine/runner/loops.py deleted file mode 100644 index a789fb0..0000000 --- a/libs/visengine/visengine/runner/loops.py +++ /dev/null @@ -1,533 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import bisect -import logging -import time -from collections.abc import Sequence - -import torch -from torch.utils.data import DataLoader - -from visengine.evaluator import Evaluator -from visengine.logging import HistoryBuffer, print_log -from visengine.registry import LOOPS -from visengine.structures import BaseDataElement -from visengine.utils import is_list_of - -from .amp import autocast -from .base_loop import BaseLoop -from .utils import calc_dynamic_intervals - - -@LOOPS.register_module(force=True) -class EpochBasedTrainLoop(BaseLoop): - """Loop for epoch-based training. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - max_epochs (int): Total training epochs. - val_begin (int): The epoch that begins validating. - Defaults to 1. - val_interval (int): Validation interval. Defaults to 1. - dynamic_intervals (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - """ - - def __init__( - self, - runner, - dataloader: DataLoader | dict, - max_epochs: int, - val_begin: int = 1, - val_interval: int = 1, - dynamic_intervals: list[tuple[int, int]] | None = None, - ) -> None: - super().__init__(runner, dataloader) - self._max_epochs = int(max_epochs) - assert self._max_epochs == max_epochs, f"`max_epochs` should be a integer number, but get {max_epochs}." - self._max_iters = self._max_epochs * len(self.dataloader) - self._epoch = 0 - self._iter = 0 - self.val_begin = val_begin - self.val_interval = val_interval - # This attribute will be updated by `EarlyStoppingHook` - # when it is enabled. - self.stop_training = False - if hasattr(self.dataloader.dataset, "metainfo"): - self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo - else: - print_log( - f"Dataset {self.dataloader.dataset.__class__.__name__} has no metainfo. ``dataset_meta`` in visualizer will be None.", - logger="current", - level=logging.WARNING, - ) - - self.dynamic_milestones, self.dynamic_intervals = calc_dynamic_intervals(self.val_interval, dynamic_intervals) - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - return self._max_epochs - - @property - def max_iters(self): - """int: Total iterations to train model.""" - return self._max_iters - - @property - def epoch(self): - """int: Current epoch.""" - return self._epoch - - @property - def iter(self): - """int: Current iteration.""" - return self._iter - - def run(self) -> torch.nn.Module: - """Launch training.""" - self.runner.call_hook("before_train") - - while self._epoch < self._max_epochs and not self.stop_training: - self.run_epoch() - - self._decide_current_val_interval() - if ( - self.runner.val_loop is not None - and self._epoch >= self.val_begin - and (self._epoch % self.val_interval == 0 or self._epoch == self._max_epochs) - ): - self.runner.val_loop.run() - - self.runner.call_hook("after_train") - return self.runner.model - - def run_epoch(self) -> None: - """Iterate one epoch.""" - self.runner.call_hook("before_train_epoch") - self.runner.model.train() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - self.runner.call_hook("after_train_epoch") - self._epoch += 1 - - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one min-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook("before_train_iter", batch_idx=idx, data_batch=data_batch) - # Enable gradient accumulation mode and avoid unnecessary gradient - # synchronization during gradient accumulation process. - # outputs should be a dict of loss. - outputs = self.runner.model.train_step(data_batch, optim_wrapper=self.runner.optim_wrapper) - - self.runner.call_hook("after_train_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs) - self._iter += 1 - - def _decide_current_val_interval(self) -> None: - """Dynamically modify the ``val_interval``.""" - step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) - self.val_interval = self.dynamic_intervals[step - 1] - - -class _InfiniteDataloaderIterator: - """An infinite dataloader iterator wrapper for IterBasedTrainLoop. - - It resets the dataloader to continue iterating when the iterator has - iterated over all the data. However, this approach is not efficient, as the - workers need to be restarted every time the dataloader is reset. It is - recommended to use `mmengine.dataset.InfiniteSampler` to enable the - dataloader to iterate infinitely. - """ - - def __init__(self, dataloader: DataLoader) -> None: - self._dataloader = dataloader - self._iterator = iter(self._dataloader) - self._epoch = 0 - - def __iter__(self): - return self - - def __next__(self) -> Sequence[dict]: - try: - data = next(self._iterator) - except StopIteration: - print_log( - "Reach the end of the dataloader, it will be " - "restarted and continue to iterate. It is " - "recommended to use " - "`mmengine.dataset.InfiniteSampler` to enable the " - "dataloader to iterate infinitely.", - logger="current", - level=logging.WARNING, - ) - self._epoch += 1 - if hasattr(self._dataloader, "sampler") and hasattr(self._dataloader.sampler, "set_epoch"): - # In case the` _SingleProcessDataLoaderIter` has no sampler, - # or data loader uses `SequentialSampler` in Pytorch. - self._dataloader.sampler.set_epoch(self._epoch) - - elif hasattr(self._dataloader, "batch_sampler") and hasattr( - self._dataloader.batch_sampler.sampler, "set_epoch" - ): - # In case the` _SingleProcessDataLoaderIter` has no batch - # sampler. batch sampler in pytorch warps the sampler as its - # attributes. - self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) - time.sleep(2) # Prevent possible deadlock during epoch transition - self._iterator = iter(self._dataloader) - data = next(self._iterator) - return data - - -@LOOPS.register_module(force=True) -class IterBasedTrainLoop(BaseLoop): - """Loop for iter-based training. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - max_iters (int): Total training iterations. - val_begin (int): The iteration that begins validating. - Defaults to 1. - val_interval (int): Validation interval. Defaults to 1000. - dynamic_intervals (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - """ - - def __init__( - self, - runner, - dataloader: DataLoader | dict, - max_iters: int, - val_begin: int = 1, - val_interval: int = 1000, - dynamic_intervals: list[tuple[int, int]] | None = None, - ) -> None: - super().__init__(runner, dataloader) - self._max_iters = int(max_iters) - assert self._max_iters == max_iters, f"`max_iters` should be a integer number, but get {max_iters}" - self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop - self._epoch = 0 - self._iter = 0 - self.val_begin = val_begin - self.val_interval = val_interval - # This attribute will be updated by `EarlyStoppingHook` - # when it is enabled. - self.stop_training = False - if hasattr(self.dataloader.dataset, "metainfo"): - self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo - else: - print_log( - f"Dataset {self.dataloader.dataset.__class__.__name__} has no metainfo. ``dataset_meta`` in visualizer will be None.", - logger="current", - level=logging.WARNING, - ) - # get the iterator of the dataloader - self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) - - self.dynamic_milestones, self.dynamic_intervals = calc_dynamic_intervals(self.val_interval, dynamic_intervals) - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - return self._max_epochs - - @property - def max_iters(self): - """int: Total iterations to train model.""" - return self._max_iters - - @property - def epoch(self): - """int: Current epoch.""" - return self._epoch - - @property - def iter(self): - """int: Current iteration.""" - return self._iter - - def run(self) -> None: - """Launch training.""" - self.runner.call_hook("before_train") - # In iteration-based training loop, we treat the whole training process - # as a big epoch and execute the corresponding hook. - self.runner.call_hook("before_train_epoch") - if self._iter > 0: - print_log( - f"Advance dataloader {self._iter} steps to skip data that has already been trained", - logger="current", - level=logging.WARNING, - ) - for _ in range(self._iter): - next(self.dataloader_iterator) - while self._iter < self._max_iters and not self.stop_training: - self.runner.model.train() - - data_batch = next(self.dataloader_iterator) - self.run_iter(data_batch) - - self._decide_current_val_interval() - if ( - self.runner.val_loop is not None - and self._iter >= self.val_begin - and (self._iter % self.val_interval == 0 or self._iter == self._max_iters) - ): - self.runner.val_loop.run() - - self.runner.call_hook("after_train_epoch") - self.runner.call_hook("after_train") - return self.runner.model - - def run_iter(self, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook("before_train_iter", batch_idx=self._iter, data_batch=data_batch) - # Enable gradient accumulation mode and avoid unnecessary gradient - # synchronization during gradient accumulation process. - # outputs should be a dict of loss. - outputs = self.runner.model.train_step(data_batch, optim_wrapper=self.runner.optim_wrapper) - - self.runner.call_hook( - "after_train_iter", - batch_idx=self._iter, - data_batch=data_batch, - outputs=outputs, - ) - self._iter += 1 - - def _decide_current_val_interval(self) -> None: - """Dynamically modify the ``val_interval``.""" - step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) - self.val_interval = self.dynamic_intervals[step - 1] - - -@LOOPS.register_module(force=True) -class ValLoop(BaseLoop): - """Loop for validation. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 validation. Defaults to - False. - """ - - def __init__( - self, - runner, - dataloader: DataLoader | dict, - evaluator: Evaluator | dict | list, - fp16: bool = False, - ) -> None: - super().__init__(runner, dataloader) - - if isinstance(evaluator, dict | list): - self.evaluator = runner.build_evaluator(evaluator) # type: ignore - else: - assert isinstance(evaluator, Evaluator), ( - f"evaluator must be one of dict, list or Evaluator instance, but got {type(evaluator)}." - ) - self.evaluator = evaluator # type: ignore - if hasattr(self.dataloader.dataset, "metainfo"): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo - else: - print_log( - f"Dataset {self.dataloader.dataset.__class__.__name__} has no " - "metainfo. ``dataset_meta`` in evaluator, metric and " - "visualizer will be None.", - logger="current", - level=logging.WARNING, - ) - self.fp16 = fp16 - self.val_loss: dict[str, HistoryBuffer] = {} - - def run(self) -> dict: - """Launch validation.""" - self.runner.call_hook("before_val") - self.runner.call_hook("before_val_epoch") - self.runner.model.eval() - - # clear val loss - self.val_loss.clear() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - if self.val_loss: - loss_dict = _parse_losses(self.val_loss, "val") - metrics.update(loss_dict) - - self.runner.call_hook("after_val_epoch", metrics=metrics) - self.runner.call_hook("after_val") - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]): - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data - from dataloader. - """ - self.runner.call_hook("before_val_iter", batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.val_step(data_batch) - - outputs, self.val_loss = _update_losses(outputs, self.val_loss) - - self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook("after_val_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs) - - -@LOOPS.register_module(force=True) -class TestLoop(BaseLoop): - """Loop for test. - - Args: - runner (Runner): A reference of runner. - dataloader (Dataloader or dict): A dataloader object or a dict to - build a dataloader. - evaluator (Evaluator or dict or list): Used for computing metrics. - fp16 (bool): Whether to enable fp16 testing. Defaults to - False. - """ - - def __init__( - self, - runner, - dataloader: DataLoader | dict, - evaluator: Evaluator | dict | list, - fp16: bool = False, - ): - super().__init__(runner, dataloader) - - if isinstance(evaluator, dict) or isinstance(evaluator, list): - self.evaluator = runner.build_evaluator(evaluator) # type: ignore - else: - self.evaluator = evaluator # type: ignore - if hasattr(self.dataloader.dataset, "metainfo"): - self.evaluator.dataset_meta = self.dataloader.dataset.metainfo - self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo - else: - print_log( - f"Dataset {self.dataloader.dataset.__class__.__name__} has no " - "metainfo. ``dataset_meta`` in evaluator, metric and " - "visualizer will be None.", - logger="current", - level=logging.WARNING, - ) - self.fp16 = fp16 - self.test_loss: dict[str, HistoryBuffer] = {} - - def run(self) -> dict: - """Launch test.""" - self.runner.call_hook("before_test") - self.runner.call_hook("before_test_epoch") - self.runner.model.eval() - - # clear test loss - self.test_loss.clear() - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - if self.test_loss: - loss_dict = _parse_losses(self.test_loss, "test") - metrics.update(loss_dict) - - self.runner.call_hook("after_test_epoch", metrics=metrics) - self.runner.call_hook("after_test") - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook("before_test_iter", batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = self.runner.model.test_step(data_batch) - - outputs, self.test_loss = _update_losses(outputs, self.test_loss) - - self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook("after_test_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs) - - -def _parse_losses(losses: dict[str, HistoryBuffer], stage: str) -> dict[str, float]: - """Parses the raw losses of the network. - - Args: - losses (dict): raw losses of the network. - stage (str): The stage of loss, e.g., 'val' or 'test'. - - Returns: - dict[str, float]: The key is the loss name, and the value is the - average loss. - """ - all_loss = 0 - loss_dict: dict[str, float] = {} - - for loss_name, loss_value in losses.items(): - avg_loss = loss_value.mean() - loss_dict[loss_name] = avg_loss - if "loss" in loss_name: - all_loss += avg_loss - - loss_dict[f"{stage}_loss"] = all_loss - return loss_dict - - -def _update_losses(outputs: list, losses: dict) -> tuple[list, dict]: - """Update and record the losses of the network. - - Args: - outputs (list): The outputs of the network. - losses (dict): The losses of the network. - - Returns: - list: The updated outputs of the network. - dict: The updated losses of the network. - """ - if isinstance(outputs[-1], BaseDataElement) and outputs[-1].keys() == ["loss"]: - loss = outputs[-1].loss # type: ignore - outputs = outputs[:-1] - else: - loss = {} - - for loss_name, loss_value in loss.items(): - if loss_name not in losses: - losses[loss_name] = HistoryBuffer() - if isinstance(loss_value, torch.Tensor): - losses[loss_name].update(loss_value.item()) - elif is_list_of(loss_value, torch.Tensor): - for loss_value_i in loss_value: - losses[loss_name].update(loss_value_i.item()) - return outputs, losses diff --git a/libs/visengine/visengine/runner/priority.py b/libs/visengine/visengine/runner/priority.py deleted file mode 100644 index fab102f..0000000 --- a/libs/visengine/visengine/runner/priority.py +++ /dev/null @@ -1,62 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from enum import Enum - - -class Priority(Enum): - """Hook priority levels. - - +--------------+------------+ - | Level | Value | - +==============+============+ - | HIGHEST | 0 | - +--------------+------------+ - | VERY_HIGH | 10 | - +--------------+------------+ - | HIGH | 30 | - +--------------+------------+ - | ABOVE_NORMAL | 40 | - +--------------+------------+ - | NORMAL | 50 | - +--------------+------------+ - | BELOW_NORMAL | 60 | - +--------------+------------+ - | LOW | 70 | - +--------------+------------+ - | VERY_LOW | 90 | - +--------------+------------+ - | LOWEST | 100 | - +--------------+------------+ - """ - - HIGHEST = 0 - VERY_HIGH = 10 - HIGH = 30 - ABOVE_NORMAL = 40 - NORMAL = 50 - BELOW_NORMAL = 60 - LOW = 70 - VERY_LOW = 90 - LOWEST = 100 - - -def get_priority(priority: int | str | Priority) -> int: - """Get priority value. - - Args: - priority (int or str or :obj:`Priority`): Priority. - - Returns: - int: The priority value. - """ - if isinstance(priority, int): - if priority < 0 or priority > 100: - raise ValueError("priority must be between 0 and 100") - return priority - elif isinstance(priority, Priority): - return priority.value - elif isinstance(priority, str): - return Priority[priority.upper()].value - else: - raise TypeError("priority must be an integer or Priority enum value") diff --git a/libs/visengine/visengine/runner/runner.py b/libs/visengine/visengine/runner/runner.py deleted file mode 100644 index 9637249..0000000 --- a/libs/visengine/visengine/runner/runner.py +++ /dev/null @@ -1,2388 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import logging -import os -import os.path as osp -import pickle -import platform -import time -import warnings -from collections import OrderedDict -from collections.abc import Callable, Sequence -from functools import partial -from typing import Union - -import torch -import torch.nn as nn -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -import visengine -from visengine.version import __version__ -from visengine.utils.version_utils import get_git_hash -from visengine.config import Config, ConfigDict -from visengine.dataset import worker_init_fn as default_worker_init_fn -from visengine.device import get_device -from visengine.dist import ( - broadcast, - get_dist_info, - get_rank, - get_world_size, - init_dist, - is_distributed, - master_only, -) -from visengine.evaluator import Evaluator -from visengine.fileio import FileClient, join_path -from visengine.hooks import Hook -from visengine.logging import MessageHub, MMLogger, print_log -from visengine.model import ( - MMDistributedDataParallel, - convert_sync_batchnorm, - is_model_wrapper, - revert_sync_batchnorm, -) -from visengine.model.efficient_conv_bn_eval import turn_on_efficient_conv_bn_eval -from visengine.optim import ( - OptimWrapper, - OptimWrapperDict, - _ParamScheduler, - build_optim_wrapper, -) -from visengine.registry import ( - DATA_SAMPLERS, - DATASETS, - EVALUATOR, - FUNCTIONS, - HOOKS, - LOG_PROCESSORS, - LOOPS, - MODEL_WRAPPERS, - MODELS, - OPTIM_WRAPPERS, - PARAM_SCHEDULERS, - RUNNERS, - VISUALIZERS, - DefaultScope, -) -from visengine.utils import apply_to, get_git_hash, is_seq_of -from visengine.utils.dl_utils import collect_env, set_multi_processing -from visengine.visualization import Visualizer - -from .activation_checkpointing import turn_on_activation_checkpointing -from .base_loop import BaseLoop -from .checkpoint import ( - _load_checkpoint, - _load_checkpoint_to_model, - find_latest_checkpoint, - save_checkpoint, - weights_to_cpu, -) -from .log_processor import LogProcessor -from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop -from .priority import Priority, get_priority -from .utils import _get_batch_size, set_random_seed - -ConfigType = Union[dict, Config, ConfigDict] -ParamSchedulerType = Union[list[_ParamScheduler], dict[str, list[_ParamScheduler]]] -OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] - - -class _SlicedDataset: - def __init__(self, dataset, length) -> None: - self._dataset = dataset - self._length = length - - def __getattr__(self, name): - return getattr(self._dataset, name) - - def __getitem__(self, idx): - return self._dataset[idx] - - def __len__(self): - return self._length - - -@RUNNERS.register_module(force=True) -class Runner: - """A training helper for PyTorch. - - Runner object can be built from config by ``runner = Runner.from_cfg(cfg)`` - where the ``cfg`` usually contains training, validation, and test-related - configurations to build corresponding components. We usually use the - same config to launch training, testing, and validation tasks. However, - only some of these components are necessary at the same time, e.g., - testing a model does not need training or validation-related components. - - To avoid repeatedly modifying config, the construction of ``Runner`` adopts - lazy initialization to only initialize components when they are going to be - used. Therefore, the model is always initialized at the beginning, and - training, validation, and, testing related components are only initialized - when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``, - respectively. - - Args: - model (:obj:`torch.nn.Module` or dict): The model to be run. It can be - a dict used for build a model. - work_dir (str): The working directory to save checkpoints. The logs - will be saved in the subdirectory of `work_dir` named - :attr:`timestamp`. - train_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping training steps. Defaults to None. - See :meth:`build_dataloader` for more details. - val_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping validation steps. Defaults to None. - See :meth:`build_dataloader` for more details. - test_dataloader (Dataloader or dict, optional): A dataloader object or - a dict to build a dataloader. If ``None`` is given, it means - skipping test steps. Defaults to None. - See :meth:`build_dataloader` for more details. - train_cfg (dict, optional): A dict to build a training loop. If it does - not provide "type" key, it should contain "by_epoch" to decide - which type of training loop :class:`EpochBasedTrainLoop` or - :class:`IterBasedTrainLoop` should be used. If ``train_cfg`` - specified, :attr:`train_dataloader` should also be specified. - Defaults to None. See :meth:`build_train_loop` for more details. - val_cfg (dict, optional): A dict to build a validation loop. If it does - not provide "type" key, :class:`ValLoop` will be used by default. - If ``val_cfg`` specified, :attr:`val_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_val_loop` for more details. - test_cfg (dict, optional): A dict to build a test loop. If it does - not provide "type" key, :class:`TestLoop` will be used by default. - If ``test_cfg`` specified, :attr:`test_dataloader` should also be - specified. If ``ValLoop`` is built with `fp16=True``, - ``runner.val()`` will be performed under fp16 precision. - Defaults to None. See :meth:`build_test_loop` for more details. - auto_scale_lr (dict, Optional): Config to scale the learning rate - automatically. It includes ``base_batch_size`` and ``enable``. - ``base_batch_size`` is the batch size that the optimizer lr is - based on. ``enable`` is the switch to turn on and off the feature. - optim_wrapper (OptimWrapper or dict, optional): - Computing gradient of model parameters. If specified, - :attr:`train_dataloader` should also be specified. If automatic - mixed precision or gradient accmulation - training is required. The type of ``optim_wrapper`` should be - AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for - examples. Defaults to None. - param_scheduler (_ParamScheduler or dict or list, optional): - Parameter scheduler for updating optimizer parameters. If - specified, :attr:`optimizer` should also be specified. - Defaults to None. - See :meth:`build_param_scheduler` for examples. - val_evaluator (Evaluator or dict or list, optional): A evaluator object - used for computing metrics for validation. It can be a dict or a - list of dict to build a evaluator. If specified, - :attr:`val_dataloader` should also be specified. Defaults to None. - test_evaluator (Evaluator or dict or list, optional): A evaluator - object used for computing metrics for test steps. It can be a dict - or a list of dict to build a evaluator. If specified, - :attr:`test_dataloader` should also be specified. Defaults to None. - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to - execute default actions like updating model parameters and saving - checkpoints. Default hooks are ``OptimizerHook``, - ``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and - ``CheckpointHook``. Defaults to None. - See :meth:`register_default_hooks` for more details. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - data_preprocessor (dict, optional): The pre-process config of - :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict - and doesn't contain the key ``data_preprocessor``, set the argument - as the ``data_preprocessor`` of the ``model`` dict. - Defaults to None. - load_from (str, optional): The checkpoint file to load from. - Defaults to None. - resume (bool): Whether to resume training. Defaults to False. If - ``resume`` is True and ``load_from`` is None, automatically to - find latest checkpoint from ``work_dir``. If not found, resuming - does nothing. - launcher (str): Way to launcher multi-process. Supported launchers - are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is provided, - non-distributed environment will be launched. - env_cfg (dict): A dict used for setting environment. Defaults to - dict(dist_cfg=dict(backend='nccl')). - log_processor (dict, optional): A processor to format logs. Defaults to - None. - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - visualizer (Visualizer or dict, optional): A Visualizer object or a - dict build Visualizer object. Defaults to None. If not - specified, default config will be used. - default_scope (str): Used to reset registries location. - Defaults to "visengine". - randomness (dict): Some settings to make the experiment as reproducible - as possible like seed and deterministic. - Defaults to ``dict(seed=None)``. If seed is None, a random number - will be generated and it will be broadcasted to all other processes - if in distributed environment. If ``cudnn_benchmark`` is - ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in - ``randomness``, the value of ``torch.backends.cudnn.benchmark`` - will be ``False`` finally. - experiment_name (str, optional): Name of current experiment. If not - specified, timestamp will be used as ``experiment_name``. - Defaults to None. - cfg (dict or Configdict or :obj:`Config`, optional): Full config. - Defaults to None. - - Note: - Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in - `cfg.compile = True`. If you want to control compile options, you - can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``. - Refer to `PyTorch API Documentation `_ for more valid - options. - - Examples: - >>> from visengine.runner import Runner - >>> cfg = dict( - >>> model=dict(type='ToyModel'), - >>> work_dir='path/of/work_dir', - >>> train_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=True), - >>> batch_size=1, - >>> num_workers=0), - >>> val_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> test_dataloader=dict( - >>> dataset=dict(type='ToyDataset'), - >>> sampler=dict(type='DefaultSampler', shuffle=False), - >>> batch_size=1, - >>> num_workers=0), - >>> auto_scale_lr=dict(base_batch_size=16, enable=False), - >>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict( - >>> type='SGD', lr=0.01)), - >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - >>> val_evaluator=dict(type='ToyEvaluator'), - >>> test_evaluator=dict(type='ToyEvaluator'), - >>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1), - >>> val_cfg=dict(), - >>> test_cfg=dict(), - >>> custom_hooks=[], - >>> default_hooks=dict( - >>> timer=dict(type='IterTimerHook'), - >>> checkpoint=dict(type='CheckpointHook', interval=1), - >>> logger=dict(type='LoggerHook'), - >>> optimizer=dict(type='OptimizerHook', grad_clip=False), - >>> param_scheduler=dict(type='ParamSchedulerHook')), - >>> launcher='none', - >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), - >>> log_processor=dict(window_size=20), - >>> visualizer=dict(type='Visualizer', - >>> vis_backends=[dict(type='LocalVisBackend', - >>> save_dir='temp_dir')]) - >>> ) - >>> runner = Runner.from_cfg(cfg) - >>> runner.train() - >>> runner.test() - """ - - cfg: Config - _train_loop: BaseLoop | dict | None - _val_loop: BaseLoop | dict | None - _test_loop: BaseLoop | dict | None - readable_config: dict | None - - def __init__( - self, - model: nn.Module | dict, - work_dir: str, - train_dataloader: DataLoader | dict | None = None, - val_dataloader: DataLoader | dict | None = None, - test_dataloader: DataLoader | dict | None = None, - train_cfg: dict | None = None, - val_cfg: dict | None = None, - test_cfg: dict | None = None, - auto_scale_lr: dict | None = None, - optim_wrapper: OptimWrapper | dict | None = None, - param_scheduler: _ParamScheduler | dict | list | None = None, - val_evaluator: Evaluator | dict | list | None = None, - test_evaluator: Evaluator | dict | list | None = None, - default_hooks: dict[str, Hook | dict] | None = None, - custom_hooks: list[Hook | dict] | None = None, - data_preprocessor: nn.Module | dict | None = None, - load_from: str | None = None, - resume: bool = False, - launcher: str = "none", - env_cfg: dict | None = None, - log_processor: dict | None = None, - log_level: str = "INFO", - visualizer: Visualizer | dict | None = None, - default_scope: str = "visengine", - randomness: dict | None = None, - experiment_name: str | None = None, - cfg: ConfigType | None = None, - ): - if randomness is None: - randomness = {"seed": None} - if env_cfg is None: - env_cfg = {"dist_cfg": {"backend": "nccl"}} - self._work_dir = osp.abspath(work_dir) - from visengine.utils import mkdir_or_exist - - mkdir_or_exist(self._work_dir) - - # recursively copy the `cfg` because `self.cfg` will be modified - # everywhere. - if cfg is not None: - if isinstance(cfg, Config): - self.cfg = copy.deepcopy(cfg) - elif isinstance(cfg, dict): - self.cfg = Config(cfg) - else: - self.cfg = Config({}) - - # lazy initialization - training_related = [train_dataloader, train_cfg, optim_wrapper] - if not (all(item is None for item in training_related) or all(item is not None for item in training_related)): - raise ValueError( - "train_dataloader, train_cfg, and optim_wrapper should be " - "either all None or not None, but got " - f"train_dataloader={train_dataloader}, " - f"train_cfg={train_cfg}, " - f"optim_wrapper={optim_wrapper}." - ) - self._train_dataloader = train_dataloader - self._train_loop = train_cfg - - self.optim_wrapper: OptimWrapper | dict | None - self.optim_wrapper = optim_wrapper - - self.auto_scale_lr = auto_scale_lr - - # If there is no need to adjust learning rate, momentum or other - # parameters of optimizer, param_scheduler can be None - if param_scheduler is not None and self.optim_wrapper is None: - raise ValueError(f"param_scheduler should be None when optim_wrapper is None, but got {param_scheduler}") - - # Parse `param_scheduler` to a list or a dict. If `optim_wrapper` is a - # `dict` with single optimizer, parsed param_scheduler will be a - # list of parameter schedulers. If `optim_wrapper` is - # a `dict` with multiple optimizers, parsed `param_scheduler` will be - # dict with multiple list of parameter schedulers. - self._check_scheduler_cfg(param_scheduler) - self.param_schedulers = param_scheduler - - val_related = [val_dataloader, val_cfg, val_evaluator] - if not (all(item is None for item in val_related) or all(item is not None for item in val_related)): - raise ValueError( - "val_dataloader, val_cfg, and val_evaluator should be either " - "all None or not None, but got " - f"val_dataloader={val_dataloader}, val_cfg={val_cfg}, " - f"val_evaluator={val_evaluator}" - ) - self._val_dataloader = val_dataloader - self._val_loop = val_cfg - self._val_evaluator = val_evaluator - - test_related = [test_dataloader, test_cfg, test_evaluator] - if not (all(item is None for item in test_related) or all(item is not None for item in test_related)): - raise ValueError( - "test_dataloader, test_cfg, and test_evaluator should be " - "either all None or not None, but got " - f"test_dataloader={test_dataloader}, test_cfg={test_cfg}, " - f"test_evaluator={test_evaluator}" - ) - self._test_dataloader = test_dataloader - self._test_loop = test_cfg - self._test_evaluator = test_evaluator - - self._launcher = launcher - if self._launcher == "none": - self._distributed = False - else: - self._distributed = True - - # self._timestamp will be set in the `setup_env` method. Besides, - # it also will initialize multi-process and (or) distributed - # environment. - self.setup_env(env_cfg) - # self._deterministic and self._seed will be set in the - # `set_randomness`` method - self._randomness_cfg = randomness - self.set_randomness(**randomness) - - if experiment_name is not None: - self._experiment_name = f"{experiment_name}_{self._timestamp}" - elif self.cfg.filename is not None: - filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0] - self._experiment_name = f"{filename_no_ext}_{self._timestamp}" - else: - self._experiment_name = self.timestamp - self._log_dir = osp.join(self.work_dir, self.timestamp) - mkdir_or_exist(self._log_dir) - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - if default_scope is not None: - default_scope = DefaultScope.get_instance( # type: ignore - self._experiment_name, scope_name=default_scope - ) - self.default_scope = default_scope - - # Build log processor to format message. - log_processor = {} if log_processor is None else log_processor - self.log_processor = self.build_log_processor(log_processor) - # Since `get_instance` could return any subclass of ManagerMixin. The - # corresponding attribute needs a type hint. - self.logger = self.build_logger(log_level=log_level) - - # Collect and log environment information. - self._log_env(env_cfg) - - # Build `message_hub` for communication among components. - # `message_hub` can store log scalars (loss, learning rate) and - # runtime information (iter and epoch). Those components that do not - # have access to the runner can get iteration or epoch information - # from `message_hub`. For example, models can get the latest created - # `message_hub` by - # `self.message_hub=MessageHub.get_current_instance()` and then get - # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. - # See `MessageHub` and `ManagerMixin` for more details. - self.message_hub = self.build_message_hub() - # visualizer used for writing log or visualizing all kinds of data - self.visualizer = self.build_visualizer(visualizer) - if self.cfg: - self.visualizer.add_config(self.cfg) - - self._load_from = load_from - self._resume = resume - # flag to mark whether checkpoint has been loaded or resumed - self._has_loaded = False - - # build a model - if isinstance(model, dict) and data_preprocessor is not None: - # Merge the data_preprocessor to model config. - model.setdefault("data_preprocessor", data_preprocessor) - self.model = self.build_model(model) - # wrap model - self.model = self.wrap_model(self.cfg.get("model_wrapper_cfg"), self.model) - - # get model name from the model class - if hasattr(self.model, "module"): - self._model_name = self.model.module.__class__.__name__ - else: - self._model_name = self.model.__class__.__name__ - - self._hooks: list[Hook] = [] - # register hooks to `self._hooks` - self.register_hooks(default_hooks, custom_hooks) - # log hooks information - self.logger.info(f"Hooks will be executed in the following order:\n{self.get_hooks_info()}") - - # dump `cfg` to `work_dir` - self.dump_config() - - @classmethod - def from_cfg(cls, cfg: ConfigType) -> "Runner": - """Build a runner from config. - - Args: - cfg (ConfigType): A config used for building runner. Keys of - ``cfg`` can see :meth:`__init__`. - - Returns: - Runner: A runner build from ``cfg``. - """ - cfg = copy.deepcopy(cfg) - runner = cls( - model=cfg["model"], - work_dir=cfg["work_dir"], - train_dataloader=cfg.get("train_dataloader"), - val_dataloader=cfg.get("val_dataloader"), - test_dataloader=cfg.get("test_dataloader"), - train_cfg=cfg.get("train_cfg"), - val_cfg=cfg.get("val_cfg"), - test_cfg=cfg.get("test_cfg"), - auto_scale_lr=cfg.get("auto_scale_lr"), - optim_wrapper=cfg.get("optim_wrapper"), - param_scheduler=cfg.get("param_scheduler"), - val_evaluator=cfg.get("val_evaluator"), - test_evaluator=cfg.get("test_evaluator"), - default_hooks=cfg.get("default_hooks"), - custom_hooks=cfg.get("custom_hooks"), - data_preprocessor=cfg.get("data_preprocessor"), - load_from=cfg.get("load_from"), - resume=cfg.get("resume", False), - launcher=cfg.get("launcher", "none"), - env_cfg=cfg.get("env_cfg", {"dist_cfg": {"backend": "nccl"}}), - log_processor=cfg.get("log_processor"), - log_level=cfg.get("log_level", "INFO"), - visualizer=cfg.get("visualizer"), - default_scope=cfg.get("default_scope", "visengine"), - randomness=cfg.get("randomness", {"seed": None}), - experiment_name=cfg.get("experiment_name"), - cfg=cfg, - ) - - return runner - - @property - def experiment_name(self): - """str: Name of experiment.""" - return self._experiment_name - - @property - def model_name(self): - """str: Name of the model, usually the module class name.""" - return self._model_name - - @property - def work_dir(self): - """str: The working directory to save checkpoints and logs.""" - return self._work_dir - - @property - def log_dir(self): - return self._log_dir - - @property - def max_epochs(self): - """int: Total epochs to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_epochs - else: - return 0 - - @property - def max_iters(self): - """int: Total iterations to train model.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.max_iters - else: - return 0 - - @property - def epoch(self): - """int: Current epoch.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.epoch - else: - return 0 - - @property - def iter(self): - """int: Current iteration.""" - if isinstance(self.train_loop, BaseLoop): - return self.train_loop.iter - else: - return 0 - - @property - def launcher(self): - """str: Way to launcher multi processes.""" - return self._launcher - - @property - def distributed(self): - """bool: Whether current environment is distributed.""" - return self._distributed - - @property - def rank(self): - """int: Rank of current process.""" - return self._rank - - @property - def world_size(self): - """int: Number of processes participating in the job.""" - return self._world_size - - @property - def deterministic(self): - """int: Whether cudnn to select deterministic algorithms.""" - return self._deterministic - - @property - def seed(self): - """int: A number to set random modules.""" - return self._seed - - @property - def timestamp(self): - """str: Timestamp when creating experiment.""" - return self._timestamp - - @property - def hooks(self): - """List[:obj:`Hook`]: A list of registered hooks.""" - return self._hooks - - @property - def train_loop(self): - """:obj:`BaseLoop`: A loop to run training.""" - if isinstance(self._train_loop, BaseLoop) or self._train_loop is None: - return self._train_loop - else: - self._train_loop = self.build_train_loop(self._train_loop) - return self._train_loop - - @property - def val_loop(self): - """:obj:`BaseLoop`: A loop to run validation.""" - if isinstance(self._val_loop, BaseLoop) or self._val_loop is None: - return self._val_loop - else: - self._val_loop = self.build_val_loop(self._val_loop) - return self._val_loop - - @property - def test_loop(self): - """:obj:`BaseLoop`: A loop to run testing.""" - if isinstance(self._test_loop, BaseLoop) or self._test_loop is None: - return self._test_loop - else: - self._test_loop = self.build_test_loop(self._test_loop) - return self._test_loop - - @property - def train_dataloader(self): - """The data loader for training.""" - return self.train_loop.dataloader - - @property - def val_dataloader(self): - """The data loader for validation.""" - return self.val_loop.dataloader - - @property - def test_dataloader(self): - """The data loader for testing.""" - return self.test_loop.dataloader - - @property - def val_evaluator(self): - """:obj:`Evaluator`: An evaluator for validation.""" - return self.val_loop.evaluator - - @property - def test_evaluator(self): - """:obj:`Evaluator`: An evaluator for testing.""" - return self.test_loop.evaluator - - @property - def val_interval(self): - """int: Interval to run validation during training.""" - return self.train_loop.val_interval - - @property - def val_begin(self): - """int: The epoch/iteration to start running validation during - training.""" - return self.train_loop.val_begin - - def setup_env(self, env_cfg: dict) -> None: - """Setup environment. - - An example of ``env_cfg``:: - - env_cfg = dict( - cudnn_benchmark=True, - mp_cfg=dict( - mp_start_method='fork', - opencv_num_threads=0 - ), - dist_cfg=dict(backend='nccl', timeout=1800), - resource_limit=4096 - ) - - Args: - env_cfg (dict): Config for setting environment. - """ - if env_cfg.get("cudnn_benchmark"): - torch.backends.cudnn.benchmark = True - - mp_cfg: dict = env_cfg.get("mp_cfg", {}) - set_multi_processing(**mp_cfg, distributed=self.distributed) - - # init distributed env first, since logger depends on the dist info. - if self.distributed and not is_distributed(): - dist_cfg: dict = env_cfg.get("dist_cfg", {}) - init_dist(self.launcher, **dist_cfg) - - self._rank, self._world_size = get_dist_info() - - timestamp = torch.tensor(time.time(), dtype=torch.float64) - # broadcast timestamp from 0 process to other processes - broadcast(timestamp) - self._timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime(timestamp.item())) - - # https://github.com/pytorch/pytorch/issues/973 - # set resource limit - if platform.system() != "Windows": - import resource - - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - base_soft_limit = rlimit[0] - hard_limit = rlimit[1] - soft_limit = min(max(env_cfg.get("resource_limit", 4096), base_soft_limit), hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) - - def set_randomness(self, seed, diff_rank_seed: bool = False, deterministic: bool = False) -> None: - """Set random seed to guarantee reproducible results. - - Args: - seed (int): A number to set random modules. - diff_rank_seed (bool): Whether or not set different seeds according - to global rank. Defaults to False. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - See https://pytorch.org/docs/stable/notes/randomness.html for - more details. - """ - self._deterministic = deterministic - self._seed = set_random_seed(seed=seed, deterministic=deterministic, diff_rank_seed=diff_rank_seed) - - def build_logger(self, log_level: int | str = "INFO", log_file: str | None = None, **kwargs) -> MMLogger: - """Build a global asscessable MMLogger. - - Args: - log_level (int or str): The log level of MMLogger handlers. - Defaults to 'INFO'. - log_file (str, optional): Path of filename to save log. - Defaults to None. - **kwargs: Remaining parameters passed to ``MMLogger``. - - Returns: - MMLogger: A MMLogger object build from ``logger``. - """ - if log_file is None: - log_file = osp.join(self._log_dir, f"{self.timestamp}.log") - - log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs) - log_cfg.setdefault("name", self._experiment_name) - # `torch.compile` in PyTorch 2.0 could close all user defined handlers - # unexpectedly. Using file mode 'a' can help prevent abnormal - # termination of the FileHandler and ensure that the log file could - # be continuously updated during the lifespan of the runner. - log_cfg.setdefault("file_mode", "a") - - return MMLogger.get_instance(**log_cfg) # type: ignore - - def build_message_hub(self, message_hub: dict | None = None) -> MessageHub: - """Build a global asscessable MessageHub. - - Args: - message_hub (dict, optional): A dict to build MessageHub object. - If not specified, default config will be used to build - MessageHub object. Defaults to None. - - Returns: - MessageHub: A MessageHub object build from ``message_hub``. - """ - if message_hub is None: - message_hub = {"name": self._experiment_name} - elif isinstance(message_hub, dict): - # ensure message_hub containing name key - message_hub.setdefault("name", self._experiment_name) - else: - raise TypeError(f"message_hub should be dict or None, but got {message_hub}") - - return MessageHub.get_instance(**message_hub) - - def build_visualizer(self, visualizer: Visualizer | dict | None = None) -> Visualizer: - """Build a global asscessable Visualizer. - - Args: - visualizer (Visualizer or dict, optional): A Visualizer object - or a dict to build Visualizer object. If ``visualizer`` is a - Visualizer object, just returns itself. If not specified, - default config will be used to build Visualizer object. - Defaults to None. - - Returns: - Visualizer: A Visualizer object build from ``visualizer``. - """ - if visualizer is None: - visualizer = { - "name": self._experiment_name, - "vis_backends": [{"type": "LocalVisBackend"}], - "save_dir": self._log_dir, - } - return Visualizer.get_instance(**visualizer) - - if isinstance(visualizer, Visualizer): - return visualizer - - if isinstance(visualizer, dict): - # ensure visualizer containing name key - visualizer.setdefault("name", self._experiment_name) - visualizer.setdefault("save_dir", self._log_dir) - return VISUALIZERS.build(visualizer) - else: - raise TypeError(f"visualizer should be Visualizer object, a dict or None, but got {visualizer}") - - def build_model(self, model: nn.Module | dict) -> nn.Module: - """Build model. - - If ``model`` is a dict, it will be used to build a nn.Module object. - Else, if ``model`` is a nn.Module object it will be returned directly. - - An example of ``model``:: - - model = dict(type='ResNet') - - Args: - model (nn.Module or dict): A ``nn.Module`` object or a dict to - build nn.Module object. If ``model`` is a nn.Module object, - just returns itself. - - Note: - The returned model must implement ``train_step``, ``test_step`` - if ``runner.train`` or ``runner.test`` will be called. If - ``runner.val`` will be called or ``val_cfg`` is configured, - model must implement `val_step`. - - Returns: - nn.Module: Model build from ``model``. - """ - if isinstance(model, nn.Module): - return model - elif isinstance(model, dict): - model = MODELS.build(model) - return model # type: ignore - else: - raise TypeError(f"model should be a nn.Module object or dict, but got {model}") - - def wrap_model(self, model_wrapper_cfg: dict | None, model: nn.Module) -> DistributedDataParallel | nn.Module: - """Wrap the model to :obj:`MMDistributedDataParallel` or other custom - distributed data-parallel module wrappers. - - An example of ``model_wrapper_cfg``:: - - model_wrapper_cfg = dict( - broadcast_buffers=False, - find_unused_parameters=False - ) - - Args: - model_wrapper_cfg (dict, optional): Config to wrap model. If not - specified, ``DistributedDataParallel`` will be used in - distributed environment. Defaults to None. - model (nn.Module): Model to be wrapped. - - Returns: - nn.Module or DistributedDataParallel: nn.Module or subclass of - ``DistributedDataParallel``. - """ - if is_model_wrapper(model): - if model_wrapper_cfg is not None: - raise TypeError( - f'model has been wrapped and "model_wrapper_cfg" should be None, but got {model_wrapper_cfg}' - ) - - return model - - # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training. - model = model.to(get_device()) - - if not self.distributed: - self.logger.info( - "Distributed training is not used, all SyncBatchNorm (SyncBN) " - "layers in the model will be automatically reverted to " - "BatchNormXd layers if they are used." - ) - model = revert_sync_batchnorm(model) - return model # type: ignore - else: - sync_bn = self.cfg.get("sync_bn", None) - if sync_bn is not None: - try: - model = convert_sync_batchnorm(model, sync_bn) - except ValueError as e: - self.logger.error(f'cfg.sync_bn should be "torch" or "mmcv", but got {sync_bn}') - raise e - if model_wrapper_cfg is None: - find_unused_parameters = self.cfg.get("find_unused_parameters", False) - # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel - # TODO: may use a more elegant way to get local device ID. - model = MMDistributedDataParallel( - module=model, - device_ids=[int(os.environ["LOCAL_RANK"])], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters, - ) - else: - model_wrapper_cfg.setdefault("type", "MMDistributedDataParallel") - model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_cfg.get("type")) # type: ignore - default_args: dict = {} - if issubclass(model_wrapper_type, DistributedDataParallel): # type: ignore - default_args["device_ids"] = [int(os.environ["LOCAL_RANK"])] - default_args["module"] = model - model = MODEL_WRAPPERS.build(model_wrapper_cfg, default_args=default_args) - return model - - def _init_model_weights(self) -> None: - """Initialize the model weights if the model has - :meth:`init_weights`""" - model = self.model.module if is_model_wrapper(self.model) else self.model - if hasattr(model, "init_weights"): - model.init_weights() - # sync params and buffers - for _name, params in model.state_dict().items(): - broadcast(params) - - def scale_lr(self, optim_wrapper: OptimWrapper, auto_scale_lr: dict | None = None) -> None: - """Automatically scaling learning rate in training according to the - ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch - size. - - It scales the learning rate linearly according to the - `paper `_. - - Note: - ``scale_lr`` must be called after building optimizer wrappers - and before building parameter schedulers. - - Args: - optim_wrapper (OptimWrapper): An OptimWrapper object whose - parameter groups' learning rate need to be scaled. - auto_scale_lr (Dict, Optional): Config to scale the learning - rate automatically. It includes ``base_batch_size`` and - ``enable``. ``base_batch_size`` is the batch size that the - optimizer lr is based on. ``enable`` is the switch to turn on - and off the feature. - """ - if auto_scale_lr is None or not auto_scale_lr.get("enable", False): - return None - - assert "base_batch_size" in auto_scale_lr, "Lack of `base_batch_size` in `auto_scale_lr`." - dataloader: DataLoader | dict = self._train_dataloader - bs = dataloader.batch_size if isinstance(dataloader, DataLoader) else dataloader["batch_size"] - real_bs = self.world_size * bs - base_bs = auto_scale_lr["base_batch_size"] - ratio = float(real_bs) / float(base_bs) - self.logger.info( - f"LR is set based on batch size of {base_bs} and the current batch size is {real_bs}. Scaling the original LR by {ratio}." - ) - - def _is_built(schedulers): - if isinstance(schedulers, dict): - return False if "type" in schedulers else any(_is_built(s) for s in schedulers.values()) - if isinstance(schedulers, list): - return any(_is_built(s) for s in schedulers) - return isinstance(schedulers, _ParamScheduler) - - if _is_built(self.param_schedulers): - raise RuntimeError( - "`scale_lr` should be called before building ParamScheduler because ParamScheduler will store initial lr from optimizer wrappers" - ) - - assert isinstance(optim_wrapper, OptimWrapper), "`scale_lr should be called after building OptimWrapper" - wrappers = list(optim_wrapper.values()) if isinstance(optim_wrapper, OptimWrapperDict) else [optim_wrapper] - for wrapper in wrappers: - for group in wrapper.optimizer.param_groups: - group["lr"] = group["lr"] * ratio - - def build_optim_wrapper(self, optim_wrapper: Optimizer | OptimWrapper | dict) -> OptimWrapper | OptimWrapperDict: - """Build optimizer wrapper. - - If ``optim_wrapper`` is a config dict for only one optimizer, - the keys must contain ``optimizer``, and ``type`` is optional. - It will build a :obj:`OptimWrapper` by default. - - If ``optim_wrapper`` is a config dict for multiple optimizers, i.e., - it has multiple keys and each key is for an optimizer wrapper. The - constructor must be specified since - :obj:`DefaultOptimizerConstructor` cannot handle the building of - training with multiple optimizers. - - If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e., - each value of ``optim_wrapper`` represents an ``OptimWrapper`` - instance. ``build_optim_wrapper`` will directly build the - :obj:`OptimWrapperDict` instance from ``optim_wrapper``. - - Args: - optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a - dict to build OptimWrapper objects. If ``optim_wrapper`` is an - OptimWrapper, just return an ``OptimizeWrapper`` instance. - - Note: - For single optimizer training, if `optim_wrapper` is a config - dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it - must contain `optimizer` to build the corresponding optimizer. - - Examples: - >>> # build an optimizer - >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)) - >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> # is also valid. - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build optimizer without `type` - >>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) - >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.01 - maximize: False - momentum: 0 - nesterov: False - weight_decay: 0 - ) - >>> # build multiple optimizers - >>> optim_wrapper_cfg = dict( - ... generator=dict(type='OptimWrapper', optimizer=dict( - ... type='SGD', lr=0.01)), - ... discriminator=dict(type='OptimWrapper', optimizer=dict( - ... type='Adam', lr=0.001)) - ... # need to customize a multiple optimizer constructor - ... constructor='CustomMultiOptimizerConstructor', - ...) - >>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) - >>> optim_wrapper - name: generator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - SGD ( - Parameter Group 0 - dampening: 0 - lr: 0.1 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - name: discriminator - Type: OptimWrapper - accumulative_counts: 1 - optimizer: - 'discriminator': Adam ( - Parameter Group 0 - dampening: 0 - lr: 0.02 - momentum: 0 - nesterov: False - weight_decay: 0 - ) - - Important: - If you need to build multiple optimizers, you should implement a - MultiOptimWrapperConstructor which gets parameters passed to - corresponding optimizers and compose the ``OptimWrapperDict``. - More details about how to customize OptimizerConstructor can be - found at `optimizer-docs`_. - - Returns: - OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. - """ - if isinstance(optim_wrapper, OptimWrapper): - return optim_wrapper - if isinstance(optim_wrapper, dict | ConfigDict | Config): - # optimizer must be defined for single optimizer training. - optimizer = optim_wrapper.get("optimizer", None) - - # If optimizer is a built `Optimizer` instance, the optimizer - # wrapper should be built by `OPTIM_WRAPPERS` registry. - if isinstance(optimizer, Optimizer): - optim_wrapper.setdefault("type", "OptimWrapper") - return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore - - # If `optimizer` is not None or `constructor` is defined, it means, - # optimizer wrapper will be built by optimizer wrapper - # constructor. Therefore, `build_optim_wrapper` should be called. - if optimizer is not None or "constructor" in optim_wrapper: - return build_optim_wrapper(self.model, optim_wrapper) - else: - # if `optimizer` is not defined, it should be the case of - # training with multiple optimizers. If `constructor` is not - # defined either, each value of `optim_wrapper` must be an - # `OptimWrapper` instance since `DefaultOptimizerConstructor` - # will not handle the case of training with multiple - # optimizers. `build_optim_wrapper` will directly build the - # `OptimWrapperDict` instance from `optim_wrapper.` - optim_wrappers = OrderedDict() - for name, optim in optim_wrapper.items(): - if not isinstance(optim, OptimWrapper): - raise ValueError( - f'each item mush be an optimizer object when "type" and "constructor" are not in optimizer, but got {name}={optim}' - ) - optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) - else: - raise TypeError(f"optimizer wrapper should be an OptimWrapper object or dict, but got {optim_wrapper}") - - def _build_param_scheduler( - self, scheduler: _ParamScheduler | dict | list, optim_wrapper: OptimWrapper - ) -> list[_ParamScheduler]: - """Build parameter schedulers for a single optimizer. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - optim_wrapper (OptimWrapper): An optimizer wrapper object is - passed to construct ParamScheduler object. - - Returns: - list[_ParamScheduler]: List of parameter schedulers build from - ``scheduler``. - - Note: - If the train loop is built, when building parameter schedulers, - it supports setting the max epochs/iters as the default ``end`` - of schedulers, and supports converting epoch-based schedulers - to iter-based according to the ``convert_to_iter_based`` key. - """ - if not isinstance(scheduler, Sequence): - schedulers = [scheduler] - else: - schedulers = scheduler - - param_schedulers = [] - for scheduler in schedulers: - if isinstance(scheduler, _ParamScheduler): - param_schedulers.append(scheduler) - elif isinstance(scheduler, dict): - _scheduler = copy.deepcopy(scheduler) - - # Set default end - if isinstance(self._train_loop, BaseLoop): - default_end = self.max_epochs if _scheduler.get("by_epoch", True) else self.max_iters - _scheduler.setdefault("end", default_end) - self.logger.debug( - f"The `end` of {_scheduler['type']} is not set. Use the max epochs/iters of train loop as default." - ) - - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args={ - "optimizer": optim_wrapper, - "epoch_length": len(self.train_dataloader), - }, - ) - ) - else: - raise TypeError(f"scheduler should be a _ParamScheduler object or dict, but got {scheduler}") - return param_schedulers - - def build_param_scheduler(self, scheduler: _ParamScheduler | dict | list) -> ParamSchedulerType: - """Build parameter schedulers. - - ``build_param_scheduler`` should be called after - ``build_optim_wrapper`` because the building logic will change - according to the number of optimizers built by the runner. - The cases are as below: - - - Single optimizer: When only one optimizer is built and used in the - runner, ``build_param_scheduler`` will return a list of - parameter schedulers. - - Multiple optimizers: When two or more optimizers are built and used - in runner, ``build_param_scheduler`` will return a dict containing - the same keys with multiple optimizers and each value is a list of - parameter schedulers. Note that, if you want different optimizers to - use different parameter schedulers to update optimizer's - hyper-parameters, the input parameter ``scheduler`` also needs to be - a dict and its key are consistent with multiple optimizers. - Otherwise, the same parameter schedulers will be used to update - optimizer's hyper-parameters. - - Args: - scheduler (_ParamScheduler or dict or list): A Param Scheduler - object or a dict or list of dict to build parameter schedulers. - - Examples: - >>> # build one scheduler - >>> optim_cfg = dict(dict(type='SGD', lr=0.01)) - >>> runner.optim_wrapper = runner.build_optim_wrapper( - >>> optim_cfg) - >>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [] # noqa: E501 - - >>> # build multiple schedulers - >>> scheduler_cfg = [ - ... dict(type='MultiStepLR', milestones=[1, 2]), - ... dict(type='StepLR', step_size=1) - ... ] - >>> schedulers = runner.build_param_scheduler(scheduler_cfg) - >>> schedulers - [, # noqa: E501 - ] - - Above examples only provide the case of one optimizer and one scheduler - or multiple schedulers. If you want to know how to set parameter - scheduler when using multiple optimizers, you can find more examples - `optimizer-docs`_. - - Returns: - list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of - parameter schedulers or a dictionary contains list of parameter - schedulers build from ``scheduler``. - - .. _optimizer-docs: - https://visengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html - """ - param_schedulers: ParamSchedulerType - if not isinstance(self.optim_wrapper, OptimWrapperDict): - # Since `OptimWrapperDict` inherits from `OptimWrapper`, - # `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell - # whether `self.optim_wrapper` is an `OptimizerWrapper` or - # `OptimWrapperDict` instance. Therefore, here we simply check - # self.optim_wrapper is not an `OptimWrapperDict` instance and - # then assert it is an OptimWrapper instance. - assert isinstance(self.optim_wrapper, OptimWrapper), ( - "`build_optimizer` should be called before`build_param_scheduler` because the latter depends on the former" - ) - param_schedulers = self._build_param_scheduler(scheduler, self.optim_wrapper) # type: ignore - return param_schedulers - else: - param_schedulers = {} - for name, optimizer in self.optim_wrapper.items(): - if isinstance(scheduler, dict) and "type" not in scheduler: - # scheduler is a dict and each item is a ParamScheduler - # object or a config to build ParamScheduler objects - param_schedulers[name] = self._build_param_scheduler(scheduler[name], optimizer) - else: - param_schedulers[name] = self._build_param_scheduler(scheduler, optimizer) - - return param_schedulers - - def build_evaluator(self, evaluator: dict | list | Evaluator) -> Evaluator: - """Build evaluator. - - Examples of ``evaluator``:: - - # evaluator could be a built Evaluator instance - evaluator = Evaluator(metrics=[ToyMetric()]) - - # evaluator can also be a list of dict - evaluator = [ - dict(type='ToyMetric1'), - dict(type='ToyEvaluator2') - ] - - # evaluator can also be a list of built metric - evaluator = [ToyMetric1(), ToyMetric2()] - - # evaluator can also be a dict with key metrics - evaluator = dict(metrics=ToyMetric()) - # metric is a list - evaluator = dict(metrics=[ToyMetric()]) - - Args: - evaluator (Evaluator or dict or list): An Evaluator object or a - config dict or list of config dict used to build an Evaluator. - - Returns: - Evaluator: Evaluator build from ``evaluator``. - """ - if isinstance(evaluator, Evaluator): - return evaluator - elif isinstance(evaluator, dict): - # if `metrics` in dict keys, it means to build customized evalutor - if "metrics" in evaluator: - evaluator.setdefault("type", "Evaluator") - return EVALUATOR.build(evaluator) - # otherwise, default evalutor will be built - else: - return Evaluator(evaluator) # type: ignore - elif isinstance(evaluator, list): - # use the default `Evaluator` - return Evaluator(evaluator) # type: ignore - else: - raise TypeError(f"evaluator should be one of dict, list of dict, and Evaluator, but got {evaluator}") - - @staticmethod - def build_dataloader( - dataloader: DataLoader | dict, - seed: int | None = None, - diff_rank_seed: bool = False, - ) -> DataLoader: - """Build dataloader. - - The method builds three components: - - - Dataset - - Sampler - - Dataloader - - An example of ``dataloader``:: - - dataloader = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=9 - ) - - Args: - dataloader (DataLoader or dict): A Dataloader object or a dict to - build Dataloader object. If ``dataloader`` is a Dataloader - object, just returns itself. - seed (int, optional): Random seed. Defaults to None. - diff_rank_seed (bool): Whether or not set different seeds to - different ranks. If True, the seed passed to sampler is set - to None, in order to synchronize the seeds used in samplers - across different ranks. - - - Returns: - Dataloader: DataLoader build from ``dataloader_cfg``. - """ - if isinstance(dataloader, DataLoader): - return dataloader - - dataloader_cfg = copy.deepcopy(dataloader) - - # build dataset - dataset_cfg = dataloader_cfg.pop("dataset") - if isinstance(dataset_cfg, dict): - dataset = DATASETS.build(dataset_cfg) - if hasattr(dataset, "full_init"): - dataset.full_init() - else: - # fallback to raise error in dataloader - # if `dataset_cfg` is not a valid type - dataset = dataset_cfg - - num_batch_per_epoch = dataloader_cfg.pop("num_batch_per_epoch", None) - if num_batch_per_epoch is not None: - world_size = get_world_size() - num_samples = num_batch_per_epoch * _get_batch_size(dataloader_cfg) * world_size - dataset = _SlicedDataset(dataset, num_samples) - - # build sampler - sampler_cfg = dataloader_cfg.pop("sampler") - if isinstance(sampler_cfg, dict): - sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build(sampler_cfg, default_args={"dataset": dataset, "seed": sampler_seed}) - else: - # fallback to raise error in dataloader - # if `sampler_cfg` is not a valid type - sampler = sampler_cfg - - # build batch sampler - batch_sampler_cfg = dataloader_cfg.pop("batch_sampler", None) - if batch_sampler_cfg is None: - batch_sampler = None - elif isinstance(batch_sampler_cfg, dict): - batch_sampler = DATA_SAMPLERS.build( - batch_sampler_cfg, - default_args={ - "sampler": sampler, - "batch_size": dataloader_cfg.pop("batch_size"), - }, - ) - else: - # fallback to raise error in dataloader - # if `batch_sampler_cfg` is not a valid type - batch_sampler = batch_sampler_cfg - - # build dataloader - init_fn: partial | None - - if "worker_init_fn" in dataloader_cfg: - worker_init_fn_cfg = dataloader_cfg.pop("worker_init_fn") - worker_init_fn_type = worker_init_fn_cfg.pop("type") - if isinstance(worker_init_fn_type, str): - worker_init_fn = FUNCTIONS.get(worker_init_fn_type) - elif callable(worker_init_fn_type): - worker_init_fn = worker_init_fn_type - else: - raise TypeError( - f"type of worker_init_fn should be string or callable object, but got {type(worker_init_fn_type)}" - ) - assert callable(worker_init_fn) - init_fn = partial(worker_init_fn, **worker_init_fn_cfg) # type: ignore - else: - if seed is not None: - disable_subprocess_warning = dataloader_cfg.pop("disable_subprocess_warning", False) - assert isinstance(disable_subprocess_warning, bool), ( - f"disable_subprocess_warning should be a bool, but got {type(disable_subprocess_warning)}" - ) - init_fn = partial( - default_worker_init_fn, - num_workers=dataloader_cfg.get("num_workers"), - rank=get_rank(), - seed=seed, - disable_subprocess_warning=disable_subprocess_warning, - ) - else: - init_fn = None - - # The default behavior of `collat_fn` in dataloader is to - # merge a list of samples to form a mini-batch of Tensor(s). - # However, in visengine, if `collate_fn` is not defined in - # dataloader_cfg, `pseudo_collate` will only convert the list of - # samples into a dict without stacking the batch tensor. - collate_fn_cfg = dataloader_cfg.pop("collate_fn", {"type": "pseudo_collate"}) - if isinstance(collate_fn_cfg, dict): - collate_fn_type = collate_fn_cfg.pop("type") - if isinstance(collate_fn_type, str): - collate_fn = FUNCTIONS.get(collate_fn_type) - else: - collate_fn = collate_fn_type - collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore - elif callable(collate_fn_cfg): - collate_fn = collate_fn_cfg - else: - raise TypeError(f"collate_fn should be a dict or callable object, but got {collate_fn_cfg}") - data_loader = DataLoader( - dataset=dataset, - sampler=sampler if batch_sampler is None else None, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - worker_init_fn=init_fn, - **dataloader_cfg, - ) - return data_loader - - def build_train_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build training loop. - - Examples of ``loop``:: - - # `EpochBasedTrainLoop` will be used - loop = dict(by_epoch=True, max_epochs=3) - - # `IterBasedTrainLoop` will be used - loop = dict(by_epoch=False, max_epochs=3) - - # custom training loop - loop = dict(type='CustomTrainLoop', max_epochs=3) - - Args: - loop (BaseLoop or dict): A training loop or a dict to build - training loop. If ``loop`` is a training loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Training loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"train_loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) - - if "type" in loop_cfg and "by_epoch" in loop_cfg: - raise RuntimeError("Only one of `type` or `by_epoch` can exist in `loop_cfg`.") - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={"runner": self, "dataloader": self._train_dataloader}, - ) - else: - by_epoch = loop_cfg.pop("by_epoch") - if by_epoch: - loop = EpochBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader) - else: - loop = IterBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader) - return loop # type: ignore - - def build_val_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build validation loop. - - Examples of ``loop``: - - # `ValLoop` will be used - loop = dict() - - # custom validation loop - loop = dict(type='CustomValLoop') - - Args: - loop (BaseLoop or dict): A validation loop or a dict to build - validation loop. If ``loop`` is a validation loop object, just - returns itself. - - Returns: - :obj:`BaseLoop`: Validation loop object build from ``loop``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"val_loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={ - "runner": self, - "dataloader": self._val_dataloader, - "evaluator": self._val_evaluator, - }, - ) - else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator, - ) # type: ignore - - return loop # type: ignore - - def build_test_loop(self, loop: BaseLoop | dict) -> BaseLoop: - """Build test loop. - - Examples of ``loop``:: - - # `TestLoop` will be used - loop = dict() - - # custom test loop - loop = dict(type='CustomTestLoop') - - Args: - loop (BaseLoop or dict): A test loop or a dict to build test loop. - If ``loop`` is a test loop object, just returns itself. - - Returns: - :obj:`BaseLoop`: Test loop object build from ``loop_cfg``. - """ - if isinstance(loop, BaseLoop): - return loop - elif not isinstance(loop, dict): - raise TypeError(f"test_loop should be a Loop object or dict, but got {loop}") - - loop_cfg = copy.deepcopy(loop) # type: ignore - - if "type" in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args={ - "runner": self, - "dataloader": self._test_dataloader, - "evaluator": self._test_evaluator, - }, - ) - else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator, - ) # type: ignore - - return loop # type: ignore - - def build_log_processor(self, log_processor: LogProcessor | dict) -> LogProcessor: - """Build test log_processor. - - Examples of ``log_processor``: - - # `LogProcessor` will be used - log_processor = dict() - - # custom log_processor - log_processor = dict(type='CustomLogProcessor') - - Args: - log_processor (LogProcessor or dict): A log processor or a dict - to build log processor. If ``log_processor`` is a log processor - object, just returns itself. - - Returns: - :obj:`LogProcessor`: Log processor object build from - ``log_processor_cfg``. - """ - if isinstance(log_processor, LogProcessor): - return log_processor - elif not isinstance(log_processor, dict): - raise TypeError(f"log processor should be a LogProcessor object or dict, butgot {log_processor}") - - log_processor_cfg = copy.deepcopy(log_processor) # type: ignore - - if "type" in log_processor_cfg: - log_processor = LOG_PROCESSORS.build(log_processor_cfg) - else: - log_processor = LogProcessor(**log_processor_cfg) # type: ignore - - return log_processor # type: ignore - - def get_hooks_info(self) -> str: - # Get hooks info in each stage - stage_hook_map: dict[str, list] = {stage: [] for stage in Hook.stages} - for hook in self.hooks: - try: - priority = Priority(hook.priority).name # type: ignore - except ValueError: - priority = hook.priority # type: ignore - classname = hook.__class__.__name__ - hook_info = f"({priority:<12}) {classname:<35}" - for trigger_stage in hook.get_triggered_stages(): - stage_hook_map[trigger_stage].append(hook_info) - - stage_hook_infos = [] - for stage in Hook.stages: - hook_infos = stage_hook_map[stage] - if len(hook_infos) > 0: - info = f"{stage}:\n" - info += "\n".join(hook_infos) - info += "\n -------------------- " - stage_hook_infos.append(info) - return "\n".join(stage_hook_infos) - - def load_or_resume(self) -> None: - """Load or resume checkpoint.""" - if self._has_loaded: - return None - - # decide to load from checkpoint or resume from checkpoint - resume_from = None - if self._resume and self._load_from is None: - # auto resume from the latest checkpoint - resume_from = find_latest_checkpoint(self.work_dir) - self.logger.info(f"Auto resumed from the latest checkpoint {resume_from}.") - elif self._resume and self._load_from is not None: - # resume from the specified checkpoint - resume_from = self._load_from - - if resume_from is not None: - self.resume(resume_from) - self._has_loaded = True - elif self._load_from is not None: - self.load_checkpoint(self._load_from) - self._has_loaded = True - - def train(self) -> nn.Module: - """Launch training. - - Returns: - nn.Module: The model after training. - """ - if is_model_wrapper(self.model): - ori_model = self.model.module - else: - ori_model = self.model - assert hasattr(ori_model, "train_step"), ( - "If you want to train your model, please make sure your model has implemented `train_step`." - ) - - if self._val_loop is not None: - assert hasattr(ori_model, "val_step"), ( - "If you want to validate your model, please make sure your model has implemented `val_step`." - ) - - if self._train_loop is None: - raise RuntimeError( - "`self._train_loop` should not be None when calling train " - "method. Please provide `train_dataloader`, `train_cfg`, " - "`optimizer` and `param_scheduler` arguments when " - "initializing runner." - ) - - self._train_loop = self.build_train_loop(self._train_loop) # type: ignore - - # `build_optimizer` should be called before `build_param_scheduler` - # because the latter depends on the former - self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) - # Automatically scaling lr by linear scaling rule - self.scale_lr(self.optim_wrapper, self.auto_scale_lr) - - if self.param_schedulers is not None: - self.param_schedulers = self.build_param_scheduler( # type: ignore - self.param_schedulers - ) # type: ignore - - if self._val_loop is not None: - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - # TODO: add a contextmanager to avoid calling `before_run` many times - self.call_hook("before_run") - - # initialize the model weights - self._init_model_weights() - - # try to enable activation_checkpointing feature - modules = self.cfg.get("activation_checkpointing", None) - if modules is not None: - self.logger.info(f'Enabling the "activation_checkpointing" feature for sub-modules: {modules}') - turn_on_activation_checkpointing(ori_model, modules) - - # try to enable efficient_conv_bn_eval feature - modules = self.cfg.get("efficient_conv_bn_eval", None) - if modules is not None: - self.logger.info(f'Enabling the "efficient_conv_bn_eval" feature for sub-modules: {modules}') - turn_on_efficient_conv_bn_eval(ori_model, modules) - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - # Initiate inner count of `optim_wrapper`. - self.optim_wrapper.initialize_count_status( - self.model, - self._train_loop.iter, - self._train_loop.max_iters, # type: ignore - ) # type: ignore - - # Maybe compile the model according to options in self.cfg.compile - # This must be called **AFTER** model has been wrapped. - self._maybe_compile("train_step") - - model = self.train_loop.run() # type: ignore - self.call_hook("after_run") - return model - - def val(self) -> dict: - """Launch validation. - - Returns: - dict: A dict of metrics on validation set. - """ - if self._val_loop is None: - raise RuntimeError( - "`self._val_loop` should not be None when calling val method." - "Please provide `val_dataloader`, `val_cfg` and " - "`val_evaluator` arguments when initializing runner." - ) - - self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - - self.call_hook("before_run") - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - metrics = self.val_loop.run() # type: ignore - self.call_hook("after_run") - return metrics - - def test(self) -> dict: - """Launch test. - - Returns: - dict: A dict of metrics on testing set. - """ - if self._test_loop is None: - raise RuntimeError( - "`self._test_loop` should not be None when calling test " - "method. Please provide `test_dataloader`, `test_cfg` and " - "`test_evaluator` arguments when initializing runner." - ) - - self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - - self.call_hook("before_run") - - # make sure checkpoint-related hooks are triggered after `before_run` - self.load_or_resume() - - metrics = self.test_loop.run() # type: ignore - self.call_hook("after_run") - return metrics - - def call_hook(self, fn_name: str, **kwargs) -> None: - """Call all hooks. - - Args: - fn_name (str): The function name in each hook to be called, such as - "before_train_epoch". - **kwargs: Keyword arguments passed to hook. - """ - for hook in self._hooks: - # support adding additional custom hook methods - if hasattr(hook, fn_name): - try: - method = getattr(hook, fn_name) - if not callable(method): - raise TypeError( - f"Hook method '{fn_name}' on {type(hook).__name__} is not callable, got {type(method)}" - ) - method(self, **kwargs) - except TypeError as e: - raise TypeError(f"{e} in {hook}") from None - - def register_hook(self, hook: Hook | dict, priority: str | int | Priority | None = None) -> None: - """Register a hook into the hook list. - - The hook will be inserted into a priority queue, with the specified - priority (See :class:`Priority` for details of priorities). - For hooks with the same priority, they will be triggered in the same - order as they are registered. - - Priority of hook will be decided with the following priority: - - - ``priority`` argument. If ``priority`` is given, it will be priority - of hook. - - If ``hook`` argument is a dict and ``priority`` in it, the priority - will be the value of ``hook['priority']``. - - If ``hook`` argument is a dict but ``priority`` not in it or ``hook`` - is an instance of ``hook``, the priority will be ``hook.priority``. - - Args: - hook (:obj:`Hook` or dict): The hook to be registered. - priority (int or str or :obj:`Priority`, optional): Hook priority. - Lower value means higher priority. - """ - if not isinstance(hook, Hook | dict): - raise TypeError(f"hook should be an instance of Hook or dict, but got {hook}") - - _priority = None - if isinstance(hook, dict): - if "priority" in hook: - _priority = hook.pop("priority") - - hook_obj = HOOKS.build(hook) - else: - hook_obj = hook - - if priority is not None: - hook_obj.priority = priority - elif _priority is not None: - hook_obj.priority = _priority - - inserted = False - for i in range(len(self._hooks) - 1, -1, -1): - if get_priority(hook_obj.priority) >= get_priority(self._hooks[i].priority): - self._hooks.insert(i + 1, hook_obj) - inserted = True - break - if not inserted: - self._hooks.insert(0, hook_obj) - - def register_default_hooks(self, hooks: dict[str, Hook | dict] | None = None) -> None: - """Register default hooks into hook list. - - ``hooks`` will be registered into runner to execute some default - actions like updating model parameters or saving checkpoints. - - Default hooks and their priorities: - - +----------------------+-------------------------+ - | Hooks | Priority | - +======================+=========================+ - | RuntimeInfoHook | VERY_HIGH (10) | - +----------------------+-------------------------+ - | IterTimerHook | NORMAL (50) | - +----------------------+-------------------------+ - | DistSamplerSeedHook | NORMAL (50) | - +----------------------+-------------------------+ - | LoggerHook | BELOW_NORMAL (60) | - +----------------------+-------------------------+ - | ParamSchedulerHook | LOW (70) | - +----------------------+-------------------------+ - | CheckpointHook | VERY_LOW (90) | - +----------------------+-------------------------+ - - If ``hooks`` is None, above hooks will be registered by - default:: - - default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - sampler_seed=dict(type='DistSamplerSeedHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - ) - - If not None, ``hooks`` will be merged into ``default_hooks``. - If there are None value in default_hooks, the corresponding item will - be popped from ``default_hooks``:: - - hooks = dict(timer=None) - - The final registered default hooks will be :obj:`RuntimeInfoHook`, - :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`, - :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`. - - Args: - hooks (dict[str, Hook or dict], optional): Default hooks or configs - to be registered. - """ - default_hooks: dict = { - "runtime_info": {"type": "RuntimeInfoHook"}, - "timer": {"type": "IterTimerHook"}, - "sampler_seed": {"type": "DistSamplerSeedHook"}, - "logger": {"type": "LoggerHook"}, - "param_scheduler": {"type": "ParamSchedulerHook"}, - "checkpoint": {"type": "CheckpointHook", "interval": 1}, - } - if hooks is not None: - for name, hook in hooks.items(): - if name in default_hooks and hook is None: - # remove hook from _default_hooks - default_hooks.pop(name) - else: - assert hook is not None - default_hooks[name] = hook - - for hook in default_hooks.values(): - self.register_hook(hook) - - def register_custom_hooks(self, hooks: list[Hook | dict]) -> None: - """Register custom hooks into hook list. - - Args: - hooks (list[Hook | dict]): List of hooks or configs to be - registered. - """ - for hook in hooks: - self.register_hook(hook) - - def register_hooks( - self, - default_hooks: dict[str, Hook | dict] | None = None, - custom_hooks: list[Hook | dict] | None = None, - ) -> None: - """Register default hooks and custom hooks into hook list. - - Args: - default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks - to execute default actions like updating model parameters and - saving checkpoints. Defaults to None. - custom_hooks (list[dict] or list[Hook], optional): Hooks to execute - custom actions like visualizing images processed by pipeline. - Defaults to None. - """ - self.register_default_hooks(default_hooks) - - if custom_hooks is not None: - self.register_custom_hooks(custom_hooks) - - def resume( - self, - filename: str, - resume_optimizer: bool = True, - resume_param_scheduler: bool = True, - map_location: str | Callable = "default", - ) -> None: - """Resume model from checkpoint. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - resume_optimizer (bool): Whether to resume optimizer state. - Defaults to True. - resume_param_scheduler (bool): Whether to resume param scheduler - state. Defaults to True. - map_location (str or callable):A string or a callable function to - specifying how to remap storage locations. - Defaults to 'default'. - """ - if map_location == "default": - device = get_device() - checkpoint = self.load_checkpoint(filename, map_location=device) - else: - checkpoint = self.load_checkpoint(filename, map_location=map_location) - - self.train_loop._epoch = checkpoint["meta"]["epoch"] - self.train_loop._iter = checkpoint["meta"]["iter"] - - # check whether the number of GPU used for current experiment - # is consistent with resuming from checkpoint - if "config" in checkpoint["meta"]: - config = visengine.Config.fromstring(checkpoint["meta"]["config"], file_format=".py") - previous_gpu_ids = config.get("gpu_ids", None) - if previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self._world_size: - # TODO, should we modify the iteration? - if self.auto_scale_lr is None or not self.auto_scale_lr.get("enable", False): - raise RuntimeError( - "Number of GPUs used for current experiment is not " - "consistent with the checkpoint being resumed from. " - "This will result in poor performance due to the " - "learning rate. You must set the " - "`auto_scale_lr` parameter for Runner and make " - '`auto_scale_lr["enable"]=True`.' - ) - else: - self.logger.info( - "Number of GPU used for current experiment is not " - "consistent with resuming from checkpoint but the " - "leaning rate will be adjusted according to the " - f"setting in auto_scale_lr={self.auto_scale_lr}" - ) - - # resume random seed - resumed_seed = checkpoint["meta"].get("seed", None) - current_seed = self._randomness_cfg.get("seed") - if resumed_seed is not None and resumed_seed != current_seed: - if current_seed is not None: - self.logger.warning( - f'The value of random seed in the checkpoint "{resumed_seed}" is different from the value in `randomness` config "{current_seed}"' - ) - self._randomness_cfg.update(seed=resumed_seed) - self.set_randomness(**self._randomness_cfg) - - resumed_dataset_meta = checkpoint["meta"].get("dataset_meta", None) - dataset_meta = getattr(self.train_dataloader.dataset, "metainfo", None) - - # `resumed_dataset_meta` and `dataset_meta` could be object like - # np.ndarray, which cannot be directly judged as equal or not, - # therefore we just compared their dumped results. - if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta): - self.logger.warning( - "The dataset metainfo from the resumed checkpoint is " - "different from the current training dataset, please " - "check the correctness of the checkpoint or the training " - "dataset." - ) - - self.message_hub.load_state_dict(checkpoint["message_hub"]) - - # resume optimizer - if "optimizer" in checkpoint and resume_optimizer: - self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) - self.optim_wrapper.load_state_dict(checkpoint["optimizer"]) # type: ignore - - # resume param scheduler - if resume_param_scheduler and self.param_schedulers is None: - self.logger.warning( - "`resume_param_scheduler` is True but `self.param_schedulers` is None, so skip resuming parameter schedulers" - ) - resume_param_scheduler = False - if "param_schedulers" in checkpoint and resume_param_scheduler: - self.param_schedulers = self.build_param_scheduler( # type: ignore - self.param_schedulers - ) # type: ignore - if isinstance(self.param_schedulers, dict): - for name, schedulers in self.param_schedulers.items(): - for scheduler, ckpt_scheduler in zip( - schedulers, checkpoint["param_schedulers"][name], strict=False - ): - scheduler.load_state_dict(ckpt_scheduler) - else: - for scheduler, ckpt_scheduler in zip( - self.param_schedulers, - checkpoint["param_schedulers"], - strict=False, # type: ignore - ): - scheduler.load_state_dict(ckpt_scheduler) - - self._has_loaded = True - - self.logger.info(f"resumed epoch: {self.epoch}, iter: {self.iter}") - - def load_checkpoint( - self, - filename: str, - map_location: str | Callable = "cpu", - strict: bool = False, - revise_keys: list | None = None, - ): - """Load checkpoint from given ``filename``. - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. - map_location (str or callable): A string or a callable function to - specifying how to remap storage locations. - Defaults to 'cpu'. - strict (bool): strict (bool): Whether to allow different params for - the model and checkpoint. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - """ - if revise_keys is None: - revise_keys = [(r"^module.", "")] - checkpoint = _load_checkpoint(filename, map_location=map_location) - - # Add comments to describe the usage of `after_load_ckpt` - self.call_hook("after_load_checkpoint", checkpoint=checkpoint) - - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - checkpoint = _load_checkpoint_to_model(model, checkpoint, strict, revise_keys=revise_keys) - - self._has_loaded = True - - self.logger.info(f"Load checkpoint from {filename}") - - return checkpoint - - @master_only - def save_checkpoint( - self, - out_dir: str, - filename: str, - file_client_args: dict | None = None, - save_optimizer: bool = True, - save_param_scheduler: bool = True, - meta: dict | None = None, - by_epoch: bool = True, - backend_args: dict | None = None, - ): - """Save checkpoints. - - ``CheckpointHook`` invokes this method to save checkpoints - periodically. - - Args: - out_dir (str): The directory that checkpoints are saved. - filename (str): The checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`visengine.fileio.FileClient` for - details. Defaults to None. It will be deprecated in future. - Please use `backend_args` instead. - save_optimizer (bool): Whether to save the optimizer to - the checkpoint. Defaults to True. - save_param_scheduler (bool): Whether to save the param_scheduler - to the checkpoint. Defaults to True. - meta (dict, optional): The meta information to be saved in the - checkpoint. Defaults to None. - by_epoch (bool): Decide the number of epoch or iteration saved in - checkpoint. Defaults to True. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - """ - if meta is None: - meta = {} - elif not isinstance(meta, dict): - raise TypeError(f"meta should be a dict or None, but got {type(meta)}") - - if by_epoch: - # self.epoch increments 1 after - # `self.call_hook('after_train_epoch)` but `save_checkpoint` is - # called by `after_train_epoch`` method of `CheckpointHook` so - # `epoch` should be `self.epoch + 1` - meta.setdefault("epoch", self.epoch + 1) - meta.setdefault("iter", self.iter) - else: - meta.setdefault("epoch", self.epoch) - meta.setdefault("iter", self.iter + 1) - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - stacklevel=2, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - file_client = FileClient.infer_client(file_client_args, out_dir) - filepath = file_client.join_path(out_dir, filename) - else: - filepath = join_path(out_dir, filename, backend_args=backend_args) # type: ignore - - meta.update( - cfg=self.cfg.pretty_text, - seed=self.seed, - experiment_name=self.experiment_name, - time=time.strftime("%Y%m%d_%H%M%S", time.localtime()), - visengine_version=__version__ + get_git_hash(), - ) - - if hasattr(self.train_dataloader.dataset, "metainfo"): - meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) - - if is_model_wrapper(self.model): - model = self.model.module - else: - model = self.model - - checkpoint = { - "meta": meta, - "state_dict": weights_to_cpu(model.state_dict()), - "message_hub": apply_to( - self.message_hub.state_dict(), - lambda x: hasattr(x, "cpu"), - lambda x: x.cpu(), - ), - } - # save optimizer state dict to checkpoint - if save_optimizer: - if isinstance(self.optim_wrapper, OptimWrapper): - checkpoint["optimizer"] = apply_to( - self.optim_wrapper.state_dict(), - lambda x: hasattr(x, "cpu"), - lambda x: x.cpu(), - ) - else: - raise TypeError( - f"self.optim_wrapper should be an `OptimWrapper` or `OptimWrapperDict` instance, but got {self.optim_wrapper}" - ) - - # save param scheduler state dict - if save_param_scheduler and self.param_schedulers is None: - self.logger.warning( - "`save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers" - ) - save_param_scheduler = False - if save_param_scheduler: - if isinstance(self.param_schedulers, dict): - checkpoint["param_schedulers"] = {} - for name, schedulers in self.param_schedulers.items(): - checkpoint["param_schedulers"][name] = [] - for scheduler in schedulers: - state_dict = scheduler.state_dict() - checkpoint["param_schedulers"][name].append(state_dict) - else: - checkpoint["param_schedulers"] = [] - for scheduler in self.param_schedulers: # type: ignore - state_dict = scheduler.state_dict() # type: ignore - checkpoint["param_schedulers"].append(state_dict) - - self.call_hook("before_save_checkpoint", checkpoint=checkpoint) - save_checkpoint( - checkpoint, - filepath, - file_client_args=file_client_args, - backend_args=backend_args, - ) - - @master_only - def dump_config(self) -> None: - """Dump config to `work_dir`.""" - if self.cfg.filename is not None: - filename = osp.basename(self.cfg.filename) - else: - filename = f"{self.timestamp}.py" - self.cfg.dump(osp.join(self.work_dir, filename)) - - def _check_scheduler_cfg(self, param_scheduler: dict | list | _ParamScheduler | None) -> None: - """Parse `param_scheduler` to a list of parameter schedulers, or a - `dict` of which each value is a list of parameter schedulers. - - If only one optimizer is used, the parsed config should be a - list of parameter scheduler configs or instances. If multiple - optimizers are used, the parsed config should be `dict`. - Its key should be consistent with the optimizer `dict` and its value - should be a list of parameter scheduler configs or instances. See - :meth:`build_param_scheduler` for more details. - - Examples: - >>> # valid scheduler: - >>> # empty scheduler - >>> scheduler = None - >>> # Single scheduler - >>> scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - >>> # Single list schedulers - >>> scheduler = [dict(type='MultiStepLR', milestones=[1, 2]), - >>> dict(type='MultiStepLR', milestones=[2, 3])] - >>> # `dict` of schedulers - >>> scheduler = dict(linear1=dict(type='MultiStepLR', milestones=[1, 2]), - >>> linear2=dict(type='MultiStepLR', milestones=[1, 2])) - >>> # `dict` of `list` of schedulers - >>> scheduler = dict(linear1=[dict(type='MultiStepLR', milestones=[1, 2])], - >>> linear2=[dict(type='MultiStepLR', milestones=[1, 2])]) - >>> # Single built scheduler - >>> from visengine.optim import MultiStepLR - >>> scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer) - >>> # Single built list schedulers - >>> scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)] - >>> # dict of built scheduler - >>> scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), - >>> linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) - >>> # dict of built list schedulers - >>> scheduler = dict(linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)], - >>> linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)]) - - Args: - param_scheduler (dict or list): The original parameter scheduler. - """ - if param_scheduler is None: - return - if isinstance(param_scheduler, _ParamScheduler): - return - if is_seq_of(param_scheduler, _ParamScheduler): - return - - if is_seq_of(param_scheduler, dict): - for _param_scheduler in param_scheduler: - assert "type" in _param_scheduler, ( - f"Each parameter scheduler should contain the key type, but got {_param_scheduler}" - ) - elif isinstance(param_scheduler, dict): - if "type" not in param_scheduler: - for _key, _param_scheduler in param_scheduler.items(): - assert isinstance(_param_scheduler, dict | tuple | list | _ParamScheduler), ( - f"Each value of `param_scheduler` should be a dict or a list, but got {_param_scheduler} with type {type(_ParamScheduler)}" - ) - - else: - raise TypeError( - "`param_scheduler` should be a `_ParamScheduler`, `dict`, " - f"list or a tuple, but got {type(param_scheduler)}. If " - "`param_scheduler` is a list of dict, it means a list of " - "scheduler configs for single optimizer. If it is a dict and " - "contains key `type`, it means a scheduler config for a " - "single optimizer. If it does not contain key `type`, it " - "means multiple lists of schedulers for multiple optimizers." - ) - - def _log_env(self, env_cfg: dict) -> None: - """Logging environment information of the current task. - - Args: - env_cfg (dict): The environment config of the runner. - """ - # Collect and log environment information. - env = collect_env() - runtime_env = OrderedDict() - runtime_env.update(env_cfg) - runtime_env.update(self._randomness_cfg) - runtime_env["seed"] = self._seed - runtime_env["Distributed launcher"] = self._launcher - runtime_env["Distributed training"] = self._distributed - runtime_env["GPU number"] = self._world_size - - env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in env.items()) - runtime_env_info = "\n " + "\n ".join(f"{k}: {v}" for k, v in runtime_env.items()) - dash_line = "-" * 60 - self.logger.info( - "\n" - + dash_line - + "\nSystem environment:" - + env_info - + "\n\nRuntime environment:" - + runtime_env_info - + "\n" - + dash_line - + "\n" - ) - - if self.cfg._cfg_dict: - self.logger.info(f"Config:\n{self.cfg.pretty_text}") - - def _maybe_compile(self, target: str) -> None: - """Use `torch.compile` to optimize model/wrapped_model.""" - compile_cfg = self.cfg.get("compile", None) - if compile_cfg is None: - # no compile options given, won't compile - return - - if isinstance(compile_cfg, bool): - if not compile_cfg: - # compile=False, compilation is disabled - return - # compile=True, use default configurations - compile_cfg = {} - - func = getattr(self.model, target) - compiled_func = torch.compile(func, **compile_cfg) - setattr(self.model, target, compiled_func) - self.logger.info('Model has been "compiled". The first few iterations will be slow, please be patient.') diff --git a/libs/visengine/visengine/runner/utils.py b/libs/visengine/visengine/runner/utils.py deleted file mode 100644 index 7746a15..0000000 --- a/libs/visengine/visengine/runner/utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import random - -import numpy as np -import torch -from torch.utils.data import DataLoader - -from visengine.device import is_cuda_available -from visengine.dist import get_rank, sync_random_seed -from visengine.logging import print_log -from visengine.utils import is_list_of - - -def calc_dynamic_intervals( - start_interval: int, dynamic_interval_list: list[tuple[int, int]] | None = None -) -> tuple[list[int], list[int]]: - """Calculate dynamic intervals. - - Args: - start_interval (int): The interval used in the beginning. - dynamic_interval_list (List[Tuple[int, int]], optional): The - first element in the tuple is a milestone and the second - element is a interval. The interval is used after the - corresponding milestone. Defaults to None. - - Returns: - Tuple[List[int], List[int]]: a list of milestone and its corresponding - intervals. - """ - if dynamic_interval_list is None: - return [0], [start_interval] - - assert is_list_of(dynamic_interval_list, tuple) - - dynamic_milestones = [0] - dynamic_milestones.extend([dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) - dynamic_intervals = [start_interval] - dynamic_intervals.extend([dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) - return dynamic_milestones, dynamic_intervals - - -def set_random_seed(seed: int | None = None, deterministic: bool = False, diff_rank_seed: bool = False) -> int: - """Set random seed. - - Args: - seed (int, optional): Seed to be used. - deterministic (bool): Whether to set the deterministic option for - CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` - to True and `torch.backends.cudnn.benchmark` to False. - Defaults to False. - diff_rank_seed (bool): Whether to add rank number to the random seed to - have different random seed in different threads. Defaults to False. - """ - if seed is None: - seed = sync_random_seed() - - if diff_rank_seed: - rank = get_rank() - seed += rank - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - # torch.cuda.manual_seed(seed) - if is_cuda_available(): - torch.cuda.manual_seed_all(seed) - # os.environ['PYTHONHASHSEED'] = str(seed) - if deterministic: - if torch.backends.cudnn.benchmark: - print_log( - "torch.backends.cudnn.benchmark is going to be set as `False` to cause cuDNN to deterministically select an algorithm", - logger="current", - level=logging.WARNING, - ) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) - return seed - - -def _get_batch_size(dataloader: dict): - if isinstance(dataloader, dict): - if "batch_size" in dataloader: - return dataloader["batch_size"] - elif "batch_sampler" in dataloader and "batch_size" in dataloader["batch_sampler"]: - return dataloader["batch_sampler"]["batch_size"] - else: - raise ValueError("Please set batch_size in `Dataloader` or `batch_sampler`") - elif isinstance(dataloader, DataLoader): - return dataloader.batch_sampler.batch_size - else: - raise ValueError(f"dataloader should be a dict or a Dataloader instance, but got {type(dataloader)}") diff --git a/libs/visengine/visengine/structures/__init__.py b/libs/visengine/visengine/structures/__init__.py deleted file mode 100644 index 5c1789c..0000000 --- a/libs/visengine/visengine/structures/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .base_data_element import BaseDataElement -from .instance_data import InstanceData -from .label_data import LabelData -from .pixel_data import PixelData - -__all__ = ["BaseDataElement", "InstanceData", "LabelData", "PixelData"] diff --git a/libs/visengine/visengine/structures/base_data_element.py b/libs/visengine/visengine/structures/base_data_element.py deleted file mode 100644 index 5ce6132..0000000 --- a/libs/visengine/visengine/structures/base_data_element.py +++ /dev/null @@ -1,623 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from collections.abc import Iterator -from typing import Any - -import numpy as np -import torch - - -class BaseDataElement: - """A base data interface that supports Tensor-like and dict-like - operations. - - A typical data elements refer to predicted results or ground truth labels - on a task, such as predicted bboxes, instance masks, semantic - segmentation masks, etc. Because groundtruth labels and predicted results - often have similar properties (for example, the predicted bboxes and the - groundtruth bboxes), MMEngine uses the same abstract data interface to - encapsulate predicted results and groundtruth labels, and it is recommended - to use different name conventions to distinguish them, such as using - ``gt_instances`` and ``pred_instances`` to distinguish between labels and - predicted results. Additionally, we distinguish data elements at instance - level, pixel level, and label level. Each of these types has its own - characteristics. Therefore, MMEngine defines the base class - ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and - ``LabelData`` inheriting from ``BaseDataElement`` to represent different - types of ground truth labels or predictions. - - Another common data element is sample data. A sample data consists of input - data (such as an image) and its annotations and predictions. In general, - an image can have multiple types of annotations and/or predictions at the - same time (for example, both pixel-level semantic segmentation annotations - and instance-level detection bboxes annotations). All labels and - predictions of a training sample are often passed between Dataset, Model, - Visualizer, and Evaluator components. In order to simplify the interface - between components, we can treat them as a large data element and - encapsulate them. Such data elements are generally called XXDataSample in - the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` - allows `BaseDataElement` as its attribute. Such a class generally - encapsulates all the data of a sample in the algorithm library, and its - attributes generally are various types of data elements. For example, - MMDetection is assigned by the BaseDataElement to encapsulate all the data - elements of the sample labeling and prediction of a sample in the - algorithm library. - - The attributes in ``BaseDataElement`` are divided into two parts, - the ``metainfo`` and the ``data`` respectively. - - - ``metainfo``: Usually contains the - information about the image such as filename, - image_shape, pad_shape, etc. The attributes can be accessed or - modified by dict-like or object-like operations, such as - ``.`` (for data access and modification), ``in``, ``del``, - ``pop(str)``, ``get(str)``, ``metainfo_keys()``, - ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for - set or change key-value pairs in metainfo). - - - ``data``: Annotations or model predictions are - stored. The attributes can be accessed or modified by - dict-like or object-like operations, such as - ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, - ``values()``, ``items()``. Users can also apply tensor-like - methods to all :obj:`torch.Tensor` in the ``data_fields``, - such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, - ``to_tensor()``, ``.detach()``. - - Args: - metainfo (dict, optional): A dict contains the meta information - of single image, such as ``dict(img_shape=(512, 512, 3), - scale_factor=(1, 1, 1, 1))``. Defaults to None. - kwargs (dict, optional): A dict contains annotations of single image or - model predictions. Defaults to None. - - Examples: - >>> import torch - >>> from visengine.structures import BaseDataElement - >>> gt_instances = BaseDataElement() - >>> bboxes = torch.rand((5, 4)) - >>> scores = torch.rand((5,)) - >>> img_id = 0 - >>> img_shape = (800, 1333) - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=img_id, img_shape=img_shape), - ... bboxes=bboxes, scores=scores) - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) - - >>> # new - >>> gt_instances1 = gt_instances.new( - ... metainfo=dict(img_id=1, img_shape=(640, 640)), - ... bboxes=torch.rand((5, 4)), - ... scores=torch.rand((5,))) - >>> gt_instances2 = gt_instances1.new() - - >>> # add and process property - >>> gt_instances = BaseDataElement() - >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) - >>> assert 'img_shape' in gt_instances.metainfo_keys() - >>> assert 'img_shape' in gt_instances - >>> assert 'img_shape' not in gt_instances.keys() - >>> assert 'img_shape' in gt_instances.all_keys() - >>> print(gt_instances.img_shape) - (100, 100) - >>> gt_instances.scores = torch.rand((5,)) - >>> assert 'scores' in gt_instances.keys() - >>> assert 'scores' in gt_instances - >>> assert 'scores' in gt_instances.all_keys() - >>> assert 'scores' not in gt_instances.metainfo_keys() - >>> print(gt_instances.scores) - tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) - >>> gt_instances.bboxes = torch.rand((5, 4)) - >>> assert 'bboxes' in gt_instances.keys() - >>> assert 'bboxes' in gt_instances - >>> assert 'bboxes' in gt_instances.all_keys() - >>> assert 'bboxes' not in gt_instances.metainfo_keys() - >>> print(gt_instances.bboxes) - tensor([[0.0900, 0.0424, 0.1755, 0.4469], - [0.8648, 0.0592, 0.3484, 0.0913], - [0.5808, 0.1909, 0.6165, 0.7088], - [0.5490, 0.4209, 0.9416, 0.2374], - [0.3652, 0.1218, 0.8805, 0.7523]]) - - >>> # delete and change property - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=0, img_shape=(640, 640)), - ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) - >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) - >>> gt_instances.img_shape # (1280, 1280) - >>> gt_instances.bboxes = gt_instances.bboxes * 2 - >>> gt_instances.get('img_shape', None) # (1280, 1280) - >>> gt_instances.get('bboxes', None) # 6x4 tensor - >>> del gt_instances.img_shape - >>> del gt_instances.bboxes - >>> assert 'img_shape' not in gt_instances - >>> assert 'bboxes' not in gt_instances - >>> gt_instances.pop('img_shape', None) # None - >>> gt_instances.pop('bboxes', None) # None - - >>> # Tensor-like - >>> cuda_instances = gt_instances.cuda() - >>> cuda_instances = gt_instances.to('cuda:0') - >>> cpu_instances = cuda_instances.cpu() - >>> cpu_instances = cuda_instances.to('cpu') - >>> fp16_instances = cuda_instances.to( - ... device=None, dtype=torch.float16, non_blocking=False, - ... copy=False, memory_format=torch.preserve_format) - >>> cpu_instances = cuda_instances.detach() - >>> np_instances = cpu_instances.numpy() - - >>> # print - >>> metainfo = dict(img_shape=(800, 1196, 3)) - >>> gt_instances = BaseDataElement( - ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) - >>> sample = BaseDataElement(metainfo=metainfo, - ... gt_instances=gt_instances) - >>> print(sample) - - ) at 0x7f0fea49e130> - - >>> # inheritance - >>> class DetDataSample(BaseDataElement): - ... @property - ... def proposals(self): - ... return self._proposals - ... @proposals.setter - ... def proposals(self, value): - ... self.set_field(value, '_proposals', dtype=BaseDataElement) - ... @proposals.deleter - ... def proposals(self): - ... del self._proposals - ... @property - ... def gt_instances(self): - ... return self._gt_instances - ... @gt_instances.setter - ... def gt_instances(self, value): - ... self.set_field(value, '_gt_instances', - ... dtype=BaseDataElement) - ... @gt_instances.deleter - ... def gt_instances(self): - ... del self._gt_instances - ... @property - ... def pred_instances(self): - ... return self._pred_instances - ... @pred_instances.setter - ... def pred_instances(self, value): - ... self.set_field(value, '_pred_instances', - ... dtype=BaseDataElement) - ... @pred_instances.deleter - ... def pred_instances(self): - ... del self._pred_instances - >>> det_sample = DetDataSample() - >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) - >>> det_sample.proposals = proposals - >>> assert 'proposals' in det_sample - >>> assert det_sample.proposals == proposals - >>> del det_sample.proposals - >>> assert 'proposals' not in det_sample - >>> with self.assertRaises(AssertionError): - ... det_sample.proposals = torch.rand((5, 4)) - """ - - def __init__(self, *, metainfo: dict | None = None, **kwargs) -> None: - self._metainfo_fields: set = set() - self._data_fields: set = set() - - if metainfo is not None: - self.set_metainfo(metainfo=metainfo) - if kwargs: - self.set_data(kwargs) - - def set_metainfo(self, metainfo: dict) -> None: - """Set or change key-value pairs in ``metainfo_field`` by parameter - ``metainfo``. - - Args: - metainfo (dict): A dict contains the meta information - of image, such as ``img_shape``, ``scale_factor``, etc. - """ - assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}" - meta = copy.deepcopy(metainfo) - for k, v in meta.items(): - self.set_field(name=k, value=v, field_type="metainfo", dtype=None) - - def set_data(self, data: dict) -> None: - """Set or change key-value pairs in ``data_field`` by parameter - ``data``. - - Args: - data (dict): A dict contains annotations of image or - model predictions. - """ - assert isinstance(data, dict), f"data should be a `dict` but got {data}" - for k, v in data.items(): - # Use `setattr()` rather than `self.set_field` to allow `set_data` - # to set property method. - setattr(self, k, v) - - def update(self, instance: "BaseDataElement") -> None: - """The update() method updates the BaseDataElement with the elements - from another BaseDataElement object. - - Args: - instance (BaseDataElement): Another BaseDataElement object for - update the current object. - """ - assert isinstance(instance, BaseDataElement), f"instance should be a `BaseDataElement` but got {type(instance)}" - self.set_metainfo(dict(instance.metainfo_items())) - self.set_data(dict(instance.items())) - - def new(self, *, metainfo: dict | None = None, **kwargs) -> "BaseDataElement": - """Return a new data element with same type. If ``metainfo`` and - ``data`` are None, the new data element will have same metainfo and - data. If metainfo or data is not None, the new result will overwrite it - with the input value. - - Args: - metainfo (dict, optional): A dict contains the meta information - of image, such as ``img_shape``, ``scale_factor``, etc. - Defaults to None. - kwargs (dict): A dict contains annotations of image or - model predictions. - - Returns: - BaseDataElement: A new data element with same type. - """ - new_data = self.__class__() - - if metainfo is not None: - new_data.set_metainfo(metainfo) - else: - new_data.set_metainfo(dict(self.metainfo_items())) - if kwargs: - new_data.set_data(kwargs) - else: - new_data.set_data(dict(self.items())) - return new_data - - def clone(self): - """Deep copy the current data element. - - Returns: - BaseDataElement: The copy of current data element. - """ - clone_data = self.__class__() - clone_data.set_metainfo(dict(self.metainfo_items())) - clone_data.set_data(dict(self.items())) - return clone_data - - def keys(self) -> list: - """ - Returns: - list: Contains all keys in data_fields. - """ - # We assume that the name of the attribute related to property is - # '_' + the name of the property. We use this rule to filter out - # private keys. - # TODO: Use a more robust way to solve this problem - private_keys = {"_" + key for key in self._data_fields if isinstance(getattr(type(self), key, None), property)} - return list(self._data_fields - private_keys) - - def metainfo_keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo_fields. - """ - return list(self._metainfo_fields) - - def values(self) -> list: - """ - Returns: - list: Contains all values in data. - """ - return [getattr(self, k) for k in self.keys()] - - def metainfo_values(self) -> list: - """ - Returns: - list: Contains all values in metainfo. - """ - return [getattr(self, k) for k in self.metainfo_keys()] - - def all_keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo and data. - """ - return self.metainfo_keys() + self.keys() - - def all_values(self) -> list: - """ - Returns: - list: Contains all values in metainfo and data. - """ - return self.metainfo_values() + self.values() - - def all_items(self) -> Iterator[tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``metainfo`` and ``data``. - """ - for k in self.all_keys(): - yield (k, getattr(self, k)) - - def items(self) -> Iterator[tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``data``. - """ - for k in self.keys(): - yield (k, getattr(self, k)) - - def metainfo_items(self) -> Iterator[tuple[str, Any]]: - """ - Returns: - iterator: An iterator object whose element is (key, value) tuple - pairs for ``metainfo``. - """ - for k in self.metainfo_keys(): - yield (k, getattr(self, k)) - - @property - def metainfo(self) -> dict: - """dict: A dict contains metainfo of current data element.""" - return dict(self.metainfo_items()) - - def __setattr__(self, name: str, value: Any): - """Setattr is only used to set data.""" - if name in ("_metainfo_fields", "_data_fields"): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f"{name} has been used as a private attribute, which is immutable.") - else: - self.set_field(name=name, value=value, field_type="data", dtype=None) - - def __delattr__(self, item: str): - """Delete the item in dataelement. - - Args: - item (str): The key to delete. - """ - if item in ("_metainfo_fields", "_data_fields"): - raise AttributeError(f"{item} has been used as a private attribute, which is immutable.") - super().__delattr__(item) - if item in self._metainfo_fields: - self._metainfo_fields.remove(item) - elif item in self._data_fields: - self._data_fields.remove(item) - - # dict-like methods - __delitem__ = __delattr__ - - def get(self, key, default=None) -> Any: - """Get property in data and metainfo as the same as python.""" - # Use `getattr()` rather than `self.__dict__.get()` to allow getting - # properties. - return getattr(self, key, default) - - def pop(self, *args) -> Any: - """Pop property in data and metainfo as the same as python.""" - assert len(args) < 3, "``pop`` get more than 2 arguments" - name = args[0] - if name in self._metainfo_fields: - self._metainfo_fields.remove(args[0]) - return self.__dict__.pop(*args) - - elif name in self._data_fields: - self._data_fields.remove(args[0]) - return self.__dict__.pop(*args) - - # with default value - elif len(args) == 2: - return args[1] - else: - # don't just use 'self.__dict__.pop(*args)' for only popping key in - # metainfo or data - raise KeyError(f"{args[0]} is not contained in metainfo or data") - - def __contains__(self, item: str) -> bool: - """Whether the item is in dataelement. - - Args: - item (str): The key to inquire. - """ - return item in self._data_fields or item in self._metainfo_fields - - def set_field( - self, - value: Any, - name: str, - dtype: type | tuple[type, ...] | None = None, - field_type: str = "data", - ) -> None: - """Special method for set union field, used as property.setter - functions.""" - assert field_type in ["metainfo", "data"] - if dtype is not None: - assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}" - - if field_type == "metainfo": - if name in self._data_fields: - raise AttributeError( - f"Cannot set {name} to be a field of metainfo because {name} is already a data field" - ) - self._metainfo_fields.add(name) - else: - if name in self._metainfo_fields: - raise AttributeError( - f"Cannot set {name} to be a field of data because {name} is already a metainfo field" - ) - self._data_fields.add(name) - super().__setattr__(name, value) - - # Tensor-like methods - def to(self, *args, **kwargs) -> "BaseDataElement": - """Apply same name function to all tensors in data_fields.""" - new_data = self.new() - for k, v in self.items(): - if hasattr(v, "to"): - v = v.to(*args, **kwargs) - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def cpu(self) -> "BaseDataElement": - """Convert all tensors to CPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.cpu() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def cuda(self) -> "BaseDataElement": - """Convert all tensors to GPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.cuda() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def musa(self) -> "BaseDataElement": - """Convert all tensors to musa in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.musa() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def npu(self) -> "BaseDataElement": - """Convert all tensors to NPU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.npu() - data = {k: v} - new_data.set_data(data) - return new_data - - def mlu(self) -> "BaseDataElement": - """Convert all tensors to MLU in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.mlu() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def detach(self) -> "BaseDataElement": - """Detach all tensors in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.detach() - data = {k: v} - new_data.set_data(data) - return new_data - - # Tensor-like methods - def numpy(self) -> "BaseDataElement": - """Convert all tensors to np.ndarray in data.""" - new_data = self.new() - for k, v in self.items(): - if isinstance(v, torch.Tensor | BaseDataElement): - v = v.detach().cpu().numpy() - data = {k: v} - new_data.set_data(data) - return new_data - - def to_tensor(self) -> "BaseDataElement": - """Convert all np.ndarray to tensor in data.""" - new_data = self.new() - for k, v in self.items(): - data = {} - if isinstance(v, np.ndarray): - v = torch.from_numpy(v) - data[k] = v - elif isinstance(v, BaseDataElement): - v = v.to_tensor() - data[k] = v - new_data.set_data(data) - return new_data - - def to_dict(self) -> dict: - """Convert BaseDataElement to dict.""" - return {k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items()} - - def __repr__(self) -> str: - """Represent the object.""" - - def _addindent(s_: str, num_spaces: int) -> str: - """This func is modified from `pytorch` https://github.com/pytorch/ - pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu - les/module.py#L29. - - Args: - s_ (str): The string to add spaces. - num_spaces (int): The num of space to add. - - Returns: - str: The string after add indent. - """ - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * " ") + line for line in s] - s = "\n".join(s) # type: ignore - s = first + "\n" + s # type: ignore - return s # type: ignore - - def dump(obj: Any) -> str: - """Represent the object. - - Args: - obj (Any): The obj to represent. - - Returns: - str: The represented str. - """ - _repr = "" - if isinstance(obj, dict): - for k, v in obj.items(): - _repr += f"\n{k}: {_addindent(dump(v), 4)}" - elif isinstance(obj, BaseDataElement): - _repr += "\n\n META INFORMATION" - metainfo_items = dict(obj.metainfo_items()) - _repr += _addindent(dump(metainfo_items), 4) - _repr += "\n\n DATA FIELDS" - items = dict(obj.items()) - _repr += _addindent(dump(items), 4) - classname = obj.__class__.__name__ - _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>" - else: - _repr += repr(obj) - return _repr - - return dump(self) diff --git a/libs/visengine/visengine/structures/instance_data.py b/libs/visengine/visengine/structures/instance_data.py deleted file mode 100644 index 7c85be1..0000000 --- a/libs/visengine/visengine/structures/instance_data.py +++ /dev/null @@ -1,301 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import itertools -from collections.abc import Sized -from typing import Any, Union - -import numpy as np -import torch - -from visengine.device import get_device - -from .base_data_element import BaseDataElement - -BoolTypeTensor: Any -LongTypeTensor: Any - -if get_device() == "npu": - BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] -elif get_device() == "mlu": - BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] -elif get_device() == "musa": - BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor] -else: - BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] - LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] - -IndexType: Any = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] - - -# Modified from -# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py -class InstanceData(BaseDataElement): - """Data structure for instance-level annotations or predictions. - - Subclass of :class:`BaseDataElement`. All value in `data_fields` - should have the same length. This design refer to - https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 - InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value - in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, - and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. - - Examples: - >>> # custom data structure - >>> class TmpObject: - ... def __init__(self, tmp) -> None: - ... assert isinstance(tmp, list) - ... self.tmp = tmp - ... def __len__(self): - ... return len(self.tmp) - ... def __getitem__(self, item): - ... if isinstance(item, int): - ... if item >= len(self) or item < -len(self): # type:ignore - ... raise IndexError(f'Index {item} out of range!') - ... else: - ... # keep the dimension - ... item = slice(item, None, len(self)) - ... return TmpObject(self.tmp[item]) - ... @staticmethod - ... def cat(tmp_objs): - ... assert all(isinstance(results, TmpObject) for results in tmp_objs) - ... if len(tmp_objs) == 1: - ... return tmp_objs[0] - ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] - ... tmp_list = list(itertools.chain(*tmp_list)) - ... new_data = TmpObject(tmp_list) - ... return new_data - ... def __repr__(self): - ... return str(self.tmp) - >>> from visengine.structures import InstanceData - >>> import numpy as np - >>> import torch - >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) - >>> instance_data = InstanceData(metainfo=img_meta) - >>> 'img_shape' in instance_data - True - >>> instance_data.det_labels = torch.LongTensor([2, 3]) - >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) - >>> instance_data.bboxes = torch.rand((2, 4)) - >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) - >>> len(instance_data) - 2 - >>> print(instance_data) - - >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] - >>> sorted_results.det_scores - tensor([0.7000, 0.8000]) - >>> print(instance_data[instance_data.det_scores > 0.75]) - - >>> print(instance_data[instance_data.det_scores > 1]) - - >>> print(instance_data.cat([instance_data, instance_data])) - - """ - - def __setattr__(self, name: str, value: Sized): - """Setattr is only used to set data. - - The value must have the attribute of `__len__` and have the same length - of `InstanceData`. - """ - if name in ("_metainfo_fields", "_data_fields"): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f"{name} has been used as a private attribute, which is immutable.") - - else: - assert isinstance(value, Sized), "value must contain `__len__` attribute" - - if len(self) > 0: - assert len(value) == len(self), ( - f"The length of values {len(value)} is not consistent with the length of this :obj:`InstanceData` {len(self)}" - ) - super().__setattr__(name, value) - - __setitem__ = __setattr__ - - def __getitem__(self, item: IndexType) -> "InstanceData": - """ - Args: - item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, - :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): - Get the corresponding values according to item. - - Returns: - :obj:`InstanceData`: Corresponding values. - """ - assert isinstance(item, IndexType.__args__) - if isinstance(item, list): - item = np.array(item) - if isinstance(item, np.ndarray): - # The default int type of numpy is platform dependent, int32 for - # windows and int64 for linux. `torch.Tensor` requires the index - # should be int64, therefore we simply convert it to int64 here. - # More details in https://github.com/numpy/numpy/issues/9464 - item = item.astype(np.int64) if item.dtype == np.int32 else item - item = torch.from_numpy(item) - - if isinstance(item, str): - return getattr(self, item) - - if isinstance(item, int): - if item >= len(self) or item < -len(self): # type:ignore - raise IndexError(f"Index {item} out of range!") - else: - # keep the dimension - item = slice(item, None, len(self)) - - new_data = self.__class__(metainfo=self.metainfo) - if isinstance(item, torch.Tensor): - assert item.dim() == 1, "Only support to get the values along the first dimension." - if isinstance(item, BoolTypeTensor.__args__): - assert len(item) == len(self), ( - "The shape of the " - "input(BoolTensor) " - f"{len(item)} " - "does not match the shape " - "of the indexed tensor " - "in results_field " - f"{len(self)} at " - "first dimension." - ) - - for k, v in self.items(): - if isinstance(v, torch.Tensor): - new_data[k] = v[item] - elif isinstance(v, np.ndarray): - new_data[k] = v[item.cpu().numpy()] - elif isinstance(v, str | list | tuple) or (hasattr(v, "__getitem__") and hasattr(v, "cat")): - # convert to indexes from BoolTensor - if isinstance(item, BoolTypeTensor.__args__): - indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() - else: - indexes = item.cpu().numpy().tolist() - slice_list = [] - if indexes: - for index in indexes: - slice_list.append(slice(index, None, len(v))) - else: - slice_list.append(slice(None, 0, None)) - r_list = [v[s] for s in slice_list] - if isinstance(v, str | list | tuple): - new_value = r_list[0] - for r in r_list[1:]: - new_value = new_value + r - else: - new_value = v.cat(r_list) - new_data[k] = new_value - else: - raise ValueError( - f"The type of `{k}` is `{type(v)}`, which has no attribute of `cat`, so it does not support slice with `bool`" - ) - - else: - # item is a slice - for k, v in self.items(): - new_data[k] = v[item] - return new_data # type:ignore - - @staticmethod - def cat(instances_list: list["InstanceData"]) -> "InstanceData": - """Concat the instances of all :obj:`InstanceData` in the list. - - Note: To ensure that cat returns as expected, make sure that - all elements in the list must have exactly the same keys. - - Args: - instances_list (list[:obj:`InstanceData`]): A list - of :obj:`InstanceData`. - - Returns: - :obj:`InstanceData` - """ - assert all(isinstance(results, InstanceData) for results in instances_list) - assert len(instances_list) > 0 - if len(instances_list) == 1: - return instances_list[0] - - # metainfo and data_fields must be exactly the - # same for each element to avoid exceptions. - field_keys_list = [instances.all_keys() for instances in instances_list] - assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( - set(itertools.chain(*field_keys_list)) - ) == len(field_keys_list[0]), ( - "There are different keys in " - "`instances_list`, which may " - "cause the cat operation " - "to fail. Please make sure all " - "elements in `instances_list` " - "have the exact same key." - ) - - new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) - for k in instances_list[0].keys(): - values = [results[k] for results in instances_list] - v0 = values[0] - if isinstance(v0, torch.Tensor): - new_values = torch.cat(values, dim=0) - elif isinstance(v0, np.ndarray): - new_values = np.concatenate(values, axis=0) - elif isinstance(v0, str | list | tuple): - new_values = v0[:] - for v in values[1:]: - new_values += v - elif hasattr(v0, "cat"): - new_values = v0.cat(values) - else: - raise ValueError(f"The type of `{k}` is `{type(v0)}` which has no attribute of `cat`") - new_data[k] = new_values - return new_data # type:ignore - - def __len__(self) -> int: - """int: The length of InstanceData.""" - if len(self._data_fields) > 0: - return len(self.values()[0]) - else: - return 0 diff --git a/libs/visengine/visengine/structures/label_data.py b/libs/visengine/visengine/structures/label_data.py deleted file mode 100644 index 720b025..0000000 --- a/libs/visengine/visengine/structures/label_data.py +++ /dev/null @@ -1,46 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. - -import torch - -from .base_data_element import BaseDataElement - - -class LabelData(BaseDataElement): - """Data structure for label-level annotations or predictions.""" - - @staticmethod - def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: - """Convert the one-hot input to label. - - Args: - onehot (torch.Tensor, optional): The one-hot input. The format - of input must be one-hot. - - Returns: - torch.Tensor: The converted results. - """ - assert isinstance(onehot, torch.Tensor) - if onehot.ndim == 1 and onehot.max().item() <= 1 and onehot.min().item() >= 0: - return onehot.nonzero().squeeze(-1) - else: - raise ValueError("input is not one-hot and can not convert to label") - - @staticmethod - def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: - """Convert the label-format input to one-hot. - - Args: - label (torch.Tensor): The label-format input. The format - of item must be label-format. - num_classes (int): The number of classes. - - Returns: - torch.Tensor: The converted results. - """ - assert isinstance(label, torch.Tensor) - onehot = label.new_zeros((num_classes,)) - assert max(label, default=torch.tensor(0)).item() < num_classes - onehot[label] = 1 - return onehot diff --git a/libs/visengine/visengine/structures/pixel_data.py b/libs/visengine/visengine/structures/pixel_data.py deleted file mode 100644 index 052fe9d..0000000 --- a/libs/visengine/visengine/structures/pixel_data.py +++ /dev/null @@ -1,122 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import warnings -from collections.abc import Sequence - -import numpy as np -import torch - -from .base_data_element import BaseDataElement - - -class PixelData(BaseDataElement): - """Data structure for pixel-level annotations or predictions. - - All data items in ``data_fields`` of ``PixelData`` meet the following - requirements: - - - They all have 3 dimensions in orders of channel, height, and width. - - They should have the same height and width. - - Examples: - >>> metainfo = dict( - ... img_id=random.randint(0, 100), - ... img_shape=(random.randint(400, 600), random.randint(400, 600))) - >>> image = np.random.randint(0, 255, (4, 20, 40)) - >>> featmap = torch.randint(0, 255, (10, 20, 40)) - >>> pixel_data = PixelData(metainfo=metainfo, - ... image=image, - ... featmap=featmap) - >>> print(pixel_data.shape) - (20, 40) - - >>> # slice - >>> slice_data = pixel_data[10:20, 20:40] - >>> assert slice_data.shape == (10, 20) - >>> slice_data = pixel_data[10, 20] - >>> assert slice_data.shape == (1, 1) - - >>> # set - >>> pixel_data.map3 = torch.randint(0, 255, (20, 40)) - >>> assert tuple(pixel_data.map3.shape) == (1, 20, 40) - >>> with self.assertRaises(AssertionError): - ... # The dimension must be 3 or 2 - ... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) - """ - - def __setattr__(self, name: str, value: torch.Tensor | np.ndarray): - """Set attributes of ``PixelData``. - - If the dimension of value is 2 and its shape meet the demand, it - will automatically expand its channel-dimension. - - Args: - name (str): The key to access the value, stored in `PixelData`. - value (Union[torch.Tensor, np.ndarray]): The value to store in. - The type of value must be `torch.Tensor` or `np.ndarray`, - and its shape must meet the requirements of `PixelData`. - """ - if name in ("_metainfo_fields", "_data_fields"): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError(f"{name} has been used as a private attribute, which is immutable.") - - else: - assert isinstance(value, torch.Tensor | np.ndarray), ( - f"Can not set {type(value)}, only support {(torch.Tensor, np.ndarray)}" - ) - - if self.shape: - assert tuple(value.shape[-2:]) == self.shape, ( - f"The height and width of values {tuple(value.shape[-2:])} is not consistent with the shape of this :obj:`PixelData` {self.shape}" - ) - assert value.ndim in [2, 3], f"The dim of value must be 2 or 3, but got {value.ndim}" - if value.ndim == 2: - value = value[None] - warnings.warn( - f"The shape of value will convert from {value.shape[-2:]} to {value.shape}", - stacklevel=2, - ) - super().__setattr__(name, value) - - # TODO torch.Long/bool - def __getitem__(self, item: Sequence[int | slice]) -> "PixelData": - """ - Args: - item (Sequence[Union[int, slice]]): Get the corresponding values - according to item. - - Returns: - :obj:`PixelData`: Corresponding values. - """ - - new_data = self.__class__(metainfo=self.metainfo) - if isinstance(item, tuple): - assert len(item) == 2, "Only support to slice height and width" - tmp_item: list[slice] = [] - for index, single_item in enumerate(item[::-1]): - if isinstance(single_item, int): - tmp_item.insert(0, slice(single_item, None, self.shape[-index - 1])) - elif isinstance(single_item, slice): - tmp_item.insert(0, single_item) - else: - raise TypeError(f"The type of element in input must be int or slice, but got {type(single_item)}") - tmp_item.insert(0, slice(None, None, None)) - item = tuple(tmp_item) - for k, v in self.items(): - setattr(new_data, k, v[item]) - else: - raise TypeError(f"Unsupported type {type(item)} for slicing PixelData") - return new_data - - @property - def shape(self): - """The shape of pixel data.""" - if len(self._data_fields) > 0: - return tuple(self.values()[0].shape[-2:]) - else: - return None - - # TODO padding, resize diff --git a/libs/visengine/visengine/utils/__init__.py b/libs/visengine/visengine/utils/__init__.py deleted file mode 100644 index a8f7c00..0000000 --- a/libs/visengine/visengine/utils/__init__.py +++ /dev/null @@ -1,104 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .manager import ManagerMeta, ManagerMixin -from .misc import ( - apply_to, - check_prerequisites, - concat_list, - deprecated_api_warning, - deprecated_function, - get_object_from_string, - has_method, - import_modules_from_strings, - is_list_of, - is_method_overridden, - is_seq_of, - is_str, - is_tuple_of, - iter_cast, - list_cast, - requires_executable, - requires_package, - slice_list, - to_1tuple, - to_2tuple, - to_3tuple, - to_4tuple, - to_ntuple, - tuple_cast, -) -from .package_utils import ( - call_command, - get_installed_path, - install_package, - is_installed, -) -from .path import ( - check_file_exist, - fopen, - is_abs, - is_filepath, - mkdir_or_exist, - scandir, - symlink, -) -from .progressbar import ( - ProgressBar, - track_iter_progress, - track_parallel_progress, - track_progress, -) -from .progressbar_rich import track_progress_rich -from .timer import Timer, TimerError, check_time -from .version_utils import digit_version, get_git_hash - -__all__ = [ - "ManagerMeta", - "ManagerMixin", - "ProgressBar", - "Timer", - "TimerError", - "apply_to", - "call_command", - "check_file_exist", - "check_prerequisites", - "check_time", - "concat_list", - "deprecated_api_warning", - "deprecated_function", - "digit_version", - "fopen", - "get_git_hash", - "get_installed_path", - "get_object_from_string", - "has_method", - "import_modules_from_strings", - "install_package", - "is_abs", - "is_filepath", - "is_installed", - "is_list_of", - "is_method_overridden", - "is_seq_of", - "is_str", - "is_tuple_of", - "iter_cast", - "list_cast", - "mkdir_or_exist", - "requires_executable", - "requires_package", - "scandir", - "slice_list", - "symlink", - "to_1tuple", - "to_2tuple", - "to_3tuple", - "to_4tuple", - "to_ntuple", - "track_iter_progress", - "track_parallel_progress", - "track_progress", - "track_progress_rich", - "tuple_cast", -] diff --git a/libs/visengine/visengine/utils/dl_utils/__init__.py b/libs/visengine/visengine/utils/dl_utils/__init__.py deleted file mode 100644 index 59b04cc..0000000 --- a/libs/visengine/visengine/utils/dl_utils/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - -from .collect_env import collect_env -from .hub import load_url -from .misc import has_batch_norm, is_norm, tensor2imgs -from .setup_env import set_multi_processing -from .torch_ops import torch_meshgrid -from .trace import is_jit_tracing - -TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) - -__all__ = [ - "TORCH_VERSION", - "collect_env", - "has_batch_norm", - "is_jit_tracing", - "is_norm", - "load_url", - "set_multi_processing", - "tensor2imgs", - "torch_meshgrid", -] diff --git a/libs/visengine/visengine/utils/dl_utils/collect_env.py b/libs/visengine/visengine/utils/dl_utils/collect_env.py deleted file mode 100644 index 1681110..0000000 --- a/libs/visengine/visengine/utils/dl_utils/collect_env.py +++ /dev/null @@ -1,55 +0,0 @@ -import sys -from collections import OrderedDict, defaultdict - -import cv2 -import numpy as np -import torch -import torchvision - -from visengine.device import is_cuda_available -from visengine.version import __version__ as visengine_version - - -def collect_env(): - """Collect the information of the running environments. - - Returns: - dict: The environment information. The following fields are contained. - - - sys.platform: The variable of ``sys.platform``. - - Python: Python version. - - CUDA available: Bool, indicating if CUDA is available. - - GPU devices: Device type of each GPU. - - CUDA_HOME (optional): The env var ``CUDA_HOME``. - - NVCC (optional): NVCC version. - - GCC: GCC version, "n/a" if GCC is not installed. - - MSVC: Microsoft Virtual C++ Compiler version, Windows only. - - PyTorch: PyTorch version. - - PyTorch compiling details: The output of \ - ``torch.__config__.show()``. - - TorchVision (optional): TorchVision version. - - OpenCV (optional): OpenCV version. - - MMENGINE: MMENGINE version. - """ - - env_info = OrderedDict() - env_info["sys.platform"] = sys.platform - env_info["Python"] = sys.version.replace("\n", "") - - cuda_available = is_cuda_available() - env_info["CUDA available"] = cuda_available - env_info["numpy_random_seed"] = np.random.get_state()[1][0] - - if cuda_available: - devices = defaultdict(list) - for k in range(torch.cuda.device_count()): - devices[torch.cuda.get_device_name(k)].append(str(k)) - for name, device_ids in devices.items(): - env_info["GPU " + ",".join(device_ids)] = name - - env_info["PyTorch"] = torch.__version__ - env_info["TorchVision"] = torchvision.__version__ - env_info["OpenCV"] = cv2.__version__ - env_info["VisEngine"] = visengine_version - - return env_info diff --git a/libs/visengine/visengine/utils/dl_utils/hub.py b/libs/visengine/visengine/utils/dl_utils/hub.py deleted file mode 100644 index b00abca..0000000 --- a/libs/visengine/visengine/utils/dl_utils/hub.py +++ /dev/null @@ -1,36 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. - -from typing import Any - -from ..path import mkdir_or_exist -from torch.hub import load_state_dict_from_url as _load_state_dict_from_url - -__all__ = ["mkdir_or_exist", "load_url"] - - -def load_url( - url: str, - model_dir: str | None = None, - map_location: Any | None = None, - progress: bool = True, - check_hash: bool = False, - file_name: str | None = None, - **kwargs: Any, -): - """Compat shim that delegates to ``torch.hub.load_state_dict_from_url``. - - Accepts the legacy ``torch.utils.model_zoo.load_url`` signature so existing - call sites remain unchanged while benefiting from the maintained API. - """ - - return _load_state_dict_from_url( - url, - model_dir=model_dir, - map_location=map_location, - progress=progress, - check_hash=check_hash, - file_name=file_name, - **kwargs, - ) diff --git a/libs/visengine/visengine/utils/dl_utils/misc.py b/libs/visengine/visengine/utils/dl_utils/misc.py deleted file mode 100644 index 8ed330e..0000000 --- a/libs/visengine/visengine/utils/dl_utils/misc.py +++ /dev/null @@ -1,111 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import pkgutil - -import numpy as np -import torch -import torch.nn as nn - -from ..misc import is_tuple_of -from torch.nn.modules.batchnorm import _BatchNorm - - -def is_norm(layer: nn.Module, exclude: type | tuple[type] | None = None) -> bool: - """Check if a layer is a normalization layer. - - Args: - layer (nn.Module): The layer to be checked. - exclude (type, tuple[type], optional): Types to be excluded. - - Returns: - bool: Whether the layer is a norm layer. - """ - if exclude is not None: - if not isinstance(exclude, tuple): - exclude = (exclude,) - if not is_tuple_of(exclude, type): - raise TypeError( - f'"exclude" must be either None or type or a tuple of types, but got {type(exclude)}: {exclude}' - ) - - if exclude and isinstance(layer, exclude): - return False - - all_norm_bases = (nn.GroupNorm, nn.LayerNorm) - return isinstance(layer, all_norm_bases) - - -def tensor2imgs( - tensor: torch.Tensor, - mean: tuple[float, float, float] | None = None, - std: tuple[float, float, float] | None = None, - to_bgr: bool = True, -): - """Convert tensor to 3-channel images or 1-channel gray images. - - Args: - tensor (torch.Tensor): Tensor that contains multiple images, shape ( - N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format - should be RGB. - mean (tuple[float], optional): Mean of images. If None, - (0, 0, 0) will be used for tensor with 3-channel, - while (0, ) for tensor with 1-channel. Defaults to None. - std (tuple[float], optional): Standard deviation of images. If None, - (1, 1, 1) will be used for tensor with 3-channel, - while (1, ) for tensor with 1-channel. Defaults to None. - to_bgr (bool): For the tensor with 3 channel, convert its format to - BGR. For the tensor with 1 channel, it must be False. Defaults to - True. - - Returns: - list[np.ndarray]: A list that contains multiple images. - """ - - assert torch.is_tensor(tensor) and tensor.ndim == 4 - channels = tensor.size(1) - assert channels in [1, 3] - if mean is None: - mean = (0,) * channels - if std is None: - std = (1,) * channels - assert (channels == len(mean) == len(std) == 3) or (channels == len(mean) == len(std) == 1 and not to_bgr) - mean = tensor.new_tensor(mean).view(1, -1) - std = tensor.new_tensor(std).view(1, -1) - tensor = tensor.permute(0, 2, 3, 1) * std + mean - imgs = tensor.detach().cpu().numpy() - if to_bgr and channels == 3: - imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR - imgs = [np.ascontiguousarray(img) for img in imgs] - return imgs - - -def has_batch_norm(model: nn.Module) -> bool: - """Detect whether model has a BatchNormalization layer. - - Args: - model (nn.Module): training model. - - Returns: - bool: whether model has a BatchNormalization layer - """ - if isinstance(model, _BatchNorm): - return True - for m in model.children(): - if has_batch_norm(m): - return True - return False - - -def mmcv_full_available() -> bool: - """Check whether mmcv-full is installed. - - Returns: - bool: True if mmcv-full is installed else False. - """ - try: - import mmcv # noqa: F401 - except ImportError: - return False - ext_loader = pkgutil.find_loader("mmcv._ext") - return ext_loader is not None diff --git a/libs/visengine/visengine/utils/dl_utils/setup_env.py b/libs/visengine/visengine/utils/dl_utils/setup_env.py deleted file mode 100644 index 33c82eb..0000000 --- a/libs/visengine/visengine/utils/dl_utils/setup_env.py +++ /dev/null @@ -1,71 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import platform -import warnings - -import torch.multiprocessing as mp - - -def set_multi_processing( - mp_start_method: str = "fork", - opencv_num_threads: int = 0, - distributed: bool = False, -) -> None: - """Set multi-processing related environment. - - Args: - mp_start_method (str): Set the method which should be used to start - child processes. Defaults to 'fork'. - opencv_num_threads (int): Number of threads for opencv. - Defaults to 0. - distributed (bool): True if distributed environment. - Defaults to False. - """ - # set multi-process start method as `fork` to speed up the training - if platform.system() != "Windows": - current_method = mp.get_start_method(allow_none=True) - if current_method is not None and current_method != mp_start_method: - warnings.warn( - f"Multi-processing start method `{mp_start_method}` is " - f"different from the previous setting `{current_method}`." - f"It will be force set to `{mp_start_method}`. You can " - "change this behavior by changing `mp_start_method` in " - "your config.", - stacklevel=2, - ) - mp.set_start_method(mp_start_method, force=True) - - try: - import cv2 - - # disable opencv multithreading to avoid system being overloaded - cv2.setNumThreads(opencv_num_threads) - except ImportError: - pass - - # setup OMP threads - # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py - if "OMP_NUM_THREADS" not in os.environ and distributed: - omp_num_threads = 1 - warnings.warn( - "Setting OMP_NUM_THREADS environment variable for each process" - f" to be {omp_num_threads} in default, to avoid your system " - "being overloaded, please further tune the variable for " - "optimal performance in your application as needed.", - stacklevel=2, - ) - os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) - - # setup MKL threads - if "MKL_NUM_THREADS" not in os.environ and distributed: - mkl_num_threads = 1 - warnings.warn( - "Setting MKL_NUM_THREADS environment variable for each process" - f" to be {mkl_num_threads} in default, to avoid your system " - "being overloaded, please further tune the variable for " - "optimal performance in your application as needed.", - stacklevel=2, - ) - os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) diff --git a/libs/visengine/visengine/utils/dl_utils/time_counter.py b/libs/visengine/visengine/utils/dl_utils/time_counter.py deleted file mode 100644 index ba5acdc..0000000 --- a/libs/visengine/visengine/utils/dl_utils/time_counter.py +++ /dev/null @@ -1,135 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import time - -import torch - -from visengine.device import is_cuda_available -from visengine.dist.utils import master_only -from visengine.logging import MMLogger, print_log - - -class TimeCounter: - """A tool that counts the average running time of a function or a method. - Users can use it as a decorator or context manager to calculate the average - running time of code blocks. - - Args: - log_interval (int): The interval of logging. Defaults to 1. - warmup_interval (int): The interval of warmup. Defaults to 1. - with_sync (bool): Whether to synchronize cuda. Defaults to True. - tag (str, optional): Function tag. Used to distinguish between - different functions or methods being called. Defaults to None. - logger (MMLogger, optional): Formatted logger used to record messages. - Defaults to None. - - Examples: - >>> import time - >>> from visengine.utils.dl_utils import TimeCounter - >>> @TimeCounter() - ... def fun1(): - ... time.sleep(0.1) - ... fun1() - [fun1]-time per run averaged in the past 1 runs: 100.0 ms - - >>> @@TimeCounter(log_interval=2, tag='fun') - ... def fun2(): - ... time.sleep(0.2) - >>> for _ in range(3): - ... fun2() - [fun]-time per run averaged in the past 2 runs: 200.0 ms - - >>> with TimeCounter(tag='fun3'): - ... time.sleep(0.3) - [fun3]-time per run averaged in the past 1 runs: 300.0 ms - """ - - instance_dict: dict = {} - - log_interval: int - warmup_interval: int - logger: MMLogger | None - __count: int - __pure_inf_time: float - - def __new__( - cls, - log_interval: int = 1, - warmup_interval: int = 1, - with_sync: bool = True, - tag: str | None = None, - logger: MMLogger | None = None, - ): - assert warmup_interval >= 1 - if tag is not None and tag in cls.instance_dict: - return cls.instance_dict[tag] - - instance = super().__new__(cls) - cls.instance_dict[tag] = instance - - instance.log_interval = log_interval - instance.warmup_interval = warmup_interval - instance.with_sync = with_sync # type: ignore - instance.tag = tag - instance.logger = logger - - instance.__count = 0 - instance.__pure_inf_time = 0.0 - instance.__start_time = 0.0 - - return instance - - @master_only - def __call__(self, fn): - if self.tag is None: - self.tag = fn.__name__ - - def wrapper(*args, **kwargs): - self.__count += 1 - - if self.with_sync and is_cuda_available(): - torch.cuda.synchronize() - start_time = time.perf_counter() - - result = fn(*args, **kwargs) - - if self.with_sync and is_cuda_available(): - torch.cuda.synchronize() - elapsed = time.perf_counter() - start_time - self.print_time(elapsed) - - return result - - return wrapper - - @master_only - def __enter__(self): - assert self.tag is not None, ( - "In order to clearly distinguish printing information in different contexts, please specify the tag parameter" - ) - - self.__count += 1 - - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() - self.__start_time = time.perf_counter() - - @master_only - def __exit__(self, exc_type, exc_val, exc_tb): - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() - elapsed = time.perf_counter() - self.__start_time - self.print_time(elapsed) - - def print_time(self, elapsed: int | float) -> None: - """Print times per count.""" - if self.__count >= self.warmup_interval: - self.__pure_inf_time += elapsed - - if self.__count % self.log_interval == 0: - times_per_count = 1000 * self.__pure_inf_time / (self.__count - self.warmup_interval + 1) - print_log( - f"[{self.tag}]-time per run averaged in the past {self.__count} runs: {times_per_count:.1f} ms", - self.logger, - ) diff --git a/libs/visengine/visengine/utils/dl_utils/torch_ops.py b/libs/visengine/visengine/utils/dl_utils/torch_ops.py deleted file mode 100644 index 4b9de92..0000000 --- a/libs/visengine/visengine/utils/dl_utils/torch_ops.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - - -def torch_meshgrid(*tensors): - """A wrapper of torch.meshgrid to compat different PyTorch versions. - - Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. - So we implement a wrapper here to avoid warning when using high-version - PyTorch and avoid compatibility issues when using previous versions of - PyTorch. - - Args: - tensors (List[Tensor]): List of scalars or 1 dimensional tensors. - - Returns: - Sequence[Tensor]: Sequence of meshgrid tensors. - """ - return torch.meshgrid(*tensors, indexing="ij") diff --git a/libs/visengine/visengine/utils/dl_utils/trace.py b/libs/visengine/visengine/utils/dl_utils/trace.py deleted file mode 100644 index 48ba268..0000000 --- a/libs/visengine/visengine/utils/dl_utils/trace.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch - - -# This is not worth a whole file -def is_jit_tracing() -> bool: - return torch.jit.is_tracing() diff --git a/libs/visengine/visengine/utils/dl_utils/visualize.py b/libs/visengine/visengine/utils/dl_utils/visualize.py deleted file mode 100644 index 6b58ec5..0000000 --- a/libs/visengine/visengine/utils/dl_utils/visualize.py +++ /dev/null @@ -1,63 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import patch - -import torch -import torch.nn as nn - -from visengine.model import BaseModel -from visengine.registry import MODELS - - -@MODELS.register_module(force=True) -class ToyModel(BaseModel): - def __init__(self, *args, **kwargs): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, *args, **kwargs): - return {"loss": torch.tensor(0.0)} - - -def update_params_step(self, loss): - pass - - -def runtimeinfo_step(self, runner, batch_idx, data_batch=None): - runner.message_hub.update_info("iter", runner.iter) - lr_dict = runner.optim_wrapper.get_lr() - for name, lr in lr_dict.items(): - runner.message_hub.update_scalar(f"train/{name}", lr[0]) - - momentum_dict = runner.optim_wrapper.get_momentum() - for name, momentum in momentum_dict.items(): - runner.message_hub.update_scalar(f"train/{name}", momentum[0]) - - -@patch("visengine.optim.optimizer.OptimWrapper.update_params", update_params_step) -@patch("visengine.hooks.RuntimeInfoHook.before_train_iter", runtimeinfo_step) -def fake_run(cfg): - from visengine.runner import Runner - - cfg.pop("model") - cfg.pop("visualizer") - cfg.pop("val_dataloader") - cfg.pop("val_evaluator") - cfg.pop("val_cfg") - cfg.pop("test_dataloader") - cfg.pop("test_evaluator") - cfg.pop("test_cfg") - extra_cfg = { - "model": {"type": "ToyModel"}, - "visualizer": { - "type": "Visualizer", - "vis_backends": [{"type": "TensorboardVisBackend", "save_dir": "temp_dir"}], - }, - } - cfg.merge_from_dict(extra_cfg) - # build the runner from config - runner = Runner.from_cfg(cfg) - - # start training - runner.train() diff --git a/libs/visengine/visengine/utils/manager.py b/libs/visengine/visengine/utils/manager.py deleted file mode 100644 index bc0c259..0000000 --- a/libs/visengine/visengine/utils/manager.py +++ /dev/null @@ -1,169 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import threading -import warnings -from collections import OrderedDict -from typing import TypeVar - -_lock = threading.RLock() -T = TypeVar("T") - - -def _accquire_lock() -> None: - """Acquire the module-level lock for serializing access to shared data. - - This should be released with _release_lock(). - """ - if _lock: - _lock.acquire() - - -def _release_lock() -> None: - """Release the module-level lock acquired by calling _accquire_lock().""" - if _lock: - _lock.release() - - -class ManagerMeta(type): - """The metaclass for global accessible class. - - The subclasses inheriting from ``ManagerMeta`` will manage their - own ``_instance_dict`` and root instances. The constructors of subclasses - must contain the ``name`` argument. - - Examples: - >>> class SubClass1(metaclass=ManagerMeta): - >>> def __init__(self, *args, **kwargs): - >>> pass - AssertionError: .__init__ must have the - name argument. - >>> class SubClass2(metaclass=ManagerMeta): - >>> def __init__(self, name): - >>> pass - >>> # valid format. - """ - - def __init__(cls, *args): - cls._instance_dict = OrderedDict() - params = inspect.getfullargspec(cls) - params_names = params[0] if params[0] else [] - assert "name" in params_names, f"{cls} must have the `name` argument" - super().__init__(*args) - - -class ManagerMixin(metaclass=ManagerMeta): - """``ManagerMixin`` is the base class for classes that have global access - requirements. - - The subclasses inheriting from ``ManagerMixin`` can get their - global instances. - - Examples: - >>> class GlobalAccessible(ManagerMixin): - >>> def __init__(self, name=''): - >>> super().__init__(name) - >>> - >>> GlobalAccessible.get_instance('name') - >>> instance_1 = GlobalAccessible.get_instance('name') - >>> instance_2 = GlobalAccessible.get_instance('name') - >>> assert id(instance_1) == id(instance_2) - - Args: - name (str): Name of the instance. Defaults to ''. - """ - - def __init__(self, name: str = "", **kwargs): - assert isinstance(name, str) and name, "name argument must be an non-empty string." - self._instance_name = name - - @classmethod - def get_instance(cls: type[T], name: str, **kwargs) -> T: - """Get subclass instance by name if the name exists. - - If corresponding name instance has not been created, ``get_instance`` - will create an instance, otherwise ``get_instance`` will return the - corresponding instance. - - Examples - >>> instance1 = GlobalAccessible.get_instance('name1') - >>> # Create name1 instance. - >>> instance.instance_name - name1 - >>> instance2 = GlobalAccessible.get_instance('name1') - >>> # Get name1 instance. - >>> assert id(instance1) == id(instance2) - - Args: - name (str): Name of instance. Defaults to ''. - - Returns: - object: Corresponding name instance, the latest instance, or root - instance. - """ - _accquire_lock() - assert isinstance(name, str), f"type of name should be str, but got {type(cls)}" - instance_dict = cls._instance_dict # type: ignore - # Get the instance by name. - if name not in instance_dict: - instance = cls(name=name, **kwargs) # type: ignore - instance_dict[name] = instance # type: ignore - elif kwargs: - warnings.warn( - f"{cls} instance named of {name} has been created, the method `get_instance` should not accept any other arguments", - stacklevel=2, - ) - # Get latest instantiated instance or root instance. - _release_lock() - return instance_dict[name] - - @classmethod - def get_current_instance(cls): - """Get latest created instance. - - Before calling ``get_current_instance``, The subclass must have called - ``get_instance(xxx)`` at least once. - - Examples - >>> instance = GlobalAccessible.get_current_instance() - AssertionError: At least one of name and current needs to be set - >>> instance = GlobalAccessible.get_instance('name1') - >>> instance.instance_name - name1 - >>> instance = GlobalAccessible.get_current_instance() - >>> instance.instance_name - name1 - - Returns: - object: Latest created instance. - """ - _accquire_lock() - if not cls._instance_dict: - raise RuntimeError( - f"Before calling {cls.__name__}.get_current_instance(), you should call get_instance(name=xxx) at least once." - ) - name = next(iter(reversed(cls._instance_dict))) - _release_lock() - return cls._instance_dict[name] - - @classmethod - def check_instance_created(cls, name: str) -> bool: - """Check whether the name corresponding instance exists. - - Args: - name (str): Name of instance. - - Returns: - bool: Whether the name corresponding instance exists. - """ - return name in cls._instance_dict - - @property - def instance_name(self) -> str: - """Get the name of instance. - - Returns: - str: Name of instance. - """ - return self._instance_name diff --git a/libs/visengine/visengine/utils/misc.py b/libs/visengine/visengine/utils/misc.py deleted file mode 100644 index 9dcb888..0000000 --- a/libs/visengine/visengine/utils/misc.py +++ /dev/null @@ -1,533 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import collections.abc -import functools -import itertools -import logging -import re -import subprocess -import textwrap -import warnings -from collections import abc -from collections.abc import Callable -from importlib import import_module -from inspect import getfullargspec, ismodule -from itertools import repeat -from typing import Any - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def is_str(x): - """Whether the input is an string instance. - - Note: This method is deprecated since python 2 is no longer supported. - """ - return isinstance(x, str) - - -def import_modules_from_strings(imports, allow_failed_imports=False): - """Import modules from the given list of strings. - - Args: - imports (list | str | None): The given module names to be imported. - allow_failed_imports (bool): If True, the failed imports will return - None. Otherwise, an ImportError is raise. Defaults to False. - - Returns: - list[module] | module | None: The imported modules. - - Examples: - >>> osp, sys = import_modules_from_strings( - ... ['os.path', 'sys']) - >>> import os.path as osp_ - >>> import sys as sys_ - >>> assert osp == osp_ - >>> assert sys == sys_ - """ - if not imports: - return - single_import = False - if isinstance(imports, str): - single_import = True - imports = [imports] - if not isinstance(imports, list): - raise TypeError(f"custom_imports must be a list but got type {type(imports)}") - imported = [] - for imp in imports: - if not isinstance(imp, str): - raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") - try: - imported_tmp = import_module(imp) - except ImportError: - if allow_failed_imports: - warnings.warn(f"{imp} failed to import and is ignored.", UserWarning, stacklevel=2) - imported_tmp = None - else: - raise ImportError(f"Failed to import {imp}") - imported.append(imported_tmp) - if single_import: - imported = imported[0] - return imported - - -def iter_cast(inputs, dst_type, return_type=None): - """Cast elements of an iterable object into some type. - - Args: - inputs (Iterable): The input object. - dst_type (type): Destination type. - return_type (type, optional): If specified, the output object will be - converted to this type, otherwise an iterator. - - Returns: - iterator or specified type: The converted object. - """ - if not isinstance(inputs, abc.Iterable): - raise TypeError("inputs must be an iterable object") - if not isinstance(dst_type, type): - raise TypeError('"dst_type" must be a valid type') - - out_iterable = map(dst_type, inputs) - - if return_type is None: - return out_iterable - else: - return return_type(out_iterable) - - -def list_cast(inputs, dst_type): - """Cast elements of an iterable object into a list of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=list) - - -def tuple_cast(inputs, dst_type): - """Cast elements of an iterable object into a tuple of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=tuple) - - -def is_seq_of(seq: Any, expected_type: type | tuple, seq_type: type | None = None) -> bool: - """Check whether it is a sequence of some type. - - Args: - seq (Sequence): The sequence to be checked. - expected_type (type or tuple): Expected type of sequence items. - seq_type (type, optional): Expected sequence type. Defaults to None. - - Returns: - bool: Return True if ``seq`` is valid else False. - - Examples: - >>> from visengine.utils import is_seq_of - >>> seq = ['a', 'b', 'c'] - >>> is_seq_of(seq, str) - True - >>> is_seq_of(seq, int) - False - """ - if seq_type is None: - exp_seq_type = abc.Sequence - else: - assert isinstance(seq_type, type) - exp_seq_type = seq_type - if not isinstance(seq, exp_seq_type): - return False - for item in seq: - if not isinstance(item, expected_type): - return False - return True - - -def is_list_of(seq, expected_type): - """Check whether it is a list of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=list) - - -def is_tuple_of(seq, expected_type): - """Check whether it is a tuple of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=tuple) - - -def slice_list(in_list, lens): - """Slice a list into several sub lists by a list of given length. - - Args: - in_list (list): The list to be sliced. - lens(int or list): The expected length of each out list. - - Returns: - list: A list of sliced list. - """ - if isinstance(lens, int): - assert len(in_list) % lens == 0 - lens = [lens] * int(len(in_list) / lens) - if not isinstance(lens, list): - raise TypeError('"indices" must be an integer or a list of integers') - elif sum(lens) != len(in_list): - raise ValueError(f"sum of lens and list length does not match: {sum(lens)} != {len(in_list)}") - out_list = [] - idx = 0 - for i in range(len(lens)): - out_list.append(in_list[idx : idx + lens[i]]) - idx += lens[i] - return out_list - - -def concat_list(in_list): - """Concatenate a list of list into a single list. - - Args: - in_list (list): The list of list to be merged. - - Returns: - list: The concatenated flat list. - """ - return list(itertools.chain(*in_list)) - - -def apply_to(data: Any, expr: Callable, apply_func: Callable): - """Apply function to each element in dict, list or tuple that matches with - the expression. - - For examples, if you want to convert each element in a list of dict from - `np.ndarray` to `Tensor`. You can use the following code: - - Examples: - >>> from visengine.utils import apply_to - >>> import numpy as np - >>> import torch - >>> data = dict(array=[np.array(1)]) # {'array': [array(1)]} - >>> result = apply_to(data, lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x)) - >>> print(result) # {'array': [tensor(1)]} - - Args: - data (Any): Data to be applied. - expr (Callable): Expression to tell which data should be applied with - the function. It should return a boolean. - apply_func (Callable): Function applied to data. - - Returns: - Any: The data after applying. - """ - if isinstance(data, dict): - # Keep the original dict type - res = type(data)() - for key, value in data.items(): - res[key] = apply_to(value, expr, apply_func) - return res - elif isinstance(data, tuple) and hasattr(data, "_fields"): - # namedtuple - return type(data)(*(apply_to(sample, expr, apply_func) for sample in data)) # type: ignore # yapf:disable - elif isinstance(data, tuple | list): - return type(data)(apply_to(sample, expr, apply_func) for sample in data) # type: ignore # yapf:disable - elif expr(data): - return apply_func(data) - else: - return data - - -def check_prerequisites( - prerequisites, - checker, - msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' - "found, please install them first.", -): # yapf: disable - """A decorator factory to check if prerequisites are satisfied. - - Args: - prerequisites (str of list[str]): Prerequisites to be checked. - checker (callable): The checker method that returns True if a - prerequisite is meet, False otherwise. - msg_tmpl (str): The message template with two variables. - - Returns: - decorator: A specific decorator. - """ - - def wrap(func): - @functools.wraps(func) - def wrapped_func(*args, **kwargs): - requirements = [prerequisites] if isinstance(prerequisites, str) else prerequisites - missing = [] - for item in requirements: - if not checker(item): - missing.append(item) - if missing: - print(msg_tmpl.format(", ".join(missing), func.__name__)) - raise RuntimeError("Prerequisites not meet.") - else: - return func(*args, **kwargs) - - return wrapped_func - - return wrap - - -def _check_py_package(package): - try: - import_module(package) - except ImportError: - return False - else: - return True - - -def _check_executable(cmd): - if subprocess.call(f"which {cmd}", shell=True) != 0: - return False - else: - return True - - -def requires_package(prerequisites): - """A decorator to check if some python packages are installed. - - Example: - >>> @requires_package('numpy') - >>> func(arg1, args): - >>> return numpy.zeros(1) - array([0.]) - >>> @requires_package(['numpy', 'non_package']) - >>> func(arg1, args): - >>> return numpy.zeros(1) - ImportError - """ - return check_prerequisites(prerequisites, checker=_check_py_package) - - -def requires_executable(prerequisites): - """A decorator to check if some executable files are installed. - - Example: - >>> @requires_executable('ffmpeg') - >>> func(arg1, args): - >>> print(1) - 1 - """ - return check_prerequisites(prerequisites, checker=_check_executable) - - -def deprecated_api_warning(name_dict: dict, cls_name: str | None = None) -> Callable: - """A decorator to check if some arguments are deprecate and try to replace - deprecate src_arg_name to dst_arg_name. - - Args: - name_dict(dict): - key (str): Deprecate argument names. - val (str): Expected argument names. - - Returns: - func: New function. - """ - - def api_warning_wrapper(old_func): - @functools.wraps(old_func) - def new_func(*args, **kwargs): - # get the arg spec of the decorated method - args_info = getfullargspec(old_func) - # get name of the function - func_name = old_func.__name__ - if cls_name is not None: - func_name = f"{cls_name}.{func_name}" - if args: - arg_names = args_info.args[: len(args)] - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in arg_names: - warnings.warn( - f'"{src_arg_name}" is deprecated in `{func_name}`, please use "{dst_arg_name}" instead', - DeprecationWarning, - stacklevel=2, - ) - arg_names[arg_names.index(src_arg_name)] = dst_arg_name - if kwargs: - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in kwargs: - assert dst_arg_name not in kwargs, ( - f"The expected behavior is to replace " - f"the deprecated key `{src_arg_name}` to " - f"new key `{dst_arg_name}`, but got them " - f"in the arguments at the same time, which " - f"is confusing. `{src_arg_name} will be " - f"deprecated in the future, please " - f"use `{dst_arg_name}` instead." - ) - - warnings.warn( - f'"{src_arg_name}" is deprecated in `{func_name}`, please use "{dst_arg_name}" instead', - DeprecationWarning, - stacklevel=2, - ) - kwargs[dst_arg_name] = kwargs.pop(src_arg_name) - - # apply converted arguments to the decorated method - output = old_func(*args, **kwargs) - return output - - return new_func - - return api_warning_wrapper - - -def is_method_overridden(method: str, base_class: type, derived_class: type | Any) -> bool: - """Check if a method of base class is overridden in derived class. - - Args: - method (str): the method name to check. - base_class (type): the class of the base class. - derived_class (type | Any): the class or instance of the derived class. - """ - assert isinstance(base_class, type), "base_class doesn't accept instance, Please pass class instead." - - if not isinstance(derived_class, type): - derived_class = derived_class.__class__ - - base_method = getattr(base_class, method) - derived_method = getattr(derived_class, method) - return derived_method != base_method - - -def has_method(obj: object, method: str) -> bool: - """Check whether the object has a method. - - Args: - method (str): The method name to check. - obj (object): The object to check. - - Returns: - bool: True if the object has the method else False. - """ - return hasattr(obj, method) and callable(getattr(obj, method)) - - -def deprecated_function(since: str, removed_in: str, instructions: str) -> Callable: - """Marks functions as deprecated. - - Throw a warning when a deprecated function is called, and add a note in the - docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py - - Args: - since (str): The version when the function was first deprecated. - removed_in (str): The version when the function will be removed. - instructions (str): The action users should take. - - Returns: - Callable: A new function, which will be deprecated soon. - """ - from visengine.logging import print_log - - def decorator(function): - @functools.wraps(function) - def wrapper(*args, **kwargs): - print_log( - f"'{function.__module__}.{function.__name__}' " - f"is deprecated in version {since} and will be " - f"removed in version {removed_in}. Please {instructions}.", - logger="current", - level=logging.WARNING, - ) - return function(*args, **kwargs) - - indent = " " - # Add a deprecation note to the docstring. - docstring = function.__doc__ or "" - # Add a note to the docstring. - deprecation_note = textwrap.dedent( - f"""\ - .. deprecated:: {since} - Deprecated and will be removed in version {removed_in}. - Please {instructions}. - """ - ) - # Split docstring at first occurrence of newline - pattern = "\n\n" - summary_and_body = re.split(pattern, docstring, 1) - - if len(summary_and_body) > 1: - summary, body = summary_and_body - body = textwrap.indent(textwrap.dedent(body), indent) - summary = "\n".join([textwrap.dedent(string) for string in summary.split("\n")]) - summary = textwrap.indent(summary, prefix=indent) - # Dedent the body. We cannot do this with the presence of the - # summary because the body contains leading whitespaces when the - # summary does not. - new_docstring_parts = [deprecation_note, "\n\n", summary, "\n\n", body] - else: - summary = summary_and_body[0] - summary = "\n".join([textwrap.dedent(string) for string in summary.split("\n")]) - summary = textwrap.indent(summary, prefix=indent) - new_docstring_parts = [deprecation_note, "\n\n", summary] - - wrapper.__doc__ = "".join(new_docstring_parts) - - return wrapper - - return decorator - - -def get_object_from_string(obj_name: str): - """Get object from name. - - Args: - obj_name (str): The name of the object. - - Examples: - >>> get_object_from_string('torch.optim.sgd.SGD') - >>> torch.optim.sgd.SGD - """ - parts = iter(obj_name.split(".")) - module_name = next(parts) - # import module - while True: - try: - module = import_module(module_name) - part = next(parts) - # mmcv.ops has nms.py and nms function at the same time. So the - # function will have a higher priority - obj = getattr(module, part, None) - if obj is not None and not ismodule(obj): - break - module_name = f"{module_name}.{part}" - except StopIteration: - # if obj is a module - return module - except ImportError: - return None - - # get class or attribute from module - obj = module - while True: - try: - obj = getattr(obj, part) - part = next(parts) - except StopIteration: - return obj - except AttributeError: - return None diff --git a/libs/visengine/visengine/utils/package_utils.py b/libs/visengine/visengine/utils/package_utils.py deleted file mode 100644 index c034b47..0000000 --- a/libs/visengine/visengine/utils/package_utils.py +++ /dev/null @@ -1,104 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -import subprocess - - -def is_installed(package: str) -> bool: - """Check package whether installed. - - Args: - package (str): Name of package to be checked. - """ - # When executing `import visengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. - import importlib.util - - import pkg_resources - from pkg_resources import get_distribution - - # refresh the pkg_resources - # more datails at https://github.com/pypa/setuptools/issues/373 - importlib.reload(pkg_resources) - try: - get_distribution(package) - return True - except pkg_resources.DistributionNotFound: - spec = importlib.util.find_spec(package) - if spec is None: - return False - elif spec.origin is not None: - return True - else: - return False - - -def get_installed_path(package: str) -> str: - """Get installed path of package. - - Args: - package (str): Name of package. - - Example: - >>> get_installed_path('mmcls') - >>> '.../lib/python3.7/site-packages/mmcls' - """ - import importlib.util - - from pkg_resources import DistributionNotFound, get_distribution - - # if the package name is not the same as module name, module name should be - # inferred. For example, mmcv-full is the package name, but mmcv is module - # name. If we want to get the installed path of mmcv-full, we should concat - # the pkg.location and module name - try: - pkg = get_distribution(package) - except DistributionNotFound as e: - # if the package is not installed, package path set in PYTHONPATH - # can be detected by `find_spec` - spec = importlib.util.find_spec(package) - if spec is not None: - if spec.origin is not None: - return osp.dirname(spec.origin) - else: - # `get_installed_path` cannot get the installed path of - # namespace packages - raise RuntimeError(f"{package} is a namespace package, which is invalid for `get_install_path`") - else: - raise e - - possible_path = osp.join(pkg.location, package) # type: ignore - if osp.exists(possible_path): - return possible_path - else: - return osp.join(pkg.location, package2module(package)) # type: ignore - - -def package2module(package: str): - """Infer module name from package. - - Args: - package (str): Package to infer module name. - """ - from pkg_resources import get_distribution - - pkg = get_distribution(package) - if pkg.has_metadata("top_level.txt"): - module_name = pkg.get_metadata("top_level.txt").split("\n")[0] - return module_name - else: - raise ValueError(f"can not infer the module name of {package}") - - -def call_command(cmd: list) -> None: - try: - subprocess.check_call(cmd) - except Exception as e: - raise e # type: ignore - - -def install_package(package: str): - if not is_installed(package): - call_command(["python", "-m", "pip", "install", package]) diff --git a/libs/visengine/visengine/utils/path.py b/libs/visengine/visengine/utils/path.py deleted file mode 100644 index 1ce5621..0000000 --- a/libs/visengine/visengine/utils/path.py +++ /dev/null @@ -1,116 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -from pathlib import Path - -from .misc import is_str - - -def is_filepath(x): - return is_str(x) or isinstance(x, Path) - - -def fopen(filepath, *args, **kwargs): - if is_str(filepath): - return open(filepath, *args, **kwargs) - elif isinstance(filepath, Path): - return filepath.open(*args, **kwargs) - raise ValueError("`filepath` should be a string or a Path") - - -def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): - if not osp.isfile(filename): - raise FileNotFoundError(msg_tmpl.format(filename)) - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == "": - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def symlink(src, dst, overwrite=True, **kwargs): - if os.path.lexists(dst) and overwrite: - os.remove(dst) - os.symlink(src, dst, **kwargs) - - -def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): - """Scan a directory to find the interested files. - - Args: - dir_path (str | :obj:`Path`): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Defaults to None. - recursive (bool, optional): If set to True, recursively scan the - directory. Defaults to False. - case_sensitive (bool, optional) : If set to False, ignore the case of - suffix. Defaults to True. - - Returns: - A generator for all the interested files with relative paths. - """ - if isinstance(dir_path, str | Path): - dir_path = str(dir_path) - else: - raise TypeError('"dir_path" must be a string or Path object') - - if (suffix is not None) and not isinstance(suffix, str | tuple): - raise TypeError('"suffix" must be a string or tuple of strings') - - if suffix is not None and not case_sensitive: - suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix) - - root = dir_path - - def _scandir(dir_path, suffix, recursive, case_sensitive): - for entry in os.scandir(dir_path): - if not entry.name.startswith(".") and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - _rel_path = rel_path if case_sensitive else rel_path.lower() - if suffix is None or _rel_path.endswith(suffix): - yield rel_path - elif recursive and os.path.isdir(entry.path): - # scan recursively if entry.path is a directory - yield from _scandir(entry.path, suffix, recursive, case_sensitive) - - return _scandir(dir_path, suffix, recursive, case_sensitive) - - -def find_vcs_root(path, markers=(".git",)): - """Finds the root directory (including itself) of specified markers. - - Args: - path (str): Path of directory or file. - markers (list[str], optional): List of file or directory names. - - Returns: - The directory contained one of the markers or None if not found. - """ - if osp.isfile(path): - path = osp.dirname(path) - - prev, cur = None, osp.abspath(osp.expanduser(path)) - while cur != prev: - if any(osp.exists(osp.join(cur, marker)) for marker in markers): - return cur - prev, cur = cur, osp.split(cur)[0] - return None - - -def is_abs(path: str) -> bool: - """Check if path is an absolute path in different backends. - - Args: - path (str): path of directory or file. - - Returns: - bool: whether path is an absolute path. - """ - if osp.isabs(path) or path.startswith(("http://", "https://", "s3://")): - return True - else: - return False diff --git a/libs/visengine/visengine/utils/progressbar.py b/libs/visengine/visengine/utils/progressbar.py deleted file mode 100644 index 2b220ae..0000000 --- a/libs/visengine/visengine/utils/progressbar.py +++ /dev/null @@ -1,239 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import sys -from collections.abc import Callable, Iterable, Sequence -from multiprocessing import Pool -from shutil import get_terminal_size - -from .timer import Timer - - -class ProgressBar: - """A progress bar which can print the progress. - - Args: - task_num (int): Number of total steps. Defaults to 0. - bar_width (int): Width of the progress bar. Defaults to 50. - start (bool): Whether to start the progress bar in the constructor. - Defaults to True. - file (callable): Progress bar output mode. Defaults to "sys.stdout". - - Examples: - >>> import visengine - >>> import time - >>> bar = mmengine.ProgressBar(10) - >>> for i in range(10): - >>> bar.update() - >>> time.sleep(1) - """ - - def __init__( - self, - task_num: int = 0, - bar_width: int = 50, - start: bool = True, - file=sys.stdout, - ): - self.task_num = task_num - self.bar_width = bar_width - self.completed = 0 - self.file = file - if start: - self.start() - - @property - def terminal_width(self): - width, _ = get_terminal_size() - return width - - def start(self): - if self.task_num > 0: - self.file.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, elapsed: 0s, ETA:") - else: - self.file.write("completed: 0, elapsed: 0s") - self.file.flush() - self.timer = Timer() - - def update(self, num_tasks: int = 1): - """Update progressbar. - - Args: - num_tasks (int): Update step size. - """ - assert num_tasks > 0 - self.completed += num_tasks - elapsed = self.timer.since_start() - if elapsed > 0: - fps = self.completed / elapsed - else: - fps = float("inf") - if self.task_num > 0: - percentage = self.completed / float(self.task_num) - eta = int(elapsed * (1 - percentage) / percentage + 0.5) - msg = f"\r[{{}}] {self.completed}/{self.task_num}, {fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ETA: {eta:5}s" - - bar_width = min( - self.bar_width, - int(self.terminal_width - len(msg)) + 2, - int(self.terminal_width * 0.6), - ) - bar_width = max(2, bar_width) - mark_width = int(bar_width * percentage) - bar_chars = ">" * mark_width + " " * (bar_width - mark_width) - self.file.write(msg.format(bar_chars)) - else: - self.file.write(f"completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, {fps:.1f} tasks/s") - self.file.flush() - - -def track_progress(func: Callable, tasks: Sequence, bar_width: int = 50, file=sys.stdout, **kwargs): - """Track the progress of tasks execution with a progress bar. - - Tasks are done with a simple for-loop. - - Args: - func (callable): The function to be applied to each task. - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - bar_width (int): Width of progress bar. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError(f'"tasks" must be a tuple object or a sequence object, but got {type(tasks)}') - prog_bar = ProgressBar(task_num, bar_width, file=file) - results = [] - for task in tasks: - results.append(func(task, **kwargs)) - prog_bar.update() - prog_bar.file.write("\n") - return results - - -def init_pool(process_num, initializer=None, initargs=None): - if initializer is None: - return Pool(process_num) - elif initargs is None: - return Pool(process_num, initializer) - else: - if not isinstance(initargs, tuple): - raise TypeError('"initargs" must be a tuple') - return Pool(process_num, initializer, initargs) - - -def track_parallel_progress( - func: Callable, - tasks: Sequence, - nproc: int, - initializer: Callable | None = None, - initargs: tuple | None = None, - bar_width: int = 50, - chunksize: int = 1, - skip_first: bool = False, - keep_order: bool = True, - file=sys.stdout, -): - """Track the progress of parallel task execution with a progress bar. - - The built-in :mod:`multiprocessing` module is used for process pools and - tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. - - Args: - func (callable): The function to be applied to each task. - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - nproc (int): Process (worker) number. - initializer (None or callable): Refer to :class:`multiprocessing.Pool` - for details. - initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for - details. - chunksize (int): Refer to :class:`multiprocessing.Pool` for details. - bar_width (int): Width of progress bar. - skip_first (bool): Whether to skip the first sample for each worker - when estimating fps, since the initialization step may takes - longer. - keep_order (bool): If True, :func:`Pool.imap` is used, otherwise - :func:`Pool.imap_unordered` is used. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError(f'"tasks" must be a tuple object or a sequence object, but got {type(tasks)}') - pool = init_pool(nproc, initializer, initargs) - start = not skip_first - task_num -= nproc * chunksize * int(skip_first) - prog_bar = ProgressBar(task_num, bar_width, start, file=file) - results = [] - if keep_order: - gen = pool.imap(func, tasks, chunksize) - else: - gen = pool.imap_unordered(func, tasks, chunksize) - for result in gen: - results.append(result) - if skip_first: - if len(results) < nproc * chunksize: - continue - elif len(results) == nproc * chunksize: - prog_bar.start() - continue - prog_bar.update() - prog_bar.file.write("\n") - pool.close() - pool.join() - return results - - -def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout): - """Track the progress of tasks iteration or enumeration with a progress - bar. - - Tasks are yielded with a simple for-loop. - - Args: - tasks (Sequence): If tasks is a tuple, it must contain two elements, - the first being the tasks to be completed and the other being the - number of tasks. If it is not a tuple, it represents the tasks to - be completed. - bar_width (int): Width of progress bar. - - Yields: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] # type: ignore - elif isinstance(tasks, Sequence): - task_num = len(tasks) - else: - raise TypeError(f'"tasks" must be a tuple object or a sequence object, but got {type(tasks)}') - prog_bar = ProgressBar(task_num, bar_width, file=file) - for task in tasks: - yield task - prog_bar.update() - prog_bar.file.write("\n") diff --git a/libs/visengine/visengine/utils/progressbar_rich.py b/libs/visengine/visengine/utils/progressbar_rich.py deleted file mode 100644 index 1af8c87..0000000 --- a/libs/visengine/visengine/utils/progressbar_rich.py +++ /dev/null @@ -1,159 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from collections.abc import Callable, Iterable, Sized -from multiprocessing import Pool - -from rich.progress import ( - BarColumn, - MofNCompleteColumn, - Progress, - Task, - TaskProgressColumn, - TextColumn, - TimeRemainingColumn, -) -from rich.text import Text - - -class _Worker: - """Function wrapper for ``track_progress_rich``""" - - def __init__(self, func) -> None: - self.func = func - - def __call__(self, inputs): - inputs, idx = inputs - if not isinstance(inputs, tuple | list): - inputs = (inputs,) - - return self.func(*inputs), idx - - -class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): - """Skip calculating remaining time for the first few times. - - Args: - skip_times (int): The number of times to skip. Defaults to 0. - """ - - def __init__(self, *args, skip_times=0, **kwargs): - super().__init__(*args, **kwargs) - self.skip_times = skip_times - - def render(self, task: Task) -> Text: - """Show time remaining.""" - if task.completed <= self.skip_times: - return Text("-:--:--", style="progress.remaining") - return super().render(task) - - -def _tasks_with_index(tasks): - """Add index to tasks.""" - for idx, task in enumerate(tasks): - yield task, idx - - -def track_progress_rich( - func: Callable, - tasks: Iterable = (), - task_num: int | None = None, - nproc: int = 1, - chunksize: int = 1, - description: str = "Processing", - color: str = "blue", -) -> list: - """Track the progress of parallel task execution with a progress bar. The - built-in :mod:`multiprocessing` module is used for process pools and tasks - are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. - - Args: - func (callable): The function to be applied to each task. - tasks (Iterable or Sized): A tuple of tasks. There are several cases - for different format tasks: - - When ``func`` accepts no arguments: tasks should be an empty - tuple, and ``task_num`` must be specified. - - When ``func`` accepts only one argument: tasks should be a tuple - containing the argument. - - When ``func`` accepts multiple arguments: tasks should be a - tuple, with each element representing a set of arguments. - If an element is a ``dict``, it will be parsed as a set of - keyword-only arguments. - Defaults to an empty tuple. - task_num (int, optional): If ``tasks`` is an iterator which does not - have length, the number of tasks can be provided by ``task_num``. - Defaults to None. - nproc (int): Process (worker) number, if nuproc is 1, - use single process. Defaults to 1. - chunksize (int): Refer to :class:`multiprocessing.Pool` for details. - Defaults to 1. - description (str): The description of progress bar. - Defaults to "Process". - color (str): The color of progress bar. Defaults to "blue". - - Examples: - >>> import time - - >>> def func(x): - ... time.sleep(1) - ... return x**2 - >>> track_progress_rich(func, range(10), nproc=2) - - Returns: - list: The task results. - """ - if not callable(func): - raise TypeError("func must be a callable object") - if not isinstance(tasks, Iterable): - raise TypeError(f"tasks must be an iterable object, but got {type(tasks)}") - if isinstance(tasks, Sized): - if len(tasks) == 0: - if task_num is None: - raise ValueError("If tasks is an empty iterable, task_num must be set") - else: - tasks = tuple(() for _ in range(task_num)) - else: - if task_num is not None and task_num != len(tasks): - raise ValueError("task_num does not match the length of tasks") - task_num = len(tasks) - - if nproc <= 0: - raise ValueError("nproc must be a positive number") - - skip_times = nproc * chunksize if nproc > 1 else 0 - prog_bar = Progress( - TextColumn("{task.description}"), - BarColumn(), - _SkipFirstTimeRemainingColumn(skip_times=skip_times), - MofNCompleteColumn(), - TaskProgressColumn(show_speed=True), - ) - - worker = _Worker(func) - task_id = prog_bar.add_task(total=task_num, color=color, description=description) - tasks = _tasks_with_index(tasks) - - # Use single process when nproc is 1, else use multiprocess. - with prog_bar: - if nproc == 1: - results = [] - for task in tasks: - results.append(worker(task)[0]) - prog_bar.update(task_id, advance=1, refresh=True) - else: - with Pool(nproc) as pool: - results = [] - unordered_results = [] - gen = pool.imap_unordered(worker, tasks, chunksize) - try: - for result in gen: - result, idx = result - unordered_results.append((result, idx)) - results.append(None) - prog_bar.update(task_id, advance=1, refresh=True) - except Exception as e: - prog_bar.stop() - raise e - for result, idx in unordered_results: - results[idx] = result - return results diff --git a/libs/visengine/visengine/utils/timer.py b/libs/visengine/visengine/utils/timer.py deleted file mode 100644 index a52716c..0000000 --- a/libs/visengine/visengine/utils/timer.py +++ /dev/null @@ -1,119 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from time import time - - -class TimerError(Exception): - def __init__(self, message): - self.message = message - super().__init__(message) - - -class Timer: - """A flexible Timer class. - - Examples: - >>> import time - >>> import mmcv - >>> with mmcv.Timer(): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - 1.000 - >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - it takes 1.0 seconds - >>> timer = mmcv.Timer() - >>> time.sleep(0.5) - >>> print(timer.since_start()) - 0.500 - >>> time.sleep(0.5) - >>> print(timer.since_last_check()) - 0.500 - >>> print(timer.since_start()) - 1.000 - """ - - def __init__(self, start=True, print_tmpl=None): - self._is_running = False - self.print_tmpl = print_tmpl if print_tmpl else "{:.3f}" - if start: - self.start() - - @property - def is_running(self): - """bool: indicate whether the timer is running""" - return self._is_running - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - print(self.print_tmpl.format(self.since_last_check())) - self._is_running = False - - def start(self): - """Start the timer.""" - if not self._is_running: - self._t_start = time() - self._is_running = True - self._t_last = time() - - def since_start(self): - """Total time since the timer is started. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError("timer is not running") - self._t_last = time() - return self._t_last - self._t_start - - def since_last_check(self): - """Time since the last checking. - - Either :func:`since_start` or :func:`since_last_check` is a checking - operation. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError("timer is not running") - dur = time() - self._t_last - self._t_last = time() - return dur - - -_g_timers = {} # global timers - - -def check_time(timer_id): - """Add check points in a single line. - - This method is suitable for running a task on a list of items. A timer will - be registered when the method is called for the first time. - - Examples: - >>> import time - >>> import mmcv - >>> for i in range(1, 6): - >>> # simulate a code block - >>> time.sleep(i) - >>> mmcv.check_time('task1') - 2.000 - 3.000 - 4.000 - 5.000 - - Args: - str: Timer identifier. - """ - if timer_id not in _g_timers: - _g_timers[timer_id] = Timer() - return 0 - else: - return _g_timers[timer_id].since_last_check() diff --git a/libs/visengine/visengine/utils/version_utils.py b/libs/visengine/visengine/utils/version_utils.py deleted file mode 100644 index c9660c0..0000000 --- a/libs/visengine/visengine/utils/version_utils.py +++ /dev/null @@ -1,93 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import os -import subprocess -import warnings - -from packaging.version import parse - - -def digit_version(version_str: str, length: int = 4): - """Convert a version string into a tuple of integers. - - This method is usually used for comparing two versions. For pre-release - versions: alpha < beta < rc. - - Args: - version_str (str): The version string. - length (int): The maximum number of version levels. Defaults to 4. - - Returns: - tuple[int]: The version info in digits (integers). - """ - assert "parrots" not in version_str - version = parse(version_str) - assert version.release, f"failed to parse version {version_str}" - release = list(version.release) - release = release[:length] - if len(release) < length: - release = release + [0] * (length - len(release)) - if version.is_prerelease: - mapping = {"a": -3, "b": -2, "rc": -1} - val = -4 - # version.pre can be None - if version.pre: - if version.pre[0] not in mapping: - warnings.warn( - f"unknown prerelease version {version.pre[0]}, version checking may go wrong", - stacklevel=2, - ) - else: - val = mapping[version.pre[0]] - release.extend([val, version.pre[-1]]) - else: - release.extend([val, 0]) - - elif version.is_postrelease: - release.extend([1, version.post]) # type: ignore - else: - release.extend([0, 0]) - return tuple(release) - - -def _minimal_ext_cmd(cmd): - # construct minimal environment - env = {} - for k in ["SYSTEMROOT", "PATH", "HOME"]: - v = os.environ.get(k) - if v is not None: - env[k] = v - # LANGUAGE is used on win32 - env["LANGUAGE"] = "C" - env["LANG"] = "C" - env["LC_ALL"] = "C" - out, err = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env).communicate() - return out - - -def get_git_hash(fallback="unknown", digits=None): - """Get the git hash of the current repo. - - Args: - fallback (str, optional): The fallback string when git hash is - unavailable. Defaults to 'unknown'. - digits (int, optional): kept digits of the hash. Defaults to None, - meaning all digits are kept. - - Returns: - str: Git commit hash. - """ - - if digits is not None and not isinstance(digits, int): - raise TypeError("digits must be None or an integer") - - try: - out = _minimal_ext_cmd(["git", "rev-parse", "HEAD"]) - sha = out.strip().decode("ascii") - if digits is not None: - sha = sha[:digits] - except OSError: - sha = fallback - - return sha diff --git a/libs/visengine/visengine/version.py b/libs/visengine/visengine/version.py deleted file mode 100644 index 1e16cf5..0000000 --- a/libs/visengine/visengine/version.py +++ /dev/null @@ -1,28 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. - -__version__ = "0.10.6" - - -def parse_version_info(version_str): - """Parse the version information. - - Args: - version_str (str): version string like '0.1.0'. - - Returns: - tuple: version information contains major, minor, micro version. - """ - version_info = [] - for x in version_str.split("."): - if x.isdigit(): - version_info.append(int(x)) - elif x.find("rc") != -1: - patch_version = x.split("rc") - version_info.append(int(patch_version[0])) - version_info.append(f"rc{patch_version[1]}") - return tuple(version_info) - - -version_info = parse_version_info(__version__) diff --git a/libs/visengine/visengine/visualization/__init__.py b/libs/visengine/visengine/visualization/__init__.py deleted file mode 100644 index 2dff0a6..0000000 --- a/libs/visengine/visengine/visualization/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -from .vis_backend import ( - AimVisBackend, - BaseVisBackend, - ClearMLVisBackend, - DVCLiveVisBackend, - LocalVisBackend, - MLflowVisBackend, - NeptuneVisBackend, - TensorboardVisBackend, - WandbVisBackend, -) -from .visualizer import Visualizer - -__all__ = [ - "AimVisBackend", - "BaseVisBackend", - "ClearMLVisBackend", - "DVCLiveVisBackend", - "LocalVisBackend", - "MLflowVisBackend", - "NeptuneVisBackend", - "TensorboardVisBackend", - "Visualizer", - "WandbVisBackend", -] diff --git a/libs/visengine/visengine/visualization/utils.py b/libs/visengine/visengine/visualization/utils.py deleted file mode 100644 index bdb3867..0000000 --- a/libs/visengine/visengine/visualization/utils.py +++ /dev/null @@ -1,243 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. - -from typing import TYPE_CHECKING, Any - -import cv2 -import numpy as np -import torch - -if TYPE_CHECKING: - from matplotlib.backends.backend_agg import FigureCanvasAgg - - -def tensor2ndarray(value: np.ndarray | torch.Tensor) -> np.ndarray: - """If the type of value is torch.Tensor, convert the value to np.ndarray. - - Args: - value (np.ndarray, torch.Tensor): value. - - Returns: - Any: value. - """ - if isinstance(value, torch.Tensor): - value = value.detach().cpu().numpy() - return value - - -def value2list(value: Any, valid_type: type | tuple[type, ...], expand_dim: int) -> list[Any]: - """If the type of ``value`` is ``valid_type``, convert the value to list - and expand to ``expand_dim``. - - Args: - value (Any): value. - valid_type (Union[Type, Tuple[Type, ...]): valid type. - expand_dim (int): expand dim. - - Returns: - List[Any]: value. - """ - if isinstance(value, valid_type): - value = [value] * expand_dim - return value - - -def check_type(name: str, value: Any, valid_type: type | tuple[type, ...]) -> None: - """Check whether the type of value is in ``valid_type``. - - Args: - name (str): value name. - value (Any): value. - valid_type (Type, Tuple[Type, ...]): expected type. - """ - if not isinstance(value, valid_type): - raise TypeError(f"`{name}` should be {valid_type} but got {type(value)}") - - -def check_length(name: str, value: Any, valid_length: int) -> None: - """If type of the ``value`` is list, check whether its length is equal with - or greater than ``valid_length``. - - Args: - name (str): value name. - value (Any): value. - valid_length (int): expected length. - """ - if isinstance(value, list): - if len(value) < valid_length: - raise AssertionError( - f"The length of {name} must equal with or greater than {valid_length}, but got {len(value)}" - ) - - -def check_type_and_length(name: str, value: Any, valid_type: type | tuple[type, ...], valid_length: int) -> None: - """Check whether the type of value is in ``valid_type``. If type of the - ``value`` is list, check whether its length is equal with or greater than - ``valid_length``. - - Args: - value (Any): value. - legal_type (Type, Tuple[Type, ...]): legal type. - valid_length (int): expected length. - - Returns: - List[Any]: value. - """ - check_type(name, value, valid_type) - check_length(name, value, valid_length) - - -def color_val_matplotlib( - colors: str | tuple | list[str | tuple], -) -> str | tuple | list[str | tuple]: - """Convert various input in RGB order to normalized RGB matplotlib color - tuples, - Args: - colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs - Returns: - Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized - floats indicating RGB channels. - """ - if isinstance(colors, str): - return colors - elif isinstance(colors, tuple): - assert len(colors) == 3 - for channel in colors: - assert 0 <= channel <= 255 - colors = [channel / 255 for channel in colors] - return tuple(colors) - elif isinstance(colors, list): - colors = [ - color_val_matplotlib(color) # type:ignore - for color in colors - ] - return colors - else: - raise TypeError(f"Invalid type for color: {type(colors)}") - - -def color_str2rgb(color: str) -> tuple: - """Convert Matplotlib str color to an RGB color which range is 0 to 255, - silently dropping the alpha channel. - - Args: - color (str): Matplotlib color. - - Returns: - tuple: RGB color. - """ - import matplotlib - - rgb_color: tuple = matplotlib.colors.to_rgb(color) - rgb_color = tuple(int(c * 255) for c in rgb_color) - return rgb_color - - -def convert_overlay_heatmap( - feat_map: np.ndarray | torch.Tensor, - img: np.ndarray | None = None, - alpha: float = 0.5, -) -> np.ndarray: - """Convert feat_map to heatmap and overlay on image, if image is not None. - - Args: - feat_map (np.ndarray, torch.Tensor): The feat_map to convert - with of shape (H, W), where H is the image height and W is - the image width. - img (np.ndarray, optional): The origin image. The format - should be RGB. Defaults to None. - alpha (float): The transparency of featmap. Defaults to 0.5. - - Returns: - np.ndarray: heatmap - """ - assert feat_map.ndim == 2 or (feat_map.ndim == 3 and feat_map.shape[0] in [1, 3]) - if isinstance(feat_map, torch.Tensor): - feat_map = feat_map.detach().cpu().numpy() - - if feat_map.ndim == 3: - feat_map = feat_map.transpose(1, 2, 0) - - norm_img = np.zeros(feat_map.shape) - norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) - norm_img = np.asarray(norm_img, dtype=np.uint8) - heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) - heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) - if img is not None: - heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) - return heat_img - - -def wait_continue(figure, timeout: float = 0, continue_key: str = " ") -> int: - """Show the image and wait for the user's input. - - This implementation refers to - https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py - - Args: - timeout (float): If positive, continue after ``timeout`` seconds. - Defaults to 0. - continue_key (str): The key for users to continue. Defaults to - the space key. - - Returns: - int: If zero, means time out or the user pressed ``continue_key``, - and if one, means the user closed the show figure. - """ - import matplotlib.pyplot as plt - from matplotlib.backend_bases import CloseEvent - - is_inline = "inline" in plt.get_backend() - if is_inline: - # If use inline backend, interactive input and timeout is no use. - return 0 - - if figure.canvas.manager: # type: ignore - # Ensure that the figure is shown - figure.show() # type: ignore - - while True: - # Connect the events to the handler function call. - event = None - - def handler(ev): - # Set external event variable - nonlocal event - # Qt backend may fire two events at the same time, - # use a condition to avoid missing close event. - event = ev if not isinstance(event, CloseEvent) else event - figure.canvas.stop_event_loop() - - cids = [ - figure.canvas.mpl_connect(name, handler) # type: ignore - for name in ("key_press_event", "close_event") - ] - - try: - figure.canvas.start_event_loop(timeout) # type: ignore - finally: # Run even on exception like ctrl-c. - # Disconnect the callbacks. - for cid in cids: - figure.canvas.mpl_disconnect(cid) # type: ignore - - if isinstance(event, CloseEvent): - return 1 # Quit for close. - elif event is None or event.key == continue_key: - return 0 # Quit for continue. - - -def img_from_canvas(canvas: "FigureCanvasAgg") -> np.ndarray: - """Get RGB image from ``FigureCanvasAgg``. - - Args: - canvas (FigureCanvasAgg): The canvas to get image. - - Returns: - np.ndarray: the output of image in RGB. - """ - s, (width, height) = canvas.print_to_buffer() - buffer = np.frombuffer(s, dtype="uint8") - img_rgba = buffer.reshape(height, width, 4) - rgb, alpha = np.split(img_rgba, [3], axis=2) - return rgb.astype("uint8") diff --git a/libs/visengine/visengine/visualization/vis_backend.py b/libs/visengine/visengine/visualization/vis_backend.py deleted file mode 100644 index 6b165ee..0000000 --- a/libs/visengine/visengine/visualization/vis_backend.py +++ /dev/null @@ -1,1346 +0,0 @@ -# ruff: noqa -# type: ignore -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import functools -import logging -import os -import os.path as osp -import platform -import warnings -from abc import ABCMeta, abstractmethod -from collections.abc import Callable, MutableMapping, Sequence -from typing import Any - -import cv2 -import numpy as np -import torch - -from visengine.config import Config, ConfigDict -from visengine.fileio import dump -from visengine.hooks.logger_hook import SUFFIX_TYPE -from visengine.logging import MMLogger, print_log -from visengine.registry import VISBACKENDS -from visengine.utils import digit_version, scandir - - -def force_init_env(old_func: Callable) -> Any: - """Those methods decorated by ``force_init_env`` will be forced to call - ``_init_env`` if the instance has not been fully initiated. This function - will decorated all the `add_xxx` method and `experiment` method, because - `VisBackend` is initialized only when used its API. - - Args: - old_func (Callable): Decorated function, make sure the first arg is an - instance with ``_init_env`` method. - - Returns: - Any: Depends on old_func. - """ - - @functools.wraps(old_func) - def wrapper(obj: object, *args, **kwargs): - # The instance must have `_init_env` method. - if not hasattr(obj, "_init_env"): - raise AttributeError(f"{type(obj)} does not have _init_env method.") - # If instance does not have `_env_initialized` attribute or - # `_env_initialized` is False, call `_init_env` and set - # `_env_initialized` to True - if not getattr(obj, "_env_initialized", False): - print_log( - "Attribute `_env_initialized` is not defined in " - f"{type(obj)} or `{type(obj)}._env_initialized is " - "False, `_init_env` will be called and " - f"{type(obj)}._env_initialized will be set to True", - logger="current", - level=logging.DEBUG, - ) - obj._init_env() # type: ignore - obj._env_initialized = True # type: ignore - - return old_func(obj, *args, **kwargs) - - return wrapper - - -class BaseVisBackend(metaclass=ABCMeta): - """Base class for visualization backend. - - All backends must inherit ``BaseVisBackend`` and implement - the required functions. - - Args: - save_dir (str, optional): The root directory to save - the files produced by the backend. - """ - - def __init__(self, save_dir: str): - self._save_dir = save_dir - self._env_initialized = False - - @property - @abstractmethod - def experiment(self) -> Any: - """Return the experiment object associated with this visualization - backend. - - The experiment attribute can get the visualization backend, such as - wandb, tensorboard. If you want to write other data, such as writing a - table, you can directly get the visualization backend through - experiment. - """ - pass - - @abstractmethod - def _init_env(self) -> Any: - """Setup env for VisBackend.""" - pass - - def add_config(self, config: Config, **kwargs) -> None: - """Record the config. - - Args: - config (Config): The Config object - """ - pass - - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - pass - - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - pass - - def add_scalar(self, name: str, value: int | float, step: int = 0, **kwargs) -> None: - """Record the scalar. - - Args: - name (str): The scalar identifier. - value (int, float): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - pass - - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - pass - - def close(self) -> None: - """Close an opened object.""" - pass - - -@VISBACKENDS.register_module(force=True) -class LocalVisBackend(BaseVisBackend): - """Local visualization backend class. - - It can write image, config, scalars, etc. - to the local hard disk. You can get the drawing backend - through the experiment property for custom drawing. - - Examples: - >>> from visengine.visualization import LocalVisBackend - >>> import numpy as np - >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> local_vis_backend.add_image('img', img) - >>> local_vis_backend.add_scalar('mAP', 0.6) - >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> local_vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. If it is none, it means no data - is stored. - img_save_dir (str): The directory to save images. - Defaults to 'vis_image'. - config_save_file (str): The file name to save config. - Defaults to 'config.py'. - scalar_save_file (str): The file name to save scalar values. - Defaults to 'scalars.json'. - """ - - def __init__( - self, - save_dir: str, - img_save_dir: str = "vis_image", - config_save_file: str = "config.py", - scalar_save_file: str = "scalars.json", - ): - assert config_save_file.split(".")[-1] == "py" - assert scalar_save_file.split(".")[-1] == "json" - super().__init__(save_dir) - self._img_save_dir = img_save_dir - self._config_save_file = config_save_file - self._scalar_save_file = scalar_save_file - - def _init_env(self): - """Init save dir.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) - self._img_save_dir = osp.join(self._save_dir, self._img_save_dir) # type: ignore - self._config_save_file = osp.join(self._save_dir, self._config_save_file) # type: ignore - self._scalar_save_file = osp.join(self._save_dir, self._scalar_save_file) # type: ignore - - @property # type: ignore - @force_init_env - def experiment(self) -> "LocalVisBackend": - """Return the experiment object associated with this visualization - backend.""" - return self - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to disk. - - Args: - config (Config): The Config object - """ - assert isinstance(config, Config) - config.dump(self._config_save_file) - - @force_init_env - def add_image(self, name: str, image: np.array, step: int = 0, **kwargs) -> None: - """Record the image to disk. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - assert image.dtype == np.uint8 - drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - os.makedirs(self._img_save_dir, exist_ok=True) - save_file_name = f"{name}_{step}.png" - cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to disk. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, torch.Tensor): - value = value.item() - self._dump({name: value, "step": step}, self._scalar_save_file, "json") - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalars to disk. - - The scalar dict will be written to the default and - specified files if ``file_path`` is specified. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. The value must be dumped - into json format. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the ``file_path`` file at the same time - if the ``file_path`` parameter is specified. - Defaults to None. - """ - assert isinstance(scalar_dict, dict) - scalar_dict = copy.deepcopy(scalar_dict) - scalar_dict.setdefault("step", step) - - if file_path is not None: - assert file_path.split(".")[-1] == "json" - new_save_file_path = osp.join(self._save_dir, file_path) # type: ignore - assert new_save_file_path != self._scalar_save_file, ( - "``file_path`` and ``scalar_save_file`` have the same name, please set ``file_path`` to another value" - ) - self._dump(scalar_dict, new_save_file_path, "json") - self._dump(scalar_dict, self._scalar_save_file, "json") - - def _dump(self, value_dict: dict, file_path: str, file_format: str) -> None: - """Dump dict to file. - - Args: - value_dict (dict) : The dict data to saved. - file_path (str): The file path to save data. - file_format (str): The file format to save data. - """ - with open(file_path, "a+") as f: - dump(value_dict, f, file_format=file_format) - f.write("\n") - - -@VISBACKENDS.register_module(force=True) -class WandbVisBackend(BaseVisBackend): - """Wandb visualization backend class. - - Examples: - >>> from visengine.visualization import WandbVisBackend - >>> import numpy as np - >>> wandb_vis_backend = WandbVisBackend() - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> wandb_vis_backend.add_image('img', img) - >>> wandb_vis_backend.add_scaler('mAP', 0.6) - >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> wandb_vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - init_kwargs (dict, optional): wandb initialization - input parameters. - See `wandb.init `_ for - details. Defaults to None. - define_metric_cfg (dict or list[dict], optional): - When a dict is set, it is a dict of metrics and summary for - ``wandb.define_metric``. - The key is metric and the value is summary. - When a list is set, each dict should be a valid argument of - the ``define_metric``. - For example, ``define_metric_cfg={'coco/bbox_mAP': 'max'}``, - means the maximum value of ``coco/bbox_mAP`` is logged on wandb UI. - When ``define_metric_cfg=[dict(name='loss', - step_metric='epoch')]``, - the "loss" will be plotted against the epoch. - See `wandb define_metric `_ for details. - Defaults to None. - commit (bool, optional) Save the metrics dict to the wandb server - and increment the step. If false `wandb.log` just updates the - current metrics dict with the row argument and metrics won't be - saved until `wandb.log` is called with `commit=True`. - Defaults to True. - log_code_name (str, optional) The name of code artifact. - By default, the artifact will be named - source-$PROJECT_ID-$ENTRYPOINT_RELPATH. See - `wandb log_code `_ - for details. Defaults to None. - `New in version 0.3.0.` - watch_kwargs (optional, dict): Agurments for ``wandb.watch``. - `New in version 0.4.0.` - """ - - def __init__( - self, - save_dir: str, - init_kwargs: dict | None = None, - define_metric_cfg: dict | list | None = None, - commit: bool | None = True, - log_code_name: str | None = None, - watch_kwargs: dict | None = None, - ): - super().__init__(save_dir) - self._init_kwargs = init_kwargs - self._define_metric_cfg = define_metric_cfg - self._commit = commit - self._log_code_name = log_code_name - self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {} - - def _init_env(self): - """Setup env for wandb.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - if self._init_kwargs is None: - self._init_kwargs = {"dir": self._save_dir} - else: - self._init_kwargs.setdefault("dir", self._save_dir) - try: - import wandb - except ImportError: - raise ImportError('Please run "pip install wandb" to install wandb') - - wandb.init(**self._init_kwargs) - if self._define_metric_cfg is not None: - if isinstance(self._define_metric_cfg, dict): - for metric, summary in self._define_metric_cfg.items(): - wandb.define_metric(metric, summary=summary) - elif isinstance(self._define_metric_cfg, list): - for metric_cfg in self._define_metric_cfg: - wandb.define_metric(**metric_cfg) - else: - raise ValueError("define_metric_cfg should be dict or list") - self._wandb = wandb - - @property # type: ignore - @force_init_env - def experiment(self): - """Return wandb object. - - The experiment attribute can get the wandb backend, If you want to - write other data, such as writing a table, you can directly get the - wandb backend through experiment. - """ - return self._wandb - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to wandb. - - Args: - config (Config): The Config object - """ - assert isinstance(self._init_kwargs, dict) - allow_val_change = self._init_kwargs.get("allow_val_change", False) - self._wandb.config.update(config.to_dict(), allow_val_change=allow_val_change) - self._wandb.run.log_code(name=self._log_code_name) - - @force_init_env - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self._wandb.watch(model, **self._watch_kwargs) - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image to wandb. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - """ - image = self._wandb.Image(image) - self._wandb.log({name: image}, commit=self._commit) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to wandb. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - """ - self._wandb.log({name: value}, commit=self._commit) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Useless parameter. Wandb does not - need this parameter. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - self._wandb.log(scalar_dict, commit=self._commit) - - def close(self) -> None: - """Close an opened wandb object.""" - if hasattr(self, "_wandb"): - self._wandb.join() - - -@VISBACKENDS.register_module(force=True) -class TensorboardVisBackend(BaseVisBackend): - """Tensorboard visualization backend class. - - It can write images, config, scalars, etc. to a - tensorboard file. - - Examples: - >>> from visengine.visualization import TensorboardVisBackend - >>> import numpy as np - >>> vis_backend = TensorboardVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img', img) - >>> vis_backend.add_scaler('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str): The root directory to save the files - produced by the backend. - """ - - def __init__(self, save_dir: str): - super().__init__(save_dir) - - def _init_env(self): - """Setup env for Tensorboard.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - from torch.utils.tensorboard import SummaryWriter - self._tensorboard = SummaryWriter(self._save_dir) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Tensorboard object.""" - return self._tensorboard - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to tensorboard. - - Args: - config (Config): The Config object - """ - self._tensorboard.add_text("config", config.pretty_text) - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image to tensorboard. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Defaults to 0. - """ - self._tensorboard.add_image(name, image, step, dataformats="HWC") - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to tensorboard. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, int | float | torch.Tensor | np.ndarray | np.number): - self._tensorboard.add_scalar(name, value, step) - else: - warnings.warn( - f"Got {type(value)}, but numpy array, torch tensor, int or float are expected. skip it!", - stacklevel=2, - ) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to tensorboard. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert "step" not in scalar_dict, "Please set it directly through the step parameter" - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step) - - def close(self): - """Close an opened tensorboard object.""" - if hasattr(self, "_tensorboard"): - self._tensorboard.close() - - -@VISBACKENDS.register_module(force=True) -class MLflowVisBackend(BaseVisBackend): - """MLflow visualization backend class. - - It can write images, config, scalars, etc. to a - mlflow file. - - Examples: - >>> from visengine.visualization import MLflowVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> vis_backend = MLflowVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img.png', img) - >>> vis_backend.add_scalar('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str): The root directory to save the files - produced by the backend. - exp_name (str, optional): The experiment name. Defaults to None. - run_name (str, optional): The run name. Defaults to None. - tags (dict, optional): The tags to be added to the experiment. - Defaults to None. - params (dict, optional): The params to be added to the experiment. - Defaults to None. - tracking_uri (str, optional): The tracking uri. Defaults to None. - artifact_suffix (Tuple[str] or str, optional): The artifact suffix. - Defaults to ('.json', '.log', '.py', 'yaml'). - tracked_config_keys (dict, optional): The top level keys of config that - will be added to the experiment. If it is None, which means all - the config will be added. Defaults to None. - `New in version 0.7.4.` - artifact_location (str, optional): The location to store run artifacts. - If None, the server picks an appropriate default. - Defaults to None. - `New in version 0.10.4.` - """ - - def __init__( - self, - save_dir: str, - exp_name: str | None = None, - run_name: str | None = None, - tags: dict | None = None, - params: dict | None = None, - tracking_uri: str | None = None, - artifact_suffix: SUFFIX_TYPE = (".json", ".log", ".py", "yaml"), - tracked_config_keys: dict | None = None, - artifact_location: str | None = None, - ): - super().__init__(save_dir) - self._exp_name = exp_name - self._run_name = run_name - self._tags = tags - self._params = params - self._tracking_uri = tracking_uri - self._artifact_suffix = artifact_suffix - self._tracked_config_keys = tracked_config_keys - self._artifact_location = artifact_location - - def _init_env(self): - """Setup env for MLflow.""" - if not os.path.exists(self._save_dir): - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - - try: - import mlflow - except ImportError: - raise ImportError('Please run "pip install mlflow" to install mlflow') # type: ignore - self._mlflow = mlflow - - # when mlflow is imported, a default logger is created. - # at this time, the default logger's stream is None - # so the stream is reopened only when the stream is None - # or the stream is closed - logger = MMLogger.get_current_instance() - for handler in logger.handlers: - if handler.stream is None or handler.stream.closed: - handler.stream = open(handler.baseFilename, "a") - - if self._tracking_uri is not None: - logger.warning("Please make sure that the mlflow server is running.") - self._mlflow.set_tracking_uri(self._tracking_uri) - else: - if os.name == "nt": - file_url = f"file:\\{os.path.abspath(self._save_dir)}" - else: - file_url = f"file://{os.path.abspath(self._save_dir)}" - self._mlflow.set_tracking_uri(file_url) - - self._exp_name = self._exp_name or "Default" - - if self._mlflow.get_experiment_by_name(self._exp_name) is None: - self._mlflow.create_experiment(self._exp_name, artifact_location=self._artifact_location) - - self._mlflow.set_experiment(self._exp_name) - - if self._run_name is not None: - self._mlflow.set_tag("mlflow.runName", self._run_name) - if self._tags is not None: - self._mlflow.set_tags(self._tags) - if self._params is not None: - self._mlflow.log_params(self._params) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return MLflow object.""" - return self._mlflow - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to mlflow. - - Args: - config (Config): The Config object - """ - self.cfg = config - if self._tracked_config_keys is None: - self._mlflow.log_params(self._flatten(self.cfg.to_dict())) - else: - tracked_cfg = {} - for k in self._tracked_config_keys: - tracked_cfg[k] = self.cfg[k] - self._mlflow.log_params(self._flatten(tracked_cfg)) - self._mlflow.log_text(self.cfg.pretty_text, "config.py") - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image to mlflow. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Default to 0. - """ - self._mlflow.log_image(image, name) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to mlflow. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._mlflow.log_metric(name, value, step) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to mlflow. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert "step" not in scalar_dict, "Please set it directly through the step parameter" - self._mlflow.log_metrics(scalar_dict, step) - - def close(self) -> None: - """Close the mlflow.""" - if not hasattr(self, "_mlflow"): - return - - file_paths = {} - for filename in scandir(self.cfg.work_dir, self._artifact_suffix, True): - file_path = osp.join(self.cfg.work_dir, filename) - relative_path = os.path.relpath(file_path, self.cfg.work_dir) - dir_path = os.path.dirname(relative_path) - file_paths[file_path] = dir_path - - for file_path, dir_path in file_paths.items(): - self._mlflow.log_artifact(file_path, dir_path) - - self._mlflow.end_run() - - def _flatten(self, d, parent_key="", sep=".") -> dict: - """Flatten the dict.""" - items = {} - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - items.update(self._flatten(v, new_key, sep=sep)) - elif isinstance(v, list): - if any(isinstance(x, dict) for x in v): - for i, x in enumerate(v): - items.update(self._flatten(x, new_key + sep + str(i), sep=sep)) - else: - items[new_key] = v - else: - items[new_key] = v - return items - - -@VISBACKENDS.register_module(force=True) -class ClearMLVisBackend(BaseVisBackend): - """Clearml visualization backend class. It requires `clearml`_ to be - installed. - - Examples: - >>> from visengine.visualization import ClearMLVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> vis_backend = ClearMLVisBackend(save_dir='temp_dir') - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> vis_backend.add_image('img.png', img) - >>> vis_backend.add_scalar('mAP', 0.6) - >>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis_backend.add_config(cfg) - - Args: - save_dir (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - init_kwargs (dict, optional): A dict contains the arguments of - ``clearml.Task.init`` . See `taskinit`_ for more details. - Defaults to None - artifact_suffix (Tuple[str] or str): The artifact suffix. - Defaults to ('.py', 'pth'). - - .. _clearml: - https://clear.ml/docs/latest/docs/ - - .. _taskinit: - https://clear.ml/docs/latest/docs/references/sdk/task/#taskinit - """ - - def __init__( - self, - save_dir: str | None = None, - init_kwargs: dict | None = None, - artifact_suffix: SUFFIX_TYPE = (".py", ".pth"), - ): - super().__init__(save_dir) # type: ignore - self._init_kwargs = init_kwargs - self._artifact_suffix = artifact_suffix - - def _init_env(self) -> None: - try: - import clearml - except ImportError: - raise ImportError('Please run "pip install clearml" to install clearml') - - task_kwargs = self._init_kwargs or {} - self._clearml = clearml - self._task = self._clearml.Task.init(**task_kwargs) - self._logger = self._task.get_logger() - - @property # type: ignore - @force_init_env - def experiment(self): - """Return clearml object.""" - return self._clearml - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to clearml. - - Args: - config (Config): The Config object - """ - self.cfg = config - self._task.connect_configuration(config.to_dict()) - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image to clearml. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Global step value to record. Defaults to 0. - """ - self._logger.report_image(title=name, series=name, iteration=step, image=image) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to clearml. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - self._logger.report_scalar(title=name, series=name, value=value, iteration=step) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to clearml. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - assert "step" not in scalar_dict, "Please set it directly through the step parameter" - for key, value in scalar_dict.items(): - self._logger.report_scalar(title=key, series=key, value=value, iteration=step) - - def close(self) -> None: - """Close the clearml.""" - if not hasattr(self, "_clearml"): - return - - file_paths: list[str] = [] - if hasattr(self, "cfg") and osp.isdir(getattr(self.cfg, "work_dir", "")): - for filename in scandir(self.cfg.work_dir, self._artifact_suffix, False): - file_path = osp.join(self.cfg.work_dir, filename) - file_paths.append(file_path) - - for file_path in file_paths: - self._task.upload_artifact(os.path.basename(file_path), file_path) - self._task.close() - - -@VISBACKENDS.register_module(force=True) -class NeptuneVisBackend(BaseVisBackend): - """Neptune visualization backend class. - - Examples: - >>> from visengine.visualization import NeptuneVisBackend - >>> from mmengine import Config - >>> import numpy as np - >>> init_kwargs = {'project': 'your_project_name'} - >>> neptune_vis_backend = NeptuneVisBackend(init_kwargs=init_kwargs) - >>> img = np.random.randint(0, 256, size=(10, 10, 3)) - >>> neptune_vis_backend.add_image('img', img) - >>> neptune_vis_backend.add_scalar('mAP', 0.6) - >>> neptune_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> neptune_vis_backend.add_config(cfg) - - Note: - `New in version 0.9.0.` - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. NeptuneVisBackend does - not require this argument. Defaults to None. - init_kwargs (dict, optional): Neptune initialization parameters. - Defaults to None. - - - project (str): Name of a project in a form of - `namespace/project_name`. If `project` is not specified, - the value of `NEPTUNE_PROJECT` environment variable - will be taken. - - api_token (str): User's API token. If api_token is not api_token, - the value of `NEPTUNE_API_TOKEN` environment variable will - be taken. Note: It is strongly recommended to use - `NEPTUNE_API_TOKEN` environment variable rather than - placing your API token here. - - If 'project' and 'api_token are not specified in `init_kwargs`, - the 'mode' will be set to 'offline'. - See `neptune.init_run - `_ for - details. - """ - - def __init__(self, save_dir: str | None = None, init_kwargs: dict | None = None): - super().__init__(save_dir) # type:ignore - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for neptune.""" - try: - import neptune - except ImportError: - raise ImportError('Please run "pip install -U neptune" to install neptune') - if self._init_kwargs is None: - self._init_kwargs = {"mode": "offline"} - - self._neptune = neptune.init_run(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Neptune object.""" - return self._neptune - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to neptune. - - Args: - config (Config): The Config object - """ - from neptune.types import File - - self._neptune["config"].upload(File.from_content(config.pretty_text)) - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - from neptune.types import File - - # values in the array need to be in the [0, 1] range - img = image.astype(np.float32) / 255.0 - self._neptune["images"].append(File.as_image(img), name=name, step=step) - - @force_init_env - def add_scalar(self, name: str, value: int | float, step: int = 0, **kwargs) -> None: - """Record the scalar. - - Args: - name (str): The scalar identifier. - value (int, float): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - self._neptune[name].append(value, step=step) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - assert isinstance(scalar_dict, dict) - assert "step" not in scalar_dict, "Please set it directly through the step parameter" - - for k, v in scalar_dict.items(): - self._neptune[k].append(v, step=step) - - def close(self) -> None: - """Close an opened object.""" - if hasattr(self, "_neptune"): - self._neptune.stop() - - -@VISBACKENDS.register_module(force=True) -class DVCLiveVisBackend(BaseVisBackend): - """DVCLive visualization backend class. - - Examples: - >>> from visengine.visualization import DVCLiveVisBackend - >>> import numpy as np - >>> dvclive_vis_backend = DVCLiveVisBackend(save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> dvclive_vis_backend.add_image('img', img) - >>> dvclive_vis_backend.add_scalar('mAP', 0.6) - >>> dvclive_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> dvclive_vis_backend.add_config(cfg) - - Note: - `New in version 0.9.0.` - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - artifact_suffix (Tuple[str] or str, optional): The artifact suffix. - Defaults to ('.json', '.py', 'yaml'). - init_kwargs (dict, optional): DVCLive initialization parameters. - See `DVCLive `_ for details. - Defaults to None. - """ - - def __init__( - self, - save_dir: str, - artifact_suffix: SUFFIX_TYPE = (".json", ".py", "yaml"), - init_kwargs: dict | None = None, - ): - super().__init__(save_dir) - self._artifact_suffix = artifact_suffix - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for dvclive.""" - if digit_version(platform.python_version()) < digit_version("3.8"): - raise RuntimeError("Please use Python 3.8 or higher version to use DVCLiveVisBackend.") - - try: - import pygit2 - from dvclive import Live - except ImportError: - raise ImportError('Please run "pip install dvclive" to install dvclive') - # if no git info, init dvc without git to avoid SCMError - try: - path = pygit2.discover_repository(os.fspath(os.curdir), True, "") - pygit2.Repository(path).default_signature - except KeyError: - os.system("dvc init -f --no-scm") - - if self._init_kwargs is None: - self._init_kwargs = {} - self._init_kwargs.setdefault("dir", self._save_dir) - self._init_kwargs.setdefault("save_dvc_exp", True) - self._init_kwargs.setdefault("cache_images", True) - - self._dvclive = Live(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return dvclive object. - - The experiment attribute can get the dvclive backend, If you want to - write other data, such as writing a table, you can directly get the - dvclive backend through experiment. - """ - return self._dvclive - - @force_init_env - def add_config(self, config: Config, **kwargs) -> None: - """Record the config to dvclive. - - Args: - config (Config): The Config object - """ - assert isinstance(config, Config) - self.cfg = config - self._dvclive.log_params(self._to_dvc_paramlike(self.cfg.to_dict())) - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image to dvclive. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. - step (int): Useless parameter. Dvclive does not - need this parameter. Defaults to 0. - """ - assert image.dtype == np.uint8 - save_file_name = f"{name}.png" - - self._dvclive.log_image(save_file_name, image) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to dvclive. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - if isinstance(value, torch.Tensor): - value = value.numpy() - self._dvclive.step = step - self._dvclive.log_metric(name, value) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to dvclive. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step, **kwargs) - - def close(self) -> None: - """Close an opened dvclive object.""" - if not hasattr(self, "_dvclive"): - return - - file_paths = {} - for filename in scandir(self._save_dir, self._artifact_suffix, True): - file_path = osp.join(self._save_dir, filename) - relative_path = os.path.relpath(file_path, self._save_dir) - dir_path = os.path.dirname(relative_path) - file_paths[file_path] = dir_path - - for file_path, dir_path in file_paths.items(): - self._dvclive.log_artifact(file_path, dir_path) - - self._dvclive.end() - - def _to_dvc_paramlike( - self, - value: int | float | dict | list | tuple | Config | ConfigDict | torch.Tensor | np.ndarray, - ): - """Convert the input value to a DVC `ParamLike` recursively. - - Or the `log_params` method of dvclive will raise an error. - """ - - if isinstance(value, dict | Config | ConfigDict): - return {k: self._to_dvc_paramlike(v) for k, v in value.items()} - elif isinstance(value, tuple | list): - return [self._to_dvc_paramlike(item) for item in value] - elif isinstance(value, torch.Tensor | np.ndarray): - return value.tolist() - elif isinstance(value, np.generic): - return value.item() - else: - return value - - -@VISBACKENDS.register_module(force=True) -class AimVisBackend(BaseVisBackend): - """Aim visualization backend class. - - Examples: - >>> from visengine.visualization import AimVisBackend - >>> import numpy as np - >>> aim_vis_backend = AimVisBackend() - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> aim_vis_backend.add_image('img', img) - >>> aim_vis_backend.add_scalar('mAP', 0.6) - >>> aim_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8}) - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> aim_vis_backend.add_config(cfg) - - Note: - 1. `New in version 0.9.0.` - 2. Refer to - `Github issue `_ , - Aim is not unable to be install on Windows for now. - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. - init_kwargs (dict, optional): Aim initialization parameters. See - `Aim `_ - for details. Defaults to None. - """ - - def __init__(self, save_dir: str | None = None, init_kwargs: dict | None = None): - super().__init__(save_dir) # type:ignore - self._init_kwargs = init_kwargs - - def _init_env(self): - """Setup env for Aim.""" - try: - from aim import Run - except ImportError: - raise ImportError('Please run "pip install aim" to install aim') - - from datetime import datetime - - if self._save_dir is not None: - path_list = os.path.normpath(self._save_dir).split(os.sep) - exp_name = f"{path_list[-2]}_{path_list[-1]}" - else: - exp_name = datetime.now().strftime("%Y%m%d_%H%M%S") - - if self._init_kwargs is None: - self._init_kwargs = {} - self._init_kwargs.setdefault("experiment", exp_name) - self._aim_run = Run(**self._init_kwargs) - - @property # type: ignore - @force_init_env - def experiment(self): - """Return Aim object.""" - return self._aim_run - - @force_init_env - def add_config(self, config, **kwargs) -> None: - """Record the config to Aim. - - Args: - config (Config): The Config object - """ - if isinstance(config, Config): - config = config.to_dict() - self._aim_run["hparams"] = config - - @force_init_env - def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - from aim import Image - - self._aim_run.track(name=name, value=Image(image), step=step) - - @force_init_env - def add_scalar( - self, - name: str, - value: int | float | torch.Tensor | np.ndarray, - step: int = 0, - **kwargs, - ) -> None: - """Record the scalar data to Aim. - - Args: - name (str): The scalar identifier. - value (int, float, torch.Tensor, np.ndarray): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._aim_run.track(name=name, value=value, step=step) - - @force_init_env - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Defaults to None. - """ - for key, value in scalar_dict.items(): - self._aim_run.track(name=key, value=value, step=step) - - def close(self) -> None: - """Close the Aim.""" - if not hasattr(self, "_aim_run"): - return - - self._aim_run.close() diff --git a/libs/visengine/visengine/visualization/visualizer.py b/libs/visengine/visengine/visualization/visualizer.py deleted file mode 100644 index 854f93a..0000000 --- a/libs/visengine/visengine/visualization/visualizer.py +++ /dev/null @@ -1,1185 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import os.path as osp -import warnings -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional, Union - -if TYPE_CHECKING: - from matplotlib.font_manager import FontProperties - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F - -from visengine.config import Config -from visengine.dist import master_only -from visengine.registry import VISBACKENDS, VISUALIZERS -from visengine.structures import BaseDataElement -from visengine.utils import ManagerMixin, is_seq_of -from visengine.visualization.utils import ( - check_type, - check_type_and_length, - color_str2rgb, - color_val_matplotlib, - convert_overlay_heatmap, - img_from_canvas, - tensor2ndarray, - value2list, - wait_continue, -) -from visengine.visualization.vis_backend import BaseVisBackend - -VisBackendsType = Union[list[list | BaseDataElement], BaseDataElement, dict, None] - - -@VISUALIZERS.register_module() -class Visualizer(ManagerMixin): - """MMEngine provides a Visualizer class that uses the ``Matplotlib`` - library as the backend. It has the following functions: - - - Basic drawing methods - - - draw_bboxes: draw single or multiple bounding boxes - - draw_texts: draw single or multiple text boxes - - draw_points: draw single or multiple points - - draw_lines: draw single or multiple line segments - - draw_circles: draw single or multiple circles - - draw_polygons: draw single or multiple polygons - - draw_binary_masks: draw single or multiple binary masks - - draw_featmap: draw feature map - - - Basic visualizer backend methods - - - add_configs: write config to all vis storage backends - - add_graph: write model graph to all vis storage backends - - add_image: write image to all vis storage backends - - add_scalar: write scalar to all vis storage backends - - add_scalars: write scalars to all vis storage backends - - add_datasample: write datasample to all vis storage \ - backends. The abstract drawing interface used by the user - - - Basic info methods - - - set_image: sets the original image data - - get_image: get the image data in Numpy format after drawing - - show: visualization - - close: close all resources that have been opened - - get_backend: get the specified vis backend - - - All the basic drawing methods support chain calls, which is convenient for - overlaydrawing and display. Each downstream algorithm library can inherit - ``Visualizer`` and implement the add_datasample logic. For example, - ``DetLocalVisualizer`` in MMDetection inherits from ``Visualizer`` - and implements functions, such as visual detection boxes, instance masks, - and semantic segmentation maps in the add_datasample interface. - - Args: - name (str): Name of the instance. Defaults to 'visualizer'. - image (np.ndarray, optional): the origin image to draw. The format - should be RGB. Defaults to None. - vis_backends (list, optional): Visual backend config list. - Defaults to None. - save_dir (str, optional): Save file dir for all storage backends. - If it is None, the backend storage will not save any data. - fig_save_cfg (dict): Keyword parameters of figure for saving. - Defaults to empty dict. - fig_show_cfg (dict): Keyword parameters of figure for showing. - Defaults to empty dict. - - Examples: - >>> # Basic info methods - >>> vis = Visualizer() - >>> vis.set_image(image) - >>> vis.get_image() - >>> vis.show() - - >>> # Basic drawing methods - >>> vis = Visualizer(image=image) - >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') - >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), - >>> edge_colors=['g', 'r']) - >>> vis.draw_lines(x_datas=np.array([1, 3]), - >>> y_datas=np.array([1, 3]), - >>> colors='r', line_widths=1) - >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), - >>> y_datas=np.array([[1, 3], [2, 4]]), - >>> colors=['r', 'r'], line_widths=[1, 2]) - >>> vis.draw_texts(text='MMEngine', - >>> position=np.array([2, 2]), - >>> colors='b') - >>> vis.draw_texts(text=['MMEngine','OpenMMLab'], - >>> position=np.array([[2, 2], [5, 5]]), - >>> colors=['b', 'b']) - >>> vis.draw_circles(circle_coord=np.array([2, 2]), radius=np.array[1]) - >>> vis.draw_circles(circle_coord=np.array([[2, 2], [3, 5]), - >>> radius=np.array[1, 2], colors=['g', 'r']) - >>> square = np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) - >>> vis.draw_polygons(polygons=square, edge_colors='g') - >>> squares = [np.array([[0, 0], [100, 0], [100, 100], [0, 100]]), - >>> np.array([[0, 0], [50, 0], [50, 50], [0, 50]])] - >>> vis.draw_polygons(polygons=squares, edge_colors=['g', 'r']) - >>> vis.draw_binary_masks(binary_mask, alpha=0.6) - >>> heatmap = vis.draw_featmap(featmap, img, - >>> channel_reduction='select_max') - >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, - >>> topk=8, arrangement=(4, 2)) - >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, - >>> topk=-1) - - >>> # chain calls - >>> vis.draw_bboxes().draw_texts().draw_circle().draw_binary_masks() - - >>> # Backend related methods - >>> vis = Visualizer(vis_backends=[dict(type='LocalVisBackend')], - >>> save_dir='temp_dir') - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> vis.add_config(cfg) - >>> image=np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - >>> vis.add_image('image',image) - >>> vis.add_scaler('mAP', 0.6) - >>> vis.add_scalars({'loss': 0.1,'acc':0.8}) - - >>> # inherit - >>> class DetLocalVisualizer(Visualizer): - >>> def add_datasample(self, - >>> name, - >>> image: np.ndarray, - >>> gt_sample: - >>> Optional['BaseDataElement'] = None, - >>> pred_sample: - >>> Optional['BaseDataElement'] = None, - >>> draw_gt: bool = True, - >>> draw_pred: bool = True, - >>> show: bool = False, - >>> wait_time: int = 0, - >>> step: int = 0) -> None: - >>> pass - """ - - def __init__( - self, - name="visualizer", - image: np.ndarray | None = None, - vis_backends: VisBackendsType = None, - save_dir: str | None = None, - fig_save_cfg=dict(frameon=False), - fig_show_cfg=dict(frameon=False), - ) -> None: - super().__init__(name) - self._dataset_meta: dict | None = None - self._vis_backends: dict[str, BaseVisBackend] = {} - - if vis_backends is None: - vis_backends = [] - - if isinstance(vis_backends, (dict, BaseVisBackend)): - vis_backends = [vis_backends] # type: ignore - - if not is_seq_of(vis_backends, (dict, BaseVisBackend)): - raise TypeError("vis_backends must be a list of dicts or a list of BaseBackend instances") - if save_dir is not None: - save_dir = osp.join(save_dir, "vis_data") - - for vis_backend in vis_backends: # type: ignore - name = None - if isinstance(vis_backend, dict): - name = vis_backend.pop("name", None) - vis_backend.setdefault("save_dir", save_dir) - vis_backend = VISBACKENDS.build(vis_backend) - - # If vis_backend requires `save_dir` (with no default value) - # but is initialized with None, then don't add this - # vis_backend to the visualizer. - save_dir_arg = inspect.signature(vis_backend.__class__.__init__).parameters.get("save_dir") - if ( - save_dir_arg is not None - and save_dir_arg.default is save_dir_arg.empty - and vis_backend._save_dir is None - ): - warnings.warn(f"Failed to add {vis_backend.__class__}, please provide the `save_dir` argument.") - continue - - type_name = vis_backend.__class__.__name__ - name = name or type_name - - if name in self._vis_backends: - raise RuntimeError(f"vis_backend name {name} already exists") - self._vis_backends[name] = vis_backend # type: ignore - - self.fig_save = None - self.fig_save_cfg = fig_save_cfg - self.fig_show_cfg = fig_show_cfg - - (self.fig_save_canvas, self.fig_save, self.ax_save) = self._initialize_fig(fig_save_cfg) - self.dpi = self.fig_save.get_dpi() - - if image is not None: - self.set_image(image) - - @property # type: ignore - @master_only - def dataset_meta(self) -> dict | None: - """Optional[dict]: Meta info of the dataset.""" - return self._dataset_meta - - @dataset_meta.setter # type: ignore - @master_only - def dataset_meta(self, dataset_meta: dict) -> None: - """Set the dataset meta info to the Visualizer.""" - self._dataset_meta = dataset_meta - - @master_only - def show( - self, - drawn_img: np.ndarray | None = None, - win_name: str = "image", - wait_time: float = 0.0, - continue_key: str = " ", - backend: str = "matplotlib", - ) -> None: - """Show the drawn image. - - Args: - drawn_img (np.ndarray, optional): The image to show. If drawn_img - is None, it will show the image got by Visualizer. Defaults - to None. - win_name (str): The image title. Defaults to 'image'. - wait_time (float): Delay in seconds. 0 is the special - value that means "forever". Defaults to 0. - continue_key (str): The key for users to continue. Defaults to - the space key. - backend (str): The backend to show the image. Defaults to - 'matplotlib'. `New in version 0.7.3.` - """ - if backend == "matplotlib": - import matplotlib.pyplot as plt - - is_inline = "inline" in plt.get_backend() - img = self.get_image() if drawn_img is None else drawn_img - self._init_manager(win_name) - fig = self.manager.canvas.figure - # remove white edges by set subplot margin - fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - fig.clear() - ax = fig.add_subplot() - ax.axis(False) - ax.imshow(img) - self.manager.canvas.draw() - - # Find a better way for inline to show the image - if is_inline: - return fig - wait_continue(fig, timeout=wait_time, continue_key=continue_key) - elif backend == "cv2": - # Keep images are shown in the same window, and the title of window - # will be updated with `win_name`. - cv2.namedWindow(winname=f"{id(self)}") - cv2.setWindowTitle(f"{id(self)}", win_name) - cv2.imshow(str(id(self)), self.get_image() if drawn_img is None else drawn_img) - cv2.waitKey(int(np.ceil(wait_time * 1000))) - else: - raise ValueError(f'backend should be "matplotlib" or "cv2", but got {backend} instead') - - @master_only - def set_image(self, image: np.ndarray) -> None: - """Set the image to draw. - - Args: - image (np.ndarray): The image to draw. - """ - assert image is not None - - # Handle both CHW (channels, height, width) and HWC (height, width, channels) formats - if len(image.shape) == 3: - if image.shape[0] <= 4: # Assume CHW format if first dim is small (likely channels) - image = image.transpose(1, 2, 0) # Convert CHW to HWC - - image = image.astype("uint8") - self._image = image - self.width, self.height = image.shape[1], image.shape[0] - self._default_font_size = max(np.sqrt(self.height * self.width) // 90, 10) - - # add a small 1e-2 to avoid precision lost due to matplotlib's - # truncation (https://github.com/matplotlib/matplotlib/issues/15363) - self.fig_save.set_size_inches( # type: ignore - (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi - ) - # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) - self.ax_save.cla() - self.ax_save.axis(False) - self.ax_save.imshow(image, extent=(0, self.width, self.height, 0), interpolation="none") - - @master_only - def get_image(self) -> np.ndarray: - """Get the drawn image. The format is RGB. - - Returns: - np.ndarray: the drawn image which channel is RGB. - """ - assert self._image is not None, "Please set image using `set_image`" - return img_from_canvas(self.fig_save_canvas) # type: ignore - - def _initialize_fig(self, fig_cfg) -> tuple: - """Build figure according to fig_cfg. - - Args: - fig_cfg (dict): The config to build figure. - - Returns: - tuple: build canvas figure and axes. - """ - from matplotlib.backends.backend_agg import FigureCanvasAgg - from matplotlib.figure import Figure - - fig = Figure(**fig_cfg) - ax = fig.add_subplot() - ax.axis(False) - - # remove white edges by set subplot margin - fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - canvas = FigureCanvasAgg(fig) - return canvas, fig, ax - - def _init_manager(self, win_name: str) -> None: - """Initialize the matplot manager. - - Args: - win_name (str): The window name. - """ - from matplotlib.figure import Figure - from matplotlib.pyplot import new_figure_manager - - if getattr(self, "manager", None) is None: - self.manager = new_figure_manager(num=1, FigureClass=Figure, **self.fig_show_cfg) - - try: - self.manager.set_window_title(win_name) - except Exception: - self.manager = new_figure_manager(num=1, FigureClass=Figure, **self.fig_show_cfg) - self.manager.set_window_title(win_name) - - @master_only - def get_backend(self, name) -> "BaseVisBackend": - """Get vis backend by name. - - Args: - name (str): The name of vis backend - - Returns: - BaseVisBackend: The vis backend. - """ - return self._vis_backends.get(name) # type: ignore - - def _is_posion_valid(self, position: np.ndarray) -> bool: - """Judge whether the position is in image. - - Args: - position (np.ndarray): The position to judge which last dim must - be two and the format is [x, y]. - - Returns: - bool: Whether the position is in image. - """ - flag = ( - (position[..., 0] < self.width).all() - and (position[..., 0] >= 0).all() - and (position[..., 1] < self.height).all() - and (position[..., 1] >= 0).all() - ) - return flag - - @master_only - def draw_points( - self, - positions: np.ndarray | torch.Tensor, - colors: str | tuple | list[str] | list[tuple] = "g", - marker: str | None = None, - sizes: np.ndarray | torch.Tensor | None = None, - ): - """Draw single or multiple points. - - Args: - positions (Union[np.ndarray, torch.Tensor]): Positions to draw. - colors (Union[str, tuple, List[str], List[tuple]]): The colors - of points. ``colors`` can have the same length with points or - just single value. If ``colors`` is single value, all the - points will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - marker (str, optional): The marker style. - See :mod:`matplotlib.markers` for more information about - marker styles. Defaults to None. - sizes (Optional[Union[np.ndarray, torch.Tensor]]): The marker size. - Defaults to None. - """ - check_type("positions", positions, (np.ndarray, torch.Tensor)) - positions = tensor2ndarray(positions) - - if len(positions.shape) == 1: - positions = positions[None] - assert positions.shape[-1] == 2, f"The shape of `positions` should be (N, 2), but got {positions.shape}" - colors = color_val_matplotlib(colors) # type: ignore - self.ax_save.scatter(positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) - return self - - @master_only - def draw_texts( - self, - texts: str | list[str], - positions: np.ndarray | torch.Tensor, - font_sizes: int | list[int] | None = None, - colors: str | tuple | list[str] | list[tuple] = "g", - vertical_alignments: str | list[str] = "top", - horizontal_alignments: str | list[str] = "left", - font_families: str | list[str] = "sans-serif", - bboxes: dict | list[dict] | None = None, - font_properties: Union["FontProperties", list["FontProperties"]] | None = None, - ) -> "Visualizer": - """Draw single or multiple text boxes. - - Args: - texts (Union[str, List[str]]): Texts to draw. - positions (Union[np.ndarray, torch.Tensor]): The position to draw - the texts, which should have the same length with texts and - each dim contain x and y. - font_sizes (Union[int, List[int]], optional): The font size of - texts. ``font_sizes`` can have the same length with texts or - just single value. If ``font_sizes`` is single value, all the - texts will have the same font size. Defaults to None. - colors (Union[str, tuple, List[str], List[tuple]]): The colors - of texts. ``colors`` can have the same length with texts or - just single value. If ``colors`` is single value, all the - texts will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - vertical_alignments (Union[str, List[str]]): The verticalalignment - of texts. verticalalignment controls whether the y positional - argument for the text indicates the bottom, center or top side - of the text bounding box. - ``vertical_alignments`` can have the same length with - texts or just single value. If ``vertical_alignments`` is - single value, all the texts will have the same - verticalalignment. verticalalignment can be 'center' or - 'top', 'bottom' or 'baseline'. Defaults to 'top'. - horizontal_alignments (Union[str, List[str]]): The - horizontalalignment of texts. Horizontalalignment controls - whether the x positional argument for the text indicates the - left, center or right side of the text bounding box. - ``horizontal_alignments`` can have - the same length with texts or just single value. - If ``horizontal_alignments`` is single value, all the texts - will have the same horizontalalignment. Horizontalalignment - can be 'center','right' or 'left'. Defaults to 'left'. - font_families (Union[str, List[str]]): The font family of - texts. ``font_families`` can have the same length with texts or - just single value. If ``font_families`` is single value, all - the texts will have the same font family. - font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' - or 'monospace'. Defaults to 'sans-serif'. - bboxes (Union[dict, List[dict]], optional): The bounding box of the - texts. If bboxes is None, there are no bounding box around - texts. ``bboxes`` can have the same length with texts or - just single value. If ``bboxes`` is single value, all - the texts will have the same bbox. Reference to - https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch - for more details. Defaults to None. - font_properties (Union[FontProperties, List[FontProperties]], optional): - The font properties of texts. FontProperties is - a ``font_manager.FontProperties()`` object. - If you want to draw Chinese texts, you need to prepare - a font file that can show Chinese characters properly. - For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on. - Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')`` - ``font_properties`` can have the same length with texts or - just single value. If ``font_properties`` is single value, - all the texts will have the same font properties. - Defaults to None. - `New in version 0.6.0.` - """ - from matplotlib.font_manager import FontProperties - - check_type("texts", texts, (str, list)) - if isinstance(texts, str): - texts = [texts] - num_text = len(texts) - check_type("positions", positions, (np.ndarray, torch.Tensor)) - positions = tensor2ndarray(positions) - if len(positions.shape) == 1: - positions = positions[None] - assert positions.shape == (num_text, 2), ( - f"`positions` should have the shape of ({num_text}, 2), but got {positions.shape}" - ) - if not self._is_posion_valid(positions): - warnings.warn( - "Warning: The text is out of bounds, the drawn text may not be in the image", - UserWarning, - ) - positions = positions.tolist() - - if font_sizes is None: - font_sizes = self._default_font_size - check_type_and_length("font_sizes", font_sizes, (int, float, list), num_text) - font_sizes = value2list(font_sizes, (int, float), num_text) - - check_type_and_length("colors", colors, (str, tuple, list), num_text) - colors = value2list(colors, (str, tuple), num_text) - colors = color_val_matplotlib(colors) # type: ignore - - check_type_and_length("vertical_alignments", vertical_alignments, (str, list), num_text) - vertical_alignments = value2list(vertical_alignments, str, num_text) - - check_type_and_length("horizontal_alignments", horizontal_alignments, (str, list), num_text) - horizontal_alignments = value2list(horizontal_alignments, str, num_text) - - check_type_and_length("font_families", font_families, (str, list), num_text) - font_families = value2list(font_families, str, num_text) - - if font_properties is None: - font_properties = [None for _ in range(num_text)] # type: ignore - else: - check_type_and_length("font_properties", font_properties, (FontProperties, list), num_text) - font_properties = value2list(font_properties, FontProperties, num_text) - - if bboxes is None: - bboxes = [None for _ in range(num_text)] # type: ignore - else: - check_type_and_length("bboxes", bboxes, (dict, list), num_text) - bboxes = value2list(bboxes, dict, num_text) - - for i in range(num_text): - self.ax_save.text( - positions[i][0], - positions[i][1], - texts[i], - size=font_sizes[i], # type: ignore - bbox=bboxes[i], # type: ignore - verticalalignment=vertical_alignments[i], - horizontalalignment=horizontal_alignments[i], - family=font_families[i], - fontproperties=font_properties[i], - color=colors[i], - ) - return self - - @master_only - def draw_lines( - self, - x_datas: np.ndarray | torch.Tensor, - y_datas: np.ndarray | torch.Tensor, - colors: str | tuple | list[str] | list[tuple] = "g", - line_styles: str | list[str] = "-", - line_widths: int | float | list[int | float] = 2, - ) -> "Visualizer": - """Draw single or multiple line segments. - - Args: - x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of - each line' start and end points. - y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of - each line' start and end points. - colors (Union[str, tuple, List[str], List[tuple]]): The colors of - lines. ``colors`` can have the same length with lines or just - single value. If ``colors`` is single value, all the lines - will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g'. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - """ - from matplotlib.collections import LineCollection - - check_type("x_datas", x_datas, (np.ndarray, torch.Tensor)) - x_datas = tensor2ndarray(x_datas) - check_type("y_datas", y_datas, (np.ndarray, torch.Tensor)) - y_datas = tensor2ndarray(y_datas) - assert x_datas.shape == y_datas.shape, "`x_datas` and `y_datas` should have the same shape" - assert x_datas.shape[-1] == 2, f"The shape of `x_datas` should be (N, 2), but got {x_datas.shape}" - if len(x_datas.shape) == 1: - x_datas = x_datas[None] - y_datas = y_datas[None] - colors = color_val_matplotlib(colors) # type: ignore - lines = np.concatenate((x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) - if not self._is_posion_valid(lines): - warnings.warn( - "Warning: The line is out of bounds, the drawn line may not be in the image", - UserWarning, - ) - line_collect = LineCollection( - lines.tolist(), - colors=colors, - linestyles=line_styles, - linewidths=line_widths, - ) - self.ax_save.add_collection(line_collect) - return self - - @master_only - def draw_circles( - self, - center: np.ndarray | torch.Tensor, - radius: np.ndarray | torch.Tensor, - edge_colors: str | tuple | list[str] | list[tuple] = "g", - line_styles: str | list[str] = "-", - line_widths: int | float | list[int | float] = 2, - face_colors: str | tuple | list[str] | list[tuple] = "none", - alpha: float | int = 0.8, - ) -> "Visualizer": - """Draw single or multiple circles. - - Args: - center (Union[np.ndarray, torch.Tensor]): The x coordinate of - each line' start and end points. - radius (Union[np.ndarray, torch.Tensor]): The y coordinate of - each line' start and end points. - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of circles. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, - all the lines will have the same colors. Reference to - https://matplotlib.org/stable/gallery/color/named_colors.html - for more details. Defaults to 'g. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of circles. - Defaults to 0.8. - """ - from matplotlib.collections import PatchCollection - from matplotlib.patches import Circle - - check_type("center", center, (np.ndarray, torch.Tensor)) - center = tensor2ndarray(center) - check_type("radius", radius, (np.ndarray, torch.Tensor)) - radius = tensor2ndarray(radius) - if len(center.shape) == 1: - center = center[None] - assert center.shape == (radius.shape[0], 2), ( - f"The shape of `center` should be (radius.shape, 2), but got {center.shape}" - ) - if not ( - self._is_posion_valid(center - np.tile(radius.reshape((-1, 1)), (1, 2))) - and self._is_posion_valid(center + np.tile(radius.reshape((-1, 1)), (1, 2))) - ): - warnings.warn( - "Warning: The circle is out of bounds, the drawn circle may not be in the image", - UserWarning, - ) - - center = center.tolist() - radius = radius.tolist() - edge_colors = color_val_matplotlib(edge_colors) # type: ignore - face_colors = color_val_matplotlib(face_colors) # type: ignore - circles = [] - for i in range(len(center)): - circles.append(Circle(tuple(center[i]), radius[i])) - - if isinstance(line_widths, (int, float)): - line_widths = [line_widths] * len(circles) - line_widths = [min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths] - p = PatchCollection( - circles, - alpha=alpha, - facecolors=face_colors, - edgecolors=edge_colors, - linewidths=line_widths, - linestyles=line_styles, - ) - self.ax_save.add_collection(p) - return self - - @master_only - def draw_bboxes( - self, - bboxes: np.ndarray | torch.Tensor, - edge_colors: str | tuple | list[str] | list[tuple] = "g", - line_styles: str | list[str] = "-", - line_widths: int | float | list[int | float] = 2, - face_colors: str | tuple | list[str] | list[tuple] = "none", - alpha: int | float = 0.8, - ) -> "Visualizer": - """Draw single or multiple bboxes. - - Args: - bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with - the format of(x1,y1,x2,y2). - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of bboxes. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, all - the lines will have the same colors. Refer to `matplotlib. - colors` for full list of formats that are accepted. - Defaults to 'g'. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of bboxes. - Defaults to 0.8. - """ - check_type("bboxes", bboxes, (np.ndarray, torch.Tensor)) - bboxes = tensor2ndarray(bboxes) - - if len(bboxes.shape) == 1: - bboxes = bboxes[None] - assert bboxes.shape[-1] == 4, f"The shape of `bboxes` should be (N, 4), but got {bboxes.shape}" - - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= bboxes[:, 3]).all() - if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): - warnings.warn( - "Warning: The bbox is out of bounds, the drawn bbox may not be in the image", - UserWarning, - ) - poly = np.stack( - ( - bboxes[:, 0], - bboxes[:, 1], - bboxes[:, 2], - bboxes[:, 1], - bboxes[:, 2], - bboxes[:, 3], - bboxes[:, 0], - bboxes[:, 3], - ), - axis=-1, - ).reshape(-1, 4, 2) - poly = [p for p in poly] - return self.draw_polygons( - poly, - alpha=alpha, - edge_colors=edge_colors, - line_styles=line_styles, - line_widths=line_widths, - face_colors=face_colors, - ) - - @master_only - def draw_polygons( - self, - polygons: np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor], - edge_colors: str | tuple | list[str] | list[tuple] = "g", - line_styles: str | list[str] = "-", - line_widths: int | float | list[int | float] = 2, - face_colors: str | tuple | list[str] | list[tuple] = "none", - alpha: int | float = 0.8, - ) -> "Visualizer": - """Draw single or multiple bboxes. - - Args: - polygons (Union[Union[np.ndarray, torch.Tensor],\ - List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw - with the format of (x1,y1,x2,y2,...,xn,yn). - edge_colors (Union[str, tuple, List[str], List[tuple]]): The - colors of polygons. ``colors`` can have the same length with - lines or just single value. If ``colors`` is single value, - all the lines will have the same colors. Refer to - `matplotlib.colors` for full list of formats that are accepted. - Defaults to 'g. - line_styles (Union[str, List[str]]): The linestyle - of lines. ``line_styles`` can have the same length with - texts or just single value. If ``line_styles`` is single - value, all the lines will have the same linestyle. - Reference to - https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle - for more details. Defaults to '-'. - line_widths (Union[Union[int, float], List[Union[int, float]]]): - The linewidth of lines. ``line_widths`` can have - the same length with lines or just single value. - If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 2. - face_colors (Union[str, tuple, List[str], List[tuple]]): - The face colors. Defaults to None. - alpha (Union[int, float]): The transparency of polygons. - Defaults to 0.8. - """ - from matplotlib.collections import PolyCollection - - check_type("polygons", polygons, (list, np.ndarray, torch.Tensor)) - edge_colors = color_val_matplotlib(edge_colors) # type: ignore - face_colors = color_val_matplotlib(face_colors) # type: ignore - - if isinstance(polygons, (np.ndarray, torch.Tensor)): - polygons = [polygons] - if isinstance(polygons, list): - for polygon in polygons: - assert polygon.shape[1] == 2, ( - f"The shape of each polygon in `polygons` should be (M, 2), but got {polygon.shape}" - ) - polygons = [tensor2ndarray(polygon) for polygon in polygons] - for polygon in polygons: - if not self._is_posion_valid(polygon): - warnings.warn( - "Warning: The polygon is out of bounds, the drawn polygon may not be in the image", - UserWarning, - ) - if isinstance(line_widths, (int, float)): - line_widths = [line_widths] * len(polygons) - line_widths = [min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor=face_colors, - linestyles=line_styles, - edgecolors=edge_colors, - linewidths=line_widths, - ) - - self.ax_save.add_collection(polygon_collection) - return self - - @master_only - def draw_binary_masks( - self, - binary_masks: np.ndarray | torch.Tensor, - colors: str | tuple | list[str] | list[tuple] = "g", - alphas: float | list[float] = 0.8, - ) -> "Visualizer": - """Draw single or multiple binary masks. - - Args: - binary_masks (np.ndarray, torch.Tensor): The binary_masks to draw - with of shape (N, H, W), where H is the image height and W is - the image width. Each value in the array is either a 0 or 1 - value of uint8 type. - colors (np.ndarray): The colors which binary_masks will convert to. - ``colors`` can have the same length with binary_masks or just - single value. If ``colors`` is single value, all the - binary_masks will convert to the same colors. The colors format - is RGB. Defaults to np.array([0, 255, 0]). - alphas (Union[int, List[int]]): The transparency of masks. - Defaults to 0.8. - """ - check_type("binary_masks", binary_masks, (np.ndarray, torch.Tensor)) - binary_masks = tensor2ndarray(binary_masks) - assert binary_masks.dtype == np.bool_, ( - f"The dtype of binary_masks should be np.bool_, but got {binary_masks.dtype}" - ) - binary_masks = binary_masks.astype("uint8") * 255 - img = self.get_image() - if binary_masks.ndim == 2: - binary_masks = binary_masks[None] - - # Debug logging for shape mismatch - import logging - - logger = logging.getLogger(__name__) - logger.debug(f"draw_binary_masks - img.shape: {img.shape}") - logger.debug(f"draw_binary_masks - binary_masks.shape: {binary_masks.shape}") - logger.debug(f"draw_binary_masks - img.shape[:2]: {img.shape[:2]}") - logger.debug(f"draw_binary_masks - binary_masks.shape[1:]: {binary_masks.shape[1:]}") - - # This helped to catch a problem where training was using mmcv's Pad which didnt' - # edit the masks, instead of mmdetection's Pad which does. - assert img.shape[:2] == binary_masks.shape[1:], ( - f"`binary_masks` must have the same shape with image. Got img.shape[:2]={img.shape[:2]}, binary_masks.shape[1:]={binary_masks.shape[1:]}" - ) - binary_mask_len = binary_masks.shape[0] - - check_type_and_length("colors", colors, (str, tuple, list), binary_mask_len) - colors = value2list(colors, (str, tuple), binary_mask_len) - colors = [color_str2rgb(color) if isinstance(color, str) else color for color in colors] - for color in colors: - assert len(color) == 3 - for channel in color: - assert 0 <= channel <= 255 # type: ignore - - if isinstance(alphas, float): - alphas = [alphas] * binary_mask_len - - for binary_mask, color, alpha in zip(binary_masks, colors, alphas, strict=False): - binary_mask_complement = cv2.bitwise_not(binary_mask) - rgb = np.zeros_like(img) - rgb[...] = color - rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask) - img_complement = cv2.bitwise_and(img, img, mask=binary_mask_complement) - rgb = rgb + img_complement - img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) - self.ax_save.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") - return self - - @staticmethod - @master_only - def draw_featmap( - featmap: torch.Tensor, - overlaid_image: np.ndarray | None = None, - channel_reduction: str | None = "squeeze_mean", - topk: int = 20, - arrangement: tuple[int, int] = (4, 5), - resize_shape: tuple | None = None, - alpha: float = 0.5, - ) -> np.ndarray: - """Draw featmap. - - - If `overlaid_image` is not None, the final output image will be the - weighted sum of img and featmap. - - - If `resize_shape` is specified, `featmap` and `overlaid_image` - are interpolated. - - - If `resize_shape` is None and `overlaid_image` is not None, - the feature map will be interpolated to the spatial size of the image - in the case where the spatial dimensions of `overlaid_image` and - `featmap` are different. - - - If `channel_reduction` is "squeeze_mean" and "select_max", - it will compress featmap to single channel image and weighted - sum to `overlaid_image`. - - - If `channel_reduction` is None - - - If topk <= 0, featmap is assert to be one or three - channel and treated as image and will be weighted sum - to ``overlaid_image``. - - If topk > 0, it will select topk channel to show by the sum of - each channel. At the same time, you can specify the `arrangement` - to set the window layout. - - Args: - featmap (torch.Tensor): The featmap to draw which format is - (C, H, W). - overlaid_image (np.ndarray, optional): The overlaid image. - Defaults to None. - channel_reduction (str, optional): Reduce multiple channels to a - single channel. The optional value is 'squeeze_mean' - or 'select_max'. Defaults to 'squeeze_mean'. - topk (int): If channel_reduction is not None and topk > 0, - it will select topk channel to show by the sum of each channel. - if topk <= 0, tensor_chw is assert to be one or three. - Defaults to 20. - arrangement (Tuple[int, int]): The arrangement of featmap when - channel_reduction is None and topk > 0. Defaults to (4, 5). - resize_shape (tuple, optional): The shape to scale the feature map. - Defaults to None. - alpha (Union[int, List[int]]): The transparency of featmap. - Defaults to 0.5. - - Returns: - np.ndarray: RGB image. - """ - import matplotlib.pyplot as plt - - assert isinstance(featmap, torch.Tensor), f"`featmap` should be torch.Tensor, but got {type(featmap)}" - assert featmap.ndim == 3, f"Input dimension must be 3, but got {featmap.ndim}" - featmap = featmap.detach().cpu() - - if overlaid_image is not None: - if overlaid_image.ndim == 2: - overlaid_image = cv2.cvtColor(overlaid_image, cv2.COLOR_GRAY2RGB) - - if overlaid_image.shape[:2] != featmap.shape[1:]: - warnings.warn( - f"Since the spatial dimensions of " - f"overlaid_image: {overlaid_image.shape[:2]} and " - f"featmap: {featmap.shape[1:]} are not same, " - f"the feature map will be interpolated. " - f"This may cause mismatch problems !" - ) - if resize_shape is None: - featmap = F.interpolate( - featmap[None], - overlaid_image.shape[:2], - mode="bilinear", - align_corners=False, - )[0] - - if resize_shape is not None: - featmap = F.interpolate(featmap[None], resize_shape, mode="bilinear", align_corners=False)[0] - if overlaid_image is not None: - overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) - - if channel_reduction is not None: - assert channel_reduction in ["squeeze_mean", "select_max"], ( - f'Mode only support "squeeze_mean", "select_max", but got {channel_reduction}' - ) - if channel_reduction == "select_max": - sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) - _, indices = torch.topk(sum_channel_featmap, 1) - feat_map = featmap[indices] - else: - feat_map = torch.mean(featmap, dim=0) - return convert_overlay_heatmap(feat_map, overlaid_image, alpha) - elif topk <= 0: - featmap_channel = featmap.shape[0] - assert featmap_channel in [1, 3], ( - "The input tensor channel dimension must be 1 or 3 " - "when topk is less than 1, but the channel " - f"dimension you input is {featmap_channel}, you can use the" - " channel_reduction parameter or set topk greater than " - "0 to solve the error" - ) - return convert_overlay_heatmap(featmap, overlaid_image, alpha) - else: - row, col = arrangement - channel, height, width = featmap.shape - assert row * col >= topk, ( - "The product of row and col in the `arrangement` is less than topk, please set the `arrangement` correctly" - ) - - # Extract the feature map of topk - topk = min(channel, topk) - sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) - _, indices = torch.topk(sum_channel_featmap, topk) - topk_featmap = featmap[indices] - - fig = plt.figure(frameon=False) - # Set the window layout - fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) - dpi = fig.get_dpi() - fig.set_size_inches((width * col + 1e-2) / dpi, (height * row + 1e-2) / dpi) - for i in range(topk): - axes = fig.add_subplot(row, col, i + 1) - axes.axis("off") - axes.text(2, 15, f"channel: {indices[i]}", fontsize=10) - axes.imshow(convert_overlay_heatmap(topk_featmap[i], overlaid_image, alpha)) - image = img_from_canvas(fig.canvas) - plt.close(fig) - return image - - @master_only - def add_config(self, config: Config, **kwargs): - """Record the config. - - Args: - config (Config): The Config object. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_config(config, **kwargs) - - @master_only - def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], **kwargs) -> None: - """Record the model graph. - - Args: - model (torch.nn.Module): Model to draw. - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_graph(model, data_batch, **kwargs) - - @master_only - def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: - """Record the image. - - Args: - name (str): The image identifier. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_image(name, image, step) # type: ignore - - @master_only - def add_scalar(self, name: str, value: int | float, step: int = 0, **kwargs) -> None: - """Record the scalar data. - - Args: - name (str): The scalar identifier. - value (float, int): Value to save. - step (int): Global step value to record. Defaults to 0. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore - - @master_only - def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: str | None = None, **kwargs) -> None: - """Record the scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Defaults to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Defaults to None. - """ - for vis_backend in self._vis_backends.values(): - vis_backend.add_scalars(scalar_dict, step, file_path, **kwargs) - - @master_only - def add_datasample( - self, - name, - image: np.ndarray, - data_sample: Optional["BaseDataElement"] = None, - draw_gt: bool = True, - draw_pred: bool = True, - show: bool = False, - wait_time: int = 0, - step: int = 0, - ) -> None: - """Draw datasample.""" - pass - - def close(self) -> None: - """Close an opened object.""" - for vis_backend in self._vis_backends.values(): - vis_backend.close() - - @classmethod - def get_instance(cls, name: str, **kwargs) -> "Visualizer": - """Make subclass can get latest created instance by - ``Visualizer.get_current_instance()``. - - Downstream codebase may need to get the latest created instance - without knowing the specific Visualizer type. For example, mmdetection - builds visualizer in runner and some component which cannot access - runner wants to get latest created visualizer. In this case, - the component does not know which type of visualizer has been built - and cannot get target instance. Therefore, :class:`Visualizer` - overrides the :meth:`get_instance` and its subclass will register - the created instance to :attr:`_instance_dict` additionally. - :meth:`get_current_instance` will return the latest created subclass - instance. - - Examples: - >>> class DetLocalVisualizer(Visualizer): - >>> def __init__(self, name): - >>> super().__init__(name) - >>> - >>> visualizer1 = DetLocalVisualizer.get_instance('name1') - >>> visualizer2 = Visualizer.get_current_instance() - >>> visualizer3 = DetLocalVisualizer.get_current_instance() - >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) - - Args: - name (str): Name of instance. - - Returns: - object: Corresponding name instance. - """ - instance = super().get_instance(name, **kwargs) - Visualizer._instance_dict[name] = instance - return instance diff --git a/pyproject.toml b/pyproject.toml index a6bbc6b..dd07787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,8 +145,7 @@ dev = [ "zuban>=0.1.2", ] -[tool.uv.workspace] -members = ["libs/viscv", "libs/visengine"] +# Workspace members have been consolidated into visdet # Note: mmtracking git dependency removed due to build isolation issues # To install for tests: uv pip install --no-build-isolation git+https://github.com/open-mmlab/mmtracking.git diff --git a/DOCS_README.md b/scratch_pads/DOCS_README.md similarity index 100% rename from DOCS_README.md rename to scratch_pads/DOCS_README.md diff --git a/README_zh-CN.md b/scratch_pads/README_zh-CN.md similarity index 100% rename from README_zh-CN.md rename to scratch_pads/README_zh-CN.md diff --git a/REFACTORING_PLAN.md b/scratch_pads/REFACTORING_PLAN.md similarity index 100% rename from REFACTORING_PLAN.md rename to scratch_pads/REFACTORING_PLAN.md diff --git a/training.log b/training.log new file mode 100644 index 0000000..9aad0a5 --- /dev/null +++ b/training.log @@ -0,0 +1,4 @@ +Traceback (most recent call last): + File "/home/georgepearse/visdet-worktrees/main/scripts/train_cmr.py", line 15, in + from visdet import SimpleRunner +ModuleNotFoundError: No module named 'visdet' diff --git a/uv.lock b/uv.lock index c8e7d25..2a8e199 100644 --- a/uv.lock +++ b/uv.lock @@ -13,22 +13,6 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", ] -[manifest] -members = [ - "viscv", - "visdet", - "visengine", -] - -[[package]] -name = "addict" -version = "2.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/85/ef/fd7649da8af11d93979831e8f1f8097e85e82d5bfeabc8c68b39175d8e75/addict-2.4.0.tar.gz", hash = "sha256:b3b2210e0e067a281f5646c8c5db92e99b7231ea8b0eb5f74dbdf9e259d4e494", size = 9186, upload-time = "2020-11-21T16:21:31.416Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl", hash = "sha256:249bb56bbfd3cdc2a004ea0ff4c2b6ddc84d53bc2194761636eb314d5cfa5dfc", size = 3832, upload-time = "2020-11-21T16:21:29.588Z" }, -] - [[package]] name = "albucore" version = "0.0.24" @@ -121,31 +105,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/ff/392bff89415399a979be4a65357a41d92729ae8580a66073d8ec8d810f98/backrefs-5.9-py39-none-any.whl", hash = "sha256:f48ee18f6252b8f5777a22a00a09a85de0ca931658f1dd96d4406a34f3748c60", size = 380265, upload-time = "2025-06-22T19:34:12.405Z" }, ] -[[package]] -name = "bitsandbytes" -version = "0.48.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'darwin'" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'darwin'" }, - { name = "packaging", marker = "sys_platform != 'darwin'" }, - { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/a1/c66a0d02091a9ef6704a153e09f541e70fc48db02df94fbbf1632b321aca/bitsandbytes-0.48.1-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:b7f440aee5ec8cb1d028b0d3b2d71e97c302766dc605232293f4a0f7e48b5c75", size = 34033906, upload-time = "2025-10-02T17:40:25.159Z" }, - { url = "https://files.pythonhosted.org/packages/a1/63/00b3f2af0b88edf5855792681641a71cf0605d6d88c62e98d75cff86d105/bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:3e72cf07ba6d2169e69a61282a6f072fc675efee86049e56a33de099a0363ef2", size = 60088899, upload-time = "2025-10-02T17:40:28.732Z" }, -] - -[[package]] -name = "cachetools" -version = "6.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cc/7e/b975b5814bd36faf009faebe22c1072a1fa1168db34d285ef0ba071ad78c/cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201", size = 31325, upload-time = "2025-10-12T14:55:30.139Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" }, -] - [[package]] name = "certifi" version = "2025.10.5" @@ -245,18 +204,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" }, ] -[[package]] -name = "cloudpathlib" -version = "0.23.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/18/2ac35d6b3015a0c74e923d94fc69baf8307f7c3233de015d69f99e17afa8/cloudpathlib-0.23.0.tar.gz", hash = "sha256:eb38a34c6b8a048ecfd2b2f60917f7cbad4a105b7c979196450c2f541f4d6b4b", size = 53126, upload-time = "2025-10-07T22:47:56.278Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/8a/c4bb04426d608be4a3171efa2e233d2c59a5c8937850c10d098e126df18e/cloudpathlib-0.23.0-py3-none-any.whl", hash = "sha256:8520b3b01468fee77de37ab5d50b1b524ea6b4a8731c35d1b7407ac0cd716002", size = 62755, upload-time = "2025-10-07T22:47:54.905Z" }, -] - [[package]] name = "codecov" version = "2.1.13" @@ -462,15 +409,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, ] -[[package]] -name = "filelock" -version = "3.20.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, -] - [[package]] name = "flake8" version = "7.3.0" @@ -527,15 +465,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/93/0dd45cd283c32dea1545151d8c3637b4b8c53cdb3a625aeb2885b184d74d/fonttools-4.60.1-py3-none-any.whl", hash = "sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb", size = 1143175, upload-time = "2025-09-29T21:13:24.134Z" }, ] -[[package]] -name = "fsspec" -version = "2025.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/de/e0/bab50af11c2d75c9c4a2a26a5254573c0bd97cea152254401510950486fa/fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19", size = 304847, upload-time = "2025-09-02T19:10:49.215Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, -] - [[package]] name = "ghp-import" version = "2.1.0" @@ -572,118 +501,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, ] -[[package]] -name = "google-api-core" -version = "2.26.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "googleapis-common-protos" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/ea/e7b6ac3c7b557b728c2d0181010548cbbdd338e9002513420c5a354fa8df/google_api_core-2.26.0.tar.gz", hash = "sha256:e6e6d78bd6cf757f4aee41dcc85b07f485fbb069d5daa3afb126defba1e91a62", size = 166369, upload-time = "2025-10-08T21:37:38.39Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/ad/f73cf9fe9bd95918502b270e3ddb8764e4c900b3bbd7782b90c56fac14bb/google_api_core-2.26.0-py3-none-any.whl", hash = "sha256:2b204bd0da2c81f918e3582c48458e24c11771f987f6258e6e227212af78f3ed", size = 162505, upload-time = "2025-10-08T21:37:36.651Z" }, -] - -[[package]] -name = "google-auth" -version = "2.41.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools" }, - { name = "pyasn1-modules" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a8/af/5129ce5b2f9688d2fa49b463e544972a7c82b0fdb50980dafee92e121d9f/google_auth-2.41.1.tar.gz", hash = "sha256:b76b7b1f9e61f0cb7e88870d14f6a94aeef248959ef6992670efee37709cbfd2", size = 292284, upload-time = "2025-09-30T22:51:26.363Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/a4/7319a2a8add4cc352be9e3efeff5e2aacee917c85ca2fa1647e29089983c/google_auth-2.41.1-py2.py3-none-any.whl", hash = "sha256:754843be95575b9a19c604a848a41be03f7f2afd8c019f716dc1f51ee41c639d", size = 221302, upload-time = "2025-09-30T22:51:24.212Z" }, -] - -[[package]] -name = "google-cloud-core" -version = "2.4.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d6/b8/2b53838d2acd6ec6168fd284a990c76695e84c65deee79c9f3a4276f6b4f/google_cloud_core-2.4.3.tar.gz", hash = "sha256:1fab62d7102844b278fe6dead3af32408b1df3eb06f5c7e8634cbd40edc4da53", size = 35861, upload-time = "2025-03-10T21:05:38.948Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/86/bda7241a8da2d28a754aad2ba0f6776e35b67e37c36ae0c45d49370f1014/google_cloud_core-2.4.3-py2.py3-none-any.whl", hash = "sha256:5130f9f4c14b4fafdff75c79448f9495cfade0d8775facf1b09c3bf67e027f6e", size = 29348, upload-time = "2025-03-10T21:05:37.785Z" }, -] - -[[package]] -name = "google-cloud-storage" -version = "3.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-crc32c" }, - { name = "google-resumable-media" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bd/ef/7cefdca67a6c8b3af0ec38612f9e78e5a9f6179dd91352772ae1a9849246/google_cloud_storage-3.4.1.tar.gz", hash = "sha256:6f041a297e23a4b485fad8c305a7a6e6831855c208bcbe74d00332a909f82268", size = 17238203, upload-time = "2025-10-08T18:43:39.665Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/83/6e/b47d83d3a35231c6232566341b0355cce78fd4e6988a7343725408547b2c/google_cloud_storage-3.4.1-py3-none-any.whl", hash = "sha256:972764cc0392aa097be8f49a5354e22eb47c3f62370067fb1571ffff4a1c1189", size = 290142, upload-time = "2025-10-08T18:43:37.524Z" }, -] - -[[package]] -name = "google-crc32c" -version = "1.7.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/69/b1b05cf415df0d86691d6a8b4b7e60ab3a6fb6efb783ee5cd3ed1382bfd3/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:b07d48faf8292b4db7c3d64ab86f950c2e94e93a11fd47271c28ba458e4a0d76", size = 30467, upload-time = "2025-03-26T14:31:11.92Z" }, - { url = "https://files.pythonhosted.org/packages/44/3d/92f8928ecd671bd5b071756596971c79d252d09b835cdca5a44177fa87aa/google_crc32c-1.7.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:7cc81b3a2fbd932a4313eb53cc7d9dde424088ca3a0337160f35d91826880c1d", size = 30311, upload-time = "2025-03-26T14:53:14.161Z" }, - { url = "https://files.pythonhosted.org/packages/33/42/c2d15a73df79d45ed6b430b9e801d0bd8e28ac139a9012d7d58af50a385d/google_crc32c-1.7.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c67ca0a1f5b56162951a9dae987988679a7db682d6f97ce0f6381ebf0fbea4c", size = 37889, upload-time = "2025-03-26T14:41:27.83Z" }, - { url = "https://files.pythonhosted.org/packages/57/ea/ac59c86a3c694afd117bb669bde32aaf17d0de4305d01d706495f09cbf19/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc5319db92daa516b653600794d5b9f9439a9a121f3e162f94b0e1891c7933cb", size = 33028, upload-time = "2025-03-26T14:41:29.141Z" }, - { url = "https://files.pythonhosted.org/packages/60/44/87e77e8476767a4a93f6cf271157c6d948eacec63688c093580af13b04be/google_crc32c-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcdf5a64adb747610140572ed18d011896e3b9ae5195f2514b7ff678c80f1603", size = 38026, upload-time = "2025-03-26T14:41:29.921Z" }, - { url = "https://files.pythonhosted.org/packages/c8/bf/21ac7bb305cd7c1a6de9c52f71db0868e104a5b573a4977cd9d0ff830f82/google_crc32c-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:754561c6c66e89d55754106739e22fdaa93fafa8da7221b29c8b8e8270c6ec8a", size = 33476, upload-time = "2025-03-26T14:29:09.086Z" }, - { url = "https://files.pythonhosted.org/packages/f7/94/220139ea87822b6fdfdab4fb9ba81b3fff7ea2c82e2af34adc726085bffc/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6fbab4b935989e2c3610371963ba1b86afb09537fd0c633049be82afe153ac06", size = 30468, upload-time = "2025-03-26T14:32:52.215Z" }, - { url = "https://files.pythonhosted.org/packages/94/97/789b23bdeeb9d15dc2904660463ad539d0318286d7633fe2760c10ed0c1c/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:ed66cbe1ed9cbaaad9392b5259b3eba4a9e565420d734e6238813c428c3336c9", size = 30313, upload-time = "2025-03-26T14:57:38.758Z" }, - { url = "https://files.pythonhosted.org/packages/81/b8/976a2b843610c211e7ccb3e248996a61e87dbb2c09b1499847e295080aec/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee6547b657621b6cbed3562ea7826c3e11cab01cd33b74e1f677690652883e77", size = 33048, upload-time = "2025-03-26T14:41:30.679Z" }, - { url = "https://files.pythonhosted.org/packages/c9/16/a3842c2cf591093b111d4a5e2bfb478ac6692d02f1b386d2a33283a19dc9/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d68e17bad8f7dd9a49181a1f5a8f4b251c6dbc8cc96fb79f1d321dfd57d66f53", size = 32669, upload-time = "2025-03-26T14:41:31.432Z" }, - { url = "https://files.pythonhosted.org/packages/04/17/ed9aba495916fcf5fe4ecb2267ceb851fc5f273c4e4625ae453350cfd564/google_crc32c-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:6335de12921f06e1f774d0dd1fbea6bf610abe0887a1638f64d694013138be5d", size = 33476, upload-time = "2025-03-26T14:29:10.211Z" }, - { url = "https://files.pythonhosted.org/packages/dd/b7/787e2453cf8639c94b3d06c9d61f512234a82e1d12d13d18584bd3049904/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2d73a68a653c57281401871dd4aeebbb6af3191dcac751a76ce430df4d403194", size = 30470, upload-time = "2025-03-26T14:34:31.655Z" }, - { url = "https://files.pythonhosted.org/packages/ed/b4/6042c2b0cbac3ec3a69bb4c49b28d2f517b7a0f4a0232603c42c58e22b44/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:22beacf83baaf59f9d3ab2bbb4db0fb018da8e5aebdce07ef9f09fce8220285e", size = 30315, upload-time = "2025-03-26T15:01:54.634Z" }, - { url = "https://files.pythonhosted.org/packages/29/ad/01e7a61a5d059bc57b702d9ff6a18b2585ad97f720bd0a0dbe215df1ab0e/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19eafa0e4af11b0a4eb3974483d55d2d77ad1911e6cf6f832e1574f6781fd337", size = 33180, upload-time = "2025-03-26T14:41:32.168Z" }, - { url = "https://files.pythonhosted.org/packages/3b/a5/7279055cf004561894ed3a7bfdf5bf90a53f28fadd01af7cd166e88ddf16/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d86616faaea68101195c6bdc40c494e4d76f41e07a37ffdef270879c15fb65", size = 32794, upload-time = "2025-03-26T14:41:33.264Z" }, - { url = "https://files.pythonhosted.org/packages/0f/d6/77060dbd140c624e42ae3ece3df53b9d811000729a5c821b9fd671ceaac6/google_crc32c-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:b7491bdc0c7564fcf48c0179d2048ab2f7c7ba36b84ccd3a3e1c3f7a72d3bba6", size = 33477, upload-time = "2025-03-26T14:29:10.94Z" }, - { url = "https://files.pythonhosted.org/packages/0b/43/31e57ce04530794917dfe25243860ec141de9fadf4aa9783dffe7dac7c39/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8e9afc74168b0b2232fb32dd202c93e46b7d5e4bf03e66ba5dc273bb3559589", size = 28242, upload-time = "2025-03-26T14:41:42.858Z" }, - { url = "https://files.pythonhosted.org/packages/eb/f3/8b84cd4e0ad111e63e30eb89453f8dd308e3ad36f42305cf8c202461cdf0/google_crc32c-1.7.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa8136cc14dd27f34a3221c0f16fd42d8a40e4778273e61a3c19aedaa44daf6b", size = 28049, upload-time = "2025-03-26T14:41:44.651Z" }, - { url = "https://files.pythonhosted.org/packages/16/1b/1693372bf423ada422f80fd88260dbfd140754adb15cbc4d7e9a68b1cb8e/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85fef7fae11494e747c9fd1359a527e5970fc9603c90764843caabd3a16a0a48", size = 28241, upload-time = "2025-03-26T14:41:45.898Z" }, - { url = "https://files.pythonhosted.org/packages/fd/3c/2a19a60a473de48717b4efb19398c3f914795b64a96cf3fbe82588044f78/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6efb97eb4369d52593ad6f75e7e10d053cf00c48983f7a973105bc70b0ac4d82", size = 28048, upload-time = "2025-03-26T14:41:46.696Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.7.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099, upload-time = "2024-08-07T22:20:38.555Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251, upload-time = "2024-08-07T22:20:36.409Z" }, -] - -[[package]] -name = "googleapis-common-protos" -version = "1.63.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4f/bc/cb5c74fca58d9c37bc621642e2c2b19c004d078b472d49fb03d9fa8ffeef/googleapis-common-protos-1.63.1.tar.gz", hash = "sha256:c6442f7a0a6b2a80369457d79e6672bb7dcbaab88e0848302497e3ec80780a6a", size = 121632, upload-time = "2024-06-03T16:14:15.453Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/87/1608d23bb9879694579fff5dc56d60e3d48e012fd08670f140cf82f6cf26/googleapis_common_protos-1.63.1-py2.py3-none-any.whl", hash = "sha256:0e1c2cdfcbc354b76e4a211a35ea35d6926a835cba1377073c4861db904a1877", size = 229151, upload-time = "2024-06-03T16:14:13.169Z" }, -] - [[package]] name = "griffe" version = "1.14.0" @@ -1343,116 +1160,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/8e/2844c3959ce9a63acc7c8e50881133d86666f0420bcde695e115ced0920f/numpy-2.3.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:81b3a59793523e552c4a96109dde028aa4448ae06ccac5a76ff6532a85558a7f", size = 12973130, upload-time = "2025-10-15T16:18:09.397Z" }, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, -] - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, -] - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, -] - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, -] - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.21.5" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, -] - [[package]] name = "onnx" version = "1.7.0" @@ -1755,18 +1462,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/a4/7c50e6992a5c6664e30c65b3e4884e93e19525eb66afbcda6c545c6cfbea/prek-0.2.10-py3-none-win_arm64.whl", hash = "sha256:62d77b3dce2eaf7f69f175a3bf6c95e351d4b55fdd8f5b31f9a739713c472c26", size = 4498683, upload-time = "2025-10-18T12:59:37.946Z" }, ] -[[package]] -name = "proto-plus" -version = "1.26.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, -] - [[package]] name = "protobuf" version = "3.19.6" @@ -1789,27 +1484,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708, upload-time = "2021-11-04T17:17:00.152Z" }, ] -[[package]] -name = "pyasn1" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, -] - -[[package]] -name = "pyasn1-modules" -version = "0.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, -] - [[package]] name = "pycocotools" version = "2.0.10" @@ -2174,18 +1848,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" }, ] -[[package]] -name = "rsa" -version = "4.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, -] - [[package]] name = "ruff" version = "0.14.1" @@ -2362,15 +2024,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/ad/13646b9beb0a95528ca46d52b7babafbe115017814a611f2065ee4e61d20/scipy-1.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a8ffaa4ac0df81a0b94577b18ee079f13fecdb924df3328fc44a7dc5ac46851", size = 25456070, upload-time = "2025-09-11T17:41:41.3Z" }, ] -[[package]] -name = "setuptools" -version = "80.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, -] - [[package]] name = "simsimd" version = "6.5.3" @@ -2541,15 +2194,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] -[[package]] -name = "termcolor" -version = "3.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/6c/3d75c196ac07ac8749600b60b03f4f6094d54e132c4d94ebac6ee0e0add0/termcolor-3.1.0.tar.gz", hash = "sha256:6a6dd7fbee581909eeec6a756cff1d7f7c376063b14e4a298dc4980309e55970", size = 14324, upload-time = "2025-04-30T11:37:53.791Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/bd/de8d508070629b6d84a30d01d57e4a65c69aa7f5abe7560b8fad3b50ea59/termcolor-3.1.0-py3-none-any.whl", hash = "sha256:591dd26b5c2ce03b9e43f391264626557873ce1d379019786f99b0c2bee140aa", size = 7684, upload-time = "2025-04-30T11:37:52.382Z" }, -] - [[package]] name = "terminaltables" version = "3.1.10" @@ -2630,135 +2274,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, ] -[[package]] -name = "torch" -version = "2.5.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", -] -dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "fsspec", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "jinja2", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "sympy", version = "1.13.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:71328e1bbe39d213b8721678f9dcac30dfc452a46d586f1d514a6aa0a99d4744", size = 906446312, upload-time = "2024-10-29T17:33:38.045Z" }, - { url = "https://files.pythonhosted.org/packages/69/f0/46e74e0d145f43fa506cb336eaefb2d240547e4ce1f496e442711093ab25/torch-2.5.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:34bfa1a852e5714cbfa17f27c49d8ce35e1b7af5608c4bc6e81392c352dbc601", size = 91919522, upload-time = "2024-10-29T17:39:08.74Z" }, - { url = "https://files.pythonhosted.org/packages/a5/13/1eb674c8efbd04d71e4a157ceba991904f633e009a584dd65dccbafbb648/torch-2.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:32a037bd98a241df6c93e4c789b683335da76a2ac142c0973675b715102dc5fa", size = 203088048, upload-time = "2024-10-29T17:34:10.913Z" }, - { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467, upload-time = "2024-10-29T17:38:49.832Z" }, - { url = "https://files.pythonhosted.org/packages/40/04/bd91593a4ca178ece93ca55f27e2783aa524aaccbfda66831d59a054c31e/torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9", size = 91919450, upload-time = "2024-10-29T17:37:26.693Z" }, - { url = "https://files.pythonhosted.org/packages/0d/4a/e51420d46cfc90562e85af2fee912237c662ab31140ab179e49bd69401d6/torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a", size = 203098237, upload-time = "2024-10-29T17:36:11.731Z" }, - { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343, upload-time = "2024-10-29T17:37:06.758Z" }, - { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673, upload-time = "2024-10-29T17:32:42.789Z" }, - { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841, upload-time = "2024-10-29T17:35:48.665Z" }, -] - -[[package]] -name = "torch" -version = "2.9.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'darwin')" }, - { name = "sympy", version = "1.14.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/86/245c240d2138c17ed572c943c289056c2721abab70810d772c6bf5495b28/torch-2.9.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:030bbfe367379ae6a4ae4042b6c44da25383343b8b3c68abaa9c7231efbaf2dd", size = 104213554, upload-time = "2025-10-15T15:45:59.798Z" }, - { url = "https://files.pythonhosted.org/packages/58/b0/2b4e647b0fc706e88eb6c253d05511865578f5f67b55fad639bf3272a4a1/torch-2.9.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:413e1654c9203733138858780e184d9fc59442f0b3b209e16f39354eb893db9b", size = 74452019, upload-time = "2025-10-15T15:46:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/58/fe/334225e6330e672b36aef23d77451fa906ea12881570c08638a91331a212/torch-2.9.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:c596708b5105d0b199215acf0c9be7c1db5f1680d88eddadf4b75a299259a677", size = 104230578, upload-time = "2025-10-15T15:46:08.182Z" }, - { url = "https://files.pythonhosted.org/packages/b3/b7/205ef3e94de636feffd64b28bb59a0dfac0771221201b9871acf9236f5ca/torch-2.9.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:614a185e4986326d526a91210c8fc1397e76e8cfafa78baf6296a790e53a9eec", size = 74463678, upload-time = "2025-10-15T15:46:29.779Z" }, - { url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" }, - { url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" }, -] - -[[package]] -name = "torchvision" -version = "0.20.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", -] -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/f6/7ff89a9f8703f623f5664afd66c8600e3f09fe188e1e0b7e6f9a8617f865/torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ffbdf8bf5b30eade22d459f5a313329eeadb20dc75efa142987b53c007098c3", size = 7238975, upload-time = "2024-10-29T17:41:03.374Z" }, - { url = "https://files.pythonhosted.org/packages/f7/ce/4c31e9b96cc4f9fec746b258d2aa35f8d1247f4f58d63f9c505ea5eb254d/torchvision-0.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:75f8a4d51a593c4bab6c9bf7d75bdd88691b00a53b07656678bc55a3a753dd73", size = 14265343, upload-time = "2024-10-29T17:40:57.799Z" }, - { url = "https://files.pythonhosted.org/packages/17/11/b5ce67715bbbec8798fb48c4a20ac28828aec1710ac01091a3eddcb8e075/torchvision-0.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:22c2fa44e20eb404b85e42b22b453863a14b0927d25e550fd4f84eea97fa5b39", size = 1562413, upload-time = "2024-10-29T17:40:39.991Z" }, - { url = "https://files.pythonhosted.org/packages/de/e9/e190ecec448d5a2abad8348cf085fcb39962a491e3f40dcb023721e04feb/torchvision-0.20.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:86f6523dee420000fe14c3527f6c8e0175139fda7d995b187f54a0b0ebec7eb6", size = 7241222, upload-time = "2024-10-29T17:40:38.056Z" }, - { url = "https://files.pythonhosted.org/packages/b1/a3/cbb8177e5e379f0c040b00c6f80f14d323a97e30495d7115d169b101b2f7/torchvision-0.20.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:a40d766345927639da322c693934e5f91b1ba2218846c7104b868dea2314ce8e", size = 14267510, upload-time = "2024-10-29T17:40:53.031Z" }, - { url = "https://files.pythonhosted.org/packages/69/55/ce836703ff77bb21582c3098d5311f8ddde7eadc7eab04be9561961f4725/torchvision-0.20.1-cp311-cp311-win_amd64.whl", hash = "sha256:5b501d5c04b034d2ecda96a31ed050e383cf8201352e4c9276ca249cbecfded0", size = 1562402, upload-time = "2024-10-29T17:40:49.052Z" }, - { url = "https://files.pythonhosted.org/packages/d4/75/00a852275ade58d3dc474530f7a7b6bc999a817148f0eb59d4fde12eb955/torchvision-0.20.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:17cd78adddf81dac57d7dccc9277a4d686425b1c55715f308769770cb26cad5c", size = 7240323, upload-time = "2024-10-29T17:40:44.951Z" }, - { url = "https://files.pythonhosted.org/packages/af/f0/ca1445406eb12cbeb7a41fc833a1941ede78e7c55621198b83ecd7bcfd0f/torchvision-0.20.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:9f853ba4497ac4691815ad41b523ee23cf5ba4f87b1ce869d704052e233ca8b7", size = 14266936, upload-time = "2024-10-29T17:40:31.335Z" }, - { url = "https://files.pythonhosted.org/packages/c3/18/00993d420b1d6e88582e51d4bc82c824c99a2e9c045d50eaf9b34fff729a/torchvision-0.20.1-cp312-cp312-win_amd64.whl", hash = "sha256:4a330422c36dbfc946d3a6c1caec3489db07ecdf3675d83369adb2e5a0ca17c4", size = 1562392, upload-time = "2024-10-29T17:40:47.6Z" }, -] - -[[package]] -name = "torchvision" -version = "0.24.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'darwin')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/5b/1404eeab00819df71a30e916c2081654366741f7838fcc4fff86b7bd9e7e/torchvision-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e8d5e667deff87bd66d26df6d225f46224bb0782d4f3f8f5d2f3068b5fd4492", size = 1891723, upload-time = "2025-10-15T15:51:08.5Z" }, - { url = "https://files.pythonhosted.org/packages/88/e3/1b003ecd52bd721f8304aeb66691edfbc2002747ec83d36188ad6abab506/torchvision-0.24.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a110a51c75e89807a8382b0d8034f5e180fb9319570be3389ffd3d4ac4fd57a9", size = 2418988, upload-time = "2025-10-15T15:51:25.195Z" }, - { url = "https://files.pythonhosted.org/packages/a3/17/54ed2ec6944ea972b461a86424c8c7f98835982c90cbc45bf59bd962863a/torchvision-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f771cf918351ad509a28488be475f3e9cc71a750d6b1467842bfb64863a5e986", size = 1891719, upload-time = "2025-10-15T15:51:10.384Z" }, - { url = "https://files.pythonhosted.org/packages/f8/07/0cd6776eee784742ad3cb2bfd3295383d84cb2f9e87386119333d1587f0f/torchvision-0.24.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbd63bf4ebff84c48c50123eba90526cc9f794fe45bc9f5dd07cec19e8c62bce", size = 2420513, upload-time = "2025-10-15T15:51:18.087Z" }, - { url = "https://files.pythonhosted.org/packages/47/ef/81e4e69e02e2c4650b30e8c11c8974f946682a30e0ab7e9803a831beff76/torchvision-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c61d40bcd2e2451e932902a702ad495ba1ec6f279e90b1e15cef2bb55dc911e2", size = 1891726, upload-time = "2025-10-15T15:51:16.977Z" }, - { url = "https://files.pythonhosted.org/packages/00/7b/e3809b3302caea9a12c13f3adebe4fef127188438e719fd6c8dc93db1da6/torchvision-0.24.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b0531d1483fc322d7da0d83be52f0df860a75114ab87dbeeb9de765feaeda843", size = 2419495, upload-time = "2025-10-15T15:51:11.885Z" }, -] - [[package]] name = "tqdm" version = "4.67.1" @@ -2771,19 +2286,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] -[[package]] -name = "triton" -version = "3.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013, upload-time = "2024-10-14T16:05:32.106Z" }, - { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424, upload-time = "2024-10-14T16:05:42.337Z" }, - { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444, upload-time = "2024-10-14T16:05:53.433Z" }, -] - [[package]] name = "typing" version = "3.10.0.0" @@ -2841,32 +2343,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] -[[package]] -name = "viscv" -version = "0.1.0" -source = { editable = "libs/viscv" } -dependencies = [ - { name = "matplotlib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "opencv-python-headless" }, - { name = "pillow" }, - { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torchvision", version = "0.20.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "visengine" }, -] - -[package.metadata] -requires-dist = [ - { name = "matplotlib", specifier = ">=3.6.0" }, - { name = "numpy", specifier = ">=2.0.0" }, - { name = "opencv-python-headless", specifier = ">=4.8.0" }, - { name = "pillow", specifier = ">=10.0.0" }, - { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", specifier = "==2.5.1" }, - { name = "torchvision", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", specifier = "==0.20.1" }, - { name = "visengine", editable = "libs/visengine" }, -] - [[package]] name = "visdet" version = "2.28.4" @@ -2983,47 +2459,6 @@ dev = [ { name = "zuban", specifier = ">=0.1.2" }, ] -[[package]] -name = "visengine" -version = "0.1.0" -source = { editable = "libs/visengine" } -dependencies = [ - { name = "addict" }, - { name = "bitsandbytes", marker = "sys_platform == 'linux'" }, - { name = "cloudpathlib" }, - { name = "google-cloud-storage" }, - { name = "matplotlib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "opencv-python-headless" }, - { name = "pyyaml" }, - { name = "rich" }, - { name = "termcolor" }, - { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torchvision", version = "0.20.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "tqdm" }, - { name = "yapf" }, -] - -[package.metadata] -requires-dist = [ - { name = "addict", specifier = ">=2.4.0" }, - { name = "bitsandbytes", marker = "sys_platform == 'linux'", specifier = ">=0.41.0" }, - { name = "cloudpathlib", specifier = ">=0.18.1" }, - { name = "google-cloud-storage", specifier = ">=2.18.2" }, - { name = "matplotlib", specifier = ">=3.6.0" }, - { name = "numpy", specifier = ">=2.0.0" }, - { name = "opencv-python-headless", specifier = ">=4.8.0" }, - { name = "pyyaml", specifier = ">=6.0.0" }, - { name = "rich", specifier = ">=13.0.0" }, - { name = "termcolor", specifier = ">=2.0.0" }, - { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", specifier = "==2.5.1" }, - { name = "torchvision", specifier = ">=0.19.1" }, - { name = "tqdm", specifier = ">=4.60.0" }, - { name = "yapf", specifier = ">=0.30.0" }, -] - [[package]] name = "watchdog" version = "6.0.0" diff --git a/visdet/visdet/__init__.py b/visdet/__init__.py similarity index 100% rename from visdet/visdet/__init__.py rename to visdet/__init__.py diff --git a/visdet/visdet/apis/__init__.py b/visdet/apis/__init__.py similarity index 100% rename from visdet/visdet/apis/__init__.py rename to visdet/apis/__init__.py diff --git a/visdet/visdet/apis/det_inferencer.py b/visdet/apis/det_inferencer.py similarity index 100% rename from visdet/visdet/apis/det_inferencer.py rename to visdet/apis/det_inferencer.py diff --git a/visdet/visdet/apis/inference.py b/visdet/apis/inference.py similarity index 100% rename from visdet/visdet/apis/inference.py rename to visdet/apis/inference.py diff --git a/visdet/visdet/configs/__init__.py b/visdet/configs/__init__.py similarity index 100% rename from visdet/visdet/configs/__init__.py rename to visdet/configs/__init__.py diff --git a/visdet/visdet/configs/_base_/datasets/coco_instance.py b/visdet/configs/_base_/datasets/coco_instance.py similarity index 100% rename from visdet/visdet/configs/_base_/datasets/coco_instance.py rename to visdet/configs/_base_/datasets/coco_instance.py diff --git a/visdet/visdet/configs/_base_/default_runtime.py b/visdet/configs/_base_/default_runtime.py similarity index 100% rename from visdet/visdet/configs/_base_/default_runtime.py rename to visdet/configs/_base_/default_runtime.py diff --git a/visdet/visdet/configs/_base_/models/cascade-mask-rcnn_r50_fpn.py b/visdet/configs/_base_/models/cascade-mask-rcnn_r50_fpn.py similarity index 100% rename from visdet/visdet/configs/_base_/models/cascade-mask-rcnn_r50_fpn.py rename to visdet/configs/_base_/models/cascade-mask-rcnn_r50_fpn.py diff --git a/visdet/visdet/configs/_base_/models/mask-rcnn_r50_fpn.py b/visdet/configs/_base_/models/mask-rcnn_r50_fpn.py similarity index 100% rename from visdet/visdet/configs/_base_/models/mask-rcnn_r50_fpn.py rename to visdet/configs/_base_/models/mask-rcnn_r50_fpn.py diff --git a/visdet/visdet/configs/_base_/schedules/schedule_1x.py b/visdet/configs/_base_/schedules/schedule_1x.py similarity index 100% rename from visdet/visdet/configs/_base_/schedules/schedule_1x.py rename to visdet/configs/_base_/schedules/schedule_1x.py diff --git a/visdet/visdet/configs/swin/cascade-mask-rcnn_swin-s-p4-w7_fpn_ms-crop-3x_coco.py b/visdet/configs/swin/cascade-mask-rcnn_swin-s-p4-w7_fpn_ms-crop-3x_coco.py similarity index 100% rename from visdet/visdet/configs/swin/cascade-mask-rcnn_swin-s-p4-w7_fpn_ms-crop-3x_coco.py rename to visdet/configs/swin/cascade-mask-rcnn_swin-s-p4-w7_fpn_ms-crop-3x_coco.py diff --git a/visdet/visdet/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py b/visdet/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py similarity index 100% rename from visdet/visdet/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py rename to visdet/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py diff --git a/visdet/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_amp-ms-crop-3x_coco.py b/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_amp-ms-crop-3x_coco.py similarity index 100% rename from visdet/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_amp-ms-crop-3x_coco.py rename to visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_amp-ms-crop-3x_coco.py diff --git a/visdet/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py b/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py similarity index 100% rename from visdet/visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py rename to visdet/configs/swin/mask-rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py diff --git a/visdet/visdet/configs/swin/metafile.yml b/visdet/configs/swin/metafile.yml similarity index 100% rename from visdet/visdet/configs/swin/metafile.yml rename to visdet/configs/swin/metafile.yml diff --git a/visdet/cv/__init__.py b/visdet/cv/__init__.py new file mode 100644 index 0000000..3a9b674 --- /dev/null +++ b/visdet/cv/__init__.py @@ -0,0 +1,18 @@ +# ruff: noqa +""" +Computer Vision utilities. + +This module provides access to computer vision functionality through visdet.cv +for better namespace organization and discoverability. + +Usage: + from visdet import cv + from visdet.cv import image, cnn, transforms, ops, fileio +""" + +# Import submodules to make them accessible under the `visdet.cv` namespace +# (e.g., `visdet.cv.image`) +from . import cnn, fileio, image, ops, transforms # noqa: F401 +from .image import imfrombytes, imwrite # noqa: F401 + +__all__ = ["cnn", "fileio", "image", "ops", "transforms", "imfrombytes", "imwrite"] diff --git a/visdet/cv/cnn/__init__.py b/visdet/cv/cnn/__init__.py new file mode 100644 index 0000000..f2941d7 --- /dev/null +++ b/visdet/cv/cnn/__init__.py @@ -0,0 +1,24 @@ +# ruff: noqa +""" +CNN building blocks and utilities module. + +This module provides neural network components and layers. +""" + +from .bricks import ( # noqa: F401 + ConvModule, + build_norm_layer, + build_conv_layer, + build_upsample_layer, + FFN, + build_dropout, +) + +__all__ = [ + "ConvModule", + "build_norm_layer", + "build_conv_layer", + "build_upsample_layer", + "FFN", + "build_dropout", +] diff --git a/visdet/cv/cnn/bricks/__init__.py b/visdet/cv/cnn/bricks/__init__.py new file mode 100644 index 0000000..a18a3cf --- /dev/null +++ b/visdet/cv/cnn/bricks/__init__.py @@ -0,0 +1,21 @@ +# ruff: noqa +""" +CNN building blocks module. + +This module provides basic neural network components like convolutions, +normalizations, activations, and attention mechanisms. +""" + +from .builder import build_conv_layer, build_upsample_layer # noqa: F401 +from .conv import ConvModule # noqa: F401 +from .norm import build_norm_layer # noqa: F401 +from .transformer import FFN, build_dropout # noqa: F401 + +__all__ = [ + "ConvModule", + "build_norm_layer", + "build_conv_layer", + "build_upsample_layer", + "FFN", + "build_dropout", +] diff --git a/visdet/cv/cnn/bricks/builder.py b/visdet/cv/cnn/bricks/builder.py new file mode 100644 index 0000000..0233768 --- /dev/null +++ b/visdet/cv/cnn/bricks/builder.py @@ -0,0 +1,62 @@ +# ruff: noqa +""" +Builder functions for CNN components. + +This module provides functions to build standard CNN layers. +""" + +from typing import Dict, Optional, Tuple, Union +import torch.nn as nn + +from .conv import ConvModule + + +def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: + """Build a conv layer. + + Args: + cfg: Config dict for the conv layer + *args: Positional arguments passed to the conv layer + **kwargs: Keyword arguments passed to the conv layer + + Returns: + Built convolutional layer + """ + if cfg is None: + return nn.Conv2d(*args, **kwargs) + + cfg = cfg.copy() + layer_type = cfg.pop("type", "Conv2d") + + if layer_type == "Conv2d": + return nn.Conv2d(*args, **kwargs) + elif layer_type == "ConvModule": + return ConvModule(*args, **kwargs, **cfg) + else: + return nn.Conv2d(*args, **kwargs) + + +def build_upsample_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: + """Build an upsample layer. + + Args: + cfg: Config dict for the upsample layer + *args: Positional arguments passed to the upsample layer + **kwargs: Keyword arguments passed to the upsample layer + + Returns: + Built upsampling layer + """ + if cfg is None: + return nn.Upsample(*args, **kwargs) + + cfg = cfg.copy() + layer_type = cfg.pop("type", "nearest") + + if isinstance(layer_type, str): + return nn.Upsample(mode=layer_type, *args, **kwargs) + else: + return nn.Upsample(*args, **kwargs) + + +__all__ = ["build_conv_layer", "build_upsample_layer"] diff --git a/visdet/cv/cnn/bricks/conv.py b/visdet/cv/cnn/bricks/conv.py new file mode 100644 index 0000000..2da6c84 --- /dev/null +++ b/visdet/cv/cnn/bricks/conv.py @@ -0,0 +1,114 @@ +# ruff: noqa +""" +Convolutional modules and utilities. + +This module provides convolutional building blocks. +""" + +from typing import Dict, Optional, Tuple, Union +import torch.nn as nn + +from .norm import build_norm_layer + + +class ConvModule(nn.Module): + """Convolutional module with optional normalization and activation. + + A standard convolutional layer with optional batch normalization and activation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + norm_cfg: Optional[Dict] = None, + act_cfg: Optional[Dict] = None, + inplace: bool = False, + **kwargs, + ) -> None: + """Initialize ConvModule. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of convolutional kernel + stride: Stride of convolution + padding: Padding for convolution + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + norm_cfg: Config dict for normalization layer + act_cfg: Config dict for activation layer + inplace: Whether to use inplace operations + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.with_bias = bias + self.with_norm = norm_cfg is not None + self.with_act = act_cfg is not None + self.inplace = inplace + + # Build conv layer + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs, + ) + + # Build norm layer + if self.with_norm: + norm_name, norm_layer = build_norm_layer(norm_cfg, out_channels) + self.add_module(norm_name if norm_name else "bn", norm_layer) + + # Build activation layer + if self.with_act: + if isinstance(act_cfg, dict): + act_cfg = act_cfg.copy() + act_type = act_cfg.pop("type", "ReLU") + if act_type == "ReLU": + self.activate = nn.ReLU(inplace=inplace) + elif act_type == "GELU": + self.activate = nn.GELU() + elif act_type == "SiLU": + self.activate = nn.SiLU(inplace=inplace) + else: + self.activate = nn.ReLU(inplace=inplace) + else: + self.activate = nn.ReLU(inplace=inplace) + else: + self.activate = None + + def forward(self, x): + """Forward pass.""" + x = self.conv(x) + if self.with_norm: + # Apply norm layer + for m in self.modules(): + if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + x = m(x) + break + if self.with_act: + x = self.activate(x) + return x + + +__all__ = ["ConvModule"] diff --git a/visdet/cv/cnn/bricks/norm.py b/visdet/cv/cnn/bricks/norm.py new file mode 100644 index 0000000..3677086 --- /dev/null +++ b/visdet/cv/cnn/bricks/norm.py @@ -0,0 +1,41 @@ +# ruff: noqa +"""Normalization layers for visdet.""" + +from typing import Any, Optional, Tuple + +import torch.nn as nn + + +def build_norm_layer(cfg: dict, num_features: int) -> Tuple[Optional[str], Optional[nn.Module]]: + """Build normalization layer. + + Args: + cfg: Config dict with keys: + type: Type of norm layer (e.g., 'BN', 'LN', 'GN', 'SyncBN') + **kwargs: Additional arguments for the layer + num_features: Number of features (channels) + + Returns: + Tuple of (layer_name, layer_module) + """ + if cfg is None: + return None, None + + cfg = cfg.copy() + layer_type = cfg.pop("type", "BN") + + if layer_type == "BN": + return "bn", nn.BatchNorm2d(num_features, **cfg) + elif layer_type == "SyncBN": + # For now, use regular BatchNorm + return "bn", nn.BatchNorm2d(num_features, **cfg) + elif layer_type == "GN": + num_groups = cfg.pop("num_groups", 32) + return "gn", nn.GroupNorm(num_groups, num_features, **cfg) + elif layer_type == "LN": + return "ln", nn.LayerNorm(num_features, **cfg) + else: + raise ValueError(f"Unsupported norm layer type: {layer_type}") + + +__all__ = ["build_norm_layer"] diff --git a/visdet/cv/cnn/bricks/transformer.py b/visdet/cv/cnn/bricks/transformer.py new file mode 100644 index 0000000..8920378 --- /dev/null +++ b/visdet/cv/cnn/bricks/transformer.py @@ -0,0 +1,82 @@ +# ruff: noqa +""" +Transformer and attention modules. + +This module provides transformer layers and attention mechanisms. +""" + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + + +class FFN(nn.Module): + """Feed-Forward Network module. + + Simple MLP for transformer blocks. + """ + + def __init__( + self, + embed_dims: int, + feedforward_channels: int, + num_fcs: int = 2, + act_cfg: Optional[Dict] = None, + dropout_cfg: Optional[Dict] = None, + add_identity: bool = False, + ) -> None: + """Initialize FFN module.""" + super().__init__() + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + + layers = [] + in_channels = embed_dims + for i in range(num_fcs - 1): + layers.append(nn.Linear(in_channels, feedforward_channels)) + layers.append(nn.GELU()) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + + self.layers = nn.Sequential(*layers) + self.add_identity = add_identity + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + out = self.layers(x) + if self.add_identity: + out = out + x + return out + + +def build_dropout(dropout_cfg: Optional[Dict] = None) -> Optional[nn.Module]: + """Build dropout layer. + + Args: + dropout_cfg: Config dict with 'type' and other parameters + + Returns: + Dropout module or None + """ + if dropout_cfg is None or dropout_cfg.get("type") == "None": + return None + + if isinstance(dropout_cfg, dict): + dropout_cfg = dropout_cfg.copy() + dropout_type = dropout_cfg.pop("type", "Dropout") + p = dropout_cfg.pop("p", dropout_cfg.pop("drop_prob", 0.0)) + + if dropout_type == "Dropout": + return nn.Dropout(p) + elif dropout_type == "DropPath": + return nn.Dropout(p) # Fallback to Dropout + else: + return nn.Dropout(p) + elif isinstance(dropout_cfg, float): + return nn.Dropout(dropout_cfg) + + return None + + +__all__ = ["FFN", "build_dropout"] diff --git a/visdet/cv/fileio.py b/visdet/cv/fileio.py new file mode 100644 index 0000000..7e2b6b8 --- /dev/null +++ b/visdet/cv/fileio.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +File I/O utilities module. + +This module provides file I/O and path handling utilities. +""" + +from visdet.cv.fileio import * # noqa: F401, F403 diff --git a/visdet/cv/image.py b/visdet/cv/image.py new file mode 100644 index 0000000..852da8a --- /dev/null +++ b/visdet/cv/image.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Image processing utilities module. + +This module provides image I/O and processing functions. +""" + +from visdet.cv.image import * # noqa: F401, F403 diff --git a/visdet/cv/image/__init__.py b/visdet/cv/image/__init__.py new file mode 100644 index 0000000..62b867f --- /dev/null +++ b/visdet/cv/image/__init__.py @@ -0,0 +1,10 @@ +# ruff: noqa +""" +Image processing utilities. + +This module provides image I/O and processing functions. +""" + +from .io import imfrombytes, imwrite # noqa: F401 + +__all__ = ["imfrombytes", "imwrite"] diff --git a/visdet/cv/image/io.py b/visdet/cv/image/io.py new file mode 100644 index 0000000..1f2725b --- /dev/null +++ b/visdet/cv/image/io.py @@ -0,0 +1,54 @@ +# ruff: noqa +""" +Image input/output utilities. + +This module provides functions for reading and writing images. +""" + +from typing import Optional, Union +import io +import numpy as np +import cv2 + + +def imfrombytes( + content: bytes, + flag: str = "color", + channel_order: str = "bgr", +) -> np.ndarray: + """Read an image from bytes. + + Args: + content: Image bytes + flag: How to read the image ('color', 'grayscale', etc) + channel_order: Channel order ('bgr' or 'rgb') + + Returns: + Numpy array of the image + """ + img_array = np.frombuffer(content, dtype=np.uint8) + imread_flag = cv2.IMREAD_COLOR if flag == "color" else cv2.IMREAD_GRAYSCALE + img = cv2.imdecode(img_array, imread_flag) + + if channel_order == "rgb" and flag == "color": + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + return img + + +def imwrite( + img: np.ndarray, + file_path: str, + params: Optional[list] = None, +) -> None: + """Write an image to file. + + Args: + img: Image array (BGR format) + file_path: Path to save the image + params: Encoding parameters for cv2.imwrite + """ + cv2.imwrite(file_path, img, params) + + +__all__ = ["imfrombytes", "imwrite"] diff --git a/visdet/cv/ops/__init__.py b/visdet/cv/ops/__init__.py new file mode 100644 index 0000000..a0a0075 --- /dev/null +++ b/visdet/cv/ops/__init__.py @@ -0,0 +1,10 @@ +# ruff: noqa +""" +Computer Vision operations module. + +This module provides various CV operations like NMS and ROI operations. +""" + +from .nms import batched_nms, nms, box_iou # noqa: F401 + +__all__ = ["batched_nms", "nms", "box_iou"] diff --git a/visdet/cv/ops/nms.py b/visdet/cv/ops/nms.py new file mode 100644 index 0000000..9212b9d --- /dev/null +++ b/visdet/cv/ops/nms.py @@ -0,0 +1,132 @@ +# ruff: noqa +""" +Non-Maximum Suppression (NMS) operations. + +This module provides NMS and related suppression algorithms. +""" + +from typing import Dict, Tuple, Union +import torch +from torch import Tensor + + +def batched_nms( + boxes: Tensor, + scores: Tensor, + idxs: Tensor, + nms_cfg: Union[Dict, float], +) -> Tuple[Tensor, Tensor]: + """Apply NMS per class to boxes. + + Args: + boxes: Tensor of shape (N, 4) with box coordinates [x1, y1, x2, y2] + scores: Tensor of shape (N,) with detection scores + idxs: Tensor of shape (N,) with class indices + nms_cfg: NMS configuration dict with 'iou_threshold' or a float iou_threshold + + Returns: + Tuple of: + - dets: Tensor of shape (K, 5) with [x1, y1, x2, y2, score] + - keep: Tensor of shape (K,) with indices of kept boxes + """ + # Extract iou_threshold from nms_cfg + if isinstance(nms_cfg, dict): + iou_threshold = nms_cfg.get("iou_threshold", 0.5) + else: + iou_threshold = nms_cfg + + # Get unique class indices + unique_classes = torch.unique(idxs, sorted=False) + keep_list = [] + + for class_id in unique_classes: + # Get mask for this class + class_mask = idxs == class_id + class_boxes = boxes[class_mask] + class_scores = scores[class_mask] + class_inds = torch.where(class_mask)[0] + + # Apply NMS for this class + keep_class = nms(class_boxes, class_scores, iou_threshold) + keep_list.append(class_inds[keep_class]) + + # Combine kept indices from all classes + if len(keep_list) > 0: + keep = torch.cat(keep_list) + # Sort by original order + keep = keep.sort()[0] + else: + keep = torch.empty(0, dtype=torch.long, device=boxes.device) + + # Create detections with scores + if keep.numel() > 0: + kept_boxes = boxes[keep] + kept_scores = scores[keep] + dets = torch.cat([kept_boxes, kept_scores.unsqueeze(1)], dim=1) + else: + dets = torch.empty((0, 5), dtype=boxes.dtype, device=boxes.device) + + return dets, keep + + +def nms( + boxes: Tensor, + scores: Tensor, + iou_threshold: float = 0.5, +) -> Tensor: + """Apply NMS to boxes. + + Args: + boxes: Tensor of shape (N, 4) with box coordinates [x1, y1, x2, y2] + scores: Tensor of shape (N,) with detection scores + iou_threshold: IoU threshold for suppression + + Returns: + Tensor of shape (K,) with indices of kept boxes + """ + if boxes.shape[0] == 0: + return torch.empty(0, dtype=torch.long, device=boxes.device) + + # Sort by scores in descending order + sorted_scores, sorted_inds = scores.sort(descending=True) + sorted_boxes = boxes[sorted_inds] + + # Compute pairwise IoU + ious = box_iou(sorted_boxes, sorted_boxes) + + # Suppress boxes + keep = torch.ones(len(sorted_boxes), dtype=torch.bool, device=boxes.device) + for i in range(len(sorted_boxes)): + if not keep[i]: + continue + # Suppress all boxes with IoU > threshold + keep[i + 1 :] &= ious[i, i + 1 :] <= iou_threshold + + return sorted_inds[keep] + + +def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """Compute pairwise IoU between two sets of boxes. + + Args: + boxes1: Tensor of shape (N, 4) with box coordinates [x1, y1, x2, y2] + boxes2: Tensor of shape (M, 4) with box coordinates [x1, y1, x2, y2] + + Returns: + Tensor of shape (N, M) with pairwise IoU values + """ + area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) + area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) + + lt = torch.max(boxes1[:, :2].unsqueeze(1), boxes2[:, :2]) # [N, M, 2] + rb = torch.min(boxes1[:, 2:].unsqueeze(1), boxes2[:, 2:]) # [N, M, 2] + wh = (rb - lt).clamp(min=0) # [N, M, 2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N, M] + + union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter + iou = inter / union + + return iou + + +__all__ = ["batched_nms", "nms", "box_iou"] diff --git a/visdet/cv/transforms.py b/visdet/cv/transforms.py new file mode 100644 index 0000000..4e4dff7 --- /dev/null +++ b/visdet/cv/transforms.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Computer Vision transforms module. + +This module provides image transformation utilities for data processing. +""" + +from visdet.cv.transforms import * # noqa: F401, F403 diff --git a/visdet/visdet/datasets/__init__.py b/visdet/datasets/__init__.py similarity index 100% rename from visdet/visdet/datasets/__init__.py rename to visdet/datasets/__init__.py diff --git a/visdet/visdet/datasets/api_wrappers/__init__.py b/visdet/datasets/api_wrappers/__init__.py similarity index 100% rename from visdet/visdet/datasets/api_wrappers/__init__.py rename to visdet/datasets/api_wrappers/__init__.py diff --git a/visdet/visdet/datasets/api_wrappers/coco_api.py b/visdet/datasets/api_wrappers/coco_api.py similarity index 100% rename from visdet/visdet/datasets/api_wrappers/coco_api.py rename to visdet/datasets/api_wrappers/coco_api.py diff --git a/visdet/visdet/datasets/api_wrappers/cocoeval_mp.py b/visdet/datasets/api_wrappers/cocoeval_mp.py similarity index 100% rename from visdet/visdet/datasets/api_wrappers/cocoeval_mp.py rename to visdet/datasets/api_wrappers/cocoeval_mp.py diff --git a/visdet/visdet/datasets/base_det_dataset.py b/visdet/datasets/base_det_dataset.py similarity index 100% rename from visdet/visdet/datasets/base_det_dataset.py rename to visdet/datasets/base_det_dataset.py diff --git a/visdet/visdet/datasets/coco.py b/visdet/datasets/coco.py similarity index 100% rename from visdet/visdet/datasets/coco.py rename to visdet/datasets/coco.py diff --git a/visdet/visdet/datasets/samplers/__init__.py b/visdet/datasets/samplers/__init__.py similarity index 100% rename from visdet/visdet/datasets/samplers/__init__.py rename to visdet/datasets/samplers/__init__.py diff --git a/visdet/visdet/datasets/samplers/class_aware_sampler.py b/visdet/datasets/samplers/class_aware_sampler.py similarity index 100% rename from visdet/visdet/datasets/samplers/class_aware_sampler.py rename to visdet/datasets/samplers/class_aware_sampler.py diff --git a/visdet/visdet/datasets/transforms/__init__.py b/visdet/datasets/transforms/__init__.py similarity index 100% rename from visdet/visdet/datasets/transforms/__init__.py rename to visdet/datasets/transforms/__init__.py diff --git a/visdet/visdet/datasets/transforms/formatting.py b/visdet/datasets/transforms/formatting.py similarity index 100% rename from visdet/visdet/datasets/transforms/formatting.py rename to visdet/datasets/transforms/formatting.py diff --git a/visdet/visdet/datasets/transforms/load_image.py b/visdet/datasets/transforms/load_image.py similarity index 100% rename from visdet/visdet/datasets/transforms/load_image.py rename to visdet/datasets/transforms/load_image.py diff --git a/visdet/visdet/datasets/transforms/loading.py b/visdet/datasets/transforms/loading.py similarity index 100% rename from visdet/visdet/datasets/transforms/loading.py rename to visdet/datasets/transforms/loading.py diff --git a/visdet/visdet/datasets/transforms/transforms.py b/visdet/datasets/transforms/transforms.py similarity index 100% rename from visdet/visdet/datasets/transforms/transforms.py rename to visdet/datasets/transforms/transforms.py diff --git a/visdet/engine/__init__.py b/visdet/engine/__init__.py new file mode 100644 index 0000000..c4df192 --- /dev/null +++ b/visdet/engine/__init__.py @@ -0,0 +1,53 @@ +# ruff: noqa +# Copyright (c) OpenMMLab. All rights reserved. +""" +Engine utilities for training and inference. + +This module includes: +1. visdet-specific hooks +2. Training infrastructure and model management + +Usage: + from visdet import engine + from visdet.engine import Config, Runner + from visdet.engine import hooks +""" + +# Import submodules to make them accessible under the `visdet.engine` namespace +# (e.g., `visdet.engine.runner`) +from . import ( + config, + dataset, + dist, + evaluator, + fileio, + infer, + logging, + model, + registry, + runner, + structures, + utils, + visualization, +) # noqa: F401 + +# NOTE: We don't eagerly import `hooks` here to avoid the circular import +# issue identified during development. It remains accessible via direct import: +# `from visdet.engine import hooks` or `from visdet.engine.hooks import ...` + +# Export all submodules. +__all__ = [ + "config", + "dataset", + "dist", + "evaluator", + "fileio", + "infer", + "logging", + "model", + "registry", + "runner", + "structures", + "utils", + "visualization", +] diff --git a/visdet/engine/__pycache__/__init__.cpython-312.pyc b/visdet/engine/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..208f9cd Binary files /dev/null and b/visdet/engine/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/engine/config.py b/visdet/engine/config.py new file mode 100644 index 0000000..acdc757 --- /dev/null +++ b/visdet/engine/config.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Config module. + +This module provides access to config functionality for visdet. +""" + +from visdet.engine.config import * # noqa: F401, F403 diff --git a/visdet/visdet/engine/config/__init__.py b/visdet/engine/config/__init__.py similarity index 100% rename from visdet/visdet/engine/config/__init__.py rename to visdet/engine/config/__init__.py diff --git a/visdet/visdet/engine/config/config_wrapper.py b/visdet/engine/config/config_wrapper.py similarity index 85% rename from visdet/visdet/engine/config/config_wrapper.py rename to visdet/engine/config/config_wrapper.py index b870a54..0e2fde0 100644 --- a/visdet/visdet/engine/config/config_wrapper.py +++ b/visdet/engine/config/config_wrapper.py @@ -1,19 +1,43 @@ """Enhanced Config class with YAML support. -This module provides an enhanced Config class that extends visengine.Config -with support for YAML configuration files. +This module provides an enhanced Config class with support for YAML +configuration files and Pydantic validation. """ import warnings from pathlib import Path from typing import Any, Dict, Optional, Union -from visengine.config import Config as BaseConfig - from .schema_generator import validate_config_with_schema from .yaml_loader import load_yaml_config +class BaseConfig(dict): + """Minimal base config class for compatibility. + + This provides a dict-like interface for configuration dictionaries. + """ + + def __init__(self, cfg_dict: Optional[Dict] = None, **kwargs: Any) -> None: + """Initialize config from dict or kwargs.""" + if cfg_dict is None: + cfg_dict = {} + if kwargs: + cfg_dict.update(kwargs) + super().__init__(cfg_dict) + + def __getattr__(self, name: str) -> Any: + """Get config value by attribute access.""" + try: + return self[name] + except KeyError: + raise AttributeError(f"'Config' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + """Set config value by attribute access.""" + self[name] = value + + class Config(BaseConfig): """Enhanced Config class with YAML support and Pydantic validation. diff --git a/visdet/visdet/engine/config/schema_generator.py b/visdet/engine/config/schema_generator.py similarity index 100% rename from visdet/visdet/engine/config/schema_generator.py rename to visdet/engine/config/schema_generator.py diff --git a/visdet/visdet/engine/config/yaml_loader.py b/visdet/engine/config/yaml_loader.py similarity index 100% rename from visdet/visdet/engine/config/yaml_loader.py rename to visdet/engine/config/yaml_loader.py diff --git a/visdet/engine/dataset.py b/visdet/engine/dataset.py new file mode 100644 index 0000000..c0bd557 --- /dev/null +++ b/visdet/engine/dataset.py @@ -0,0 +1,83 @@ +# ruff: noqa +""" +Dataset module. + +This module provides access to dataset functionality for visdet. +""" + +from typing import Any, Dict, List, Optional +import torch + + +def pseudo_collate(data_batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Collate batch data items into a single batch. + + This function takes a list of data items and collates them into a single + batch by stacking tensors and grouping other data. + + Args: + data_batch: List of data items, each with "inputs" and "data_samples" keys + + Returns: + Collated batch dictionary with "inputs" as stacked tensor and + "data_samples" as list of samples + """ + inputs = [] + data_samples = [] + + for data_item in data_batch: + inputs.append(data_item.get("inputs")) + data_samples.append(data_item.get("data_samples")) + + # Stack inputs into a batch tensor + if inputs and inputs[0] is not None: + inputs = torch.stack(inputs, dim=0) + else: + inputs = None + + return {"inputs": inputs, "data_samples": data_samples} + + +class BaseDataset: + """Base dataset class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide comprehensive dataset features. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize dataset. + + Args: + **kwargs: Dataset configuration arguments + """ + pass + + def __len__(self) -> int: + """Get dataset length.""" + return 0 + + def __getitem__(self, idx: int) -> Dict: + """Get a sample from the dataset. + + Args: + idx: Sample index + + Returns: + Sample dictionary + """ + return {} + + def get_cat_ids(self, idx: int) -> List: + """Get category IDs for a sample. + + Args: + idx: Sample index + + Returns: + List of category IDs + """ + return [] + + +__all__ = ["BaseDataset", "pseudo_collate"] diff --git a/visdet/engine/dist.py b/visdet/engine/dist.py new file mode 100644 index 0000000..b62d28f --- /dev/null +++ b/visdet/engine/dist.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Dist module. + +This module provides access to dist functionality for visdet. +""" + +from visdet.engine.dist import * # noqa: F401, F403 diff --git a/visdet/engine/evaluator.py b/visdet/engine/evaluator.py new file mode 100644 index 0000000..9c7163c --- /dev/null +++ b/visdet/engine/evaluator.py @@ -0,0 +1,51 @@ +# ruff: noqa +""" +Evaluator module. + +This module provides access to evaluator functionality for visdet. +""" + +from typing import Any, Dict, List, Optional + + +class BaseMetric: + """Base metric class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide comprehensive metrics. + """ + + metric_name = "metric" + + def __init__(self, collect_device: str = "cpu", prefix: Optional[str] = None) -> None: + """Initialize base metric. + + Args: + collect_device: Device to collect results on + prefix: Prefix for metric names + """ + self.collect_device = collect_device + self.prefix = prefix + + def process(self, data_batch: Dict, data_samples: List) -> None: + """Process a batch of data samples. + + Args: + data_batch: Input batch data + data_samples: Data samples with predictions + """ + pass + + def compute_metrics(self, results: List) -> Dict: + """Compute metrics from processed results. + + Args: + results: List of processed results + + Returns: + Dictionary of computed metrics + """ + return {} + + +__all__ = ["BaseMetric"] diff --git a/visdet/engine/fileio/__init__.py b/visdet/engine/fileio/__init__.py new file mode 100644 index 0000000..752c53a --- /dev/null +++ b/visdet/engine/fileio/__init__.py @@ -0,0 +1,89 @@ +# ruff: noqa +""" +File I/O utilities for visdet. + +This module provides functions for reading and writing files. +""" + +from typing import Optional +from pathlib import Path +import os + + +def get(filepath: str, backend_args: Optional[dict] = None) -> bytes: + """Get file content. + + Args: + filepath: Path to the file + backend_args: Backend arguments (unused in this stub) + + Returns: + File content as bytes + """ + with open(filepath, "rb") as f: + return f.read() + + +def get_local_path(filepath: str, backend_args: Optional[dict] = None) -> str: + """Get local path of a file. + + For local files, returns the file path as-is. + For remote files, would download and return local path. + + Args: + filepath: Path to the file (local or remote) + backend_args: Backend arguments (unused in this stub) + + Returns: + Local file path + """ + # In a full implementation, this would handle remote file downloads + # For now, just return the path as-is + return filepath + + +def load(filepath: str, backend_args: Optional[dict] = None) -> any: + """Load data from a file. + + Args: + filepath: Path to the file + backend_args: Backend arguments (unused in this stub) + + Returns: + Loaded data (could be any type based on file format) + """ + # This is a stub that would handle JSON, pickle, yaml, etc. + # For now, just read the file content + import json + + try: + with open(filepath, "r") as f: + return json.load(f) + except (json.JSONDecodeError, UnicodeDecodeError): + # Fall back to reading as bytes + with open(filepath, "rb") as f: + return f.read() + + +def dump(obj: any, file: str, backend_args: Optional[dict] = None) -> None: + """Save data to a file. + + Args: + obj: Object to save + file: Path to save to + backend_args: Backend arguments (unused in this stub) + """ + import json + + try: + with open(file, "w") as f: + json.dump(obj, f) + except (TypeError, ValueError): + # Fall back to pickle for non-JSON-serializable objects + import pickle + + with open(file, "wb") as f: + pickle.dump(obj, f) + + +__all__ = ["get", "get_local_path", "load", "dump"] diff --git a/visdet/visdet/engine/hooks/__init__.py b/visdet/engine/hooks/__init__.py similarity index 62% rename from visdet/visdet/engine/hooks/__init__.py rename to visdet/engine/hooks/__init__.py index dace130..d24cb5f 100644 --- a/visdet/visdet/engine/hooks/__init__.py +++ b/visdet/engine/hooks/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa # Copyright (c) OpenMMLab. All rights reserved. +from .base_hook import Hook from .visualization_hook import DetVisualizationHook -__all__ = ["DetVisualizationHook"] +__all__ = ["Hook", "DetVisualizationHook"] diff --git a/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc b/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..46cb53d Binary files /dev/null and b/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc b/visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc similarity index 78% rename from visdet/visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc rename to visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc index 52a839a..0defc96 100644 Binary files a/visdet/visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc and b/visdet/engine/hooks/__pycache__/visualization_hook.cpython-312.pyc differ diff --git a/visdet/engine/hooks/base_hook.py b/visdet/engine/hooks/base_hook.py new file mode 100644 index 0000000..25fbc57 --- /dev/null +++ b/visdet/engine/hooks/base_hook.py @@ -0,0 +1,39 @@ +# ruff: noqa +# Copyright (c) OpenMMLab. All rights reserved. +"""Base Hook class for visdet.""" + +from typing import Any, Dict + + +class Hook: + """Stub Hook base class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would come from mmengine. + """ + + rule: str = "" + + def before_train(self, runner: Any) -> None: + """Called before training starts.""" + pass + + def after_train(self, runner: Any) -> None: + """Called after training finishes.""" + pass + + def before_train_epoch(self, runner: Any) -> None: + """Called before each epoch.""" + pass + + def after_train_epoch(self, runner: Any) -> None: + """Called after each epoch.""" + pass + + def before_train_iter(self, runner: Any, batch_idx: int, data_batch: Dict) -> None: + """Called before each iteration.""" + pass + + def after_train_iter(self, runner: Any, batch_idx: int, data_batch: Dict, outputs: Any) -> None: + """Called after each iteration.""" + pass diff --git a/visdet/visdet/engine/hooks/visualization_hook.py b/visdet/engine/hooks/visualization_hook.py similarity index 99% rename from visdet/visdet/engine/hooks/visualization_hook.py rename to visdet/engine/hooks/visualization_hook.py index e03bced..f3c2ed1 100644 --- a/visdet/visdet/engine/hooks/visualization_hook.py +++ b/visdet/engine/hooks/visualization_hook.py @@ -5,10 +5,10 @@ from typing import Optional import numpy as np -from visengine.hooks import Hook from visdet.cv import imfrombytes, imwrite from visdet.engine.fileio import get +from visdet.engine.hooks.base_hook import Hook from visdet.engine.runner import Runner from visdet.engine.utils import mkdir_or_exist from visdet.engine.visualization import Visualizer diff --git a/visdet/engine/infer.py b/visdet/engine/infer.py new file mode 100644 index 0000000..9988e2c --- /dev/null +++ b/visdet/engine/infer.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Infer module. + +This module provides access to infer functionality for visdet. +""" + +from visdet.engine.infer import * # noqa: F401, F403 diff --git a/visdet/engine/logging.py b/visdet/engine/logging.py new file mode 100644 index 0000000..1ab54d6 --- /dev/null +++ b/visdet/engine/logging.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Logging module. + +This module provides access to logging functionality for visdet. +""" + +from visdet.engine.logging import * # noqa: F401, F403 diff --git a/visdet/engine/logging/__init__.py b/visdet/engine/logging/__init__.py new file mode 100644 index 0000000..c207625 --- /dev/null +++ b/visdet/engine/logging/__init__.py @@ -0,0 +1,87 @@ +# ruff: noqa +""" +Logging module for visdet. + +Provides logging utilities for training and inference. +""" + +import sys +from typing import Any, Dict, Optional + + +class MMLogger: + """Stub logger class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would wrap mmengine logging. + """ + + def __init__(self, name: str = "visdet") -> None: + """Initialize logger.""" + self.name = name + + def debug(self, msg: str) -> None: + """Log debug message.""" + pass + + def info(self, msg: str) -> None: + """Log info message.""" + pass + + def warning(self, msg: str) -> None: + """Log warning message.""" + pass + + def error(self, msg: str) -> None: + """Log error message.""" + pass + + +class MessageHub: + """Stub MessageHub for visdet. + + Used to collect and manage messages during training. + """ + + def __init__(self) -> None: + """Initialize message hub.""" + self.messages: Dict[str, Any] = {} + + def log_scalars(self, scalars: Dict[str, float], step: int = 0) -> None: + """Log scalar values. + + Args: + scalars: Dictionary of scalar values to log + step: Step/iteration number + """ + self.messages[step] = scalars + + def get_scalar(self, key: str, step: int = 0) -> Optional[float]: + """Get a scalar value. + + Args: + key: Key to retrieve + step: Step/iteration number + + Returns: + Scalar value if found + """ + if step in self.messages: + return self.messages[step].get(key) + return None + + +def print_log(msg: str, logger: Optional[MMLogger] = None) -> None: + """Print log message. + + Args: + msg: Message to print + logger: Optional logger instance + """ + if logger is not None: + logger.info(msg) + else: + print(msg, file=sys.stdout) + + +__all__ = ["MMLogger", "MessageHub", "print_log"] diff --git a/visdet/engine/model/__init__.py b/visdet/engine/model/__init__.py new file mode 100644 index 0000000..d6c6e5e --- /dev/null +++ b/visdet/engine/model/__init__.py @@ -0,0 +1,35 @@ +# ruff: noqa +""" +Model utilities for visdet. + +This module provides base model classes and utilities. +""" + +from typing import List +import torch.nn as nn + +from .base_module import BaseModule, BaseModel # noqa: F401 +from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor # noqa: F401 +from .weight_init import constant_init, trunc_normal_, trunc_normal_init # noqa: F401 + + +# Simple alias for ModuleList +class ModuleList(nn.ModuleList): + """ModuleList for visdet. + + This is an alias for torch.nn.ModuleList with additional features. + """ + + pass + + +__all__ = [ + "BaseModule", + "BaseModel", + "ModuleList", + "BaseDataPreprocessor", + "ImgDataPreprocessor", + "constant_init", + "trunc_normal_", + "trunc_normal_init", +] diff --git a/visdet/engine/model/base_module.py b/visdet/engine/model/base_module.py new file mode 100644 index 0000000..60cb626 --- /dev/null +++ b/visdet/engine/model/base_module.py @@ -0,0 +1,46 @@ +# ruff: noqa +"""Base module for visdet.""" + +from typing import Any, Dict, Optional + +import torch.nn as nn + + +class BaseModule(nn.Module): + """Base module class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide enhanced features like + initialization control, profiling, etc. + """ + + def __init__(self, init_cfg: Optional[Dict] = None) -> None: + """Initialize module. + + Args: + init_cfg: Config dict for weight initialization + """ + super().__init__() + self.init_cfg = init_cfg + + def _init_weights(self) -> None: + """Initialize weights.""" + pass + + +class BaseModel(BaseModule): + """Base model class for visdet. + + Extends BaseModule for model-specific functionality. + """ + + def __init__(self, init_cfg: Optional[Dict] = None) -> None: + """Initialize base model. + + Args: + init_cfg: Config dict for weight initialization + """ + super().__init__(init_cfg=init_cfg) + + +__all__ = ["BaseModule", "BaseModel"] diff --git a/visdet/engine/model/data_preprocessor.py b/visdet/engine/model/data_preprocessor.py new file mode 100644 index 0000000..55c6544 --- /dev/null +++ b/visdet/engine/model/data_preprocessor.py @@ -0,0 +1,59 @@ +# ruff: noqa +""" +Data preprocessor classes for visdet. + +This module provides base data preprocessor classes. +""" + +from typing import Dict, List, Optional, Union +import torch +import torch.nn as nn + +from .base_module import BaseModule + + +class BaseDataPreprocessor(BaseModule): + """Base data preprocessor for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide comprehensive data preprocessing. + """ + + def __init__( + self, + mean: Optional[Union[List[float], torch.Tensor]] = None, + std: Optional[Union[List[float], torch.Tensor]] = None, + rgb_to_bgr: bool = False, + bgr_to_rgb: bool = False, + pad_mask: bool = False, + pad_size_divisor: int = 1, + init_cfg: Optional[Dict] = None, + ) -> None: + """Initialize data preprocessor.""" + super().__init__(init_cfg=init_cfg) + self.mean = mean + self.std = std + self.rgb_to_bgr = rgb_to_bgr + self.bgr_to_rgb = bgr_to_rgb + self.pad_mask = pad_mask + self.pad_size_divisor = pad_size_divisor + + def forward(self, data: Dict) -> Dict: + """Forward pass for data preprocessing. + + Args: + data: Input data dictionary + + Returns: + Preprocessed data dictionary + """ + return data + + +class ImgDataPreprocessor(BaseDataPreprocessor): + """Image data preprocessor for visdet. + + Handles image normalization, RGB/BGR conversion, and padding. + """ + + pass diff --git a/visdet/engine/model/weight_init.py b/visdet/engine/model/weight_init.py new file mode 100644 index 0000000..c4942a2 --- /dev/null +++ b/visdet/engine/model/weight_init.py @@ -0,0 +1,63 @@ +# ruff: noqa +""" +Weight initialization utilities for visdet. + +This module provides functions for initializing network weights. +""" + +import math +from typing import Optional + +import torch +import torch.nn as nn + + +def constant_init(module: nn.Module, val: float, bias: float = 0) -> None: + """Initialize module parameters with constant values. + + Args: + module: Module to initialize + val: Constant value for weight initialization + bias: Constant value for bias initialization + """ + if hasattr(module, "weight") and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def trunc_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1, a: float = -2, b: float = 2) -> None: + """Truncated normal initialization. + + Args: + tensor: Tensor to initialize + mean: Mean of distribution + std: Standard deviation of distribution + a: Lower truncation bound + b: Upper truncation bound + """ + with torch.no_grad(): + # Calculate uniform bounds corresponding to truncated normal + normal = torch.normal(mean, std, size=tensor.shape) + # Clip to bounds + normal = torch.clamp(normal, a, b) + # Renormalize to ensure desired distribution + scale = (b - a) / 4 # Approximate scale factor + tensor.copy_(normal * scale) + + +def trunc_normal_init(module: nn.Module, mean: float = 0, std: float = 1) -> None: + """Apply truncated normal initialization to a module. + + Args: + module: Module to initialize + mean: Mean of distribution + std: Standard deviation of distribution + """ + if hasattr(module, "weight") and module.weight is not None: + trunc_normal_(module.weight, mean, std) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, 0) + + +__all__ = ["constant_init", "trunc_normal_", "trunc_normal_init"] diff --git a/libs/viscv/tests/test_image/__init__.py b/visdet/engine/optimizers/__init__.py similarity index 100% rename from libs/viscv/tests/test_image/__init__.py rename to visdet/engine/optimizers/__init__.py diff --git a/visdet/engine/registry.py b/visdet/engine/registry.py new file mode 100644 index 0000000..391c1ea --- /dev/null +++ b/visdet/engine/registry.py @@ -0,0 +1,155 @@ +# ruff: noqa +# type: ignore +""" +Registry module for visdet. + +This module provides access to the registry system for managing models, +datasets, hooks, and other components. +""" + +from contextlib import contextmanager +from typing import Any, Dict, Generator, Optional + + +class Registry(dict): + """Stub Registry class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide enhanced registry features. + """ + + def __init__(self, name: str = "", parent: Optional["Registry"] = None, locations: Optional[list] = None) -> None: + """Initialize registry.""" + super().__init__() + self.name = name + self.parent = parent + self.locations = locations or [] + + def register(self, cls_or_func: Any = None, force: bool = False) -> Any: + """Register a class or function.""" + + def _register(obj: Any) -> Any: + self[obj.__name__] = obj + return obj + + if cls_or_func is None: + return _register + return _register(cls_or_func) + + def register_module( + self, + name: str = None, + module: Any = None, + force: bool = False, + ) -> Any: + """Register a module with a given name. + + Can be used as: + - @registry.register_module() - uses class __name__ + - @registry.register_module(name="custom_name") - uses provided name + - registry.register_module(name="name", module=obj) - direct registration + """ + + # Case 1: Direct registration with module argument + if module is not None: + key = name if name is not None else module.__name__ + self[key] = module + return module + + # Case 2: Decorator usage - could be: + # @registry.register_module() - name will be None + # @registry.register_module(name="custom") - name will be string + # or even @registry.register_module("CustomName") - name could be string (class passed as name arg) + + def _register(obj: Any) -> Any: + # If name was provided as a string, use it; otherwise use obj's name + key = name if isinstance(name, str) else obj.__name__ + self[key] = obj + return obj + + # If name is None or a string, return decorator + # If name is actually a class (used as @registry.register_module without parens), register it directly + if name is not None and not isinstance(name, str): + # name is actually a class being decorated + return _register(name) + + return _register + + def build(self, cfg: Dict) -> Any: + """Build object from config.""" + if isinstance(cfg, dict): + cfg = cfg.copy() + obj_type = cfg.pop("type") + return self[obj_type](**cfg) + return cfg + + +# Create stub registry instances +DATA_SAMPLERS = Registry("data_sampler") +DATASETS = Registry("dataset") +EVALUATOR = Registry("evaluator") +HOOKS = Registry("hook") +LOG_PROCESSORS = Registry("log_processor") +LOOPS = Registry("loop") +METRICS = Registry("metric") +MODEL_WRAPPERS = Registry("model_wrapper") +MODELS = Registry("model") +OPTIM_WRAPPER_CONSTRUCTORS = Registry("optimizer_constructor") +OPTIM_WRAPPERS = Registry("optim_wrapper") +OPTIMIZERS = Registry("optimizer") +PARAM_SCHEDULERS = Registry("parameter_scheduler") +RUNNER_CONSTRUCTORS = Registry("runner_constructor") +RUNNERS = Registry("runner") +TASK_UTILS = Registry("task_util") +TRANSFORMS = Registry("transform") +VISBACKENDS = Registry("vis_backend") +VISUALIZERS = Registry("visualizer") +WEIGHT_INITIALIZERS = Registry("weight_initializer") + + +class DefaultScope: + """Stub implementation of DefaultScope for registry management. + + This is a simplified version for the type checking phase. + In a full implementation, this would come from mmengine. + """ + + _current_instance: Optional[str] = None + _created_instances: set = set() + + def __init__(self, scope_name: str) -> None: + """Initialize a DefaultScope instance.""" + self.scope_name = scope_name + DefaultScope._created_instances.add(scope_name) + + @classmethod + def get_instance(cls, instance_name: str, scope_name: str = "") -> "DefaultScope": + """Get or create a DefaultScope instance.""" + cls._created_instances.add(scope_name) + return cls(scope_name) + + @classmethod + def get_current_instance(cls) -> Optional["DefaultScope"]: + """Get the current DefaultScope instance.""" + if cls._current_instance is not None: + return cls(cls._current_instance) + return None + + @classmethod + def check_instance_created(cls, scope_name: str) -> bool: + """Check if a scope instance has been created.""" + return scope_name in cls._created_instances + + @classmethod + @contextmanager + def overwrite_default_scope(cls, scope_name: str) -> Generator[None, None, None]: + """Context manager to temporarily set the default scope.""" + old_instance = cls._current_instance + cls._current_instance = scope_name + try: + yield + finally: + cls._current_instance = old_instance + + +__all__ = ["DefaultScope"] diff --git a/visdet/engine/runner/__init__.py b/visdet/engine/runner/__init__.py new file mode 100644 index 0000000..50c0781 --- /dev/null +++ b/visdet/engine/runner/__init__.py @@ -0,0 +1,11 @@ +# ruff: noqa +""" +Runner utilities for visdet. + +This module provides training runner implementations. +""" + +from .runner import Runner # noqa: F401 +from .checkpoint import CheckpointLoader # noqa: F401 + +__all__ = ["Runner", "CheckpointLoader"] diff --git a/visdet/engine/runner/checkpoint.py b/visdet/engine/runner/checkpoint.py new file mode 100644 index 0000000..9de0441 --- /dev/null +++ b/visdet/engine/runner/checkpoint.py @@ -0,0 +1,49 @@ +# ruff: noqa +""" +Checkpoint utilities for visdet. + +This module provides checkpoint loading and saving utilities. +""" + +from typing import Any, Dict, Optional +import torch + + +class CheckpointLoader: + """Stub checkpoint loader for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would handle various checkpoint formats. + """ + + @staticmethod + def load_checkpoint(filename: str, map_location: Optional[str] = None) -> Dict[str, Any]: + """Load checkpoint from file. + + Args: + filename: Path to checkpoint file + map_location: Device to load to + + Returns: + Loaded checkpoint dict + """ + return torch.load(filename, map_location=map_location) + + @staticmethod + def save_checkpoint(model: Any, filename: str, optimizer: Optional[Any] = None, **kwargs: Any) -> None: + """Save checkpoint to file. + + Args: + model: Model to save + filename: Path to save to + optimizer: Optional optimizer to save + **kwargs: Additional state to save + """ + checkpoint = {"state_dict": model.state_dict()} + if optimizer is not None: + checkpoint["optimizer"] = optimizer.state_dict() + checkpoint.update(kwargs) + torch.save(checkpoint, filename) + + +__all__ = ["CheckpointLoader"] diff --git a/visdet/engine/runner/runner.py b/visdet/engine/runner/runner.py new file mode 100644 index 0000000..4c4769a --- /dev/null +++ b/visdet/engine/runner/runner.py @@ -0,0 +1,35 @@ +# ruff: noqa +""" +Runner class for visdet training. + +This module provides the main training runner. +""" + +from typing import Any, Dict, Optional + + +class Runner: + """Stub runner class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would handle training loops, validation, etc. + """ + + def __init__(self) -> None: + """Initialize runner.""" + self.iter = 0 + self.epoch = 0 + self.work_dir = "" + self.timestamp = "" + self._hooks = [] + + def register_hook(self, hook: Any) -> None: + """Register a hook. + + Args: + hook: Hook instance to register + """ + self._hooks.append(hook) + + +__all__ = ["Runner"] diff --git a/visdet/engine/structures.py b/visdet/engine/structures.py new file mode 100644 index 0000000..ee14efc --- /dev/null +++ b/visdet/engine/structures.py @@ -0,0 +1,8 @@ +# ruff: noqa +""" +Structures module. + +This module provides access to structures functionality for visdet. +""" + +from visdet.engine.structures import * # noqa: F401, F403 diff --git a/visdet/engine/structures/__init__.py b/visdet/engine/structures/__init__.py new file mode 100644 index 0000000..71d2d11 --- /dev/null +++ b/visdet/engine/structures/__init__.py @@ -0,0 +1,60 @@ +# ruff: noqa +""" +Data structures module for visdet. + +Provides data container classes for training and inference. +""" + +from typing import Any, Dict + + +class BaseDataElement(dict): + """Base data element class for visdet. + + Provides dict-like interface with attribute access. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize data element.""" + super().__init__(**kwargs) + + def __getattr__(self, name: str) -> Any: + """Get attribute from data.""" + try: + return self[name] + except KeyError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute in data.""" + self[name] = value + + def cpu(self) -> "BaseDataElement": + """Move to CPU (stub implementation).""" + return self + + def to(self, device: str) -> "BaseDataElement": + """Move to device (stub implementation).""" + return self + + +class InstanceData(BaseDataElement): + """Instance data container for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide structured data access. + """ + + pass + + +class PixelData(BaseDataElement): + """Pixel-level data container for visdet. + + Used for mask, heatmap, and other pixel-level annotations. + """ + + pass + + +__all__ = ["BaseDataElement", "InstanceData", "PixelData"] diff --git a/visdet/engine/utils.py b/visdet/engine/utils.py new file mode 100644 index 0000000..880a00d --- /dev/null +++ b/visdet/engine/utils.py @@ -0,0 +1,31 @@ +# ruff: noqa +""" +Utils module. + +This module provides access to utils functionality for visdet. +""" + +# Re-export from the utils package +from .utils import ( # noqa: F401 + digit_version, + to_2tuple, + is_str, + is_seq_of, + is_tuple_of, + scandir, + slice_list, + mkdir_or_exist, + is_abs, +) + +__all__ = [ + "digit_version", + "to_2tuple", + "is_str", + "is_seq_of", + "is_tuple_of", + "scandir", + "slice_list", + "mkdir_or_exist", + "is_abs", +] diff --git a/visdet/engine/utils/__init__.py b/visdet/engine/utils/__init__.py new file mode 100644 index 0000000..edbe8f3 --- /dev/null +++ b/visdet/engine/utils/__init__.py @@ -0,0 +1,148 @@ +# ruff: noqa +""" +Utility functions for visdet. + +Provides common utility functions for training and inference. +""" + +import os +import os.path as osp +from pathlib import Path +from typing import Any, Iterable, Sequence, Union + + +def digit_version(version_str: str) -> tuple: + """Convert version string to a tuple of digits. + + Args: + version_str: Version string like "1.2.3" + + Returns: + Tuple of integers representing the version + """ + try: + return tuple(int(d) for d in version_str.split(".")) + except (ValueError, AttributeError): + return (0,) + + +def to_2tuple(x: Union[int, float, Sequence]) -> tuple: + """Convert input to a 2-tuple.""" + if isinstance(x, (tuple, list)): + return tuple(x) if len(x) == 2 else (x[0], x[0]) + return (x, x) + + +def is_str(x: Any) -> bool: + """Whether the input is a string instance.""" + return isinstance(x, str) + + +def is_seq_of(seq: Any, expected_type: type, seq_type: type = None) -> bool: + """Check whether it is a sequence of some type. + + Args: + seq: The sequence to be checked + expected_type: Expected type of sequence items + seq_type: Expected sequence type, defaults to (tuple, list) + + Returns: + True if all items have expected_type + """ + if seq_type is None: + seq_type = (tuple, list) + if not isinstance(seq, seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_tuple_of(seq: Any, expected_type: type) -> bool: + """Check whether it is a tuple of some type. + + Args: + seq: The sequence to be checked + expected_type: Expected type of sequence items + + Returns: + True if it is a tuple and all items have expected_type + """ + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def scandir(path: str) -> Iterable[str]: + """Scan a directory to find the interested files. + + Args: + path: Path to the directory + + Yields: + Path to each file in the directory + """ + for entry in os.scandir(path): + yield entry.path + + +def slice_list(in_list: list, lens: Iterable) -> list: + """Slice a list into several sub lists by the given length. + + Args: + in_list: The list to slice + lens: The slice length + + Returns: + List of sliced lists + """ + if not isinstance(lens, Iterable): + raise TypeError(f"lens must be an iterable, but got {type(lens)}") + + if isinstance(lens, int): + raise TypeError(f"lens must be an iterable, but got {type(lens)}") + + lens = list(lens) + if sum(lens) != len(in_list): + raise ValueError("sum(lens) and the length of in_list do not match") + + result = [] + idx = 0 + for length in lens: + result.append(in_list[idx : idx + length]) + idx += length + return result + + +def mkdir_or_exist(dir_name: str) -> None: + """Make a directory or check if it exists. + + Args: + dir_name: Directory name + """ + if not osp.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + +def is_abs(path: str) -> bool: + """Check whether the path is absolute. + + Args: + path: Path string + + Returns: + True if path is absolute + """ + return osp.isabs(path) or Path(path).is_absolute() + + +__all__ = [ + "digit_version", + "to_2tuple", + "is_str", + "is_seq_of", + "is_tuple_of", + "scandir", + "slice_list", + "mkdir_or_exist", + "is_abs", +] diff --git a/visdet/engine/visualization.py b/visdet/engine/visualization.py new file mode 100644 index 0000000..e7f31a8 --- /dev/null +++ b/visdet/engine/visualization.py @@ -0,0 +1,141 @@ +# ruff: noqa +""" +Visualization module. + +This module provides access to visualization functionality for visdet. +""" + +from typing import Any, Dict, Optional, Union +import numpy as np + + +class Visualizer: + """Stub visualizer class for visdet. + + This is a minimal implementation for type checking. + In a full implementation, this would provide comprehensive visualization features. + """ + + _instance: Optional["Visualizer"] = None + + def __init__(self, name: str = "visdet") -> None: + """Initialize visualizer. + + Args: + name: Visualizer name + """ + self.name = name + self._image = None + self._vis_backends = {} + + @classmethod + def get_current_instance(cls) -> Optional["Visualizer"]: + """Get current visualizer instance. + + Returns: + Current visualizer instance or None + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def set_image(self, image: np.ndarray) -> None: + """Set image for visualization. + + Args: + image: Image array + """ + self._image = image + + def get_image(self) -> np.ndarray: + """Get current image. + + Returns: + Current image array + """ + if self._image is None: + return np.zeros((100, 100, 3), dtype=np.uint8) + return self._image + + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: Any = None, + show: bool = False, + wait_time: float = 0, + pred_score_thr: float = 0.3, + step: int = 0, + out_file: Optional[str] = None, + ) -> None: + """Add a data sample visualization. + + Args: + name: Name of the visualization + image: Image array + data_sample: Data sample to visualize + show: Whether to show the visualization + wait_time: Time to wait (for show mode) + pred_score_thr: Prediction score threshold + step: Step number + out_file: Output file path + """ + self.set_image(image) + + def draw_bboxes( + self, + bboxes: Union[np.ndarray, list], + edge_colors: Union[tuple, str] = "green", + face_colors: Union[tuple, str] = "green", + alpha: float = 0.8, + **kwargs, + ) -> None: + """Draw bounding boxes on the image. + + Args: + bboxes: Bounding boxes array + edge_colors: Edge colors + face_colors: Face colors + alpha: Alpha value for transparency + """ + pass + + def draw_texts( + self, + texts: Union[str, list], + positions: Union[np.ndarray, list], + colors: Union[tuple, str] = "white", + font_sizes: Union[int, list] = 13, + font_families: str = "sans-serif", + bboxes: Optional[list] = None, + **kwargs, + ) -> None: + """Draw text on the image. + + Args: + texts: Text(s) to draw + positions: Position(s) for text + colors: Text color(s) + font_sizes: Font size(s) + font_families: Font family + bboxes: Optional bounding box configs + """ + pass + + def show( + self, + drawn_img: np.ndarray, + win_name: str = "image", + wait_time: float = 0, + ) -> None: + """Show the image. + + Args: + drawn_img: Image to show + win_name: Window name + wait_time: Time to wait + """ + pass + + +__all__ = ["Visualizer"] diff --git a/visdet/visdet/evaluation/__init__.py b/visdet/evaluation/__init__.py similarity index 100% rename from visdet/visdet/evaluation/__init__.py rename to visdet/evaluation/__init__.py diff --git a/visdet/visdet/evaluation/evaluator/__init__.py b/visdet/evaluation/evaluator/__init__.py similarity index 100% rename from visdet/visdet/evaluation/evaluator/__init__.py rename to visdet/evaluation/evaluator/__init__.py diff --git a/visdet/visdet/evaluation/functional/__init__.py b/visdet/evaluation/functional/__init__.py similarity index 100% rename from visdet/visdet/evaluation/functional/__init__.py rename to visdet/evaluation/functional/__init__.py diff --git a/visdet/visdet/evaluation/functional/bbox_overlaps.py b/visdet/evaluation/functional/bbox_overlaps.py similarity index 100% rename from visdet/visdet/evaluation/functional/bbox_overlaps.py rename to visdet/evaluation/functional/bbox_overlaps.py diff --git a/visdet/visdet/evaluation/functional/class_names.py b/visdet/evaluation/functional/class_names.py similarity index 100% rename from visdet/visdet/evaluation/functional/class_names.py rename to visdet/evaluation/functional/class_names.py diff --git a/visdet/visdet/evaluation/functional/mean_ap.py b/visdet/evaluation/functional/mean_ap.py similarity index 100% rename from visdet/visdet/evaluation/functional/mean_ap.py rename to visdet/evaluation/functional/mean_ap.py diff --git a/visdet/visdet/evaluation/functional/panoptic_utils.py b/visdet/evaluation/functional/panoptic_utils.py similarity index 100% rename from visdet/visdet/evaluation/functional/panoptic_utils.py rename to visdet/evaluation/functional/panoptic_utils.py diff --git a/visdet/visdet/evaluation/functional/recall.py b/visdet/evaluation/functional/recall.py similarity index 100% rename from visdet/visdet/evaluation/functional/recall.py rename to visdet/evaluation/functional/recall.py diff --git a/visdet/visdet/evaluation/metrics/__init__.py b/visdet/evaluation/metrics/__init__.py similarity index 100% rename from visdet/visdet/evaluation/metrics/__init__.py rename to visdet/evaluation/metrics/__init__.py diff --git a/visdet/visdet/evaluation/metrics/coco_metric.py b/visdet/evaluation/metrics/coco_metric.py similarity index 100% rename from visdet/visdet/evaluation/metrics/coco_metric.py rename to visdet/evaluation/metrics/coco_metric.py diff --git a/visdet/visdet/models/__init__.py b/visdet/models/__init__.py similarity index 100% rename from visdet/visdet/models/__init__.py rename to visdet/models/__init__.py diff --git a/visdet/models/__pycache__/__init__.cpython-312.pyc b/visdet/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..85b26a7 Binary files /dev/null and b/visdet/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/models/backbones/__init__.py b/visdet/models/backbones/__init__.py similarity index 100% rename from visdet/visdet/models/backbones/__init__.py rename to visdet/models/backbones/__init__.py diff --git a/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc b/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..b5d09f1 Binary files /dev/null and b/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/models/backbones/__pycache__/swin.cpython-312.pyc b/visdet/models/backbones/__pycache__/swin.cpython-312.pyc similarity index 72% rename from visdet/visdet/models/backbones/__pycache__/swin.cpython-312.pyc rename to visdet/models/backbones/__pycache__/swin.cpython-312.pyc index d338595..dddce77 100644 Binary files a/visdet/visdet/models/backbones/__pycache__/swin.cpython-312.pyc and b/visdet/models/backbones/__pycache__/swin.cpython-312.pyc differ diff --git a/visdet/visdet/models/backbones/swin.py b/visdet/models/backbones/swin.py similarity index 99% rename from visdet/visdet/models/backbones/swin.py rename to visdet/models/backbones/swin.py index 4cb4cac..0ab94db 100644 --- a/visdet/visdet/models/backbones/swin.py +++ b/visdet/models/backbones/swin.py @@ -19,6 +19,12 @@ from ..layers import PatchEmbed, PatchMerging +# Optional flash attention function (available if flash_attn is installed) +try: + from flash_attn import flash_attn_func as flash_swin_attn_func # type: ignore +except ImportError: + flash_swin_attn_func = None # type: ignore + class WindowMSA(BaseModule): """Window based multi-head self-attention (W-MSA) module with relative diff --git a/visdet/visdet/models/data_preprocessors/__init__.py b/visdet/models/data_preprocessors/__init__.py similarity index 100% rename from visdet/visdet/models/data_preprocessors/__init__.py rename to visdet/models/data_preprocessors/__init__.py diff --git a/visdet/visdet/models/data_preprocessors/data_preprocessor.py b/visdet/models/data_preprocessors/data_preprocessor.py similarity index 100% rename from visdet/visdet/models/data_preprocessors/data_preprocessor.py rename to visdet/models/data_preprocessors/data_preprocessor.py diff --git a/visdet/visdet/models/dense_heads/__init__.py b/visdet/models/dense_heads/__init__.py similarity index 100% rename from visdet/visdet/models/dense_heads/__init__.py rename to visdet/models/dense_heads/__init__.py diff --git a/visdet/visdet/models/dense_heads/anchor_head.py b/visdet/models/dense_heads/anchor_head.py similarity index 100% rename from visdet/visdet/models/dense_heads/anchor_head.py rename to visdet/models/dense_heads/anchor_head.py diff --git a/visdet/visdet/models/dense_heads/base_dense_head.py b/visdet/models/dense_heads/base_dense_head.py similarity index 100% rename from visdet/visdet/models/dense_heads/base_dense_head.py rename to visdet/models/dense_heads/base_dense_head.py diff --git a/visdet/visdet/models/dense_heads/rpn_head.py b/visdet/models/dense_heads/rpn_head.py similarity index 100% rename from visdet/visdet/models/dense_heads/rpn_head.py rename to visdet/models/dense_heads/rpn_head.py diff --git a/visdet/visdet/models/detectors/__init__.py b/visdet/models/detectors/__init__.py similarity index 100% rename from visdet/visdet/models/detectors/__init__.py rename to visdet/models/detectors/__init__.py diff --git a/visdet/visdet/models/detectors/base.py b/visdet/models/detectors/base.py similarity index 100% rename from visdet/visdet/models/detectors/base.py rename to visdet/models/detectors/base.py diff --git a/visdet/visdet/models/detectors/cascade_rcnn.py b/visdet/models/detectors/cascade_rcnn.py similarity index 100% rename from visdet/visdet/models/detectors/cascade_rcnn.py rename to visdet/models/detectors/cascade_rcnn.py diff --git a/visdet/visdet/models/detectors/mask_rcnn.py b/visdet/models/detectors/mask_rcnn.py similarity index 100% rename from visdet/visdet/models/detectors/mask_rcnn.py rename to visdet/models/detectors/mask_rcnn.py diff --git a/visdet/visdet/models/detectors/two_stage.py b/visdet/models/detectors/two_stage.py similarity index 100% rename from visdet/visdet/models/detectors/two_stage.py rename to visdet/models/detectors/two_stage.py diff --git a/visdet/visdet/models/layers/__init__.py b/visdet/models/layers/__init__.py similarity index 100% rename from visdet/visdet/models/layers/__init__.py rename to visdet/models/layers/__init__.py diff --git a/visdet/visdet/models/layers/__pycache__/__init__.cpython-312.pyc b/visdet/models/layers/__pycache__/__init__.cpython-312.pyc similarity index 94% rename from visdet/visdet/models/layers/__pycache__/__init__.cpython-312.pyc rename to visdet/models/layers/__pycache__/__init__.cpython-312.pyc index dec8757..43e8850 100644 Binary files a/visdet/visdet/models/layers/__pycache__/__init__.cpython-312.pyc and b/visdet/models/layers/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc b/visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc similarity index 93% rename from visdet/visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc rename to visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc index 8eac021..f378dc1 100644 Binary files a/visdet/visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc and b/visdet/models/layers/__pycache__/bbox_nms.cpython-312.pyc differ diff --git a/visdet/visdet/models/layers/bbox_nms.py b/visdet/models/layers/bbox_nms.py similarity index 100% rename from visdet/visdet/models/layers/bbox_nms.py rename to visdet/models/layers/bbox_nms.py diff --git a/visdet/visdet/models/layers/normed_predictor.py b/visdet/models/layers/normed_predictor.py similarity index 100% rename from visdet/visdet/models/layers/normed_predictor.py rename to visdet/models/layers/normed_predictor.py diff --git a/visdet/visdet/models/losses/__init__.py b/visdet/models/losses/__init__.py similarity index 100% rename from visdet/visdet/models/losses/__init__.py rename to visdet/models/losses/__init__.py diff --git a/visdet/visdet/models/losses/accuracy.py b/visdet/models/losses/accuracy.py similarity index 100% rename from visdet/visdet/models/losses/accuracy.py rename to visdet/models/losses/accuracy.py diff --git a/visdet/visdet/models/losses/cross_entropy_loss.py b/visdet/models/losses/cross_entropy_loss.py similarity index 100% rename from visdet/visdet/models/losses/cross_entropy_loss.py rename to visdet/models/losses/cross_entropy_loss.py diff --git a/visdet/visdet/models/losses/smooth_l1_loss.py b/visdet/models/losses/smooth_l1_loss.py similarity index 100% rename from visdet/visdet/models/losses/smooth_l1_loss.py rename to visdet/models/losses/smooth_l1_loss.py diff --git a/visdet/visdet/models/losses/utils.py b/visdet/models/losses/utils.py similarity index 100% rename from visdet/visdet/models/losses/utils.py rename to visdet/models/losses/utils.py diff --git a/visdet/visdet/models/necks/__init__.py b/visdet/models/necks/__init__.py similarity index 100% rename from visdet/visdet/models/necks/__init__.py rename to visdet/models/necks/__init__.py diff --git a/visdet/visdet/models/necks/fpn.py b/visdet/models/necks/fpn.py similarity index 100% rename from visdet/visdet/models/necks/fpn.py rename to visdet/models/necks/fpn.py diff --git a/visdet/visdet/models/roi_heads/__init__.py b/visdet/models/roi_heads/__init__.py similarity index 100% rename from visdet/visdet/models/roi_heads/__init__.py rename to visdet/models/roi_heads/__init__.py diff --git a/visdet/visdet/models/roi_heads/base_roi_head.py b/visdet/models/roi_heads/base_roi_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/base_roi_head.py rename to visdet/models/roi_heads/base_roi_head.py diff --git a/visdet/visdet/models/roi_heads/bbox_heads/__init__.py b/visdet/models/roi_heads/bbox_heads/__init__.py similarity index 100% rename from visdet/visdet/models/roi_heads/bbox_heads/__init__.py rename to visdet/models/roi_heads/bbox_heads/__init__.py diff --git a/visdet/visdet/models/roi_heads/bbox_heads/bbox_head.py b/visdet/models/roi_heads/bbox_heads/bbox_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/bbox_heads/bbox_head.py rename to visdet/models/roi_heads/bbox_heads/bbox_head.py diff --git a/visdet/visdet/models/roi_heads/bbox_heads/convfc_bbox_head.py b/visdet/models/roi_heads/bbox_heads/convfc_bbox_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/bbox_heads/convfc_bbox_head.py rename to visdet/models/roi_heads/bbox_heads/convfc_bbox_head.py diff --git a/visdet/visdet/models/roi_heads/cascade_roi_head.py b/visdet/models/roi_heads/cascade_roi_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/cascade_roi_head.py rename to visdet/models/roi_heads/cascade_roi_head.py diff --git a/visdet/visdet/models/roi_heads/mask_heads/__init__.py b/visdet/models/roi_heads/mask_heads/__init__.py similarity index 100% rename from visdet/visdet/models/roi_heads/mask_heads/__init__.py rename to visdet/models/roi_heads/mask_heads/__init__.py diff --git a/visdet/visdet/models/roi_heads/mask_heads/fcn_mask_head.py b/visdet/models/roi_heads/mask_heads/fcn_mask_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/mask_heads/fcn_mask_head.py rename to visdet/models/roi_heads/mask_heads/fcn_mask_head.py diff --git a/visdet/visdet/models/roi_heads/roi_extractors/__init__.py b/visdet/models/roi_heads/roi_extractors/__init__.py similarity index 100% rename from visdet/visdet/models/roi_heads/roi_extractors/__init__.py rename to visdet/models/roi_heads/roi_extractors/__init__.py diff --git a/visdet/visdet/models/roi_heads/roi_extractors/base_roi_extractor.py b/visdet/models/roi_heads/roi_extractors/base_roi_extractor.py similarity index 100% rename from visdet/visdet/models/roi_heads/roi_extractors/base_roi_extractor.py rename to visdet/models/roi_heads/roi_extractors/base_roi_extractor.py diff --git a/visdet/visdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/visdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py similarity index 100% rename from visdet/visdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py rename to visdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py diff --git a/visdet/visdet/models/roi_heads/standard_roi_head.py b/visdet/models/roi_heads/standard_roi_head.py similarity index 100% rename from visdet/visdet/models/roi_heads/standard_roi_head.py rename to visdet/models/roi_heads/standard_roi_head.py diff --git a/visdet/visdet/models/task_modules/__init__.py b/visdet/models/task_modules/__init__.py similarity index 100% rename from visdet/visdet/models/task_modules/__init__.py rename to visdet/models/task_modules/__init__.py diff --git a/visdet/visdet/models/task_modules/assigners/__init__.py b/visdet/models/task_modules/assigners/__init__.py similarity index 100% rename from visdet/visdet/models/task_modules/assigners/__init__.py rename to visdet/models/task_modules/assigners/__init__.py diff --git a/visdet/visdet/models/task_modules/assigners/assign_result.py b/visdet/models/task_modules/assigners/assign_result.py similarity index 100% rename from visdet/visdet/models/task_modules/assigners/assign_result.py rename to visdet/models/task_modules/assigners/assign_result.py diff --git a/visdet/visdet/models/task_modules/assigners/base_assigner.py b/visdet/models/task_modules/assigners/base_assigner.py similarity index 100% rename from visdet/visdet/models/task_modules/assigners/base_assigner.py rename to visdet/models/task_modules/assigners/base_assigner.py diff --git a/visdet/visdet/models/task_modules/assigners/iou2d_calculator.py b/visdet/models/task_modules/assigners/iou2d_calculator.py similarity index 100% rename from visdet/visdet/models/task_modules/assigners/iou2d_calculator.py rename to visdet/models/task_modules/assigners/iou2d_calculator.py diff --git a/visdet/visdet/models/task_modules/assigners/max_iou_assigner.py b/visdet/models/task_modules/assigners/max_iou_assigner.py similarity index 100% rename from visdet/visdet/models/task_modules/assigners/max_iou_assigner.py rename to visdet/models/task_modules/assigners/max_iou_assigner.py diff --git a/visdet/visdet/models/task_modules/coders/__init__.py b/visdet/models/task_modules/coders/__init__.py similarity index 100% rename from visdet/visdet/models/task_modules/coders/__init__.py rename to visdet/models/task_modules/coders/__init__.py diff --git a/visdet/visdet/models/task_modules/prior_generators/__init__.py b/visdet/models/task_modules/prior_generators/__init__.py similarity index 100% rename from visdet/visdet/models/task_modules/prior_generators/__init__.py rename to visdet/models/task_modules/prior_generators/__init__.py diff --git a/visdet/visdet/models/task_modules/prior_generators/anchor_generator.py b/visdet/models/task_modules/prior_generators/anchor_generator.py similarity index 100% rename from visdet/visdet/models/task_modules/prior_generators/anchor_generator.py rename to visdet/models/task_modules/prior_generators/anchor_generator.py diff --git a/visdet/visdet/models/task_modules/samplers/__init__.py b/visdet/models/task_modules/samplers/__init__.py similarity index 100% rename from visdet/visdet/models/task_modules/samplers/__init__.py rename to visdet/models/task_modules/samplers/__init__.py diff --git a/visdet/visdet/models/test_time_augs/__init__.py b/visdet/models/test_time_augs/__init__.py similarity index 100% rename from visdet/visdet/models/test_time_augs/__init__.py rename to visdet/models/test_time_augs/__init__.py diff --git a/visdet/visdet/models/test_time_augs/merge_augs.py b/visdet/models/test_time_augs/merge_augs.py similarity index 100% rename from visdet/visdet/models/test_time_augs/merge_augs.py rename to visdet/models/test_time_augs/merge_augs.py diff --git a/visdet/visdet/models/utils.py b/visdet/models/utils.py similarity index 100% rename from visdet/visdet/models/utils.py rename to visdet/models/utils.py diff --git a/visdet/visdet/models/utils/__init__.py b/visdet/models/utils/__init__.py similarity index 100% rename from visdet/visdet/models/utils/__init__.py rename to visdet/models/utils/__init__.py diff --git a/visdet/visdet/models/utils/gaussian_target.py b/visdet/models/utils/gaussian_target.py similarity index 100% rename from visdet/visdet/models/utils/gaussian_target.py rename to visdet/models/utils/gaussian_target.py diff --git a/visdet/visdet/models/utils/image.py b/visdet/models/utils/image.py similarity index 100% rename from visdet/visdet/models/utils/image.py rename to visdet/models/utils/image.py diff --git a/visdet/visdet/models/utils/make_divisible.py b/visdet/models/utils/make_divisible.py similarity index 100% rename from visdet/visdet/models/utils/make_divisible.py rename to visdet/models/utils/make_divisible.py diff --git a/visdet/visdet/models/utils/misc.py b/visdet/models/utils/misc.py similarity index 100% rename from visdet/visdet/models/utils/misc.py rename to visdet/models/utils/misc.py diff --git a/visdet/visdet/models/utils/panoptic_gt_processing.py b/visdet/models/utils/panoptic_gt_processing.py similarity index 100% rename from visdet/visdet/models/utils/panoptic_gt_processing.py rename to visdet/models/utils/panoptic_gt_processing.py diff --git a/visdet/visdet/models/utils/point_sample.py b/visdet/models/utils/point_sample.py similarity index 100% rename from visdet/visdet/models/utils/point_sample.py rename to visdet/models/utils/point_sample.py diff --git a/visdet/visdet/models/utils/vlfuse_helper.py b/visdet/models/utils/vlfuse_helper.py similarity index 100% rename from visdet/visdet/models/utils/vlfuse_helper.py rename to visdet/models/utils/vlfuse_helper.py diff --git a/visdet/visdet/models/utils/wbf.py b/visdet/models/utils/wbf.py similarity index 100% rename from visdet/visdet/models/utils/wbf.py rename to visdet/models/utils/wbf.py diff --git a/visdet/visdet/presets/__init__.py b/visdet/presets/__init__.py similarity index 100% rename from visdet/visdet/presets/__init__.py rename to visdet/presets/__init__.py diff --git a/visdet/visdet/presets/registry.py b/visdet/presets/registry.py similarity index 100% rename from visdet/visdet/presets/registry.py rename to visdet/presets/registry.py diff --git a/visdet/pyproject.toml b/visdet/pyproject.toml index fce94db..21bd4cb 100644 --- a/visdet/pyproject.toml +++ b/visdet/pyproject.toml @@ -5,8 +5,6 @@ description = "Swin Mask R-CNN for object detection and instance segmentation" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "visengine>=0.1.0", - "viscv>=0.1.0", "numpy>=2.0.0", "torch==2.5.1; platform_system == 'Linux' and platform_machine == 'x86_64'", "torchvision==0.20.1; platform_system == 'Linux' and platform_machine == 'x86_64'", @@ -29,7 +27,3 @@ include = ["visdet*"] [tool.setuptools.package-data] '*' = ['*.yaml', '*.json', '*.yml'] - -[tool.uv.sources] -visengine = { workspace = true } -viscv = { workspace = true } diff --git a/visdet/visdet/registry.py b/visdet/registry.py similarity index 100% rename from visdet/visdet/registry.py rename to visdet/registry.py diff --git a/visdet/visdet/runner.py b/visdet/runner.py similarity index 100% rename from visdet/visdet/runner.py rename to visdet/runner.py diff --git a/visdet/visdet/structures/__init__.py b/visdet/structures/__init__.py similarity index 100% rename from visdet/visdet/structures/__init__.py rename to visdet/structures/__init__.py diff --git a/visdet/structures/__pycache__/__init__.cpython-312.pyc b/visdet/structures/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..7b5e557 Binary files /dev/null and b/visdet/structures/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/structures/__pycache__/det_data_sample.cpython-312.pyc b/visdet/structures/__pycache__/det_data_sample.cpython-312.pyc similarity index 71% rename from visdet/visdet/structures/__pycache__/det_data_sample.cpython-312.pyc rename to visdet/structures/__pycache__/det_data_sample.cpython-312.pyc index 35869e3..6388f4d 100644 Binary files a/visdet/visdet/structures/__pycache__/det_data_sample.cpython-312.pyc and b/visdet/structures/__pycache__/det_data_sample.cpython-312.pyc differ diff --git a/visdet/visdet/structures/bbox/__init__.py b/visdet/structures/bbox/__init__.py similarity index 100% rename from visdet/visdet/structures/bbox/__init__.py rename to visdet/structures/bbox/__init__.py diff --git a/visdet/visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc b/visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc similarity index 79% rename from visdet/visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc rename to visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc index 6bf65e4..86cfbf9 100644 Binary files a/visdet/visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc and b/visdet/structures/bbox/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc b/visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc similarity index 99% rename from visdet/visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc rename to visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc index 67cecdc..af31b58 100644 Binary files a/visdet/visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc and b/visdet/structures/bbox/__pycache__/base_boxes.cpython-312.pyc differ diff --git a/visdet/visdet/structures/bbox/base_boxes.py b/visdet/structures/bbox/base_boxes.py similarity index 100% rename from visdet/visdet/structures/bbox/base_boxes.py rename to visdet/structures/bbox/base_boxes.py diff --git a/visdet/visdet/structures/bbox/bbox_overlaps.py b/visdet/structures/bbox/bbox_overlaps.py similarity index 100% rename from visdet/visdet/structures/bbox/bbox_overlaps.py rename to visdet/structures/bbox/bbox_overlaps.py diff --git a/visdet/visdet/structures/bbox/bbox_overlaps.py.backup b/visdet/structures/bbox/bbox_overlaps.py.backup similarity index 100% rename from visdet/visdet/structures/bbox/bbox_overlaps.py.backup rename to visdet/structures/bbox/bbox_overlaps.py.backup diff --git a/visdet/visdet/structures/bbox/box_type.py b/visdet/structures/bbox/box_type.py similarity index 100% rename from visdet/visdet/structures/bbox/box_type.py rename to visdet/structures/bbox/box_type.py diff --git a/visdet/visdet/structures/bbox/coders/base_bbox_coder.py b/visdet/structures/bbox/coders/base_bbox_coder.py similarity index 100% rename from visdet/visdet/structures/bbox/coders/base_bbox_coder.py rename to visdet/structures/bbox/coders/base_bbox_coder.py diff --git a/visdet/visdet/structures/bbox/horizontal_boxes.py b/visdet/structures/bbox/horizontal_boxes.py similarity index 100% rename from visdet/visdet/structures/bbox/horizontal_boxes.py rename to visdet/structures/bbox/horizontal_boxes.py diff --git a/visdet/visdet/structures/bbox/transforms.py b/visdet/structures/bbox/transforms.py similarity index 100% rename from visdet/visdet/structures/bbox/transforms.py rename to visdet/structures/bbox/transforms.py diff --git a/visdet/visdet/structures/det_data_sample.py b/visdet/structures/det_data_sample.py similarity index 100% rename from visdet/visdet/structures/det_data_sample.py rename to visdet/structures/det_data_sample.py diff --git a/visdet/visdet/structures/mask/__init__.py b/visdet/structures/mask/__init__.py similarity index 100% rename from visdet/visdet/structures/mask/__init__.py rename to visdet/structures/mask/__init__.py diff --git a/visdet/visdet/structures/mask/__pycache__/__init__.cpython-312.pyc b/visdet/structures/mask/__pycache__/__init__.cpython-312.pyc similarity index 68% rename from visdet/visdet/structures/mask/__pycache__/__init__.cpython-312.pyc rename to visdet/structures/mask/__pycache__/__init__.cpython-312.pyc index 42e8bda..50ac3c0 100644 Binary files a/visdet/visdet/structures/mask/__pycache__/__init__.cpython-312.pyc and b/visdet/structures/mask/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc b/visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc similarity index 95% rename from visdet/visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc rename to visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc index 9cb76c4..5f9c5c2 100644 Binary files a/visdet/visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc and b/visdet/structures/mask/__pycache__/mask_target.cpython-312.pyc differ diff --git a/visdet/visdet/structures/mask/__pycache__/structures.cpython-312.pyc b/visdet/structures/mask/__pycache__/structures.cpython-312.pyc similarity index 99% rename from visdet/visdet/structures/mask/__pycache__/structures.cpython-312.pyc rename to visdet/structures/mask/__pycache__/structures.cpython-312.pyc index 6922559..5754314 100644 Binary files a/visdet/visdet/structures/mask/__pycache__/structures.cpython-312.pyc and b/visdet/structures/mask/__pycache__/structures.cpython-312.pyc differ diff --git a/visdet/visdet/structures/mask/mask_target.py b/visdet/structures/mask/mask_target.py similarity index 100% rename from visdet/visdet/structures/mask/mask_target.py rename to visdet/structures/mask/mask_target.py diff --git a/visdet/visdet/structures/mask/structures.py b/visdet/structures/mask/structures.py similarity index 100% rename from visdet/visdet/structures/mask/structures.py rename to visdet/structures/mask/structures.py diff --git a/visdet/visdet/structures/mask/utils.py b/visdet/structures/mask/utils.py similarity index 100% rename from visdet/visdet/structures/mask/utils.py rename to visdet/structures/mask/utils.py diff --git a/visdet/visdet/testing/__init__.py b/visdet/testing/__init__.py similarity index 100% rename from visdet/visdet/testing/__init__.py rename to visdet/testing/__init__.py diff --git a/visdet/visdet/testing/_utils.py b/visdet/testing/_utils.py similarity index 100% rename from visdet/visdet/testing/_utils.py rename to visdet/testing/_utils.py diff --git a/visdet/tests/__pycache__/__init__.cpython-312.pyc b/visdet/tests/__pycache__/__init__.cpython-312.pyc index d70b796..f492d45 100644 Binary files a/visdet/tests/__pycache__/__init__.cpython-312.pyc and b/visdet/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/tests/__pycache__/__init__.cpython-39.pyc b/visdet/tests/__pycache__/__init__.cpython-39.pyc index 6b5a373..4ea0dda 100644 Binary files a/visdet/tests/__pycache__/__init__.cpython-39.pyc and b/visdet/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/visdet/tests/__pycache__/test_patch_merging_mmcv_compatibility.cpython-39-pytest-8.3.5.pyc b/visdet/tests/__pycache__/test_patch_merging_mmcv_compatibility.cpython-39-pytest-8.3.5.pyc index 204b38d..cfdf01e 100644 Binary files a/visdet/tests/__pycache__/test_patch_merging_mmcv_compatibility.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/__pycache__/test_patch_merging_mmcv_compatibility.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/tests/test_models/__pycache__/__init__.cpython-312.pyc b/visdet/tests/test_models/__pycache__/__init__.cpython-312.pyc index d3f445c..2988892 100644 Binary files a/visdet/tests/test_models/__pycache__/__init__.cpython-312.pyc and b/visdet/tests/test_models/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/tests/test_models/__pycache__/__init__.cpython-39.pyc b/visdet/tests/test_models/__pycache__/__init__.cpython-39.pyc index bef199e..3006f12 100644 Binary files a/visdet/tests/test_models/__pycache__/__init__.cpython-39.pyc and b/visdet/tests/test_models/__pycache__/__init__.cpython-39.pyc differ diff --git a/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-312.pyc b/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-312.pyc index 032cef6..80d8333 100644 Binary files a/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-312.pyc and b/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-39.pyc b/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-39.pyc index 11a9b17..47e728c 100644 Binary files a/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-39.pyc and b/visdet/tests/test_models/test_backbones/__pycache__/__init__.cpython-39.pyc differ diff --git a/visdet/tests/test_models/test_backbones/__pycache__/test_swin.cpython-39-pytest-8.3.5.pyc b/visdet/tests/test_models/test_backbones/__pycache__/test_swin.cpython-39-pytest-8.3.5.pyc index e543dac..027f961 100644 Binary files a/visdet/tests/test_models/test_backbones/__pycache__/test_swin.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/test_models/test_backbones/__pycache__/test_swin.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-312.pyc b/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-312.pyc index e186bae..f9f9cf5 100644 Binary files a/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-312.pyc and b/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-39.pyc b/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-39.pyc index 2c23ed5..bbedf28 100644 Binary files a/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-39.pyc and b/visdet/tests/test_models/test_detectors/__pycache__/__init__.cpython-39.pyc differ diff --git a/visdet/tests/test_models/test_detectors/__pycache__/test_mask_rcnn.cpython-39-pytest-8.3.5.pyc b/visdet/tests/test_models/test_detectors/__pycache__/test_mask_rcnn.cpython-39-pytest-8.3.5.pyc index b93a828..ec3617e 100644 Binary files a/visdet/tests/test_models/test_detectors/__pycache__/test_mask_rcnn.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/test_models/test_detectors/__pycache__/test_mask_rcnn.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/tests/test_models/test_roi_heads/__pycache__/test_bbox_heads.cpython-39-pytest-8.3.5.pyc b/visdet/tests/test_models/test_roi_heads/__pycache__/test_bbox_heads.cpython-39-pytest-8.3.5.pyc index 80ee9cd..76874d0 100644 Binary files a/visdet/tests/test_models/test_roi_heads/__pycache__/test_bbox_heads.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/test_models/test_roi_heads/__pycache__/test_bbox_heads.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/tests/test_models/test_roi_heads/__pycache__/test_standard_roi_head.cpython-39-pytest-8.3.5.pyc b/visdet/tests/test_models/test_roi_heads/__pycache__/test_standard_roi_head.cpython-39-pytest-8.3.5.pyc index fb95d99..bbaf5bf 100644 Binary files a/visdet/tests/test_models/test_roi_heads/__pycache__/test_standard_roi_head.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/test_models/test_roi_heads/__pycache__/test_standard_roi_head.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-312.pyc b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-312.pyc index 50e9143..74ff9cf 100644 Binary files a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-312.pyc and b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-39.pyc b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-39.pyc index 126d672..5795da2 100644 Binary files a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-39.pyc and b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/__init__.cpython-39.pyc differ diff --git a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/test_fcn_mask_head.cpython-39-pytest-8.3.5.pyc b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/test_fcn_mask_head.cpython-39-pytest-8.3.5.pyc index cdf96f8..a9717a7 100644 Binary files a/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/test_fcn_mask_head.cpython-39-pytest-8.3.5.pyc and b/visdet/tests/test_models/test_roi_heads/test_mask_heads/__pycache__/test_fcn_mask_head.cpython-39-pytest-8.3.5.pyc differ diff --git a/visdet/visdet/utils/__init__.py b/visdet/utils/__init__.py similarity index 100% rename from visdet/visdet/utils/__init__.py rename to visdet/utils/__init__.py diff --git a/visdet/visdet/utils/misc.py b/visdet/utils/misc.py similarity index 100% rename from visdet/visdet/utils/misc.py rename to visdet/utils/misc.py diff --git a/visdet/visdet/utils/setup_env.py b/visdet/utils/setup_env.py similarity index 100% rename from visdet/visdet/utils/setup_env.py rename to visdet/utils/setup_env.py diff --git a/visdet/visdet/utils/typing_utils.py b/visdet/utils/typing_utils.py similarity index 100% rename from visdet/visdet/utils/typing_utils.py rename to visdet/utils/typing_utils.py diff --git a/visdet/visdet/utils/util_mixins.py b/visdet/utils/util_mixins.py similarity index 100% rename from visdet/visdet/utils/util_mixins.py rename to visdet/utils/util_mixins.py diff --git a/visdet/visdet/version.py b/visdet/version.py similarity index 100% rename from visdet/visdet/version.py rename to visdet/version.py diff --git a/visdet/visdet/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/__pycache__/__init__.cpython-312.pyc index 02f6373..95e65ea 100644 Binary files a/visdet/visdet/__pycache__/__init__.cpython-312.pyc and b/visdet/visdet/__pycache__/__init__.cpython-312.pyc differ diff --git a/visdet/visdet/__pycache__/registry.cpython-312.pyc b/visdet/visdet/__pycache__/registry.cpython-312.pyc index efc88a1..bfcd855 100644 Binary files a/visdet/visdet/__pycache__/registry.cpython-312.pyc and b/visdet/visdet/__pycache__/registry.cpython-312.pyc differ diff --git a/visdet/visdet/__pycache__/version.cpython-312.pyc b/visdet/visdet/__pycache__/version.cpython-312.pyc index 3f47ccb..2a9a22f 100644 Binary files a/visdet/visdet/__pycache__/version.cpython-312.pyc and b/visdet/visdet/__pycache__/version.cpython-312.pyc differ diff --git a/visdet/visdet/cv/__init__.py b/visdet/visdet/cv/__init__.py deleted file mode 100644 index 391db3a..0000000 --- a/visdet/visdet/cv/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# ruff: noqa -""" -Computer Vision utilities (re-export of viscv). - -This module provides access to viscv functionality through visdet.cv -for better namespace organization and discoverability. - -Usage: - # New preferred way - from visdet import cv - from visdet.cv import image - - # Legacy way (still works but discouraged inside visdet) - import viscv - from viscv import image - -All viscv functionality is re-exported here for backwards compatibility -and namespace consistency within the visdet package. -""" - -# 1. Re-export all top-level symbols from the original `viscv` library. -from viscv import * # noqa: F401, F403 - -# 2. Explicitly import submodules using relative paths for clarity and to make -# them accessible under the `visdet.cv` namespace (e.g., `visdet.cv.image`). -from . import cnn, fileio, image, ops, transforms - -# 3. Construct `__all__` to control `from visdet.cv import *` behavior. -# This combines top-level symbols from viscv with the submodule names. -try: - # Dynamically get __all__ from the original library if it exists. - from viscv import __all__ as viscv_all -except ImportError: - # Fallback if viscv.__all__ is not defined. - viscv_all = [] - -# Expose both the re-exported symbols and the submodules. -__all__ = list(viscv_all) + ["cnn", "fileio", "image", "ops", "transforms"] diff --git a/visdet/visdet/cv/cnn/__init__.py b/visdet/visdet/cv/cnn/__init__.py deleted file mode 100644 index 2b81674..0000000 --- a/visdet/visdet/cv/cnn/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.cnn for dotted import support. - -This module allows `from visdet.cv.cnn import X` to work properly. -""" - -from viscv.cnn import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.cv.cnn import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/cnn/bricks/__init__.py b/visdet/visdet/cv/cnn/bricks/__init__.py deleted file mode 100644 index f81df20..0000000 --- a/visdet/visdet/cv/cnn/bricks/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.cnn.bricks for dotted import support. - -This module allows `from visdet.cv.cnn.bricks import X` to work properly. -""" - -from viscv.cnn.bricks import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from viscv.cnn.bricks import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/cnn/bricks/transformer.py b/visdet/visdet/cv/cnn/bricks/transformer.py deleted file mode 100644 index 44ab37d..0000000 --- a/visdet/visdet/cv/cnn/bricks/transformer.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.cnn.bricks.transformer for dotted import support. - -This module allows `from visdet.cv.cnn.bricks.transformer import X` to work properly. -""" - -from viscv.cnn.bricks.transformer import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from viscv.cnn.bricks.transformer import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/fileio.py b/visdet/visdet/cv/fileio.py deleted file mode 100644 index 5c6bcc7..0000000 --- a/visdet/visdet/cv/fileio.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.fileio for dotted import support. - -This module allows `from visdet.cv.fileio import X` or `import visdet.cv.fileio` to work properly. -""" - -from viscv.fileio import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from viscv.fileio import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/image.py b/visdet/visdet/cv/image.py deleted file mode 100644 index f734397..0000000 --- a/visdet/visdet/cv/image.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.image for dotted import support. - -This module allows `from visdet.cv.image import X` to work properly. -""" - -from viscv.image import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.cv.image import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/ops/__init__.py b/visdet/visdet/cv/ops/__init__.py deleted file mode 100644 index 85a8057..0000000 --- a/visdet/visdet/cv/ops/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.ops for dotted import support. - -This module allows `from visdet.cv.ops import X` to work properly. -""" - -from viscv.ops import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.cv.ops import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/ops/nms.py b/visdet/visdet/cv/ops/nms.py deleted file mode 100644 index c386630..0000000 --- a/visdet/visdet/cv/ops/nms.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.ops.nms for dotted import support. - -This module allows `from visdet.cv.ops.nms import X` to work properly. -""" - -from viscv.ops.nms import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from viscv.ops.nms import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/cv/transforms.py b/visdet/visdet/cv/transforms.py deleted file mode 100644 index 2df4033..0000000 --- a/visdet/visdet/cv/transforms.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of viscv.transforms for dotted import support. - -This module allows `from visdet.cv.transforms import X` to work properly. -""" - -from viscv.transforms import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.cv.transforms import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/__init__.py b/visdet/visdet/engine/__init__.py deleted file mode 100644 index 1d86e72..0000000 --- a/visdet/visdet/engine/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -# ruff: noqa -# Copyright (c) OpenMMLab. All rights reserved. -""" -Engine utilities for training and inference. - -This module includes: -1. visdet-specific hooks (from . import hooks) -2. Re-exports of visengine functionality for visdet.engine access - -Usage: - # New preferred way - from visdet import engine - from visdet.engine import Config, Runner - - # Legacy way (still works but discouraged inside visdet) - import visengine - from visengine import Config, Runner - -All visengine functionality is re-exported here for backwards compatibility -and namespace consistency within the visdet package. -""" - -# 1. Re-export all top-level symbols from the original `visengine` library. -from visengine import * # noqa: F401, F403 - -# 2. Explicitly import submodules using relative paths to make them accessible -# under the `visdet.engine` namespace (e.g., `visdet.engine.runner`). -from . import ( - config, - dataset, - dist, - evaluator, - fileio, - infer, - logging, - model, - registry, - runner, - structures, - utils, - visualization, -) - -# NOTE: We don't eagerly import `hooks` here to avoid the circular import -# issue identified during development. It remains accessible via direct import: -# `from visdet.engine import hooks` or `from visdet.engine.hooks import ...` - -# 3. Construct `__all__` to control `from visdet.engine import *` behavior. -try: - from visengine import __all__ as visengine_all -except ImportError: - visengine_all = [] - -# Expose re-exported symbols and all submodules except 'hooks'. -__all__ = list(visengine_all) + [ - "config", - "dataset", - "dist", - "evaluator", - "fileio", - "infer", - "logging", - "model", - "registry", - "runner", - "structures", - "utils", - "visualization", -] diff --git a/visdet/visdet/engine/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/engine/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 3a7bc50..0000000 Binary files a/visdet/visdet/engine/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/visdet/visdet/engine/config.py b/visdet/visdet/engine/config.py deleted file mode 100644 index 209b60b..0000000 --- a/visdet/visdet/engine/config.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.config for dotted import support. - -This module allows `from visdet.engine.config import X` to work properly. -""" - -from visengine.config import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.config import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/dataset.py b/visdet/visdet/engine/dataset.py deleted file mode 100644 index 1bb1bf9..0000000 --- a/visdet/visdet/engine/dataset.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.dataset for dotted import support. - -This module allows `from visdet.engine.dataset import X` to work properly. -""" - -from visengine.dataset import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.dataset import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/dist.py b/visdet/visdet/engine/dist.py deleted file mode 100644 index 7a16591..0000000 --- a/visdet/visdet/engine/dist.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.dist for dotted import support. - -This module allows `from visdet.engine.dist import X` to work properly. -""" - -from visengine.dist import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.dist import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/evaluator.py b/visdet/visdet/engine/evaluator.py deleted file mode 100644 index aabc50b..0000000 --- a/visdet/visdet/engine/evaluator.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.evaluator for dotted import support. - -This module allows `from visdet.engine.evaluator import X` to work properly. -""" - -from visengine.evaluator import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.evaluator import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/fileio.py b/visdet/visdet/engine/fileio.py deleted file mode 100644 index 65b91c7..0000000 --- a/visdet/visdet/engine/fileio.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.fileio for dotted import support. - -This module allows `from visdet.engine.fileio import X` to work properly. -""" - -from visengine.fileio import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.fileio import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 14a127e..0000000 Binary files a/visdet/visdet/engine/hooks/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/visdet/visdet/engine/infer.py b/visdet/visdet/engine/infer.py deleted file mode 100644 index f842c90..0000000 --- a/visdet/visdet/engine/infer.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.infer for dotted import support. - -This module allows `from visdet.engine.infer import X` to work properly. -""" - -from visengine.infer import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.infer import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/logging.py b/visdet/visdet/engine/logging.py deleted file mode 100644 index 41b4057..0000000 --- a/visdet/visdet/engine/logging.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.logging for dotted import support. - -This module allows `from visdet.engine.logging import X` to work properly. -""" - -from visengine.logging import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.logging import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/model/__init__.py b/visdet/visdet/engine/model/__init__.py deleted file mode 100644 index fca93dc..0000000 --- a/visdet/visdet/engine/model/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.model for dotted import support. - -This module allows `from visdet.engine.model import X` to work properly. -""" - -from visengine.model import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.model import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/model/weight_init.py b/visdet/visdet/engine/model/weight_init.py deleted file mode 100644 index bcf0d56..0000000 --- a/visdet/visdet/engine/model/weight_init.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.model.weight_init for dotted import support. - -This module allows `from visdet.engine.model.weight_init import X` to work properly. -""" - -from visengine.model.weight_init import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visengine.model.weight_init import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/optimizers/__init__.py b/visdet/visdet/engine/optimizers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/visdet/visdet/engine/registry.py b/visdet/visdet/engine/registry.py deleted file mode 100644 index a4e9eee..0000000 --- a/visdet/visdet/engine/registry.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.registry for dotted import support. - -This module allows `from visdet.engine.registry import X` to work properly. -""" - -from visengine.registry import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.registry import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/runner/__init__.py b/visdet/visdet/engine/runner/__init__.py deleted file mode 100644 index 85eace5..0000000 --- a/visdet/visdet/engine/runner/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.runner for dotted import support. - -This module allows `from visdet.engine.runner import X` to work properly. -""" - -from visengine.runner import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.runner import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/runner/checkpoint.py b/visdet/visdet/engine/runner/checkpoint.py deleted file mode 100644 index 6dd0caf..0000000 --- a/visdet/visdet/engine/runner/checkpoint.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.runner.checkpoint for dotted import support. - -This module allows `from visdet.engine.runner.checkpoint import X` to work properly. -""" - -from visengine.runner.checkpoint import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visengine.runner.checkpoint import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/structures.py b/visdet/visdet/engine/structures.py deleted file mode 100644 index 903baff..0000000 --- a/visdet/visdet/engine/structures.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.structures for dotted import support. - -This module allows `from visdet.engine.structures import X` to work properly. -""" - -from visengine.structures import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.structures import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/utils.py b/visdet/visdet/engine/utils.py deleted file mode 100644 index 804e593..0000000 --- a/visdet/visdet/engine/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.utils for dotted import support. - -This module allows `from visdet.engine.utils import X` to work properly. -""" - -from visengine.utils import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visengine.utils import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/engine/visualization.py b/visdet/visdet/engine/visualization.py deleted file mode 100644 index 2b43387..0000000 --- a/visdet/visdet/engine/visualization.py +++ /dev/null @@ -1,14 +0,0 @@ -# ruff: noqa -""" -Re-export of visengine.visualization for dotted import support. - -This module allows `from visdet.engine.visualization import X` to work properly. -""" - -from visengine.visualization import * # noqa: F401, F403 - -# Preserve the __all__ from upstream if it exists -try: - from visdet.engine.visualization import __all__ # noqa: F401 -except ImportError: - pass diff --git a/visdet/visdet/models/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/models/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 261a58d..0000000 Binary files a/visdet/visdet/models/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/visdet/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 3eb226a..0000000 Binary files a/visdet/visdet/models/backbones/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/visdet/visdet/structures/__pycache__/__init__.cpython-312.pyc b/visdet/visdet/structures/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 6470779..0000000 Binary files a/visdet/visdet/structures/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/visdet/visdet/visualization/__init__.py b/visdet/visualization/__init__.py similarity index 100% rename from visdet/visdet/visualization/__init__.py rename to visdet/visualization/__init__.py diff --git a/visdet/visdet/visualization/local_visualizer.py b/visdet/visualization/local_visualizer.py similarity index 100% rename from visdet/visdet/visualization/local_visualizer.py rename to visdet/visualization/local_visualizer.py diff --git a/visdet/visdet/visualization/palette.py b/visdet/visualization/palette.py similarity index 100% rename from visdet/visdet/visualization/palette.py rename to visdet/visualization/palette.py