|
1 | 1 | import { BEDROCK } from '../../globals';
|
| 2 | +import { StabilityAIImageGenerateV2Config } from '../stability-ai/imageGenerateV2'; |
2 | 3 | import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
|
3 | 4 | import { generateInvalidProviderResponseError } from '../utils';
|
4 | 5 | import { BedrockErrorResponseTransform } from './chatComplete';
|
5 | 6 | import { BedrockErrorResponse } from './embed';
|
6 | 7 |
|
7 |
| -export const BedrockStabilityAIImageGenerateConfig: ProviderConfig = { |
| 8 | +export const BedrockStabilityAIImageGenerateV1Config: ProviderConfig = { |
8 | 9 | prompt: {
|
9 | 10 | param: 'text_prompts',
|
10 | 11 | required: true,
|
@@ -47,29 +48,60 @@ interface ImageArtifact {
|
47 | 48 | seed: number;
|
48 | 49 | }
|
49 | 50 |
|
50 |
| -interface BedrockStabilityAIImageGenerateResponse { |
| 51 | +interface BedrockStabilityAIImageGenerateV1Response { |
51 | 52 | result: string;
|
52 | 53 | artifacts: ImageArtifact[];
|
53 | 54 | }
|
54 | 55 |
|
55 |
| -export const BedrockStabilityAIImageGenerateResponseTransform: ( |
56 |
| - response: BedrockStabilityAIImageGenerateResponse | BedrockErrorResponse, |
| 56 | +export const BedrockStabilityAIImageGenerateV1ResponseTransform: ( |
| 57 | + response: BedrockStabilityAIImageGenerateV1Response | BedrockErrorResponse, |
57 | 58 | responseStatus: number
|
58 | 59 | ) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
|
59 | 60 | if (responseStatus !== 200) {
|
60 |
| - const errorResposne = BedrockErrorResponseTransform( |
| 61 | + const errorResponse = BedrockErrorResponseTransform( |
61 | 62 | response as BedrockErrorResponse
|
62 | 63 | );
|
63 |
| - if (errorResposne) return errorResposne; |
| 64 | + if (errorResponse) return errorResponse; |
64 | 65 | }
|
65 | 66 |
|
66 | 67 | if ('artifacts' in response) {
|
67 | 68 | return {
|
68 |
| - created: `${new Date().getTime()}`, |
| 69 | + created: Math.floor(Date.now() / 1000), |
69 | 70 | data: response.artifacts.map((art) => ({ b64_json: art.base64 })),
|
70 | 71 | provider: BEDROCK,
|
71 | 72 | };
|
72 | 73 | }
|
73 | 74 |
|
74 | 75 | return generateInvalidProviderResponseError(response, BEDROCK);
|
75 | 76 | };
|
| 77 | + |
| 78 | +interface BedrockStabilityAIImageGenerateV2Response { |
| 79 | + seeds: number[]; |
| 80 | + finish_reasons: string[]; |
| 81 | + images: string[]; |
| 82 | +} |
| 83 | + |
| 84 | +export const BedrockStabilityAIImageGenerateV2Config = |
| 85 | + StabilityAIImageGenerateV2Config; |
| 86 | + |
| 87 | +export const BedrockStabilityAIImageGenerateV2ResponseTransform: ( |
| 88 | + response: BedrockStabilityAIImageGenerateV2Response | BedrockErrorResponse, |
| 89 | + responseStatus: number |
| 90 | +) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => { |
| 91 | + if (responseStatus !== 200) { |
| 92 | + const errorResponse = BedrockErrorResponseTransform( |
| 93 | + response as BedrockErrorResponse |
| 94 | + ); |
| 95 | + if (errorResponse) return errorResponse; |
| 96 | + } |
| 97 | + |
| 98 | + if ('images' in response) { |
| 99 | + return { |
| 100 | + created: Math.floor(Date.now() / 1000), |
| 101 | + data: response.images.map((image) => ({ b64_json: image })), |
| 102 | + provider: BEDROCK, |
| 103 | + }; |
| 104 | + } |
| 105 | + |
| 106 | + return generateInvalidProviderResponseError(response, BEDROCK); |
| 107 | +}; |
0 commit comments