From fc707b8cfbaa8ac007efac5ddb53d65602d88510 Mon Sep 17 00:00:00 2001 From: Godwinh19 Date: Sat, 1 Apr 2023 01:48:15 +0100 Subject: [PATCH 1/3] add tests and remove redundant imports --- optimizers/test_optimizer.py | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 optimizers/test_optimizer.py diff --git a/optimizers/test_optimizer.py b/optimizers/test_optimizer.py new file mode 100644 index 0000000..5ef1107 --- /dev/null +++ b/optimizers/test_optimizer.py @@ -0,0 +1,44 @@ +import unittest +import torch +from tml.optimizers.config import LearningRate, OptimizerConfig +from .optimizer import compute_lr, LRShim, get_optimizer_class, build_optimizer + + +class TestComputeLR(unittest.TestCase): + def test_constant_lr(self): + lr_config = LearningRate(constant=0.1) + lr = compute_lr(lr_config, step=0) + self.assertAlmostEqual(lr, 0.1) + + def test_piecewise_constant_lr(self): + lr_config = LearningRate(piecewise_constant={"learning_rate_boundaries": [10, 20], "learning_rate_values": [0.1, 0.01, 0.001]}) + lr = compute_lr(lr_config, step=5) + self.assertAlmostEqual(lr, 0.1) + lr = compute_lr(lr_config, step=15) + self.assertAlmostEqual(lr, 0.01) + lr = compute_lr(lr_config, step=25) + self.assertAlmostEqual(lr, 0.001) + + +class TestLRShim(unittest.TestCase): + def setUp(self): + self.optimizer = torch.optim.SGD([torch.randn(10, 10)], lr=0.1) + self.lr_dict = {"ALL_PARAMS": LearningRate(constant=0.1)} + + def test_get_lr(self): + lr_scheduler = LRShim(self.optimizer, self.lr_dict) + lr = lr_scheduler.get_lr() + self.assertAlmostEqual(lr, [0.1]) + + +class TestBuildOptimizer(unittest.TestCase): + def test_build_optimizer(self): + model = torch.nn.Linear(10, 1) + optimizer_config = OptimizerConfig(sgd={"lr": 0.1}) + optimizer, scheduler = build_optimizer(model, optimizer_config) + self.assertIsInstance(optimizer, torch.optim.SGD) + self.assertIsInstance(scheduler, LRShim) + + +if __name__ == "__main__": + unittest.main() From 103e56f17c0a765d7a0a6de7cc2075e93605832c Mon Sep 17 00:00:00 2001 From: Godwinh19 Date: Sat, 1 Apr 2023 02:04:02 +0100 Subject: [PATCH 2/3] add tests; update roundrobin func --- metrics/test_aggregation.py | 42 +++++++++++++++++++++++++++++++++++++ reader/utils.py | 3 +-- 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 metrics/test_aggregation.py diff --git a/metrics/test_aggregation.py b/metrics/test_aggregation.py new file mode 100644 index 0000000..c35c989 --- /dev/null +++ b/metrics/test_aggregation.py @@ -0,0 +1,42 @@ +import torch +import unittest +from aggregation import StableMean + + +class TestStableMean(unittest.TestCase): + + def setUp(self): + self.metric = StableMean() + + def test_compute_empty(self): + result = self.metric.compute() + self.assertEqual(result, torch.tensor(0.0)) + + def test_compute_single_value(self): + self.metric.update(torch.tensor(1.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_weighted_single_value(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_multiple_values(self): + self.metric.update(torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(2.0)) + + def test_compute_weighted_multiple_values(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0)) + result = self.metric.compute() + print(f"get= {result.item()} but expected= 2.1666666667") + self.assertAlmostEqual(result.item(), 2.1666666667, places=0) + + +if '__name__' == '__main__': + unittest.main() diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..0de5868 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -21,8 +21,7 @@ def roundrobin(*iterables): while num_active: try: for _next in nexts: - result = _next() - yield result + yield _next() except StopIteration: # Remove the iterator we just exhausted from the cycle. num_active -= 1 From 4f8bf389d8a6662faf74d0503a3de8f980483311 Mon Sep 17 00:00:00 2001 From: Godwin H <51889272+Godwinh19@users.noreply.github.com> Date: Sat, 1 Apr 2023 02:13:17 +0100 Subject: [PATCH 3/3] Update test_optimizer.py --- optimizers/test_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimizers/test_optimizer.py b/optimizers/test_optimizer.py index 5ef1107..e396c41 100644 --- a/optimizers/test_optimizer.py +++ b/optimizers/test_optimizer.py @@ -1,7 +1,7 @@ import unittest import torch from tml.optimizers.config import LearningRate, OptimizerConfig -from .optimizer import compute_lr, LRShim, get_optimizer_class, build_optimizer +from optimizer import compute_lr, LRShim, get_optimizer_class, build_optimizer class TestComputeLR(unittest.TestCase):