Skip to content

Commit 4ef0a07

Browse files
committed
✨ Added ability to modify Safety Settings
1 parent 7b3f41f commit 4ef0a07

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

src/index.ts

+36-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Command } from "./types";
1+
import { Command, HarmCategory, SafetyThreshold } from "./types";
22

33
import type {
44
ChatAskOptions,
@@ -15,7 +15,7 @@ import type {
1515
QueryResponseMap,
1616
} from "./types";
1717

18-
import { getFileType, handleReader, pairToMessage } from "./utils";
18+
import { SafetyError, getFileType, handleReader, pairToMessage } from "./utils";
1919

2020
const uploadFile = async ({
2121
file,
@@ -162,6 +162,7 @@ class Gemini {
162162

163163
static TEXT = "text" as const;
164164
static JSON = "json" as const;
165+
static SafetyThreshold = SafetyThreshold;
165166

166167
constructor(key: string, options: Partial<GeminiOptions> = {}) {
167168
if (!options.fetch && typeof fetch !== "function") {
@@ -281,7 +282,7 @@ class Gemini {
281282
<F extends Format>(format: F = Gemini.TEXT as F) =>
282283
(response: GeminiResponse): FormatType<F> => {
283284
if (response.candidates[0].finishReason === "SAFETY") {
284-
throw new Error(
285+
throw new SafetyError(
285286
`Your prompt was blocked by Google. Here are the Harm Categories: \n${JSON.stringify(
286287
response.candidates[0].safetyRatings,
287288
null,
@@ -337,10 +338,35 @@ class Gemini {
337338
maxOutputTokens: 2048,
338339
data: [],
339340
messages: [],
341+
safetySettings: {
342+
hate: Gemini.SafetyThreshold.BLOCK_SOME,
343+
sexual: Gemini.SafetyThreshold.BLOCK_SOME,
344+
harassment: Gemini.SafetyThreshold.BLOCK_SOME,
345+
dangerous: Gemini.SafetyThreshold.BLOCK_SOME,
346+
},
340347
},
341348
...options,
342349
};
343350

351+
const safety_settings = [
352+
{
353+
category: HarmCategory.HateSpeech,
354+
threshold: parsedOptions.safetySettings.hate,
355+
},
356+
{
357+
category: HarmCategory.SexuallyExplicit,
358+
threshold: parsedOptions.safetySettings.sexual,
359+
},
360+
{
361+
category: HarmCategory.Harassment,
362+
threshold: parsedOptions.safetySettings.harassment,
363+
},
364+
{
365+
category: HarmCategory.DangerousContent,
366+
threshold: parsedOptions.safetySettings.dangerous,
367+
},
368+
];
369+
344370
const command = parsedOptions.stream
345371
? Command.StreamGenerate
346372
: Command.Generate;
@@ -378,6 +404,7 @@ class Gemini {
378404
topP: parsedOptions.topP,
379405
topK: parsedOptions.topK,
380406
},
407+
safety_settings,
381408
};
382409

383410
const response: Response = await this.query(
@@ -440,6 +467,12 @@ class Chat {
440467
...{
441468
data: [],
442469
format: Gemini.TEXT as F,
470+
safetySettings: {
471+
hate: Gemini.SafetyThreshold.BLOCK_SOME,
472+
sexual: Gemini.SafetyThreshold.BLOCK_SOME,
473+
harassment: Gemini.SafetyThreshold.BLOCK_SOME,
474+
dangerous: Gemini.SafetyThreshold.BLOCK_SOME,
475+
},
443476
},
444477
...options,
445478
};

src/types.ts

+25
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ export enum Command {
6060
Count = "countTokens",
6161
}
6262

63+
export enum HarmCategory {
64+
HateSpeech = "HARM_CATEGORY_HATE_SPEECH",
65+
SexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT",
66+
Harassment = "HARM_CATEGORY_HARASSMENT",
67+
DangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT",
68+
}
69+
6370
/**
6471
* The body used for the API call to generateContent or streamGenerateContent
6572
*/
@@ -71,6 +78,7 @@ type GenerateContentBody = {
7178
topP: number;
7279
topK: number;
7380
};
81+
safety_settings: { category: HarmCategory; threshold: SafetyThreshold }[];
7482
};
7583

7684
/**
@@ -143,6 +151,12 @@ export type CommandOptionMap<F extends Format = TextFormat> = {
143151
maxOutputTokens: number;
144152
model: string;
145153
data: Buffer[];
154+
safetySettings: {
155+
hate: SafetyThreshold;
156+
sexual: SafetyThreshold;
157+
harassment: SafetyThreshold;
158+
dangerous: SafetyThreshold;
159+
};
146160
messages: ([string, string] | Message)[];
147161
stream?(stream: CommandResponseMap<F>[Command.StreamGenerate]): void;
148162
};
@@ -154,6 +168,17 @@ export type CommandOptionMap<F extends Format = TextFormat> = {
154168
};
155169
};
156170

171+
export enum SafetyThreshold {
172+
// Content with NEGLIGIBLE will be allowed.
173+
BLOCK_MOST = "BLOCK_LOW_AND_ABOVE",
174+
// Content with NEGLIGIBLE and LOW will be allowed.
175+
BLOCK_SOME = "BLOCK_MEDIUM_AND_ABOVE",
176+
// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
177+
BLOCK_FEW = "BLOCK_ONLY_HIGH",
178+
// All content will be allowed.
179+
BLOCK_NONE = "BLOCK_NONE",
180+
}
181+
157182
export type FormatType<T> = T extends JSONFormat ? GeminiResponse : string;
158183

159184
export type ChatOptions = {

0 commit comments

Comments
 (0)