diff --git a/v2/tutorials/ml/batch_image_pipeline.py b/v2/tutorials/ml/batch_image_pipeline.py new file mode 100644 index 00000000..56e7d700 --- /dev/null +++ b/v2/tutorials/ml/batch_image_pipeline.py @@ -0,0 +1,368 @@ +# /// script +# requires-python = "==3.13" +# dependencies = [ +# "flyte>=2.0.0b35", +# "torch>=2.0", +# "torchvision>=0.15", +# "Pillow>=10.0", +# "httpx", +# "async-lru", +# "datasets>=2.18", +# ] +# main = "batch_image_pipeline" +# params = "dataset_name='beans', split='test', max_images=200" +# /// + +""" +Batch Image Classification Pipeline +==================================== + +Demonstrates a 3-stage async pipeline that maximizes GPU utilization by +overlapping I/O, CPU preprocessing, and GPU inference using +``InferencePipeline`` from ``flyte.extras``. + +Architecture:: + + [I/O: Download Images] Runs on preprocess_executor (16 threads) + | + [CPU: Resize + Normalize] Same executor — PIL/torchvision release the GIL + | + [GPU: model.forward()] DynamicBatcher batches items, runs on gpu_pool (1 thread) + | + [Decode Labels + Confidence] Event loop (lightweight) + +Key patterns: +- ``InferencePipeline`` wires preprocess → DynamicBatcher → postprocess +- ``alru_cache`` singletons for model + pipeline (shared across concurrent tasks) +- ``ReusePolicy`` keeps warm containers with loaded models +- Multiple concurrent tasks on the same replica all feed one pipeline → bigger GPU batches + +Usage:: + + flyte run batch_image_pipeline.py classify_dataset +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from io import BytesIO + +import httpx +import torch +import torchvision.models as models +import torchvision.transforms as T +from async_lru import alru_cache +from PIL import Image + +import flyte +import flyte.io +from flyte.extras import InferencePipeline + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Thread pools (module-level singletons, shared across concurrent tasks) +# --------------------------------------------------------------------------- + +# I/O + CPU preprocessing share a pool — both release the GIL. +# A dedicated single-thread GPU pool prevents contention. +_io_cpu_pool = ThreadPoolExecutor(max_workers=16, thread_name_prefix="io-cpu") +_gpu_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="gpu") + +# --------------------------------------------------------------------------- +# Image & environments +# --------------------------------------------------------------------------- + +image = flyte.Image.from_uv_script( + __file__, name="batch_image_pipeline_image" +).with_pip_packages("unionai-reuse>=0.1.9") + +worker = flyte.TaskEnvironment( + name="image_pipeline_worker", + image=image, + resources=flyte.Resources(cpu=4, memory="8Gi", gpu="T4:1"), + reusable=flyte.ReusePolicy( + replicas=3, + concurrency=4, # 4 concurrent tasks per replica → 12 streams feeding 3 GPUs + idle_ttl=120, + scaledown_ttl=120, + ), +) + +driver = flyte.TaskEnvironment( + name="image_pipeline_driver", + image=image, + resources=flyte.Resources(cpu=2, memory="4Gi"), + depends_on=[worker], +) + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + + +@dataclass +class ImageItem: + """A single image to classify.""" + url: str + image_id: str + + +@dataclass +class ClassificationResult: + """Final output after postprocessing.""" + image_id: str + url: str + top_label: str + confidence: float + top5: list[tuple[str, float]] + + +# --------------------------------------------------------------------------- +# Model loading (process-level singleton via alru_cache) +# --------------------------------------------------------------------------- + +_preprocess_transform = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + + +@alru_cache(maxsize=1) +async def _load_model(): + """Load ResNet-50 once per process. Shared across all concurrent tasks.""" + loop = asyncio.get_running_loop() + + def _load(): + model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) + model.eval() + if torch.cuda.is_available(): + model = model.half().cuda() + # dynamic=False + reduce-overhead enables CUDA graphs for fixed shapes. + # ResNet-50 input is always 224x224, only batch dim varies. + model = torch.compile(model, dynamic=False, mode="reduce-overhead") + # Warmup at all plausible batch sizes to avoid JIT spikes at runtime + for bs in [1, 4, 8, 16, 32]: + dummy = torch.randn(bs, 3, 224, 224, dtype=torch.float16, device="cuda") + with torch.no_grad(): + model(dummy) + return model + + model = await loop.run_in_executor(_gpu_pool, _load) + logger.warning("Model loaded on device: %s", "cuda" if torch.cuda.is_available() else "cpu") + return model + + +_IMAGENET_LABELS: list[str] = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"] + + +# --------------------------------------------------------------------------- +# Pipeline stage functions +# --------------------------------------------------------------------------- + +# Shared HTTP client for downloading images (created per-process) +_http_client: httpx.AsyncClient | None = None + + +def _get_http_client() -> httpx.AsyncClient: + global _http_client + if _http_client is None: + _http_client = httpx.AsyncClient(timeout=30, follow_redirects=True) + return _http_client + + +async def preprocess(item: ImageItem) -> torch.Tensor: + """Download an image and apply torchvision transforms. + + The download is async (httpx). The PIL resize/normalize runs on + ``_io_cpu_pool`` to avoid blocking the event loop. + """ + client = _get_http_client() + resp = await client.get(item.url) + resp.raise_for_status() + + loop = asyncio.get_running_loop() + tensor = await loop.run_in_executor( + _io_cpu_pool, + lambda: _preprocess_transform(Image.open(BytesIO(resp.content)).convert("RGB")), + ) + return tensor + + +@dataclass +class Top5Result: + """Top-5 predictions computed on GPU, transferred as small tensors.""" + probs: torch.Tensor # [5] float + indices: torch.Tensor # [5] int + + +async def inference_batch(batch: list[torch.Tensor]) -> list[Top5Result]: + """Run model.forward() on a batch of preprocessed tensors. + + Stacks individual tensors, moves to GPU, runs inference, computes + top-5 on-device (200x less D2H data than full logits), then + transfers only the small result tensors back to CPU. + """ + model = await _load_model() + loop = asyncio.get_running_loop() + + def _forward(): + stacked = torch.stack(batch).half() + if torch.cuda.is_available(): + # pin_memory + non_blocking enables async H2D transfer + stacked = stacked.pin_memory().to("cuda", non_blocking=True) + with torch.no_grad(): + logits = model(stacked) + # Compute top-5 on GPU to minimize D2H transfer ([N,5] vs [N,1000]) + probs = torch.softmax(logits.float(), dim=1) + top5_probs, top5_idx = torch.topk(probs, 5, dim=1) + return [ + Top5Result(probs=top5_probs[i].cpu(), indices=top5_idx[i].cpu()) + for i in range(len(batch)) + ] + + return await loop.run_in_executor(_gpu_pool, _forward) + + +def postprocess(item: ImageItem, result: Top5Result) -> ClassificationResult: + """Decode top-5 indices into human-readable labels.""" + top5 = [ + (_IMAGENET_LABELS[idx], prob) + for idx, prob in zip(result.indices.tolist(), result.probs.tolist()) + ] + return ClassificationResult( + image_id=item.image_id, + url=item.url, + top_label=top5[0][0], + confidence=top5[0][1], + top5=top5, + ) + + +# --------------------------------------------------------------------------- +# Pipeline singleton (shared across concurrent tasks on a replica) +# --------------------------------------------------------------------------- + + +@alru_cache(maxsize=1) +async def get_pipeline() -> InferencePipeline[ImageItem, torch.Tensor, Top5Result, ClassificationResult]: + pipeline = InferencePipeline( + preprocess_fn=preprocess, + inference_fn=inference_batch, + postprocess_fn=postprocess, + target_batch_cost=32, # 1 cost per image (uniform size after resize) + max_batch_size=32, + min_batch_size=8, # avoid pathologically small batches (T4 throughput + # drops ~15x at batch=1 vs batch=32 for ResNet-50) + batch_timeout_s=0.15, # slightly longer to accumulate larger batches + max_queue_size=1_000, + pipeline_depth=16, # up to 16 images preprocessed ahead of GPU + ) + await pipeline.start() + return pipeline + + +# --------------------------------------------------------------------------- +# Worker task +# --------------------------------------------------------------------------- + + +@worker.task(cache="auto", retries=3) +async def classify_images(image_urls: list[str], chunk_id: str) -> list[dict]: + """Classify a chunk of images through the 3-stage pipeline. + + Multiple concurrent calls on the same replica share one pipeline + singleton, so the DynamicBatcher inside sees items from all streams. + """ + pipeline = await get_pipeline() + + items = [ + ImageItem(url=url, image_id=f"{chunk_id}_{i}") + for i, url in enumerate(image_urls) + ] + + results = await pipeline.run_all(items) + + logger.info( + "[%s] %d images classified | GPU utilization: %.1f%% | avg batch: %.1f", + chunk_id, + len(results), + pipeline.stats.utilization * 100, + pipeline.stats.avg_batch_size, + ) + + return [ + { + "image_id": r.image_id, + "url": r.url, + "top_label": r.top_label, + "confidence": r.confidence, + "top5": [(label, round(conf, 4)) for label, conf in r.top5], + } + for r in results + ] + + +# --------------------------------------------------------------------------- +# Driver task +# --------------------------------------------------------------------------- + + +@driver.task(cache="auto") +async def classify_dataset( + dataset_name: str = "beans", + split: str = "test", + max_images: int = 200, + chunk_size: int = 50, +) -> list[dict]: + """Load images from a HuggingFace dataset, fan out to GPU workers. + + Each chunk becomes a separate task call, routed to a warm replica. + All concurrent tasks on the same replica share one InferencePipeline, + keeping the GPU saturated. + """ + from datasets import load_dataset + + ds = load_dataset(dataset_name, split=split) + if max_images: + ds = ds.select(range(min(max_images, len(ds)))) + + # Upload images and collect URLs + import tempfile, os + image_urls = [] + for i, row in enumerate(ds): + img = row["image"] + path = os.path.join(tempfile.gettempdir(), f"img_{i:05d}.jpg") + img.convert("RGB").save(path) + f = await flyte.io.File.from_local(path) + image_urls.append(f.remote_path) + + print(f"Uploaded {len(image_urls)} images, chunking into groups of {chunk_size}") + + # Fan out to workers + tasks = [] + for i in range(0, len(image_urls), chunk_size): + chunk = image_urls[i : i + chunk_size] + chunk_id = f"chunk_{i // chunk_size:03d}" + with flyte.group(f"classify-{chunk_id}"): + tasks.append(asyncio.create_task( + classify_images(chunk, chunk_id) + )) + + all_results = await asyncio.gather(*tasks) + flat = [r for chunk_results in all_results for r in chunk_results] + print(f"Classified {len(flat)} images total") + return flat + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.run(classify_dataset, dataset_name="beans", split="test", max_images=200) + print(run.url) diff --git a/v2/tutorials/ml/batch_llm_pipeline.py b/v2/tutorials/ml/batch_llm_pipeline.py new file mode 100644 index 00000000..a01bb385 --- /dev/null +++ b/v2/tutorials/ml/batch_llm_pipeline.py @@ -0,0 +1,450 @@ +# /// script +# requires-python = "==3.13" +# dependencies = [ +# "flyte>=2.0.0b35", +# "torch>=2.0", +# "transformers>=4.41", +# "accelerate", +# "async-lru", +# "huggingface-hub>=0.24", +# "hf-transfer", +# ] +# main = "batch_llm_pipeline" +# params = "jsonl_path='prompts.jsonl', max_new_tokens=128" +# /// + +""" +Batch LLM Inference Pipeline +============================= + +Demonstrates batch inference with a small LLM (Qwen2.5-0.5B) using +``InferencePipeline`` from ``flyte.extras``. + +Architecture:: + + [I/O: Read JSONL lines] Async file read + | + [CPU: Tokenize + estimate tokens] preprocess_executor (4 threads) + | + [GPU: model.generate()] DynamicBatcher with token budgeting, gpu_pool (1 thread) + | + [CPU: Decode + simple eval] Event loop (lightweight) + +The preprocessing tokenizes each prompt and estimates its token count so +the DynamicBatcher can assemble token-budgeted GPU batches. This prevents +OOM from batches with too many total tokens while still filling each batch +as much as possible. + +Usage:: + + flyte run batch_llm_pipeline.py batch_generate +""" + +import asyncio +import json +import logging +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +import torch +from async_lru import alru_cache +from transformers import AutoModelForCausalLM, AutoTokenizer + +import flyte +import flyte.io +from flyte.extras import InferencePipeline + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Thread pools +# --------------------------------------------------------------------------- + +_cpu_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix="cpu") +_gpu_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="gpu") + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +MAX_INPUT_TOKENS = 512 # truncate inputs longer than this +TARGET_BATCH_TOKENS = 4096 # token budget per GPU batch + +# --------------------------------------------------------------------------- +# Image & environments +# --------------------------------------------------------------------------- + +image = flyte.Image.from_uv_script( + __file__, name="batch_llm_pipeline_image" +).with_pip_packages("unionai-reuse>=0.1.9") + +worker = flyte.TaskEnvironment( + name="llm_pipeline_worker", + image=image, + resources=flyte.Resources(cpu=4, memory="16Gi", gpu="T4:1"), + env_vars={"HF_HUB_ENABLE_HF_TRANSFER": "1"}, + reusable=flyte.ReusePolicy( + replicas=2, + concurrency=4, # 4 concurrent task streams per replica + idle_ttl=120, + scaledown_ttl=120, + ), +) + +driver = flyte.TaskEnvironment( + name="llm_pipeline_driver", + image=image, + resources=flyte.Resources(cpu=2, memory="4Gi"), + depends_on=[worker], +) + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + + +@dataclass +class PromptItem: + """A single JSONL line with a prompt and optional expected answer.""" + prompt: str + expected: str # expected substring for simple eval (empty if none) + line_idx: int + num_tokens: int = 0 # populated after tokenization + + +@dataclass +class TokenizedPrompt: + """Tokenized prompt ready for GPU inference.""" + input_ids: torch.Tensor # [seq_len] + attention_mask: torch.Tensor # [seq_len] + num_tokens: int + + def estimate_cost(self) -> int: + """Token count as cost — DynamicBatcher uses this for budgeting.""" + return self.num_tokens + + +@dataclass +class GenerationResult: + """Final output after postprocessing.""" + line_idx: int + prompt: str + response: str + num_input_tokens: int + num_output_tokens: int + expected: str + match: bool | None # None if no expected answer + + +# --------------------------------------------------------------------------- +# Model loading (process-level singleton) +# --------------------------------------------------------------------------- + + +@alru_cache(maxsize=1) +async def _load_model_and_tokenizer(): + loop = asyncio.get_running_loop() + + def _load(): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.float16, + device_map="auto" if torch.cuda.is_available() else "cpu", + ) + model.eval() + + # Warmup at realistic batch size + generation length to pre-allocate KV cache + if torch.cuda.is_available(): + dummy = tokenizer(["warmup " * 20] * 16, return_tensors="pt", padding=True) + dummy = {k: v.to(model.device) for k, v in dummy.items()} + with torch.no_grad(): + model.generate(**dummy, max_new_tokens=128, pad_token_id=tokenizer.pad_token_id) + + logger.warning("Model %s loaded on %s", MODEL_NAME, model.device) + return model, tokenizer + + return await loop.run_in_executor(_gpu_pool, _load) + + +# --------------------------------------------------------------------------- +# Pipeline stage functions +# --------------------------------------------------------------------------- + + +async def preprocess(item: PromptItem) -> TokenizedPrompt: + """Tokenize a prompt on the CPU threadpool. + + Truncates to MAX_INPUT_TOKENS and returns the token count for + cost-based batch budgeting. + """ + _, tokenizer = await _load_model_and_tokenizer() + loop = asyncio.get_running_loop() + + def _tokenize(): + encoded = tokenizer( + item.prompt, + return_tensors="pt", + truncation=True, + max_length=MAX_INPUT_TOKENS, + ) + num_tokens = encoded["input_ids"].shape[1] + # Store actual token count on the item for postprocessing + item.num_tokens = num_tokens + return TokenizedPrompt( + input_ids=encoded["input_ids"].squeeze(0), + attention_mask=encoded["attention_mask"].squeeze(0), + num_tokens=num_tokens, + ) + + return await loop.run_in_executor(_cpu_pool, _tokenize) + + +async def inference_batch( + batch: list[TokenizedPrompt], + max_new_tokens: int = 128, +) -> list[str]: + """Run model.generate() on a batch of tokenized prompts. + + Pads the batch to uniform length, generates on GPU, decodes back + to text. Returns only the generated portion (not the input). + """ + model, tokenizer = await _load_model_and_tokenizer() + loop = asyncio.get_running_loop() + + def _generate(): + # Use tokenizer.pad() for correct left-padding (handles edge cases + # and avoids the pad_token_id=0 fallback that can corrupt BOS tokens) + padded = tokenizer.pad( + {"input_ids": [t.input_ids for t in batch], + "attention_mask": [t.attention_mask for t in batch]}, + padding=True, + return_tensors="pt", + ) + input_ids = padded["input_ids"].to(model.device) + attention_mask = padded["attention_mask"].to(model.device) + + with torch.no_grad(): + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, # early termination for short answers + ) + + # Decode only the generated tokens (strip input prefix) + generated = outputs[:, input_ids.shape[1]:] + texts = tokenizer.batch_decode(generated, skip_special_tokens=True) + return texts + + return await loop.run_in_executor(_gpu_pool, _generate) + + +def postprocess(item: PromptItem, response: str) -> GenerationResult: + """Decode and run simple eval: check if expected substring is present.""" + match = None + if item.expected: + match = item.expected.lower() in response.lower() + + return GenerationResult( + line_idx=item.line_idx, + prompt=item.prompt, + response=response.strip(), + num_input_tokens=item.num_tokens, + num_output_tokens=len(response.split()), + expected=item.expected, + match=match, + ) + + +# --------------------------------------------------------------------------- +# Pipeline singleton +# --------------------------------------------------------------------------- + + +@alru_cache(maxsize=1) +async def get_pipeline() -> InferencePipeline[PromptItem, TokenizedPrompt, str, GenerationResult]: + pipeline = InferencePipeline( + preprocess_fn=preprocess, + inference_fn=inference_batch, + postprocess_fn=postprocess, + target_batch_cost=TARGET_BATCH_TOKENS, + max_batch_size=16, + min_batch_size=4, # avoid tiny batches that underutilize the GPU + batch_timeout_s=0.2, + max_queue_size=500, + pipeline_depth=8, + ) + await pipeline.start() + return pipeline + + +# --------------------------------------------------------------------------- +# Worker task +# --------------------------------------------------------------------------- + + +@worker.task(cache="auto", retries=2) +async def generate_responses( + jsonl_file: flyte.io.File, + chunk_id: str, + max_new_tokens: int = 128, +) -> list[dict]: + """Process a chunk of JSONL prompts through the LLM pipeline. + + Each line is expected to be a JSON object with at least a "prompt" + field, and optionally an "expected" field for simple eval. + """ + pipeline = await get_pipeline() + + # Download and parse JSONL + local_path = await jsonl_file.download() + items = [] + with open(local_path) as f: + for i, line in enumerate(f): + line = line.strip() + if not line: + continue + data = json.loads(line) + items.append(PromptItem( + prompt=data["prompt"], + expected=data.get("expected", ""), + line_idx=i, + )) + + results = await pipeline.run_all(items) + + # Log stats + matches = [r for r in results if r.match is True] + total_with_expected = [r for r in results if r.match is not None] + accuracy = len(matches) / len(total_with_expected) if total_with_expected else 0 + + logger.info( + "[%s] %d prompts | GPU util: %.1f%% | avg batch: %.1f tokens | accuracy: %.1f%% (%d/%d)", + chunk_id, + len(results), + pipeline.stats.utilization * 100, + pipeline.stats.avg_batch_cost, + accuracy * 100, + len(matches), + len(total_with_expected), + ) + + return [ + { + "line_idx": r.line_idx, + "prompt": r.prompt[:200], # truncate for readability + "response": r.response, + "num_output_tokens": r.num_output_tokens, + "expected": r.expected, + "match": r.match, + } + for r in results + ] + + +# --------------------------------------------------------------------------- +# Driver task +# --------------------------------------------------------------------------- + + +@driver.task(cache="auto") +async def batch_generate( + jsonl_path: str = "prompts.jsonl", + max_new_tokens: int = 128, + chunk_size: int = 50, +) -> list[dict]: + """Generate LLM responses for all prompts in a JSONL file. + + If no JSONL file is provided, creates a demo dataset with sample prompts. + """ + # Create demo JSONL if needed + if jsonl_path == "prompts.jsonl": + jsonl_path = _create_demo_jsonl() + + # Read and chunk the JSONL + with open(jsonl_path) as f: + lines = [l.strip() for l in f if l.strip()] + + print(f"Loaded {len(lines)} prompts from {jsonl_path}") + + # Split into chunk files and upload + tasks = [] + for i in range(0, len(lines), chunk_size): + chunk_lines = lines[i : i + chunk_size] + chunk_id = f"chunk_{i // chunk_size:03d}" + + # Write chunk to temp file and upload + chunk_path = os.path.join(tempfile.gettempdir(), f"{chunk_id}.jsonl") + with open(chunk_path, "w") as f: + f.write("\n".join(chunk_lines) + "\n") + chunk_file = await flyte.io.File.from_local(chunk_path) + + with flyte.group(f"generate-{chunk_id}"): + tasks.append(asyncio.create_task( + generate_responses(chunk_file, chunk_id, max_new_tokens) + )) + + all_results = await asyncio.gather(*tasks) + flat = [r for chunk_results in all_results for r in chunk_results] + + # Summary + matches = sum(1 for r in flat if r.get("match") is True) + total_eval = sum(1 for r in flat if r.get("match") is not None) + print(f"\nCompleted {len(flat)} generations") + if total_eval: + print(f"Eval accuracy: {matches}/{total_eval} ({matches/total_eval*100:.1f}%)") + + return flat + + +def _create_demo_jsonl() -> str: + """Create a small demo JSONL with diverse prompts for testing.""" + prompts = [ + {"prompt": "What is the capital of France?", "expected": "Paris"}, + {"prompt": "What is 2 + 2?", "expected": "4"}, + {"prompt": "Translate 'hello' to Spanish.", "expected": "hola"}, + {"prompt": "What color is the sky on a clear day?", "expected": "blue"}, + {"prompt": "What is the largest planet in our solar system?", "expected": "Jupiter"}, + {"prompt": "What is the chemical symbol for water?", "expected": "H2O"}, + {"prompt": "Who wrote Romeo and Juliet?", "expected": "Shakespeare"}, + {"prompt": "What is the speed of light in km/s approximately?", "expected": "300"}, + {"prompt": "Name the first element on the periodic table.", "expected": "Hydrogen"}, + {"prompt": "What year did World War II end?", "expected": "1945"}, + {"prompt": "What is the square root of 144?", "expected": "12"}, + {"prompt": "What continent is Brazil in?", "expected": "South America"}, + {"prompt": "What is the boiling point of water in Celsius?", "expected": "100"}, + {"prompt": "Who painted the Mona Lisa?", "expected": "Vinci"}, + {"prompt": "What is the longest river in the world?", "expected": "Nile"}, + {"prompt": "Summarize the theory of relativity in one sentence."}, + {"prompt": "Explain how a neural network learns."}, + {"prompt": "Write a haiku about programming."}, + {"prompt": "What are the benefits of async programming in Python?"}, + {"prompt": "Describe the difference between TCP and UDP."}, + ] + # Repeat to get enough volume for meaningful batching + all_prompts = prompts * 5 # 100 prompts + + path = os.path.join(tempfile.gettempdir(), "demo_prompts.jsonl") + with open(path, "w") as f: + for p in all_prompts: + f.write(json.dumps(p) + "\n") + return path + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.run(batch_generate, max_new_tokens=128) + print(run.url)