Skip to content

Commit 3edb84e

Browse files
authored
regression test (#234)
Summary: - add a test that uses fixtures to validate against previous runs of the test - the fixutres can be written using `WRITE_FIXTURE=true pytest -vs ./torchft/test_diloco_mocked_updates.py` - when writing fixtures, the test also numerically validates the implementation of streaming diloco
1 parent 8170a4b commit 3edb84e

9 files changed

+3769
-1
lines changed

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_0.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_1.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_2.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_3.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_4.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.

test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_updates_5.json

Lines changed: 546 additions & 0 deletions
Large diffs are not rendered by default.
File renamed without changes.

torchft/diloco_regression_test.py

Lines changed: 492 additions & 0 deletions
Large diffs are not rendered by default.

torchft/local_sgd_integ_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.distributed.pipelining import SplitPoint, pipeline
1919
from torch.distributed.tensor import DTensor, Replicate
2020

21+
from torchft._test.diloco_trainer import DiLoCoTrainer, MultiMyModel
2122
from torchft._torchft import LighthouseServer
2223
from torchft.device_mesh import ft_init_device_mesh
2324
from torchft.local_sgd import DiLoCo, LocalSGD
@@ -33,7 +34,6 @@
3334
ProcessGroupBabyNCCL,
3435
ProcessGroupGloo,
3536
)
36-
from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
3939

0 commit comments

Comments
 (0)