Skip to content

Commit b8b49d3

Browse files
committed
fix fetch batch
Signed-off-by: Praneeth Bedapudi <[email protected]>
1 parent 37df5d5 commit b8b49d3

File tree

4 files changed

+61
-36
lines changed

4 files changed

+61
-36
lines changed

fastdeploy/_infer.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def create_response(self, unique_id, response, is_compressed, input_type):
109109
pass
110110

111111
if is_compressed:
112-
_utils.logger.debug(f"{unique_id}: compressing response")
113112
response = self._compressor.compress(response)
113+
_utils.logger.debug(f"{unique_id}: response compressed")
114114

115115
return success, response
116116

@@ -224,6 +224,7 @@ def get_responses_for_unique_ids(self, unique_ids, is_compresseds, input_types):
224224
all_responses = {}
225225

226226
updations = {}
227+
still_processing = []
227228

228229
for unique_id, is_compressed, input_type in zip(
229230
unique_ids, is_compresseds, input_types
@@ -232,24 +233,21 @@ def get_responses_for_unique_ids(self, unique_ids, is_compresseds, input_types):
232233

233234
if current_results["timedout_in_queue"]:
234235
_utils.logger.warning(f"{unique_id}: timedout in queue")
236+
updations[unique_id] = {
237+
"-1.predicted_at": time.time(),
238+
}
235239
all_responses[unique_id] = self.get_timeout_response(
236240
unique_id, is_compressed, input_type
237241
)
242+
_utils.logger.debug(f"{unique_id}: timedout in queue response created")
238243

239244
elif (
240245
current_results["last_predictor_success"] is True
241246
and current_results["last_predictor_sequence"]
242247
== _utils.LAST_PREDICTOR_SEQUENCE
243248
):
244249
updations[unique_id] = {
245-
**{
246-
"-1.predicted_at": time.time(),
247-
"-1.outputs": None,
248-
},
249-
**{
250-
f"{__}.outputs": None
251-
for __ in _utils.PREDICTOR_SEQUENCE_TO_FILES
252-
},
250+
"-1.predicted_at": time.time(),
253251
}
254252

255253
all_responses[unique_id] = self.create_response(
@@ -265,10 +263,14 @@ def get_responses_for_unique_ids(self, unique_ids, is_compresseds, input_types):
265263
is_compressed,
266264
input_type,
267265
)
266+
_utils.logger.debug(f"{unique_id}: response created")
268267
elif current_results["last_predictor_success"] is False:
269268
_utils.logger.warning(
270269
f"{unique_id}: predictor failed at {current_results['last_predictor_sequence']}"
271270
)
271+
updations[unique_id] = {
272+
"-1.predicted_at": time.time(),
273+
}
272274
all_responses[unique_id] = self.create_response(
273275
unique_id,
274276
{
@@ -280,8 +282,15 @@ def get_responses_for_unique_ids(self, unique_ids, is_compresseds, input_types):
280282
is_compressed,
281283
input_type,
282284
)
285+
_utils.logger.debug(f"{unique_id}: failed response created")
286+
287+
else:
288+
still_processing.append(unique_id)
283289

284290
if updations:
285291
_utils.MAIN_INDEX.update(updations)
286292

293+
if still_processing:
294+
_utils.logger.debug(f"Still processing: {still_processing}")
295+
287296
return all_responses

fastdeploy/_loop.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,16 @@ def process_batch(predictor, input_batch, optimal_batch_size):
7575
return results, last_predictor_success, received_at, predicted_at
7676

7777

78-
def fetch_batch(main_index, predictor_sequence, optimal_batch_size):
78+
def fetch_batch(
79+
main_index,
80+
predictor_sequence,
81+
optimal_batch_size,
82+
max_wait_time_for_batch_collection,
83+
):
7984
unique_id_wise_input_count = {}
8085
input_batch = []
8186
current_batch_length = 0
87+
batch_collection_started_at = time.time()
8288

8389
while current_batch_length < optimal_batch_size:
8490
to_process = main_index.search(
@@ -98,9 +104,6 @@ def fetch_batch(main_index, predictor_sequence, optimal_batch_size):
98104
},
99105
)
100106

101-
if not to_process: # No more items to process
102-
break
103-
104107
for unique_id, data in to_process.items():
105108
outputs = data[f"{predictor_sequence - 1}.outputs"]
106109

@@ -110,6 +113,22 @@ def fetch_batch(main_index, predictor_sequence, optimal_batch_size):
110113
input_batch.extend(outputs)
111114
current_batch_length += input_count
112115

116+
if current_batch_length == 0:
117+
time.sleep(max_wait_time_for_batch_collection / 2)
118+
continue
119+
120+
elif (
121+
time.time() - batch_collection_started_at
122+
< max_wait_time_for_batch_collection
123+
and current_batch_length / optimal_batch_size < 0.9
124+
):
125+
time.sleep(max_wait_time_for_batch_collection / 2)
126+
continue
127+
128+
else:
129+
# finished collecting batch
130+
break
131+
113132
return unique_id_wise_input_count, input_batch
114133

115134

@@ -156,7 +175,7 @@ def start_loop(
156175

157176
optimal_batch_size = predictor_info["optimal_batch_size"]
158177
time_per_example = predictor_info["time_per_example"]
159-
max_wait_time_for_batch_collection = max(0.003, time_per_example * 0.25)
178+
max_wait_time_for_batch_collection = max(0.003, time_per_example * 0.51)
160179

161180
_utils.logger.info(
162181
f"""{predictor_name}
@@ -167,23 +186,27 @@ def start_loop(
167186
"""
168187
)
169188

170-
last_batch_collection_started_at = 0
171-
172189
while True:
173190
"""
174191
Set timedout_in_queue to True for all the predictions that have been in the queue for more than timeout_time seconds
175192
and delete older than 7 seconds predictions that have finished prediction
176193
"""
177194

178-
_utils.MAIN_INDEX.search(
195+
timedout_in_queue_unique_ids = _utils.MAIN_INDEX.search(
179196
query={
180197
"-1.predicted_at": 0,
181198
"-1.received_at": {"$lt": time.time() - timeout_time},
182199
"timedout_in_queue": {"$ne": True},
183200
},
184201
update={"timedout_in_queue": True},
202+
select_keys=[],
185203
)
186204

205+
if timedout_in_queue_unique_ids:
206+
_utils.logger.warning(
207+
f"{_utils.MAIN_INDEX.count()} in queue, set timedout_in_queue to True for {list(timedout_in_queue_unique_ids)} unique_ids"
208+
)
209+
187210
_utils.MAIN_INDEX.delete(
188211
query={
189212
"$and": [
@@ -194,23 +217,12 @@ def start_loop(
194217
)
195218

196219
unique_id_wise_input_count, input_batch = fetch_batch(
197-
_utils.MAIN_INDEX, predictor_sequence, optimal_batch_size
220+
_utils.MAIN_INDEX,
221+
predictor_sequence,
222+
optimal_batch_size,
223+
max_wait_time_for_batch_collection,
198224
)
199225

200-
current_batch_length = len(input_batch)
201-
202-
if current_batch_length == 0:
203-
time.sleep(max_wait_time_for_batch_collection)
204-
continue
205-
206-
if (
207-
time.time() - last_batch_collection_started_at
208-
< max_wait_time_for_batch_collection
209-
and current_batch_length / optimal_batch_size < 0.9
210-
):
211-
time.sleep(max_wait_time_for_batch_collection / 2)
212-
continue
213-
214226
_utils.logger.debug(f"Processing batch {unique_id_wise_input_count}")
215227

216228
results, last_predictor_success, received_at, predicted_at = process_batch(
@@ -223,11 +235,12 @@ def start_loop(
223235
last_predictor_success,
224236
received_at,
225237
predicted_at,
226-
current_batch_length,
238+
len(input_batch),
227239
)
228240
_utils.MAIN_INDEX.update(unique_id_wise_results)
241+
229242
_utils.logger.debug(
230-
f"Updated results predictor {predictor_sequence}: list({unique_id_wise_results})"
243+
f"Updated results predictor {predictor_sequence}: {list(unique_id_wise_results)}"
231244
)
232245

233246
last_batch_collection_started_at = time.time()

fastdeploy/_rest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class AsyncResponseHandler:
27-
def __init__(self, check_interval=0.01):
27+
def __init__(self, check_interval=0.003):
2828
self.pending_requests = {}
2929
self.check_interval = check_interval
3030
self.lock = threading.Lock()
@@ -75,6 +75,9 @@ def _response_checker(self):
7575
input_types.append(data["input_type"])
7676

7777
if unique_ids:
78+
_utils.logger.debug(
79+
f"Checking responses for unique_ids: {unique_ids}"
80+
)
7881
try:
7982
responses = self.infer.get_responses_for_unique_ids(
8083
unique_ids=unique_ids,

fastdeploy/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
logging.basicConfig(
44
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
55
datefmt="%Y-%m-%d:%H:%M:%S",
6-
level=logging.INFO,
6+
level=logging.DEBUG,
77
)
88

99
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)