From c9ec2cc7ef9b7d507fd33f3a98b1b875e8a42d71 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 11 Feb 2026 05:49:49 -0800 Subject: [PATCH 1/2] Fix progress bar --- ctgan/synthesizers/_utils.py | 11 +++++++++ ctgan/synthesizers/ctgan.py | 13 ++++++---- ctgan/synthesizers/tvae.py | 8 +++---- tests/integration/synthesizer/test_tvae.py | 4 ++-- tests/unit/synthesizer/test__utils.py | 28 +++++++++++++++++++++- tests/unit/synthesizer/test_tvae.py | 4 ++-- 6 files changed, 55 insertions(+), 13 deletions(-) diff --git a/ctgan/synthesizers/_utils.py b/ctgan/synthesizers/_utils.py index 78138c4a..e3c6c644 100644 --- a/ctgan/synthesizers/_utils.py +++ b/ctgan/synthesizers/_utils.py @@ -55,3 +55,14 @@ def _set_device(enable_gpu, device=None): def validate_and_set_device(enable_gpu, cuda): enable_gpu = get_enable_gpu_value(enable_gpu, cuda) return _set_device(enable_gpu) + + +def _format_score(score): + """Format a score as a fixed-length string ``±XX.XX``. + + Values are clipped to the range ``[-99.99, +99.99]`` so the result + is always exactly 6 characters. + """ + score = max(-99.99, min(99.99, score)) + sign = '+' if score >= 0 else '-' + return f'{sign}{abs(score):05.2f}' diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index d9398856..74a94f03 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -12,7 +12,7 @@ from ctgan.data_sampler import DataSampler from ctgan.data_transformer import DataTransformer from ctgan.errors import InvalidDataError -from ctgan.synthesizers._utils import _set_device, validate_and_set_device +from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -379,8 +379,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None): epoch_iterator = tqdm(range(epochs), disable=(not self._verbose)) if self._verbose: - description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' - epoch_iterator.set_description(description.format(gen=0, dis=0)) + description = 'Gen. ({gen}) | Discrim. ({dis})' + epoch_iterator.set_description( + description.format(gen=_format_score(0), dis=_format_score(0)) + ) steps_per_epoch = max(len(train_data) // self._batch_size, 1) for i in epoch_iterator: @@ -479,7 +481,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None): if self._verbose: epoch_iterator.set_description( - description.format(gen=generator_loss, dis=discriminator_loss) + description.format( + gen=_format_score(generator_loss), + dis=_format_score(discriminator_loss), + ) ) @random_state diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index 30baea3d..b307a418 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -10,7 +10,7 @@ from tqdm import tqdm from ctgan.data_transformer import DataTransformer -from ctgan.synthesizers._utils import _set_device, validate_and_set_device +from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -161,8 +161,8 @@ def fit(self, train_data, discrete_columns=()): self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) iterator = tqdm(range(self.epochs), disable=(not self.verbose)) if self.verbose: - iterator_description = 'Loss: {loss:.3f}' - iterator.set_description(iterator_description.format(loss=0)) + iterator_description = 'Loss: {loss}' + iterator.set_description(iterator_description.format(loss=_format_score(0))) for i in iterator: loss_values = [] @@ -205,7 +205,7 @@ def fit(self, train_data, discrete_columns=()): if self.verbose: iterator.set_description( - iterator_description.format(loss=loss.detach().cpu().item()) + iterator_description.format(loss=_format_score(loss.detach().cpu().item())) ) @random_state diff --git a/tests/integration/synthesizer/test_tvae.py b/tests/integration/synthesizer/test_tvae.py index bbbf65d4..fcff6f56 100644 --- a/tests/integration/synthesizer/test_tvae.py +++ b/tests/integration/synthesizer/test_tvae.py @@ -120,5 +120,5 @@ def test_tvae_save(tmpdir, capsys): assert len(loss_values) == 10 assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'} assert all(loss_values['Batch'] == 0) - last_loss_val = loss_values['Loss'].iloc[-1] - assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out + last_loss_val = max(-99.99, min(99.99, loss_values['Loss'].iloc[-1])) + assert f'Loss: {last_loss_val:+06.2f}: 100%' in captured_out diff --git a/tests/unit/synthesizer/test__utils.py b/tests/unit/synthesizer/test__utils.py index 151663a8..6a322712 100644 --- a/tests/unit/synthesizer/test__utils.py +++ b/tests/unit/synthesizer/test__utils.py @@ -5,7 +5,12 @@ import pytest import torch -from ctgan.synthesizers._utils import _set_device, get_enable_gpu_value, validate_and_set_device +from ctgan.synthesizers._utils import ( + _format_score, + _set_device, + get_enable_gpu_value, + validate_and_set_device, +) def test__validate_gpu_parameter(): @@ -61,6 +66,27 @@ def test__set_device(): assert device_4 == torch.device('cpu') +@pytest.mark.parametrize( + 'score, expected', + [ + (0, '+00.00'), + (1.233434, '+01.23'), + (-0.93, '-00.93'), + (0.01, '+00.01'), + (-1.21, '-01.21'), + (99.99, '+99.99'), + (-99.99, '-99.99'), + (150, '+99.99'), + (-200, '-99.99'), + ], +) +def test__format_score(score, expected): + """Test the ``_format_score`` method.""" + result = _format_score(score) + assert result == expected + assert len(result) == 6 + + @patch('ctgan.synthesizers._utils._set_device') @patch('ctgan.synthesizers._utils.get_enable_gpu_value') def test_validate_and_set_device(mock_validate, mock_set_device): diff --git a/tests/unit/synthesizer/test_tvae.py b/tests/unit/synthesizer/test_tvae.py index 259b2618..c1d7e538 100644 --- a/tests/unit/synthesizer/test_tvae.py +++ b/tests/unit/synthesizer/test_tvae.py @@ -60,6 +60,6 @@ def mock_add(a, b): # Assert tqdm_mock.assert_called_once_with(range(epochs), disable=False) - assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000') - assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235') + assert iterator_mock.set_description.call_args_list[0] == call('Loss: +00.00') + assert iterator_mock.set_description.call_args_list[1] == call('Loss: +01.23') assert iterator_mock.set_description.call_count == 2 From 769fc676c0ff0c2900da93e75810f7d935ce7544 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Wed, 11 Feb 2026 06:32:28 -0800 Subject: [PATCH 2/2] Fix coverage --- tests/unit/synthesizer/test_ctgan.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/synthesizer/test_ctgan.py b/tests/unit/synthesizer/test_ctgan.py index f131d983..1a936c4b 100644 --- a/tests/unit/synthesizer/test_ctgan.py +++ b/tests/unit/synthesizer/test_ctgan.py @@ -286,6 +286,21 @@ def test__cond_loss(self): assert (result - expected).abs() < 1e-3 + @patch('ctgan.synthesizers.ctgan._format_score') + def test_fit_verbose_calls_format_score(self, format_score_mock): + """Test that ``_format_score`` is called during verbose fitting.""" + # Setup + format_score_mock.side_effect = lambda x: f'+{abs(x):05.2f}' + data = pd.DataFrame({'col1': [0, 1, 2, 3, 4], 'col2': ['a', 'b', 'c', 'a', 'b']}) + + # Run + ctgan = CTGAN(epochs=1, verbose=True) + ctgan.fit(data, discrete_columns=['col2']) + + # Assert + assert format_score_mock.call_count == 4 + format_score_mock.assert_any_call(0) + def test__validate_discrete_columns(self): """Test `_validate_discrete_columns` if the discrete column doesn't exist.