diff --git a/README.md b/README.md index 4a8a987..9e66860 100644 --- a/README.md +++ b/README.md @@ -378,7 +378,55 @@ We found some plugins created by community developers. Thanks for their efforts: - Replicate Demo & Cloud API. [Replicate-MagicTime](https://replicate.com/camenduru/magictime) (by [@camenduru](https://twitter.com/camenduru)). - Jupyter Notebook. [Jupyter-MagicTime](https://github.com/camenduru/MagicTime-jupyter) (by [@ModelsLab](https://modelslab.com/)). -If you find related work, please let us know. +If you find related work, please let us know. + +## 🤖 LLM Provider for Data Preprocessing + +The data preprocessing scripts (`data_preprocess/`) support multiple LLM providers for video/frame captioning. By default, **OpenAI GPT-4V** is used, but you can switch to **[MiniMax](https://www.minimaxi.com/)** or any OpenAI-compatible API. + +### Using MiniMax + +[MiniMax](https://www.minimaxi.com/) provides powerful LLM models (MiniMax-M2.7, MiniMax-M2.5) with an OpenAI-compatible API, supporting both text and vision inputs. + +```bash +# Set your MiniMax API key +export MINIMAX_API_KEY="your-api-key" + +# Frame captioning with MiniMax +python data_preprocess/step2_1_GPT4V_frame_caption.py \ + --provider minimax \ + --image_directories ./step_1 \ + --output_file ./2_1_gpt_frames_caption.json + +# Video captioning (concise) with MiniMax +python data_preprocess/step3_1_GPT4V_video_caption_concise.py \ + --provider minimax \ + --input_file ./2_2_final_useful_gpt_frames_caption.json \ + --output_file ./3_1_gpt_video_caption.json + +# Or edit data_preprocess/run.sh and set PROVIDER="minimax" +``` + +### Using a Custom Provider + +You can also use any OpenAI-compatible API by specifying `--base_url` and `--model`: + +```bash +python data_preprocess/step2_1_GPT4V_frame_caption.py \ + --base_url https://your-api.example.com/v1 \ + --model your-model-name \ + --api_key your-api-key \ + --image_directories ./step_1 +``` + +### Available Provider Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--provider` | `openai` | LLM provider (`openai` or `minimax`) | +| `--base_url` | Provider default | Custom API base URL | +| `--model` | Provider default | Model name | +| `--api_key` | From env var | API key (or set `OPENAI_API_KEY` / `MINIMAX_API_KEY`) | ## 🐳 ChronoMagic Dataset ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic), or you can download it with the following command. Some samples can be found on our [Project Page](https://pku-yuangroup.github.io/MagicTime/). diff --git a/data_preprocess/llm_provider.py b/data_preprocess/llm_provider.py new file mode 100644 index 0000000..9d52ad3 --- /dev/null +++ b/data_preprocess/llm_provider.py @@ -0,0 +1,115 @@ +""" +LLM provider configuration for data preprocessing scripts. + +Supports multiple LLM providers via OpenAI-compatible APIs: +- OpenAI (default): GPT-4V, GPT-4o, etc. +- MiniMax: MiniMax-M2.7, MiniMax-M2.5, etc. +- Any OpenAI-compatible provider via --base_url and --model + +Usage: + from llm_provider import add_provider_args, create_client, get_model_name + + # In argument parser setup: + add_provider_args(parser) + + # In code: + args = parser.parse_args() + client = create_client(args) + model = get_model_name(args) +""" + +import os + +from openai import OpenAI + +# Provider presets: base_url, default model, env var for API key +PROVIDER_PRESETS = { + "openai": { + "base_url": None, # OpenAI SDK default + "default_model": "gpt-4-vision-preview", + "env_key": "OPENAI_API_KEY", + }, + "minimax": { + "base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + "env_key": "MINIMAX_API_KEY", + }, +} + + +def add_provider_args(parser): + """Add LLM provider arguments to an argparse parser.""" + parser.add_argument( + "--provider", + type=str, + default="openai", + choices=list(PROVIDER_PRESETS.keys()), + help="LLM provider to use (default: openai).", + ) + parser.add_argument( + "--base_url", + type=str, + default=None, + help="Custom API base URL (overrides provider default).", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model name (overrides provider default).", + ) + + +def _resolve_api_key(args): + """Resolve API key from args or environment variables.""" + # Explicit --api_key takes priority + api_key = getattr(args, "api_key", None) + if api_key: + return api_key + + # Check provider-specific env var + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + env_key = preset["env_key"] + api_key = os.environ.get(env_key) + if api_key: + return api_key + + # Fallback to OPENAI_API_KEY for any provider + return os.environ.get("OPENAI_API_KEY") + + +def create_client(args): + """Create an OpenAI-compatible client based on provider args.""" + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + + base_url = getattr(args, "base_url", None) or preset["base_url"] + api_key = _resolve_api_key(args) + + kwargs = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + + return OpenAI(**kwargs) + + +def get_model_name(args): + """Get the model name from args or provider defaults.""" + model = getattr(args, "model", None) + if model: + return model + + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + return preset["default_model"] + + +def clamp_temperature(temperature, provider="openai"): + """Clamp temperature to provider-specific valid range. + + MiniMax accepts temperature in (0.0, 1.0]. + """ + if provider == "minimax": + return max(0.01, min(temperature, 1.0)) + return temperature diff --git a/data_preprocess/run.sh b/data_preprocess/run.sh index 7e09ef7..f3d2201 100644 --- a/data_preprocess/run.sh +++ b/data_preprocess/run.sh @@ -6,6 +6,10 @@ OUTPUT_FOLDER_STEP_1="./step_1" API_KEY="XXX" NUM_WORKERS=8 +# LLM provider: "openai" (default) or "minimax" +# For MiniMax, set MINIMAX_API_KEY env var and change PROVIDER to "minimax" +PROVIDER="openai" + # File paths FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json" GROUP_FRAMES_FILE="./2_1_temp_group_frames.json" @@ -22,22 +26,25 @@ FINAL_CSV_FILE="./all_clean_data.csv" # Step 1: Extract and resize frames python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1" -# Step 2.1: Generate frame captions using GPT-4V +# Step 2.1: Generate frame captions using LLM (GPT-4V or MiniMax) python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \ - --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" + --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" \ + --provider "$PROVIDER" # Step 2.2: Preprocess frame captions python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \ --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \ --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE" -# Step 3.1: Generate concise video captions using GPT-4V +# Step 3.1: Generate concise video captions using LLM (GPT-4V or MiniMax) python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \ - --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" + --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" \ + --provider "$PROVIDER" # Optional: Generate detailed video captions (uncomment to enable) # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \ -# --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" +# --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" \ +# --provider "$PROVIDER" # Step 3.2: Preprocess video captions python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \ diff --git a/data_preprocess/step2_1_GPT4V_frame_caption.py b/data_preprocess/step2_1_GPT4V_frame_caption.py index e397080..72e0ef2 100644 --- a/data_preprocess/step2_1_GPT4V_frame_caption.py +++ b/data_preprocess/step2_1_GPT4V_frame_caption.py @@ -4,11 +4,12 @@ import base64 import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, wait_exponential, stop_after_attempt +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Suppose you are a data annotator, specialized in generating captions for time-lapse videos. You will be supplied with eight key frames extracted from a video, each with a filename labeled with its position in the video sequence. Your task is to generate a caption for each frame, focusing on the primary subject and integrating all discernible elements. Note: These captions should be brief and concise, avoiding redundancy. @@ -165,8 +166,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -180,9 +180,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): print(chat_completion) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -193,7 +193,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -204,12 +204,12 @@ def main(num_workers, all_prompts, output_file, api_key): return print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.") - + progress_bar = tqdm(total=len(unprocessed_prompts)) with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -225,15 +225,21 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Process video frame captions.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=6, help="Number of worker threads for processing.") parser.add_argument("--output_file", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the output JSON file.") parser.add_argument("--group_frames_file", type=str, default="./2_1_temp_group_frames.json", help="Path to save grouped frame metadata.") parser.add_argument("--image_directories", type=str, nargs="+", default=["./step_1"], help="List of directories containing images.") - + add_provider_args(parser) + # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + all_prompts = {} all_grouped_images = {} @@ -241,7 +247,7 @@ def main(num_workers, all_prompts, output_file, api_key): for directory in args.image_directories: filenames = get_image_filenames(directory) grouped_images = group_images_by_video_id(filenames) - + # Sort images within each video group for video_id in grouped_images: grouped_images[video_id].sort(key=extract_frame_number) @@ -257,4 +263,4 @@ def main(num_workers, all_prompts, output_file, api_key): json.dump(all_grouped_images, file, indent=4) # Execute main processing function - main(args.num_workers, all_prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, all_prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/step3_1_GPT4V_video_caption_concise.py b/data_preprocess/step3_1_GPT4V_video_caption_concise.py index 54ab0eb..ddef994 100644 --- a/data_preprocess/step3_1_GPT4V_video_caption_concise.py +++ b/data_preprocess/step3_1_GPT4V_video_caption_concise.py @@ -2,11 +2,12 @@ import json import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from tenacity import retry, wait_exponential, stop_after_attempt from concurrent.futures import ThreadPoolExecutor, as_completed +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Imagine you're an expert data annotator with a specialization in summarizing time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a concise summary for the given time-lapse video. @@ -59,8 +60,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -73,9 +73,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): ) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -86,7 +86,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -102,7 +102,7 @@ def main(num_workers, all_prompts, output_file, api_key): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -118,14 +118,20 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") parser.add_argument("--input_file", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") + add_provider_args(parser) # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + # Load data from the input file with open(args.input_file, 'r') as file: data = json.load(file) @@ -134,4 +140,4 @@ def main(num_workers, all_prompts, output_file, api_key): prompts = create_prompts(txt_prompt, data) # Execute main processing function - main(args.num_workers, prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/step3_1_GPT4V_video_caption_detail.py b/data_preprocess/step3_1_GPT4V_video_caption_detail.py index b6c2c7f..7fc38b2 100644 --- a/data_preprocess/step3_1_GPT4V_video_caption_detail.py +++ b/data_preprocess/step3_1_GPT4V_video_caption_detail.py @@ -2,11 +2,12 @@ import json import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, wait_exponential, stop_after_attempt +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Imagine you are a data annotator, specialized in generating summaries for time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a succinct and precise summary for the given time-lapse video. Note: The summary should efficiently encapsulate all discernible elements, particularly emphasizing the primary subject. It is important to indicate whether the video pertains to a forward or reverse sequence. Additionally, integrate any time-related aspects from the video into the summary. @@ -61,8 +62,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -75,9 +75,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): ) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -88,7 +88,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -104,7 +104,7 @@ def main(num_workers, all_prompts, output_file, api_key): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -120,14 +120,20 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") parser.add_argument("--input_file", type=str, default="2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") + add_provider_args(parser) # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + # Load data from the input file with open(args.input_file, 'r') as file: data = json.load(file) @@ -136,4 +142,4 @@ def main(num_workers, all_prompts, output_file, api_key): prompts = create_prompts(txt_prompt, data) # Execute main processing function - main(args.num_workers, prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/tests/__init__.py b/data_preprocess/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_preprocess/tests/test_integration.py b/data_preprocess/tests/test_integration.py new file mode 100644 index 0000000..4fbe673 --- /dev/null +++ b/data_preprocess/tests/test_integration.py @@ -0,0 +1,99 @@ +"""Integration tests for MiniMax LLM provider. + +These tests verify actual API connectivity with MiniMax. +They require MINIMAX_API_KEY to be set in the environment. +Skipped automatically when the key is not available. +""" + +import argparse +import os +import sys +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxTextCompletion(unittest.TestCase): + """Test MiniMax provider for text-based captioning.""" + + def test_minimax_text_completion(self): + from llm_provider import create_client, get_model_name + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": "Summarize in one sentence: A flower blooms from bud to full bloom over 7 days.", + } + ], + max_tokens=128, + ) + content = response.choices[0].message.content + self.assertIsInstance(content, str) + self.assertGreater(len(content), 10) + + def test_minimax_call_gpt_function(self): + """Test call_gpt from step3_1 with MiniMax client.""" + from llm_provider import create_client, get_model_name + from step3_1_GPT4V_video_caption_concise import call_gpt + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + prompt = [ + { + "type": "text", + "text": "Describe this scene in 20 words: A plant grows from a seed to a small sprout.", + } + ] + result = call_gpt(prompt, client, model_name=model_name) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 5) + + def test_minimax_call_gpt_detail_function(self): + """Test call_gpt from step3_1_detail with MiniMax client.""" + from llm_provider import create_client, get_model_name + from step3_1_GPT4V_video_caption_detail import call_gpt + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + prompt = [ + { + "type": "text", + "text": "Summarize: A candle melts from tall to a small puddle of wax over 3 hours.", + } + ] + result = call_gpt(prompt, client, model_name=model_name) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_preprocess/tests/test_llm_provider.py b/data_preprocess/tests/test_llm_provider.py new file mode 100644 index 0000000..45f5ac1 --- /dev/null +++ b/data_preprocess/tests/test_llm_provider.py @@ -0,0 +1,233 @@ +"""Unit tests for llm_provider module.""" + +import argparse +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from llm_provider import ( + PROVIDER_PRESETS, + add_provider_args, + clamp_temperature, + create_client, + get_model_name, +) + + +class TestProviderPresets(unittest.TestCase): + """Test provider preset configuration.""" + + def test_openai_preset_exists(self): + self.assertIn("openai", PROVIDER_PRESETS) + + def test_minimax_preset_exists(self): + self.assertIn("minimax", PROVIDER_PRESETS) + + def test_openai_preset_values(self): + preset = PROVIDER_PRESETS["openai"] + self.assertIsNone(preset["base_url"]) + self.assertEqual(preset["default_model"], "gpt-4-vision-preview") + self.assertEqual(preset["env_key"], "OPENAI_API_KEY") + + def test_minimax_preset_values(self): + preset = PROVIDER_PRESETS["minimax"] + self.assertEqual(preset["base_url"], "https://api.minimax.io/v1") + self.assertEqual(preset["default_model"], "MiniMax-M2.7") + self.assertEqual(preset["env_key"], "MINIMAX_API_KEY") + + +class TestAddProviderArgs(unittest.TestCase): + """Test add_provider_args adds correct arguments.""" + + def test_adds_provider_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--provider", "minimax"]) + self.assertEqual(args.provider, "minimax") + + def test_default_provider_is_openai(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args([]) + self.assertEqual(args.provider, "openai") + + def test_adds_base_url_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--base_url", "https://custom.api.com/v1"]) + self.assertEqual(args.base_url, "https://custom.api.com/v1") + + def test_adds_model_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--model", "gpt-4o"]) + self.assertEqual(args.model, "gpt-4o") + + def test_invalid_provider_raises(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + with self.assertRaises(SystemExit): + parser.parse_args(["--provider", "invalid"]) + + +class TestGetModelName(unittest.TestCase): + """Test get_model_name resolution.""" + + def _make_args(self, provider="openai", model=None): + args = argparse.Namespace(provider=provider, model=model) + return args + + def test_openai_default_model(self): + args = self._make_args(provider="openai") + self.assertEqual(get_model_name(args), "gpt-4-vision-preview") + + def test_minimax_default_model(self): + args = self._make_args(provider="minimax") + self.assertEqual(get_model_name(args), "MiniMax-M2.7") + + def test_explicit_model_overrides_default(self): + args = self._make_args(provider="openai", model="gpt-4o") + self.assertEqual(get_model_name(args), "gpt-4o") + + def test_explicit_model_overrides_minimax_default(self): + args = self._make_args(provider="minimax", model="MiniMax-M2.5") + self.assertEqual(get_model_name(args), "MiniMax-M2.5") + + +class TestCreateClient(unittest.TestCase): + """Test create_client creates OpenAI client with correct params.""" + + @patch("llm_provider.OpenAI") + def test_openai_client_default(self, mock_openai_cls): + args = argparse.Namespace( + provider="openai", base_url=None, api_key="test-key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with(api_key="test-key") + + @patch("llm_provider.OpenAI") + def test_minimax_client_sets_base_url(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url=None, api_key="mm-key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="mm-key", base_url="https://api.minimax.io/v1" + ) + + @patch("llm_provider.OpenAI") + def test_custom_base_url_overrides_preset(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url="https://custom.api.com/v1", api_key="key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="key", base_url="https://custom.api.com/v1" + ) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "env-mm-key"}, clear=False) + @patch("llm_provider.OpenAI") + def test_minimax_env_key_auto_detected(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url=None, api_key=None + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="env-mm-key", base_url="https://api.minimax.io/v1" + ) + + @patch.dict(os.environ, {"OPENAI_API_KEY": "env-oai-key"}, clear=False) + @patch("llm_provider.OpenAI") + def test_openai_env_key_fallback(self, mock_openai_cls): + args = argparse.Namespace( + provider="openai", base_url=None, api_key=None + ) + create_client(args) + mock_openai_cls.assert_called_once_with(api_key="env-oai-key") + + +class TestClampTemperature(unittest.TestCase): + """Test temperature clamping for different providers.""" + + def test_openai_no_clamping(self): + self.assertEqual(clamp_temperature(0.0, "openai"), 0.0) + self.assertEqual(clamp_temperature(2.0, "openai"), 2.0) + + def test_minimax_clamp_low(self): + self.assertEqual(clamp_temperature(0.0, "minimax"), 0.01) + + def test_minimax_clamp_high(self): + self.assertEqual(clamp_temperature(2.0, "minimax"), 1.0) + + def test_minimax_normal_range(self): + self.assertEqual(clamp_temperature(0.5, "minimax"), 0.5) + + def test_minimax_boundary_one(self): + self.assertEqual(clamp_temperature(1.0, "minimax"), 1.0) + + +class TestCallGptFunction(unittest.TestCase): + """Test call_gpt function signature compatibility.""" + + def test_call_gpt_with_client_and_model(self): + """Verify call_gpt accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "test caption" + mock_client.chat.completions.create.return_value = mock_response + + from step3_1_GPT4V_video_caption_concise import call_gpt + + result = call_gpt( + [{"type": "text", "text": "test prompt"}], + mock_client, + model_name="MiniMax-M2.7", + ) + self.assertEqual(result, "test caption") + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args + self.assertEqual(call_kwargs.kwargs["model"], "MiniMax-M2.7") + + +class TestSaveOutputFunction(unittest.TestCase): + """Test save_output function signature compatibility.""" + + def test_save_output_accepts_client_and_model(self): + """Verify save_output accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + from step3_1_GPT4V_video_caption_concise import save_output + + import inspect + sig = inspect.signature(save_output) + params = list(sig.parameters.keys()) + self.assertIn("client", params) + self.assertIn("model_name", params) + self.assertNotIn("api_key", params) + + +class TestMainFunction(unittest.TestCase): + """Test main function signature compatibility.""" + + def test_main_accepts_client_and_model(self): + """Verify main() accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + from step3_1_GPT4V_video_caption_concise import main + + import inspect + sig = inspect.signature(main) + params = list(sig.parameters.keys()) + self.assertIn("client", params) + self.assertIn("model_name", params) + self.assertNotIn("api_key", params) + + +if __name__ == "__main__": + unittest.main()