Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make process_prompt Cancellable Outside Downloads #407

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()

async def ensure_shard(self, shard: Shard) -> Path:
"""Ensure a shard is downloaded and return its path. Downloads are protected from cancellation."""
if shard in self.completed_downloads:
if DEBUG >= 2: print(f"Using completed download for {shard}")
return self.completed_downloads[shard]

if self.quick_check:
repo_root = get_repo_root(shard.model_id)
snapshots_dir = repo_root/"snapshots"
Expand All @@ -29,10 +32,13 @@ async def ensure_shard(self, shard: Shard) -> Path:
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
return most_recent_dir

# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
if active_shard == shard:
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
# Check for active download
if shard in self.active_downloads:
if DEBUG >= 2: print(f"Using existing download for {shard}")
try:
return await self.active_downloads[shard]
except asyncio.CancelledError:
if DEBUG >= 2: print(f"Ignoring cancellation for existing download of {shard}")
return await self.active_downloads[shard]

# Cancel any downloads for this model_id on a different shard
Expand All @@ -50,17 +56,42 @@ async def ensure_shard(self, shard: Shard) -> Path:
traceback.print_exc()
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}

# Start new download
download_task = asyncio.create_task(self._download_shard(shard))
# Start new protected download
event = asyncio.Event()
result = None
error = None

async def protected_download():
nonlocal result, error
try:
if DEBUG >= 2: print(f"Starting protected download for {shard}")
result = await self._download_shard(shard)
self.completed_downloads[shard] = result
if DEBUG >= 2: print(f"Download completed for {shard}: {result}")
return result
except Exception as e:
if DEBUG >= 2: print(f"Error in download for {shard}: {e}")
error = e
raise
finally:
event.set()

download_task = asyncio.create_task(protected_download())
self.active_downloads[shard] = download_task

try:
path = await download_task
self.completed_downloads[shard] = path
return path
if DEBUG >= 2: print(f"Waiting for download to complete for {shard}")
try:
return await download_task
except asyncio.CancelledError:
if DEBUG >= 2: print(f"Ignoring cancellation and waiting for download to complete for {shard}")
await event.wait()
if error:
raise error
return result
finally:
# Ensure the task is removed even if an exception occurs
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
if shard in self.active_downloads:
if DEBUG >= 2: print(f"Cleaning up download task for {shard}")
if shard in self.active_downloads and self.active_downloads[shard] is download_task:
self.active_downloads.pop(shard)

async def _download_shard(self, shard: Shard) -> Path:
Expand Down
135 changes: 135 additions & 0 deletions test/test_hf_shard_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest
import asyncio
from pathlib import Path
from unittest import mock
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.shard import Shard
from exo import DEBUG

class MockPath:
def __init__(self, exists=True):
self.exists = lambda: exists
self.iterdir = lambda: []

def __truediv__(self, other):
return self

mock_was_cancelled = False

async def mock_download_shard(self, shard):
"""Mock the _download_shard method to simulate a long download"""
global mock_was_cancelled
try:
if DEBUG >= 2: print(f"Starting mock download for {shard}")
await asyncio.sleep(0.5) # Simulate download time
if DEBUG >= 2: print(f"Mock download completed for {shard}")
return Path("/mock/download/path")
except asyncio.CancelledError:
mock_was_cancelled = True
if DEBUG >= 2: print(f"Mock download was cancelled for {shard}, completing anyway")
# Continue despite cancellation
await asyncio.sleep(0.5)
if DEBUG >= 2: print(f"Mock download completed after cancellation for {shard}")
return Path("/mock/download/path")

@pytest.mark.asyncio
async def test_download_protection():
"""Test that downloads are protected from cancellation"""
global mock_was_cancelled
mock_was_cancelled = False

with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \
mock.patch.object(HFShardDownloader, '_download_shard', mock_download_shard):

downloader = HFShardDownloader()
shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1)

# Create a future we'll use to control the test flow
download_complete = asyncio.Future()

async def do_download():
try:
result = await downloader.ensure_shard(shard)
download_complete.set_result(result)
except Exception as e:
if not download_complete.done():
download_complete.set_exception(e)

# Start the download
task = asyncio.create_task(do_download())

# Give it a moment to start
await asyncio.sleep(0.1)

# Try to cancel the task
if DEBUG >= 2: print("Attempting to cancel download task")
task.cancel()

# Wait for result with timeout
try:
result = await asyncio.wait_for(download_complete, timeout=2.0)
assert isinstance(result, Path), "Should return a Path"
assert shard not in downloader.active_downloads, "Download task should be cleaned up"
assert shard in downloader.completed_downloads, "Download should be marked as completed"
if DEBUG >= 2: print("Download completed successfully despite cancellation")
except asyncio.TimeoutError:
pytest.fail("Download did not complete in time")
except Exception as e:
pytest.fail(f"Download failed with error: {e}")

@pytest.mark.asyncio
async def test_multiple_downloads():
"""Test handling multiple downloads for the same shard"""
with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \
mock.patch.object(HFShardDownloader, '_download_shard', mock_download_shard):

downloader = HFShardDownloader()
shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1)

# Start both downloads with a small delay between them
if DEBUG >= 2: print("Starting first download")
download1 = asyncio.create_task(downloader.ensure_shard(shard))

await asyncio.sleep(0.2) # Give first download time to start

if DEBUG >= 2: print("Starting second download")
download2 = asyncio.create_task(downloader.ensure_shard(shard))

# Wait for both downloads to complete
if DEBUG >= 2: print("Waiting for downloads to complete")

path1 = await download1
if DEBUG >= 2: print(f"First download completed with path: {path1}")

path2 = await download2
if DEBUG >= 2: print(f"Second download completed with path: {path2}")

# Verify results
assert isinstance(path1, Path), "First download should return a Path"
assert isinstance(path2, Path), "Second download should return a Path"
assert path1 == path2, "Multiple downloads should return same path"
assert shard not in downloader.active_downloads, "Download tasks should be cleaned up"
assert shard in downloader.completed_downloads, "Download should be marked as completed"

@pytest.mark.asyncio
async def test_download_error_handling():
"""Test that errors during download are handled properly"""

async def mock_download_error(self, shard):
await asyncio.sleep(0.1) # Simulate some work
raise Exception("Download failed")

with mock.patch('exo.download.hf.hf_shard_download.get_repo_root', return_value=MockPath(exists=False)), \
mock.patch.object(HFShardDownloader, '_download_shard', mock_download_error):

downloader = HFShardDownloader()
shard = Shard(model_id="test-model", start_layer=0, end_layer=1, n_layers=1)

with pytest.raises(Exception) as exc_info:
await downloader.ensure_shard(shard)

assert str(exc_info.value) == "Download failed"
assert shard not in downloader.active_downloads, "Failed download should be cleaned up"

if __name__ == "__main__":
pytest.main([__file__, "-v"])