Skip to content
Open

Dev #125

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions safetytooling/apis/batch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Literal
from typing import Literal, Callable

from safetytooling.apis.inference.anthropic import ANTHROPIC_MODELS, AnthropicModelBatch
from safetytooling.apis.inference.cache_manager import get_cache_manager
Expand Down Expand Up @@ -110,6 +110,7 @@ async def __call__(
use_cache: bool = True,
max_tokens: int | None = None,
chunk: int | None = None,
batch_id_callback: Callable[[str], None] | None = None,
**kwargs,
) -> tuple[list[LLMResponse], str]:
"""Make batch API requests for the specified model and prompts.
Expand All @@ -121,6 +122,7 @@ async def __call__(
use_cache: Whether to use caching. Ignored if no_cache was set during initialization.
max_tokens: Maximum number of tokens to generate per prompt. Setting different max_tokens for each prompt is not yet supported here, although it is in principle supported in the underlying API.\
chunk: Maximum number of prompts to send in a single batch. If None, will use a default value based on the model.
batch_id_callback: Optional callback function that takes a batch ID string as an argument.
**kwargs: Additional arguments to pass to the model. "seed" should be used for preventing cached responses from being used.

Returns:
Expand Down Expand Up @@ -182,13 +184,14 @@ async def __call__(
else:
chunks = [uncached_prompts[i : i + chunk] for i in range(0, len(uncached_prompts), chunk)]

async def process_chunk(chunk_prompts) -> tuple[list[LLMResponse], str]:
async def process_chunk(chunk_prompts, callback_fn: Callable[[str], None] | None) -> tuple[list[LLMResponse], str]:
print(f"Batch kwargs: {kwargs}")
if model_id in ANTHROPIC_MODELS:
return await self._anthropic_batch(
model_id=model_id,
prompts=chunk_prompts,
max_tokens=kwargs.get("max_tokens", None),
batch_id_callback=callback_fn,
**{k: v for k, v in kwargs.items() if k != "seed" and k != "max_tokens"},
)
elif model_id in GPT_CHAT_MODELS:
Expand All @@ -197,6 +200,7 @@ async def process_chunk(chunk_prompts) -> tuple[list[LLMResponse], str]:
prompts=chunk_prompts,
log_dir=log_dir,
max_tokens=kwargs.get("max_tokens", None),
#batch_id_callback=callback_fn,
**{k: v for k, v in kwargs.items() if k != "max_tokens"},
)
else:
Expand All @@ -207,7 +211,7 @@ async def process_chunk(chunk_prompts) -> tuple[list[LLMResponse], str]:
all_responses = []
all_batch_ids = []

results = await asyncio.gather(*[process_chunk(chunk) for chunk in chunks])
results = await asyncio.gather(*[process_chunk(chunk, batch_id_callback) for chunk in chunks])

for chunk_prompts, (responses, batch_id) in zip(chunks, results):
all_responses.extend(responses)
Expand All @@ -221,7 +225,7 @@ async def process_chunk(chunk_prompts) -> tuple[list[LLMResponse], str]:

else:
# Process all uncached prompts in a single batch
responses, batch_id = await process_chunk(uncached_prompts)
responses, batch_id = await process_chunk(uncached_prompts, batch_id_callback)
all_responses = responses
all_batch_ids = [batch_id]

Expand Down
46 changes: 28 additions & 18 deletions safetytooling/apis/finetuning/together/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import os
import subprocess
from pathlib import Path

import simple_parsing
Expand All @@ -30,7 +31,7 @@ async def upload_finetuning_file(
LOGGER.info(f"Uploading {file_path} to Together...")
if not file_path.exists():
raise FileNotFoundError(f"File {file_path} does not exist")
if file_path.suffix != ".jsonl":
if file_path.suffix not in [".jsonl", ".parquet"]:
raise ValueError(f"File {file_path} must have a .jsonl extension")
response = client.files.upload(file=str(file_path))
file_id = response.id
Expand Down Expand Up @@ -83,13 +84,18 @@ async def main(cfg: TogetherFTConfig, verbose: bool = True):
validation_file=val_file_id,
model=cfg.model,
n_epochs=cfg.n_epochs,
n_checkpoints=cfg.n_epochs,
batch_size=cfg.batch_size,
learning_rate=cfg.learning_rate,
lora=cfg.lora,
suffix=cfg.suffix,
lora_r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
lora_dropout=cfg.lora_dropout,
lora_trainable_modules=cfg.lora_trainable_modules,
wandb_api_key=os.environ.get("WANDB_API_KEY"),
wandb_project_name=cfg.wandb_project_name,
from_checkpoint=cfg.from_checkpoint,
)
LOGGER.info(f"Started fine-tuning job: {ft_job.id}")

Expand All @@ -113,24 +119,26 @@ async def main(cfg: TogetherFTConfig, verbose: bool = True):
)

# Upload files to wandb
upload_file_to_wandb(
file_path=cfg.train_file,
name=train_file_id,
type="together-finetune-training-file",
description="Training file for finetuning",
wandb_run=wrun,
)

if cfg.val_file is not None:
upload_file_to_wandb(
file_path=cfg.val_file,
name=val_file_id,
type="together-finetune-validation-file",
description="Validation file for finetuning",
wandb_run=wrun,
)
# Currently turned off
#upload_file_to_wandb(
#file_path=cfg.train_file,
#name=train_file_id,
#type="together-finetune-training-file",
#description="Training file for finetuning",
#wandb_run=wrun,
#)

#if cfg.val_file is not None:
#upload_file_to_wandb(
#file_path=cfg.val_file,
#name=val_file_id,
#type="together-finetune-validation-file",
#description="Validation file for finetuning",
#wandb_run=wrun,
#)

LOGGER.info("Waiting for fine-tuning job to finish...")
status = None
while True:
status = client.fine_tuning.retrieve(ft_job.id) # https://docs.together.ai/reference/get_fine-tunes-id
if wrun is not None:
Expand All @@ -151,7 +159,7 @@ async def main(cfg: TogetherFTConfig, verbose: bool = True):
LOGGER.info(f"Fine-tuning job finished with id <ft_id>{ft_job.id}</ft_id>")

if cfg.save_folder is not None:
assert ft_job.output_name is not None, "Output name is None"
ft_job.output_name = status.output_name
ft_job.output_name = ft_job.output_name.replace("/", "|")

if cfg.save_folder.endswith("/"):
Expand Down Expand Up @@ -188,7 +196,9 @@ class TogetherFTConfig(FinetuneConfig):
lora_r: int = 8
lora_alpha: int = 8
lora_dropout: float = 0.0
lora_trainable_modules: str = "all-linear"
suffix: str = "" # Together suffix to append to the model name
from_checkpoint: str | None = None

save_model: bool = False

Expand Down
10 changes: 10 additions & 0 deletions safetytooling/apis/inference/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from pathlib import Path
from traceback import format_exc
from typing import Callable

import anthropic.types
from anthropic import AsyncAnthropic
Expand Down Expand Up @@ -274,6 +275,7 @@ async def __call__(
prompts: list[Prompt],
max_tokens: int,
log_dir: Path | None = None,
batch_id_callback: Callable[[str], None] | None = None,
**kwargs,
) -> tuple[list[LLMResponse], str]:
assert max_tokens is not None, "Anthropic batch API requires max_tokens to be specified"
Expand All @@ -286,6 +288,14 @@ async def __call__(
requests = self.prompts_to_requests(model_id, prompts, max_tokens, **kwargs)
batch_response = self.create_message_batch(requests=requests)
batch_id = batch_response.id

# Invoke the callback with the batch_id
if batch_id_callback and batch_id:
try:
batch_id_callback(batch_id)
except Exception as e:
LOGGER.error(f"Error in batch_id_callback for batch {batch_id}: {e}")

if log_dir is not None:
log_file = log_dir / f"batch_id_{batch_id}.json"
log_file.parent.mkdir(parents=True, exist_ok=True)
Expand Down
55 changes: 30 additions & 25 deletions safetytooling/apis/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .openai.moderation import OpenAIModerationModel
from .openai.s2s import OpenAIS2SModel, S2SRateLimiter
from .openai.utils import COMPLETION_MODELS, GPT_CHAT_MODELS, S2S_MODELS
from .openrouter import OPENROUTER_MODELS, OpenRouterChatModel
from .opensource.batch_inference import BATCHED_MODELS, BatchModel
from .runpod_vllm import VLLM_MODELS, VLLMChatModel
from .together import TOGETHER_MODELS, TogetherChatModel
Expand All @@ -54,6 +53,8 @@
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
DEEPSEEK_MODELS = {"deepseek-chat", "deepseek-reasoner"}

OPENROUTER_MODELS = {"meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct"}


class InferenceAPI:
"""
Expand Down Expand Up @@ -86,8 +87,9 @@ def __init__(
anthropic_api_key: str | None = None,
openai_api_key: str | None = None,
print_prompt_and_response: bool = False,
use_vllm_if_model_not_found: bool = False,
use_provider_if_model_not_found: str | None = None,
vllm_base_url: str = "http://localhost:8000/v1/chat/completions",
vllm_api_key: str | None = None,
no_cache: bool = False,
oai_embedding_batch_size: int = 2048,
):
Expand All @@ -104,16 +106,17 @@ def __init__(
self.anthropic_num_threads = anthropic_num_threads
self.openai_fraction_rate_limit = openai_fraction_rate_limit
self.openai_base_url = openai_base_url
self.openrouter_num_threads = openrouter_num_threads
# limit openai api calls to stop async jamming
self.deepseek_semaphore = asyncio.Semaphore(deepseek_num_threads)
self.openai_semaphore = asyncio.Semaphore(openai_num_threads)
self.gemini_semaphore = asyncio.Semaphore(gemini_num_threads)
self.openai_s2s_semaphore = asyncio.Semaphore(openai_s2s_num_threads)
self.openrouter_semaphore = asyncio.Semaphore(openrouter_num_threads)
self.gemini_recitation_rate_check_volume = gemini_recitation_rate_check_volume
self.gemini_recitation_rate_threshold = gemini_recitation_rate_threshold
self.gray_swan_num_threads = gray_swan_num_threads
self.together_num_threads = together_num_threads
self.openrouter_num_threads = openrouter_num_threads
self.huggingface_num_threads = huggingface_num_threads
self.vllm_num_threads = vllm_num_threads
self.deepseek_num_threads = deepseek_num_threads
Expand All @@ -124,8 +127,9 @@ def __init__(
self.n_calls = 0
self.gpt_4o_rate_limiter = S2SRateLimiter(self.gpt4o_s2s_rpm_cap)
self.print_prompt_and_response = print_prompt_and_response
self.use_vllm_if_model_not_found = use_vllm_if_model_not_found
self.use_provider_if_model_not_found = use_provider_if_model_not_found
self.vllm_base_url = vllm_base_url
self.vllm_api_key = vllm_api_key
# can also set via env var
if os.environ.get("SAFETYTOOLING_PRINT_PROMPTS", "false").lower() == "true":
self.print_prompt_and_response = True
Expand Down Expand Up @@ -167,6 +171,13 @@ def __init__(
openai_api_key=openai_api_key,
)

self._openrouter = OpenAIChatModel(
frac_rate_limit=self.openai_fraction_rate_limit,
prompt_history_dir=self.prompt_history_dir,
base_url="https://openrouter.ai/api/v1",
openai_api_key=os.environ.get("OPENROUTER_API_KEY", None),
)

self._openai_moderation = OpenAIModerationModel()

self._openai_embedding = OpenAIEmbeddingModel(batch_size=oai_embedding_batch_size)
Expand Down Expand Up @@ -196,12 +207,6 @@ def __init__(
api_key=os.environ.get("TOGETHER_API_KEY", None),
)

self._openrouter = OpenRouterChatModel(
num_threads=self.openrouter_num_threads,
prompt_history_dir=self.prompt_history_dir,
api_key=os.environ.get("OPENROUTER_API_KEY", None),
)

self._gemini_vertex = GeminiVertexAIModel(prompt_history_dir=self.prompt_history_dir)
self._gemini_genai = GeminiModel(
prompt_history_dir=self.prompt_history_dir,
Expand All @@ -214,6 +219,7 @@ def __init__(
num_threads=vllm_num_threads,
prompt_history_dir=self.prompt_history_dir,
vllm_base_url=self.vllm_base_url,
vllm_api_key=self.vllm_api_key,
)

# DeepSeek uses the OpenAI API
Expand Down Expand Up @@ -254,9 +260,9 @@ def __init__(
"batch_gpu": self._batch_models,
"openai_s2s": self._openai_s2s,
"together": self._together,
"openrouter": self._openrouter,
"vllm": self._vllm,
"deepseek": self._deepseek,
"openrouter": self._openrouter,
}

@classmethod
Expand Down Expand Up @@ -305,16 +311,16 @@ def model_id_to_class(
return self._openai_s2s
elif model_id in TOGETHER_MODELS or model_id.startswith("scalesafetyresearch"):
return self._together
elif model_id in OPENROUTER_MODELS or model_id.startswith("openrouter/"):
return self._openrouter
elif model_id in VLLM_MODELS:
return self._vllm
elif model_id in DEEPSEEK_MODELS:
return self._deepseek
elif self.use_vllm_if_model_not_found:
return self._vllm
elif model_id in OPENROUTER_MODELS:
return self._openrouter
elif self.use_provider_if_model_not_found is not None:
return self.provider_to_class[self.use_provider_if_model_not_found]
raise ValueError(
f"Invalid model id: {model_id}. Pass openai_completion, openai_chat, anthropic, huggingface, gemini, batch_gpu, openai_s2s, together, openrouter, vllm, or deepseek to force a provider."
f"Invalid model id: {model_id}. Pass openai_completion, openai_chat, anthropic, huggingface, gemini, batch_gpu, openai_s2s, together, vllm, or deepseek to force a provider."
)

async def check_rate_limit(self, wait_time=60):
Expand Down Expand Up @@ -379,7 +385,7 @@ async def __call__(
audio_out_dir: str | Path = None,
print_prompt_and_response: bool = False,
n: int = 1,
max_attempts_per_api_call: int = 10,
max_attempts_per_api_call: int = 40,
num_candidates_per_completion: int = 1,
is_valid: Callable[[str], bool] = lambda _: True,
insufficient_valids_behaviour: Literal["error", "continue", "pad_invalids", "retry"] = "retry",
Expand Down Expand Up @@ -416,9 +422,12 @@ async def __call__(
"""
if self.no_cache:
use_cache = False

model_class = self.model_id_to_class(model_id, gemini_use_vertexai, force_provider)

#if isinstance(model_class, VLLMChatModel):
# use_cache = False

num_candidates = num_candidates_per_completion * n

assert "top_logprobs" not in kwargs, "top_logprobs is not supported, pass an integer with `logprobs` instead"
Expand Down Expand Up @@ -471,11 +480,7 @@ async def __call__(
# If prompt is a single prompt and there is no cached result, simply return the original prompt for regular processing
prompt = prompt

if (
isinstance(model_class, AnthropicChatModel)
or isinstance(model_class, HuggingFaceModel)
or isinstance(model_class, OpenRouterChatModel)
):
if isinstance(model_class, AnthropicChatModel) or isinstance(model_class, HuggingFaceModel):
if isinstance(model_class, HuggingFaceModel):
kwargs["model_url"] = huggingface_model_url
# Anthropic chat doesn't support generating multiple candidates at once, so we have to do it manually
Expand Down Expand Up @@ -564,12 +569,12 @@ async def __call__(
)
else:
# At this point, the request should be for DeepSeek or OpenAI, which use the same API.
expected_chat_models = [OpenAIChatModel, OpenAICompletionModel]
expected_chat_models = [OpenAIChatModel, OpenAICompletionModel, VLLMChatModel]
if not any(isinstance(model_class, model) for model in expected_chat_models):
raise RuntimeError(
f"Got unexpected ChatModel class: {model_class}. Make sure to implement logic to handle InferenceAPI.__call__ for your custom ChatModel. Or, add your ChatModel class to expected_chat_models above."
)
if model_class.base_url == DEEPSEEK_BASE_URL:
if not isinstance(model_class, VLLMChatModel) and model_class.base_url == DEEPSEEK_BASE_URL:
candidate_responses = []
for _ in range(num_candidates): # DeepSeek doesn't support multiple completions.
async with self.deepseek_semaphore:
Expand Down
4 changes: 2 additions & 2 deletions safetytooling/apis/inference/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def process_cached_responses(

if cached_result is not None:
cache_file, _ = self.get_cache_file(prompt=individual_prompt, params=params)
LOGGER.info(f"Loaded cache for prompt from {cache_file}")
#LOGGER.info(f"Loaded cache for prompt from {cache_file}")

prop_empty_completions = sum(
1 for response in cached_result.responses if response.completion == ""
Expand Down Expand Up @@ -390,7 +390,7 @@ def process_cached_responses(

if cached_result is not None:
cache_dir, _ = self.get_cache_file(prompt=individual_prompt, params=params)
LOGGER.info(f"Loaded cache for prompt from {cache_dir}")
# LOGGER.info(f"Loaded cache for prompt from {cache_dir}")

prop_empty_completions = sum(
1 for response in cached_result.responses if response.completion == ""
Expand Down
Loading