diff --git a/README.md b/README.md index da49910..f5d53c9 100644 --- a/README.md +++ b/README.md @@ -68,9 +68,21 @@ Configuration is managed through two files: ```dotenv # .env - # Required: Your OpenAI API Key + # Required: Your OpenAI API Key (used for both OpenAI and custom providers) OPENAI_API_KEY="sk-..." + # Optional: Embedding provider (defaults to "openai") + PROVIDER="openai" # or "custom" + + # Optional: Custom embedding model (defaults based on provider) + EMBEDDING_MODEL="text-embedding-3-large" # or your preferred model + + # Required if using custom provider: Custom endpoint URL + CUSTOM_ENDPOINT="http://localhost:8000/v1/embeddings" + + # Optional: Vector size of custom embedding model + EMBEDDING_VECTOR_SIZE=1024 + # Required for GitHub sources GITHUB_PERSONAL_ACCESS_TOKEN="ghp_..." @@ -84,19 +96,25 @@ Configuration is managed through two files: 2. **`config.yaml` file:** This file defines the sources to process and how to handle them. Create a `config.yaml` file (or use a different name and pass it as an argument). + **Embedding Provider Configuration:** + + Embedding providers are now configured via environment variables: + - `OPENAI_API_KEY`: API key used for both providers + - `PROVIDER`: Set to "openai" (default) or "custom" + - `EMBEDDING_MODEL`: Model to use (default: "text-embedding-3-large") + - `EMBEDDING_VECTOR_SIZE`: Vector size of the custom embedding model (default: 3072) + - `CUSTOM_ENDPOINT`: Required when using custom provider (e.g., "http://localhost:8000/v1/embeddings") + **Structure:** * `sources`: An array of source configurations. * `type`: Either `'website'`, `'github'`, `'local_directory'`, or `'zendesk'` - For websites (`type: 'website'`): * `url`: The starting URL for crawling the documentation site. * `sitemap_url`: (Optional) URL to the site's XML sitemap for discovering additional pages not linked in navigation. - For GitHub repositories (`type: 'github'`): * `repo`: Repository name in the format `'owner/repo'` (e.g., `'istio/istio'`). * `start_date`: (Optional) Starting date to fetch issues from (e.g., `'2025-01-01'`). - For local directories (`type: 'local_directory'`): * `path`: Path to the local directory to process. * `include_extensions`: (Optional) Array of file extensions to include (e.g., `['.md', '.txt', '.pdf']`). Defaults to `['.md', '.txt', '.html', '.htm', '.pdf']`. @@ -104,7 +122,6 @@ Configuration is managed through two files: * `recursive`: (Optional) Whether to traverse subdirectories (defaults to `true`). * `url_rewrite_prefix` (Optional) URL prefix to rewrite `file://` URLs (e.g., `https://mydomain.com`) * `encoding`: (Optional) File encoding to use (defaults to `'utf8'`). Note: PDF files are processed as binary and this setting doesn't apply to them. - For Zendesk (`type: 'zendesk'`): * `zendesk_subdomain`: Your Zendesk subdomain (e.g., `'mycompany'` for mycompany.zendesk.com). * `email`: Your Zendesk admin email address. @@ -131,6 +148,21 @@ Configuration is managed through two files: **Example (`config.yaml`):** ```yaml + # Example with OpenAI embedding provider (default) + embedding_config: + provider: "openai" + openai: + api_key_env: "OPENAI_API_KEY" + + # Example with custom embedding provider (LiteLLM) + # embedding_config: + # provider: "custom" + # custom: + # endpoint: "http://localhost:8000/v1/embeddings" + # model: "text-embedding-ada-002" + # api_key_env: "LITELLM_API_KEY" + # timeout: 30000 + sources: # Website source example - type: 'website' @@ -155,7 +187,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './istio-issues.db' - # Local directory source example - type: 'local_directory' product_name: 'project-docs' @@ -168,7 +199,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './project-docs.db' - # Zendesk example - type: 'zendesk' product_name: 'MyCompany' @@ -186,7 +216,6 @@ Configuration is managed through two files: type: 'sqlite' params: db_path: './zendesk-kb.db' - # Qdrant example - type: 'website' product_name: 'Istio' diff --git a/database.ts b/database.ts index eb9f810..e4eb997 100644 --- a/database.ts +++ b/database.ts @@ -5,11 +5,11 @@ import * as sqliteVec from "sqlite-vec"; import { QdrantClient } from '@qdrant/js-client-rest'; import { Logger } from './logger'; import { Utils } from './utils'; -import { - SourceConfig, - DatabaseConnection, - SqliteDB, - QdrantDB, +import { + SourceConfig, + DatabaseConnection, + SqliteDB, + QdrantDB, DocumentChunk, SqliteDatabaseParams, QdrantDatabaseParams, @@ -17,23 +17,36 @@ import { } from './types'; export class DatabaseManager { + private static getEmbeddingSize(): number { + const envSize = process.env.EMBEDDING_VECTOR_SIZE; + if (envSize) { + const parsed = parseInt(envSize, 10); + if (isNaN(parsed) || parsed <= 0) { + throw new Error(`Invalid EMBEDDING_VECTOR_SIZE: ${envSize}. Must be a positive integer.`); + } + return parsed; + } + return 3072; // Default value + } + static async initDatabase(config: SourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('database'); const dbConfig = config.database_config; - + if (dbConfig.type === 'sqlite') { const params = dbConfig.params as SqliteDatabaseParams; const dbPath = params.db_path || path.join(process.cwd(), `${config.product_name.replace(/\s+/g, '_')}-${config.version}.db`); - + logger.info(`Opening SQLite database at ${dbPath}`); - + const db = new BetterSqlite3(dbPath, { allowExtension: true } as any); sqliteVec.load(db); - logger.debug(`Creating vec_items table if it doesn't exist`); + const embeddingSize = this.getEmbeddingSize(); + logger.debug(`Creating vec_items table if it doesn't exist with embedding size: ${embeddingSize}`); db.exec(` CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0( - embedding FLOAT[3072], + embedding FLOAT[${embeddingSize}], product_name TEXT, version TEXT, heading_hierarchy TEXT, @@ -51,7 +64,7 @@ export class DatabaseManager { const qdrantUrl = params.qdrant_url || 'http://localhost:6333'; const qdrantPort = params.qdrant_port || 443; const collectionName = params.collection_name || `${config.product_name.toLowerCase().replace(/\s+/g, '_')}_${config.version}`; - + logger.info(`Connecting to Qdrant at ${qdrantUrl}:${qdrantPort}, collection: ${collectionName}`); const qdrantClient = new QdrantClient({ url: qdrantUrl, apiKey: process.env.QDRANT_API_KEY, port: qdrantPort }); @@ -72,16 +85,17 @@ export class DatabaseManager { const collectionExists = collections.collections.some( (collection: any) => collection.name === collectionName ); - + if (collectionExists) { logger.info(`Collection ${collectionName} already exists`); return; } - - logger.info(`Creating new collection ${collectionName}`); + + const embeddingSize = this.getEmbeddingSize(); + logger.info(`Creating new collection ${collectionName} with embedding size: ${embeddingSize}`); await qdrantClient.createCollection(collectionName, { vectors: { - size: 3072, + size: embeddingSize, distance: "Cosine", }, }); @@ -90,9 +104,9 @@ export class DatabaseManager { if (error instanceof Error) { const errorMsg = error.message.toLowerCase(); const errorString = JSON.stringify(error).toLowerCase(); - + if ( - errorMsg.includes("already exists") || + errorMsg.includes("already exists") || errorString.includes("already exists") || (error as any)?.status === 409 || errorString.includes("conflict") @@ -101,7 +115,7 @@ export class DatabaseManager { return; } } - + logger.error(`Error creating Qdrant collection:`, error); logger.warn(`Continuing with existing collection...`); } @@ -127,12 +141,12 @@ export class DatabaseManager { static async getLastRunDate(dbConnection: DatabaseConnection, repo: string, defaultDate: string, logger: Logger): Promise { const metadataKey = `last_run_${repo.replace('/', '_')}`; - + try { if (dbConnection.type === 'sqlite') { const stmt = dbConnection.db.prepare('SELECT value FROM vec_metadata WHERE key = ?'); const result = stmt.get(metadataKey) as { value: string } | undefined; - + if (result) { logger.info(`Retrieved last run date for ${repo}: ${result.value}`); return result.value; @@ -141,7 +155,7 @@ export class DatabaseManager { // Generate a UUID for this repo's metadata const metadataUUID = Utils.generateMetadataUUID(repo); logger.debug(`Looking up metadata with UUID: ${metadataUUID}`); - + try { // Try to retrieve the metadata point for this repo const response = await dbConnection.client.retrieve(dbConnection.collectionName, { @@ -149,7 +163,7 @@ export class DatabaseManager { with_payload: true, with_vector: false }); - + if (response.length > 0 && response[0].payload?.metadata_value) { const lastRunDate = response[0].payload.metadata_value as string; logger.info(`Retrieved last run date for ${repo}: ${lastRunDate}`); @@ -162,14 +176,14 @@ export class DatabaseManager { } catch (error) { logger.warn(`Error retrieving last run date:`, error); } - + logger.info(`No saved run date found for ${repo}, using default: ${defaultDate}`); return defaultDate; } static async updateLastRunDate(dbConnection: DatabaseConnection, repo: string, logger: Logger): Promise { const now = new Date().toISOString(); - + try { if (dbConnection.type === 'sqlite') { const metadataKey = `last_run_${repo.replace('/', '_')}`; @@ -183,13 +197,13 @@ export class DatabaseManager { // Generate UUID for this repo's metadata const metadataUUID = Utils.generateMetadataUUID(repo); const metadataKey = `last_run_${repo.replace('/', '_')}`; - + logger.debug(`Using UUID: ${metadataUUID} for metadata`); - + // Generate a dummy embedding (all zeros) - const dummyEmbeddingSize = 3072; // Same size as your content embeddings + const dummyEmbeddingSize = this.getEmbeddingSize(); // Same size as your content embeddings const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0); - + // Create a point with special metadata payload const metadataPoint = { id: metadataUUID, @@ -204,12 +218,12 @@ export class DatabaseManager { url: 'metadata://' + repo } }; - + await dbConnection.client.upsert(dbConnection.collectionName, { wait: true, points: [metadataPoint] }); - + logger.info(`Updated last run date for ${repo} to ${now}`); } } catch (error) { @@ -236,7 +250,7 @@ export class DatabaseManager { static insertVectorsSQLite(db: Database, chunk: DocumentChunk, embedding: number[], logger: Logger, chunkHash?: string) { const { insertStmt, updateStmt } = this.prepareSQLiteStatements(db); const hash = chunkHash || Utils.generateHash(chunk.content); - + const transaction = db.transaction(() => { const params = [ new Float32Array(embedding), @@ -272,9 +286,9 @@ export class DatabaseManager { } catch (e) { pointId = crypto.randomUUID(); } - + const hash = chunkHash || Utils.generateHash(chunk.content); - + const pointItem = { id: pointId, vector: embedding, @@ -376,9 +390,9 @@ export class DatabaseManager { } static removeObsoleteFilesSQLite( - db: Database, - processedFiles: Set, - pathConfig: { path: string; url_rewrite_prefix?: string } | string, + db: Database, + processedFiles: Set, + pathConfig: { path: string; url_rewrite_prefix?: string } | string, logger: Logger ) { const getChunksForPathStmt = db.prepare(` @@ -386,10 +400,10 @@ export class DatabaseManager { WHERE url LIKE ? || '%' `); const deleteChunkStmt = db.prepare(`DELETE FROM vec_items WHERE chunk_id = ?`); - + // Determine if we're using URL rewriting or direct file paths const isRewriteMode = typeof pathConfig === 'object' && pathConfig.url_rewrite_prefix; - + // Set up the URL prefix for searching let urlPrefix: string; if (isRewriteMode) { @@ -402,19 +416,19 @@ export class DatabaseManager { const cleanedDirPrefix = dirPrefix.replace(/^\.\/+/, ''); urlPrefix = `file://${cleanedDirPrefix}`; } - + logger.debug(`Searching for chunks with URL prefix: ${urlPrefix}`); const existingChunks = getChunksForPathStmt.all(urlPrefix) as { chunk_id: string; url: string }[]; let deletedCount = 0; - + const transaction = db.transaction(() => { for (const { chunk_id, url } of existingChunks) { // Skip if it's not from our URL prefix (safety check) if (!url.startsWith(urlPrefix)) continue; - + let filePath: string; let shouldDelete = false; - + if (isRewriteMode) { // URL rewrite mode: extract relative path and construct full file path const config = pathConfig as { path: string; url_rewrite_prefix?: string }; @@ -426,7 +440,7 @@ export class DatabaseManager { filePath = url.substring(7); // Remove 'file://' prefix shouldDelete = !processedFiles.has(filePath); } - + if (shouldDelete) { logger.debug(`Deleting obsolete chunk from SQLite: ${chunk_id.substring(0, 8)}... (File not processed: ${filePath})`); deleteChunkStmt.run(chunk_id); @@ -435,21 +449,21 @@ export class DatabaseManager { } }); transaction(); - + logger.info(`Deleted ${deletedCount} obsolete chunks from SQLite for URL prefix ${urlPrefix}`); } static async removeObsoleteFilesQdrant( - db: QdrantDB, - processedFiles: Set, - pathConfig: { path: string; url_rewrite_prefix?: string } | string, + db: QdrantDB, + processedFiles: Set, + pathConfig: { path: string; url_rewrite_prefix?: string } | string, logger: Logger ) { const { client, collectionName } = db; try { // Determine if we're using URL rewriting or direct file paths const isRewriteMode = typeof pathConfig === 'object' && pathConfig.url_rewrite_prefix; - + // Set up the URL prefix for searching let urlPrefix: string; if (isRewriteMode) { @@ -462,7 +476,7 @@ export class DatabaseManager { const cleanedDirPrefix = dirPrefix.replace(/^\.\/+/, ''); urlPrefix = `file://${cleanedDirPrefix}`; } - + logger.debug(`Checking for obsolete chunks with URL prefix: ${urlPrefix}`); const response = await client.scroll(collectionName, { limit: 10000, @@ -487,7 +501,7 @@ export class DatabaseManager { ] } }); - + const obsoletePointIds = response.points .filter((point: any) => { const url = point.payload?.url; @@ -495,13 +509,13 @@ export class DatabaseManager { if (point.payload?.is_metadata === true) { return false; } - + if (!url || !url.startsWith(urlPrefix)) { return false; } - + let filePath: string; - + if (isRewriteMode) { // URL rewrite mode: extract relative path and construct full file path const config = pathConfig as { path: string; url_rewrite_prefix?: string }; @@ -511,11 +525,11 @@ export class DatabaseManager { // Direct file path mode: remove file:// prefix to match with processedFiles filePath = url.startsWith('file://') ? url.substring(7) : ''; } - + return filePath && !processedFiles.has(filePath); }) .map((point: any) => point.id); - + if (obsoletePointIds.length > 0) { await client.delete(collectionName, { points: obsoletePointIds, @@ -528,4 +542,4 @@ export class DatabaseManager { logger.error(`Error removing obsolete chunks from Qdrant:`, error); } } -} \ No newline at end of file +} diff --git a/doc2vec.ts b/doc2vec.ts index 4a9d8bc..09996a5 100644 --- a/doc2vec.ts +++ b/doc2vec.ts @@ -6,17 +6,18 @@ import * as yaml from 'js-yaml'; import * as fs from 'fs'; import * as path from 'path'; import { Buffer } from 'buffer'; -import { OpenAI } from "openai"; import * as dotenv from "dotenv"; import { Logger, LogLevel } from './logger'; import { Utils } from './utils'; import { DatabaseManager } from './database'; import { ContentProcessor } from './content-processor'; -import { - Config, - SourceConfig, - GithubSourceConfig, - WebsiteSourceConfig, +import { EmbeddingProviderFactory } from './embedding-factory'; +import { EmbeddingProvider } from './embedding-provider'; +import { + Config, + SourceConfig, + GithubSourceConfig, + WebsiteSourceConfig, LocalDirectorySourceConfig, ZendeskSourceConfig, DatabaseConnection, @@ -29,7 +30,7 @@ dotenv.config(); class Doc2Vec { private config: Config; - private openai: OpenAI; + private embeddingProvider: EmbeddingProvider; private contentProcessor: ContentProcessor; private logger: Logger; @@ -40,10 +41,14 @@ class Doc2Vec { useColor: true, prettyPrint: true }); - + this.logger.info('Initializing Doc2Vec'); this.config = this.loadConfig(configPath); - this.openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); + + // Initialize embedding provider based on environment variables + this.embeddingProvider = EmbeddingProviderFactory.createProvider(this.logger); + this.logger.info(`Using embedding provider: ${this.embeddingProvider.getProviderName()}`); + this.contentProcessor = new ContentProcessor(this.logger); } @@ -51,9 +56,9 @@ class Doc2Vec { try { const logger = this.logger.child('config'); logger.info(`Loading configuration from ${configPath}`); - + let configFile = fs.readFileSync(configPath, 'utf8'); - + // Substitute environment variables in the format ${VAR_NAME} configFile = configFile.replace(/\$\{([^}]+)\}/g, (match, varName) => { const envValue = process.env[varName]; @@ -64,9 +69,9 @@ class Doc2Vec { logger.debug(`Substituted ${match} with environment variable value`); return envValue; }); - + let config = yaml.load(configFile) as any; - + const typedConfig = config as Config; logger.info(`Configuration loaded successfully, found ${typedConfig.sources.length} sources`); return typedConfig; @@ -78,12 +83,12 @@ class Doc2Vec { public async run(): Promise { this.logger.section('PROCESSING SOURCES'); - + for (const sourceConfig of this.config.sources) { const sourceLogger = this.logger.child(`source:${sourceConfig.product_name}`); - + sourceLogger.info(`Processing ${sourceConfig.type} source for ${sourceConfig.product_name}@${sourceConfig.version}`); - + if (sourceConfig.type === 'github') { await this.processGithubRepo(sourceConfig, sourceLogger); } else if (sourceConfig.type === 'website') { @@ -96,17 +101,17 @@ class Doc2Vec { sourceLogger.error(`Unknown source type: ${(sourceConfig as any).type}`); } } - + this.logger.section('PROCESSING COMPLETE'); } private async fetchAndProcessGitHubIssues(repo: string, sourceConfig: GithubSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const [owner, repoName] = repo.split('/'); const GITHUB_API_URL = `https://api.github.com/repos/${owner}/${repoName}/issues`; - + // Initialize metadata storage if needed await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + // Get the last run date from the database const startDate = sourceConfig.start_date || '2025-01-01'; const lastRunDate = await DatabaseManager.getLastRunDate(dbConnection, repo, `${startDate}T00:00:00Z`, logger); @@ -193,31 +198,31 @@ class Doc2Vec { const processIssue = async (issue: any): Promise => { const issueNumber = issue.number; const url = `https://github.com/${repo}/issues/${issueNumber}`; - + logger.info(`Processing issue #${issueNumber}`); - + // Generate markdown for the issue const markdown = await generateMarkdownForIssue(issue); - + // Chunk the markdown content const issueConfig = { ...sourceConfig, product_name: sourceConfig.product_name || repo, max_size: sourceConfig.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, issueConfig, url); logger.info(`Issue #${issueNumber}: Created ${chunks.length} chunks`); - + // Process and store each chunk immediately for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -252,7 +257,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -279,36 +284,36 @@ class Doc2Vec { // Update the last run date in the database after processing all issues await DatabaseManager.updateLastRunDate(dbConnection, repo, logger); - + logger.info(`Successfully processed ${issues.length} issues`); } private async processGithubRepo(config: GithubSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for GitHub repo: ${config.repo}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); - + // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + logger.section('GITHUB ISSUES'); - + // Process GitHub issues await this.fetchAndProcessGitHubIssues(config.repo, config, dbConnection, logger); - + logger.info(`Finished processing GitHub repo: ${config.repo}`); } private async processWebsite(config: WebsiteSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for website: ${config.url}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); const validChunkIds: Set = new Set(); const visitedUrls: Set = new Set(); const urlPrefix = Utils.getUrlPrefix(config.url); - + logger.section('CRAWL AND EMBEDDING'); const crawlResult = await this.contentProcessor.crawlWebsite(config.url, config, async (url, content) => { @@ -399,7 +404,7 @@ class Doc2Vec { logger.info(`Found ${validChunkIds.size} valid chunks across processed pages for ${config.url}`); logger.section('CLEANUP'); - + if (crawlResult.hasNetworkErrors) { logger.warn('Skipping cleanup due to network errors encountered during crawling. This prevents removal of valid chunks when the site is temporarily unreachable.'); } else { @@ -418,28 +423,28 @@ class Doc2Vec { private async processLocalDirectory(config: LocalDirectorySourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for local directory: ${config.path}`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); const validChunkIds: Set = new Set(); const processedFiles: Set = new Set(); - + logger.section('FILE SCANNING AND EMBEDDING'); - + await this.contentProcessor.processDirectory( - config.path, - config, + config.path, + config, async (filePath, content) => { processedFiles.add(filePath); - + logger.info(`Processing content from ${filePath} (${content.length} chars)`); try { // Generate URL based on configuration let fileUrl: string; - + if (config.url_rewrite_prefix) { // Replace local path with URL prefix const relativePath = path.relative(config.path, filePath).replace(/\\/g, '/'); - + // If relativePath starts with '..', it means the file is outside the base directory if (relativePath.startsWith('..')) { // For files outside the configured path, use the default file:// scheme @@ -448,10 +453,10 @@ class Doc2Vec { } else { // For files inside the configured path, rewrite the URL // Handle trailing slashes in the URL prefix to avoid double slashes - const prefix = config.url_rewrite_prefix.endsWith('/') - ? config.url_rewrite_prefix.slice(0, -1) + const prefix = config.url_rewrite_prefix.endsWith('/') + ? config.url_rewrite_prefix.slice(0, -1) : config.url_rewrite_prefix; - + fileUrl = `${prefix}/${relativePath}`; logger.debug(`URL rewritten: ${filePath} -> ${fileUrl}`); } @@ -459,26 +464,26 @@ class Doc2Vec { // Use default file:// URL fileUrl = `file://${filePath}`; } - + const chunks = await this.contentProcessor.chunkMarkdown(content, config, fileUrl); logger.info(`Created ${chunks.length} chunks`); - + if (chunks.length > 0) { const chunkProgress = logger.progress(`Embedding chunks for ${filePath}`, chunks.length); - + for (let i = 0; i < chunks.length; i++) { const chunk = chunks[i]; validChunkIds.add(chunk.metadata.chunk_id); - + const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + let needsEmbedding = true; const chunkHash = Utils.generateHash(chunk.content); - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { needsEmbedding = false; chunkProgress.update(1, `Skipping unchanged chunk ${chunkId}`); @@ -495,13 +500,13 @@ class Doc2Vec { } catch (e) { pointId = crypto.randomUUID(); } - + const existingPoints = await dbConnection.client.retrieve(dbConnection.collectionName, { ids: [pointId], with_payload: true, with_vector: false, }); - + if (existingPoints.length > 0 && existingPoints[0].payload && existingPoints[0].payload.hash === chunkHash) { needsEmbedding = false; chunkProgress.update(1, `Skipping unchanged chunk ${chunkId}`); @@ -511,7 +516,7 @@ class Doc2Vec { logger.error(`Error checking existing point in Qdrant:`, error); } } - + if (needsEmbedding) { const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length > 0) { @@ -529,16 +534,16 @@ class Doc2Vec { } } } - + chunkProgress.complete(); } } catch (error) { logger.error(`Error during chunking or embedding for ${filePath}:`, error); } - }, + }, logger ); - + logger.section('CLEANUP'); if (dbConnection.type === 'sqlite') { logger.info(`Running SQLite cleanup for local directory ${config.path}`); @@ -547,43 +552,43 @@ class Doc2Vec { logger.info(`Running Qdrant cleanup for local directory ${config.path} in collection ${dbConnection.collectionName}`); await DatabaseManager.removeObsoleteFilesQdrant(dbConnection, processedFiles, config, logger); } - + logger.info(`Finished processing local directory: ${config.path}`); } private async processZendesk(config: ZendeskSourceConfig, parentLogger: Logger): Promise { const logger = parentLogger.child('process'); logger.info(`Starting processing for Zendesk: ${config.zendesk_subdomain}.zendesk.com`); - + const dbConnection = await DatabaseManager.initDatabase(config, logger); - + // Initialize metadata storage await DatabaseManager.initDatabaseMetadata(dbConnection, logger); - + const fetchTickets = config.fetch_tickets !== false; // default true const fetchArticles = config.fetch_articles !== false; // default true - + if (fetchTickets) { logger.section('ZENDESK TICKETS'); await this.fetchAndProcessZendeskTickets(config, dbConnection, logger); } - + if (fetchArticles) { logger.section('ZENDESK ARTICLES'); await this.fetchAndProcessZendeskArticles(config, dbConnection, logger); } - + logger.info(`Finished processing Zendesk: ${config.zendesk_subdomain}.zendesk.com`); } private async fetchAndProcessZendeskTickets(config: ZendeskSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const baseUrl = `https://${config.zendesk_subdomain}.zendesk.com/api/v2`; const auth = Buffer.from(`${config.email}/token:${config.api_token}`).toString('base64'); - + // Get the last run date from the database const startDate = config.start_date || `${new Date().getFullYear()}-01-01`; const lastRunDate = await DatabaseManager.getLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, `${startDate}T00:00:00Z`, logger); - + const fetchWithRetry = async (url: string, retries = 3): Promise => { for (let attempt = 0; attempt < retries; attempt++) { try { @@ -593,14 +598,14 @@ class Doc2Vec { 'Content-Type': 'application/json', }, }); - + if (response.status === 429) { const retryAfter = parseInt(response.headers['retry-after'] || '60'); logger.warn(`Rate limited, waiting ${retryAfter}s before retry`); await new Promise(res => setTimeout(res, retryAfter * 1000)); continue; } - + return response.data; } catch (error: any) { logger.error(`Zendesk API error (attempt ${attempt + 1}):`, error.message); @@ -619,26 +624,26 @@ class Doc2Vec { md += `- **Assignee:** ${ticket.assignee_id || 'Unassigned'}\n`; md += `- **Created:** ${new Date(ticket.created_at).toDateString()}\n`; md += `- **Updated:** ${new Date(ticket.updated_at).toDateString()}\n`; - + if (ticket.tags && ticket.tags.length > 0) { md += `- **Tags:** ${ticket.tags.map((tag: string) => `\`${tag}\``).join(', ')}\n`; } - + // Handle ticket description const description = ticket.description || ''; const cleanDescription = description || '_No description._'; md += `\n## Description\n\n${cleanDescription}\n\n`; - + if (comments && comments.length > 0) { md += `## Comments\n\n`; for (const comment of comments) { if (comment.public) { md += `### ${comment.author_id} - ${new Date(comment.created_at).toDateString()}\n\n`; - + // Handle comment body const rawBody = comment.plain_body || comment.html_body || comment.body || ''; const commentBody = rawBody.replace(/ /g, " ") || '_No content._'; - + md += `${commentBody}\n\n---\n\n`; } } @@ -652,36 +657,36 @@ class Doc2Vec { const processTicket = async (ticket: any): Promise => { const ticketId = ticket.id; const url = `https://${config.zendesk_subdomain}.zendesk.com/agent/tickets/${ticketId}`; - + logger.info(`Processing ticket #${ticketId}`); - + // Fetch ticket comments const commentsUrl = `${baseUrl}/tickets/${ticketId}/comments.json`; const commentsData = await fetchWithRetry(commentsUrl); const comments = commentsData?.comments || []; - + // Generate markdown for the ticket const markdown = generateMarkdownForTicket(ticket, comments); - + // Chunk the markdown content const ticketConfig = { ...config, product_name: config.product_name || `zendesk_${config.zendesk_subdomain}`, max_size: config.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, ticketConfig, url); logger.info(`Ticket #${ticketId}: Created ${chunks.length} chunks`); - + // Process and store each chunk for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -716,7 +721,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -732,27 +737,27 @@ class Doc2Vec { }; logger.info(`Fetching Zendesk tickets updated since ${lastRunDate}`); - + // Build query parameters const statusFilter = config.ticket_status || ['new', 'open', 'pending', 'hold', 'solved']; const query = `updated>${lastRunDate.split('T')[0]} status:${statusFilter.join(',status:')}`; - + let nextPage = `${baseUrl}/search.json?query=${encodeURIComponent(query)}&sort_by=updated_at&sort_order=asc`; let totalTickets = 0; - + while (nextPage) { const data = await fetchWithRetry(nextPage); const tickets = data.results || []; - + logger.info(`Processing batch of ${tickets.length} tickets`); - + for (const ticket of tickets) { await processTicket(ticket); totalTickets++; } - + nextPage = data.next_page; - + if (nextPage) { logger.debug(`Fetching next page: ${nextPage}`); // Rate limiting: wait between requests @@ -762,18 +767,18 @@ class Doc2Vec { // Update the last run date in the database await DatabaseManager.updateLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, logger); - + logger.info(`Successfully processed ${totalTickets} tickets`); } private async fetchAndProcessZendeskArticles(config: ZendeskSourceConfig, dbConnection: DatabaseConnection, logger: Logger): Promise { const baseUrl = `https://${config.zendesk_subdomain}.zendesk.com/api/v2/help_center`; const auth = Buffer.from(`${config.email}/token:${config.api_token}`).toString('base64'); - + // Get the start date for filtering const startDate = config.start_date || `${new Date().getFullYear()}-01-01`; const startDateObj = new Date(startDate); - + const fetchWithRetry = async (url: string, retries = 3): Promise => { for (let attempt = 0; attempt < retries; attempt++) { try { @@ -783,14 +788,14 @@ class Doc2Vec { 'Content-Type': 'application/json', }, }); - + if (response.status === 429) { const retryAfter = parseInt(response.headers['retry-after'] || '60'); logger.warn(`Rate limited, waiting ${retryAfter}s before retry`); await new Promise(res => setTimeout(res, retryAfter * 1000)); continue; } - + return response.data; } catch (error: any) { logger.error(`Zendesk API error (attempt ${attempt + 1}):`, error.message); @@ -808,11 +813,11 @@ class Doc2Vec { md += `- **Updated:** ${new Date(article.updated_at).toDateString()}\n`; md += `- **Vote Sum:** ${article.vote_sum || 0}\n`; md += `- **Vote Count:** ${article.vote_count || 0}\n`; - + if (article.label_names && article.label_names.length > 0) { md += `- **Labels:** ${article.label_names.map((label: string) => `\`${label}\``).join(', ')}\n`; } - + // Handle article content - convert HTML to markdown const articleBody = article.body || ''; let cleanContent = '_No content._'; @@ -825,7 +830,7 @@ class Doc2Vec { cleanContent = articleBody; } } - + md += `\n## Content\n\n${cleanContent}\n`; return md; @@ -834,31 +839,31 @@ class Doc2Vec { const processArticle = async (article: any): Promise => { const articleId = article.id; const url = article.html_url || `https://${config.zendesk_subdomain}.zendesk.com/hc/articles/${articleId}`; - + logger.info(`Processing article #${articleId}: ${article.title}`); - + // Generate markdown for the article const markdown = generateMarkdownForArticle(article); - + // Chunk the markdown content const articleConfig = { ...config, product_name: config.product_name || `zendesk_${config.zendesk_subdomain}`, max_size: config.max_size || Infinity }; - + const chunks = await this.contentProcessor.chunkMarkdown(markdown, articleConfig, url); logger.info(`Article #${articleId}: Created ${chunks.length} chunks`); - + // Process and store each chunk (similar to ticket processing) for (const chunk of chunks) { const chunkHash = Utils.generateHash(chunk.content); const chunkId = chunk.metadata.chunk_id.substring(0, 8) + '...'; - + if (dbConnection.type === 'sqlite') { const { checkHashStmt } = DatabaseManager.prepareSQLiteStatements(dbConnection.db); const existing = checkHashStmt.get(chunk.metadata.chunk_id) as { hash: string } | undefined; - + if (existing && existing.hash === chunkHash) { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; @@ -893,7 +898,7 @@ class Doc2Vec { logger.info(`Skipping unchanged chunk: ${chunkId}`); continue; } - + const embeddings = await this.createEmbeddings([chunk.content]); if (embeddings.length) { await DatabaseManager.storeChunkInQdrant(dbConnection, chunk, embeddings[0], chunkHash); @@ -909,20 +914,20 @@ class Doc2Vec { }; logger.info(`Fetching Zendesk help center articles updated since ${startDate}`); - + let nextPage = `${baseUrl}/articles.json`; let totalArticles = 0; let processedArticles = 0; - + while (nextPage) { const data = await fetchWithRetry(nextPage); const articles = data.articles || []; - + logger.info(`Processing batch of ${articles.length} articles`); - + for (const article of articles) { totalArticles++; - + // Check if article was updated since the start date const updatedAt = new Date(article.updated_at); if (updatedAt >= startDateObj) { @@ -932,33 +937,21 @@ class Doc2Vec { logger.debug(`Skipping article #${article.id} (updated ${article.updated_at}, before ${startDate})`); } } - + nextPage = data.next_page; - + if (nextPage) { logger.debug(`Fetching next page: ${nextPage}`); // Rate limiting: wait between requests await new Promise(res => setTimeout(res, 1000)); } } - + logger.info(`Successfully processed ${processedArticles} of ${totalArticles} articles (filtered by date >= ${startDate})`); } private async createEmbeddings(texts: string[]): Promise { - const logger = this.logger.child('embeddings'); - try { - logger.debug(`Creating embeddings for ${texts.length} texts`); - const response = await this.openai.embeddings.create({ - model: "text-embedding-3-large", - input: texts, - }); - logger.debug(`Successfully created ${response.data.length} embeddings`); - return response.data.map(d => d.embedding); - } catch (error) { - logger.error('Failed to create embeddings:', error); - return []; - } + return await this.embeddingProvider.createEmbeddings(texts); } } @@ -970,4 +963,4 @@ if (require.main === module) { } const doc2Vec = new Doc2Vec(configPath); doc2Vec.run().catch(console.error); -} \ No newline at end of file +} diff --git a/embedding-factory.ts b/embedding-factory.ts new file mode 100644 index 0000000..9599e8e --- /dev/null +++ b/embedding-factory.ts @@ -0,0 +1,49 @@ +import { Logger } from './logger'; +import { + EmbeddingProvider, + EmbeddingConfig, + OpenAIEmbeddingProvider, + CustomEmbeddingProvider +} from './embedding-provider'; + +/** + * Factory class for creating embedding providers + */ +export class EmbeddingProviderFactory { + /** + * Creates an embedding provider based on environment variables + * @param logger Logger instance + * @returns Configured embedding provider + */ + static createProvider(logger: Logger): EmbeddingProvider { + const factoryLogger = logger.child('embedding-factory'); + + // Get provider from PROVIDER environment variable, default to openai + const provider = (process.env.PROVIDER || 'openai').toLowerCase(); + + factoryLogger.info(`Creating embedding provider: ${provider}`); + + switch (provider) { + case 'openai': + return new OpenAIEmbeddingProvider(logger); + + case 'custom': + const endpoint = process.env.CUSTOM_ENDPOINT; + if (!endpoint) { + throw new Error('CUSTOM_ENDPOINT environment variable is required when using custom provider'); + } + + // Validate endpoint URL format + try { + new URL(endpoint); + } catch (error) { + throw new Error(`Invalid custom embedding endpoint URL: ${endpoint}`); + } + + return new CustomEmbeddingProvider(endpoint, logger); + + default: + throw new Error(`Unknown embedding provider: ${provider}. Must be 'openai' or 'custom'`); + } + } +} diff --git a/embedding-provider.ts b/embedding-provider.ts new file mode 100644 index 0000000..851bc62 --- /dev/null +++ b/embedding-provider.ts @@ -0,0 +1,174 @@ +import axios from 'axios'; +import { OpenAI } from 'openai'; +import { Logger } from './logger'; + +/** + * Abstract interface for embedding providers + */ +export interface EmbeddingProvider { + createEmbeddings(texts: string[]): Promise; + getProviderName(): string; +} + +/** + * Configuration for embedding providers + */ +export interface EmbeddingConfig { + provider: 'openai' | 'custom'; + endpoint?: string; // For custom provider + model?: string; // Model to use + timeout?: number; // Timeout for custom provider +} + +/** + * OpenAI embedding provider implementation + */ +export class OpenAIEmbeddingProvider implements EmbeddingProvider { + private openai: OpenAI; + private logger: Logger; + private model: string; + + constructor(logger: Logger) { + this.logger = logger.child('openai-embeddings'); + + const apiKey = process.env.OPENAI_API_KEY; + if (!apiKey) { + throw new Error('OpenAI API key not found in environment variable: OPENAI_API_KEY'); + } + + // Use EMBEDDING_MODEL env var or default to text-embedding-3-large + this.model = process.env.EMBEDDING_MODEL || 'text-embedding-3-large'; + + this.openai = new OpenAI({ apiKey }); + this.logger.info(`Initialized OpenAI embedding provider with model: ${this.model}`); + } + + async createEmbeddings(texts: string[]): Promise { + const maxRetries = 3; + const baseDelay = 1000; // 1 second + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + this.logger.debug(`Creating embeddings for ${texts.length} texts (attempt ${attempt}/${maxRetries})`); + + const response = await this.openai.embeddings.create({ + model: this.model, + input: texts, + }); + + this.logger.debug(`Successfully created ${response.data.length} embeddings`); + return response.data.map(d => d.embedding); + + } catch (error: any) { + this.logger.warn(`OpenAI embedding attempt ${attempt} failed:`, error.message); + + if (attempt === maxRetries) { + this.logger.error('All OpenAI embedding attempts failed'); + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + this.logger.debug(`Retrying in ${delay}ms...`); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + return []; + } + + getProviderName(): string { + return 'openai'; + } +} + +/** + * Custom endpoint embedding provider implementation (OpenAI-compatible) + */ +export class CustomEmbeddingProvider implements EmbeddingProvider { + private endpoint: string; + private model: string; + private apiKey?: string; + private timeout: number; + private logger: Logger; + + constructor(endpoint: string, logger: Logger) { + this.logger = logger.child('custom-embeddings'); + + this.endpoint = endpoint; + this.timeout = 30000; // 30 seconds default + + // Use OPENAI_API_KEY for authentication (same as OpenAI provider) + this.apiKey = process.env.OPENAI_API_KEY; + if (!this.apiKey) { + throw new Error('OpenAI API key not found in environment variable: OPENAI_API_KEY'); + } + + // Use EMBEDDING_MODEL env var or default to text-embedding-ada-002 for custom + this.model = process.env.EMBEDDING_MODEL || 'text-embedding-3-large'; + + this.logger.info(`Initialized custom embedding provider: ${this.endpoint} with model: ${this.model}`); + } + + async createEmbeddings(texts: string[]): Promise { + const maxRetries = 3; + const baseDelay = 1000; // 1 second + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + this.logger.debug(`Creating embeddings for ${texts.length} texts (attempt ${attempt}/${maxRetries})`); + + const headers: Record = { + 'Content-Type': 'application/json', + }; + + if (this.apiKey) { + headers['Authorization'] = `Bearer ${this.apiKey}`; + } + + const requestBody = { + model: this.model, + input: texts, + }; + + const response = await axios.post(this.endpoint, requestBody, { + headers, + timeout: this.timeout, + }); + + if (!response.data || !response.data.data) { + throw new Error('Invalid response format from custom embedding endpoint'); + } + + const embeddings = response.data.data.map((item: any) => { + if (!item.embedding || !Array.isArray(item.embedding)) { + throw new Error('Invalid embedding format in response'); + } + return item.embedding; + }); + + this.logger.debug(`Successfully created ${embeddings.length} embeddings`); + return embeddings; + + } catch (error: any) { + this.logger.warn(`Custom embedding attempt ${attempt} failed:`, error.message); + + if (attempt === maxRetries) { + this.logger.error('All custom embedding attempts failed'); + throw error; + } + + // Exponential backoff + const delay = baseDelay * Math.pow(2, attempt - 1); + this.logger.debug(`Retrying in ${delay}ms...`); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + return []; + } + + getProviderName(): string { + return 'custom'; + } +} diff --git a/mcp/package-lock.json b/mcp/package-lock.json index 3f496f6..c63824c 100644 --- a/mcp/package-lock.json +++ b/mcp/package-lock.json @@ -1,6 +1,6 @@ { "name": "sqlite-vec-mcp-server", - "version": "1.0.0", + "version": "1.1.0", "lockfileVersion": 3, "requires": true, "packages": { @@ -13,6 +13,7 @@ "@azure/openai": "^2.0.0", "@google/generative-ai": "^0.24.1", "@modelcontextprotocol/sdk": "^1.12.1", + "axios": "^1.12.2", "better-sqlite3": "^11.8.1", "dotenv": "^16.4.7", "express": "^5.1.0", @@ -424,6 +425,17 @@ "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" }, + "node_modules/axios": { + "version": "1.12.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz", + "integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.4", + "proxy-from-env": "^1.1.0" + } + }, "node_modules/base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -972,14 +984,36 @@ "node": ">= 0.8" } }, + "node_modules/follow-redirects": { + "version": "1.15.11", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz", + "integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, "node_modules/form-data": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz", - "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", + "license": "MIT", "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", "mime-types": "^2.1.12" }, "engines": { @@ -1671,6 +1705,12 @@ "node": ">= 0.10" } }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + }, "node_modules/pump": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.2.tgz", diff --git a/mcp/package.json b/mcp/package.json index 55a8ad7..f02170d 100644 --- a/mcp/package.json +++ b/mcp/package.json @@ -1,6 +1,6 @@ { "name": "sqlite-vec-mcp-server", - "version": "1.0.0", + "version": "1.1.0", "description": "MCP Server for querying documentation with sqlite-vec", "main": "build/index.js", "type": "module", @@ -28,6 +28,7 @@ "@azure/openai": "^2.0.0", "@google/generative-ai": "^0.24.1", "@modelcontextprotocol/sdk": "^1.12.1", + "axios": "^1.12.2", "better-sqlite3": "^11.8.1", "dotenv": "^16.4.7", "express": "^5.1.0", diff --git a/mcp/src/index.ts b/mcp/src/index.ts index ecc12df..1a441a1 100644 --- a/mcp/src/index.ts +++ b/mcp/src/index.ts @@ -15,6 +15,7 @@ import * as sqliteVec from "sqlite-vec"; import Database, { Database as DatabaseType } from "better-sqlite3"; import { OpenAI } from 'openai'; import { GoogleGenerativeAI } from '@google/generative-ai'; +import axios from 'axios'; import path from 'path'; import { fileURLToPath } from 'url'; import fs from 'fs'; // Import fs for checking file existence @@ -26,7 +27,7 @@ const __dirname = path.dirname(__filename); // Provider configuration // Note: Anthropic does not provide an embeddings API, only text generation -// Supported providers: 'openai', 'azure', 'gemini' +// Supported providers: 'openai', 'azure', 'gemini', 'custom' const embeddingProvider = process.env.EMBEDDING_PROVIDER || 'openai'; // OpenAI configuration @@ -43,6 +44,9 @@ const azureDeploymentName = process.env.AZURE_OPENAI_DEPLOYMENT_NAME || 'text-em const geminiApiKey = process.env.GEMINI_API_KEY; const geminiModel = process.env.GEMINI_MODEL || 'gemini-embedding-001'; +// Custom endpoint configuration +const customEndpoint = process.env.CUSTOM_ENDPOINT; + const dbDir = process.env.SQLITE_DB_DIR || __dirname; // Default to current dir if not set if (!fs.existsSync(dbDir)) { @@ -71,8 +75,14 @@ if (strictMode) { process.exit(1); } break; + case 'custom': + if (!customEndpoint || !openAIApiKey) { + console.error("Error: CUSTOM_ENDPOINT and OPENAI_API_KEY environment variables are required for custom provider."); + process.exit(1); + } + break; default: - console.error(`Error: Unknown embedding provider '${embeddingProvider}'. Supported providers: openai, azure, gemini`); + console.error(`Error: Unknown embedding provider '${embeddingProvider}'. Supported providers: openai, azure, gemini, custom`); console.error("Note: Anthropic does not provide an embeddings API, only text generation models."); process.exit(1); } @@ -104,14 +114,14 @@ async function createEmbeddings(text: string): Promise { } return response.data[0].embedding; } - + case 'azure': { - const azure = new AzureOpenAI({ - apiKey: azureApiKey, - endpoint: azureEndpoint, - deployment: azureDeploymentName, - apiVersion: azureApiVersion, - }); + const azure = new AzureOpenAI({ + apiKey: azureApiKey, + endpoint: azureEndpoint, + deployment: azureDeploymentName, + apiVersion: azureApiVersion, + }); const response = await azure.embeddings.create({ model: azureDeploymentName, // Use deployment name for Azure @@ -122,7 +132,7 @@ async function createEmbeddings(text: string): Promise { } return response.data[0].embedding; } - + case 'gemini': { const genAI = new GoogleGenerativeAI(geminiApiKey!); const model = genAI.getGenerativeModel({ model: geminiModel }); @@ -132,8 +142,24 @@ async function createEmbeddings(text: string): Promise { } return result.embedding.values; } + + case 'custom': { + const response = await axios.post(`${customEndpoint}/embeddings`, { + model: openAIModel, + input: text, + }, { + headers: { + 'Authorization': `Bearer ${openAIApiKey}`, + 'Content-Type': 'application/json', + }, + }); + if (!response.data?.data?.[0]?.embedding) { + throw new Error("Failed to get embedding from custom endpoint response."); + } + return response.data.data[0].embedding; + } default: - throw new Error(`Unsupported embedding provider: ${embeddingProvider}. Supported providers: openai, azure, gemini`); + throw new Error(`Unsupported embedding provider: ${embeddingProvider}. Supported providers: openai, azure, gemini, custom`); } } catch (error) { @@ -161,30 +187,30 @@ function queryCollection(queryEmbedding: number[], filter: { product_name: strin distance FROM vec_items WHERE embedding MATCH @query_embedding`; - + if (filter.product_name) query += ` AND product_name = @product_name`; if (filter.version) query += ` AND version = @version`; - + query += ` ORDER BY distance LIMIT @top_k;`; - + const stmt = db.prepare(query); console.error(`[DB ${dbPath}] Query prepared. Executing...`); const startTime = Date.now(); const rows = stmt.all({ - query_embedding: new Float32Array(queryEmbedding), - product_name: filter.product_name, - version: filter.version, - top_k: topK, + query_embedding: new Float32Array(queryEmbedding), + product_name: filter.product_name, + version: filter.version, + top_k: topK, }); const duration = Date.now() - startTime; console.error(`[DB ${dbPath}] Query executed in ${duration}ms. Found ${rows.length} rows.`); - + rows.forEach((row: any) => { - delete row.embedding; + delete row.embedding; }) - + return rows as QueryResult[]; } catch (error) { console.error(`Error querying collection in ${dbPath}:`, error); @@ -224,9 +250,9 @@ const queryDocumentationToolHandler = async ({ queryText, productName, version, const results = await queryDocumentation(queryText, productName, version, limit); if (results.length === 0) { - return { - content: [{ type: "text" as const, text: `No relevant documentation found for "${queryText}" in product "${productName}" ${version ? `(version ${version})` : ''}.` }], - }; + return { + content: [{ type: "text" as const, text: `No relevant documentation found for "${queryText}" in product "${productName}" ${version ? `(version ${version})` : ''}.` }], + }; } const formattedResults = results.map((r, index) => @@ -270,12 +296,12 @@ server.tool( async function main() { const transport_type = process.env.TRANSPORT_TYPE || 'http'; let webserver: any = null; // Store server reference for proper shutdown - + // Common graceful shutdown handler const createGracefulShutdownHandler = (transportCleanup: () => Promise) => { return async (signal: string) => { console.error(`Received ${signal}, initiating graceful shutdown...`); - + const shutdownTimeout = parseInt(process.env.SHUTDOWN_TIMEOUT || '5000', 10); const forceExitTimeout = setTimeout(() => { console.error(`Shutdown timeout (${shutdownTimeout}ms) exceeded, force exiting...`); @@ -311,32 +337,32 @@ async function main() { } }; }; - + if (transport_type === 'stdio') { // Stdio transport for direct communication console.error("Starting MCP server with stdio transport..."); const transport = new StdioServerTransport(); await server.connect(transport); console.error("MCP server connected via stdio."); - + // Add shutdown handler for stdio transport const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing stdio transport...'); // StdioServerTransport doesn't have a close method, but we can clean up the connection // The transport will be cleaned up when the process exits }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else if (transport_type === 'sse') { // SSE transport for backward compatibility console.error("Starting MCP server with SSE transport..."); - + const app = express(); - + // Storage for SSE transports by session ID - const sseTransports: {[sessionId: string]: SSEServerTransport} = {}; + const sseTransports: { [sessionId: string]: SSEServerTransport } = {}; app.get("/sse", async (_: Request, res: Response) => { console.error('Received SSE connection request'); @@ -370,18 +396,18 @@ async function main() { console.error(`MCP server is running on port ${PORT} with SSE transport`); console.error(`Connect to: http://localhost:${PORT}/sse`); }); - + webserver.keepAliveTimeout = 3000; - + // Keep the process alive webserver.on('error', (error: any) => { console.error('HTTP server error:', error); }); - + // Handle server shutdown with proper SIGTERM/SIGINT support const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing SSE transports...'); - + // Close all active SSE transports for (const [sessionId, transport] of Object.entries(sseTransports)) { try { @@ -393,19 +419,19 @@ async function main() { } } }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else if (transport_type === 'http') { // Streamable HTTP transport for web-based communication console.error("Starting MCP server with HTTP transport..."); - + const app = express(); - + const transports: Map = new Map(); const servers: Map = new Map(); - + // Handle POST requests for MCP initialization and method calls app.post('/mcp', async (req: Request, res: Response) => { console.error('Received MCP POST request'); @@ -563,20 +589,20 @@ async function main() { app.get("/health", (_: Request, res: Response) => { res.status(200).send("OK"); }); - + const PORT = process.env.PORT || 3001; webserver = app.listen(PORT, () => { console.error(`MCP server is running on port ${PORT} with HTTP transport`); console.error(`Connect to: http://localhost:${PORT}/mcp`); }); - + webserver.keepAliveTimeout = 3000; - + // Keep the process alive webserver.on('error', (error: any) => { console.error('HTTP server error:', error); }); - + // Handle server shutdown with proper SIGTERM/SIGINT support and timeout const shutdownHandler = createGracefulShutdownHandler(async () => { console.error('Closing HTTP transports and servers...'); @@ -585,17 +611,17 @@ async function main() { const transportClosePromises = Array.from(transports.entries()).map(async ([sessionId, transport]) => { try { console.error(`Closing transport and server for session ${sessionId}`); - + // Add timeout to individual transport close operations const closeTimeout = new Promise((_, reject) => { setTimeout(() => reject(new Error(`Transport close timeout for ${sessionId}`)), 2000); }); - + await Promise.race([ transport.close(), closeTimeout ]); - + transports.delete(sessionId); servers.delete(sessionId); console.error(`Transport and server closed for session ${sessionId}`); @@ -611,10 +637,10 @@ async function main() { await Promise.allSettled(transportClosePromises); console.error('All transports and servers cleanup completed'); }); - + process.on('SIGTERM', () => shutdownHandler('SIGTERM')); process.on('SIGINT', () => shutdownHandler('SIGINT')); - + } else { console.error(`Unknown transport type: ${transport_type}. Use 'stdio', 'sse', or 'http'.`); process.exit(1); diff --git a/package-lock.json b/package-lock.json index 4afa738..d4235d7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "doc2vec", - "version": "1.1.1", - "lockfileVersion": 3, + "version": "1.4.0", + "lockfileVersion": 4, "requires": true, "packages": { "": { "name": "doc2vec", - "version": "1.1.1", + "version": "1.3.0", "license": "ISC", "dependencies": { "@mozilla/readability": "^0.4.4", diff --git a/package.json b/package.json index 58cffbb..1e82c11 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "doc2vec", - "version": "1.3.0", + "version": "1.4.0", "type": "commonjs", "description": "", "main": "dist/doc2vec.js",