diff --git a/Dockerfile b/Dockerfile
index ab818ef5..6b384e8d 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,6 +2,8 @@
# Build argument for custom certificates directory
ARG CUSTOM_CERT_DIR="certs"
+# Build argument for AST chunking (default to false)
+ARG AST_CHUNKING=false
FROM node:20-alpine3.22 AS node_base
@@ -67,6 +69,15 @@ ENV PATH="/opt/venv/bin:$PATH"
COPY --from=py_deps /opt/venv /opt/venv
COPY api/ ./api/
+# Configure AST chunking based on build argument
+RUN if [ "$AST_CHUNKING" = "true" ]; then \
+ echo "š Enabling AST chunking during build..."; \
+ cd /app/api && python enable_ast.py enable; \
+ else \
+ echo "š Using default text chunking..."; \
+ cd /app/api && python enable_ast.py disable; \
+ fi
+
# Copy Node app
COPY --from=node_builder /app/public ./public
COPY --from=node_builder /app/.next/standalone ./
diff --git a/api/api.py b/api/api.py
index d40e73f9..5683f3bd 100644
--- a/api/api.py
+++ b/api/api.py
@@ -9,20 +9,215 @@
from pydantic import BaseModel, Field
import google.generativeai as genai
import asyncio
+from collections import defaultdict
+import fnmatch
# Configure logging
from api.logging_config import setup_logging
+from api.config import load_repo_config
setup_logging()
logger = logging.getLogger(__name__)
# Initialize FastAPI app
-app = FastAPI(
- title="Streaming API",
- description="API for streaming chat completions"
+app = FastAPI()
+
+# Configure CORS
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # In production, specify your frontend domain
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
)
+
+# Pydantic models for wiki pages
+class WikiPage(BaseModel):
+ id: str
+ title: str
+ content: str
+ related_pages: List[str] = []
+
+
+# ============================================================================
+# INTELLIGENT FILE CHUNKING SYSTEM
+# ============================================================================
+
+def should_exclude_dir(dir_name: str, excluded_patterns: List[str]) -> bool:
+ """Check if directory should be excluded based on patterns."""
+ # Always exclude hidden directories and common build/cache dirs
+ if dir_name.startswith('.'):
+ return True
+ if dir_name in ['__pycache__', 'node_modules', '.venv', 'venv', 'env',
+ 'image-cache', 'dist', 'build', 'target', 'out']:
+ return True
+
+ # Check against user-defined patterns
+ for pattern in excluded_patterns:
+ pattern_clean = pattern.strip('./').rstrip('/')
+ if fnmatch.fnmatch(dir_name, pattern_clean):
+ return True
+ return False
+
+
+def should_exclude_file(file_name: str, excluded_patterns: List[str]) -> bool:
+ """Check if file should be excluded based on patterns."""
+ # Always exclude hidden files and common files
+ if file_name.startswith('.') or file_name == '__init__.py' or file_name == '.DS_Store':
+ return True
+
+ # Check against user-defined patterns
+ for pattern in excluded_patterns:
+ if fnmatch.fnmatch(file_name, pattern):
+ return True
+ return False
+
+
+def collect_all_files(path: str, config: Dict) -> tuple[List[str], str]:
+ """
+ Collect ALL files from repository respecting include/exclude patterns.
+ Also finds and reads README.md during the same walk.
+
+ Args:
+ path: Root directory path
+ config: Configuration with excluded_dirs and excluded_files
+
+ Returns:
+ Tuple of (list of relative file paths, README content string)
+ """
+ all_files = []
+ readme_content = ""
+ excluded_dirs = config.get('excluded_dirs', [])
+ excluded_files = config.get('excluded_files', [])
+
+ logger.info(f"Collecting files from {path}")
+ logger.info(f"Excluded dirs: {len(excluded_dirs)} patterns")
+ logger.info(f"Excluded files: {len(excluded_files)} patterns")
+
+ for root, dirs, files in os.walk(path):
+ # Filter directories in-place
+ dirs[:] = [d for d in dirs if not should_exclude_dir(d, excluded_dirs)]
+
+ for file in files:
+ if not should_exclude_file(file, excluded_files):
+ rel_dir = os.path.relpath(root, path)
+ rel_file = os.path.join(rel_dir, file) if rel_dir != '.' else file
+ all_files.append(rel_file)
+
+ # Find README.md (case-insensitive) during the same walk
+ if file.lower() == 'readme.md' and not readme_content:
+ try:
+ with open(os.path.join(root, file), 'r', encoding='utf-8') as f:
+ readme_content = f.read()
+ logger.info(f"Found README.md at: {rel_file}")
+ except Exception as e:
+ logger.warning(f"Could not read README.md at {rel_file}: {str(e)}")
+
+ logger.info(f"Collected {len(all_files)} files after filtering")
+ return all_files, readme_content
+
+
+def group_files_by_directory(files: List[str]) -> Dict[str, List[str]]:
+ """Group files by their parent directory."""
+ by_dir = defaultdict(list)
+
+ for file_path in files:
+ dir_name = os.path.dirname(file_path)
+ if not dir_name:
+ dir_name = "root"
+ by_dir[dir_name].append(file_path)
+
+ return dict(by_dir)
+
+
+def create_file_chunks(files: List[str], max_files_per_chunk: int = 500) -> List[Dict[str, Any]]:
+ """
+ Create intelligent chunks of files grouped by directory.
+ Ensures no chunk exceeds max_files_per_chunk by splitting large directories.
+
+ Args:
+ files: List of all file paths
+ max_files_per_chunk: Maximum files per chunk
+
+ Returns:
+ List of chunk dictionaries with metadata
+ """
+ # Group by directory
+ by_dir = group_files_by_directory(files)
+
+ chunks = []
+ current_chunk_files = []
+ current_chunk_dirs = []
+
+ for dir_name, dir_files in sorted(by_dir.items()):
+ # Handle large directories that exceed max_files_per_chunk on their own
+ if len(dir_files) > max_files_per_chunk:
+ # First, save current chunk if it has files
+ if current_chunk_files:
+ chunks.append({
+ 'files': current_chunk_files[:],
+ 'directories': current_chunk_dirs[:],
+ 'file_count': len(current_chunk_files)
+ })
+ current_chunk_files = []
+ current_chunk_dirs = []
+
+ # Split large directory across multiple chunks
+ logger.warning(f"Directory '{dir_name}' has {len(dir_files)} files, splitting across multiple chunks")
+ for i in range(0, len(dir_files), max_files_per_chunk):
+ chunk_slice = dir_files[i:i + max_files_per_chunk]
+ chunks.append({
+ 'files': chunk_slice,
+ 'directories': [f"{dir_name} (part {i//max_files_per_chunk + 1})"],
+ 'file_count': len(chunk_slice)
+ })
+ else:
+ # Normal case: check if adding this directory would exceed limit
+ if current_chunk_files and len(current_chunk_files) + len(dir_files) > max_files_per_chunk:
+ # Save current chunk and start new one
+ chunks.append({
+ 'files': current_chunk_files[:],
+ 'directories': current_chunk_dirs[:],
+ 'file_count': len(current_chunk_files)
+ })
+ current_chunk_files = []
+ current_chunk_dirs = []
+
+ # Add directory to current chunk
+ current_chunk_files.extend(dir_files)
+ current_chunk_dirs.append(dir_name)
+
+ # Add final chunk if it has files
+ if current_chunk_files:
+ chunks.append({
+ 'files': current_chunk_files,
+ 'directories': current_chunk_dirs,
+ 'file_count': len(current_chunk_files)
+ })
+
+ logger.info(f"Created {len(chunks)} chunks from {len(files)} files")
+ for i, chunk in enumerate(chunks):
+ logger.info(f" Chunk {i+1}: {chunk['file_count']} files across {len(chunk['directories'])} directories")
+
+ return chunks
+
+
+def format_chunk_as_tree(chunk: Dict[str, Any]) -> str:
+ """Format a chunk of files as a tree string."""
+ files = chunk['files']
+ tree_lines = sorted(files)
+
+ # Add chunk metadata
+ chunk_info = f"# Chunk contains {len(files)} files from {len(chunk['directories'])} directories\n"
+ chunk_info += f"# Directories: {', '.join(chunk['directories'][:5])}"
+ if len(chunk['directories']) > 5:
+ chunk_info += f" ... and {len(chunk['directories']) - 5} more"
+ chunk_info += "\n\n"
+
+ return chunk_info + '\n'.join(tree_lines)
+
# Configure CORS
app.add_middleware(
CORSMiddleware,
@@ -273,8 +468,19 @@ async def export_wiki(request: WikiExportRequest):
raise HTTPException(status_code=500, detail=error_msg)
@app.get("/local_repo/structure")
-async def get_local_repo_structure(path: str = Query(None, description="Path to local repository")):
- """Return the file tree and README content for a local repository."""
+async def get_local_repo_structure(
+ path: str = Query(None, description="Path to local repository"),
+ chunk_size: int = Query(500, description="Maximum files per chunk"),
+ return_chunks: bool = Query(False, description="Return chunked structure for large repos")
+):
+ """
+ Return the file tree and README content for a local repository.
+
+ Now supports intelligent chunking for large repositories:
+ - Collects ALL files respecting include/exclude patterns
+ - Groups files by directory
+ - Returns chunks if repository is large
+ """
if not path:
return JSONResponse(
status_code=400,
@@ -288,30 +494,48 @@ async def get_local_repo_structure(path: str = Query(None, description="Path to
)
try:
- logger.info(f"Processing local repository at: {path}")
- file_tree_lines = []
- readme_content = ""
-
- for root, dirs, files in os.walk(path):
- # Exclude hidden dirs/files and virtual envs
- dirs[:] = [d for d in dirs if not d.startswith('.') and d != '__pycache__' and d != 'node_modules' and d != '.venv']
- for file in files:
- if file.startswith('.') or file == '__init__.py' or file == '.DS_Store':
- continue
- rel_dir = os.path.relpath(root, path)
- rel_file = os.path.join(rel_dir, file) if rel_dir != '.' else file
- file_tree_lines.append(rel_file)
- # Find README.md (case-insensitive)
- if file.lower() == 'readme.md' and not readme_content:
- try:
- with open(os.path.join(root, file), 'r', encoding='utf-8') as f:
- readme_content = f.read()
- except Exception as e:
- logger.warning(f"Could not read README.md: {str(e)}")
- readme_content = ""
-
- file_tree_str = '\n'.join(sorted(file_tree_lines))
- return {"file_tree": file_tree_str, "readme": readme_content}
+ logger.info(f"Processing local repository at: {path} (chunk_size={chunk_size}, return_chunks={return_chunks})")
+
+ # Load configuration from repo.json (imported at the top)
+ config_data = load_repo_config()
+ file_filters = config_data.get('file_filters', {})
+
+ # Collect ALL files respecting patterns and find README in one pass
+ all_files, readme_content = collect_all_files(path, file_filters)
+
+ # Decide whether to chunk based on repository size
+ total_files = len(all_files)
+ logger.info(f"Total files collected: {total_files}")
+
+ if return_chunks or total_files > chunk_size:
+ # Create intelligent chunks
+ chunks = create_file_chunks(all_files, max_files_per_chunk=chunk_size)
+
+ return {
+ "chunked": True,
+ "total_files": total_files,
+ "chunk_count": len(chunks),
+ "chunks": [
+ {
+ "chunk_id": i,
+ "file_count": chunk['file_count'],
+ "directories": chunk['directories'],
+ "file_tree": format_chunk_as_tree(chunk)
+ }
+ for i, chunk in enumerate(chunks)
+ ],
+ "readme": readme_content
+ }
+ else:
+ # Small repo, return as single tree
+ file_tree_str = '\n'.join(sorted(all_files))
+ return {
+ "chunked": False,
+ "total_files": total_files,
+ "file_tree": file_tree_str,
+ "readme": readme_content
+ }
+
except Exception as e:
logger.error(f"Error processing local repository: {str(e)}")
return JSONResponse(
diff --git a/api/ast_chunker.py b/api/ast_chunker.py
new file mode 100644
index 00000000..eb78023b
--- /dev/null
+++ b/api/ast_chunker.py
@@ -0,0 +1,528 @@
+"""
+AST-based chunking for code files using LlamaIndex-inspired approach.
+This provides semantic chunking that respects code structure.
+"""
+
+import ast
+import logging
+from typing import List, Dict, Any, Optional, Union
+from pathlib import Path
+from dataclasses import dataclass
+
+# Use a simple token counter if tiktoken is not available
+try:
+ import tiktoken
+ TIKTOKEN_AVAILABLE = True
+except ImportError:
+ TIKTOKEN_AVAILABLE = False
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class CodeChunk:
+ """Represents a semantically meaningful chunk of code."""
+ content: str
+ chunk_type: str # 'function', 'class', 'module', 'import_block', 'comment_block'
+ name: Optional[str] = None # function/class name
+ start_line: int = 0
+ end_line: int = 0
+ file_path: str = ""
+ dependencies: List[str] = None # imported modules/functions this chunk depends on
+
+ def __post_init__(self):
+ if self.dependencies is None:
+ self.dependencies = []
+
+
+class ASTChunker:
+ """AST-based code chunker that creates semantically meaningful chunks."""
+
+ def __init__(self,
+ max_chunk_size: int = 2000,
+ min_chunk_size: int = 100,
+ overlap_lines: int = 5,
+ preserve_structure: bool = True):
+ """
+ Initialize AST chunker.
+
+ Args:
+ max_chunk_size: Maximum tokens per chunk
+ min_chunk_size: Minimum tokens per chunk
+ overlap_lines: Lines of overlap between chunks
+ preserve_structure: Whether to keep related code together
+ """
+ self.max_chunk_size = max_chunk_size
+ self.min_chunk_size = min_chunk_size
+ self.overlap_lines = overlap_lines
+ self.preserve_structure = preserve_structure
+
+ # Initialize token encoder with fallback
+ if TIKTOKEN_AVAILABLE:
+ self.encoding = tiktoken.get_encoding("cl100k_base")
+ else:
+ # Simple fallback: approximate 4 chars per token
+ self.encoding = None
+
+ def chunk_file(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk a single file based on its type and structure."""
+ file_ext = Path(file_path).suffix.lower()
+
+ # Route to appropriate chunker based on file type
+ if file_ext == '.py':
+ return self._chunk_python(file_path, content)
+ elif file_ext in ['.js', '.ts', '.jsx', '.tsx']:
+ return self._chunk_javascript(file_path, content)
+ elif file_ext in ['.java', '.kt']:
+ return self._chunk_java_kotlin(file_path, content)
+ elif file_ext in ['.cpp', '.cc', '.cxx', '.c', '.h', '.hpp']:
+ return self._chunk_cpp(file_path, content)
+ elif file_ext in ['.rs']:
+ return self._chunk_rust(file_path, content)
+ elif file_ext in ['.go']:
+ return self._chunk_go(file_path, content)
+ elif file_ext in ['.md', '.rst', '.txt']:
+ return self._chunk_markdown(file_path, content)
+ elif file_ext in ['.json', '.yaml', '.yml', '.toml']:
+ return self._chunk_config(file_path, content)
+ else:
+ # Fall back to text-based chunking
+ return self._chunk_text(file_path, content)
+
+ def _chunk_python(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk Python code using AST analysis."""
+ chunks = []
+
+ try:
+ tree = ast.parse(content)
+ lines = content.split('\n')
+
+ # Group imports at the top
+ imports = []
+ other_nodes = []
+
+ for node in ast.walk(tree):
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
+ imports.append(node)
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
+ other_nodes.append(node)
+
+ # Create import chunk if imports exist
+ if imports:
+ import_lines = []
+ for imp in imports:
+ if hasattr(imp, 'lineno'):
+ import_lines.extend(range(imp.lineno - 1,
+ getattr(imp, 'end_lineno', imp.lineno)))
+
+ if import_lines:
+ import_content = '\n'.join(lines[min(import_lines):max(import_lines)+1])
+ chunks.append(CodeChunk(
+ content=import_content,
+ chunk_type='import_block',
+ start_line=min(import_lines) + 1,
+ end_line=max(import_lines) + 1,
+ file_path=file_path
+ ))
+
+ # Process classes and functions
+ for node in other_nodes:
+ if isinstance(node, ast.ClassDef):
+ chunk = self._extract_class_chunk(node, lines, file_path)
+ if chunk:
+ chunks.append(chunk)
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ chunk = self._extract_function_chunk(node, lines, file_path)
+ if chunk:
+ chunks.append(chunk)
+
+ # Handle module-level code
+ module_code = self._extract_module_level_code(tree, lines, file_path)
+ if module_code:
+ chunks.extend(module_code)
+
+ except SyntaxError as e:
+ logger.warning(f"Could not parse Python file {file_path}: {e}")
+ # Fall back to text chunking
+ return self._chunk_text(file_path, content)
+
+ return self._optimize_chunks(chunks)
+
+ def _extract_class_chunk(self, node: ast.ClassDef,
+ lines: List[str], file_path: str) -> Optional[CodeChunk]:
+ """Extract a class and its methods as a chunk."""
+ start_line = node.lineno - 1
+ end_line = getattr(node, 'end_lineno', node.lineno) - 1
+
+ class_content = '\n'.join(lines[start_line:end_line + 1])
+
+ # Check if chunk is too large, split methods if needed
+ token_count = self._count_tokens(class_content)
+
+ if token_count > self.max_chunk_size:
+ # Split into method chunks
+ method_chunks = []
+ for item in node.body:
+ if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ method_chunk = self._extract_function_chunk(item, lines, file_path)
+ if method_chunk:
+ method_chunks.append(method_chunk)
+ return method_chunks if method_chunks else None
+
+ # Extract dependencies (imports used in class)
+ dependencies = self._extract_dependencies(node)
+
+ return CodeChunk(
+ content=class_content,
+ chunk_type='class',
+ name=node.name,
+ start_line=start_line + 1,
+ end_line=end_line + 1,
+ file_path=file_path,
+ dependencies=dependencies
+ )
+
+ def _extract_function_chunk(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
+ lines: List[str], file_path: str) -> Optional[CodeChunk]:
+ """Extract a function as a chunk."""
+ start_line = node.lineno - 1
+ end_line = getattr(node, 'end_lineno', node.lineno) - 1
+
+ function_content = '\n'.join(lines[start_line:end_line + 1])
+
+ # Extract dependencies
+ dependencies = self._extract_dependencies(node)
+
+ return CodeChunk(
+ content=function_content,
+ chunk_type='function',
+ name=node.name,
+ start_line=start_line + 1,
+ end_line=end_line + 1,
+ file_path=file_path,
+ dependencies=dependencies
+ )
+
+ def _extract_dependencies(self, node) -> List[str]:
+ """Extract dependencies (imported names) used in this node."""
+ dependencies = []
+
+ for child in ast.walk(node):
+ if isinstance(child, ast.Name):
+ dependencies.append(child.id)
+ elif isinstance(child, ast.Attribute):
+ # Handle module.function calls
+ if isinstance(child.value, ast.Name):
+ dependencies.append(f"{child.value.id}.{child.attr}")
+
+ return list(set(dependencies)) # Remove duplicates
+
+ def _extract_module_level_code(self, tree: ast.AST,
+ lines: List[str], file_path: str) -> List[CodeChunk]:
+ """Extract module-level code that's not in classes or functions."""
+ chunks = []
+
+ # Find lines not covered by classes/functions
+ covered_lines = set()
+
+ for node in ast.walk(tree):
+ if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef,
+ ast.Import, ast.ImportFrom)):
+ start = node.lineno - 1
+ end = getattr(node, 'end_lineno', node.lineno) - 1
+ covered_lines.update(range(start, end + 1))
+
+ # Group uncovered lines into chunks
+ uncovered_lines = []
+ for i, line in enumerate(lines):
+ if i not in covered_lines and line.strip():
+ uncovered_lines.append((i, line))
+
+ if uncovered_lines:
+ # Group consecutive lines
+ current_chunk_lines = []
+ current_start = None
+
+ for line_num, line in uncovered_lines:
+ if current_start is None:
+ current_start = line_num
+ current_chunk_lines = [line]
+ elif line_num == current_start + len(current_chunk_lines):
+ current_chunk_lines.append(line)
+ else:
+ # Gap found, save current chunk and start new one
+ if current_chunk_lines:
+ chunk_content = '\n'.join(current_chunk_lines)
+ if self._count_tokens(chunk_content) > self.min_chunk_size:
+ chunks.append(CodeChunk(
+ content=chunk_content,
+ chunk_type='module',
+ start_line=current_start + 1,
+ end_line=current_start + len(current_chunk_lines),
+ file_path=file_path
+ ))
+
+ current_start = line_num
+ current_chunk_lines = [line]
+
+ # Add final chunk
+ if current_chunk_lines:
+ chunk_content = '\n'.join(current_chunk_lines)
+ if self._count_tokens(chunk_content) > self.min_chunk_size:
+ chunks.append(CodeChunk(
+ content=chunk_content,
+ chunk_type='module',
+ start_line=current_start + 1,
+ end_line=current_start + len(current_chunk_lines),
+ file_path=file_path
+ ))
+
+ return chunks
+
+ def _chunk_javascript(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk JavaScript/TypeScript using regex patterns (simplified AST)."""
+ # This is a simplified implementation
+ # For production, you'd want to use a proper JS/TS parser like babel or esprima
+ chunks = []
+ lines = content.split('\n')
+
+ # Find function and class boundaries using regex
+ import re
+
+ function_pattern = r'^(export\s+)?(async\s+)?function\s+(\w+)|^(export\s+)?const\s+(\w+)\s*=\s*(async\s+)?\('
+ class_pattern = r'^(export\s+)?class\s+(\w+)'
+
+ current_chunk = []
+ current_type = 'module'
+ current_name = None
+ brace_count = 0
+ in_function = False
+
+ for i, line in enumerate(lines):
+ if re.match(function_pattern, line.strip()):
+ # Start of function
+ if current_chunk:
+ chunks.append(self._create_js_chunk(
+ current_chunk, current_type, current_name, file_path
+ ))
+ current_chunk = [line]
+ current_type = 'function'
+ match = re.match(function_pattern, line.strip())
+ current_name = match.group(3) or match.group(5)
+ brace_count = line.count('{') - line.count('}')
+ in_function = True
+ elif re.match(class_pattern, line.strip()):
+ # Start of class
+ if current_chunk:
+ chunks.append(self._create_js_chunk(
+ current_chunk, current_type, current_name, file_path
+ ))
+ current_chunk = [line]
+ current_type = 'class'
+ match = re.match(class_pattern, line.strip())
+ current_name = match.group(2)
+ brace_count = line.count('{') - line.count('}')
+ in_function = True
+ else:
+ current_chunk.append(line)
+ if in_function:
+ brace_count += line.count('{') - line.count('}')
+ if brace_count <= 0:
+ # End of function/class
+ chunks.append(self._create_js_chunk(
+ current_chunk, current_type, current_name, file_path
+ ))
+ current_chunk = []
+ current_type = 'module'
+ current_name = None
+ in_function = False
+
+ # Add remaining content
+ if current_chunk:
+ chunks.append(self._create_js_chunk(
+ current_chunk, current_type, current_name, file_path
+ ))
+
+ return self._optimize_chunks(chunks)
+
+ def _create_js_chunk(self, lines: List[str], chunk_type: str,
+ name: Optional[str], file_path: str) -> CodeChunk:
+ """Create a JavaScript chunk from lines."""
+ content = '\n'.join(lines)
+ return CodeChunk(
+ content=content,
+ chunk_type=chunk_type,
+ name=name,
+ file_path=file_path
+ )
+
+ def _chunk_markdown(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk Markdown by headers and sections."""
+ chunks = []
+ lines = content.split('\n')
+
+ current_chunk = []
+ current_header = None
+ header_level = 0
+
+ for line in lines:
+ if line.startswith('#'):
+ # Header found
+ if current_chunk:
+ chunks.append(CodeChunk(
+ content='\n'.join(current_chunk),
+ chunk_type='section',
+ name=current_header,
+ file_path=file_path
+ ))
+
+ current_header = line.strip('#').strip()
+ header_level = len(line) - len(line.lstrip('#'))
+ current_chunk = [line]
+ else:
+ current_chunk.append(line)
+
+ # Add final chunk
+ if current_chunk:
+ chunks.append(CodeChunk(
+ content='\n'.join(current_chunk),
+ chunk_type='section',
+ name=current_header,
+ file_path=file_path
+ ))
+
+ return chunks
+
+ def _chunk_config(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk configuration files as single units (they're usually coherent)."""
+ return [CodeChunk(
+ content=content,
+ chunk_type='config',
+ name=Path(file_path).name,
+ file_path=file_path
+ )]
+
+ def _chunk_text(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Fall back to simple text chunking."""
+ chunks = []
+ lines = content.split('\n')
+
+ current_chunk = []
+ current_tokens = 0
+
+ for line in lines:
+ line_tokens = self._count_tokens(line)
+
+ if current_tokens + line_tokens > self.max_chunk_size and current_chunk:
+ chunks.append(CodeChunk(
+ content='\n'.join(current_chunk),
+ chunk_type='text',
+ file_path=file_path
+ ))
+ current_chunk = []
+ current_tokens = 0
+
+ current_chunk.append(line)
+ current_tokens += line_tokens
+
+ if current_chunk:
+ chunks.append(CodeChunk(
+ content='\n'.join(current_chunk),
+ chunk_type='text',
+ file_path=file_path
+ ))
+
+ return chunks
+
+ def _optimize_chunks(self, chunks: List[CodeChunk]) -> List[CodeChunk]:
+ """Optimize chunks by merging small ones and splitting large ones."""
+ optimized = []
+
+ for chunk in chunks:
+ token_count = self._count_tokens(chunk.content)
+
+ if token_count > self.max_chunk_size:
+ # Split large chunk
+ split_chunks = self._split_large_chunk(chunk)
+ optimized.extend(split_chunks)
+ elif token_count < self.min_chunk_size and optimized:
+ # Merge with previous chunk if it won't exceed max size
+ prev_chunk = optimized[-1]
+ prev_tokens = self._count_tokens(prev_chunk.content)
+
+ if prev_tokens + token_count <= self.max_chunk_size:
+ # Merge chunks
+ merged_content = f"{prev_chunk.content}\n\n{chunk.content}"
+ optimized[-1] = CodeChunk(
+ content=merged_content,
+ chunk_type='merged',
+ file_path=chunk.file_path,
+ dependencies=list(set(prev_chunk.dependencies + chunk.dependencies))
+ )
+ else:
+ optimized.append(chunk)
+ else:
+ optimized.append(chunk)
+
+ return optimized
+
+ def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]:
+ """Split a chunk that's too large."""
+ lines = chunk.content.split('\n')
+ split_chunks = []
+
+ current_lines = []
+ current_tokens = 0
+
+ for line in lines:
+ line_tokens = self._count_tokens(line)
+
+ if current_tokens + line_tokens > self.max_chunk_size and current_lines:
+ split_chunks.append(CodeChunk(
+ content='\n'.join(current_lines),
+ chunk_type=f"{chunk.chunk_type}_split",
+ name=f"{chunk.name}_part_{len(split_chunks) + 1}" if chunk.name else None,
+ file_path=chunk.file_path,
+ dependencies=chunk.dependencies
+ ))
+ current_lines = []
+ current_tokens = 0
+
+ current_lines.append(line)
+ current_tokens += line_tokens
+
+ if current_lines:
+ split_chunks.append(CodeChunk(
+ content='\n'.join(current_lines),
+ chunk_type=f"{chunk.chunk_type}_split",
+ name=f"{chunk.name}_part_{len(split_chunks) + 1}" if chunk.name else None,
+ file_path=chunk.file_path,
+ dependencies=chunk.dependencies
+ ))
+
+ return split_chunks
+
+ def _count_tokens(self, text: str) -> int:
+ """Count tokens in text with fallback for when tiktoken is not available."""
+ if TIKTOKEN_AVAILABLE and self.encoding:
+ return len(self.encoding.encode(text))
+ else:
+ # Simple approximation: 4 characters per token
+ return len(text) // 4
+
+ # Placeholder methods for other languages
+ def _chunk_java_kotlin(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk Java/Kotlin code."""
+ # Simplified implementation - in production, use proper parsers
+ return self._chunk_text(file_path, content)
+
+ def _chunk_cpp(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk C++ code."""
+ return self._chunk_text(file_path, content)
+
+ def _chunk_rust(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk Rust code."""
+ return self._chunk_text(file_path, content)
+
+ def _chunk_go(self, file_path: str, content: str) -> List[CodeChunk]:
+ """Chunk Go code."""
+ return self._chunk_text(file_path, content)
\ No newline at end of file
diff --git a/api/ast_integration.py b/api/ast_integration.py
new file mode 100644
index 00000000..7b9c5304
--- /dev/null
+++ b/api/ast_integration.py
@@ -0,0 +1,301 @@
+"""
+Integration layer for AST chunking with existing data pipeline.
+This bridges the AST chunker with the current adalflow TextSplitter interface.
+"""
+
+from typing import List, Dict, Any, Union
+import logging
+from pathlib import Path
+
+from adalflow.core.component import Component
+from adalflow.core.document import Document
+
+from .ast_chunker import ASTChunker, CodeChunk
+
+logger = logging.getLogger(__name__)
+
+
+class ASTTextSplitter(Component):
+ """
+ AST-aware text splitter that integrates with adalflow pipeline.
+ Provides both AST chunking for code files and fallback text chunking.
+ """
+
+ def __init__(self,
+ split_by: str = "ast",
+ chunk_size: int = 2000,
+ chunk_overlap: int = 100,
+ min_chunk_size: int = 100,
+ overlap_lines: int = 5,
+ preserve_structure: bool = True,
+ fallback_to_text: bool = True):
+ """
+ Initialize AST text splitter.
+
+ Args:
+ split_by: Splitting strategy - "ast", "word", "character"
+ chunk_size: Maximum tokens per chunk
+ chunk_overlap: Token overlap between chunks (for text mode)
+ min_chunk_size: Minimum tokens per chunk
+ overlap_lines: Lines of overlap for AST chunks
+ preserve_structure: Whether to keep code structures together
+ fallback_to_text: Fall back to text splitting for non-code files
+ """
+ super().__init__()
+ self.split_by = split_by
+ self.chunk_size = chunk_size
+ self.chunk_overlap = chunk_overlap
+ self.min_chunk_size = min_chunk_size
+ self.overlap_lines = overlap_lines
+ self.preserve_structure = preserve_structure
+ self.fallback_to_text = fallback_to_text
+
+ # Initialize AST chunker
+ self.ast_chunker = ASTChunker(
+ max_chunk_size=chunk_size,
+ min_chunk_size=min_chunk_size,
+ overlap_lines=overlap_lines,
+ preserve_structure=preserve_structure
+ )
+
+ # Initialize text chunker for fallback
+ if fallback_to_text:
+ from adalflow.core.text_splitter import TextSplitter
+ self.text_splitter = TextSplitter(
+ split_by="word" if split_by == "ast" else split_by,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap
+ )
+
+ def call(self, documents: List[Document]) -> List[Document]:
+ """
+ Split documents using AST-aware chunking.
+
+ Args:
+ documents: Input documents to split
+
+ Returns:
+ List of document chunks
+ """
+ result_docs = []
+
+ for doc in documents:
+ try:
+ if self.split_by == "ast":
+ chunks = self._ast_split_document(doc)
+ else:
+ # Use traditional text splitting
+ chunks = self.text_splitter.call([doc])
+
+ result_docs.extend(chunks)
+
+ except Exception as e:
+ logger.error(f"Error splitting document {doc.id}: {e}")
+ if self.fallback_to_text and hasattr(self, 'text_splitter'):
+ # Fallback to text splitting
+ chunks = self.text_splitter.call([doc])
+ result_docs.extend(chunks)
+ else:
+ # Keep original document if all else fails
+ result_docs.append(doc)
+
+ return result_docs
+
+ def _ast_split_document(self, document: Document) -> List[Document]:
+ """Split a single document using AST chunking."""
+ # Extract file path from document metadata
+ file_path = document.meta_data.get('file_path', '')
+ if not file_path:
+ file_path = document.meta_data.get('source', 'unknown')
+
+ # Use AST chunker
+ code_chunks = self.ast_chunker.chunk_file(file_path, document.text)
+
+ # Convert CodeChunk objects to Document objects
+ result_docs = []
+ for i, chunk in enumerate(code_chunks):
+ # Create enhanced metadata
+ enhanced_metadata = document.meta_data.copy()
+ enhanced_metadata.update({
+ 'chunk_id': i,
+ 'chunk_type': chunk.chunk_type,
+ 'chunk_name': chunk.name,
+ 'start_line': chunk.start_line,
+ 'end_line': chunk.end_line,
+ 'dependencies': chunk.dependencies,
+ 'file_path': chunk.file_path,
+ 'original_doc_id': document.id
+ })
+
+ # Create new document
+ chunk_doc = Document(
+ text=chunk.content,
+ id=f"{document.id}_chunk_{i}",
+ meta_data=enhanced_metadata
+ )
+
+ result_docs.append(chunk_doc)
+
+ return result_docs
+
+ def get_chunk_metadata(self, chunk_doc: Document) -> Dict[str, Any]:
+ """Extract chunk-specific metadata for enhanced retrieval."""
+ return {
+ 'chunk_type': chunk_doc.meta_data.get('chunk_type', 'unknown'),
+ 'chunk_name': chunk_doc.meta_data.get('chunk_name'),
+ 'start_line': chunk_doc.meta_data.get('start_line', 0),
+ 'end_line': chunk_doc.meta_data.get('end_line', 0),
+ 'dependencies': chunk_doc.meta_data.get('dependencies', []),
+ 'file_path': chunk_doc.meta_data.get('file_path', ''),
+ 'language': self._detect_language(chunk_doc.meta_data.get('file_path', ''))
+ }
+
+ def _detect_language(self, file_path: str) -> str:
+ """Detect programming language from file extension."""
+ ext = Path(file_path).suffix.lower()
+
+ language_map = {
+ '.py': 'python',
+ '.js': 'javascript',
+ '.ts': 'typescript',
+ '.jsx': 'javascript',
+ '.tsx': 'typescript',
+ '.java': 'java',
+ '.kt': 'kotlin',
+ '.cpp': 'cpp',
+ '.cc': 'cpp',
+ '.cxx': 'cpp',
+ '.c': 'c',
+ '.h': 'c',
+ '.hpp': 'cpp',
+ '.rs': 'rust',
+ '.go': 'go',
+ '.rb': 'ruby',
+ '.php': 'php',
+ '.cs': 'csharp',
+ '.swift': 'swift',
+ '.md': 'markdown',
+ '.rst': 'restructuredtext',
+ '.json': 'json',
+ '.yaml': 'yaml',
+ '.yml': 'yaml',
+ '.toml': 'toml',
+ '.xml': 'xml',
+ '.html': 'html',
+ '.css': 'css',
+ '.scss': 'scss',
+ '.sass': 'sass'
+ }
+
+ return language_map.get(ext, 'text')
+
+
+class EnhancedRAGRetriever:
+ """Enhanced retriever that uses AST chunk metadata for better results."""
+
+ def __init__(self, base_retriever):
+ """Initialize with base retriever (FAISS, etc.)."""
+ self.base_retriever = base_retriever
+
+ def retrieve(self, query: str, top_k: int = 5,
+ filter_by_type: List[str] = None,
+ prefer_functions: bool = False,
+ prefer_classes: bool = False) -> List[Document]:
+ """
+ Enhanced retrieval with AST-aware filtering.
+
+ Args:
+ query: Search query
+ top_k: Number of results to return
+ filter_by_type: Filter by chunk types (e.g., ['function', 'class'])
+ prefer_functions: Boost function chunks in results
+ prefer_classes: Boost class chunks in results
+ """
+ # Get base results
+ base_results = self.base_retriever.retrieve(query, top_k * 2) # Get more for filtering
+
+ # Apply AST-aware filtering and ranking
+ filtered_results = []
+
+ for doc in base_results:
+ chunk_type = doc.meta_data.get('chunk_type', 'unknown')
+
+ # Apply type filter
+ if filter_by_type and chunk_type not in filter_by_type:
+ continue
+
+ # Calculate boost score
+ boost_score = 1.0
+ if prefer_functions and chunk_type == 'function':
+ boost_score = 1.5
+ elif prefer_classes and chunk_type == 'class':
+ boost_score = 1.5
+
+ # Add boost to similarity score if available
+ if hasattr(doc, 'similarity_score'):
+ doc.similarity_score *= boost_score
+
+ filtered_results.append(doc)
+
+ # Sort by similarity score and return top_k
+ if hasattr(filtered_results[0], 'similarity_score'):
+ filtered_results.sort(key=lambda x: x.similarity_score, reverse=True)
+
+ return filtered_results[:top_k]
+
+ def retrieve_related_code(self, query: str, top_k: int = 5) -> Dict[str, List[Document]]:
+ """
+ Retrieve related code organized by type.
+
+ Returns:
+ Dictionary with keys: 'functions', 'classes', 'imports', 'modules'
+ """
+ all_results = self.base_retriever.retrieve(query, top_k * 4)
+
+ organized_results = {
+ 'functions': [],
+ 'classes': [],
+ 'imports': [],
+ 'modules': [],
+ 'other': []
+ }
+
+ for doc in all_results:
+ chunk_type = doc.meta_data.get('chunk_type', 'other')
+
+ if chunk_type == 'function':
+ organized_results['functions'].append(doc)
+ elif chunk_type == 'class':
+ organized_results['classes'].append(doc)
+ elif chunk_type == 'import_block':
+ organized_results['imports'].append(doc)
+ elif chunk_type == 'module':
+ organized_results['modules'].append(doc)
+ else:
+ organized_results['other'].append(doc)
+
+ # Limit each category
+ for key in organized_results:
+ organized_results[key] = organized_results[key][:top_k//4 + 1]
+
+ return organized_results
+
+
+def create_ast_config() -> Dict[str, Any]:
+ """Create default AST chunking configuration."""
+ return {
+ "text_splitter": {
+ "split_by": "ast",
+ "chunk_size": 2000,
+ "chunk_overlap": 100,
+ "min_chunk_size": 100,
+ "overlap_lines": 5,
+ "preserve_structure": True,
+ "fallback_to_text": True
+ },
+ "retrieval": {
+ "prefer_functions": True,
+ "prefer_classes": True,
+ "boost_code_chunks": 1.5
+ }
+ }
\ No newline at end of file
diff --git a/api/config/embedder.ast.json b/api/config/embedder.ast.json
new file mode 100644
index 00000000..90940bf1
--- /dev/null
+++ b/api/config/embedder.ast.json
@@ -0,0 +1,57 @@
+{
+ "embedder_ollama": {
+ "client_class": "OllamaClient",
+ "model_kwargs": {
+ "model": "nomic-embed-text"
+ }
+ },
+ "retriever": {
+ "top_k": 20
+ },
+ "text_splitter": {
+ "split_by": "ast",
+ "chunk_size": 2000,
+ "chunk_overlap": 100,
+ "min_chunk_size": 100,
+ "overlap_lines": 5,
+ "preserve_structure": true,
+ "fallback_to_text": true
+ },
+ "retrieval": {
+ "prefer_functions": true,
+ "prefer_classes": true,
+ "boost_code_chunks": 1.5
+ },
+ "language_support": {
+ "python": {
+ "enabled": true,
+ "max_function_size": 1500,
+ "max_class_size": 3000,
+ "preserve_docstrings": true
+ },
+ "javascript": {
+ "enabled": true,
+ "max_function_size": 1500,
+ "parse_jsx": true
+ },
+ "typescript": {
+ "enabled": true,
+ "max_function_size": 1500,
+ "parse_tsx": true
+ },
+ "markdown": {
+ "enabled": true,
+ "split_by_headers": true,
+ "max_section_size": 2000
+ },
+ "config_files": {
+ "enabled": true,
+ "treat_as_single_chunk": true
+ }
+ },
+ "fallback": {
+ "split_by": "word",
+ "chunk_size": 350,
+ "chunk_overlap": 100
+ }
+}
\ No newline at end of file
diff --git a/api/data_pipeline.py b/api/data_pipeline.py
index fcea34ce..81ce9c6b 100644
--- a/api/data_pipeline.py
+++ b/api/data_pipeline.py
@@ -393,7 +393,21 @@ def prepare_data_pipeline(embedder_type: str = None, is_ollama_embedder: bool =
if embedder_type is None:
embedder_type = get_embedder_type()
- splitter = TextSplitter(**configs["text_splitter"])
+ # Choose splitter based on configuration
+ split_by = configs.get("text_splitter", {}).get("split_by", "word")
+ if split_by == "ast":
+ # Use AST splitter for better code understanding
+ try:
+ from .ast_integration import ASTTextSplitter
+ splitter = ASTTextSplitter(**configs["text_splitter"])
+ print("š Using AST-based chunking for better code understanding")
+ except ImportError as e:
+ print(f"ā ļø AST chunking not available, falling back to text: {e}")
+ splitter = TextSplitter(**configs["text_splitter"])
+ else:
+ # Use traditional text splitter
+ splitter = TextSplitter(**configs["text_splitter"])
+
embedder_config = get_embedder_config()
embedder = get_embedder(embedder_type=embedder_type)
diff --git a/api/enable_ast.py b/api/enable_ast.py
new file mode 100644
index 00000000..74328a1b
--- /dev/null
+++ b/api/enable_ast.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+"""
+Simple script to enable/disable AST chunking in DeepWiki.
+"""
+
+import json
+import shutil
+import os
+import sys
+
+def enable_ast_chunking():
+ """Enable AST-based chunking."""
+ embedder_config = "config/embedder.json"
+ ast_config = "config/embedder.ast.json"
+ backup_config = "config/embedder.json.backup"
+
+ # Check if AST config exists
+ if not os.path.exists(ast_config):
+ print(f"ā AST config not found: {ast_config}")
+ return False
+
+ # Backup current config
+ if os.path.exists(embedder_config):
+ shutil.copy2(embedder_config, backup_config)
+ print(f"ā
Backed up current config to {backup_config}")
+
+ # Load current config to preserve embedder settings
+ with open(embedder_config, 'r') as f:
+ current_config = json.load(f)
+ else:
+ current_config = {}
+
+ # Load AST config
+ with open(ast_config, 'r') as f:
+ ast_config_data = json.load(f)
+
+ # Merge embedder settings from current config into AST config
+ if 'embedder_ollama' in current_config:
+ ast_config_data['embedder_ollama'] = current_config['embedder_ollama']
+ if 'retriever' in current_config:
+ ast_config_data['retriever'] = current_config['retriever']
+
+ # Write merged config
+ with open(embedder_config, 'w') as f:
+ json.dump(ast_config_data, f, indent=2)
+
+ print(f"ā
Enabled AST chunking with preserved embedder settings")
+
+ # Verify the switch
+ with open(embedder_config, 'r') as f:
+ config = json.load(f)
+ split_by = config.get('text_splitter', {}).get('split_by', 'unknown')
+ print(f"ā
Current chunking mode: {split_by}")
+
+ return True
+
+def disable_ast_chunking():
+ """Disable AST chunking and restore previous config."""
+ embedder_config = "config/embedder.json"
+ backup_config = "config/embedder.json.backup"
+
+ if os.path.exists(backup_config):
+ shutil.copy2(backup_config, embedder_config)
+ print(f"ā
Restored previous config from {backup_config}")
+ else:
+ # Create default text config
+ default_config = {
+ "embedder_ollama": {
+ "client_class": "OllamaClient",
+ "model_kwargs": {
+ "model": "nomic-embed-text"
+ }
+ },
+ "retriever": {
+ "top_k": 20
+ },
+ "text_splitter": {
+ "split_by": "word",
+ "chunk_size": 350,
+ "chunk_overlap": 100
+ }
+ }
+
+ with open(embedder_config, 'w') as f:
+ json.dump(default_config, f, indent=2)
+
+ print(f"ā
Created default text chunking config")
+
+ # Verify the switch
+ with open(embedder_config, 'r') as f:
+ config = json.load(f)
+ split_by = config.get('text_splitter', {}).get('split_by', 'unknown')
+ print(f"ā
Current chunking mode: {split_by}")
+
+def check_status():
+ """Check current chunking status."""
+ embedder_config = "config/embedder.json"
+
+ if not os.path.exists(embedder_config):
+ print("ā No embedder config found")
+ return
+
+ with open(embedder_config, 'r') as f:
+ config = json.load(f)
+ split_by = config.get('text_splitter', {}).get('split_by', 'word')
+ chunk_size = config.get('text_splitter', {}).get('chunk_size', 0)
+
+ print(f"\nš Current Configuration:")
+ print(f" Chunking mode: {split_by}")
+ print(f" Chunk size: {chunk_size}")
+
+ if split_by == "ast":
+ print(" Status: š AST chunking ENABLED")
+ print(" Benefits: Semantic code understanding, function/class boundaries preserved")
+ else:
+ print(" Status: š Traditional text chunking")
+ print(" Note: Consider enabling AST chunking for better code understanding")
+
+def main():
+ if len(sys.argv) < 2:
+ print("Usage: python enable_ast.py [enable|disable|status]")
+ print("\nCommands:")
+ print(" enable - Enable AST-based chunking")
+ print(" disable - Disable AST chunking (restore text chunking)")
+ print(" status - Show current chunking status")
+ sys.exit(1)
+
+ command = sys.argv[1].lower()
+
+ if command == "enable":
+ enable_ast_chunking()
+ elif command == "disable":
+ disable_ast_chunking()
+ elif command == "status":
+ check_status()
+ else:
+ print(f"Unknown command: {command}")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/api/enhanced_pipeline.py b/api/enhanced_pipeline.py
new file mode 100644
index 00000000..fd050caf
--- /dev/null
+++ b/api/enhanced_pipeline.py
@@ -0,0 +1,324 @@
+"""
+Enhanced data pipeline with AST chunking support.
+This patch integrates AST chunking into the existing adalflow pipeline.
+"""
+
+import os
+from typing import List, Dict, Any, Optional
+import logging
+
+from adalflow.core.document import Document
+from adalflow.core.db import LocalDB
+import adalflow as adal
+
+from .ast_integration import ASTTextSplitter, EnhancedRAGRetriever
+from .config import get_embedder_config, get_embedder_type
+from .data_pipeline import get_embedder, OllamaDocumentProcessor, ToEmbeddings
+
+logger = logging.getLogger(__name__)
+
+
+def prepare_enhanced_data_pipeline(embedder_type: str = None,
+ is_ollama_embedder: bool = None,
+ use_ast_chunking: bool = True) -> adal.Sequential:
+ """
+ Prepare enhanced data pipeline with optional AST chunking support.
+
+ Args:
+ embedder_type: The embedder type ('openai', 'google', 'ollama')
+ is_ollama_embedder: DEPRECATED. Use embedder_type instead
+ use_ast_chunking: Whether to use AST-based chunking
+
+ Returns:
+ Sequential pipeline with splitter and embedder
+ """
+ # Load configuration
+ configs = get_embedder_config()
+
+ # Handle legacy parameter
+ if is_ollama_embedder is not None:
+ embedder_type = 'ollama' if is_ollama_embedder else None
+
+ # Determine embedder type if not specified
+ if embedder_type is None:
+ embedder_type = get_embedder_type()
+
+ # Choose splitter based on configuration
+ if use_ast_chunking and configs.get("text_splitter", {}).get("split_by") == "ast":
+ # Use AST splitter
+ splitter_config = configs["text_splitter"]
+ splitter = ASTTextSplitter(**splitter_config)
+ logger.info("Using AST-based text splitting")
+ else:
+ # Use traditional text splitter
+ from adalflow.core.text_splitter import TextSplitter
+ splitter = TextSplitter(**configs["text_splitter"])
+ logger.info("Using traditional text splitting")
+
+ # Get embedder configuration and instance
+ embedder_config = get_embedder_config()
+ embedder = get_embedder(embedder_type=embedder_type)
+
+ # Choose appropriate processor based on embedder type
+ if embedder_type == 'ollama':
+ # Use Ollama document processor for single-document processing
+ embedder_transformer = OllamaDocumentProcessor(embedder=embedder)
+ else:
+ # Use batch processing for OpenAI and Google embedders
+ batch_size = embedder_config.get("batch_size", 500)
+ embedder_transformer = ToEmbeddings(
+ embedder=embedder, batch_size=batch_size
+ )
+
+ # Create sequential pipeline
+ data_transformer = adal.Sequential(
+ splitter, embedder_transformer
+ )
+
+ return data_transformer
+
+
+def transform_documents_and_save_to_enhanced_db(
+ documents: List[Document],
+ db_path: str,
+ embedder_type: str = None,
+ is_ollama_embedder: bool = None,
+ use_ast_chunking: bool = None
+) -> LocalDB:
+ """
+ Enhanced document transformation with AST chunking support.
+
+ Args:
+ documents: List of Document objects
+ db_path: Path to the local database file
+ embedder_type: The embedder type ('openai', 'google', 'ollama')
+ is_ollama_embedder: DEPRECATED. Use embedder_type instead
+ use_ast_chunking: Whether to use AST chunking (auto-detected if None)
+
+ Returns:
+ LocalDB instance with processed documents
+ """
+ # Auto-detect AST chunking preference if not specified
+ if use_ast_chunking is None:
+ configs = get_embedder_config()
+ use_ast_chunking = configs.get("text_splitter", {}).get("split_by") == "ast"
+
+ # Get the enhanced data transformer
+ data_transformer = prepare_enhanced_data_pipeline(
+ embedder_type=embedder_type,
+ is_ollama_embedder=is_ollama_embedder,
+ use_ast_chunking=use_ast_chunking
+ )
+
+ # Save the documents to a local database
+ db = LocalDB()
+ db.register_transformer(transformer=data_transformer, key="split_and_embed")
+ db.load(documents)
+ db.transform(key="split_and_embed")
+
+ # Ensure directory exists and save
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
+ db.save_state(filepath=db_path)
+
+ # Log chunking statistics
+ _log_chunking_stats(db, use_ast_chunking)
+
+ return db
+
+
+def _log_chunking_stats(db: LocalDB, used_ast: bool):
+ """Log statistics about the chunking process."""
+ try:
+ documents = db.get_transformed_data("split_and_embed")
+ if not documents:
+ return
+
+ total_chunks = len(documents)
+
+ if used_ast:
+ # Analyze AST chunk types
+ chunk_types = {}
+ languages = {}
+
+ for doc in documents:
+ chunk_type = doc.meta_data.get('chunk_type', 'unknown')
+ language = doc.meta_data.get('language', 'unknown')
+
+ chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1
+ languages[language] = languages.get(language, 0) + 1
+
+ logger.info(f"AST chunking created {total_chunks} chunks")
+ logger.info(f"Chunk types: {dict(chunk_types)}")
+ logger.info(f"Languages: {dict(languages)}")
+
+ # Log function and class statistics
+ function_chunks = chunk_types.get('function', 0)
+ class_chunks = chunk_types.get('class', 0)
+
+ if function_chunks > 0:
+ logger.info(f"Found {function_chunks} function chunks")
+ if class_chunks > 0:
+ logger.info(f"Found {class_chunks} class chunks")
+ else:
+ logger.info(f"Traditional chunking created {total_chunks} chunks")
+
+ except Exception as e:
+ logger.error(f"Error logging chunking stats: {e}")
+
+
+def create_enhanced_retriever(db: LocalDB,
+ embedder_type: str = None) -> EnhancedRAGRetriever:
+ """
+ Create an enhanced retriever with AST-aware capabilities.
+
+ Args:
+ db: LocalDB instance with processed documents
+ embedder_type: The embedder type used
+
+ Returns:
+ EnhancedRAGRetriever instance
+ """
+ # Get base retriever from database
+ base_retriever = db.get_retriever() # This would need to be implemented in LocalDB
+
+ # Create enhanced retriever
+ enhanced_retriever = EnhancedRAGRetriever(base_retriever)
+
+ return enhanced_retriever
+
+
+# Configuration utilities
+def switch_to_ast_chunking():
+ """Switch the system to use AST chunking."""
+ import json
+ import shutil
+
+ # Backup current config
+ embedder_config_path = "api/config/embedder.json"
+ backup_path = f"{embedder_config_path}.backup"
+
+ if os.path.exists(embedder_config_path):
+ shutil.copy2(embedder_config_path, backup_path)
+ logger.info(f"Backed up current config to {backup_path}")
+
+ # Copy AST config
+ ast_config_path = "api/config/embedder.ast.json"
+ if os.path.exists(ast_config_path):
+ shutil.copy2(ast_config_path, embedder_config_path)
+ logger.info("Switched to AST chunking configuration")
+ else:
+ logger.error(f"AST config file not found: {ast_config_path}")
+
+
+def switch_to_text_chunking():
+ """Switch back to traditional text chunking."""
+ import json
+
+ embedder_config_path = "api/config/embedder.json"
+ backup_path = f"{embedder_config_path}.backup"
+
+ if os.path.exists(backup_path):
+ shutil.copy2(backup_path, embedder_config_path)
+ logger.info("Switched back to text chunking configuration")
+ else:
+ # Create default text config
+ default_config = {
+ "text_splitter": {
+ "split_by": "word",
+ "chunk_size": 350,
+ "chunk_overlap": 100
+ }
+ }
+
+ with open(embedder_config_path, 'w') as f:
+ json.dump(default_config, f, indent=2)
+
+ logger.info("Created default text chunking configuration")
+
+
+def get_chunking_mode() -> str:
+ """Get current chunking mode."""
+ try:
+ configs = get_embedder_config()
+ split_by = configs.get("text_splitter", {}).get("split_by", "word")
+ return "ast" if split_by == "ast" else "text"
+ except Exception:
+ return "text" # Default fallback
+
+
+# Example usage for testing
+def test_ast_chunking():
+ """Test function to demonstrate AST chunking capabilities."""
+
+ # Sample Python code for testing
+ sample_code = '''
+import os
+import sys
+from typing import List, Dict
+
+class DataProcessor:
+ """A class for processing data."""
+
+ def __init__(self, config: Dict):
+ self.config = config
+ self.data = []
+
+ def load_data(self, file_path: str) -> List[Dict]:
+ """Load data from file."""
+ with open(file_path, 'r') as f:
+ return json.load(f)
+
+ def process_data(self, data: List[Dict]) -> List[Dict]:
+ """Process the loaded data."""
+ processed = []
+ for item in data:
+ processed_item = self._process_item(item)
+ processed.append(processed_item)
+ return processed
+
+ def _process_item(self, item: Dict) -> Dict:
+ """Process a single item."""
+ return {**item, 'processed': True}
+
+def main():
+ """Main function."""
+ processor = DataProcessor({'debug': True})
+ data = processor.load_data('data.json')
+ result = processor.process_data(data)
+ print(f"Processed {len(result)} items")
+
+if __name__ == "__main__":
+ main()
+'''
+
+ # Create test document
+ test_doc = Document(
+ text=sample_code,
+ id="test_python_file",
+ meta_data={'file_path': 'test.py'}
+ )
+
+ # Test AST chunking
+ from .ast_integration import ASTTextSplitter
+
+ ast_splitter = ASTTextSplitter(
+ split_by="ast",
+ chunk_size=2000,
+ min_chunk_size=100
+ )
+
+ chunks = ast_splitter.call([test_doc])
+
+ print(f"\nAST Chunking Results:")
+ print(f"Created {len(chunks)} chunks from test file")
+
+ for i, chunk in enumerate(chunks):
+ chunk_type = chunk.meta_data.get('chunk_type', 'unknown')
+ chunk_name = chunk.meta_data.get('chunk_name', 'unnamed')
+ print(f"\nChunk {i+1}: {chunk_type} - {chunk_name}")
+ print(f"Lines {chunk.meta_data.get('start_line', 0)}-{chunk.meta_data.get('end_line', 0)}")
+ print(f"Content preview: {chunk.text[:100]}...")
+
+
+if __name__ == "__main__":
+ test_ast_chunking()
\ No newline at end of file
diff --git a/api/prompts.py b/api/prompts.py
index 61ef0a4d..fcc41c9e 100644
--- a/api/prompts.py
+++ b/api/prompts.py
@@ -1,5 +1,56 @@
"""Module containing all prompts used in the DeepWiki project."""
+# System prompt for XML Wiki Structure Generation
+WIKI_STRUCTURE_SYSTEM_PROMPT = r"""
+You are an expert code analyst tasked with analyzing a repository and creating a structured wiki outline.
+
+CRITICAL XML FORMATTING INSTRUCTIONS:
+- You MUST return ONLY valid XML with NO additional text before or after
+- DO NOT wrap the XML in markdown code blocks (no ``` or ```xml)
+- DO NOT include any explanation or commentary
+- Start directly with and end with
+- Ensure all XML tags are properly closed
+- Use proper XML escaping for special characters (& < > " ')
+
+XML STRUCTURE REQUIREMENTS:
+- The root element must be
+- Include a element for the wiki title
+- Include a element for the repository description
+- For comprehensive mode: Include a element containing section hierarchies
+- Include a element containing all wiki pages
+- Each page must have: id, title, description, importance, relevant_files, related_pages
+
+Example XML structure (comprehensive mode):
+
+ Repository Wiki
+ A comprehensive guide
+
+
+ Overview
+
+ page-1
+
+
+
+
+
+ Introduction
+ Overview of the project
+ high
+
+ README.md
+
+
+ page-2
+
+ section-1
+
+
+
+
+IMPORTANT: Your entire response must be valid XML. Do not include any text outside the tags.
+"""
+
# System prompt for RAG
RAG_SYSTEM_PROMPT = r"""
You are a code assistant which answers user questions on a Github Repo.
diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py
index 2a7cce9e..906ce0c5 100644
--- a/api/websocket_wiki.py
+++ b/api/websocket_wiki.py
@@ -9,12 +9,14 @@
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from pydantic import BaseModel, Field
+from api.api import get_local_repo_structure
from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY
from api.data_pipeline import count_tokens, get_file_content
from api.openai_client import OpenAIClient
from api.openrouter_client import OpenRouterClient
from api.azureai_client import AzureAIClient
from api.dashscope_client import DashscopeClient
+from api.prompts import WIKI_STRUCTURE_SYSTEM_PROMPT
from api.rag import RAG
# Configure logging
@@ -49,6 +51,135 @@ class ChatCompletionRequest(BaseModel):
included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively")
included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively")
+
+async def handle_response_stream(
+ response,
+ websocket: WebSocket,
+ is_structure_generation: bool,
+ provider: str
+) -> None:
+ """
+ Handle streaming or accumulated response based on context.
+
+ This helper function eliminates code duplication between different providers
+ by handling both streaming (for chat) and accumulating (for wiki structure) modes.
+
+ Args:
+ response: Async iterator from the model's response
+ websocket: Active WebSocket connection
+ is_structure_generation: If True, accumulate full response before sending
+ provider: Provider name for provider-specific text extraction
+ """
+ if is_structure_generation:
+ # Accumulate full response for wiki structure generation
+ logger.info("Accumulating full response for wiki structure generation")
+ full_response = ""
+ chunk_count = 0
+
+ async for chunk in response:
+ chunk_count += 1
+
+ # Extract text based on provider and response format
+ text = None
+ if provider == "ollama":
+ # Ollama-specific extraction
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'):
+ text = chunk.message.content
+ elif hasattr(chunk, 'response'):
+ text = chunk.response
+ elif hasattr(chunk, 'text'):
+ text = chunk.text
+ else:
+ text = str(chunk)
+
+ logger.info(f"Chunk {chunk_count}: type={type(chunk)}, text_len={len(text) if text else 0}, starts_with={text[:50] if text else 'None'}")
+
+ # Filter out metadata chunks
+ if text and not text.startswith('model=') and not text.startswith('created_at='):
+ text = text.replace('', '').replace('', '')
+ full_response += text
+ logger.info(f"Added to full_response, new length: {len(full_response)}")
+ elif provider in ["openai", "azure"]:
+ # OpenAI/Azure-style responses with choices and delta
+ choices = getattr(chunk, "choices", [])
+ if len(choices) > 0:
+ delta = getattr(choices[0], "delta", None)
+ if delta is not None:
+ text = getattr(delta, "content", None)
+ if text:
+ full_response += text
+ else:
+ # OpenRouter and other providers - simpler text extraction
+ if isinstance(chunk, str):
+ text = chunk
+ elif hasattr(chunk, 'content'):
+ text = chunk.content
+ elif hasattr(chunk, 'text'):
+ text = chunk.text
+ else:
+ text = str(chunk)
+
+ if text:
+ full_response += text
+
+ # Strip markdown code blocks if present
+ cleaned_response = full_response.strip()
+ if cleaned_response.startswith('```xml'):
+ cleaned_response = cleaned_response[6:] # Remove ```xml
+ elif cleaned_response.startswith('```'):
+ cleaned_response = cleaned_response[3:] # Remove ```
+ if cleaned_response.endswith('```'):
+ cleaned_response = cleaned_response[:-3] # Remove trailing ```
+ cleaned_response = cleaned_response.strip()
+
+ # Send the complete response at once
+ logger.info(f"Total chunks processed: {chunk_count}, Sending complete XML structure ({len(cleaned_response)} chars)")
+ logger.info(f"First 500 chars of response: {cleaned_response[:500]}")
+ await websocket.send_text(cleaned_response)
+ else:
+ # Stream response chunks as they arrive for regular chat
+ async for chunk in response:
+ # Extract text based on provider
+ text = None
+ if provider == "ollama":
+ # Ollama-specific extraction
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'):
+ text = chunk.message.content
+ elif hasattr(chunk, 'response'):
+ text = chunk.response
+ elif hasattr(chunk, 'text'):
+ text = chunk.text
+ else:
+ text = str(chunk)
+
+ # Filter out metadata chunks and remove thinking tags
+ if text and not text.startswith('model=') and not text.startswith('created_at='):
+ text = text.replace('', '').replace('', '')
+ await websocket.send_text(text)
+ elif provider in ["openai", "azure"]:
+ # OpenAI/Azure-style responses with choices and delta
+ choices = getattr(chunk, "choices", [])
+ if len(choices) > 0:
+ delta = getattr(choices[0], "delta", None)
+ if delta is not None:
+ text = getattr(delta, "content", None)
+ if text:
+ await websocket.send_text(text)
+ else:
+ # OpenRouter and other providers
+ if isinstance(chunk, str):
+ text = chunk
+ elif hasattr(chunk, 'content'):
+ text = chunk.content
+ elif hasattr(chunk, 'text'):
+ text = chunk.text
+ else:
+ text = str(chunk)
+
+ if text:
+ await websocket.send_text(text)
+
+
async def handle_websocket_chat(websocket: WebSocket):
"""
Handle WebSocket connection for chat completions.
@@ -69,8 +200,9 @@ async def handle_websocket_chat(websocket: WebSocket):
tokens = count_tokens(last_message.content, request.provider == "ollama")
logger.info(f"Request size: {tokens} tokens")
if tokens > 8000:
- logger.warning(f"Request exceeds recommended token limit ({tokens} > 7500)")
+ logger.warning(f"Request exceeds recommended token limit ({tokens} > 8000)")
input_too_large = True
+ logger.info("Input is large - RAG retrieval will be used to reduce context size")
# Create a new RAG instance for this request
try:
@@ -178,11 +310,19 @@ async def handle_websocket_chat(websocket: WebSocket):
# Get the query from the last message
query = last_message.content
- # Only retrieve documents if input is not too large
+ # Use RAG retrieval to get relevant context
+ # RAG is ALWAYS used when we have embedded documents available
+ # For large inputs (>8K tokens), RAG is ESSENTIAL to reduce context to manageable size
+ # For small inputs, RAG still helps focus on most relevant content
context_text = ""
retrieved_documents = None
- if not input_too_large:
+ # Always attempt RAG retrieval when retriever is prepared
+ logger.info(f"Checking RAG availability: request_rag={request_rag}, type={type(request_rag)}")
+ use_rag = request_rag is not None
+ logger.info(f"use_rag={use_rag}")
+
+ if use_rag:
try:
# If filePath exists, modify the query for RAG to focus on the file
rag_query = query
@@ -194,8 +334,15 @@ async def handle_websocket_chat(websocket: WebSocket):
# Try to perform RAG retrieval
try:
# This will use the actual RAG implementation
+ logger.info(f"Calling RAG with query: {rag_query[:100]}...")
retrieved_documents = request_rag(rag_query, language=request.language)
-
+ logger.info(f"RAG returned: {type(retrieved_documents)}, length: {len(retrieved_documents) if retrieved_documents else 0}")
+
+ if retrieved_documents and len(retrieved_documents) > 0:
+ logger.info(f"First result type: {type(retrieved_documents[0])}")
+ if hasattr(retrieved_documents[0], 'documents'):
+ logger.info(f"Documents found: {len(retrieved_documents[0].documents)}")
+
if retrieved_documents and retrieved_documents[0].documents:
# Format context for the prompt in a more structured way
documents = retrieved_documents[0].documents
@@ -243,8 +390,19 @@ async def handle_websocket_chat(websocket: WebSocket):
supported_langs = configs["lang_config"]["supported_languages"]
language_name = supported_langs.get(language_code, "English")
+ # Detect if this is a wiki structure generation request
+ is_structure_generation = "create a wiki structure" in query.lower() or \
+ ("analyze this github repository" in query.lower() and
+ "wiki structure" in query.lower())
+
+ if is_structure_generation:
+ logger.info("Detected wiki structure generation request - will use XML format")
+
# Create system prompt
- if is_deep_research:
+ if is_structure_generation:
+ # Use the XML structure prompt imported at the top
+ system_prompt = WIKI_STRUCTURE_SYSTEM_PROMPT
+ elif is_deep_research:
# Check if this is the first iteration
is_first_iteration = research_iteration == 1
@@ -426,10 +584,14 @@ async def handle_websocket_chat(websocket: WebSocket):
prompt += f"\n{query}\n\n\nAssistant: "
+ logger.info(f"About to get model config for provider={request.provider}, model={request.model}")
model_config = get_model_config(request.provider, request.model)["model_kwargs"]
+ logger.info(f"Got model_config: {model_config}")
if request.provider == "ollama":
prompt += " /no_think"
+
+ logger.debug("Entering Ollama provider block")
model = OllamaClient()
model_kwargs = {
@@ -441,12 +603,18 @@ async def handle_websocket_chat(websocket: WebSocket):
"num_ctx": model_config["num_ctx"]
}
}
+ logger.debug(f"Created model_kwargs for Ollama: {model_kwargs}")
api_kwargs = model.convert_inputs_to_api_kwargs(
input=prompt,
model_kwargs=model_kwargs,
model_type=ModelType.LLM
)
+ logger.debug(f"api_kwargs before model fix: {api_kwargs}")
+
+ # WORKAROUND: Force the model name in api_kwargs as convert_inputs_to_api_kwargs seems to override it
+ api_kwargs["model"] = model_config["model"]
+ logger.debug(f"api_kwargs after forcing model to {model_config['model']}: {api_kwargs}")
elif request.provider == "openrouter":
logger.info(f"Using OpenRouter with model: {request.model}")
@@ -544,12 +712,15 @@ async def handle_websocket_chat(websocket: WebSocket):
if request.provider == "ollama":
# Get the response and handle it properly using the previously created api_kwargs
response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Ollama
- async for chunk in response:
- text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
- if text and not text.startswith('model=') and not text.startswith('created_at='):
- text = text.replace('', '').replace('', '')
- await websocket.send_text(text)
+
+ # Use shared helper to handle streaming or accumulation
+ await handle_response_stream(
+ response=response,
+ websocket=websocket,
+ is_structure_generation=is_structure_generation,
+ provider="ollama"
+ )
+
# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
elif request.provider == "openrouter":
@@ -557,9 +728,15 @@ async def handle_websocket_chat(websocket: WebSocket):
# Get the response and handle it properly using the previously created api_kwargs
logger.info("Making OpenRouter API call")
response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from OpenRouter
- async for chunk in response:
- await websocket.send_text(chunk)
+
+ # Use shared helper to handle streaming or accumulation
+ await handle_response_stream(
+ response=response,
+ websocket=websocket,
+ is_structure_generation=is_structure_generation,
+ provider="openrouter"
+ )
+
# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
except Exception as e_openrouter:
@@ -573,15 +750,15 @@ async def handle_websocket_chat(websocket: WebSocket):
# Get the response and handle it properly using the previously created api_kwargs
logger.info("Making Openai API call")
response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Openai
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- await websocket.send_text(text)
+
+ # Use shared helper to handle streaming or accumulation
+ await handle_response_stream(
+ response=response,
+ websocket=websocket,
+ is_structure_generation=is_structure_generation,
+ provider="openai"
+ )
+
# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
except Exception as e_openai:
@@ -595,15 +772,15 @@ async def handle_websocket_chat(websocket: WebSocket):
# Get the response and handle it properly using the previously created api_kwargs
logger.info("Making Azure AI API call")
response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Azure AI
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- await websocket.send_text(text)
+
+ # Use shared helper to handle streaming or accumulation
+ await handle_response_stream(
+ response=response,
+ websocket=websocket,
+ is_structure_generation=is_structure_generation,
+ provider="azure"
+ )
+
# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
except Exception as e_azure:
@@ -767,3 +944,399 @@ async def handle_websocket_chat(websocket: WebSocket):
await websocket.close()
except:
pass
+
+
+# ============================================================================
+# CHUNKED WIKI GENERATION FOR LARGE REPOSITORIES
+# ============================================================================
+
+async def process_wiki_chunk(
+ chunk_data: Dict[str, Any],
+ chunk_id: int,
+ total_chunks: int,
+ request: ChatCompletionRequest,
+ readme_content: str
+) -> str:
+ """
+ Process a single chunk of files to generate partial wiki structure.
+
+ Args:
+ chunk_data: Dictionary with chunk info (files, directories, file_count)
+ chunk_id: Index of this chunk
+ total_chunks: Total number of chunks
+ request: Original chat completion request
+ readme_content: README content for context
+
+ Returns:
+ XML string with partial wiki structure for this chunk
+ """
+ logger.info(f"Processing chunk {chunk_id + 1}/{total_chunks} with {chunk_data['file_count']} files")
+
+ # Create focused query for this chunk
+ chunk_dirs = chunk_data.get('directories', [])
+ chunk_query = f"""Analyze chunk {chunk_id + 1} of {total_chunks} for this repository.
+
+This chunk contains {chunk_data['file_count']} files from these directories: {', '.join(chunk_dirs[:10])}
+
+Generate a partial wiki structure for ONLY the files in this chunk. Focus on:
+1. Identifying the purpose of these files/directories
+2. Their role in the overall system
+3. How they relate to each other
+
+Return the result in the same XML format, but only include pages relevant to this chunk.
+
+
+{chunk_data.get('file_tree', '')}
+
+
+
+{readme_content}
+"""
+
+ try:
+ # Use RAG to get relevant context for this chunk's files
+ retrieved_documents = None
+ if request_rag:
+ try:
+ # Create a focused query for RAG retrieval based on chunk directories
+ rag_query = f"Information about: {', '.join(chunk_dirs[:5])}"
+ logger.info(f"RAG query for chunk {chunk_id + 1}: {rag_query}")
+
+ retrieved_documents = request_rag(rag_query, language="en")
+
+ if retrieved_documents and retrieved_documents[0].documents:
+ documents = retrieved_documents[0].documents
+ logger.info(f"Retrieved {len(documents)} documents for chunk {chunk_id + 1}")
+
+ # Group documents by file path
+ docs_by_file = {}
+ for doc in documents:
+ file_path = doc.meta_data.get('file_path', 'unknown')
+ if file_path not in docs_by_file:
+ docs_by_file[file_path] = []
+ docs_by_file[file_path].append(doc)
+
+ # Add context to query
+ context_parts = []
+ for file_path, docs in docs_by_file.items():
+ header = f"## File Path: {file_path}\n\n"
+ content = "\n\n".join([doc.text for doc in docs])
+ context_parts.append(f"{header}{content}")
+
+ context_text = "\n\n" + "-" * 10 + "\n\n".join(context_parts)
+ chunk_query = f"\n{context_text}\n\n\n{chunk_query}"
+
+ except Exception as e:
+ logger.warning(f"RAG retrieval failed for chunk {chunk_id + 1}: {str(e)}")
+
+ # Get model configuration
+ model_config = get_model_config(request.provider, request.model)["model_kwargs"]
+
+ # Initialize the appropriate model client based on provider
+ if request.provider == "ollama":
+ from api.ollama_patch import OllamaClient
+ model = OllamaClient()
+ model_kwargs = {
+ "model": model_config["model"],
+ "stream": False, # Non-streaming for chunk processing
+ "options": {
+ "temperature": model_config["temperature"],
+ "top_p": model_config["top_p"],
+ "num_ctx": model_config["num_ctx"]
+ }
+ }
+ api_kwargs = model.convert_inputs_to_api_kwargs(
+ input=chunk_query,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+ api_kwargs["model"] = model_config["model"]
+
+ elif request.provider == "openrouter":
+ from api.openrouter_client import OpenRouterClient
+ model = OpenRouterClient()
+ model_kwargs = {
+ "model": request.model,
+ "stream": False,
+ "temperature": model_config["temperature"]
+ }
+ if "top_p" in model_config:
+ model_kwargs["top_p"] = model_config["top_p"]
+ api_kwargs = model.convert_inputs_to_api_kwargs(
+ input=chunk_query,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+
+ elif request.provider == "openai":
+ from api.openai_client import OpenAIClient
+ model = OpenAIClient()
+ model_kwargs = {
+ "model": request.model,
+ "stream": False,
+ "temperature": model_config["temperature"]
+ }
+ if "top_p" in model_config:
+ model_kwargs["top_p"] = model_config["top_p"]
+ api_kwargs = model.convert_inputs_to_api_kwargs(
+ input=chunk_query,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+ else:
+ # Fallback: use OpenAI-compatible endpoint
+ from api.openai_client import OpenAIClient
+ model = OpenAIClient()
+ model_kwargs = {
+ "model": request.model,
+ "stream": False,
+ "temperature": model_config.get("temperature", 0.7)
+ }
+ api_kwargs = model.convert_inputs_to_api_kwargs(
+ input=chunk_query,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+
+ # Call the model synchronously for chunk processing
+ import asyncio
+ response = asyncio.create_task(model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM))
+ result = await response
+
+ # Collect the response
+ full_response = ""
+ async for chunk in result:
+ # Extract text based on response format
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'content'):
+ text = chunk.message.content
+ elif hasattr(chunk, 'response'):
+ text = chunk.response
+ elif hasattr(chunk, 'text'):
+ text = chunk.text
+ else:
+ text = str(chunk)
+
+ if text:
+ full_response += text
+
+ # Strip markdown code blocks if present
+ cleaned_response = full_response.strip()
+ if cleaned_response.startswith('```xml'):
+ cleaned_response = cleaned_response[6:]
+ elif cleaned_response.startswith('```'):
+ cleaned_response = cleaned_response[3:]
+ if cleaned_response.endswith('```'):
+ cleaned_response = cleaned_response[:-3]
+
+ logger.info(f"Chunk {chunk_id + 1} processed successfully, response length: {len(cleaned_response)}")
+ return cleaned_response.strip()
+
+ except Exception as e:
+ logger.error(f"Error processing chunk {chunk_id + 1}: {str(e)}")
+ # Return a minimal valid XML structure on error
+ return f"""
+ Error processing chunk {chunk_id + 1}: {str(e)}
+"""
+
+
+def merge_wiki_structures(partial_wikis: List[str]) -> str:
+ """
+ Merge multiple partial wiki structures into a single cohesive wiki.
+
+ Args:
+ partial_wikis: List of XML strings, each containing partial wiki structure
+
+ Returns:
+ Combined XML wiki structure
+ """
+ logger.info(f"Merging {len(partial_wikis)} partial wiki structures")
+
+ try:
+ from xml.etree import ElementTree as ET
+
+ # Parse all partial wikis
+ all_pages = []
+ all_sections = []
+ titles = []
+ descriptions = []
+
+ for i, partial_xml in enumerate(partial_wikis):
+ try:
+ # Parse the XML
+ root = ET.fromstring(partial_xml)
+
+ # Extract title and description
+ title_elem = root.find('title')
+ if title_elem is not None and title_elem.text:
+ titles.append(title_elem.text)
+
+ desc_elem = root.find('description')
+ if desc_elem is not None and desc_elem.text:
+ descriptions.append(desc_elem.text)
+
+ # Extract pages
+ pages_elem = root.find('pages')
+ if pages_elem is not None:
+ for page in pages_elem.findall('page'):
+ # Add chunk information to page
+ page.set('source_chunk', str(i + 1))
+ all_pages.append(page)
+
+ # Extract sections if present
+ sections_elem = root.find('sections')
+ if sections_elem is not None:
+ for section in sections_elem.findall('section'):
+ section.set('source_chunk', str(i + 1))
+ all_sections.append(section)
+
+ except ET.ParseError as e:
+ logger.error(f"Failed to parse partial wiki {i + 1}: {str(e)}")
+ continue
+
+ # Deduplicate pages by ID
+ unique_pages = {}
+ for page in all_pages:
+ page_id = page.get('id', f'page-{len(unique_pages) + 1}')
+ if page_id not in unique_pages:
+ unique_pages[page_id] = page
+ else:
+ # Merge information if duplicate found
+ logger.debug(f"Duplicate page ID found: {page_id}, keeping first occurrence")
+
+ # Build merged structure
+ merged_root = ET.Element('wiki_structure')
+
+ # Use first non-empty title or generate one
+ title_elem = ET.SubElement(merged_root, 'title')
+ title_elem.text = titles[0] if titles else "Repository Documentation"
+
+ # Combine descriptions
+ desc_elem = ET.SubElement(merged_root, 'description')
+ if descriptions:
+ desc_elem.text = descriptions[0] # Use first description as primary
+ else:
+ desc_elem.text = "Comprehensive documentation generated from repository analysis"
+
+ # Add sections if any were found
+ if all_sections:
+ sections_container = ET.SubElement(merged_root, 'sections')
+ for section in all_sections:
+ sections_container.append(section)
+
+ # Add all unique pages
+ pages_container = ET.SubElement(merged_root, 'pages')
+ for page_id, page in unique_pages.items():
+ pages_container.append(page)
+
+ # Convert back to string with proper formatting
+ xml_string = ET.tostring(merged_root, encoding='unicode', method='xml')
+
+ logger.info(f"Successfully merged {len(partial_wikis)} wikis into {len(unique_pages)} unique pages")
+ return xml_string
+
+ except Exception as e:
+ logger.error(f"Error merging wiki structures: {str(e)}")
+
+ # Fallback: Simple concatenation with wrapper
+ merged = "\n"
+ merged += " Repository Documentation\n"
+ merged += " Combined documentation from multiple repository sections\n"
+ merged += " \n"
+
+ # Try to extract individual pages from each partial
+ page_counter = 1
+ for i, partial in enumerate(partial_wikis):
+ try:
+ # Simple extraction of page elements
+ if ']*>.*?', partial, re.DOTALL)
+ for page in pages:
+ # Ensure page has an ID
+ if 'id=' not in page:
+ page = page.replace('\n'
+ merged += f' Chunk {i+1} Processing Note\n'
+ merged += f' Could not merge chunk {i+1} properly\n'
+ merged += f' low\n'
+ merged += f' \n'
+
+ merged += " \n"
+ merged += ""
+
+ return merged
+
+
+async def handle_chunked_wiki_generation(
+ websocket: WebSocket,
+ repo_path: str,
+ request: ChatCompletionRequest
+) -> None:
+ """
+ Handle wiki generation for large repositories using chunked processing.
+
+ This function:
+ 1. Fetches repository structure with chunking enabled
+ 2. Processes each chunk separately with RAG
+ 3. Merges partial results into final wiki structure
+ 4. Sends progress updates via WebSocket
+
+ Args:
+ websocket: Active WebSocket connection
+ repo_path: Path to the repository
+ request: Original chat completion request
+ """
+ try:
+ logger.info(f"Starting chunked wiki generation for {repo_path}")
+ await websocket.send_text("š Analyzing large repository structure...\n")
+
+ # Call get_local_repo_structure directly instead of making HTTP request
+ repo_info = await get_local_repo_structure(
+ path=repo_path,
+ return_chunks=True,
+ chunk_size=500
+ )
+
+ if not repo_info.get('chunked'):
+ # Small repo, process normally
+ await websocket.send_text("Repository is small enough to process in one go.\n")
+ return
+
+ chunks = repo_info.get('chunks', [])
+ readme = repo_info.get('readme', '')
+ total_files = repo_info.get('total_files', 0)
+
+ await websocket.send_text(f"š Repository has {total_files} files split into {len(chunks)} chunks\n")
+ await websocket.send_text("š Processing each chunk with RAG...\n\n")
+
+ partial_wikis = []
+ for i, chunk in enumerate(chunks):
+ await websocket.send_text(f"ā³ Chunk {i + 1}/{len(chunks)}: {chunk['file_count']} files from {len(chunk['directories'])} directories\n")
+
+ # Process this chunk
+ partial_wiki = await process_wiki_chunk(chunk, i, len(chunks), request, readme)
+ partial_wikis.append(partial_wiki)
+
+ await websocket.send_text(f"ā
Completed chunk {i + 1}/{len(chunks)}\n\n")
+
+ await websocket.send_text("š Merging all chunks into final wiki structure...\n")
+
+ # Merge all partial wikis
+ final_wiki = merge_wiki_structures(partial_wikis)
+
+ await websocket.send_text("\nš Final wiki structure:\n\n")
+ await websocket.send_text(final_wiki)
+
+ logger.info("Chunked wiki generation completed successfully")
+
+ except Exception as e:
+ logger.error(f"Error in chunked wiki generation: {str(e)}")
+ await websocket.send_text(f"\nā Error: {str(e)}\n")
+ finally:
+ await websocket.close()
diff --git a/docker-compose.yml b/docker-compose.yml
index 15f9cdb6..f8c9a07b 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -3,6 +3,8 @@ services:
build:
context: .
dockerfile: Dockerfile
+ args:
+ AST_CHUNKING: ${AST_CHUNKING:-false} # Set to 'true' to enable AST chunking during build
ports:
- "${PORT:-8001}:${PORT:-8001}" # API port
- "3000:3000" # Next.js port
@@ -17,6 +19,7 @@ services:
volumes:
- ~/.adalflow:/root/.adalflow # Persist repository and embedding data
- ./api/logs:/app/api/logs # Persist log files across container restarts
+ - ./api/config:/app/api/config:ro # Mount config files (read-only)
# Resource limits for docker-compose up (not Swarm mode)
mem_limit: 6g
mem_reservation: 2g
@@ -27,3 +30,5 @@ services:
timeout: 10s
retries: 3
start_period: 30s
+ extra_hosts:
+ - "host.docker.internal:host-gateway" # Allow access to host machine from Docker