Skip to content

Commit 98f36d3

Browse files
committed
Merge branch 'main' of https://github.com/Akshay-66/gateway
2 parents be1e892 + 08121d6 commit 98f36d3

24 files changed

+572
-95
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ Make your AI app more <ins>reliable</ins> and <ins>forward compatible</ins>, whi
264264
&nbsp; Secure Key Management - for role-based access control and tracking <br>
265265
&nbsp; Simple & Semantic Caching - to serve repeat queries faster & save costs <br>
266266
&nbsp; Access Control & Inbound Rules - to control which IPs and Geos can connect to your deployments <br>
267-
&nbsp; PII Redaction - to automatically remove sensitive data from your requests to prevent indavertent exposure <br>
267+
&nbsp; PII Redaction - to automatically remove sensitive data from your requests to prevent inadvertent exposure <br>
268268
&nbsp; SOC2, ISO, HIPAA, GDPR Compliances - for best security practices <br>
269269
&nbsp; Professional Support - along with feature prioritization <br>
270270

plugins/Medical-Advice-Detection/jest.config.js

-6
This file was deleted.

src/globals.ts

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export const HEADER_KEYS: Record<string, string> = {
1313
CUSTOM_HOST: `x-${POWERED_BY}-custom-host`,
1414
REQUEST_TIMEOUT: `x-${POWERED_BY}-request-timeout`,
1515
STRICT_OPEN_AI_COMPLIANCE: `x-${POWERED_BY}-strict-open-ai-compliance`,
16+
CONTENT_TYPE: `Content-Type`,
1617
};
1718

1819
export const RESPONSE_HEADER_KEYS: Record<string, string> = {

src/handlers/handlerUtils.ts

+22-4
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import {
1111
OPEN_AI,
1212
AZURE_AI_INFERENCE,
1313
ANTHROPIC,
14-
MULTIPART_FORM_DATA_ENDPOINTS,
1514
CONTENT_TYPES,
1615
HUGGING_FACE,
16+
STABILITY_AI,
1717
} from '../globals';
1818
import Providers from '../providers';
1919
import { ProviderAPIConfig, endpointStrings } from '../providers/types';
@@ -524,6 +524,7 @@ export async function tryPost(
524524
fn,
525525
transformedRequestBody,
526526
transformedRequestUrl: url,
527+
gatewayRequestBody: params,
527528
});
528529

529530
// Construct the base object for the POST request
@@ -535,9 +536,10 @@ export async function tryPost(
535536
requestHeaders
536537
);
537538

538-
fetchOptions.body = MULTIPART_FORM_DATA_ENDPOINTS.includes(fn)
539-
? (transformedRequestBody as FormData)
540-
: JSON.stringify(transformedRequestBody);
539+
fetchOptions.body =
540+
headers[HEADER_KEYS.CONTENT_TYPE] === CONTENT_TYPES.MULTIPART_FORM_DATA
541+
? (transformedRequestBody as FormData)
542+
: JSON.stringify(transformedRequestBody);
541543

542544
providerOption.retry = {
543545
attempts: providerOption.retry?.attempts ?? 0,
@@ -1012,6 +1014,14 @@ export function constructConfigFromRequestHeaders(
10121014
azureModelName: requestHeaders[`x-${POWERED_BY}-azure-model-name`],
10131015
};
10141016

1017+
const stabilityAiConfig = {
1018+
stabilityClientId: requestHeaders[`x-${POWERED_BY}-stability-client-id`],
1019+
stabilityClientUserId:
1020+
requestHeaders[`x-${POWERED_BY}-stability-client-user-id`],
1021+
stabilityClientVersion:
1022+
requestHeaders[`x-${POWERED_BY}-stability-client-version`],
1023+
};
1024+
10151025
const azureAiInferenceConfig = {
10161026
azureDeploymentName:
10171027
requestHeaders[`x-${POWERED_BY}-azure-deployment-name`],
@@ -1128,6 +1138,12 @@ export function constructConfigFromRequestHeaders(
11281138
...anthropicConfig,
11291139
};
11301140
}
1141+
if (parsedConfigJson.provider === STABILITY_AI) {
1142+
parsedConfigJson = {
1143+
...parsedConfigJson,
1144+
...stabilityAiConfig,
1145+
};
1146+
}
11311147
}
11321148
return convertKeysToCamelCase(parsedConfigJson, [
11331149
'override_params',
@@ -1158,6 +1174,8 @@ export function constructConfigFromRequestHeaders(
11581174
huggingfaceConfig),
11591175
mistralFimCompletion:
11601176
requestHeaders[`x-${POWERED_BY}-mistral-fim-completion`],
1177+
...(requestHeaders[`x-${POWERED_BY}-provider`] === STABILITY_AI &&
1178+
stabilityAiConfig),
11611179
};
11621180
}
11631181

