Skip to content

feat(model): Implement safe model deletion with trash functionality #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import dev
from .local import LocalBackend
from .model import Model, TrainableModel
from .model import Model, TrainableModel, list_trash, restore_from_trash, empty_trash, delete_model
from .trajectories import TrajectoryGroup
from .types import TrainConfig
from .utils.deploy_model import LoRADeploymentProvider
Expand All @@ -35,6 +35,7 @@ def is_port_available(port: int) -> bool:
)
return


# Reset the custom __new__ and __init__ methods for TrajectoryGroup
def __new__(cls, *args: Any, **kwargs: Any) -> TrajectoryGroup:
return pydantic.BaseModel.__new__(cls)
Expand All @@ -45,6 +46,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
TrajectoryGroup.__new__ = __new__ # type: ignore
TrajectoryGroup.__init__ = __init__


backend = LocalBackend()
app = FastAPI()

Expand Down Expand Up @@ -147,3 +149,38 @@ async def _experimental_deploy(
)

uvicorn.run(app, host=host, port=port, loop="asyncio")

@app.command()
def delete_model(project_name: str, model_name: str):
"""Move a model to the trash."""
try:
delete_model(project_name, model_name)
print(f"Model '{model_name}' in project '{project_name}' moved to trash.")
except FileNotFoundError as e:
print(e)

@app.command()
def list_trash(project_name: str):
"""List models in the trash."""
trashed_models = list_trash(project_name)
if not trashed_models:
print("Trash is empty.")
else:
print("Models in trash:")
for model_name in trashed_models:
print(f"- {model_name}")

@app.command()
def restore_model(project_name: str, model_name: str):
"""Restore a model from the trash."""
try:
restore_from_trash(project_name, model_name)
print(f"Model '{model_name}' in project '{project_name}' restored.")
except FileNotFoundError as e:
print(e)

@app.command()
def empty_trash(project_name: str):
"""Permanently delete all models in the trash."""
empty_trash(project_name)
print("Trash has been emptied.")
66 changes: 58 additions & 8 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import httpx
import os
import shutil
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from typing import TYPE_CHECKING, cast, Generic, Iterable, Optional, overload, TypeVar
Expand All @@ -8,6 +10,7 @@
from .openai import patch_openai
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainConfig
from .utils.output_dirs import get_default_art_path, get_models_dir, get_model_dir

if TYPE_CHECKING:
from art.backend import Backend
Expand All @@ -30,22 +33,26 @@ class Model(

You can instantiate a prompted model like so:

``python model = art.Model(
name="gpt-4.1", project="my-project",
```python
model = art.Model(
name="gpt-4.1",
project="my-project",
inference_api_key=os.getenv("OPENAI_API_KEY"),
inference_base_url="https://api.openai.com/v1/",
)
``
```

Or, if you're pointing at OpenRouter:

``python model = art.Model(
name="gemini-2.5-pro", project="my-project",
```python
model = art.Model(
name="gemini-2.5-pro",
project="my-project",
inference_api_key=os.getenv("OPENROUTER_API_KEY"),
inference_base_url="https://openrouter.ai/api/v1",
inference_model_name="google/gemini-2.5-pro-preview-03-25",
)
``
```

For trainable (`art.TrainableModel`) models the inference values will be
populated automatically by `model.register(api)` so you generally don't need
Expand Down Expand Up @@ -227,9 +234,9 @@ async def log(
)


# ---------------------------------------------------------------------------
# --------------------------------------------------------------------------
# Trainable models
# ---------------------------------------------------------------------------
# --------------------------------------------------------------------------


class TrainableModel(Model[ModelConfig], Generic[ModelConfig]):
Expand Down Expand Up @@ -354,3 +361,46 @@ async def train(
self, list(trajectory_groups), config, _config or {}, verbose
):
pass


def _get_trash_dir(project_name: str) -> str:
art_path = get_default_art_path()
return os.path.join(art_path, "trash", project_name, "models")


def move_to_trash(project_name: str, model_name: str):
model_dir = get_model_dir(Model(name=model_name, project=project_name, config=None))
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model '{model_name}' not found in project '{project_name}'.")

trash_dir = _get_trash_dir(project_name)
os.makedirs(trash_dir, exist_ok=True)

shutil.move(model_dir, os.path.join(trash_dir, model_name))


def list_trash(project_name: str) -> list[str]:
trash_dir = _get_trash_dir(project_name)
if not os.path.exists(trash_dir):
return []
return os.listdir(trash_dir)


def restore_from_trash(project_name: str, model_name: str):
trash_dir = _get_trash_dir(project_name)
trashed_model_path = os.path.join(trash_dir, model_name)
if not os.path.exists(trashed_model_path):
raise FileNotFoundError(f"Model '{model_name}' not found in trash.")

models_dir = get_models_dir(project_name)
shutil.move(trashed_model_path, os.path.join(models_dir, model_name))


def empty_trash(project_name: str):
trash_dir = _get_trash_dir(project_name)
if os.path.exists(trash_dir):
shutil.rmtree(trash_dir)


def delete_model(project_name: str, model_name: str):
move_to_trash(project_name, model_name)
3 changes: 3 additions & 0 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def __iter__(self) -> Iterator[Trajectory]:
def __len__(self) -> int:
return len(self.trajectories)



@overload
def __new__(
cls,
Expand Down Expand Up @@ -226,3 +228,4 @@ def __await__(self):
exceptions=exceptions,
)
return group