diff --git a/@types/express.d.ts b/@types/express.d.ts index 50664af..16f9b82 100644 --- a/@types/express.d.ts +++ b/@types/express.d.ts @@ -1,3 +1,4 @@ +import {UserModel} from '../src/models/user.model.ts'; import 'express'; declare global { diff --git a/src/app.ts b/src/app.ts index 59baa36..0e30a85 100644 --- a/src/app.ts +++ b/src/app.ts @@ -18,7 +18,7 @@ import {PrismaSessionStore} from '@quixo3/prisma-session-store'; import {prisma} from './db.config.js'; import {BaseError} from './errors.js'; import swaggerDocument from '../swagger/openapi.json' assert {type: 'json'}; - +import {labelDetectionController} from './controllers/tags-ai.controller.js'; dotenv.config(); const app = express(); @@ -89,10 +89,12 @@ app.use(passport.session()); app.use('/oauth2', authRouter); app.use('/memo', memoFolderRouter); app.use('/challenge', challengeRouter); +app.post('/image/ai', labelDetectionController); app.get('/', (req: Request, res: Response) => { res.send('Sweepic'); }); + // Error handling middleware (수정된 부분) const errorHandler: ErrorRequestHandler = (err, req, res, next) => { if (res.headersSent) { diff --git a/src/controllers/tags-ai.controller.ts b/src/controllers/tags-ai.controller.ts new file mode 100644 index 0000000..6d95490 --- /dev/null +++ b/src/controllers/tags-ai.controller.ts @@ -0,0 +1,131 @@ +import {Request, Response} from 'express'; +import {detectLabels} from '../services/tags-ai.service.js'; +import {StatusCodes} from 'http-status-codes'; +import { + DataValidationError, + LabelDetectionError, + LabelNotFoundError, +} from '../errors.js'; + +export const labelDetectionController = async ( + req: Request, + res: Response, +): Promise => { + /* + #swagger.tags = ['label-detection'] + #swagger.summary = '이미지 라벨링' + #swagger.description = 'Base64 데이터를 JSON으로 받아 이미지를 분석하여 상위 3개의 라벨과 정확도를 반환합니다.' + #swagger.requestBody = { + required: true, + content: { + "application/json": { + schema: { + type: "object", + properties: { + base64_image: { + type: "string", + description: "Base64 인코딩된 이미지 데이터", + example: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD..." + } + } + } + } + } + } + #swagger.responses[200] = { + description: "라벨링 결과 반환", + content: { + "application/json": { + schema: { + type: "object", + properties: { + topLabels: { + type: "array", + items: { + type: "object", + properties: { + description: { type: "string", example: "Mountain" }, + score: { type: "number", example: 0.95 } + } + } + } + } + } + } + } + } + #swagger.responses[400] = { + description: "잘못된 요청 데이터", + content: { + "application/json": { + schema: { + type: "object", + properties: { + error: { type: "string", example: "Base64 이미지 데이터가 제공되지 않았습니다." } + } + } + } + } + } + #swagger.responses[500] = { + description: "서버 내부 오류", + content: { + "application/json": { + schema: { + type: "object", + properties: { + error: { type: "string", example: "라벨링 중 오류가 발생했습니다." } + } + } + } + } + } + */ + + try { + // Base64 이미지 데이터가 요청에 포함되었는지 확인 + const {base64_image} = req.body; + + if (!base64_image) { + throw new DataValidationError({ + reason: 'Base64 이미지 데이터가 제공되지 않았습니다.', + }); + } + + // Base64 데이터에서 MIME 타입 제거 + const base64Data = base64_image.replace(/^data:image\/\w+;base64,/, ''); + + // 서비스 호출 + const labels = await detectLabels(base64Data); + + // 라벨 반환 + res.status(StatusCodes.OK).json({topLabels: labels}); + } catch (error) { + console.error('Error in labelDetectionController:', error); + + // 커스텀 에러 처리 + if ( + error instanceof DataValidationError || + error instanceof LabelNotFoundError + ) { + res.status(error.statusCode).json({ + errorCode: error.code, + reason: error.message, + details: error.details, + }); + } else if (error instanceof LabelDetectionError) { + res.status(error.statusCode).json({ + errorCode: error.code, + reason: error.message, + details: error.details, + }); + } else { + // 기타 예상치 못한 에러 처리 + res.status(StatusCodes.INTERNAL_SERVER_ERROR).json({ + errorCode: 'unknown', + reason: '예상치 못한 서버 오류가 발생했습니다.', + details: null, + }); + } + } +}; diff --git a/src/errors.ts b/src/errors.ts index ac2531b..88df8f8 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -122,3 +122,16 @@ export class PhotoDataNotFoundError extends BaseError { super(404, 'PHO-404', '사진 데이터가 없습니다.', details); } } + +// 라벨링 관련 에러 (LBL-Labeling) +export class LabelDetectionError extends BaseError { + constructor(details?: ErrorDetails) { + super(500, 'LBL-500', '라벨링 처리 중 오류가 발생했습니다.', details); + } +} + +export class LabelNotFoundError extends BaseError { + constructor(details?: ErrorDetails) { + super(404, 'LBL-404', '이미지에서 라벨을 감지하지 못했습니다.', details); + } +} diff --git a/src/routers/memo.router.ts b/src/routers/memo.router.ts index 858226d..6045a58 100644 --- a/src/routers/memo.router.ts +++ b/src/routers/memo.router.ts @@ -12,6 +12,7 @@ import { handlerMemoSearch, handlerMemoTextImageList, } from '../controllers/memo-folder.controller.js'; + memoFolderRouter.post('/folders', handlerMemoFolderAdd); memoFolderRouter.post( '/image-format/folders', diff --git a/src/services/memo-ocrService.ts b/src/services/memo-ocrService.ts index a46d5f5..4155bff 100644 --- a/src/services/memo-ocrService.ts +++ b/src/services/memo-ocrService.ts @@ -1,6 +1,5 @@ import {ImageAnnotatorClient} from '@google-cloud/vision'; -import path from 'path'; -import {fileURLToPath} from 'url'; + import {folderRepository} from '../repositories/memo-OCR.repositoy.js'; import {OCRRequest} from '../models/memo-OCR.model.js'; import { @@ -8,15 +7,11 @@ import { FolderDuplicateError, PhotoDataNotFoundError, } from '../errors.js'; -// ES Module 환경에서 __dirname 대체 -const __filename = fileURLToPath(import.meta.url); -const __dirname = path.dirname(__filename); + +import path from 'path'; // Google Cloud Vision 클라이언트 초기화 -const keyFilename = path.resolve( - __dirname, - '../../sweepicai-00d515e813ea.json', -); +const keyFilename = path.resolve('../sweepicai-00d515e813ea.json'); const visionClient = new ImageAnnotatorClient({keyFilename}); diff --git a/src/services/tags-ai.service.ts b/src/services/tags-ai.service.ts new file mode 100644 index 0000000..56bd127 --- /dev/null +++ b/src/services/tags-ai.service.ts @@ -0,0 +1,41 @@ +import {ImageAnnotatorClient} from '@google-cloud/vision'; +import {LabelDetectionError, LabelNotFoundError} from '../errors.js'; + +import path from 'path'; + +const keyFilename = path.resolve('../sweepicai-00d515e813ea.json'); + +const visionClient = new ImageAnnotatorClient({keyFilename}); +export const detectLabels = async ( + base64Image: string, +): Promise<{description: string; score: number}[]> => { + try { + // Vision API 호출 + const [result] = await visionClient.labelDetection({ + image: {content: base64Image}, + }); + + const labels = result.labelAnnotations; + + // 라벨이 없으면 LabelNotFoundError 발생 + if (!labels || labels.length === 0) { + throw new LabelNotFoundError({ + reason: '이미지에서 라벨을 감지하지 못했습니다.', + }); + } + + // 상위 3개의 라벨 반환 + return labels + .sort((a, b) => (b.score || 0) - (a.score || 0)) // 정확도 내림차순 정렬 + .slice(0, 3) // 상위 3개만 선택 + .map(label => ({ + description: label.description || 'Unknown', + score: label.score || 0, + })); + } catch (error) { + console.error('Error in detectLabels service:', error); + throw new LabelDetectionError({ + reason: '라벨링 처리 중 오류가 발생했습니다.', + }); + } +};