diff --git a/Dockerfile b/Dockerfile index ab818ef5..6b384e8d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,8 @@ # Build argument for custom certificates directory ARG CUSTOM_CERT_DIR="certs" +# Build argument for AST chunking (default to false) +ARG AST_CHUNKING=false FROM node:20-alpine3.22 AS node_base @@ -67,6 +69,15 @@ ENV PATH="/opt/venv/bin:$PATH" COPY --from=py_deps /opt/venv /opt/venv COPY api/ ./api/ +# Configure AST chunking based on build argument +RUN if [ "$AST_CHUNKING" = "true" ]; then \ + echo "šŸš€ Enabling AST chunking during build..."; \ + cd /app/api && python enable_ast.py enable; \ + else \ + echo "šŸ“ Using default text chunking..."; \ + cd /app/api && python enable_ast.py disable; \ + fi + # Copy Node app COPY --from=node_builder /app/public ./public COPY --from=node_builder /app/.next/standalone ./ diff --git a/api/api.py b/api/api.py index d40e73f9..5683f3bd 100644 --- a/api/api.py +++ b/api/api.py @@ -9,20 +9,215 @@ from pydantic import BaseModel, Field import google.generativeai as genai import asyncio +from collections import defaultdict +import fnmatch # Configure logging from api.logging_config import setup_logging +from api.config import load_repo_config setup_logging() logger = logging.getLogger(__name__) # Initialize FastAPI app -app = FastAPI( - title="Streaming API", - description="API for streaming chat completions" +app = FastAPI() + +# Configure CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, specify your frontend domain + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) + +# Pydantic models for wiki pages +class WikiPage(BaseModel): + id: str + title: str + content: str + related_pages: List[str] = [] + + +# ============================================================================ +# INTELLIGENT FILE CHUNKING SYSTEM +# ============================================================================ + +def should_exclude_dir(dir_name: str, excluded_patterns: List[str]) -> bool: + """Check if directory should be excluded based on patterns.""" + # Always exclude hidden directories and common build/cache dirs + if dir_name.startswith('.'): + return True + if dir_name in ['__pycache__', 'node_modules', '.venv', 'venv', 'env', + 'image-cache', 'dist', 'build', 'target', 'out']: + return True + + # Check against user-defined patterns + for pattern in excluded_patterns: + pattern_clean = pattern.strip('./').rstrip('/') + if fnmatch.fnmatch(dir_name, pattern_clean): + return True + return False + + +def should_exclude_file(file_name: str, excluded_patterns: List[str]) -> bool: + """Check if file should be excluded based on patterns.""" + # Always exclude hidden files and common files + if file_name.startswith('.') or file_name == '__init__.py' or file_name == '.DS_Store': + return True + + # Check against user-defined patterns + for pattern in excluded_patterns: + if fnmatch.fnmatch(file_name, pattern): + return True + return False + + +def collect_all_files(path: str, config: Dict) -> tuple[List[str], str]: + """ + Collect ALL files from repository respecting include/exclude patterns. + Also finds and reads README.md during the same walk. + + Args: + path: Root directory path + config: Configuration with excluded_dirs and excluded_files + + Returns: + Tuple of (list of relative file paths, README content string) + """ + all_files = [] + readme_content = "" + excluded_dirs = config.get('excluded_dirs', []) + excluded_files = config.get('excluded_files', []) + + logger.info(f"Collecting files from {path}") + logger.info(f"Excluded dirs: {len(excluded_dirs)} patterns") + logger.info(f"Excluded files: {len(excluded_files)} patterns") + + for root, dirs, files in os.walk(path): + # Filter directories in-place + dirs[:] = [d for d in dirs if not should_exclude_dir(d, excluded_dirs)] + + for file in files: + if not should_exclude_file(file, excluded_files): + rel_dir = os.path.relpath(root, path) + rel_file = os.path.join(rel_dir, file) if rel_dir != '.' else file + all_files.append(rel_file) + + # Find README.md (case-insensitive) during the same walk + if file.lower() == 'readme.md' and not readme_content: + try: + with open(os.path.join(root, file), 'r', encoding='utf-8') as f: + readme_content = f.read() + logger.info(f"Found README.md at: {rel_file}") + except Exception as e: + logger.warning(f"Could not read README.md at {rel_file}: {str(e)}") + + logger.info(f"Collected {len(all_files)} files after filtering") + return all_files, readme_content + + +def group_files_by_directory(files: List[str]) -> Dict[str, List[str]]: + """Group files by their parent directory.""" + by_dir = defaultdict(list) + + for file_path in files: + dir_name = os.path.dirname(file_path) + if not dir_name: + dir_name = "root" + by_dir[dir_name].append(file_path) + + return dict(by_dir) + + +def create_file_chunks(files: List[str], max_files_per_chunk: int = 500) -> List[Dict[str, Any]]: + """ + Create intelligent chunks of files grouped by directory. + Ensures no chunk exceeds max_files_per_chunk by splitting large directories. + + Args: + files: List of all file paths + max_files_per_chunk: Maximum files per chunk + + Returns: + List of chunk dictionaries with metadata + """ + # Group by directory + by_dir = group_files_by_directory(files) + + chunks = [] + current_chunk_files = [] + current_chunk_dirs = [] + + for dir_name, dir_files in sorted(by_dir.items()): + # Handle large directories that exceed max_files_per_chunk on their own + if len(dir_files) > max_files_per_chunk: + # First, save current chunk if it has files + if current_chunk_files: + chunks.append({ + 'files': current_chunk_files[:], + 'directories': current_chunk_dirs[:], + 'file_count': len(current_chunk_files) + }) + current_chunk_files = [] + current_chunk_dirs = [] + + # Split large directory across multiple chunks + logger.warning(f"Directory '{dir_name}' has {len(dir_files)} files, splitting across multiple chunks") + for i in range(0, len(dir_files), max_files_per_chunk): + chunk_slice = dir_files[i:i + max_files_per_chunk] + chunks.append({ + 'files': chunk_slice, + 'directories': [f"{dir_name} (part {i//max_files_per_chunk + 1})"], + 'file_count': len(chunk_slice) + }) + else: + # Normal case: check if adding this directory would exceed limit + if current_chunk_files and len(current_chunk_files) + len(dir_files) > max_files_per_chunk: + # Save current chunk and start new one + chunks.append({ + 'files': current_chunk_files[:], + 'directories': current_chunk_dirs[:], + 'file_count': len(current_chunk_files) + }) + current_chunk_files = [] + current_chunk_dirs = [] + + # Add directory to current chunk + current_chunk_files.extend(dir_files) + current_chunk_dirs.append(dir_name) + + # Add final chunk if it has files + if current_chunk_files: + chunks.append({ + 'files': current_chunk_files, + 'directories': current_chunk_dirs, + 'file_count': len(current_chunk_files) + }) + + logger.info(f"Created {len(chunks)} chunks from {len(files)} files") + for i, chunk in enumerate(chunks): + logger.info(f" Chunk {i+1}: {chunk['file_count']} files across {len(chunk['directories'])} directories") + + return chunks + + +def format_chunk_as_tree(chunk: Dict[str, Any]) -> str: + """Format a chunk of files as a tree string.""" + files = chunk['files'] + tree_lines = sorted(files) + + # Add chunk metadata + chunk_info = f"# Chunk contains {len(files)} files from {len(chunk['directories'])} directories\n" + chunk_info += f"# Directories: {', '.join(chunk['directories'][:5])}" + if len(chunk['directories']) > 5: + chunk_info += f" ... and {len(chunk['directories']) - 5} more" + chunk_info += "\n\n" + + return chunk_info + '\n'.join(tree_lines) + # Configure CORS app.add_middleware( CORSMiddleware, @@ -273,8 +468,19 @@ async def export_wiki(request: WikiExportRequest): raise HTTPException(status_code=500, detail=error_msg) @app.get("/local_repo/structure") -async def get_local_repo_structure(path: str = Query(None, description="Path to local repository")): - """Return the file tree and README content for a local repository.""" +async def get_local_repo_structure( + path: str = Query(None, description="Path to local repository"), + chunk_size: int = Query(500, description="Maximum files per chunk"), + return_chunks: bool = Query(False, description="Return chunked structure for large repos") +): + """ + Return the file tree and README content for a local repository. + + Now supports intelligent chunking for large repositories: + - Collects ALL files respecting include/exclude patterns + - Groups files by directory + - Returns chunks if repository is large + """ if not path: return JSONResponse( status_code=400, @@ -288,30 +494,48 @@ async def get_local_repo_structure(path: str = Query(None, description="Path to ) try: - logger.info(f"Processing local repository at: {path}") - file_tree_lines = [] - readme_content = "" - - for root, dirs, files in os.walk(path): - # Exclude hidden dirs/files and virtual envs - dirs[:] = [d for d in dirs if not d.startswith('.') and d != '__pycache__' and d != 'node_modules' and d != '.venv'] - for file in files: - if file.startswith('.') or file == '__init__.py' or file == '.DS_Store': - continue - rel_dir = os.path.relpath(root, path) - rel_file = os.path.join(rel_dir, file) if rel_dir != '.' else file - file_tree_lines.append(rel_file) - # Find README.md (case-insensitive) - if file.lower() == 'readme.md' and not readme_content: - try: - with open(os.path.join(root, file), 'r', encoding='utf-8') as f: - readme_content = f.read() - except Exception as e: - logger.warning(f"Could not read README.md: {str(e)}") - readme_content = "" - - file_tree_str = '\n'.join(sorted(file_tree_lines)) - return {"file_tree": file_tree_str, "readme": readme_content} + logger.info(f"Processing local repository at: {path} (chunk_size={chunk_size}, return_chunks={return_chunks})") + + # Load configuration from repo.json (imported at the top) + config_data = load_repo_config() + file_filters = config_data.get('file_filters', {}) + + # Collect ALL files respecting patterns and find README in one pass + all_files, readme_content = collect_all_files(path, file_filters) + + # Decide whether to chunk based on repository size + total_files = len(all_files) + logger.info(f"Total files collected: {total_files}") + + if return_chunks or total_files > chunk_size: + # Create intelligent chunks + chunks = create_file_chunks(all_files, max_files_per_chunk=chunk_size) + + return { + "chunked": True, + "total_files": total_files, + "chunk_count": len(chunks), + "chunks": [ + { + "chunk_id": i, + "file_count": chunk['file_count'], + "directories": chunk['directories'], + "file_tree": format_chunk_as_tree(chunk) + } + for i, chunk in enumerate(chunks) + ], + "readme": readme_content + } + else: + # Small repo, return as single tree + file_tree_str = '\n'.join(sorted(all_files)) + return { + "chunked": False, + "total_files": total_files, + "file_tree": file_tree_str, + "readme": readme_content + } + except Exception as e: logger.error(f"Error processing local repository: {str(e)}") return JSONResponse( diff --git a/api/ast_chunker.py b/api/ast_chunker.py new file mode 100644 index 00000000..eb78023b --- /dev/null +++ b/api/ast_chunker.py @@ -0,0 +1,528 @@ +""" +AST-based chunking for code files using LlamaIndex-inspired approach. +This provides semantic chunking that respects code structure. +""" + +import ast +import logging +from typing import List, Dict, Any, Optional, Union +from pathlib import Path +from dataclasses import dataclass + +# Use a simple token counter if tiktoken is not available +try: + import tiktoken + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + +logger = logging.getLogger(__name__) + +@dataclass +class CodeChunk: + """Represents a semantically meaningful chunk of code.""" + content: str + chunk_type: str # 'function', 'class', 'module', 'import_block', 'comment_block' + name: Optional[str] = None # function/class name + start_line: int = 0 + end_line: int = 0 + file_path: str = "" + dependencies: List[str] = None # imported modules/functions this chunk depends on + + def __post_init__(self): + if self.dependencies is None: + self.dependencies = [] + + +class ASTChunker: + """AST-based code chunker that creates semantically meaningful chunks.""" + + def __init__(self, + max_chunk_size: int = 2000, + min_chunk_size: int = 100, + overlap_lines: int = 5, + preserve_structure: bool = True): + """ + Initialize AST chunker. + + Args: + max_chunk_size: Maximum tokens per chunk + min_chunk_size: Minimum tokens per chunk + overlap_lines: Lines of overlap between chunks + preserve_structure: Whether to keep related code together + """ + self.max_chunk_size = max_chunk_size + self.min_chunk_size = min_chunk_size + self.overlap_lines = overlap_lines + self.preserve_structure = preserve_structure + + # Initialize token encoder with fallback + if TIKTOKEN_AVAILABLE: + self.encoding = tiktoken.get_encoding("cl100k_base") + else: + # Simple fallback: approximate 4 chars per token + self.encoding = None + + def chunk_file(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk a single file based on its type and structure.""" + file_ext = Path(file_path).suffix.lower() + + # Route to appropriate chunker based on file type + if file_ext == '.py': + return self._chunk_python(file_path, content) + elif file_ext in ['.js', '.ts', '.jsx', '.tsx']: + return self._chunk_javascript(file_path, content) + elif file_ext in ['.java', '.kt']: + return self._chunk_java_kotlin(file_path, content) + elif file_ext in ['.cpp', '.cc', '.cxx', '.c', '.h', '.hpp']: + return self._chunk_cpp(file_path, content) + elif file_ext in ['.rs']: + return self._chunk_rust(file_path, content) + elif file_ext in ['.go']: + return self._chunk_go(file_path, content) + elif file_ext in ['.md', '.rst', '.txt']: + return self._chunk_markdown(file_path, content) + elif file_ext in ['.json', '.yaml', '.yml', '.toml']: + return self._chunk_config(file_path, content) + else: + # Fall back to text-based chunking + return self._chunk_text(file_path, content) + + def _chunk_python(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk Python code using AST analysis.""" + chunks = [] + + try: + tree = ast.parse(content) + lines = content.split('\n') + + # Group imports at the top + imports = [] + other_nodes = [] + + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + imports.append(node) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + other_nodes.append(node) + + # Create import chunk if imports exist + if imports: + import_lines = [] + for imp in imports: + if hasattr(imp, 'lineno'): + import_lines.extend(range(imp.lineno - 1, + getattr(imp, 'end_lineno', imp.lineno))) + + if import_lines: + import_content = '\n'.join(lines[min(import_lines):max(import_lines)+1]) + chunks.append(CodeChunk( + content=import_content, + chunk_type='import_block', + start_line=min(import_lines) + 1, + end_line=max(import_lines) + 1, + file_path=file_path + )) + + # Process classes and functions + for node in other_nodes: + if isinstance(node, ast.ClassDef): + chunk = self._extract_class_chunk(node, lines, file_path) + if chunk: + chunks.append(chunk) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + chunk = self._extract_function_chunk(node, lines, file_path) + if chunk: + chunks.append(chunk) + + # Handle module-level code + module_code = self._extract_module_level_code(tree, lines, file_path) + if module_code: + chunks.extend(module_code) + + except SyntaxError as e: + logger.warning(f"Could not parse Python file {file_path}: {e}") + # Fall back to text chunking + return self._chunk_text(file_path, content) + + return self._optimize_chunks(chunks) + + def _extract_class_chunk(self, node: ast.ClassDef, + lines: List[str], file_path: str) -> Optional[CodeChunk]: + """Extract a class and its methods as a chunk.""" + start_line = node.lineno - 1 + end_line = getattr(node, 'end_lineno', node.lineno) - 1 + + class_content = '\n'.join(lines[start_line:end_line + 1]) + + # Check if chunk is too large, split methods if needed + token_count = self._count_tokens(class_content) + + if token_count > self.max_chunk_size: + # Split into method chunks + method_chunks = [] + for item in node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + method_chunk = self._extract_function_chunk(item, lines, file_path) + if method_chunk: + method_chunks.append(method_chunk) + return method_chunks if method_chunks else None + + # Extract dependencies (imports used in class) + dependencies = self._extract_dependencies(node) + + return CodeChunk( + content=class_content, + chunk_type='class', + name=node.name, + start_line=start_line + 1, + end_line=end_line + 1, + file_path=file_path, + dependencies=dependencies + ) + + def _extract_function_chunk(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef], + lines: List[str], file_path: str) -> Optional[CodeChunk]: + """Extract a function as a chunk.""" + start_line = node.lineno - 1 + end_line = getattr(node, 'end_lineno', node.lineno) - 1 + + function_content = '\n'.join(lines[start_line:end_line + 1]) + + # Extract dependencies + dependencies = self._extract_dependencies(node) + + return CodeChunk( + content=function_content, + chunk_type='function', + name=node.name, + start_line=start_line + 1, + end_line=end_line + 1, + file_path=file_path, + dependencies=dependencies + ) + + def _extract_dependencies(self, node) -> List[str]: + """Extract dependencies (imported names) used in this node.""" + dependencies = [] + + for child in ast.walk(node): + if isinstance(child, ast.Name): + dependencies.append(child.id) + elif isinstance(child, ast.Attribute): + # Handle module.function calls + if isinstance(child.value, ast.Name): + dependencies.append(f"{child.value.id}.{child.attr}") + + return list(set(dependencies)) # Remove duplicates + + def _extract_module_level_code(self, tree: ast.AST, + lines: List[str], file_path: str) -> List[CodeChunk]: + """Extract module-level code that's not in classes or functions.""" + chunks = [] + + # Find lines not covered by classes/functions + covered_lines = set() + + for node in ast.walk(tree): + if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, + ast.Import, ast.ImportFrom)): + start = node.lineno - 1 + end = getattr(node, 'end_lineno', node.lineno) - 1 + covered_lines.update(range(start, end + 1)) + + # Group uncovered lines into chunks + uncovered_lines = [] + for i, line in enumerate(lines): + if i not in covered_lines and line.strip(): + uncovered_lines.append((i, line)) + + if uncovered_lines: + # Group consecutive lines + current_chunk_lines = [] + current_start = None + + for line_num, line in uncovered_lines: + if current_start is None: + current_start = line_num + current_chunk_lines = [line] + elif line_num == current_start + len(current_chunk_lines): + current_chunk_lines.append(line) + else: + # Gap found, save current chunk and start new one + if current_chunk_lines: + chunk_content = '\n'.join(current_chunk_lines) + if self._count_tokens(chunk_content) > self.min_chunk_size: + chunks.append(CodeChunk( + content=chunk_content, + chunk_type='module', + start_line=current_start + 1, + end_line=current_start + len(current_chunk_lines), + file_path=file_path + )) + + current_start = line_num + current_chunk_lines = [line] + + # Add final chunk + if current_chunk_lines: + chunk_content = '\n'.join(current_chunk_lines) + if self._count_tokens(chunk_content) > self.min_chunk_size: + chunks.append(CodeChunk( + content=chunk_content, + chunk_type='module', + start_line=current_start + 1, + end_line=current_start + len(current_chunk_lines), + file_path=file_path + )) + + return chunks + + def _chunk_javascript(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk JavaScript/TypeScript using regex patterns (simplified AST).""" + # This is a simplified implementation + # For production, you'd want to use a proper JS/TS parser like babel or esprima + chunks = [] + lines = content.split('\n') + + # Find function and class boundaries using regex + import re + + function_pattern = r'^(export\s+)?(async\s+)?function\s+(\w+)|^(export\s+)?const\s+(\w+)\s*=\s*(async\s+)?\(' + class_pattern = r'^(export\s+)?class\s+(\w+)' + + current_chunk = [] + current_type = 'module' + current_name = None + brace_count = 0 + in_function = False + + for i, line in enumerate(lines): + if re.match(function_pattern, line.strip()): + # Start of function + if current_chunk: + chunks.append(self._create_js_chunk( + current_chunk, current_type, current_name, file_path + )) + current_chunk = [line] + current_type = 'function' + match = re.match(function_pattern, line.strip()) + current_name = match.group(3) or match.group(5) + brace_count = line.count('{') - line.count('}') + in_function = True + elif re.match(class_pattern, line.strip()): + # Start of class + if current_chunk: + chunks.append(self._create_js_chunk( + current_chunk, current_type, current_name, file_path + )) + current_chunk = [line] + current_type = 'class' + match = re.match(class_pattern, line.strip()) + current_name = match.group(2) + brace_count = line.count('{') - line.count('}') + in_function = True + else: + current_chunk.append(line) + if in_function: + brace_count += line.count('{') - line.count('}') + if brace_count <= 0: + # End of function/class + chunks.append(self._create_js_chunk( + current_chunk, current_type, current_name, file_path + )) + current_chunk = [] + current_type = 'module' + current_name = None + in_function = False + + # Add remaining content + if current_chunk: + chunks.append(self._create_js_chunk( + current_chunk, current_type, current_name, file_path + )) + + return self._optimize_chunks(chunks) + + def _create_js_chunk(self, lines: List[str], chunk_type: str, + name: Optional[str], file_path: str) -> CodeChunk: + """Create a JavaScript chunk from lines.""" + content = '\n'.join(lines) + return CodeChunk( + content=content, + chunk_type=chunk_type, + name=name, + file_path=file_path + ) + + def _chunk_markdown(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk Markdown by headers and sections.""" + chunks = [] + lines = content.split('\n') + + current_chunk = [] + current_header = None + header_level = 0 + + for line in lines: + if line.startswith('#'): + # Header found + if current_chunk: + chunks.append(CodeChunk( + content='\n'.join(current_chunk), + chunk_type='section', + name=current_header, + file_path=file_path + )) + + current_header = line.strip('#').strip() + header_level = len(line) - len(line.lstrip('#')) + current_chunk = [line] + else: + current_chunk.append(line) + + # Add final chunk + if current_chunk: + chunks.append(CodeChunk( + content='\n'.join(current_chunk), + chunk_type='section', + name=current_header, + file_path=file_path + )) + + return chunks + + def _chunk_config(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk configuration files as single units (they're usually coherent).""" + return [CodeChunk( + content=content, + chunk_type='config', + name=Path(file_path).name, + file_path=file_path + )] + + def _chunk_text(self, file_path: str, content: str) -> List[CodeChunk]: + """Fall back to simple text chunking.""" + chunks = [] + lines = content.split('\n') + + current_chunk = [] + current_tokens = 0 + + for line in lines: + line_tokens = self._count_tokens(line) + + if current_tokens + line_tokens > self.max_chunk_size and current_chunk: + chunks.append(CodeChunk( + content='\n'.join(current_chunk), + chunk_type='text', + file_path=file_path + )) + current_chunk = [] + current_tokens = 0 + + current_chunk.append(line) + current_tokens += line_tokens + + if current_chunk: + chunks.append(CodeChunk( + content='\n'.join(current_chunk), + chunk_type='text', + file_path=file_path + )) + + return chunks + + def _optimize_chunks(self, chunks: List[CodeChunk]) -> List[CodeChunk]: + """Optimize chunks by merging small ones and splitting large ones.""" + optimized = [] + + for chunk in chunks: + token_count = self._count_tokens(chunk.content) + + if token_count > self.max_chunk_size: + # Split large chunk + split_chunks = self._split_large_chunk(chunk) + optimized.extend(split_chunks) + elif token_count < self.min_chunk_size and optimized: + # Merge with previous chunk if it won't exceed max size + prev_chunk = optimized[-1] + prev_tokens = self._count_tokens(prev_chunk.content) + + if prev_tokens + token_count <= self.max_chunk_size: + # Merge chunks + merged_content = f"{prev_chunk.content}\n\n{chunk.content}" + optimized[-1] = CodeChunk( + content=merged_content, + chunk_type='merged', + file_path=chunk.file_path, + dependencies=list(set(prev_chunk.dependencies + chunk.dependencies)) + ) + else: + optimized.append(chunk) + else: + optimized.append(chunk) + + return optimized + + def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]: + """Split a chunk that's too large.""" + lines = chunk.content.split('\n') + split_chunks = [] + + current_lines = [] + current_tokens = 0 + + for line in lines: + line_tokens = self._count_tokens(line) + + if current_tokens + line_tokens > self.max_chunk_size and current_lines: + split_chunks.append(CodeChunk( + content='\n'.join(current_lines), + chunk_type=f"{chunk.chunk_type}_split", + name=f"{chunk.name}_part_{len(split_chunks) + 1}" if chunk.name else None, + file_path=chunk.file_path, + dependencies=chunk.dependencies + )) + current_lines = [] + current_tokens = 0 + + current_lines.append(line) + current_tokens += line_tokens + + if current_lines: + split_chunks.append(CodeChunk( + content='\n'.join(current_lines), + chunk_type=f"{chunk.chunk_type}_split", + name=f"{chunk.name}_part_{len(split_chunks) + 1}" if chunk.name else None, + file_path=chunk.file_path, + dependencies=chunk.dependencies + )) + + return split_chunks + + def _count_tokens(self, text: str) -> int: + """Count tokens in text with fallback for when tiktoken is not available.""" + if TIKTOKEN_AVAILABLE and self.encoding: + return len(self.encoding.encode(text)) + else: + # Simple approximation: 4 characters per token + return len(text) // 4 + + # Placeholder methods for other languages + def _chunk_java_kotlin(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk Java/Kotlin code.""" + # Simplified implementation - in production, use proper parsers + return self._chunk_text(file_path, content) + + def _chunk_cpp(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk C++ code.""" + return self._chunk_text(file_path, content) + + def _chunk_rust(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk Rust code.""" + return self._chunk_text(file_path, content) + + def _chunk_go(self, file_path: str, content: str) -> List[CodeChunk]: + """Chunk Go code.""" + return self._chunk_text(file_path, content) \ No newline at end of file diff --git a/api/ast_integration.py b/api/ast_integration.py new file mode 100644 index 00000000..7b9c5304 --- /dev/null +++ b/api/ast_integration.py @@ -0,0 +1,301 @@ +""" +Integration layer for AST chunking with existing data pipeline. +This bridges the AST chunker with the current adalflow TextSplitter interface. +""" + +from typing import List, Dict, Any, Union +import logging +from pathlib import Path + +from adalflow.core.component import Component +from adalflow.core.document import Document + +from .ast_chunker import ASTChunker, CodeChunk + +logger = logging.getLogger(__name__) + + +class ASTTextSplitter(Component): + """ + AST-aware text splitter that integrates with adalflow pipeline. + Provides both AST chunking for code files and fallback text chunking. + """ + + def __init__(self, + split_by: str = "ast", + chunk_size: int = 2000, + chunk_overlap: int = 100, + min_chunk_size: int = 100, + overlap_lines: int = 5, + preserve_structure: bool = True, + fallback_to_text: bool = True): + """ + Initialize AST text splitter. + + Args: + split_by: Splitting strategy - "ast", "word", "character" + chunk_size: Maximum tokens per chunk + chunk_overlap: Token overlap between chunks (for text mode) + min_chunk_size: Minimum tokens per chunk + overlap_lines: Lines of overlap for AST chunks + preserve_structure: Whether to keep code structures together + fallback_to_text: Fall back to text splitting for non-code files + """ + super().__init__() + self.split_by = split_by + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_chunk_size = min_chunk_size + self.overlap_lines = overlap_lines + self.preserve_structure = preserve_structure + self.fallback_to_text = fallback_to_text + + # Initialize AST chunker + self.ast_chunker = ASTChunker( + max_chunk_size=chunk_size, + min_chunk_size=min_chunk_size, + overlap_lines=overlap_lines, + preserve_structure=preserve_structure + ) + + # Initialize text chunker for fallback + if fallback_to_text: + from adalflow.core.text_splitter import TextSplitter + self.text_splitter = TextSplitter( + split_by="word" if split_by == "ast" else split_by, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + + def call(self, documents: List[Document]) -> List[Document]: + """ + Split documents using AST-aware chunking. + + Args: + documents: Input documents to split + + Returns: + List of document chunks + """ + result_docs = [] + + for doc in documents: + try: + if self.split_by == "ast": + chunks = self._ast_split_document(doc) + else: + # Use traditional text splitting + chunks = self.text_splitter.call([doc]) + + result_docs.extend(chunks) + + except Exception as e: + logger.error(f"Error splitting document {doc.id}: {e}") + if self.fallback_to_text and hasattr(self, 'text_splitter'): + # Fallback to text splitting + chunks = self.text_splitter.call([doc]) + result_docs.extend(chunks) + else: + # Keep original document if all else fails + result_docs.append(doc) + + return result_docs + + def _ast_split_document(self, document: Document) -> List[Document]: + """Split a single document using AST chunking.""" + # Extract file path from document metadata + file_path = document.meta_data.get('file_path', '') + if not file_path: + file_path = document.meta_data.get('source', 'unknown') + + # Use AST chunker + code_chunks = self.ast_chunker.chunk_file(file_path, document.text) + + # Convert CodeChunk objects to Document objects + result_docs = [] + for i, chunk in enumerate(code_chunks): + # Create enhanced metadata + enhanced_metadata = document.meta_data.copy() + enhanced_metadata.update({ + 'chunk_id': i, + 'chunk_type': chunk.chunk_type, + 'chunk_name': chunk.name, + 'start_line': chunk.start_line, + 'end_line': chunk.end_line, + 'dependencies': chunk.dependencies, + 'file_path': chunk.file_path, + 'original_doc_id': document.id + }) + + # Create new document + chunk_doc = Document( + text=chunk.content, + id=f"{document.id}_chunk_{i}", + meta_data=enhanced_metadata + ) + + result_docs.append(chunk_doc) + + return result_docs + + def get_chunk_metadata(self, chunk_doc: Document) -> Dict[str, Any]: + """Extract chunk-specific metadata for enhanced retrieval.""" + return { + 'chunk_type': chunk_doc.meta_data.get('chunk_type', 'unknown'), + 'chunk_name': chunk_doc.meta_data.get('chunk_name'), + 'start_line': chunk_doc.meta_data.get('start_line', 0), + 'end_line': chunk_doc.meta_data.get('end_line', 0), + 'dependencies': chunk_doc.meta_data.get('dependencies', []), + 'file_path': chunk_doc.meta_data.get('file_path', ''), + 'language': self._detect_language(chunk_doc.meta_data.get('file_path', '')) + } + + def _detect_language(self, file_path: str) -> str: + """Detect programming language from file extension.""" + ext = Path(file_path).suffix.lower() + + language_map = { + '.py': 'python', + '.js': 'javascript', + '.ts': 'typescript', + '.jsx': 'javascript', + '.tsx': 'typescript', + '.java': 'java', + '.kt': 'kotlin', + '.cpp': 'cpp', + '.cc': 'cpp', + '.cxx': 'cpp', + '.c': 'c', + '.h': 'c', + '.hpp': 'cpp', + '.rs': 'rust', + '.go': 'go', + '.rb': 'ruby', + '.php': 'php', + '.cs': 'csharp', + '.swift': 'swift', + '.md': 'markdown', + '.rst': 'restructuredtext', + '.json': 'json', + '.yaml': 'yaml', + '.yml': 'yaml', + '.toml': 'toml', + '.xml': 'xml', + '.html': 'html', + '.css': 'css', + '.scss': 'scss', + '.sass': 'sass' + } + + return language_map.get(ext, 'text') + + +class EnhancedRAGRetriever: + """Enhanced retriever that uses AST chunk metadata for better results.""" + + def __init__(self, base_retriever): + """Initialize with base retriever (FAISS, etc.).""" + self.base_retriever = base_retriever + + def retrieve(self, query: str, top_k: int = 5, + filter_by_type: List[str] = None, + prefer_functions: bool = False, + prefer_classes: bool = False) -> List[Document]: + """ + Enhanced retrieval with AST-aware filtering. + + Args: + query: Search query + top_k: Number of results to return + filter_by_type: Filter by chunk types (e.g., ['function', 'class']) + prefer_functions: Boost function chunks in results + prefer_classes: Boost class chunks in results + """ + # Get base results + base_results = self.base_retriever.retrieve(query, top_k * 2) # Get more for filtering + + # Apply AST-aware filtering and ranking + filtered_results = [] + + for doc in base_results: + chunk_type = doc.meta_data.get('chunk_type', 'unknown') + + # Apply type filter + if filter_by_type and chunk_type not in filter_by_type: + continue + + # Calculate boost score + boost_score = 1.0 + if prefer_functions and chunk_type == 'function': + boost_score = 1.5 + elif prefer_classes and chunk_type == 'class': + boost_score = 1.5 + + # Add boost to similarity score if available + if hasattr(doc, 'similarity_score'): + doc.similarity_score *= boost_score + + filtered_results.append(doc) + + # Sort by similarity score and return top_k + if hasattr(filtered_results[0], 'similarity_score'): + filtered_results.sort(key=lambda x: x.similarity_score, reverse=True) + + return filtered_results[:top_k] + + def retrieve_related_code(self, query: str, top_k: int = 5) -> Dict[str, List[Document]]: + """ + Retrieve related code organized by type. + + Returns: + Dictionary with keys: 'functions', 'classes', 'imports', 'modules' + """ + all_results = self.base_retriever.retrieve(query, top_k * 4) + + organized_results = { + 'functions': [], + 'classes': [], + 'imports': [], + 'modules': [], + 'other': [] + } + + for doc in all_results: + chunk_type = doc.meta_data.get('chunk_type', 'other') + + if chunk_type == 'function': + organized_results['functions'].append(doc) + elif chunk_type == 'class': + organized_results['classes'].append(doc) + elif chunk_type == 'import_block': + organized_results['imports'].append(doc) + elif chunk_type == 'module': + organized_results['modules'].append(doc) + else: + organized_results['other'].append(doc) + + # Limit each category + for key in organized_results: + organized_results[key] = organized_results[key][:top_k//4 + 1] + + return organized_results + + +def create_ast_config() -> Dict[str, Any]: + """Create default AST chunking configuration.""" + return { + "text_splitter": { + "split_by": "ast", + "chunk_size": 2000, + "chunk_overlap": 100, + "min_chunk_size": 100, + "overlap_lines": 5, + "preserve_structure": True, + "fallback_to_text": True + }, + "retrieval": { + "prefer_functions": True, + "prefer_classes": True, + "boost_code_chunks": 1.5 + } + } \ No newline at end of file diff --git a/api/config/embedder.ast.json b/api/config/embedder.ast.json new file mode 100644 index 00000000..90940bf1 --- /dev/null +++ b/api/config/embedder.ast.json @@ -0,0 +1,57 @@ +{ + "embedder_ollama": { + "client_class": "OllamaClient", + "model_kwargs": { + "model": "nomic-embed-text" + } + }, + "retriever": { + "top_k": 20 + }, + "text_splitter": { + "split_by": "ast", + "chunk_size": 2000, + "chunk_overlap": 100, + "min_chunk_size": 100, + "overlap_lines": 5, + "preserve_structure": true, + "fallback_to_text": true + }, + "retrieval": { + "prefer_functions": true, + "prefer_classes": true, + "boost_code_chunks": 1.5 + }, + "language_support": { + "python": { + "enabled": true, + "max_function_size": 1500, + "max_class_size": 3000, + "preserve_docstrings": true + }, + "javascript": { + "enabled": true, + "max_function_size": 1500, + "parse_jsx": true + }, + "typescript": { + "enabled": true, + "max_function_size": 1500, + "parse_tsx": true + }, + "markdown": { + "enabled": true, + "split_by_headers": true, + "max_section_size": 2000 + }, + "config_files": { + "enabled": true, + "treat_as_single_chunk": true + } + }, + "fallback": { + "split_by": "word", + "chunk_size": 350, + "chunk_overlap": 100 + } +} \ No newline at end of file diff --git a/api/data_pipeline.py b/api/data_pipeline.py index fcea34ce..81ce9c6b 100644 --- a/api/data_pipeline.py +++ b/api/data_pipeline.py @@ -393,7 +393,21 @@ def prepare_data_pipeline(embedder_type: str = None, is_ollama_embedder: bool = if embedder_type is None: embedder_type = get_embedder_type() - splitter = TextSplitter(**configs["text_splitter"]) + # Choose splitter based on configuration + split_by = configs.get("text_splitter", {}).get("split_by", "word") + if split_by == "ast": + # Use AST splitter for better code understanding + try: + from .ast_integration import ASTTextSplitter + splitter = ASTTextSplitter(**configs["text_splitter"]) + print("šŸš€ Using AST-based chunking for better code understanding") + except ImportError as e: + print(f"āš ļø AST chunking not available, falling back to text: {e}") + splitter = TextSplitter(**configs["text_splitter"]) + else: + # Use traditional text splitter + splitter = TextSplitter(**configs["text_splitter"]) + embedder_config = get_embedder_config() embedder = get_embedder(embedder_type=embedder_type) diff --git a/api/enable_ast.py b/api/enable_ast.py new file mode 100644 index 00000000..74328a1b --- /dev/null +++ b/api/enable_ast.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +Simple script to enable/disable AST chunking in DeepWiki. +""" + +import json +import shutil +import os +import sys + +def enable_ast_chunking(): + """Enable AST-based chunking.""" + embedder_config = "config/embedder.json" + ast_config = "config/embedder.ast.json" + backup_config = "config/embedder.json.backup" + + # Check if AST config exists + if not os.path.exists(ast_config): + print(f"āŒ AST config not found: {ast_config}") + return False + + # Backup current config + if os.path.exists(embedder_config): + shutil.copy2(embedder_config, backup_config) + print(f"āœ… Backed up current config to {backup_config}") + + # Load current config to preserve embedder settings + with open(embedder_config, 'r') as f: + current_config = json.load(f) + else: + current_config = {} + + # Load AST config + with open(ast_config, 'r') as f: + ast_config_data = json.load(f) + + # Merge embedder settings from current config into AST config + if 'embedder_ollama' in current_config: + ast_config_data['embedder_ollama'] = current_config['embedder_ollama'] + if 'retriever' in current_config: + ast_config_data['retriever'] = current_config['retriever'] + + # Write merged config + with open(embedder_config, 'w') as f: + json.dump(ast_config_data, f, indent=2) + + print(f"āœ… Enabled AST chunking with preserved embedder settings") + + # Verify the switch + with open(embedder_config, 'r') as f: + config = json.load(f) + split_by = config.get('text_splitter', {}).get('split_by', 'unknown') + print(f"āœ… Current chunking mode: {split_by}") + + return True + +def disable_ast_chunking(): + """Disable AST chunking and restore previous config.""" + embedder_config = "config/embedder.json" + backup_config = "config/embedder.json.backup" + + if os.path.exists(backup_config): + shutil.copy2(backup_config, embedder_config) + print(f"āœ… Restored previous config from {backup_config}") + else: + # Create default text config + default_config = { + "embedder_ollama": { + "client_class": "OllamaClient", + "model_kwargs": { + "model": "nomic-embed-text" + } + }, + "retriever": { + "top_k": 20 + }, + "text_splitter": { + "split_by": "word", + "chunk_size": 350, + "chunk_overlap": 100 + } + } + + with open(embedder_config, 'w') as f: + json.dump(default_config, f, indent=2) + + print(f"āœ… Created default text chunking config") + + # Verify the switch + with open(embedder_config, 'r') as f: + config = json.load(f) + split_by = config.get('text_splitter', {}).get('split_by', 'unknown') + print(f"āœ… Current chunking mode: {split_by}") + +def check_status(): + """Check current chunking status.""" + embedder_config = "config/embedder.json" + + if not os.path.exists(embedder_config): + print("āŒ No embedder config found") + return + + with open(embedder_config, 'r') as f: + config = json.load(f) + split_by = config.get('text_splitter', {}).get('split_by', 'word') + chunk_size = config.get('text_splitter', {}).get('chunk_size', 0) + + print(f"\nšŸ“Š Current Configuration:") + print(f" Chunking mode: {split_by}") + print(f" Chunk size: {chunk_size}") + + if split_by == "ast": + print(" Status: šŸš€ AST chunking ENABLED") + print(" Benefits: Semantic code understanding, function/class boundaries preserved") + else: + print(" Status: šŸ“ Traditional text chunking") + print(" Note: Consider enabling AST chunking for better code understanding") + +def main(): + if len(sys.argv) < 2: + print("Usage: python enable_ast.py [enable|disable|status]") + print("\nCommands:") + print(" enable - Enable AST-based chunking") + print(" disable - Disable AST chunking (restore text chunking)") + print(" status - Show current chunking status") + sys.exit(1) + + command = sys.argv[1].lower() + + if command == "enable": + enable_ast_chunking() + elif command == "disable": + disable_ast_chunking() + elif command == "status": + check_status() + else: + print(f"Unknown command: {command}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/api/enhanced_pipeline.py b/api/enhanced_pipeline.py new file mode 100644 index 00000000..fd050caf --- /dev/null +++ b/api/enhanced_pipeline.py @@ -0,0 +1,324 @@ +""" +Enhanced data pipeline with AST chunking support. +This patch integrates AST chunking into the existing adalflow pipeline. +""" + +import os +from typing import List, Dict, Any, Optional +import logging + +from adalflow.core.document import Document +from adalflow.core.db import LocalDB +import adalflow as adal + +from .ast_integration import ASTTextSplitter, EnhancedRAGRetriever +from .config import get_embedder_config, get_embedder_type +from .data_pipeline import get_embedder, OllamaDocumentProcessor, ToEmbeddings + +logger = logging.getLogger(__name__) + + +def prepare_enhanced_data_pipeline(embedder_type: str = None, + is_ollama_embedder: bool = None, + use_ast_chunking: bool = True) -> adal.Sequential: + """ + Prepare enhanced data pipeline with optional AST chunking support. + + Args: + embedder_type: The embedder type ('openai', 'google', 'ollama') + is_ollama_embedder: DEPRECATED. Use embedder_type instead + use_ast_chunking: Whether to use AST-based chunking + + Returns: + Sequential pipeline with splitter and embedder + """ + # Load configuration + configs = get_embedder_config() + + # Handle legacy parameter + if is_ollama_embedder is not None: + embedder_type = 'ollama' if is_ollama_embedder else None + + # Determine embedder type if not specified + if embedder_type is None: + embedder_type = get_embedder_type() + + # Choose splitter based on configuration + if use_ast_chunking and configs.get("text_splitter", {}).get("split_by") == "ast": + # Use AST splitter + splitter_config = configs["text_splitter"] + splitter = ASTTextSplitter(**splitter_config) + logger.info("Using AST-based text splitting") + else: + # Use traditional text splitter + from adalflow.core.text_splitter import TextSplitter + splitter = TextSplitter(**configs["text_splitter"]) + logger.info("Using traditional text splitting") + + # Get embedder configuration and instance + embedder_config = get_embedder_config() + embedder = get_embedder(embedder_type=embedder_type) + + # Choose appropriate processor based on embedder type + if embedder_type == 'ollama': + # Use Ollama document processor for single-document processing + embedder_transformer = OllamaDocumentProcessor(embedder=embedder) + else: + # Use batch processing for OpenAI and Google embedders + batch_size = embedder_config.get("batch_size", 500) + embedder_transformer = ToEmbeddings( + embedder=embedder, batch_size=batch_size + ) + + # Create sequential pipeline + data_transformer = adal.Sequential( + splitter, embedder_transformer + ) + + return data_transformer + + +def transform_documents_and_save_to_enhanced_db( + documents: List[Document], + db_path: str, + embedder_type: str = None, + is_ollama_embedder: bool = None, + use_ast_chunking: bool = None +) -> LocalDB: + """ + Enhanced document transformation with AST chunking support. + + Args: + documents: List of Document objects + db_path: Path to the local database file + embedder_type: The embedder type ('openai', 'google', 'ollama') + is_ollama_embedder: DEPRECATED. Use embedder_type instead + use_ast_chunking: Whether to use AST chunking (auto-detected if None) + + Returns: + LocalDB instance with processed documents + """ + # Auto-detect AST chunking preference if not specified + if use_ast_chunking is None: + configs = get_embedder_config() + use_ast_chunking = configs.get("text_splitter", {}).get("split_by") == "ast" + + # Get the enhanced data transformer + data_transformer = prepare_enhanced_data_pipeline( + embedder_type=embedder_type, + is_ollama_embedder=is_ollama_embedder, + use_ast_chunking=use_ast_chunking + ) + + # Save the documents to a local database + db = LocalDB() + db.register_transformer(transformer=data_transformer, key="split_and_embed") + db.load(documents) + db.transform(key="split_and_embed") + + # Ensure directory exists and save + os.makedirs(os.path.dirname(db_path), exist_ok=True) + db.save_state(filepath=db_path) + + # Log chunking statistics + _log_chunking_stats(db, use_ast_chunking) + + return db + + +def _log_chunking_stats(db: LocalDB, used_ast: bool): + """Log statistics about the chunking process.""" + try: + documents = db.get_transformed_data("split_and_embed") + if not documents: + return + + total_chunks = len(documents) + + if used_ast: + # Analyze AST chunk types + chunk_types = {} + languages = {} + + for doc in documents: + chunk_type = doc.meta_data.get('chunk_type', 'unknown') + language = doc.meta_data.get('language', 'unknown') + + chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1 + languages[language] = languages.get(language, 0) + 1 + + logger.info(f"AST chunking created {total_chunks} chunks") + logger.info(f"Chunk types: {dict(chunk_types)}") + logger.info(f"Languages: {dict(languages)}") + + # Log function and class statistics + function_chunks = chunk_types.get('function', 0) + class_chunks = chunk_types.get('class', 0) + + if function_chunks > 0: + logger.info(f"Found {function_chunks} function chunks") + if class_chunks > 0: + logger.info(f"Found {class_chunks} class chunks") + else: + logger.info(f"Traditional chunking created {total_chunks} chunks") + + except Exception as e: + logger.error(f"Error logging chunking stats: {e}") + + +def create_enhanced_retriever(db: LocalDB, + embedder_type: str = None) -> EnhancedRAGRetriever: + """ + Create an enhanced retriever with AST-aware capabilities. + + Args: + db: LocalDB instance with processed documents + embedder_type: The embedder type used + + Returns: + EnhancedRAGRetriever instance + """ + # Get base retriever from database + base_retriever = db.get_retriever() # This would need to be implemented in LocalDB + + # Create enhanced retriever + enhanced_retriever = EnhancedRAGRetriever(base_retriever) + + return enhanced_retriever + + +# Configuration utilities +def switch_to_ast_chunking(): + """Switch the system to use AST chunking.""" + import json + import shutil + + # Backup current config + embedder_config_path = "api/config/embedder.json" + backup_path = f"{embedder_config_path}.backup" + + if os.path.exists(embedder_config_path): + shutil.copy2(embedder_config_path, backup_path) + logger.info(f"Backed up current config to {backup_path}") + + # Copy AST config + ast_config_path = "api/config/embedder.ast.json" + if os.path.exists(ast_config_path): + shutil.copy2(ast_config_path, embedder_config_path) + logger.info("Switched to AST chunking configuration") + else: + logger.error(f"AST config file not found: {ast_config_path}") + + +def switch_to_text_chunking(): + """Switch back to traditional text chunking.""" + import json + + embedder_config_path = "api/config/embedder.json" + backup_path = f"{embedder_config_path}.backup" + + if os.path.exists(backup_path): + shutil.copy2(backup_path, embedder_config_path) + logger.info("Switched back to text chunking configuration") + else: + # Create default text config + default_config = { + "text_splitter": { + "split_by": "word", + "chunk_size": 350, + "chunk_overlap": 100 + } + } + + with open(embedder_config_path, 'w') as f: + json.dump(default_config, f, indent=2) + + logger.info("Created default text chunking configuration") + + +def get_chunking_mode() -> str: + """Get current chunking mode.""" + try: + configs = get_embedder_config() + split_by = configs.get("text_splitter", {}).get("split_by", "word") + return "ast" if split_by == "ast" else "text" + except Exception: + return "text" # Default fallback + + +# Example usage for testing +def test_ast_chunking(): + """Test function to demonstrate AST chunking capabilities.""" + + # Sample Python code for testing + sample_code = ''' +import os +import sys +from typing import List, Dict + +class DataProcessor: + """A class for processing data.""" + + def __init__(self, config: Dict): + self.config = config + self.data = [] + + def load_data(self, file_path: str) -> List[Dict]: + """Load data from file.""" + with open(file_path, 'r') as f: + return json.load(f) + + def process_data(self, data: List[Dict]) -> List[Dict]: + """Process the loaded data.""" + processed = [] + for item in data: + processed_item = self._process_item(item) + processed.append(processed_item) + return processed + + def _process_item(self, item: Dict) -> Dict: + """Process a single item.""" + return {**item, 'processed': True} + +def main(): + """Main function.""" + processor = DataProcessor({'debug': True}) + data = processor.load_data('data.json') + result = processor.process_data(data) + print(f"Processed {len(result)} items") + +if __name__ == "__main__": + main() +''' + + # Create test document + test_doc = Document( + text=sample_code, + id="test_python_file", + meta_data={'file_path': 'test.py'} + ) + + # Test AST chunking + from .ast_integration import ASTTextSplitter + + ast_splitter = ASTTextSplitter( + split_by="ast", + chunk_size=2000, + min_chunk_size=100 + ) + + chunks = ast_splitter.call([test_doc]) + + print(f"\nAST Chunking Results:") + print(f"Created {len(chunks)} chunks from test file") + + for i, chunk in enumerate(chunks): + chunk_type = chunk.meta_data.get('chunk_type', 'unknown') + chunk_name = chunk.meta_data.get('chunk_name', 'unnamed') + print(f"\nChunk {i+1}: {chunk_type} - {chunk_name}") + print(f"Lines {chunk.meta_data.get('start_line', 0)}-{chunk.meta_data.get('end_line', 0)}") + print(f"Content preview: {chunk.text[:100]}...") + + +if __name__ == "__main__": + test_ast_chunking() \ No newline at end of file diff --git a/api/prompts.py b/api/prompts.py index 61ef0a4d..fcc41c9e 100644 --- a/api/prompts.py +++ b/api/prompts.py @@ -1,5 +1,56 @@ """Module containing all prompts used in the DeepWiki project.""" +# System prompt for XML Wiki Structure Generation +WIKI_STRUCTURE_SYSTEM_PROMPT = r""" +You are an expert code analyst tasked with analyzing a repository and creating a structured wiki outline. + +CRITICAL XML FORMATTING INSTRUCTIONS: +- You MUST return ONLY valid XML with NO additional text before or after +- DO NOT wrap the XML in markdown code blocks (no ``` or ```xml) +- DO NOT include any explanation or commentary +- Start directly with and end with +- Ensure all XML tags are properly closed +- Use proper XML escaping for special characters (& < > " ') + +XML STRUCTURE REQUIREMENTS: +- The root element must be +- Include a element for the wiki title +- Include a <description> element for the repository description +- For comprehensive mode: Include a <sections> element containing section hierarchies +- Include a <pages> element containing all wiki pages +- Each page must have: id, title, description, importance, relevant_files, related_pages + +Example XML structure (comprehensive mode): +<wiki_structure> + <title>Repository Wiki + A comprehensive guide + +
+ Overview + + page-1 + +
+
+ + + Introduction + Overview of the project + high + + README.md + + + page-2 + + section-1 + + +
+ +IMPORTANT: Your entire response must be valid XML. Do not include any text outside the tags. +""" + # System prompt for RAG RAG_SYSTEM_PROMPT = r""" You are a code assistant which answers user questions on a Github Repo. diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py index 2a7cce9e..906ce0c5 100644 --- a/api/websocket_wiki.py +++ b/api/websocket_wiki.py @@ -9,12 +9,14 @@ from fastapi import WebSocket, WebSocketDisconnect, HTTPException from pydantic import BaseModel, Field +from api.api import get_local_repo_structure from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY from api.data_pipeline import count_tokens, get_file_content from api.openai_client import OpenAIClient from api.openrouter_client import OpenRouterClient from api.azureai_client import AzureAIClient from api.dashscope_client import DashscopeClient +from api.prompts import WIKI_STRUCTURE_SYSTEM_PROMPT from api.rag import RAG # Configure logging @@ -49,6 +51,135 @@ class ChatCompletionRequest(BaseModel): included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") + +async def handle_response_stream( + response, + websocket: WebSocket, + is_structure_generation: bool, + provider: str +) -> None: + """ + Handle streaming or accumulated response based on context. + + This helper function eliminates code duplication between different providers + by handling both streaming (for chat) and accumulating (for wiki structure) modes. + + Args: + response: Async iterator from the model's response + websocket: Active WebSocket connection + is_structure_generation: If True, accumulate full response before sending + provider: Provider name for provider-specific text extraction + """ + if is_structure_generation: + # Accumulate full response for wiki structure generation + logger.info("Accumulating full response for wiki structure generation") + full_response = "" + chunk_count = 0 + + async for chunk in response: + chunk_count += 1 + + # Extract text based on provider and response format + text = None + if provider == "ollama": + # Ollama-specific extraction + if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'): + text = chunk.message.content + elif hasattr(chunk, 'response'): + text = chunk.response + elif hasattr(chunk, 'text'): + text = chunk.text + else: + text = str(chunk) + + logger.info(f"Chunk {chunk_count}: type={type(chunk)}, text_len={len(text) if text else 0}, starts_with={text[:50] if text else 'None'}") + + # Filter out metadata chunks + if text and not text.startswith('model=') and not text.startswith('created_at='): + text = text.replace('', '').replace('', '') + full_response += text + logger.info(f"Added to full_response, new length: {len(full_response)}") + elif provider in ["openai", "azure"]: + # OpenAI/Azure-style responses with choices and delta + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text: + full_response += text + else: + # OpenRouter and other providers - simpler text extraction + if isinstance(chunk, str): + text = chunk + elif hasattr(chunk, 'content'): + text = chunk.content + elif hasattr(chunk, 'text'): + text = chunk.text + else: + text = str(chunk) + + if text: + full_response += text + + # Strip markdown code blocks if present + cleaned_response = full_response.strip() + if cleaned_response.startswith('```xml'): + cleaned_response = cleaned_response[6:] # Remove ```xml + elif cleaned_response.startswith('```'): + cleaned_response = cleaned_response[3:] # Remove ``` + if cleaned_response.endswith('```'): + cleaned_response = cleaned_response[:-3] # Remove trailing ``` + cleaned_response = cleaned_response.strip() + + # Send the complete response at once + logger.info(f"Total chunks processed: {chunk_count}, Sending complete XML structure ({len(cleaned_response)} chars)") + logger.info(f"First 500 chars of response: {cleaned_response[:500]}") + await websocket.send_text(cleaned_response) + else: + # Stream response chunks as they arrive for regular chat + async for chunk in response: + # Extract text based on provider + text = None + if provider == "ollama": + # Ollama-specific extraction + if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'): + text = chunk.message.content + elif hasattr(chunk, 'response'): + text = chunk.response + elif hasattr(chunk, 'text'): + text = chunk.text + else: + text = str(chunk) + + # Filter out metadata chunks and remove thinking tags + if text and not text.startswith('model=') and not text.startswith('created_at='): + text = text.replace('', '').replace('', '') + await websocket.send_text(text) + elif provider in ["openai", "azure"]: + # OpenAI/Azure-style responses with choices and delta + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text: + await websocket.send_text(text) + else: + # OpenRouter and other providers + if isinstance(chunk, str): + text = chunk + elif hasattr(chunk, 'content'): + text = chunk.content + elif hasattr(chunk, 'text'): + text = chunk.text + else: + text = str(chunk) + + if text: + await websocket.send_text(text) + + async def handle_websocket_chat(websocket: WebSocket): """ Handle WebSocket connection for chat completions. @@ -69,8 +200,9 @@ async def handle_websocket_chat(websocket: WebSocket): tokens = count_tokens(last_message.content, request.provider == "ollama") logger.info(f"Request size: {tokens} tokens") if tokens > 8000: - logger.warning(f"Request exceeds recommended token limit ({tokens} > 7500)") + logger.warning(f"Request exceeds recommended token limit ({tokens} > 8000)") input_too_large = True + logger.info("Input is large - RAG retrieval will be used to reduce context size") # Create a new RAG instance for this request try: @@ -178,11 +310,19 @@ async def handle_websocket_chat(websocket: WebSocket): # Get the query from the last message query = last_message.content - # Only retrieve documents if input is not too large + # Use RAG retrieval to get relevant context + # RAG is ALWAYS used when we have embedded documents available + # For large inputs (>8K tokens), RAG is ESSENTIAL to reduce context to manageable size + # For small inputs, RAG still helps focus on most relevant content context_text = "" retrieved_documents = None - if not input_too_large: + # Always attempt RAG retrieval when retriever is prepared + logger.info(f"Checking RAG availability: request_rag={request_rag}, type={type(request_rag)}") + use_rag = request_rag is not None + logger.info(f"use_rag={use_rag}") + + if use_rag: try: # If filePath exists, modify the query for RAG to focus on the file rag_query = query @@ -194,8 +334,15 @@ async def handle_websocket_chat(websocket: WebSocket): # Try to perform RAG retrieval try: # This will use the actual RAG implementation + logger.info(f"Calling RAG with query: {rag_query[:100]}...") retrieved_documents = request_rag(rag_query, language=request.language) - + logger.info(f"RAG returned: {type(retrieved_documents)}, length: {len(retrieved_documents) if retrieved_documents else 0}") + + if retrieved_documents and len(retrieved_documents) > 0: + logger.info(f"First result type: {type(retrieved_documents[0])}") + if hasattr(retrieved_documents[0], 'documents'): + logger.info(f"Documents found: {len(retrieved_documents[0].documents)}") + if retrieved_documents and retrieved_documents[0].documents: # Format context for the prompt in a more structured way documents = retrieved_documents[0].documents @@ -243,8 +390,19 @@ async def handle_websocket_chat(websocket: WebSocket): supported_langs = configs["lang_config"]["supported_languages"] language_name = supported_langs.get(language_code, "English") + # Detect if this is a wiki structure generation request + is_structure_generation = "create a wiki structure" in query.lower() or \ + ("analyze this github repository" in query.lower() and + "wiki structure" in query.lower()) + + if is_structure_generation: + logger.info("Detected wiki structure generation request - will use XML format") + # Create system prompt - if is_deep_research: + if is_structure_generation: + # Use the XML structure prompt imported at the top + system_prompt = WIKI_STRUCTURE_SYSTEM_PROMPT + elif is_deep_research: # Check if this is the first iteration is_first_iteration = research_iteration == 1 @@ -426,10 +584,14 @@ async def handle_websocket_chat(websocket: WebSocket): prompt += f"\n{query}\n\n\nAssistant: " + logger.info(f"About to get model config for provider={request.provider}, model={request.model}") model_config = get_model_config(request.provider, request.model)["model_kwargs"] + logger.info(f"Got model_config: {model_config}") if request.provider == "ollama": prompt += " /no_think" + + logger.debug("Entering Ollama provider block") model = OllamaClient() model_kwargs = { @@ -441,12 +603,18 @@ async def handle_websocket_chat(websocket: WebSocket): "num_ctx": model_config["num_ctx"] } } + logger.debug(f"Created model_kwargs for Ollama: {model_kwargs}") api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) + logger.debug(f"api_kwargs before model fix: {api_kwargs}") + + # WORKAROUND: Force the model name in api_kwargs as convert_inputs_to_api_kwargs seems to override it + api_kwargs["model"] = model_config["model"] + logger.debug(f"api_kwargs after forcing model to {model_config['model']}: {api_kwargs}") elif request.provider == "openrouter": logger.info(f"Using OpenRouter with model: {request.model}") @@ -544,12 +712,15 @@ async def handle_websocket_chat(websocket: WebSocket): if request.provider == "ollama": # Get the response and handle it properly using the previously created api_kwargs response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) - # Handle streaming response from Ollama - async for chunk in response: - text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk) - if text and not text.startswith('model=') and not text.startswith('created_at='): - text = text.replace('', '').replace('', '') - await websocket.send_text(text) + + # Use shared helper to handle streaming or accumulation + await handle_response_stream( + response=response, + websocket=websocket, + is_structure_generation=is_structure_generation, + provider="ollama" + ) + # Explicitly close the WebSocket connection after the response is complete await websocket.close() elif request.provider == "openrouter": @@ -557,9 +728,15 @@ async def handle_websocket_chat(websocket: WebSocket): # Get the response and handle it properly using the previously created api_kwargs logger.info("Making OpenRouter API call") response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) - # Handle streaming response from OpenRouter - async for chunk in response: - await websocket.send_text(chunk) + + # Use shared helper to handle streaming or accumulation + await handle_response_stream( + response=response, + websocket=websocket, + is_structure_generation=is_structure_generation, + provider="openrouter" + ) + # Explicitly close the WebSocket connection after the response is complete await websocket.close() except Exception as e_openrouter: @@ -573,15 +750,15 @@ async def handle_websocket_chat(websocket: WebSocket): # Get the response and handle it properly using the previously created api_kwargs logger.info("Making Openai API call") response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) - # Handle streaming response from Openai - async for chunk in response: - choices = getattr(chunk, "choices", []) - if len(choices) > 0: - delta = getattr(choices[0], "delta", None) - if delta is not None: - text = getattr(delta, "content", None) - if text is not None: - await websocket.send_text(text) + + # Use shared helper to handle streaming or accumulation + await handle_response_stream( + response=response, + websocket=websocket, + is_structure_generation=is_structure_generation, + provider="openai" + ) + # Explicitly close the WebSocket connection after the response is complete await websocket.close() except Exception as e_openai: @@ -595,15 +772,15 @@ async def handle_websocket_chat(websocket: WebSocket): # Get the response and handle it properly using the previously created api_kwargs logger.info("Making Azure AI API call") response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) - # Handle streaming response from Azure AI - async for chunk in response: - choices = getattr(chunk, "choices", []) - if len(choices) > 0: - delta = getattr(choices[0], "delta", None) - if delta is not None: - text = getattr(delta, "content", None) - if text is not None: - await websocket.send_text(text) + + # Use shared helper to handle streaming or accumulation + await handle_response_stream( + response=response, + websocket=websocket, + is_structure_generation=is_structure_generation, + provider="azure" + ) + # Explicitly close the WebSocket connection after the response is complete await websocket.close() except Exception as e_azure: @@ -767,3 +944,399 @@ async def handle_websocket_chat(websocket: WebSocket): await websocket.close() except: pass + + +# ============================================================================ +# CHUNKED WIKI GENERATION FOR LARGE REPOSITORIES +# ============================================================================ + +async def process_wiki_chunk( + chunk_data: Dict[str, Any], + chunk_id: int, + total_chunks: int, + request: ChatCompletionRequest, + readme_content: str +) -> str: + """ + Process a single chunk of files to generate partial wiki structure. + + Args: + chunk_data: Dictionary with chunk info (files, directories, file_count) + chunk_id: Index of this chunk + total_chunks: Total number of chunks + request: Original chat completion request + readme_content: README content for context + + Returns: + XML string with partial wiki structure for this chunk + """ + logger.info(f"Processing chunk {chunk_id + 1}/{total_chunks} with {chunk_data['file_count']} files") + + # Create focused query for this chunk + chunk_dirs = chunk_data.get('directories', []) + chunk_query = f"""Analyze chunk {chunk_id + 1} of {total_chunks} for this repository. + +This chunk contains {chunk_data['file_count']} files from these directories: {', '.join(chunk_dirs[:10])} + +Generate a partial wiki structure for ONLY the files in this chunk. Focus on: +1. Identifying the purpose of these files/directories +2. Their role in the overall system +3. How they relate to each other + +Return the result in the same XML format, but only include pages relevant to this chunk. + + +{chunk_data.get('file_tree', '')} + + + +{readme_content} +""" + + try: + # Use RAG to get relevant context for this chunk's files + retrieved_documents = None + if request_rag: + try: + # Create a focused query for RAG retrieval based on chunk directories + rag_query = f"Information about: {', '.join(chunk_dirs[:5])}" + logger.info(f"RAG query for chunk {chunk_id + 1}: {rag_query}") + + retrieved_documents = request_rag(rag_query, language="en") + + if retrieved_documents and retrieved_documents[0].documents: + documents = retrieved_documents[0].documents + logger.info(f"Retrieved {len(documents)} documents for chunk {chunk_id + 1}") + + # Group documents by file path + docs_by_file = {} + for doc in documents: + file_path = doc.meta_data.get('file_path', 'unknown') + if file_path not in docs_by_file: + docs_by_file[file_path] = [] + docs_by_file[file_path].append(doc) + + # Add context to query + context_parts = [] + for file_path, docs in docs_by_file.items(): + header = f"## File Path: {file_path}\n\n" + content = "\n\n".join([doc.text for doc in docs]) + context_parts.append(f"{header}{content}") + + context_text = "\n\n" + "-" * 10 + "\n\n".join(context_parts) + chunk_query = f"\n{context_text}\n\n\n{chunk_query}" + + except Exception as e: + logger.warning(f"RAG retrieval failed for chunk {chunk_id + 1}: {str(e)}") + + # Get model configuration + model_config = get_model_config(request.provider, request.model)["model_kwargs"] + + # Initialize the appropriate model client based on provider + if request.provider == "ollama": + from api.ollama_patch import OllamaClient + model = OllamaClient() + model_kwargs = { + "model": model_config["model"], + "stream": False, # Non-streaming for chunk processing + "options": { + "temperature": model_config["temperature"], + "top_p": model_config["top_p"], + "num_ctx": model_config["num_ctx"] + } + } + api_kwargs = model.convert_inputs_to_api_kwargs( + input=chunk_query, + model_kwargs=model_kwargs, + model_type=ModelType.LLM + ) + api_kwargs["model"] = model_config["model"] + + elif request.provider == "openrouter": + from api.openrouter_client import OpenRouterClient + model = OpenRouterClient() + model_kwargs = { + "model": request.model, + "stream": False, + "temperature": model_config["temperature"] + } + if "top_p" in model_config: + model_kwargs["top_p"] = model_config["top_p"] + api_kwargs = model.convert_inputs_to_api_kwargs( + input=chunk_query, + model_kwargs=model_kwargs, + model_type=ModelType.LLM + ) + + elif request.provider == "openai": + from api.openai_client import OpenAIClient + model = OpenAIClient() + model_kwargs = { + "model": request.model, + "stream": False, + "temperature": model_config["temperature"] + } + if "top_p" in model_config: + model_kwargs["top_p"] = model_config["top_p"] + api_kwargs = model.convert_inputs_to_api_kwargs( + input=chunk_query, + model_kwargs=model_kwargs, + model_type=ModelType.LLM + ) + else: + # Fallback: use OpenAI-compatible endpoint + from api.openai_client import OpenAIClient + model = OpenAIClient() + model_kwargs = { + "model": request.model, + "stream": False, + "temperature": model_config.get("temperature", 0.7) + } + api_kwargs = model.convert_inputs_to_api_kwargs( + input=chunk_query, + model_kwargs=model_kwargs, + model_type=ModelType.LLM + ) + + # Call the model synchronously for chunk processing + import asyncio + response = asyncio.create_task(model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)) + result = await response + + # Collect the response + full_response = "" + async for chunk in result: + # Extract text based on response format + if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'): + text = chunk.message.content + elif hasattr(chunk, 'response'): + text = chunk.response + elif hasattr(chunk, 'text'): + text = chunk.text + else: + text = str(chunk) + + if text: + full_response += text + + # Strip markdown code blocks if present + cleaned_response = full_response.strip() + if cleaned_response.startswith('```xml'): + cleaned_response = cleaned_response[6:] + elif cleaned_response.startswith('```'): + cleaned_response = cleaned_response[3:] + if cleaned_response.endswith('```'): + cleaned_response = cleaned_response[:-3] + + logger.info(f"Chunk {chunk_id + 1} processed successfully, response length: {len(cleaned_response)}") + return cleaned_response.strip() + + except Exception as e: + logger.error(f"Error processing chunk {chunk_id + 1}: {str(e)}") + # Return a minimal valid XML structure on error + return f""" + Error processing chunk {chunk_id + 1}: {str(e)} +""" + + +def merge_wiki_structures(partial_wikis: List[str]) -> str: + """ + Merge multiple partial wiki structures into a single cohesive wiki. + + Args: + partial_wikis: List of XML strings, each containing partial wiki structure + + Returns: + Combined XML wiki structure + """ + logger.info(f"Merging {len(partial_wikis)} partial wiki structures") + + try: + from xml.etree import ElementTree as ET + + # Parse all partial wikis + all_pages = [] + all_sections = [] + titles = [] + descriptions = [] + + for i, partial_xml in enumerate(partial_wikis): + try: + # Parse the XML + root = ET.fromstring(partial_xml) + + # Extract title and description + title_elem = root.find('title') + if title_elem is not None and title_elem.text: + titles.append(title_elem.text) + + desc_elem = root.find('description') + if desc_elem is not None and desc_elem.text: + descriptions.append(desc_elem.text) + + # Extract pages + pages_elem = root.find('pages') + if pages_elem is not None: + for page in pages_elem.findall('page'): + # Add chunk information to page + page.set('source_chunk', str(i + 1)) + all_pages.append(page) + + # Extract sections if present + sections_elem = root.find('sections') + if sections_elem is not None: + for section in sections_elem.findall('section'): + section.set('source_chunk', str(i + 1)) + all_sections.append(section) + + except ET.ParseError as e: + logger.error(f"Failed to parse partial wiki {i + 1}: {str(e)}") + continue + + # Deduplicate pages by ID + unique_pages = {} + for page in all_pages: + page_id = page.get('id', f'page-{len(unique_pages) + 1}') + if page_id not in unique_pages: + unique_pages[page_id] = page + else: + # Merge information if duplicate found + logger.debug(f"Duplicate page ID found: {page_id}, keeping first occurrence") + + # Build merged structure + merged_root = ET.Element('wiki_structure') + + # Use first non-empty title or generate one + title_elem = ET.SubElement(merged_root, 'title') + title_elem.text = titles[0] if titles else "Repository Documentation" + + # Combine descriptions + desc_elem = ET.SubElement(merged_root, 'description') + if descriptions: + desc_elem.text = descriptions[0] # Use first description as primary + else: + desc_elem.text = "Comprehensive documentation generated from repository analysis" + + # Add sections if any were found + if all_sections: + sections_container = ET.SubElement(merged_root, 'sections') + for section in all_sections: + sections_container.append(section) + + # Add all unique pages + pages_container = ET.SubElement(merged_root, 'pages') + for page_id, page in unique_pages.items(): + pages_container.append(page) + + # Convert back to string with proper formatting + xml_string = ET.tostring(merged_root, encoding='unicode', method='xml') + + logger.info(f"Successfully merged {len(partial_wikis)} wikis into {len(unique_pages)} unique pages") + return xml_string + + except Exception as e: + logger.error(f"Error merging wiki structures: {str(e)}") + + # Fallback: Simple concatenation with wrapper + merged = "\n" + merged += " Repository Documentation\n" + merged += " Combined documentation from multiple repository sections\n" + merged += " \n" + + # Try to extract individual pages from each partial + page_counter = 1 + for i, partial in enumerate(partial_wikis): + try: + # Simple extraction of page elements + if ']*>.*?', partial, re.DOTALL) + for page in pages: + # Ensure page has an ID + if 'id=' not in page: + page = page.replace('\n' + merged += f' Chunk {i+1} Processing Note\n' + merged += f' Could not merge chunk {i+1} properly\n' + merged += f' low\n' + merged += f' \n' + + merged += " \n" + merged += "" + + return merged + + +async def handle_chunked_wiki_generation( + websocket: WebSocket, + repo_path: str, + request: ChatCompletionRequest +) -> None: + """ + Handle wiki generation for large repositories using chunked processing. + + This function: + 1. Fetches repository structure with chunking enabled + 2. Processes each chunk separately with RAG + 3. Merges partial results into final wiki structure + 4. Sends progress updates via WebSocket + + Args: + websocket: Active WebSocket connection + repo_path: Path to the repository + request: Original chat completion request + """ + try: + logger.info(f"Starting chunked wiki generation for {repo_path}") + await websocket.send_text("šŸ”„ Analyzing large repository structure...\n") + + # Call get_local_repo_structure directly instead of making HTTP request + repo_info = await get_local_repo_structure( + path=repo_path, + return_chunks=True, + chunk_size=500 + ) + + if not repo_info.get('chunked'): + # Small repo, process normally + await websocket.send_text("Repository is small enough to process in one go.\n") + return + + chunks = repo_info.get('chunks', []) + readme = repo_info.get('readme', '') + total_files = repo_info.get('total_files', 0) + + await websocket.send_text(f"šŸ“Š Repository has {total_files} files split into {len(chunks)} chunks\n") + await websocket.send_text("šŸ” Processing each chunk with RAG...\n\n") + + partial_wikis = [] + for i, chunk in enumerate(chunks): + await websocket.send_text(f"ā³ Chunk {i + 1}/{len(chunks)}: {chunk['file_count']} files from {len(chunk['directories'])} directories\n") + + # Process this chunk + partial_wiki = await process_wiki_chunk(chunk, i, len(chunks), request, readme) + partial_wikis.append(partial_wiki) + + await websocket.send_text(f"āœ… Completed chunk {i + 1}/{len(chunks)}\n\n") + + await websocket.send_text("šŸ”— Merging all chunks into final wiki structure...\n") + + # Merge all partial wikis + final_wiki = merge_wiki_structures(partial_wikis) + + await websocket.send_text("\nšŸ“ Final wiki structure:\n\n") + await websocket.send_text(final_wiki) + + logger.info("Chunked wiki generation completed successfully") + + except Exception as e: + logger.error(f"Error in chunked wiki generation: {str(e)}") + await websocket.send_text(f"\nāŒ Error: {str(e)}\n") + finally: + await websocket.close() diff --git a/docker-compose.yml b/docker-compose.yml index 15f9cdb6..f8c9a07b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,8 @@ services: build: context: . dockerfile: Dockerfile + args: + AST_CHUNKING: ${AST_CHUNKING:-false} # Set to 'true' to enable AST chunking during build ports: - "${PORT:-8001}:${PORT:-8001}" # API port - "3000:3000" # Next.js port @@ -17,6 +19,7 @@ services: volumes: - ~/.adalflow:/root/.adalflow # Persist repository and embedding data - ./api/logs:/app/api/logs # Persist log files across container restarts + - ./api/config:/app/api/config:ro # Mount config files (read-only) # Resource limits for docker-compose up (not Swarm mode) mem_limit: 6g mem_reservation: 2g @@ -27,3 +30,5 @@ services: timeout: 10s retries: 3 start_period: 30s + extra_hosts: + - "host.docker.internal:host-gateway" # Allow access to host machine from Docker