src/providers/bedrock/constants.ts

+5
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@ export const MISTRAL_CONTROL_TOKENS = {
3434
MIDDLE: '[MIDDLE]',
3535
SUFFIX: '[SUFFIX]',
3636
};
37+
38+
export const BEDROCK_STABILITY_V1_MODELS = [
39+
'stable-diffusion-xl-v0',
40+
'stable-diffusion-xl-v1',
41+
];

src/providers/bedrock/imageGenerate.ts

+39-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import { BEDROCK } from '../../globals';
2+
import { StabilityAIImageGenerateV2Config } from '../stability-ai/imageGenerateV2';
23
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
34
import { generateInvalidProviderResponseError } from '../utils';
45
import { BedrockErrorResponseTransform } from './chatComplete';
56
import { BedrockErrorResponse } from './embed';
67

7-
export const BedrockStabilityAIImageGenerateConfig: ProviderConfig = {
8+
export const BedrockStabilityAIImageGenerateV1Config: ProviderConfig = {
89
prompt: {
910
param: 'text_prompts',
1011
required: true,
@@ -47,29 +48,60 @@ interface ImageArtifact {
4748
seed: number;
4849
}
4950

50-
interface BedrockStabilityAIImageGenerateResponse {
51+
interface BedrockStabilityAIImageGenerateV1Response {
5152
result: string;
5253
artifacts: ImageArtifact[];
5354
}
5455

55-
export const BedrockStabilityAIImageGenerateResponseTransform: (
56-
response: BedrockStabilityAIImageGenerateResponse | BedrockErrorResponse,
56+
export const BedrockStabilityAIImageGenerateV1ResponseTransform: (
57+
response: BedrockStabilityAIImageGenerateV1Response | BedrockErrorResponse,
5758
responseStatus: number
5859
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
5960
if (responseStatus !== 200) {
60-
const errorResposne = BedrockErrorResponseTransform(
61+
const errorResponse = BedrockErrorResponseTransform(
6162
response as BedrockErrorResponse
6263
);
63-
if (errorResposne) return errorResposne;
64+
if (errorResponse) return errorResponse;
6465
}
6566

6667
if ('artifacts' in response) {
6768
return {
68-
created: `${new Date().getTime()}`,
69+
created: Math.floor(Date.now() / 1000),
6970
data: response.artifacts.map((art) => ({ b64_json: art.base64 })),
7071
provider: BEDROCK,
7172
};
7273
}
7374

7475
return generateInvalidProviderResponseError(response, BEDROCK);
7576
};
77+
78+
interface BedrockStabilityAIImageGenerateV2Response {
79+
seeds: number[];
80+
finish_reasons: string[];
81+
images: string[];
82+
}
83+
84+
export const BedrockStabilityAIImageGenerateV2Config =
85+
StabilityAIImageGenerateV2Config;
86+
87+
export const BedrockStabilityAIImageGenerateV2ResponseTransform: (
88+
response: BedrockStabilityAIImageGenerateV2Response | BedrockErrorResponse,
89+
responseStatus: number
90+
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
91+
if (responseStatus !== 200) {
92+
const errorResponse = BedrockErrorResponseTransform(
93+
response as BedrockErrorResponse
94+
);
95+
if (errorResponse) return errorResponse;
96+
}
97+
98+
if ('images' in response) {
99+
return {
100+
created: Math.floor(Date.now() / 1000),
101+
data: response.images.map((image) => ({ b64_json: image })),
102+
provider: BEDROCK,
103+
};
104+
}
105+
106+
return generateInvalidProviderResponseError(response, BEDROCK);
107+
};

src/providers/bedrock/index.ts

+19-6
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ import {
4242
BedrockTitanCompleteResponseTransform,
4343
BedrockTitanCompleteStreamChunkTransform,
4444
} from './complete';
45+
import { BEDROCK_STABILITY_V1_MODELS } from './constants';
4546
import {
4647
BedrockCohereEmbedConfig,
4748
BedrockCohereEmbedResponseTransform,
4849
BedrockTitanEmbedConfig,
4950
BedrockTitanEmbedResponseTransform,
5051
} from './embed';
5152
import {
52-
BedrockStabilityAIImageGenerateConfig,
53-
BedrockStabilityAIImageGenerateResponseTransform,
53+
BedrockStabilityAIImageGenerateV1Config,
54+
BedrockStabilityAIImageGenerateV1ResponseTransform,
55+
BedrockStabilityAIImageGenerateV2Config,
56+
BedrockStabilityAIImageGenerateV2ResponseTransform,
5457
} from './imageGenerate';
5558

5659
const BedrockConfig: ProviderConfigs = {
@@ -63,8 +66,9 @@ const BedrockConfig: ProviderConfigs = {
6366
// To remove the region in case its a cross-region inference profile ID
6467
// https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html
6568
const providerModel = params.model.replace(/^(us\.|eu\.)/, '');
66-
const provider = providerModel?.split('.')[0];
67-
const model = providerModel?.split('.')[1];
69+
const providerModelArray = providerModel.split('.');
70+
const provider = providerModelArray[0];
71+
const model = providerModelArray.slice(1).join('.');
6872
switch (provider) {
6973
case ANTHROPIC:
7074
return {
@@ -148,11 +152,20 @@ const BedrockConfig: ProviderConfigs = {
148152
},
149153
};
150154
case 'stability':
155+
if (model && BEDROCK_STABILITY_V1_MODELS.includes(model)) {
156+
return {
157+
imageGenerate: BedrockStabilityAIImageGenerateV1Config,
158+
api: BedrockAPIConfig,
159+
responseTransforms: {
160+
imageGenerate: BedrockStabilityAIImageGenerateV1ResponseTransform,
161+
},
162+
};
163+
}
151164
return {
152-
imageGenerate: BedrockStabilityAIImageGenerateConfig,
165+
imageGenerate: BedrockStabilityAIImageGenerateV2Config,
153166
api: BedrockAPIConfig,
154167
responseTransforms: {
155-
imageGenerate: BedrockStabilityAIImageGenerateResponseTransform,
168+
imageGenerate: BedrockStabilityAIImageGenerateV2ResponseTransform,
156169
},
157170
};
158171
default:

src/providers/fireworks-ai/imageGenerate.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ export const FireworksAIImageGenerateResponseTransform: (
100100
}
101101
if (response instanceof Array) {
102102
return {
103-
created: `${new Date().getTime()}`, // Corrected method call
103+
created: Math.floor(Date.now() / 1000), // Corrected method call
104104
data: response?.map((r) => ({
105105
b64_json: r.base64,
106106
seed: r.seed,

src/providers/google-vertex-ai/api.ts

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ export const GoogleApiConfig: ProviderAPIConfig = {
6363
'embed',
6464
`${projectRoute}/publishers/${provider}/models/${model}:predict`,
6565
],
66+
[
67+
'imageGenerate',
68+
`${projectRoute}/publishers/${provider}/models/${model}:predict`,
69+
],
6670
]);
6771

6872
switch (provider) {

0 commit comments

Comments
 (0)