From ff4716c4c306b345e69ce31dea14d6ad9cd96445 Mon Sep 17 00:00:00 2001 From: PR Bot Date: Tue, 24 Mar 2026 18:17:51 +0800 Subject: [PATCH] feat: add MiniMax as alternative LLM provider for data preprocessing Add configurable LLM provider support to the data preprocessing scripts (frame captioning and video captioning). Users can now choose between OpenAI (default) and MiniMax via --provider flag, or use any OpenAI-compatible API via --base_url and --model. Changes: - Add data_preprocess/llm_provider.py: shared provider config module with PROVIDER_PRESETS, create_client(), get_model_name() - Modify step2_1_GPT4V_frame_caption.py: use configurable client/model - Modify step3_1_GPT4V_video_caption_concise.py: same - Modify step3_1_GPT4V_video_caption_detail.py: same - Update run.sh: add PROVIDER variable - Add 26 unit tests + 3 integration tests - Update README with MiniMax usage docs --- README.md | 50 +++- data_preprocess/llm_provider.py | 115 +++++++++ data_preprocess/run.sh | 17 +- .../step2_1_GPT4V_frame_caption.py | 30 ++- .../step3_1_GPT4V_video_caption_concise.py | 24 +- .../step3_1_GPT4V_video_caption_detail.py | 24 +- data_preprocess/tests/__init__.py | 0 data_preprocess/tests/test_integration.py | 99 ++++++++ data_preprocess/tests/test_llm_provider.py | 233 ++++++++++++++++++ 9 files changed, 556 insertions(+), 36 deletions(-) create mode 100644 data_preprocess/llm_provider.py create mode 100644 data_preprocess/tests/__init__.py create mode 100644 data_preprocess/tests/test_integration.py create mode 100644 data_preprocess/tests/test_llm_provider.py diff --git a/README.md b/README.md index 4a8a987..9e66860 100644 --- a/README.md +++ b/README.md @@ -378,7 +378,55 @@ We found some plugins created by community developers. Thanks for their efforts: - Replicate Demo & Cloud API. [Replicate-MagicTime](https://replicate.com/camenduru/magictime) (by [@camenduru](https://twitter.com/camenduru)). - Jupyter Notebook. [Jupyter-MagicTime](https://github.com/camenduru/MagicTime-jupyter) (by [@ModelsLab](https://modelslab.com/)). -If you find related work, please let us know. +If you find related work, please let us know. + +## 🤖 LLM Provider for Data Preprocessing + +The data preprocessing scripts (`data_preprocess/`) support multiple LLM providers for video/frame captioning. By default, **OpenAI GPT-4V** is used, but you can switch to **[MiniMax](https://www.minimaxi.com/)** or any OpenAI-compatible API. + +### Using MiniMax + +[MiniMax](https://www.minimaxi.com/) provides powerful LLM models (MiniMax-M2.7, MiniMax-M2.5) with an OpenAI-compatible API, supporting both text and vision inputs. + +```bash +# Set your MiniMax API key +export MINIMAX_API_KEY="your-api-key" + +# Frame captioning with MiniMax +python data_preprocess/step2_1_GPT4V_frame_caption.py \ + --provider minimax \ + --image_directories ./step_1 \ + --output_file ./2_1_gpt_frames_caption.json + +# Video captioning (concise) with MiniMax +python data_preprocess/step3_1_GPT4V_video_caption_concise.py \ + --provider minimax \ + --input_file ./2_2_final_useful_gpt_frames_caption.json \ + --output_file ./3_1_gpt_video_caption.json + +# Or edit data_preprocess/run.sh and set PROVIDER="minimax" +``` + +### Using a Custom Provider + +You can also use any OpenAI-compatible API by specifying `--base_url` and `--model`: + +```bash +python data_preprocess/step2_1_GPT4V_frame_caption.py \ + --base_url https://your-api.example.com/v1 \ + --model your-model-name \ + --api_key your-api-key \ + --image_directories ./step_1 +``` + +### Available Provider Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--provider` | `openai` | LLM provider (`openai` or `minimax`) | +| `--base_url` | Provider default | Custom API base URL | +| `--model` | Provider default | Model name | +| `--api_key` | From env var | API key (or set `OPENAI_API_KEY` / `MINIMAX_API_KEY`) | ## 🐳 ChronoMagic Dataset ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic), or you can download it with the following command. Some samples can be found on our [Project Page](https://pku-yuangroup.github.io/MagicTime/). diff --git a/data_preprocess/llm_provider.py b/data_preprocess/llm_provider.py new file mode 100644 index 0000000..9d52ad3 --- /dev/null +++ b/data_preprocess/llm_provider.py @@ -0,0 +1,115 @@ +""" +LLM provider configuration for data preprocessing scripts. + +Supports multiple LLM providers via OpenAI-compatible APIs: +- OpenAI (default): GPT-4V, GPT-4o, etc. +- MiniMax: MiniMax-M2.7, MiniMax-M2.5, etc. +- Any OpenAI-compatible provider via --base_url and --model + +Usage: + from llm_provider import add_provider_args, create_client, get_model_name + + # In argument parser setup: + add_provider_args(parser) + + # In code: + args = parser.parse_args() + client = create_client(args) + model = get_model_name(args) +""" + +import os + +from openai import OpenAI + +# Provider presets: base_url, default model, env var for API key +PROVIDER_PRESETS = { + "openai": { + "base_url": None, # OpenAI SDK default + "default_model": "gpt-4-vision-preview", + "env_key": "OPENAI_API_KEY", + }, + "minimax": { + "base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + "env_key": "MINIMAX_API_KEY", + }, +} + + +def add_provider_args(parser): + """Add LLM provider arguments to an argparse parser.""" + parser.add_argument( + "--provider", + type=str, + default="openai", + choices=list(PROVIDER_PRESETS.keys()), + help="LLM provider to use (default: openai).", + ) + parser.add_argument( + "--base_url", + type=str, + default=None, + help="Custom API base URL (overrides provider default).", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model name (overrides provider default).", + ) + + +def _resolve_api_key(args): + """Resolve API key from args or environment variables.""" + # Explicit --api_key takes priority + api_key = getattr(args, "api_key", None) + if api_key: + return api_key + + # Check provider-specific env var + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + env_key = preset["env_key"] + api_key = os.environ.get(env_key) + if api_key: + return api_key + + # Fallback to OPENAI_API_KEY for any provider + return os.environ.get("OPENAI_API_KEY") + + +def create_client(args): + """Create an OpenAI-compatible client based on provider args.""" + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + + base_url = getattr(args, "base_url", None) or preset["base_url"] + api_key = _resolve_api_key(args) + + kwargs = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + + return OpenAI(**kwargs) + + +def get_model_name(args): + """Get the model name from args or provider defaults.""" + model = getattr(args, "model", None) + if model: + return model + + provider = getattr(args, "provider", "openai") + preset = PROVIDER_PRESETS.get(provider, PROVIDER_PRESETS["openai"]) + return preset["default_model"] + + +def clamp_temperature(temperature, provider="openai"): + """Clamp temperature to provider-specific valid range. + + MiniMax accepts temperature in (0.0, 1.0]. + """ + if provider == "minimax": + return max(0.01, min(temperature, 1.0)) + return temperature diff --git a/data_preprocess/run.sh b/data_preprocess/run.sh index 7e09ef7..f3d2201 100644 --- a/data_preprocess/run.sh +++ b/data_preprocess/run.sh @@ -6,6 +6,10 @@ OUTPUT_FOLDER_STEP_1="./step_1" API_KEY="XXX" NUM_WORKERS=8 +# LLM provider: "openai" (default) or "minimax" +# For MiniMax, set MINIMAX_API_KEY env var and change PROVIDER to "minimax" +PROVIDER="openai" + # File paths FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json" GROUP_FRAMES_FILE="./2_1_temp_group_frames.json" @@ -22,22 +26,25 @@ FINAL_CSV_FILE="./all_clean_data.csv" # Step 1: Extract and resize frames python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1" -# Step 2.1: Generate frame captions using GPT-4V +# Step 2.1: Generate frame captions using LLM (GPT-4V or MiniMax) python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \ - --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" + --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" \ + --provider "$PROVIDER" # Step 2.2: Preprocess frame captions python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \ --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \ --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE" -# Step 3.1: Generate concise video captions using GPT-4V +# Step 3.1: Generate concise video captions using LLM (GPT-4V or MiniMax) python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \ - --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" + --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" \ + --provider "$PROVIDER" # Optional: Generate detailed video captions (uncomment to enable) # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \ -# --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" +# --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" \ +# --provider "$PROVIDER" # Step 3.2: Preprocess video captions python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \ diff --git a/data_preprocess/step2_1_GPT4V_frame_caption.py b/data_preprocess/step2_1_GPT4V_frame_caption.py index e397080..72e0ef2 100644 --- a/data_preprocess/step2_1_GPT4V_frame_caption.py +++ b/data_preprocess/step2_1_GPT4V_frame_caption.py @@ -4,11 +4,12 @@ import base64 import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, wait_exponential, stop_after_attempt +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Suppose you are a data annotator, specialized in generating captions for time-lapse videos. You will be supplied with eight key frames extracted from a video, each with a filename labeled with its position in the video sequence. Your task is to generate a caption for each frame, focusing on the primary subject and integrating all discernible elements. Note: These captions should be brief and concise, avoiding redundancy. @@ -165,8 +166,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -180,9 +180,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): print(chat_completion) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -193,7 +193,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -204,12 +204,12 @@ def main(num_workers, all_prompts, output_file, api_key): return print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.") - + progress_bar = tqdm(total=len(unprocessed_prompts)) with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -225,15 +225,21 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Process video frame captions.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=6, help="Number of worker threads for processing.") parser.add_argument("--output_file", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the output JSON file.") parser.add_argument("--group_frames_file", type=str, default="./2_1_temp_group_frames.json", help="Path to save grouped frame metadata.") parser.add_argument("--image_directories", type=str, nargs="+", default=["./step_1"], help="List of directories containing images.") - + add_provider_args(parser) + # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + all_prompts = {} all_grouped_images = {} @@ -241,7 +247,7 @@ def main(num_workers, all_prompts, output_file, api_key): for directory in args.image_directories: filenames = get_image_filenames(directory) grouped_images = group_images_by_video_id(filenames) - + # Sort images within each video group for video_id in grouped_images: grouped_images[video_id].sort(key=extract_frame_number) @@ -257,4 +263,4 @@ def main(num_workers, all_prompts, output_file, api_key): json.dump(all_grouped_images, file, indent=4) # Execute main processing function - main(args.num_workers, all_prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, all_prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/step3_1_GPT4V_video_caption_concise.py b/data_preprocess/step3_1_GPT4V_video_caption_concise.py index 54ab0eb..ddef994 100644 --- a/data_preprocess/step3_1_GPT4V_video_caption_concise.py +++ b/data_preprocess/step3_1_GPT4V_video_caption_concise.py @@ -2,11 +2,12 @@ import json import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from tenacity import retry, wait_exponential, stop_after_attempt from concurrent.futures import ThreadPoolExecutor, as_completed +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Imagine you're an expert data annotator with a specialization in summarizing time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a concise summary for the given time-lapse video. @@ -59,8 +60,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -73,9 +73,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): ) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -86,7 +86,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -102,7 +102,7 @@ def main(num_workers, all_prompts, output_file, api_key): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -118,14 +118,20 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") parser.add_argument("--input_file", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") + add_provider_args(parser) # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + # Load data from the input file with open(args.input_file, 'r') as file: data = json.load(file) @@ -134,4 +140,4 @@ def main(num_workers, all_prompts, output_file, api_key): prompts = create_prompts(txt_prompt, data) # Execute main processing function - main(args.num_workers, prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/step3_1_GPT4V_video_caption_detail.py b/data_preprocess/step3_1_GPT4V_video_caption_detail.py index b6c2c7f..7fc38b2 100644 --- a/data_preprocess/step3_1_GPT4V_video_caption_detail.py +++ b/data_preprocess/step3_1_GPT4V_video_caption_detail.py @@ -2,11 +2,12 @@ import json import argparse from tqdm import tqdm -from openai import OpenAI from threading import Lock from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import retry, wait_exponential, stop_after_attempt +from llm_provider import add_provider_args, create_client, get_model_name + txt_prompt = ''' Imagine you are a data annotator, specialized in generating summaries for time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a succinct and precise summary for the given time-lapse video. Note: The summary should efficiently encapsulate all discernible elements, particularly emphasizing the primary subject. It is important to indicate whether the video pertains to a forward or reverse sequence. Additionally, integrate any time-related aspects from the video into the summary. @@ -61,8 +62,7 @@ def load_existing_results(file_path): return empty_data @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) -def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): - client = OpenAI(api_key=api_key) +def call_gpt(prompt, client, model_name="gpt-4-vision-preview"): chat_completion = client.chat.completions.create( model=model_name, messages=[ @@ -75,9 +75,9 @@ def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): ) return chat_completion.choices[0].message.content -def save_output(video_id, prompt, output_file, api_key): +def save_output(video_id, prompt, output_file, client, model_name): if not has_been_processed(video_id, output_file): - result = call_gpt(prompt, api_key=api_key) + result = call_gpt(prompt, client, model_name=model_name) with file_lock: with open(output_file, 'r+') as f: # Read the current data and update it @@ -88,7 +88,7 @@ def save_output(video_id, prompt, output_file, api_key): f.truncate() # Truncate file to new size print(f"Processed and saved output for Video ID {video_id}") -def main(num_workers, all_prompts, output_file, api_key): +def main(num_workers, all_prompts, output_file, client, model_name): # Load existing results existing_results = load_existing_results(output_file) @@ -104,7 +104,7 @@ def main(num_workers, all_prompts, output_file, api_key): with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_index = { - executor.submit(save_output, video_id, prompt, output_file, api_key): video_id + executor.submit(save_output, video_id, prompt, output_file, client, model_name): video_id for video_id, prompt in unprocessed_prompts.items() } @@ -120,14 +120,20 @@ def main(num_workers, all_prompts, output_file, api_key): if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") - parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") + parser.add_argument("--api_key", type=str, default=None, help="API key (or set OPENAI_API_KEY / MINIMAX_API_KEY env var).") parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") parser.add_argument("--input_file", type=str, default="2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") + add_provider_args(parser) # Parse command-line arguments args = parser.parse_args() + # Create LLM client and get model name + client = create_client(args) + model_name = get_model_name(args) + print(f"Using provider: {args.provider}, model: {model_name}") + # Load data from the input file with open(args.input_file, 'r') as file: data = json.load(file) @@ -136,4 +142,4 @@ def main(num_workers, all_prompts, output_file, api_key): prompts = create_prompts(txt_prompt, data) # Execute main processing function - main(args.num_workers, prompts, args.output_file, args.api_key) \ No newline at end of file + main(args.num_workers, prompts, args.output_file, client, model_name) \ No newline at end of file diff --git a/data_preprocess/tests/__init__.py b/data_preprocess/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_preprocess/tests/test_integration.py b/data_preprocess/tests/test_integration.py new file mode 100644 index 0000000..4fbe673 --- /dev/null +++ b/data_preprocess/tests/test_integration.py @@ -0,0 +1,99 @@ +"""Integration tests for MiniMax LLM provider. + +These tests verify actual API connectivity with MiniMax. +They require MINIMAX_API_KEY to be set in the environment. +Skipped automatically when the key is not available. +""" + +import argparse +import os +import sys +import unittest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxTextCompletion(unittest.TestCase): + """Test MiniMax provider for text-based captioning.""" + + def test_minimax_text_completion(self): + from llm_provider import create_client, get_model_name + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": "Summarize in one sentence: A flower blooms from bud to full bloom over 7 days.", + } + ], + max_tokens=128, + ) + content = response.choices[0].message.content + self.assertIsInstance(content, str) + self.assertGreater(len(content), 10) + + def test_minimax_call_gpt_function(self): + """Test call_gpt from step3_1 with MiniMax client.""" + from llm_provider import create_client, get_model_name + from step3_1_GPT4V_video_caption_concise import call_gpt + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + prompt = [ + { + "type": "text", + "text": "Describe this scene in 20 words: A plant grows from a seed to a small sprout.", + } + ] + result = call_gpt(prompt, client, model_name=model_name) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 5) + + def test_minimax_call_gpt_detail_function(self): + """Test call_gpt from step3_1_detail with MiniMax client.""" + from llm_provider import create_client, get_model_name + from step3_1_GPT4V_video_caption_detail import call_gpt + + args = argparse.Namespace( + provider="minimax", + base_url=None, + model="MiniMax-M2.7", + api_key=None, + ) + client = create_client(args) + model_name = get_model_name(args) + + prompt = [ + { + "type": "text", + "text": "Summarize: A candle melts from tall to a small puddle of wax over 3 hours.", + } + ] + result = call_gpt(prompt, client, model_name=model_name) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/data_preprocess/tests/test_llm_provider.py b/data_preprocess/tests/test_llm_provider.py new file mode 100644 index 0000000..45f5ac1 --- /dev/null +++ b/data_preprocess/tests/test_llm_provider.py @@ -0,0 +1,233 @@ +"""Unit tests for llm_provider module.""" + +import argparse +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from llm_provider import ( + PROVIDER_PRESETS, + add_provider_args, + clamp_temperature, + create_client, + get_model_name, +) + + +class TestProviderPresets(unittest.TestCase): + """Test provider preset configuration.""" + + def test_openai_preset_exists(self): + self.assertIn("openai", PROVIDER_PRESETS) + + def test_minimax_preset_exists(self): + self.assertIn("minimax", PROVIDER_PRESETS) + + def test_openai_preset_values(self): + preset = PROVIDER_PRESETS["openai"] + self.assertIsNone(preset["base_url"]) + self.assertEqual(preset["default_model"], "gpt-4-vision-preview") + self.assertEqual(preset["env_key"], "OPENAI_API_KEY") + + def test_minimax_preset_values(self): + preset = PROVIDER_PRESETS["minimax"] + self.assertEqual(preset["base_url"], "https://api.minimax.io/v1") + self.assertEqual(preset["default_model"], "MiniMax-M2.7") + self.assertEqual(preset["env_key"], "MINIMAX_API_KEY") + + +class TestAddProviderArgs(unittest.TestCase): + """Test add_provider_args adds correct arguments.""" + + def test_adds_provider_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--provider", "minimax"]) + self.assertEqual(args.provider, "minimax") + + def test_default_provider_is_openai(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args([]) + self.assertEqual(args.provider, "openai") + + def test_adds_base_url_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--base_url", "https://custom.api.com/v1"]) + self.assertEqual(args.base_url, "https://custom.api.com/v1") + + def test_adds_model_arg(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + args = parser.parse_args(["--model", "gpt-4o"]) + self.assertEqual(args.model, "gpt-4o") + + def test_invalid_provider_raises(self): + parser = argparse.ArgumentParser() + add_provider_args(parser) + with self.assertRaises(SystemExit): + parser.parse_args(["--provider", "invalid"]) + + +class TestGetModelName(unittest.TestCase): + """Test get_model_name resolution.""" + + def _make_args(self, provider="openai", model=None): + args = argparse.Namespace(provider=provider, model=model) + return args + + def test_openai_default_model(self): + args = self._make_args(provider="openai") + self.assertEqual(get_model_name(args), "gpt-4-vision-preview") + + def test_minimax_default_model(self): + args = self._make_args(provider="minimax") + self.assertEqual(get_model_name(args), "MiniMax-M2.7") + + def test_explicit_model_overrides_default(self): + args = self._make_args(provider="openai", model="gpt-4o") + self.assertEqual(get_model_name(args), "gpt-4o") + + def test_explicit_model_overrides_minimax_default(self): + args = self._make_args(provider="minimax", model="MiniMax-M2.5") + self.assertEqual(get_model_name(args), "MiniMax-M2.5") + + +class TestCreateClient(unittest.TestCase): + """Test create_client creates OpenAI client with correct params.""" + + @patch("llm_provider.OpenAI") + def test_openai_client_default(self, mock_openai_cls): + args = argparse.Namespace( + provider="openai", base_url=None, api_key="test-key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with(api_key="test-key") + + @patch("llm_provider.OpenAI") + def test_minimax_client_sets_base_url(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url=None, api_key="mm-key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="mm-key", base_url="https://api.minimax.io/v1" + ) + + @patch("llm_provider.OpenAI") + def test_custom_base_url_overrides_preset(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url="https://custom.api.com/v1", api_key="key" + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="key", base_url="https://custom.api.com/v1" + ) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "env-mm-key"}, clear=False) + @patch("llm_provider.OpenAI") + def test_minimax_env_key_auto_detected(self, mock_openai_cls): + args = argparse.Namespace( + provider="minimax", base_url=None, api_key=None + ) + create_client(args) + mock_openai_cls.assert_called_once_with( + api_key="env-mm-key", base_url="https://api.minimax.io/v1" + ) + + @patch.dict(os.environ, {"OPENAI_API_KEY": "env-oai-key"}, clear=False) + @patch("llm_provider.OpenAI") + def test_openai_env_key_fallback(self, mock_openai_cls): + args = argparse.Namespace( + provider="openai", base_url=None, api_key=None + ) + create_client(args) + mock_openai_cls.assert_called_once_with(api_key="env-oai-key") + + +class TestClampTemperature(unittest.TestCase): + """Test temperature clamping for different providers.""" + + def test_openai_no_clamping(self): + self.assertEqual(clamp_temperature(0.0, "openai"), 0.0) + self.assertEqual(clamp_temperature(2.0, "openai"), 2.0) + + def test_minimax_clamp_low(self): + self.assertEqual(clamp_temperature(0.0, "minimax"), 0.01) + + def test_minimax_clamp_high(self): + self.assertEqual(clamp_temperature(2.0, "minimax"), 1.0) + + def test_minimax_normal_range(self): + self.assertEqual(clamp_temperature(0.5, "minimax"), 0.5) + + def test_minimax_boundary_one(self): + self.assertEqual(clamp_temperature(1.0, "minimax"), 1.0) + + +class TestCallGptFunction(unittest.TestCase): + """Test call_gpt function signature compatibility.""" + + def test_call_gpt_with_client_and_model(self): + """Verify call_gpt accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "test caption" + mock_client.chat.completions.create.return_value = mock_response + + from step3_1_GPT4V_video_caption_concise import call_gpt + + result = call_gpt( + [{"type": "text", "text": "test prompt"}], + mock_client, + model_name="MiniMax-M2.7", + ) + self.assertEqual(result, "test caption") + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args + self.assertEqual(call_kwargs.kwargs["model"], "MiniMax-M2.7") + + +class TestSaveOutputFunction(unittest.TestCase): + """Test save_output function signature compatibility.""" + + def test_save_output_accepts_client_and_model(self): + """Verify save_output accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + from step3_1_GPT4V_video_caption_concise import save_output + + import inspect + sig = inspect.signature(save_output) + params = list(sig.parameters.keys()) + self.assertIn("client", params) + self.assertIn("model_name", params) + self.assertNotIn("api_key", params) + + +class TestMainFunction(unittest.TestCase): + """Test main function signature compatibility.""" + + def test_main_accepts_client_and_model(self): + """Verify main() accepts client and model_name parameters.""" + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + from step3_1_GPT4V_video_caption_concise import main + + import inspect + sig = inspect.signature(main) + params = list(sig.parameters.keys()) + self.assertIn("client", params) + self.assertIn("model_name", params) + self.assertNotIn("api_key", params) + + +if __name__ == "__main__": + unittest.main()