diff --git a/.github/workflows/quick-test.yml b/.github/workflows/quick-test.yml index 21f849d..9651c2a 100644 --- a/.github/workflows/quick-test.yml +++ b/.github/workflows/quick-test.yml @@ -22,17 +22,32 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + key: ${{ runner.os }}-pip-quick-${{ hashFiles('**/requirements-ci.txt') }} restore-keys: | - ${{ runner.os }}-pip- + ${{ runner.os }}-pip-quick- + + - name: Free up disk space + run: | + # Remove unnecessary packages and clean up + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-cov - # Install only essential dependencies for quick tests - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install transformers jiwer librosa soundfile + pip install --no-cache-dir pytest pytest-cov + # Install only essential dependencies for quick tests with CPU-only PyTorch + pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu + pip install --no-cache-dir transformers jiwer librosa soundfile + + - name: Clean up pip cache + run: | + pip cache purge || true + df -h - name: Run quick unit tests run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ec4030..e9a4138 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,20 +26,38 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + key: ${{ runner.os }}-pip-ci-${{ hashFiles('**/requirements-ci.txt') }} restore-keys: | - ${{ runner.os }}-pip- + ${{ runner.os }}-pip-ci- + + - name: Free up disk space + run: | + # Remove unnecessary packages and clean up + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y ffmpeg libsndfile1 + sudo apt-get clean + sudo rm -rf /var/lib/apt/lists/* - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest pytest-cov pytest-mock + # Use CI requirements with CPU-only PyTorch to save space + pip install --no-cache-dir -r requirements-ci.txt + pip install --no-cache-dir pytest pytest-cov pytest-mock + + - name: Clean up pip cache + run: | + pip cache purge || true + df -h - name: Run unit tests run: | @@ -123,11 +141,27 @@ jobs: with: python-version: '3.9' + - name: Free up disk space + run: | + # Remove unnecessary packages and clean up + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest pytest-mock requests + # Use CI requirements with CPU-only PyTorch to save space + pip install --no-cache-dir -r requirements-ci.txt + pip install --no-cache-dir pytest pytest-mock requests + + - name: Clean up pip cache + run: | + pip cache purge || true + df -h - name: Start API server in background run: | diff --git a/.github/workflows/weekly-full-test.yml b/.github/workflows/weekly-full-test.yml index 9fe2bf7..f260e96 100644 --- a/.github/workflows/weekly-full-test.yml +++ b/.github/workflows/weekly-full-test.yml @@ -20,16 +20,34 @@ jobs: with: python-version: '3.9' + - name: Free up disk space + run: | + # Remove unnecessary packages and clean up + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -h + - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y ffmpeg libsndfile1 + sudo apt-get clean + sudo rm -rf /var/lib/apt/lists/* - name: Install all dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest pytest-cov pytest-html pytest-xdist + # Use CI requirements with CPU-only PyTorch to save space + pip install --no-cache-dir -r requirements-ci.txt + pip install --no-cache-dir pytest pytest-cov pytest-html pytest-xdist + + - name: Clean up pip cache + run: | + pip cache purge || true + df -h - name: Run all tests in parallel run: | diff --git a/Project Planning.pdf b/Project Planning.pdf new file mode 100644 index 0000000..bfe21ba Binary files /dev/null and b/Project Planning.pdf differ diff --git a/requirements-ci.txt b/requirements-ci.txt new file mode 100644 index 0000000..2ac0cb4 --- /dev/null +++ b/requirements-ci.txt @@ -0,0 +1,59 @@ +# CI-specific requirements with CPU-only PyTorch to save disk space +# Use CPU-only PyTorch (much smaller than CUDA version) +--extra-index-url https://download.pytorch.org/whl/cpu +torch>=2.0.0 +torchvision>=0.15.0 +torchaudio>=2.0.0 + +# Core ML libraries +transformers>=4.35.0 +accelerate>=0.24.0 +datasets>=2.14.0 +# Skip bitsandbytes and peft for CI (not needed for most tests) + +# Audio processing +librosa>=0.10.0 +soundfile>=0.12.0 +pydub>=0.25.0 +audioread>=3.0.0 + +# Evaluation +jiwer>=3.0.0 + +# Google Cloud (optional for CI, but keep for compatibility) +google-cloud-storage>=2.10.0 +gcsfs>=2023.6.0 + +# Data processing +pandas>=2.0.0 +numpy>=1.24.0 +scikit-learn>=1.3.0 +scipy>=1.11.0 + +# Visualization (minimal for CI) +matplotlib>=3.7.0 +seaborn>=0.12.0 + +# Experiment Tracking (optional for CI) +wandb>=0.16.0 + +# Utilities +tqdm>=4.65.0 +pyyaml>=6.0 +python-dotenv>=1.0.0 + +# Ollama for LLM inference (Llama 2/3) +ollama>=0.1.0 + +# Development +pytest>=7.4.0 +black>=23.0.0 +flake8>=6.0.0 + +fastapi>=0.104.0 +uvicorn>=0.24.0 +--only-binary=:all: numba +python-multipart>=0.0.5 + +importlib-metadata + diff --git a/src/agent/adaptive_scheduler.py b/src/agent/adaptive_scheduler.py index c7058a1..9a22def 100644 --- a/src/agent/adaptive_scheduler.py +++ b/src/agent/adaptive_scheduler.py @@ -441,24 +441,38 @@ def _load_history(self): self.error_samples_collected = data.get('error_samples_collected', 0) # Load fine-tuning history - self.fine_tuning_history = [ - FineTuningEvent(**event_data) - for event_data in data.get('fine_tuning_history', []) - ] + self.fine_tuning_history = [] + for event_data in data.get('fine_tuning_history', []): + try: + event = FineTuningEvent(**event_data) + self.fine_tuning_history.append(event) + except (TypeError, ValueError) as e: + logger.warning(f"Failed to load fine-tuning event: {e}") + continue # Load performance history (limited to window size) performance_data = data.get('performance_history', []) for metric_data in performance_data[-self.performance_window_size:]: - metric = PerformanceMetrics( - timestamp=datetime.fromisoformat(metric_data['timestamp']), - error_count=metric_data['error_count'], - accuracy=metric_data['accuracy'], - wer=metric_data.get('wer'), - cer=metric_data.get('cer'), - inference_time=metric_data.get('inference_time', 0.0), - cost_per_inference=metric_data.get('cost_per_inference', 0.0) - ) - self.performance_history.append(metric) + try: + # Validate timestamp before parsing + timestamp_str = metric_data.get('timestamp') + if not timestamp_str: + logger.warning("Missing timestamp in performance metric, skipping") + continue + + metric = PerformanceMetrics( + timestamp=datetime.fromisoformat(timestamp_str), + error_count=metric_data['error_count'], + accuracy=metric_data['accuracy'], + wer=metric_data.get('wer'), + cer=metric_data.get('cer'), + inference_time=metric_data.get('inference_time', 0.0), + cost_per_inference=metric_data.get('cost_per_inference', 0.0) + ) + self.performance_history.append(metric) + except (TypeError, ValueError, KeyError) as e: + logger.warning(f"Failed to load performance metric: {e}") + continue logger.info(f"Loaded scheduler history from {self.history_path}") except Exception as e: diff --git a/src/agent/agent.py b/src/agent/agent.py index c61d084..3fdf6cd 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -59,16 +59,17 @@ def __init__( self.llm_corrector = LlamaLLMCorrector( model_name=llm_model_name or "llama3.2:3b", use_quantization=False, # Not used for Ollama - fast_mode=True # Kept for compatibility + fast_mode=True, # Kept for compatibility + raise_on_error=False # Don't raise, just mark as unavailable ) if self.llm_corrector.is_available(): logger.info("✅ LLM corrector initialized successfully") else: logger.warning("⚠️ LLM not available, using rule-based correction only") - self.llm_corrector = None + # Keep llm_corrector object but it will be unavailable except Exception as e: - logger.error(f"❌ Failed to initialize LLM: {e}") - raise # Fail and alert if Ollama is not available + logger.warning(f"⚠️ Failed to initialize LLM: {e}. Continuing without LLM correction.") + self.llm_corrector = None # Initialize adaptive scheduler and fine-tuner (Week 3) self.enable_adaptive_fine_tuning = enable_adaptive_fine_tuning @@ -194,31 +195,30 @@ def transcribe_with_agent( # Record corrections for learning for error in errors: - if error.suggested_correction or correction_method.startswith("llama"): - self.self_learner.record_error( - error_type=error.error_type, - transcript=transcript, - context={ - 'audio_length': audio_length_seconds, - 'confidence': baseline_result.get('confidence'), - 'correction_method': correction_method - }, - correction=corrected_transcript if correction_method.startswith("llama") else error.suggested_correction - ) - - # Step 5: Record errors for learning (even if not corrected) - error_count = len(errors) - for error in errors: - self.self_learner.record_error( - error_type=error.error_type, - transcript=transcript, - context={ + # Build comprehensive context + context = { 'audio_path': audio_path, # Store path for fine-tuning 'audio_length': audio_length_seconds, 'confidence': baseline_result.get('confidence'), 'error_confidence': error.confidence } - ) + + # Add correction info if available + correction = None + if error.suggested_correction or correction_method.startswith("llama"): + context['correction_method'] = correction_method + correction = corrected_transcript if correction_method.startswith("llama") else error.suggested_correction + + # Record error with all context (single recording per error) + self.self_learner.record_error( + error_type=error.error_type, + transcript=transcript, + context=context, + correction=correction + ) + + # Step 5: Calculate error count for metrics + error_count = len(errors) # Step 6: Update adaptive scheduler and check for fine-tuning trigger (Week 3) fine_tuning_triggered = False @@ -264,7 +264,7 @@ def transcribe_with_agent( 'error_threshold': self.error_detector.min_confidence_threshold, 'auto_correction_enabled': enable_auto_correction, 'correction_method': correction_method, - 'llm_available': self.llm_corrector.is_available() if self.llm_corrector else False, + 'llm_available': self.llm_corrector.is_available() if (self.llm_corrector and hasattr(self.llm_corrector, 'is_available')) else False, 'adaptive_fine_tuning_enabled': self.enable_adaptive_fine_tuning, 'fine_tuning_triggered': fine_tuning_triggered } diff --git a/src/agent/error_detector.py b/src/agent/error_detector.py index fc824e2..e3af916 100644 --- a/src/agent/error_detector.py +++ b/src/agent/error_detector.py @@ -109,9 +109,9 @@ def detect_errors( # 3. Length anomaly detection (based on audio length) if audio_length_seconds: - expected_length = audio_length_seconds * 2.5 # ~2.5 chars per second average + expected_length = max(1, audio_length_seconds * 2.5) # ~2.5 chars per second average, minimum 1 actual_length = len(transcript) - ratio = actual_length / expected_length if expected_length > 0 else 1.0 + ratio = actual_length / expected_length if ratio > self.max_length_ratio: errors.append(ErrorSignal( diff --git a/src/agent/fine_tuner.py b/src/agent/fine_tuner.py index ae14dd7..e0a2a15 100644 --- a/src/agent/fine_tuner.py +++ b/src/agent/fine_tuner.py @@ -87,10 +87,10 @@ def __getitem__(self, idx): # Process text labels with self.processor.as_target_processor(): label_ids = self.processor( - sample['corrected_transcript'], - padding=True, - return_tensors="pt" - ) + sample['corrected_transcript'], + padding=True, + return_tensors="pt" + ) return { 'input_values': inputs.input_values.squeeze(0), diff --git a/src/agent/llm_corrector.py b/src/agent/llm_corrector.py index 8b8cce9..794ee34 100644 --- a/src/agent/llm_corrector.py +++ b/src/agent/llm_corrector.py @@ -26,7 +26,8 @@ def __init__( ollama_base_url: str = "http://localhost:11434", device: Optional[str] = None, use_quantization: bool = False, # Not used for Ollama, kept for compatibility - fast_mode: bool = True # Not used for Ollama, kept for compatibility + fast_mode: bool = True, # Not used for Ollama, kept for compatibility + raise_on_error: bool = False # If False, mark as unavailable instead of raising ): """ Initialize Ollama LLM corrector. @@ -37,6 +38,7 @@ def __init__( device: Not used for Ollama (kept for compatibility) use_quantization: Not used for Ollama (kept for compatibility) fast_mode: Not used for Ollama (kept for compatibility) + raise_on_error: If True, raise exceptions on initialization failure. If False, mark as unavailable. """ self.model_name = model_name self.ollama_base_url = ollama_base_url @@ -49,12 +51,20 @@ def __init__( try: self.ollama = OllamaLLM( model_name=model_name, - base_url=ollama_base_url + base_url=ollama_base_url, + raise_on_error=raise_on_error ) - logger.info(f"✅ Ollama LLM corrector initialized successfully with model: {model_name}") + if self.ollama.is_available(): + logger.info(f"✅ Ollama LLM corrector initialized successfully with model: {model_name}") + else: + logger.warning(f"⚠️ Ollama LLM corrector initialized but unavailable (server not running or model not found)") except Exception as e: - logger.error(f"Failed to initialize Ollama LLM: {e}") - raise # Fail and alert if Ollama is not available + if raise_on_error: + logger.error(f"Failed to initialize Ollama LLM: {e}") + raise # Fail and alert if Ollama is not available and raise_on_error=True + else: + logger.warning(f"Failed to initialize Ollama LLM: {e}. LLM correction will be unavailable.") + self.ollama = None def correct_transcript( self, @@ -267,7 +277,7 @@ def improve_transcript( def is_available(self) -> bool: """Check if Ollama LLM is available.""" - return self.ollama is not None and self.ollama.is_available() + return self.ollama is not None and hasattr(self.ollama, 'is_available') and self.ollama.is_available() def get_model_info(self) -> Dict: """Get information about the loaded model.""" diff --git a/src/agent/ollama_llm.py b/src/agent/ollama_llm.py index 637c886..900c2d2 100644 --- a/src/agent/ollama_llm.py +++ b/src/agent/ollama_llm.py @@ -32,7 +32,8 @@ class OllamaLLM: def __init__( self, model_name: str = "llama3.2:3b", - base_url: str = "http://localhost:11434" + base_url: str = "http://localhost:11434", + raise_on_error: bool = False ): """ Initialize Ollama LLM. @@ -40,21 +41,31 @@ def __init__( Args: model_name: Ollama model name (e.g., "llama3.2:3b", "llama3.1:8b", "llama2:7b") base_url: Ollama server URL (default: http://localhost:11434) + raise_on_error: If True, raise exceptions on initialization failure. If False, mark as unavailable. Raises: - ImportError: If Ollama package is not installed - ConnectionError: If Ollama server is not running - ValueError: If specified model is not available + ImportError: If Ollama package is not installed and raise_on_error=True + ConnectionError: If Ollama server is not running and raise_on_error=True + ValueError: If specified model is not available and raise_on_error=True """ if not OLLAMA_AVAILABLE: - raise ImportError( - "Ollama package not found. Install with: pip install ollama\n" - "Then install Ollama: https://ollama.ai/download" - ) + if raise_on_error: + raise ImportError( + "Ollama package not found. Install with: pip install ollama\n" + "Then install Ollama: https://ollama.ai/download" + ) + else: + logger.warning("Ollama package not found. Ollama LLM will be unavailable.") + self.model_name = model_name + self.base_url = base_url + self.client = None + self._available = False + return self.model_name = model_name self.base_url = base_url self.client = None + self._available = False logger.info(f"Initializing Ollama LLM with model: {model_name}") @@ -64,12 +75,20 @@ def __init__( ollama.list() # This will fail if server is not running logger.info("✓ Ollama server connection successful") except Exception as e: - raise ConnectionError( - f"Ollama server is not running or not accessible at {base_url}.\n" - f"Error: {e}\n" - f"Please start Ollama server with: ollama serve\n" - f"Or install Ollama from: https://ollama.ai/download" - ) + if raise_on_error: + raise ConnectionError( + f"Ollama server is not running or not accessible at {base_url}.\n" + f"Error: {e}\n" + f"Please start Ollama server with: ollama serve\n" + f"Or install Ollama from: https://ollama.ai/download" + ) + else: + logger.warning( + f"Ollama server is not running or not accessible at {base_url}. " + f"Ollama LLM will be unavailable. Error: {e}" + ) + self._available = False + return # Check if model is available try: @@ -138,30 +157,55 @@ def __init__( break if not model_found: - raise ValueError( - f"Model '{model_name}' is not available in Ollama.\n" - f"Available models: {', '.join(available_models) if available_models else 'None (no models installed)'}\n" - f"Please pull the model with: ollama pull {model_name}\n" - f"Supported models: llama3.2:3b, llama3.1:8b, llama2:7b\n" - f"Or use one of the available models above." - ) + if raise_on_error: + raise ValueError( + f"Model '{model_name}' is not available in Ollama.\n" + f"Available models: {', '.join(available_models) if available_models else 'None (no models installed)'}\n" + f"Please pull the model with: ollama pull {model_name}\n" + f"Supported models: llama3.2:3b, llama3.1:8b, llama2:7b\n" + f"Or use one of the available models above." + ) + else: + logger.warning( + f"Model '{model_name}' is not available in Ollama. " + f"Available models: {', '.join(available_models) if available_models else 'None (no models installed)'}. " + f"Ollama LLM will be unavailable." + ) + self._available = False + return logger.info(f"✓ Model '{model_name}' is available (matched: {matched_model})") + self._available = True except ValueError as e: - # Re-raise ValueError as-is (it has helpful messages) - raise + # Re-raise ValueError as-is if raise_on_error is True + if raise_on_error: + raise + else: + logger.warning(f"Model validation failed: {e}. Ollama LLM will be unavailable.") + self._available = False + return except KeyError as e: - raise RuntimeError( - f"Failed to parse Ollama model list: unexpected structure (key: {e})\n" - f"Please ensure Ollama is properly installed and running.\n" - f"You can verify by running: ollama list" - ) + if raise_on_error: + raise RuntimeError( + f"Failed to parse Ollama model list: unexpected structure (key: {e})\n" + f"Please ensure Ollama is properly installed and running.\n" + f"You can verify by running: ollama list" + ) + else: + logger.warning(f"Failed to parse Ollama model list: {e}. Ollama LLM will be unavailable.") + self._available = False + return except Exception as e: - raise RuntimeError( - f"Failed to check model availability: {e}\n" - f"Please ensure Ollama is properly installed and running.\n" - f"You can verify by running: ollama list" - ) + if raise_on_error: + raise RuntimeError( + f"Failed to check model availability: {e}\n" + f"Please ensure Ollama is properly installed and running.\n" + f"You can verify by running: ollama list" + ) + else: + logger.warning(f"Failed to check model availability: {e}. Ollama LLM will be unavailable.") + self._available = False + return logger.info(f"✅ Ollama LLM initialized successfully with model: {model_name}") @@ -179,14 +223,30 @@ def generate( Returns: Generated text + + Raises: + RuntimeError: If Ollama is not available or generation fails """ + if not self.is_available(): + raise RuntimeError("Ollama LLM is not available. Check initialization or start Ollama server.") + try: response = ollama.generate( model=self.model_name, prompt=prompt, **kwargs ) - return response.get('response', '') + + # Validate response type + if not isinstance(response, dict): + raise RuntimeError(f"Unexpected response type: {type(response)}. Expected dict.") + + # Extract and validate result + result = response.get('response', '') + if not result: + logger.warning("Ollama returned empty response") + + return result except Exception as e: logger.error(f"Ollama generation failed: {e}") raise RuntimeError(f"Failed to generate text with Ollama: {e}") @@ -205,14 +265,34 @@ def chat( Returns: Generated response + + Raises: + RuntimeError: If Ollama is not available or chat fails """ + if not self.is_available(): + raise RuntimeError("Ollama LLM is not available. Check initialization or start Ollama server.") + try: response = ollama.chat( model=self.model_name, messages=messages, **kwargs ) - return response.get('message', {}).get('content', '') + + # Validate response type + if not isinstance(response, dict): + raise RuntimeError(f"Unexpected response type: {type(response)}. Expected dict.") + + # Extract and validate result + message = response.get('message', {}) + if not isinstance(message, dict): + raise RuntimeError(f"Unexpected message type: {type(message)}. Expected dict.") + + result = message.get('content', '') + if not result: + logger.warning("Ollama returned empty chat response") + + return result except Exception as e: logger.error(f"Ollama chat failed: {e}") raise RuntimeError(f"Failed to chat with Ollama: {e}") @@ -224,6 +304,11 @@ def is_available(self) -> bool: Returns: True if Ollama is available, False otherwise """ + # If _available attribute exists, use it (set during initialization) + if hasattr(self, '_available'): + return self._available + + # Fallback: check if Ollama is available if not OLLAMA_AVAILABLE: return False diff --git a/src/agent_api.py b/src/agent_api.py index d87743d..41e86f8 100644 --- a/src/agent_api.py +++ b/src/agent_api.py @@ -5,14 +5,12 @@ from fastapi import FastAPI, UploadFile, File, HTTPException, Body from pydantic import BaseModel -import tempfile -import os -import librosa from typing import Optional -import time from src.baseline_model import BaselineSTTModel from src.agent import STTAgent +from src.utils.api_helpers import handle_audio_upload, transcribe_with_timing, transcribe_agent_with_timing +from src.utils.file_utils import cleanup_temp_file app = FastAPI(title="STT Agent API", version="2.0.0") @@ -64,25 +62,24 @@ async def transcribe(file: UploadFile = File(...)): curl -X POST "http://localhost:8000/transcribe" \\ -F "file=@audio.wav" """ + tmp_path = None try: # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + tmp_path = await handle_audio_upload(file) - # Transcribe - start = time.time() - result = baseline_model.transcribe(tmp_path) - result["inference_time_seconds"] = time.time() - start - - # Cleanup - os.remove(tmp_path) + # Transcribe with timing + result = transcribe_with_timing(baseline_model, tmp_path) return result + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + finally: + # Cleanup + if tmp_path: + cleanup_temp_file(tmp_path) @app.post("/agent/transcribe") @@ -101,34 +98,28 @@ async def agent_transcribe( curl -X POST "http://localhost:8000/agent/transcribe?auto_correction=true" \\ -F "file=@audio.wav" """ + tmp_path = None try: # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name - - # Get audio length for error detection - try: - audio, sr = librosa.load(tmp_path, sr=16000) - audio_length = len(audio) / sr - except: - audio_length = None + tmp_path = await handle_audio_upload(file) - # Transcribe with agent - result = agent.transcribe_with_agent( - audio_path=tmp_path, - audio_length_seconds=audio_length, + # Transcribe with agent (includes audio length calculation) + result = transcribe_agent_with_timing( + agent, + tmp_path, enable_auto_correction=auto_correction ) - # Cleanup - os.remove(tmp_path) - return result + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + finally: + # Cleanup + if tmp_path: + cleanup_temp_file(tmp_path) @app.post("/agent/feedback") diff --git a/src/agent_evaluation/agent_evaluator.py b/src/agent_evaluation/agent_evaluator.py index b02f44d..c086d54 100644 --- a/src/agent_evaluation/agent_evaluator.py +++ b/src/agent_evaluation/agent_evaluator.py @@ -12,8 +12,8 @@ import time from datetime import datetime -from jiwer import wer, cer import numpy as np +from src.evaluation.metrics import STTEvaluator logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -79,6 +79,7 @@ def __init__(self, self.output_dir.mkdir(parents=True, exist_ok=True) self.results: List[CorrectionResult] = [] + self.evaluator = STTEvaluator() logger.info("Agent Evaluator initialized") @@ -117,10 +118,10 @@ def evaluate_correction(self, corrected_cer = 0.0 if reference_transcript: - original_wer = wer(reference_transcript, original_transcript) - original_cer = cer(reference_transcript, original_transcript) - corrected_wer = wer(reference_transcript, corrected_transcript) - corrected_cer = cer(reference_transcript, corrected_transcript) + original_wer = self.evaluator.calculate_wer(reference_transcript, original_transcript) + original_cer = self.evaluator.calculate_cer(reference_transcript, original_transcript) + corrected_wer = self.evaluator.calculate_wer(reference_transcript, corrected_transcript) + corrected_cer = self.evaluator.calculate_cer(reference_transcript, corrected_transcript) # Calculate improvements (negative means correction made it worse) wer_improvement = original_wer - corrected_wer diff --git a/src/baseline_model.py b/src/baseline_model.py index 9bb1846..b564174 100644 --- a/src/baseline_model.py +++ b/src/baseline_model.py @@ -178,18 +178,29 @@ def transcribe(self, audio_path: str) -> Dict[str, str]: inputs = self.processor(audio, sampling_rate=sr, return_tensors="pt") with torch.no_grad(): input_features = inputs["input_features"].to(self.device) + + # Prepare generation kwargs with modern parameters to avoid deprecation warnings + # Using task and language parameters instead of deprecated forced_decoder_ids + generate_kwargs = { + "input_features": input_features, + "max_new_tokens": 128, + "task": "transcribe", # Explicitly set task to avoid forced_decoder_ids deprecation + "language": None, # None = auto-detect language (set to "en" for English-only transcription) + } + + # Set pad_token_id to eos_token_id to avoid attention mask warning + # This tells the model how to handle padding during generation + if hasattr(self.model.config, 'eos_token_id') and self.model.config.eos_token_id is not None: + generate_kwargs["pad_token_id"] = self.model.config.eos_token_id + if self.device.startswith("cuda"): - predicted_ids = self.model.generate( - input_features, - max_new_tokens=128, - num_beams=5, - use_cache=True - ) - else: - predicted_ids = self.model.generate( - input_features, - max_new_tokens=128 - ) + generate_kwargs.update({ + "num_beams": 5, + "use_cache": True + }) + + predicted_ids = self.model.generate(**generate_kwargs) + transcript = self.processor.batch_decode( predicted_ids, skip_special_tokens=True diff --git a/src/control_panel_api.py b/src/control_panel_api.py index b50384e..a95eee7 100644 --- a/src/control_panel_api.py +++ b/src/control_panel_api.py @@ -8,12 +8,10 @@ from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -import tempfile import os import time import asyncio import random -import librosa from typing import Optional, List, Dict, Any from pathlib import Path import json @@ -27,7 +25,8 @@ from src.data.finetuning_orchestrator import FinetuningConfig from src.evaluation.metrics import STTEvaluator from src.agent.llm_corrector import LlamaLLMCorrector -from jiwer import wer, cer +from src.utils.api_helpers import handle_audio_upload, transcribe_with_timing, transcribe_agent_with_timing +from src.utils.file_utils import cleanup_temp_file, load_audio_duration from src.constants import ( MIN_SAMPLES_FOR_FINETUNING, RECOMMENDED_SAMPLES_FOR_FINETUNING, @@ -776,6 +775,7 @@ async def transcribe_baseline( Transcribe audio with baseline model only (no LLM correction) Faster than agent mode since no LLM processing is involved """ + tmp_path = None try: # Get the appropriate model instance stt_model, _ = get_model_and_agent(model) @@ -783,14 +783,11 @@ async def transcribe_baseline( # Verify which model is actually being used logger.info(f"📝 Transcribing with model: {model} -> Actual: {stt_model.model_name}, Path: {stt_model.model_path}") - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + # Save uploaded file temporarily + tmp_path = await handle_audio_upload(file) - start = time.time() - result = stt_model.transcribe(tmp_path) - result["inference_time_seconds"] = time.time() - start + # Transcribe with timing + result = transcribe_with_timing(stt_model, tmp_path) result["model_used"] = model result["model_name"] = stt_model.model_name # Include actual model name result["model_path"] = stt_model.model_path # Include model path for verification @@ -806,11 +803,15 @@ async def transcribe_baseline( error_score = result.get("error_detection", {}).get("error_score", 0.0) perf_counters["sum_error_scores"] += error_score - os.remove(tmp_path) return result + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + finally: + if tmp_path: + cleanup_temp_file(tmp_path) @app.post("/api/transcribe/agent") @@ -833,27 +834,15 @@ async def transcribe_agent( logger.info(f"📝 Transcribing with agent. Model: {model} -> Actual: {stt_model.model_name}, Path: {stt_model.model_path}") logger.info(f" Agent's baseline_model: {stt_agent.baseline_model.model_name}, Path: {stt_agent.baseline_model.model_path}") - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + # Save uploaded file temporarily + tmp_path = await handle_audio_upload(file) - # Get audio length - try: - audio, sr = librosa.load(tmp_path, sr=16000) - audio_length = len(audio) / sr - except Exception as e: - print(f"Warning: Could not load audio for length calculation: {e}") - audio_length = None - - # Transcribe with agent (this includes STT + LLM correction if enabled) - start = time.time() - result = stt_agent.transcribe_with_agent( - audio_path=tmp_path, - audio_length_seconds=audio_length, + # Transcribe with agent (includes audio length calculation) + result = transcribe_agent_with_timing( + stt_agent, + tmp_path, enable_auto_correction=auto_correction ) - result["inference_time_seconds"] = result.get("inference_time_seconds", time.time() - start) result["model_used"] = model result["model_name"] = stt_model.model_name # Include actual model name for verification @@ -919,8 +908,6 @@ async def transcribe_agent( except Exception as e: print(f"Failed to record error: {e}") - os.remove(tmp_path) - # Update perf counters perf_counters["total_inferences"] += 1 perf_counters["total_inference_time"] += result.get("inference_time_seconds", 0.0) @@ -931,10 +918,15 @@ async def transcribe_agent( return result + except HTTPException: + raise except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + finally: + if tmp_path: + cleanup_temp_file(tmp_path) # ==================== AGENT MANAGEMENT ==================== diff --git a/src/data/finetuning_orchestrator.py b/src/data/finetuning_orchestrator.py index 163e818..8253164 100644 --- a/src/data/finetuning_orchestrator.py +++ b/src/data/finetuning_orchestrator.py @@ -284,7 +284,8 @@ def create_finetuning_job( Returns: FinetuningJob instance """ - job_id = f"ft_job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + # Use microseconds to ensure unique job IDs even when created in rapid succession + job_id = f"ft_job_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}" job = FinetuningJob( job_id=job_id, diff --git a/src/data/model_validator.py b/src/data/model_validator.py index ea2de97..7e9b441 100644 --- a/src/data/model_validator.py +++ b/src/data/model_validator.py @@ -54,7 +54,22 @@ def to_dict(self) -> Dict: result = asdict(self) if self.per_sample_results is None: result['per_sample_results'] = [] - return result + + # Convert numpy types to native Python types for JSON serialization + def convert_numpy_types(obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.bool_): + return bool(obj) + elif isinstance(obj, dict): + return {k: convert_numpy_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_types(item) for item in obj] + return obj + + return convert_numpy_types(result) @dataclass @@ -372,9 +387,9 @@ def _test_significance( t_stat, p_value = stats.ttest_1samp(improvements, 0) # One-tailed test (we care if model is better, not just different) - p_value = p_value / 2 if t_stat > 0 else 1.0 + p_value = float(p_value / 2 if t_stat > 0 else 1.0) - is_significant = p_value < self.config.significance_alpha + is_significant = bool(p_value < self.config.significance_alpha) return is_significant, p_value diff --git a/src/inference_api.py b/src/inference_api.py index 886aa63..c1d4538 100644 --- a/src/inference_api.py +++ b/src/inference_api.py @@ -4,10 +4,9 @@ """ from fastapi import FastAPI, UploadFile, File, HTTPException -import tempfile -import os from src.baseline_model import BaselineSTTModel -import time +from src.utils.api_helpers import handle_audio_upload, transcribe_with_timing +from src.utils.file_utils import cleanup_temp_file app = FastAPI(title="STT Baseline API") @@ -29,25 +28,24 @@ async def transcribe(file: UploadFile = File(...)): curl -X POST "http://localhost:8000/transcribe" \\ -F "file=@audio.wav" """ + tmp_path = None try: # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + tmp_path = await handle_audio_upload(file) - # Transcribe - start = time.time() - result = model.transcribe(tmp_path) - result["inference_time_seconds"] = time.time() - start - - # Cleanup - os.remove(tmp_path) + # Transcribe with timing + result = transcribe_with_timing(model, tmp_path) return result + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + finally: + # Cleanup + if tmp_path: + cleanup_temp_file(tmp_path) @app.get("/health") async def health(): diff --git a/src/utils/api_helpers.py b/src/utils/api_helpers.py new file mode 100644 index 0000000..b4658ea --- /dev/null +++ b/src/utils/api_helpers.py @@ -0,0 +1,83 @@ +""" +Common API helper functions to reduce duplication in FastAPI endpoints. +""" + +import time +from typing import Dict, Any +from fastapi import UploadFile, HTTPException +import librosa + +from ..utils.file_utils import save_uploaded_file, cleanup_temp_file, load_audio_duration + + +async def handle_audio_upload(file: UploadFile, suffix: str = ".wav") -> str: + """ + Handle audio file upload and return temporary file path. + + Args: + file: Uploaded file from FastAPI + suffix: File suffix (default: ".wav") + + Returns: + Path to temporary file + + Raises: + HTTPException: If file reading fails + """ + try: + content = await file.read() + return save_uploaded_file(content, suffix=suffix) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to read uploaded file: {str(e)}") + + +def transcribe_with_timing(model, audio_path: str, **kwargs) -> Dict[str, Any]: + """ + Transcribe audio and add timing information. + + Args: + model: Model with transcribe method + audio_path: Path to audio file + **kwargs: Additional arguments to pass to transcribe method + + Returns: + Transcription result with inference_time_seconds added + """ + start_time = time.time() + result = model.transcribe(audio_path, **kwargs) + result["inference_time_seconds"] = time.time() - start_time + return result + + +def transcribe_agent_with_timing(agent, audio_path: str, enable_auto_correction: bool = True, **kwargs) -> Dict[str, Any]: + """ + Transcribe with agent and add timing information. + + Args: + agent: Agent with transcribe_with_agent method + audio_path: Path to audio file + enable_auto_correction: Whether to enable auto correction + **kwargs: Additional arguments to pass to transcribe_with_agent method + + Returns: + Transcription result with timing information + """ + # Get audio length if not provided + audio_length = kwargs.pop('audio_length_seconds', None) + if audio_length is None: + audio_length = load_audio_duration(audio_path) + + start_time = time.time() + result = agent.transcribe_with_agent( + audio_path=audio_path, + audio_length_seconds=audio_length, + enable_auto_correction=enable_auto_correction, + **kwargs + ) + + # Add timing if not already present + if "inference_time_seconds" not in result: + result["inference_time_seconds"] = time.time() - start_time + + return result + diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py new file mode 100644 index 0000000..0e9e1a3 --- /dev/null +++ b/src/utils/file_utils.py @@ -0,0 +1,61 @@ +""" +Common file handling utilities used across the codebase. +Only includes functions that are actually used to avoid unnecessary abstraction. +""" + +import os +import tempfile +from typing import Optional +import librosa +import logging + +logger = logging.getLogger(__name__) + + +def load_audio_duration(audio_path: str, sample_rate: int = 16000) -> Optional[float]: + """ + Load audio file and return its duration in seconds. + + Args: + audio_path: Path to audio file + sample_rate: Target sample rate (default: 16000) + + Returns: + Audio duration in seconds, or None if loading fails + """ + try: + audio, sr = librosa.load(audio_path, sr=sample_rate) + return len(audio) / sr + except Exception as e: + logger.warning(f"Could not load audio for duration calculation: {e}") + return None + + +def save_uploaded_file(file_content: bytes, suffix: str = ".wav") -> str: + """ + Save uploaded file content to a temporary file. + + Args: + file_content: File content as bytes + suffix: File suffix (default: ".wav") + + Returns: + Path to temporary file + """ + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + tmp.write(file_content) + return tmp.name + + +def cleanup_temp_file(file_path: str) -> None: + """ + Safely remove a temporary file. + + Args: + file_path: Path to temporary file + """ + try: + if os.path.exists(file_path): + os.remove(file_path) + except Exception as e: + logger.warning(f"Could not remove temporary file {file_path}: {e}") diff --git a/tests/test_finetuning_orchestrator.py b/tests/test_finetuning_orchestrator.py index aad2be8..09c41b2 100644 --- a/tests/test_finetuning_orchestrator.py +++ b/tests/test_finetuning_orchestrator.py @@ -163,6 +163,8 @@ def test_prepare_dataset_for_job(self, mock_pipeline, orchestrator): 'is_valid': True, 'issues': [] } + # Mock output_dir as a Path object to support path operations + mock_pipeline_instance.output_dir = Path("/tmp/test_output") orchestrator.dataset_pipeline = mock_pipeline_instance # Mock version control @@ -202,6 +204,8 @@ def test_trigger_finetuning_force(self, mock_pipeline, orchestrator, mock_data_m 'is_valid': True, 'issues': [] } + # Mock output_dir as a Path object to support path operations + mock_pipeline_instance.output_dir = Path("/tmp/test_output") orchestrator.dataset_pipeline = mock_pipeline_instance # Mock version control diff --git a/tests/test_integration.py b/tests/test_integration.py index 6aed978..eeda4d0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -15,14 +15,31 @@ import shutil +def check_ollama_available(): + """Check if Ollama is available for testing.""" + try: + import ollama + # Try to connect to Ollama server + ollama.list() + return True + except (ImportError, Exception): + return False + + class TestCompleteWorkflow: """Test complete workflow integration""" def setup_method(self): - """Setup test components""" + """Setup test components - requires Ollama""" + if not check_ollama_available(): + pytest.skip("Ollama not available - install with: pip install ollama and start server: ollama serve") + self.test_dir = tempfile.mkdtemp() self.baseline_model = BaselineSTTModel(model_name="whisper") - self.agent = STTAgent(baseline_model=self.baseline_model) + try: + self.agent = STTAgent(baseline_model=self.baseline_model) + except (ImportError, RuntimeError) as e: + pytest.skip(f"Ollama not available: {e}") self.data_system = IntegratedDataManagementSystem( base_dir=self.test_dir, use_gcs=False @@ -30,7 +47,7 @@ def setup_method(self): def teardown_method(self): """Cleanup test directory""" - if Path(self.test_dir).exists(): + if hasattr(self, 'test_dir') and Path(self.test_dir).exists(): shutil.rmtree(self.test_dir) @pytest.mark.skipif(not Path("data/test_audio/test_1.wav").exists(), @@ -186,10 +203,16 @@ class TestAgentDataIntegration: """Test integration between agent and data management""" def setup_method(self): - """Setup test components""" + """Setup test components - requires Ollama""" + if not check_ollama_available(): + pytest.skip("Ollama not available - install with: pip install ollama and start server: ollama serve") + self.test_dir = tempfile.mkdtemp() self.baseline_model = BaselineSTTModel(model_name="whisper") - self.agent = STTAgent(baseline_model=self.baseline_model) + try: + self.agent = STTAgent(baseline_model=self.baseline_model) + except (ImportError, RuntimeError) as e: + pytest.skip(f"Ollama not available: {e}") self.data_system = IntegratedDataManagementSystem( base_dir=self.test_dir, use_gcs=False @@ -197,7 +220,7 @@ def setup_method(self): def teardown_method(self): """Cleanup""" - if Path(self.test_dir).exists(): + if hasattr(self, 'test_dir') and Path(self.test_dir).exists(): shutil.rmtree(self.test_dir) def test_agent_stats_and_data_stats_consistency(self):