diff --git a/requirements/lightning.txt b/requirements/lightning.txt index 1d92d1b..2f6a61e 100644 --- a/requirements/lightning.txt +++ b/requirements/lightning.txt @@ -1,4 +1,3 @@ # this sets the requirements contains if you go with main lightning -# in 2.0.7 we have removed lightning.pytorch.overrides.base._LightningPrecisionModuleWrapperBase lightning >=2.0.0, <2.2.0 diff --git a/src/lightning_graphcore/strategy.py b/src/lightning_graphcore/strategy.py index f28dd03..6a701d3 100644 --- a/src/lightning_graphcore/strategy.py +++ b/src/lightning_graphcore/strategy.py @@ -228,9 +228,7 @@ def _convert_to_poptorch_loader( # the user is returning the `poptorch.DataLoader` directly, don't change anything. return dataloader - dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs( - dataloader, sampler, mode, self.replication_factor > 1 - ) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts return _reinstantiate_wrapped_cls(dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs) diff --git a/src/lightning_graphcore/utils.py b/src/lightning_graphcore/utils.py index 9d56a15..d789ffd 100644 --- a/src/lightning_graphcore/utils.py +++ b/src/lightning_graphcore/utils.py @@ -20,15 +20,13 @@ if package_available("lightning"): from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch import LightningModule - from lightning.pytorch.overrides.base import _LightningPrecisionModuleWrapperBase elif package_available("pytorch_lightning"): from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from pytorch_lightning import LightningModule - from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, forward_module: Union[LightningModule, _LightningPrecisionModuleWrapperBase]) -> None: + def __init__(self, forward_module: LightningModule) -> None: """Wrap the user's LightningModule and redirect the forward call to the appropriate `*_step()` methods. Inheriting classes may also modify the inputs or outputs of forward. diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index c9241b8..1998664 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -54,7 +54,6 @@ def test_fail_if_no_ipus(_, tmpdir): # noqa: PT019 Trainer(default_root_dir=tmpdir, accelerator=IPUAccelerator(), devices=1) -@pytest.mark.xfail() # todo def test_accelerator_selected(tmpdir): assert IPUAccelerator.is_available() trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1) @@ -62,7 +61,7 @@ def test_accelerator_selected(tmpdir): def test_warning_if_ipus_not_used(): - with pytest.warns(UserWarning, match="IPU available but not used. Set `accelerator` and `devices`"): + with pytest.warns(UserWarning): Trainer(accelerator="cpu") @@ -82,18 +81,7 @@ def test_all_stages(tmpdir, devices): trainer.predict(model) -@pytest.mark.parametrize( - "devices", - [ - 1, - pytest.param( - 4, - marks=pytest.mark.xfail( # fixme - AssertionError, reason="Invalid batch dimension: In the input torch.Size([1, 32]), ..." - ), - ), - ], -) +@pytest.mark.parametrize("devices", [1, 4]) def test_inference_only(tmpdir, devices): model = IPUModel() @@ -341,7 +329,6 @@ def test_clip_gradients_fails(tmpdir): trainer.fit(model) -@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo def test_autoreport(tmpdir): """Ensure autoreport dumps to a file.""" model = IPUModel() @@ -358,7 +345,6 @@ def test_autoreport(tmpdir): assert os.path.isfile(autoreport_path + "training/profile.pop") -@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo def test_manual_poptorch_dataloader(tmpdir): model_options = poptorch.Options() @@ -390,7 +376,6 @@ def train_dataloader(self): assert dataloader.drop_last # was kept -@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo def test_manual_poptorch_opts(tmpdir): """Ensure if the user passes manual poptorch Options, we run with the correct object.""" model = IPUModel() @@ -573,7 +558,6 @@ def test_accelerator_ipu_with_devices(): assert trainer.num_devices == 8 -@pytest.mark.xfail(AssertionError, reason="not implemented on PL side") def test_accelerator_auto_with_devices_ipu(): trainer = Trainer(accelerator="auto", devices=8) assert isinstance(trainer.accelerator, IPUAccelerator) @@ -618,7 +602,6 @@ def test_poptorch_models_at_different_stages(tmpdir): assert list(trainer.strategy.poptorch_models) == [stage] -@pytest.mark.xfail(AssertionError, reason="not implemented on PL side") def test_devices_auto_choice_ipu(): trainer = Trainer(accelerator="auto", devices="auto") assert trainer.num_devices == 4