diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 802a1a0b..a87d7da5 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -88,8 +88,21 @@ async def train( ) -> AsyncIterator[dict[str, float]]: # Get the packed tensors from disk packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - # Wait for existing batches to finish - await self.results_queue.join() + # Wait for existing batches to finish, with timeout to prevent deadlock + try: + await asyncio.wait_for(self.results_queue.join(), timeout=10.0) + except asyncio.TimeoutError: + # Recover from deadlock by draining queue + drained = 0 + while True: + try: + self.results_queue.get_nowait() + self.results_queue.task_done() + drained += 1 + except asyncio.QueueEmpty: + break + if verbose and drained > 0: + print(f"Warning: Drained {drained} lingering result(s) from queue") # If we haven't already, start the training task if self._train_task is None: self._train_task = asyncio.create_task(