Skip to content

Commit fb4918d

Browse files
committed
Clean up code and add docstrings
1 parent 564dab0 commit fb4918d

File tree

1 file changed

+51
-23
lines changed

1 file changed

+51
-23
lines changed

tests/stability/test_adaptive_scaling.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111

1212

1313
@pytest.mark.stability
14-
@pytest.mark.parametrize("minimum", (0, 1))
14+
@pytest.mark.parametrize("minimum,threshold", [(0, 240), (1, 120)])
1515
@pytest.mark.parametrize(
1616
"scatter",
1717
(
1818
False,
1919
pytest.param(True, marks=[pytest.mark.xfail(reason="dask/distributed#6686")]),
2020
),
2121
)
22-
def test_scale_up_on_task_load(minimum, scatter):
22+
def test_scale_up_on_task_load(minimum, threshold, scatter):
23+
"""Tests that adaptive scaling reacts in a reasonable amount of time to
24+
an increased task load and scales up.
25+
"""
2326
maximum = 10
2427
with Cluster(
2528
name=f"test_adaptive_scaling-{uuid.uuid4().hex}",
@@ -47,7 +50,7 @@ def clog(x: int, ev: Event) -> int:
4750
client.wait_for_workers(n_workers=maximum, timeout=TIMEOUT_THRESHOLD)
4851
start = time.monotonic()
4952
duration = end - start
50-
assert duration < 360
53+
assert duration < threshold, duration
5154
assert len(adapt.log) <= 2
5255
assert adapt.log[-1][1] == {"status": "up", "n": maximum}
5356
ev_fan_out.set()
@@ -58,6 +61,10 @@ def clog(x: int, ev: Event) -> int:
5861
@pytest.mark.stability
5962
@pytest.mark.parametrize("minimum", (0, 1))
6063
def test_adapt_to_changing_workload(minimum: int):
64+
"""Tests that adaptive scaling reacts within a reasonable amount of time to
65+
a varying task load and scales up or down. This also asserts that no recomputation
66+
is caused by the scaling.
67+
"""
6168
maximum = 10
6269
fan_out_size = 100
6370
with Cluster(
@@ -70,41 +77,58 @@ def test_adapt_to_changing_workload(minimum: int):
7077
adapt = cluster.adapt(minimum=minimum, maximum=maximum)
7178
assert len(adapt.log) == 0
7279

80+
@delayed
7381
def clog(x: int, ev: Event, sem: Semaphore, **kwargs) -> int:
7482
# Ensure that no recomputation happens by decrementing a countdown on a semaphore
75-
acquired = sem.acquire(timeout=0.1)
76-
assert acquired is True
83+
acquired = sem.acquire(timeout=1)
84+
assert acquired is True, "Could not acquire semaphore, likely recomputation happened."
7785
ev.wait()
7886
return x
7987

88+
def workload(
89+
fan_out_size,
90+
ev_fan_out,
91+
sem_fan_out,
92+
ev_barrier,
93+
sem_barrier,
94+
ev_final_fan_out,
95+
sem_final_fan_out,
96+
):
97+
fan_out = [
98+
clog(i, ev=ev_fan_out, sem=sem_fan_out) for i in range(fan_out_size)
99+
]
100+
barrier = clog(delayed(sum)(fan_out), ev=ev_barrier, sem=sem_barrier)
101+
final_fan_out = [
102+
clog(i, ev=ev_final_fan_out, sem=sem_final_fan_out, barrier=barrier)
103+
for i in range(fan_out_size)
104+
]
105+
return final_fan_out
106+
80107
sem_fan_out = Semaphore(name="fan-out", max_leases=fan_out_size)
81108
ev_fan_out = Event(name="fan-out", client=client)
82-
83-
fut = client.map(
84-
clog, range(fan_out_size), ev=ev_fan_out, sem=sem_fan_out
85-
)
86-
87-
fut = client.submit(sum, fut)
88109
sem_barrier = Semaphore(name="barrier", max_leases=1)
89110
ev_barrier = Event(name="barrier", client=client)
90-
fut = client.submit(clog, fut, ev=ev_barrier, sem=sem_barrier)
91-
92111
sem_final_fan_out = Semaphore(name="final-fan-out", max_leases=fan_out_size)
93112
ev_final_fan_out = Event(name="final-fan-out", client=client)
94-
fut = client.map(
95-
clog,
96-
range(fan_out_size),
97-
ev=ev_final_fan_out,
98-
sem=sem_final_fan_out,
99-
barrier=fut,
113+
114+
fut = client.compute(
115+
workload(
116+
fan_out_size=fan_out_size,
117+
ev_fan_out=ev_fan_out,
118+
sem_fan_out=sem_fan_out,
119+
ev_barrier=ev_barrier,
120+
sem_barrier=sem_barrier,
121+
ev_final_fan_out=ev_final_fan_out,
122+
sem_final_fan_out=sem_final_fan_out,
123+
)
100124
)
101125

102126
# Scale up to maximum
103127
start = time.monotonic()
104128
client.wait_for_workers(n_workers=maximum, timeout=TIMEOUT_THRESHOLD)
105129
end = time.monotonic()
106130
duration_first_scale_up = end - start
107-
assert duration_first_scale_up < 420
131+
assert duration_first_scale_up < 120
108132
assert len(cluster.observed) == maximum
109133
assert adapt.log[-1][1]["status"] == "up"
110134

@@ -117,7 +141,7 @@ def clog(x: int, ev: Event, sem: Semaphore, **kwargs) -> int:
117141
time.sleep(0.1)
118142
end = time.monotonic()
119143
duration_first_scale_down = end - start
120-
assert duration_first_scale_down < 420
144+
assert duration_first_scale_down < 330
121145
assert len(cluster.observed) == 1
122146
assert adapt.log[-1][1]["status"] == "down"
123147

@@ -127,7 +151,7 @@ def clog(x: int, ev: Event, sem: Semaphore, **kwargs) -> int:
127151
client.wait_for_workers(n_workers=maximum, timeout=TIMEOUT_THRESHOLD)
128152
end = time.monotonic()
129153
duration_second_scale_up = end - start
130-
assert duration_second_scale_up < 420
154+
assert duration_second_scale_up < 120
131155
assert len(cluster.observed) == maximum
132156
assert adapt.log[-1][1]["status"] == "up"
133157

@@ -143,7 +167,7 @@ def clog(x: int, ev: Event, sem: Semaphore, **kwargs) -> int:
143167
time.sleep(0.1)
144168
end = time.monotonic()
145169
duration_second_scale_down = end - start
146-
assert duration_second_scale_down < 420
170+
assert duration_second_scale_down < 330
147171
assert len(cluster.observed) == minimum
148172
assert adapt.log[-1][1]["status"] == "down"
149173
return (
@@ -160,6 +184,10 @@ def clog(x: int, ev: Event, sem: Semaphore, **kwargs) -> int:
160184
@pytest.mark.stability
161185
@pytest.mark.parametrize("minimum", (0, 1))
162186
def test_adapt_to_memory_intensive_workload(minimum):
187+
"""Tests that adaptive scaling reacts within a reasonable amount of time to a varying task and memory load.
188+
189+
Note: This tests currently results in spilling and very long runtimes.
190+
"""
163191
maximum = 10
164192
with Cluster(
165193
name=f"test_adaptive_scaling-{uuid.uuid4().hex}",

0 commit comments

Comments
 (0)