diff --git a/projects/multi_gpu/README.md b/projects/multi_gpu/README.md index 812f8b49..51907575 100644 --- a/projects/multi_gpu/README.md +++ b/projects/multi_gpu/README.md @@ -31,7 +31,7 @@ python server.py ### 2. 启动客户端 以下代码展示了客户端的使用方式,可根据需求修改配置: ```python -files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件 +files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 pdf、jpg/jpeg、png、doc、docx、ppt、pptx 文件 n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改 results = Parallel(n_jobs, prefer='threads', verbose=10)( delayed(do_parse)(p) for p in files diff --git a/projects/multi_gpu/client.py b/projects/multi_gpu/client.py index 3e1c70b1..6cd3cd20 100644 --- a/projects/multi_gpu/client.py +++ b/projects/multi_gpu/client.py @@ -31,7 +31,7 @@ def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs): if __name__ == '__main__': - files = ['small_ocr.pdf'] + files = ['demo/small_ocr.pdf'] n_jobs = np.clip(len(files), 1, 8) results = Parallel(n_jobs, prefer='threads', verbose=10)( delayed(do_parse)(p) for p in files diff --git a/projects/multi_gpu/server.py b/projects/multi_gpu/server.py index ea339a95..30be5abe 100644 --- a/projects/multi_gpu/server.py +++ b/projects/multi_gpu/server.py @@ -1,18 +1,22 @@ import os +import uuid +import shutil +import tempfile +import gc import fitz import torch import base64 +import filetype import litserve as ls -from uuid import uuid4 +from pathlib import Path from fastapi import HTTPException -from filetype import guess_extension -from magic_pdf.tools.common import do_parse +from magic_pdf.tools.cli import do_parse, convert_file_to_pdf from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton class MinerUAPI(ls.LitAPI): def __init__(self, output_dir='/tmp'): - self.output_dir = output_dir + self.output_dir = Path(output_dir) def setup(self, device): if device.startswith('cuda'): @@ -27,7 +31,7 @@ def setup(self, device): def decode_request(self, request): file = request['file'] - file = self.to_pdf(file) + file = self.cvt2pdf(file) opts = request.get('kwargs', {}) opts.setdefault('debug_able', False) opts.setdefault('parse_method', 'auto') @@ -35,9 +39,12 @@ def decode_request(self, request): def predict(self, inputs): try: - do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1]) - return pdf_name + pdf_name = str(uuid.uuid4()) + output_dir = self.output_dir.joinpath(pdf_name) + do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1]) + return output_dir except Exception as e: + shutil.rmtree(output_dir, ignore_errors=True) raise HTTPException(status_code=500, detail=str(e)) finally: self.clean_memory() @@ -46,21 +53,34 @@ def encode_response(self, response): return {'output_dir': response} def clean_memory(self): - import gc if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() - def to_pdf(self, file_base64): + def cvt2pdf(self, file_base64): try: + temp_dir = Path(tempfile.mkdtemp()) + temp_file = temp_dir.joinpath('tmpfile') file_bytes = base64.b64decode(file_base64) - file_ext = guess_extension(file_bytes) - with fitz.open(stream=file_bytes, filetype=file_ext) as f: - if f.is_pdf: return f.tobytes() - return f.convert_to_pdf() + file_ext = filetype.guess_extension(file_bytes) + + if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']: + if file_ext == 'pdf': + return file_bytes + elif file_ext in ['jpg', 'png']: + with fitz.open(stream=file_bytes, filetype=file_ext) as f: + return f.convert_to_pdf() + else: + temp_file.write_bytes(file_bytes) + convert_file_to_pdf(temp_file, temp_dir) + return temp_file.with_suffix('.pdf').read_bytes() + else: + raise Exception('Unsupported file format') except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + finally: + shutil.rmtree(temp_dir, ignore_errors=True) if __name__ == '__main__': diff --git a/projects/multi_gpu/small_ocr.pdf b/projects/multi_gpu/small_ocr.pdf deleted file mode 100644 index 2ab92332..00000000 Binary files a/projects/multi_gpu/small_ocr.pdf and /dev/null differ