Skip to content

Commit b9c3bf9

Browse files
dipannita08copybara-github
authored andcommitted
Add monitoring API to upload step deviation to Tensorboard.
PiperOrigin-RevId: 720236693
1 parent 68dabdd commit b9c3bf9

File tree

2 files changed

+324
-4
lines changed

2 files changed

+324
-4
lines changed

ml_goodput_measurement/src/monitoring.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_TENSORBOARD_GCS_SUBDIR = 'goodput'
1717
_TENSORBOARD_GOODPUT_LABEL = 'goodput'
1818
_TENSORBOARD_BADPUT_LABEL = 'badput'
19+
_TENSORBOARD_STEP_DEVIATION_LABEL = 'step_deviation'
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -32,6 +33,9 @@ def __init__(
3233
monitoring_enabled: bool = False,
3334
pathway_enabled: bool = False,
3435
include_badput_breakdown=False,
36+
include_step_deviation=False,
37+
configured_ideal_step_time=None,
38+
step_deviation_interval_seconds=10,
3539
):
3640
"""Initializes the GoodputMonitor.
3741
@@ -47,6 +51,11 @@ def __init__(
4751
pathway_enabled: Whether the application is using Pathways.
4852
include_badput_breakdown: Whether to query and upload badput breakdown
4953
data to Tensorboard.
54+
include_step_deviation: Whether to query and upload step deviation data
55+
to Tensorboard.
56+
step_deviation_interval_seconds: The interval to query step deviation
57+
data.
58+
configured_ideal_step_time: The optionalideal step time configured by the user.
5059
"""
5160
if not monitoring_enabled:
5261
logger.info(
@@ -55,32 +64,48 @@ def __init__(
5564
)
5665
return
5766

67+
# Common configurations.
5868
self._job_name = job_name
5969
self._logger_name = logger_name
6070
self._tensorboard_dir = os.path.join(
6171
tensorboard_dir, _TENSORBOARD_GCS_SUBDIR
6272
)
73+
# Goodput configurations.
6374
self._upload_interval = upload_interval
75+
self._include_badput_breakdown = include_badput_breakdown
76+
77+
# Step deviation configurations.
78+
self._include_step_deviation = include_step_deviation
79+
self._step_deviation_interval_seconds = step_deviation_interval_seconds
80+
self._configured_ideal_step_time = configured_ideal_step_time
81+
82+
# Initialize the GoodputCalculator.
6483
self._goodput_calculator = GoodputCalculator(
6584
job_name=self._job_name,
6685
logger_name=self._logger_name,
6786
using_pathways=pathway_enabled,
6887
)
6988
self._writer = writer.SummaryWriter(self._tensorboard_dir)
70-
self._include_badput_breakdown = include_badput_breakdown
7189

72-
# Flag to signal the daemon thread if it exists when to initate
90+
# Goodput uploader flags to signal the daemon thread if it exists when to initate
7391
# shutdown and wait for termination.
7492
self._uploader_thread_running = False
7593
self._goodput_upload_thread = None
7694
self._termination_event = threading.Event()
7795
self._termination_event.clear()
7896

97+
# Step deviation threading flags.
98+
self._step_deviation_uploader_thread_running = False
99+
self._step_deviation_upload_thread = None
100+
self._step_deviation_termination_event = threading.Event()
101+
self._step_deviation_termination_event.clear()
102+
79103
def __del__(self):
80104
if self._uploader_thread_running:
81105
self.stop_goodput_uploader()
106+
self.stop_step_deviation_uploader()
82107

83-
def _write_to_tensorboard(
108+
def _write_goodput_data_to_tensorboard(
84109
self,
85110
job_goodput: float,
86111
badput_breakdown: dict[BadputType, float],
@@ -125,7 +150,7 @@ def _query_and_upload_goodput(self):
125150
include_badput_breakdown=self._include_badput_breakdown
126151
)
127152
)
128-
self._write_to_tensorboard(job_goodput, job_badput_breakdown, last_step)
153+
self._write_goodput_data_to_tensorboard(job_goodput, job_badput_breakdown, last_step)
129154
except Exception as e: # pylint: disable=broad-exception-caught
130155
logger.error(
131156
'Error while querying and uploading goodput to Tensorboard. This'
@@ -165,3 +190,81 @@ def stop_goodput_uploader(self):
165190
' be uploaded to Tensorboard.'
166191
)
167192
self._uploader_thread_running = False
193+
194+
def _write_step_deviation_to_tensorboard(
195+
self, step_deviation: dict[int, float]
196+
):
197+
if self._writer is not None:
198+
for step_count, step_deviation in step_deviation.items():
199+
self._writer.add_scalar(
200+
_TENSORBOARD_STEP_DEVIATION_LABEL,
201+
float(step_deviation),
202+
step_count,
203+
)
204+
self._writer.flush()
205+
206+
def _query_and_upload_step_deviation(self):
207+
"""Queries and uploads step deviation data to Tensorboard."""
208+
while not self._step_deviation_termination_event.is_set():
209+
time.sleep(self._step_deviation_interval_seconds)
210+
try:
211+
step_deviation = self._goodput_calculator.get_step_deviation(
212+
self._configured_ideal_step_time
213+
)
214+
except Exception as e: # pylint: disable=broad-exception-caught
215+
logger.error(
216+
'Error while querying step deviation to Tensorboard.'
217+
' This will not impact the workload. Error: %s',
218+
e,
219+
)
220+
continue
221+
try:
222+
self._write_step_deviation_to_tensorboard(step_deviation)
223+
except Exception as e: # pylint: disable=broad-exception-caught
224+
logger.error(
225+
'Error while writing step deviation to Tensorboard.'
226+
' This will not impact the workload. Error: %s',
227+
e,
228+
)
229+
230+
def start_step_deviation_uploader(self):
231+
"""Starts the step deviation uploader thread."""
232+
if not self._include_step_deviation:
233+
logger.info(
234+
'Step deviation monitoring is disabled. Returning without initializing'
235+
' step deviation uploader thread.'
236+
)
237+
return
238+
239+
if self._step_deviation_uploader_thread_running:
240+
raise RuntimeError('Step deviation uploader thread is already running.')
241+
242+
self._step_deviation_termination_event.clear()
243+
self._step_deviation_upload_thread = threading.Thread(
244+
target=self._query_and_upload_step_deviation, daemon=True
245+
)
246+
logger.info(
247+
'Starting step deviation query and uploader thread in the background'
248+
' for job: %s and logger: %s',
249+
self._job_name,
250+
self._logger_name,
251+
)
252+
self._step_deviation_upload_thread.start()
253+
self._step_deviation_uploader_thread_running = True
254+
255+
def stop_step_deviation_uploader(self):
256+
"""Stops the step deviation uploader thread."""
257+
if not self._step_deviation_uploader_thread_running:
258+
raise RuntimeError('Step deviation uploader thread is not running.')
259+
260+
self._step_deviation_termination_event.set()
261+
if self._step_deviation_upload_thread is not None:
262+
logger.info(
263+
'Waiting for step deviation query and uploader thread to complete.'
264+
)
265+
self._step_deviation_upload_thread.join()
266+
logger.info(
267+
'Step deviation query and uploader thread stopped. No more step'
268+
' deviation data will be uploaded to Tensorboard.'
269+
)
270+
self._step_deviation_uploader_thread_running = False
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Tests to validate the monitoring module.
2+
3+
This module tests the GoodputMonitor class and its functionality, specifically
4+
the uploading of step deviation, goodput and badput data to Tensorboard.
5+
"""
6+
7+
from unittest import mock
8+
9+
from absl.testing import absltest
10+
from cloud_goodput.ml_goodput_measurement.src import monitoring
11+
12+
GoodputMonitor = monitoring.GoodputMonitor
13+
patch = mock.patch
14+
MagicMock = mock.MagicMock
15+
16+
_TEST_UPLOAD_INTERVAL = 1
17+
18+
19+
class GoodputMonitorTests(absltest.TestCase):
20+
"""Tests for the GoodputMonitor class."""
21+
22+
def setUp(self):
23+
super().setUp()
24+
self.job_name = 'test-run'
25+
self.logger_name = 'test-logger'
26+
self.tensorboard_dir = 'test-dir'
27+
28+
@patch('tensorboardX.writer.SummaryWriter')
29+
@patch('google.cloud.logging.Client')
30+
def test_goodput_monitor_init(self, mock_logger_client, mock_summary_writer):
31+
mock_summary_writer.return_value = MagicMock()
32+
mock_logger_client.return_value = MagicMock()
33+
goodput_monitor = GoodputMonitor(
34+
self.job_name,
35+
self.logger_name,
36+
self.tensorboard_dir,
37+
upload_interval=_TEST_UPLOAD_INTERVAL,
38+
monitoring_enabled=True,
39+
)
40+
# Objects should be initialized correctly.
41+
self.assertIsNotNone(goodput_monitor)
42+
self.assertIs(goodput_monitor._writer, mock_summary_writer.return_value)
43+
self.assertIsNotNone(goodput_monitor._goodput_calculator)
44+
45+
# Thread events should be initialized correctly.
46+
self.assertIsNotNone(goodput_monitor._step_deviation_termination_event)
47+
self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set())
48+
self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running)
49+
self.assertIsNotNone(goodput_monitor._termination_event)
50+
self.assertFalse(goodput_monitor._termination_event.is_set())
51+
self.assertFalse(goodput_monitor._uploader_thread_running)
52+
53+
@patch(
54+
'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_goodput_to_tensorboard'
55+
)
56+
@patch('tensorboardX.writer.SummaryWriter')
57+
@patch('google.cloud.logging.Client')
58+
async def test_goodput_monitor_start_goodput_uploader_success(
59+
self, mock_logger_client, mock_summary_writer, mock_goodput_to_tensorboard
60+
):
61+
mock_summary_writer.return_value = MagicMock()
62+
mock_goodput_to_tensorboard.return_value = MagicMock()
63+
mock_logger_client.return_value = MagicMock()
64+
goodput_monitor = monitoring.GoodputMonitor(
65+
self.job_name,
66+
self.logger_name,
67+
self.tensorboard_dir,
68+
upload_interval=_TEST_UPLOAD_INTERVAL,
69+
monitoring_enabled=True,
70+
)
71+
goodput_monitor.start_goodput_uploader()
72+
self.assertTrue(goodput_monitor._uploader_thread_running)
73+
self.assertIsNotNone(goodput_monitor._goodput_upload_thread)
74+
self.assertFalse(goodput_monitor._termination_event.is_set())
75+
mock_goodput_to_tensorboard.assert_called_once()
76+
mock_summary_writer.return_value.add_scalar.assert_called_once()
77+
goodput_monitor.stop_goodput_uploader()
78+
self.assertFalse(goodput_monitor._uploader_thread_running)
79+
self.assertIsNone(goodput_monitor._goodput_upload_thread)
80+
self.assertTrue(goodput_monitor._termination_event.is_set())
81+
82+
@patch(
83+
'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_goodput_to_tensorboard'
84+
)
85+
@patch('tensorboardX.writer.SummaryWriter')
86+
@patch('google.cloud.logging.Client')
87+
async def test_goodput_monitor_start_goodput_uploader_failure(
88+
self, mock_logger_client, mock_summary_writer, mock_goodput_to_tensorboard
89+
):
90+
mock_logger_client.return_value = MagicMock()
91+
mock_summary_writer.return_value = MagicMock()
92+
mock_goodput_to_tensorboard.side_effect = ValueError('Test Error')
93+
goodput_monitor = monitoring.GoodputMonitor(
94+
self.job_name,
95+
self.logger_name,
96+
self.tensorboard_dir,
97+
upload_interval=_TEST_UPLOAD_INTERVAL,
98+
monitoring_enabled=True,
99+
)
100+
goodput_monitor.start_goodput_uploader()
101+
self.assertTrue(goodput_monitor._uploader_thread_running)
102+
self.assertIsNotNone(goodput_monitor._goodput_upload_thread)
103+
self.assertFalse(goodput_monitor._termination_event.is_set())
104+
mock_goodput_to_tensorboard.assert_called_once()
105+
with self.assertRaisesRegex(ValueError, 'Test Error'):
106+
goodput_monitor._query_and_upload_goodput()
107+
mock_summary_writer.return_value.add_scalar.assert_not_called()
108+
goodput_monitor.stop_goodput_uploader()
109+
self.assertFalse(goodput_monitor._uploader_thread_running)
110+
self.assertIsNone(goodput_monitor._goodput_upload_thread)
111+
self.assertTrue(goodput_monitor._termination_event.is_set())
112+
113+
@patch(
114+
'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_badput_to_tensorboard'
115+
)
116+
@patch('tensorboardX.writer.SummaryWriter')
117+
@patch('google.cloud.logging.Client')
118+
async def test_goodput_monitor_start_badput_uploader_success(
119+
self, mock_logger_client, mock_summary_writer, mock_badput_to_tensorboard
120+
):
121+
mock_summary_writer.return_value = MagicMock()
122+
mock_badput_to_tensorboard.return_value = MagicMock()
123+
mock_logger_client.return_value = MagicMock()
124+
goodput_monitor = monitoring.GoodputMonitor(
125+
self.job_name,
126+
self.logger_name,
127+
self.tensorboard_dir,
128+
upload_interval=_TEST_UPLOAD_INTERVAL,
129+
monitoring_enabled=True,
130+
include_badput_breakdown=True,
131+
)
132+
133+
goodput_monitor.start_goodput_uploader()
134+
self.assertTrue(goodput_monitor._uploader_thread_running)
135+
self.assertIsNotNone(goodput_monitor._goodput_upload_thread)
136+
self.assertFalse(goodput_monitor._termination_event.is_set())
137+
self.assertTrue(goodput_monitor._include_badput_breakdown)
138+
139+
mock_badput_to_tensorboard.assert_called_once()
140+
mock_summary_writer.return_value.add_scalar.assert_called_once()
141+
142+
goodput_monitor.stop_goodput_uploader()
143+
self.assertFalse(goodput_monitor._uploader_thread_running)
144+
self.assertIsNone(goodput_monitor._goodput_upload_thread)
145+
self.assertTrue(goodput_monitor._termination_event.is_set())
146+
147+
@patch(
148+
'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_step_deviation_to_tensorboard'
149+
)
150+
@patch('tensorboardX.writer.SummaryWriter')
151+
@patch('google.cloud.logging.Client')
152+
async def test_goodput_monitor_start_step_deviation_uploader_success(
153+
self,
154+
mock_logger_client,
155+
mock_summary_writer,
156+
mock_step_deviation_to_tensorboard,
157+
):
158+
mock_logger_client.return_value = MagicMock()
159+
mock_summary_writer.return_value = MagicMock()
160+
mock_step_deviation_to_tensorboard.return_value = MagicMock()
161+
goodput_monitor = monitoring.GoodputMonitor(
162+
self.job_name,
163+
self.logger_name,
164+
self.tensorboard_dir,
165+
upload_interval=_TEST_UPLOAD_INTERVAL,
166+
monitoring_enabled=True,
167+
include_step_deviation=True,
168+
)
169+
goodput_monitor.start_step_deviation_uploader()
170+
self.assertTrue(goodput_monitor._step_deviation_uploader_thread_running)
171+
self.assertIsNotNone(goodput_monitor._step_deviation_upload_thread)
172+
self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set())
173+
mock_step_deviation_to_tensorboard.assert_called_once()
174+
mock_summary_writer.return_value.add_scalar.assert_called_once()
175+
goodput_monitor.stop_step_deviation_uploader()
176+
self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running)
177+
self.assertIsNone(goodput_monitor._step_deviation_upload_thread)
178+
self.assertTrue(goodput_monitor._step_deviation_termination_event.is_set())
179+
180+
@patch(
181+
'cloud_goodput.ml_goodput_measurement.src.monitoring.GoodputMonitor._write_step_deviation_to_tensorboard'
182+
)
183+
@patch('tensorboardX.writer.SummaryWriter')
184+
@patch('google.cloud.logging.Client')
185+
async def test_goodput_monitor_start_step_deviation_uploader_failure(
186+
self,
187+
mock_logger_client,
188+
mock_summary_writer,
189+
mock_query_and_upload_step_deviation,
190+
):
191+
mock_logger_client.return_value = MagicMock()
192+
mock_summary_writer.return_value = MagicMock()
193+
mock_query_and_upload_step_deviation.side_effect = ValueError('Test Error')
194+
goodput_monitor = monitoring.GoodputMonitor(
195+
self.job_name,
196+
self.logger_name,
197+
self.tensorboard_dir,
198+
upload_interval=_TEST_UPLOAD_INTERVAL,
199+
monitoring_enabled=True,
200+
include_step_deviation=True,
201+
)
202+
goodput_monitor.start_step_deviation_uploader()
203+
self.assertTrue(goodput_monitor._step_deviation_uploader_thread_running)
204+
self.assertIsNotNone(goodput_monitor._step_deviation_upload_thread)
205+
self.assertFalse(goodput_monitor._step_deviation_termination_event.is_set())
206+
mock_query_and_upload_step_deviation.assert_called_once()
207+
with self.assertRaisesRegex(ValueError, 'Test Error'):
208+
goodput_monitor._query_and_upload_step_deviation()
209+
mock_summary_writer.return_value.add_scalar.assert_not_called()
210+
goodput_monitor.stop_step_deviation_uploader()
211+
self.assertFalse(goodput_monitor._step_deviation_uploader_thread_running)
212+
self.assertIsNone(goodput_monitor._step_deviation_upload_thread)
213+
self.assertTrue(goodput_monitor._step_deviation_termination_event.is_set())
214+
215+
216+
if __name__ == '__main__':
217+
absltest.main()

0 commit comments

Comments
 (0)