Skip to content

Commit 3f5ebab

Browse files
pashankagsiddh
authored andcommitted
Complete Hybrid inference impl
Fix languageCode parameter in action_code_url (#8912) * Fix languageCode parameter in action_code_url * Add changeset Vaihi add langmodel types. (#8927) * Adding LanguageModel types. These are based off https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl * Adding LanguageModel types. * Remove bunch of exports * yarn formatted * after lint Define HybridParams (#8935) Co-authored-by: Erik Eldridge <[email protected]> Adding smoke test for new hybrid params (#8937) * Adding smoke test for new hybrid params * Use the existing name of the model params input --------- Co-authored-by: Erik Eldridge <[email protected]> Moving to in-cloud naming (#8938) Co-authored-by: Erik Eldridge <[email protected]> Moving to string type for the inference mode (#8941) Define ChromeAdapter class (#8942) Co-authored-by: Erik Eldridge <[email protected]> VinF Hybrid Inference: Implement ChromeAdapter (rebased) (#8943) Adding count token impl (#8950) VinF Hybrid Inference #4: ChromeAdapter in stream methods (rebased) (#8949) Define values for Availability enum (#8951) VinF Hybrid Inference: narrow Chrome input type (#8953) Add image inference support (#8954) * Adding image based input for inference * adding image as input to create language model object disable count tokens api for on-device inference (#8962) VinF Hybrid Inference: throw if only_on_device and model is unavailable (#8965)
1 parent 475c81a commit 3f5ebab

16 files changed

+1216
-102
lines changed

e2e/sample-apps/modular.js

+21-15
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ import {
5858
onValue,
5959
off
6060
} from 'firebase/database';
61-
import { getGenerativeModel, getVertexAI, VertexAI } from 'firebase/vertexai';
61+
import { getGenerativeModel, getVertexAI } from 'firebase/vertexai';
6262
import { getDataConnect, DataConnect } from 'firebase/data-connect';
6363

6464
/**
@@ -313,9 +313,15 @@ function callPerformance(app) {
313313
async function callVertexAI(app) {
314314
console.log('[VERTEXAI] start');
315315
const vertexAI = getVertexAI(app);
316-
const model = getGenerativeModel(vertexAI, { model: 'gemini-1.5-flash' });
317-
const result = await model.countTokens('abcdefg');
318-
console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`);
316+
const model = getGenerativeModel(vertexAI, {
317+
mode: 'only_on_device'
318+
});
319+
const singleResult = await model.generateContent([
320+
{ text: 'describe the following:' },
321+
{ text: 'the mojave desert' }
322+
]);
323+
console.log(`Generated text: ${singleResult.response.text()}`);
324+
console.log(`[VERTEXAI] end`);
319325
}
320326

321327
/**
@@ -341,18 +347,18 @@ async function main() {
341347
const app = initializeApp(config);
342348
setLogLevel('warn');
343349

344-
callAppCheck(app);
345-
await authLogin(app);
346-
await callStorage(app);
347-
await callFirestore(app);
348-
await callDatabase(app);
349-
await callMessaging(app);
350-
callAnalytics(app);
351-
callPerformance(app);
352-
await callFunctions(app);
350+
// callAppCheck(app);
351+
// await authLogin(app);
352+
// await callStorage(app);
353+
// await callFirestore(app);
354+
// await callDatabase(app);
355+
// await callMessaging(app);
356+
// callAnalytics(app);
357+
// callPerformance(app);
358+
// await callFunctions(app);
353359
await callVertexAI(app);
354-
callDataConnect(app);
355-
await authLogout(app);
360+
// callDataConnect(app);
361+
// await authLogout(app);
356362
console.log('DONE');
357363
}
358364

packages/vertexai/src/api.test.ts

+15
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ describe('Top level API', () => {
101101
expect(genModel).to.be.an.instanceOf(GenerativeModel);
102102
expect(genModel.model).to.equal('publishers/google/models/my-model');
103103
});
104+
it('getGenerativeModel with HybridParams sets a default model', () => {
105+
const genModel = getGenerativeModel(fakeAI, {
106+
mode: 'only_on_device'
107+
});
108+
expect(genModel.model).to.equal(
109+
`publishers/google/models/${GenerativeModel.DEFAULT_HYBRID_IN_CLOUD_MODEL}`
110+
);
111+
});
112+
it('getGenerativeModel with HybridParams honors a model override', () => {
113+
const genModel = getGenerativeModel(fakeAI, {
114+
mode: 'prefer_on_device',
115+
inCloudParams: { model: 'my-model' }
116+
});
117+
expect(genModel.model).to.equal('publishers/google/models/my-model');
118+
});
104119
it('getImagenModel throws if no model is provided', () => {
105120
try {
106121
getImagenModel(fakeAI, {} as ImagenModelParams);

packages/vertexai/src/api.ts

+26-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { AIService } from './service';
2323
import { AI, AIOptions, VertexAI, VertexAIOptions } from './public-types';
2424
import {
2525
ImagenModelParams,
26+
HybridParams,
2627
ModelParams,
2728
RequestOptions,
2829
AIErrorCode
@@ -31,6 +32,8 @@ import { AIError } from './errors';
3132
import { AIModel, GenerativeModel, ImagenModel } from './models';
3233
import { encodeInstanceIdentifier } from './helpers';
3334
import { GoogleAIBackend, VertexAIBackend } from './backend';
35+
import { ChromeAdapter } from './methods/chrome-adapter';
36+
import { LanguageModel } from './types/language-model';
3437

3538
export { ChatSession } from './methods/chat-session';
3639
export * from './requests/schema-builder';
@@ -138,16 +141,36 @@ export function getAI(
138141
*/
139142
export function getGenerativeModel(
140143
ai: AI,
141-
modelParams: ModelParams,
144+
modelParams: ModelParams | HybridParams,
142145
requestOptions?: RequestOptions
143146
): GenerativeModel {
144-
if (!modelParams.model) {
147+
// Uses the existence of HybridParams.mode to clarify the type of the modelParams input.
148+
const hybridParams = modelParams as HybridParams;
149+
let inCloudParams: ModelParams;
150+
if (hybridParams.mode) {
151+
inCloudParams = hybridParams.inCloudParams || {
152+
model: GenerativeModel.DEFAULT_HYBRID_IN_CLOUD_MODEL
153+
};
154+
} else {
155+
inCloudParams = modelParams as ModelParams;
156+
}
157+
158+
if (!inCloudParams.model) {
145159
throw new AIError(
146160
AIErrorCode.NO_MODEL,
147161
`Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`
148162
);
149163
}
150-
return new GenerativeModel(ai, modelParams, requestOptions);
164+
return new GenerativeModel(
165+
ai,
166+
inCloudParams,
167+
new ChromeAdapter(
168+
window.LanguageModel as LanguageModel,
169+
hybridParams.mode,
170+
hybridParams.onDeviceParams
171+
),
172+
requestOptions
173+
);
151174
}
152175

153176
/**

packages/vertexai/src/methods/chat-session.test.ts

+16-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { GenerateContentStreamResult } from '../types';
2424
import { ChatSession } from './chat-session';
2525
import { ApiSettings } from '../types/internal';
2626
import { VertexAIBackend } from '../backend';
27+
import { ChromeAdapter } from './chrome-adapter';
2728

2829
use(sinonChai);
2930
use(chaiAsPromised);
@@ -46,7 +47,11 @@ describe('ChatSession', () => {
4647
generateContentMethods,
4748
'generateContent'
4849
).rejects('generateContent failed');
49-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
50+
const chatSession = new ChatSession(
51+
fakeApiSettings,
52+
'a-model',
53+
new ChromeAdapter()
54+
);
5055
await expect(chatSession.sendMessage('hello')).to.be.rejected;
5156
expect(generateContentStub).to.be.calledWith(
5257
fakeApiSettings,
@@ -63,7 +68,11 @@ describe('ChatSession', () => {
6368
generateContentMethods,
6469
'generateContentStream'
6570
).rejects('generateContentStream failed');
66-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
71+
const chatSession = new ChatSession(
72+
fakeApiSettings,
73+
'a-model',
74+
new ChromeAdapter()
75+
);
6776
await expect(chatSession.sendMessageStream('hello')).to.be.rejected;
6877
expect(generateContentStreamStub).to.be.calledWith(
6978
fakeApiSettings,
@@ -82,7 +91,11 @@ describe('ChatSession', () => {
8291
generateContentMethods,
8392
'generateContentStream'
8493
).resolves({} as unknown as GenerateContentStreamResult);
85-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
94+
const chatSession = new ChatSession(
95+
fakeApiSettings,
96+
'a-model',
97+
new ChromeAdapter()
98+
);
8699
await chatSession.sendMessageStream('hello');
87100
expect(generateContentStreamStub).to.be.calledWith(
88101
fakeApiSettings,

packages/vertexai/src/methods/chat-session.ts

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import { validateChatHistory } from './chat-session-helpers';
3030
import { generateContent, generateContentStream } from './generate-content';
3131
import { ApiSettings } from '../types/internal';
3232
import { logger } from '../logger';
33+
import { ChromeAdapter } from './chrome-adapter';
3334

3435
/**
3536
* Do not log a message for this error.
@@ -50,6 +51,7 @@ export class ChatSession {
5051
constructor(
5152
apiSettings: ApiSettings,
5253
public model: string,
54+
private chromeAdapter: ChromeAdapter,
5355
public params?: StartChatParams,
5456
public requestOptions?: RequestOptions
5557
) {
@@ -95,6 +97,7 @@ export class ChatSession {
9597
this._apiSettings,
9698
this.model,
9799
generateContentRequest,
100+
this.chromeAdapter,
98101
this.requestOptions
99102
)
100103
)
@@ -146,6 +149,7 @@ export class ChatSession {
146149
this._apiSettings,
147150
this.model,
148151
generateContentRequest,
152+
this.chromeAdapter,
149153
this.requestOptions
150154
);
151155

0 commit comments

Comments
 (0)