diff --git a/.github/workflows/evals.yml b/.github/workflows/evals.yml index 22f2114d..fd4d6b63 100644 --- a/.github/workflows/evals.yml +++ b/.github/workflows/evals.yml @@ -16,6 +16,9 @@ jobs: NEON_PROJECT_ID: ${{ vars.NEON_PROJECT_ID }} NEON_DATABASE: ${{ vars.NEON_DATABASE || 'neondb' }} NEON_ROLE: ${{ vars.NEON_ROLE || 'neondb_owner' }} + AI_GATEWAY_API_KEY: ${{ secrets.AI_GATEWAY_API_KEY || vars.AI_GATEWAY_API_KEY }} + AI_GATEWAY_MODEL: ${{ vars.AI_GATEWAY_MODEL }} + AI_GATEWAY_BASE_URL: ${{ vars.AI_GATEWAY_BASE_URL }} steps: - uses: actions/checkout@v4 @@ -95,11 +98,18 @@ jobs: if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }} run: npm ci - - name: Run evals - if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' && hashFiles('tests/**/*.spec.*') != '' }} + - name: Run core tests (blocking) + if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }} + env: + DATABASE_URL: ${{ steps.cs.outputs.DATABASE_URL }} + run: npm run test:core + + - name: Run E2E NL→SQL tests (non-blocking) + if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }} + continue-on-error: true env: DATABASE_URL: ${{ steps.cs.outputs.DATABASE_URL }} - run: npm test + run: npm run test:e2e - name: Delete Neon branch if: ${{ always() && env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }} diff --git a/db.ts b/db.ts index dc182a4e..41c935d6 100644 --- a/db.ts +++ b/db.ts @@ -173,8 +173,17 @@ export async function runQuery({ `SET LOCAL statement_timeout TO '${Math.max(1000, Math.min(timeoutMs, 60000))}ms'`, ); + // Determine the highest positional parameter index used in the input SQL (e.g., $1, $2, ...) + const matches = [...safeSql.matchAll(/\$(\d+)/g)]; + const maxInSql = matches.length + ? Math.max(...matches.map((m) => Number(m[1]) || 0)) + : 0; + const base = Math.max(maxInSql, Array.isArray(params) ? params.length : 0); + const limitIdx = base + 1; + const offsetIdx = base + 2; + // Wrap the query to enforce limit/offset without trying to parse the SQL - const wrapped = `select * from ( ${safeSql} ) as t limit $${params.length + 1} offset $${params.length + 2}`; + const wrapped = `select * from ( ${safeSql} ) as t limit $${limitIdx} offset $${offsetIdx}`; const res = await client.query({ text: wrapped, values: [...params, clampedLimit, clampedOffset], diff --git a/package.json b/package.json index 81994a87..484a98ca 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,9 @@ "type": "module", "private": true, "scripts": { - "test": "vitest run" + "test": "vitest run", + "test:core": "vitest run tests/contract.spec.ts tests/scenario.spec.ts", + "test:e2e": "vitest run tests/e2e_nl_sql.spec.ts" }, "dependencies": { "ai": "^5.0.44", diff --git a/tests/contract.spec.ts b/tests/contract.spec.ts index 21c52236..de32e47d 100644 --- a/tests/contract.spec.ts +++ b/tests/contract.spec.ts @@ -59,7 +59,7 @@ describe("contract: runQuery safety and limit/offset enforcement", () => { // The last structured call should be the wrapped SELECT const last = queryCalls.find((c) => typeof c !== "string") as any; expect(last.text).toMatch( - /select \* from \( select 1 as x \) as t limit \$1 offset \$2/i, + /select \* from \(\s*select 1 as x\s*\) as t limit \$(\d+) offset \$(\d+)/i, ); // Clamped expect(last.values?.[0]).toBe(2000); diff --git a/tests/e2e_nl_sql.spec.ts b/tests/e2e_nl_sql.spec.ts new file mode 100644 index 00000000..3ee22ca2 --- /dev/null +++ b/tests/e2e_nl_sql.spec.ts @@ -0,0 +1,174 @@ +import { it, expect } from "vitest"; +import { runQuery } from "../db.ts"; +import { buildSystemPrompt } from "../prompt.ts"; + +const GATEWAY_KEY = process.env.AI_GATEWAY_API_KEY; +const GATEWAY_BASE = process.env.AI_GATEWAY_BASE_URL; +const GATEWAY_MODEL = process.env.AI_GATEWAY_MODEL; + +const HAS_E2E = + !!process.env.DATABASE_URL && + !!GATEWAY_KEY && + !!GATEWAY_BASE && + !!GATEWAY_MODEL; +const itE2E = HAS_E2E ? it : it.skip; + +async function nlToSql( + prompt: string, +): Promise<{ sql: string; params: unknown[] }> { + const system = + buildSystemPrompt() + + "\n\nAdditional instructions:" + + "\n- Only respond with a single SQL statement." + + "\n- Output the SQL inside a fenced code block marked 'sql'." + + "\n- Do not include explanations." + + "\n- Do NOT use parameter placeholders like $1, $2. Inline literal values (with proper quoting) directly in the SQL."; + + const res = await fetch(`${GATEWAY_BASE}/chat/completions`, { + method: "POST", + headers: { + "content-type": "application/json", + authorization: `Bearer ${GATEWAY_KEY}`, + }, + body: JSON.stringify({ + model: GATEWAY_MODEL, + temperature: 0, + messages: [ + { role: "system", content: system }, + { role: "user", content: prompt }, + ], + }), + }); + + if (!res.ok) { + const text = await res.text().catch(() => ""); + throw new Error(`gateway_error: ${res.status} ${res.statusText} ${text}`); + } + + const data = (await res.json()) as any; + const content = data?.choices?.[0]?.message?.content ?? ""; + const sqlMatch = content.match(/```sql\s*([\s\S]*?)```/i); + const jsonMatch = content.match(/```json\s*([\s\S]*?)```/i); + if (!sqlMatch) throw new Error("no_sql_block_found"); + const sql = sqlMatch[1].trim(); + let params: unknown[] = []; + if (jsonMatch) { + try { + const parsed = JSON.parse(jsonMatch[1]); + params = Array.isArray(parsed) + ? parsed + : Array.isArray((parsed as any)?.params) + ? (parsed as any).params + : []; + } catch { + params = []; + } + } + if (!/^\s*with\s+|^\s*select\s+/i.test(sql)) { + throw new Error(`not_select_sql: ${sql.slice(0, 160)}`); + } + const maxPlaceholder = (() => { + const m = [...sql.matchAll(/\$(\d+)/g)]; + return m.length ? Math.max(...m.map((x) => Number(x[1]) || 0)) : 0; + })(); + if (maxPlaceholder > 0 && params.length < maxPlaceholder) { + // allow caller to inject fallback params per test + } + return { sql, params }; +} + +// Increase per-test timeout for E2E calls +const T = 60000; // Increase test timeout to 60000 ms + +const A_START = "2025-01-10T00:00:00Z"; +const A_END = "2025-01-12T23:59:59Z"; + +// E2E 1: count field changes in window +itE2E( + "e2e: Proj A field_changes in fixed window = 4", + async () => { + const { sql, params } = await nlToSql( + `MUST use table field_changes. MUST filter project_name = 'Proj A'. MUST restrict changed_at between '${A_START}' and '${A_END}'. For project \"Proj A\", how many field changes occurred between ${A_START} and ${A_END}? Return a single row with a numeric count.`, + ); + let p = params; + if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) { + p = ["Proj A", A_START, A_END]; + } + // Debugging: log the SQL query and params + console.log(`Running SQL: ${sql}`, params); + const res = await runQuery({ + sql, + params: p, + limit: 2000, + timeoutMs: 60000, + }); // Increased timeoutMs for runQuery + const count = (() => { + const row = res.rows?.[0] ?? {}; + const byKey = Object.values(row).find((v) => typeof v === "number"); + return typeof byKey === "number" ? byKey : res.rowCount; + })(); + expect(count).toBe(4); + }, + T, +); + +// E2E 2: current status for ITEM_A_1 +itE2E( + "e2e: Proj A ITEM_A_1 Status is Done", + async () => { + const { sql, params } = await nlToSql( + `MUST use table current_field_values. MUST filter project_name = 'Proj A' AND item_node_id = 'ITEM_A_1' AND field_name = 'Status'. Return a single row with only the status value.`, + ); + let p = params; + if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) { + p = ["Proj A", "ITEM_A_1"]; + } + console.log(`Running SQL: ${sql}`, p); + const res = await runQuery({ + sql, + params: p, + limit: 50, + timeoutMs: 120000, + }); // timeout increased to 120000 ms + console.log(`Query result:`, res); // additional debug + const textVal = (() => { + const row = res.rows?.[0] ?? {}; + const str = Object.values(row).find((v) => typeof v === "string") as + | string + | undefined; + return str; + })(); + expect(textVal).toBe("Done"); + }, + T, +); + +// E2E 3: deletion events list +itE2E( + "e2e: Proj A has one deletion event", + async () => { + const { sql, params } = await nlToSql( + `MUST use table field_changes. MUST filter project_name = 'Proj A' AND field_name = '_item_deleted'. Return old_value and new_value columns only.`, + ); + let p = params; + if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) { + p = ["Proj A"]; + } + console.log(`Running SQL: ${sql}`, p); + const res = await runQuery({ + sql, + params: p, + limit: 50, + timeoutMs: 120000, + }); // timeout increased to 120000 ms + console.debug(`Deletion check result:`, res); // additional debug + expect(res.rowCount).toBeGreaterThanOrEqual(1); + const ok = res.rows.some( + (r) => + r?.old_value === true && + (r?.new_value === null || r?.new_value === undefined), + ); + expect(ok).toBe(true); + }, + T, +);