diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index a733f48659..3d80a6fd42 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -14,8 +14,9 @@ * limitations under the License. */ -import { z, type EvaluatorAction, type Genkit } from 'genkit'; +import { z, type EvaluatorAction } from 'genkit'; import type { BaseEvalDataPoint } from 'genkit/evaluator'; +import { evaluator } from 'genkit/plugin'; import { runInNewSpan } from 'genkit/tracing'; import type { GoogleAuth } from 'google-auth-library'; import { @@ -24,8 +25,7 @@ import { type ChecksEvaluationMetricConfig, } from './metrics'; -export function checksEvaluators( - ai: Genkit, +export function checksEvaluator( auth: GoogleAuth, metrics: ChecksEvaluationMetric[], projectId: string @@ -42,7 +42,7 @@ export function checksEvaluators( } ); - return createPolicyEvaluator(projectId, auth, ai, policy_configs); + return createPolicyEvaluator(projectId, auth, policy_configs); } const ResponseSchema = z.object({ @@ -58,10 +58,9 @@ const ResponseSchema = z.object({ function createPolicyEvaluator( projectId: string, auth: GoogleAuth, - ai: Genkit, policy_config: ChecksEvaluationMetricConfig[] ): EvaluatorAction { - return ai.defineEvaluator( + return evaluator( { name: 'checks/guardrails', displayName: 'checks/guardrails', @@ -83,7 +82,6 @@ function createPolicyEvaluator( }; const response = await checksEvalInstance( - ai, projectId, auth, partialRequest, @@ -109,14 +107,12 @@ function createPolicyEvaluator( } async function checksEvalInstance( - ai: Genkit, projectId: string, auth: GoogleAuth, partialRequest: any, responseSchema: ResponseType ): Promise> { return await runInNewSpan( - ai, { metadata: { name: 'EvaluationService#evaluateInstances', diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 93a3a6506d..d6dd9bf3b5 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -14,12 +14,11 @@ * limitations under the License. */ -import type { Genkit } from 'genkit'; import { logger } from 'genkit/logging'; import type { ModelMiddleware } from 'genkit/model'; -import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin'; +import { genkitPluginV2, type GenkitPluginV2 } from 'genkit/plugin'; import { GoogleAuth, type GoogleAuthOptions } from 'google-auth-library'; -import { checksEvaluators } from './evaluation.js'; +import { checksEvaluator } from './evaluation.js'; import { ChecksEvaluationMetricType, type ChecksEvaluationMetric, @@ -44,26 +43,52 @@ const CLOUD_PLATFROM_OAUTH_SCOPE = const CHECKS_OAUTH_SCOPE = 'https://www.googleapis.com/auth/checks'; -/** - * Add Google Checks evaluators. - */ -export function checks(options?: PluginOptions): GenkitPlugin { - return genkitPlugin('checks', async (ai: Genkit) => { - const googleAuth = inititializeAuth(options?.googleAuthOptions); +export async function getProjectId( + googleAuth: GoogleAuth, + options?: PluginOptions +): Promise { + const projectId = options?.projectId || (await googleAuth.getProjectId()); - const projectId = options?.projectId || (await googleAuth.getProjectId()); + if (!projectId) { + throw new Error( + `Checks Plugin is missing the 'projectId' configuration. Please set the 'GCLOUD_PROJECT' environment variable or explicitly pass 'projectId' into Genkit config.` + ); + } - if (!projectId) { - throw new Error( - `Checks Plugin is missing the 'projectId' configuration. Please set the 'GCLOUD_PROJECT' environment variable or explicitly pass 'projectId' into genkit config.` - ); - } + return projectId; +} - const metrics = - options?.evaluation && options.evaluation.metrics.length > 0 - ? options.evaluation.metrics - : []; - checksEvaluators(ai, googleAuth, metrics, projectId); +/** + * Add Google Checks evaluators. + */ +export function checks(options?: PluginOptions): GenkitPluginV2 { + const googleAuth = inititializeAuth(options?.googleAuthOptions); + + const metrics = + options?.evaluation && options.evaluation.metrics.length > 0 + ? options.evaluation.metrics + : []; + + return genkitPluginV2({ + name: 'checks', + init: async () => { + return [ + checksEvaluator( + googleAuth, + metrics, + await getProjectId(googleAuth, options) + ), + ]; + }, + list: async () => { + return [ + checksEvaluator( + googleAuth, + metrics, + await getProjectId(googleAuth, options) + ).__action, + ]; + }, }); }