Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions .github/workflows/evals.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ jobs:
NEON_PROJECT_ID: ${{ vars.NEON_PROJECT_ID }}
NEON_DATABASE: ${{ vars.NEON_DATABASE || 'neondb' }}
NEON_ROLE: ${{ vars.NEON_ROLE || 'neondb_owner' }}
AI_GATEWAY_API_KEY: ${{ secrets.AI_GATEWAY_API_KEY || vars.AI_GATEWAY_API_KEY }}
AI_GATEWAY_MODEL: ${{ vars.AI_GATEWAY_MODEL }}
AI_GATEWAY_BASE_URL: ${{ vars.AI_GATEWAY_BASE_URL }}

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -95,11 +98,18 @@ jobs:
if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }}
run: npm ci

- name: Run evals
if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' && hashFiles('tests/**/*.spec.*') != '' }}
- name: Run core tests (blocking)
if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }}
env:
DATABASE_URL: ${{ steps.cs.outputs.DATABASE_URL }}
run: npm run test:core

- name: Run E2E NL→SQL tests (non-blocking)
if: ${{ env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }}
continue-on-error: true
env:
DATABASE_URL: ${{ steps.cs.outputs.DATABASE_URL }}
run: npm test
run: npm run test:e2e

- name: Delete Neon branch
if: ${{ always() && env.NEON_API_KEY != '' && env.NEON_PROJECT_ID != '' }}
Expand Down
11 changes: 10 additions & 1 deletion db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,17 @@ export async function runQuery({
`SET LOCAL statement_timeout TO '${Math.max(1000, Math.min(timeoutMs, 60000))}ms'`,
);

// Determine the highest positional parameter index used in the input SQL (e.g., $1, $2, ...)
const matches = [...safeSql.matchAll(/\$(\d+)/g)];
const maxInSql = matches.length
? Math.max(...matches.map((m) => Number(m[1]) || 0))
: 0;
const base = Math.max(maxInSql, Array.isArray(params) ? params.length : 0);
const limitIdx = base + 1;
const offsetIdx = base + 2;

// Wrap the query to enforce limit/offset without trying to parse the SQL
const wrapped = `select * from ( ${safeSql} ) as t limit $${params.length + 1} offset $${params.length + 2}`;
const wrapped = `select * from ( ${safeSql} ) as t limit $${limitIdx} offset $${offsetIdx}`;
const res = await client.query({
text: wrapped,
values: [...params, clampedLimit, clampedOffset],
Expand Down
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"type": "module",
"private": true,
"scripts": {
"test": "vitest run"
"test": "vitest run",
"test:core": "vitest run tests/contract.spec.ts tests/scenario.spec.ts",
"test:e2e": "vitest run tests/e2e_nl_sql.spec.ts"
},
"dependencies": {
"ai": "^5.0.44",
Expand Down
2 changes: 1 addition & 1 deletion tests/contract.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ describe("contract: runQuery safety and limit/offset enforcement", () => {
// The last structured call should be the wrapped SELECT
const last = queryCalls.find((c) => typeof c !== "string") as any;
expect(last.text).toMatch(
/select \* from \( select 1 as x \) as t limit \$1 offset \$2/i,
/select \* from \(\s*select 1 as x\s*\) as t limit \$(\d+) offset \$(\d+)/i,
);
// Clamped
expect(last.values?.[0]).toBe(2000);
Expand Down
174 changes: 174 additions & 0 deletions tests/e2e_nl_sql.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import { it, expect } from "vitest";
import { runQuery } from "../db.ts";
import { buildSystemPrompt } from "../prompt.ts";

const GATEWAY_KEY = process.env.AI_GATEWAY_API_KEY;
const GATEWAY_BASE = process.env.AI_GATEWAY_BASE_URL;
const GATEWAY_MODEL = process.env.AI_GATEWAY_MODEL;

const HAS_E2E =
!!process.env.DATABASE_URL &&
!!GATEWAY_KEY &&
!!GATEWAY_BASE &&
!!GATEWAY_MODEL;
const itE2E = HAS_E2E ? it : it.skip;

async function nlToSql(
prompt: string,
): Promise<{ sql: string; params: unknown[] }> {
const system =
buildSystemPrompt() +
"\n\nAdditional instructions:" +
"\n- Only respond with a single SQL statement." +
"\n- Output the SQL inside a fenced code block marked 'sql'." +
"\n- Do not include explanations." +
"\n- Do NOT use parameter placeholders like $1, $2. Inline literal values (with proper quoting) directly in the SQL.";

const res = await fetch(`${GATEWAY_BASE}/chat/completions`, {
method: "POST",
headers: {
"content-type": "application/json",
authorization: `Bearer ${GATEWAY_KEY}`,
},
body: JSON.stringify({
model: GATEWAY_MODEL,
temperature: 0,
messages: [
{ role: "system", content: system },
{ role: "user", content: prompt },
],
}),
});

if (!res.ok) {
const text = await res.text().catch(() => "");
throw new Error(`gateway_error: ${res.status} ${res.statusText} ${text}`);
}

const data = (await res.json()) as any;
const content = data?.choices?.[0]?.message?.content ?? "";
const sqlMatch = content.match(/```sql\s*([\s\S]*?)```/i);
const jsonMatch = content.match(/```json\s*([\s\S]*?)```/i);
if (!sqlMatch) throw new Error("no_sql_block_found");
const sql = sqlMatch[1].trim();
let params: unknown[] = [];
if (jsonMatch) {
try {
const parsed = JSON.parse(jsonMatch[1]);
params = Array.isArray(parsed)
? parsed
: Array.isArray((parsed as any)?.params)
? (parsed as any).params
: [];
} catch {
params = [];
}
}
if (!/^\s*with\s+|^\s*select\s+/i.test(sql)) {
throw new Error(`not_select_sql: ${sql.slice(0, 160)}`);
}
const maxPlaceholder = (() => {
const m = [...sql.matchAll(/\$(\d+)/g)];
return m.length ? Math.max(...m.map((x) => Number(x[1]) || 0)) : 0;
})();
if (maxPlaceholder > 0 && params.length < maxPlaceholder) {
// allow caller to inject fallback params per test
}
return { sql, params };
}

// Increase per-test timeout for E2E calls
const T = 60000; // Increase test timeout to 60000 ms

const A_START = "2025-01-10T00:00:00Z";
const A_END = "2025-01-12T23:59:59Z";

// E2E 1: count field changes in window
itE2E(
"e2e: Proj A field_changes in fixed window = 4",
async () => {
const { sql, params } = await nlToSql(
`MUST use table field_changes. MUST filter project_name = 'Proj A'. MUST restrict changed_at between '${A_START}' and '${A_END}'. For project \"Proj A\", how many field changes occurred between ${A_START} and ${A_END}? Return a single row with a numeric count.`,
);
let p = params;
if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) {
p = ["Proj A", A_START, A_END];
}
// Debugging: log the SQL query and params
console.log(`Running SQL: ${sql}`, params);
const res = await runQuery({
sql,
params: p,
limit: 2000,
timeoutMs: 60000,
}); // Increased timeoutMs for runQuery
const count = (() => {
const row = res.rows?.[0] ?? {};
const byKey = Object.values(row).find((v) => typeof v === "number");
return typeof byKey === "number" ? byKey : res.rowCount;
})();
expect(count).toBe(4);
},
T,
);

// E2E 2: current status for ITEM_A_1
itE2E(
"e2e: Proj A ITEM_A_1 Status is Done",
async () => {
const { sql, params } = await nlToSql(
`MUST use table current_field_values. MUST filter project_name = 'Proj A' AND item_node_id = 'ITEM_A_1' AND field_name = 'Status'. Return a single row with only the status value.`,
);
let p = params;
if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) {
p = ["Proj A", "ITEM_A_1"];
}
console.log(`Running SQL: ${sql}`, p);
const res = await runQuery({
sql,
params: p,
limit: 50,
timeoutMs: 120000,
}); // timeout increased to 120000 ms
console.log(`Query result:`, res); // additional debug
const textVal = (() => {
const row = res.rows?.[0] ?? {};
const str = Object.values(row).find((v) => typeof v === "string") as
| string
| undefined;
return str;
})();
expect(textVal).toBe("Done");
},
T,
);

// E2E 3: deletion events list
itE2E(
"e2e: Proj A has one deletion event",
async () => {
const { sql, params } = await nlToSql(
`MUST use table field_changes. MUST filter project_name = 'Proj A' AND field_name = '_item_deleted'. Return old_value and new_value columns only.`,
);
let p = params;
if ((!p || p.length === 0) && /\\$\\d+/.test(sql)) {
p = ["Proj A"];
}
console.log(`Running SQL: ${sql}`, p);
const res = await runQuery({
sql,
params: p,
limit: 50,
timeoutMs: 120000,
}); // timeout increased to 120000 ms
console.debug(`Deletion check result:`, res); // additional debug
expect(res.rowCount).toBeGreaterThanOrEqual(1);
const ok = res.rows.some(
(r) =>
r?.old_value === true &&
(r?.new_value === null || r?.new_value === undefined),
);
expect(ok).toBe(true);
},
T,
);
Loading