diff --git a/Makefile b/Makefile index 250da0e5..82b39559 100644 --- a/Makefile +++ b/Makefile @@ -139,7 +139,7 @@ check-version: -s CHANGELOG.md \ -f preprocessing-pipeline-family.yaml release \ -f ${PACKAGE_NAME}/api/app.py release \ - -f ${PACKAGE_NAME}/api/general.py release + -f ${PACKAGE_NAME}/api/endpoints.py release ## version-sync: update references to version with most recent version from CHANGELOG.md .PHONY: version-sync @@ -148,4 +148,4 @@ version-sync: -s CHANGELOG.md \ -f preprocessing-pipeline-family.yaml release \ -f ${PACKAGE_NAME}/api/app.py release \ - -f ${PACKAGE_NAME}/api/general.py release + -f ${PACKAGE_NAME}/api/endpoints.py release diff --git a/prepline_general/api/app.py b/prepline_general/api/app.py index 0aba4e03..9f8df266 100644 --- a/prepline_general/api/app.py +++ b/prepline_general/api/app.py @@ -1,12 +1,11 @@ -from fastapi import FastAPI, Request, status, HTTPException -from fastapi.datastructures import FormData +from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse -from fastapi.security import APIKeyHeader +from fastapi.datastructures import FormData import logging import os -from .general import router as general_router -from .openapi import set_custom_openapi +from prepline_general.api.endpoints import router as general_router +from prepline_general.api.openapi import set_custom_openapi logger = logging.getLogger("unstructured_api") @@ -31,6 +30,8 @@ openapi_tags=[{"name": "general"}], ) +app.include_router(general_router) + # Note(austin) - This logger just dumps exceptions # We'd rather handle those below, so disable this in deployments uvicorn_logger = logging.getLogger("uvicorn.error") @@ -62,7 +63,6 @@ async def error_handler(request: Request, e: Exception): allow_headers=["Content-Type"], ) -app.include_router(general_router) set_custom_openapi(app) @@ -107,26 +107,4 @@ async def patched_get_form( # Replace the private method with our wrapper Request._get_form = patched_get_form # type: ignore[assignment] - -# Filter out /healthcheck noise -class HealthCheckFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return record.getMessage().find("/healthcheck") == -1 - - -# Filter out /metrics noise -class MetricsCheckFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return record.getMessage().find("/metrics") == -1 - - -logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) -logging.getLogger("uvicorn.access").addFilter(MetricsCheckFilter()) - - -@app.get("/healthcheck", status_code=status.HTTP_200_OK, include_in_schema=False) -def healthcheck(request: Request): - return {"healthcheck": "HEALTHCHECK STATUS: EVERYTHING OK!"} - - logger.info("Started Unstructured API") diff --git a/prepline_general/api/endpoints.py b/prepline_general/api/endpoints.py new file mode 100644 index 00000000..4bfc990a --- /dev/null +++ b/prepline_general/api/endpoints.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import io +import json +import os +from typing import List, Sequence, Dict, Any, cast, Union, Optional + +import pandas as pd +from fastapi import APIRouter, UploadFile, Depends, HTTPException +from starlette import status +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +from prepline_general.api.general import ( + ungz_file, + MultipartMixedResponse, + pipeline_api, +) +from prepline_general.api.validation import _validate_chunking_strategy, get_validated_mimetype +from prepline_general.api.models.form_params import GeneralFormParams + +router = APIRouter() + + +@router.post( + "/general/v0/general", + openapi_extra={"x-speakeasy-name-override": "partition"}, + tags=["general"], + summary="Summary", + description="Description", + operation_id="partition_parameters", +) +@router.post("/general/v0.0.68/general", include_in_schema=False) +def general_partition( + request: Request, + # cannot use annotated type here because of a bug described here: + # https://github.com/tiangolo/fastapi/discussions/10280 + # The openapi metadata must be added separately in openapi.py file. + # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] + # For new parameters - add them in models/form_params.py + files: List[UploadFile], + form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), +): + # -- must have a valid API key -- + if api_key_env := os.environ.get("UNSTRUCTURED_API_KEY"): + api_key = request.headers.get("unstructured-api-key") + if api_key != api_key_env: + raise HTTPException( + detail=f"API key {api_key} is invalid", status_code=status.HTTP_401_UNAUTHORIZED + ) + + content_type = request.headers.get("Accept") + + # -- detect response content-type conflict when multiple files are uploaded -- + if ( + len(files) > 1 + and content_type + and content_type + not in [ + "*/*", + "multipart/mixed", + "application/json", + "text/csv", + ] + ): + raise HTTPException( + detail=f"Conflict in media type {content_type} with response type 'multipart/mixed'.\n", + status_code=status.HTTP_406_NOT_ACCEPTABLE, + ) + + # -- validate other arguments -- + chunking_strategy = _validate_chunking_strategy(form_params.chunking_strategy) + + # -- unzip any uploaded files that need it -- + for idx, file in enumerate(files): + is_content_type_gz = file.content_type == "application/gzip" + is_extension_gz = file.filename and file.filename.endswith(".gz") + if is_content_type_gz or is_extension_gz: + files[idx] = ungz_file(file, form_params.gz_uncompressed_content_type) + + return ( + MultipartMixedResponse( + response_generator(files, request, form_params, chunking_strategy, is_multipart=True), + content_type=form_params.output_format, + ) + if content_type == "multipart/mixed" + else ( + list( + response_generator( + files, request, form_params, chunking_strategy, is_multipart=False + ) + )[0] + if len(files) == 1 + else join_responses( + form_params, + list( + response_generator( + files, request, form_params, chunking_strategy, is_multipart=False + ) + ), + ) + ) + ) + + +def join_responses( + form_params: GeneralFormParams, + responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse], +) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: + """Consolidate partitionings from multiple documents into single response payload.""" + if form_params.output_format != "text/csv": + return cast(List[Union[str, List[Dict[str, Any]]]], responses) + responses = cast(List[PlainTextResponse], responses) + data = pd.read_csv(io.BytesIO(responses[0].body)) # pyright: ignore[reportUnknownMemberType] + if len(responses) > 1: + for resp in responses[1:]: + resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] + io.BytesIO(resp.body) + ) + data = data.merge(resp_data, how="outer") # pyright: ignore[reportUnknownMemberType] + return PlainTextResponse(data.to_csv()) + + +def response_generator( + files: List[UploadFile], + request: Request, + form_params: GeneralFormParams, + chunking_strategy: Optional[str], + is_multipart: bool, +): + for file in files: + file_content_type = get_validated_mimetype(file) + _file = file.file + + response = pipeline_api( + _file, + request=request, + coordinates=form_params.coordinates, + encoding=form_params.encoding, + hi_res_model_name=form_params.hi_res_model_name, + include_page_breaks=form_params.include_page_breaks, + ocr_languages=form_params.ocr_languages, + pdf_infer_table_structure=form_params.pdf_infer_table_structure, + skip_infer_table_types=form_params.skip_infer_table_types, + strategy=form_params.strategy, + xml_keep_tags=form_params.xml_keep_tags, + response_type=form_params.output_format, + filename=str(file.filename), + file_content_type=file_content_type, + languages=form_params.languages, + extract_image_block_types=form_params.extract_image_block_types, + unique_element_ids=form_params.unique_element_ids, + # -- chunking options -- + chunking_strategy=chunking_strategy, + combine_under_n_chars=form_params.combine_under_n_chars, + max_characters=form_params.max_characters, + multipage_sections=form_params.multipage_sections, + new_after_n_chars=form_params.new_after_n_chars, + overlap=form_params.overlap, + overlap_all=form_params.overlap_all, + starting_page_number=form_params.starting_page_number, + ) + + yield ( + json.dumps(response) + if is_multipart and type(response) not in [str, bytes] + else ( + PlainTextResponse(response) + if not is_multipart and form_params.output_format == "text/csv" + else response + ) + ) + + +@router.get("/general/v0/general", include_in_schema=False) +@router.get("/general/v0.0.68/general", include_in_schema=False) +async def handle_invalid_get_request(): + raise HTTPException( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported." + ) + + +@router.get("/healthcheck", status_code=status.HTTP_200_OK, include_in_schema=False) +def healthcheck(request: Request): + return {"healthcheck": "HEALTHCHECK STATUS: EVERYTHING OK!"} diff --git a/prepline_general/api/general.py b/prepline_general/api/general.py index 640a83d9..9e6e97b6 100644 --- a/prepline_general/api/general.py +++ b/prepline_general/api/general.py @@ -11,51 +11,32 @@ from base64 import b64encode from concurrent.futures import ThreadPoolExecutor from functools import partial -from types import TracebackType -from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import IO, Any, Dict, List, Mapping, Optional, Sequence, Tuple import backoff -import pandas as pd import psutil import requests from fastapi import ( - APIRouter, - Depends, - FastAPI, HTTPException, - Request, UploadFile, - status, ) -from fastapi.responses import PlainTextResponse, StreamingResponse -from pypdf import PageObject, PdfReader, PdfWriter -from pypdf.errors import FileNotDecryptedError, PdfReadError +from fastapi.responses import StreamingResponse +from pypdf import PdfReader, PageObject, PdfWriter from starlette.datastructures import Headers +from starlette.requests import Request from starlette.types import Send - -from prepline_general.api.models.form_params import GeneralFormParams from unstructured.documents.elements import Element from unstructured.partition.auto import partition -from unstructured.staging.base import ( - convert_to_dataframe, - convert_to_isd, - elements_from_json, -) +from unstructured.staging.base import elements_from_json, convert_to_dataframe, convert_to_isd from unstructured_inference.models.base import UnknownModelException from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES -app = FastAPI() -router = APIRouter() - - -def is_compatible_response_type(media_type: str, response_type: type) -> bool: - """True when `response_type` can be converted to `media_type` for HTTP Response.""" - return ( - False - if media_type == "application/json" and response_type not in [dict, list] - else False if media_type == "text/csv" and response_type != str else True - ) - +from prepline_general.api.memory_protection import ChipperMemoryProtection +from prepline_general.api.validation import ( + _check_pdf, + _validate_hi_res_model_name, + _validate_strategy, +) logger = logging.getLogger("unstructured_api") @@ -244,37 +225,6 @@ def partition_pdf_splits( return results -is_chipper_processing = False - - -class ChipperMemoryProtection: - """Chipper calls are expensive, and right now we can only do one call at a time. - - If the model is in use, return a 503 error. The API should scale up and the user can try again - on a different server. - """ - - def __enter__(self): - global is_chipper_processing - if is_chipper_processing: - # Log here so we can track how often it happens - logger.error("Chipper is already is use") - raise HTTPException( - status_code=503, detail="Server is under heavy load. Please try again later." - ) - - is_chipper_processing = True - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_tb: Optional[TracebackType], - ): - global is_chipper_processing - is_chipper_processing = False - - def pipeline_api( file: IO[bytes], request: Request, @@ -530,108 +480,11 @@ def _check_free_memory(): ) -def _check_pdf(file: IO[bytes]): - """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF.""" - try: - pdf = PdfReader(file) - - # This will raise if the file is encrypted - pdf.metadata - return pdf - except FileNotDecryptedError: - raise HTTPException( - status_code=400, - detail="File is encrypted. Please decrypt it with password.", - ) - except PdfReadError: - raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF") - - -def _validate_strategy(strategy: str) -> str: - strategy = strategy.lower() - strategies = ["fast", "hi_res", "auto", "ocr_only"] - if strategy not in strategies: - raise HTTPException( - status_code=400, detail=f"Invalid strategy: {strategy}. Must be one of {strategies}" - ) - return strategy - - -def _validate_hi_res_model_name( - hi_res_model_name: Optional[str], show_coordinates: bool -) -> Optional[str]: - # Make sure chipper aliases to the latest model - if hi_res_model_name and hi_res_model_name == "chipper": - hi_res_model_name = "chipperv2" - - if hi_res_model_name and hi_res_model_name in CHIPPER_MODEL_TYPES and show_coordinates: - raise HTTPException( - status_code=400, - detail=f"coordinates aren't available when using the {hi_res_model_name} model type", - ) - return hi_res_model_name - - -def _validate_chunking_strategy(chunking_strategy: Optional[str]) -> Optional[str]: - """Raise on `chunking_strategy` is not a valid chunking strategy name. - - Also provides case-insensitivity. - """ - if chunking_strategy is None: - return None - - chunking_strategy = chunking_strategy.lower() - available_strategies = ["basic", "by_title"] - - if chunking_strategy not in available_strategies: - raise HTTPException( - status_code=400, - detail=( - f"Invalid chunking strategy: {chunking_strategy}. Must be one of" - f" {available_strategies}" - ), - ) - - return chunking_strategy - - def _set_pdf_infer_table_structure(pdf_infer_table_structure: bool, strategy: str) -> bool: """Avoids table inference in "fast" and "ocr_only" runs.""" return strategy in ("hi_res", "auto") and pdf_infer_table_structure -def get_validated_mimetype(file: UploadFile) -> Optional[str]: - """The MIME-type of `file`. - - The mimetype is computed based on `file.content_type`, or the mimetypes lib if that's too - generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and - return HTTP 400 for an invalid type. - """ - content_type = file.content_type - filename = str(file.filename) # -- "None" when file.filename is None -- - if not content_type or content_type == "application/octet-stream": - content_type = mimetypes.guess_type(filename)[0] - - # Some filetypes missing for this library, just hardcode them for now - if not content_type: - if filename.endswith(".md"): - content_type = "text/markdown" - elif filename.endswith(".msg"): - content_type = "message/rfc822" - - allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES") - if allowed_mimetypes_str is not None: - allowed_mimetypes = allowed_mimetypes_str.split(",") - - if content_type not in allowed_mimetypes: - raise HTTPException( - status_code=400, - detail=(f"File type {content_type} is not supported."), - ) - - return content_type - - class MultipartMixedResponse(StreamingResponse): CRLF = b"\r\n" @@ -701,148 +554,3 @@ def return_content_type(filename: str): filename=filename, headers=Headers({"content-type": return_content_type(filename)}), ) - - -@router.get("/general/v0/general", include_in_schema=False) -@router.get("/general/v0.0.68/general", include_in_schema=False) -async def handle_invalid_get_request(): - raise HTTPException( - status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported." - ) - - -@router.post( - "/general/v0/general", - openapi_extra={"x-speakeasy-name-override": "partition"}, - tags=["general"], - summary="Summary", - description="Description", - operation_id="partition_parameters", -) -@router.post("/general/v0.0.68/general", include_in_schema=False) -def general_partition( - request: Request, - # cannot use annotated type here because of a bug described here: - # https://github.com/tiangolo/fastapi/discussions/10280 - # The openapi metadata must be added separately in openapi.py file. - # TODO: Check if the bug is fixed and change the declaration to use Annoteted[List[UploadFile], File(...)] - # For new parameters - add them in models/form_params.py - files: List[UploadFile], - form_params: GeneralFormParams = Depends(GeneralFormParams.as_form), -): - # -- must have a valid API key -- - if api_key_env := os.environ.get("UNSTRUCTURED_API_KEY"): - api_key = request.headers.get("unstructured-api-key") - if api_key != api_key_env: - raise HTTPException( - detail=f"API key {api_key} is invalid", status_code=status.HTTP_401_UNAUTHORIZED - ) - - content_type = request.headers.get("Accept") - - # -- detect response content-type conflict when multiple files are uploaded -- - if ( - len(files) > 1 - and content_type - and content_type - not in [ - "*/*", - "multipart/mixed", - "application/json", - "text/csv", - ] - ): - raise HTTPException( - detail=f"Conflict in media type {content_type} with response type 'multipart/mixed'.\n", - status_code=status.HTTP_406_NOT_ACCEPTABLE, - ) - - # -- validate other arguments -- - chunking_strategy = _validate_chunking_strategy(form_params.chunking_strategy) - - # -- unzip any uploaded files that need it -- - for idx, file in enumerate(files): - is_content_type_gz = file.content_type == "application/gzip" - is_extension_gz = file.filename and file.filename.endswith(".gz") - if is_content_type_gz or is_extension_gz: - files[idx] = ungz_file(file, form_params.gz_uncompressed_content_type) - - def response_generator(is_multipart: bool): - for file in files: - file_content_type = get_validated_mimetype(file) - - _file = file.file - - response = pipeline_api( - _file, - request=request, - coordinates=form_params.coordinates, - encoding=form_params.encoding, - hi_res_model_name=form_params.hi_res_model_name, - include_page_breaks=form_params.include_page_breaks, - ocr_languages=form_params.ocr_languages, - pdf_infer_table_structure=form_params.pdf_infer_table_structure, - skip_infer_table_types=form_params.skip_infer_table_types, - strategy=form_params.strategy, - xml_keep_tags=form_params.xml_keep_tags, - response_type=form_params.output_format, - filename=str(file.filename), - file_content_type=file_content_type, - languages=form_params.languages, - extract_image_block_types=form_params.extract_image_block_types, - unique_element_ids=form_params.unique_element_ids, - # -- chunking options -- - chunking_strategy=chunking_strategy, - combine_under_n_chars=form_params.combine_under_n_chars, - max_characters=form_params.max_characters, - multipage_sections=form_params.multipage_sections, - new_after_n_chars=form_params.new_after_n_chars, - overlap=form_params.overlap, - overlap_all=form_params.overlap_all, - starting_page_number=form_params.starting_page_number, - ) - - yield ( - json.dumps(response) - if is_multipart and type(response) not in [str, bytes] - else ( - PlainTextResponse(response) - if not is_multipart and form_params.output_format == "text/csv" - else response - ) - ) - - def join_responses( - responses: Sequence[str | List[Dict[str, Any]] | PlainTextResponse] - ) -> List[str | List[Dict[str, Any]]] | PlainTextResponse: - """Consolidate partitionings from multiple documents into single response payload.""" - if form_params.output_format != "text/csv": - return cast(List[Union[str, List[Dict[str, Any]]]], responses) - responses = cast(List[PlainTextResponse], responses) - data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(responses[0].body) - ) - if len(responses) > 1: - for resp in responses[1:]: - resp_data = pd.read_csv( # pyright: ignore[reportUnknownMemberType] - io.BytesIO(resp.body) - ) - data = data.merge( # pyright: ignore[reportUnknownMemberType] - resp_data, how="outer" - ) - return PlainTextResponse(data.to_csv()) - - return ( - MultipartMixedResponse( - response_generator(is_multipart=True), content_type=form_params.output_format - ) - if content_type == "multipart/mixed" - else ( - list(response_generator(is_multipart=False))[0] - if len(files) == 1 - else join_responses(list(response_generator(is_multipart=False))) - ) - ) - - -app.include_router(router) diff --git a/prepline_general/api/logging.py b/prepline_general/api/logging.py new file mode 100644 index 00000000..696bb3fc --- /dev/null +++ b/prepline_general/api/logging.py @@ -0,0 +1,17 @@ +import logging + + +# Filter out /healthcheck noise +class HealthCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/healthcheck") == -1 + + +# Filter out /metrics noise +class MetricsCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.getMessage().find("/metrics") == -1 + + +logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) +logging.getLogger("uvicorn.access").addFilter(MetricsCheckFilter()) diff --git a/prepline_general/api/memory_protection.py b/prepline_general/api/memory_protection.py new file mode 100644 index 00000000..439acdb7 --- /dev/null +++ b/prepline_general/api/memory_protection.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import logging +from types import TracebackType +from typing import Optional + +from fastapi import HTTPException + +logger = logging.getLogger("unstructured_api") + + +is_chipper_processing = False + + +class ChipperMemoryProtection: + """Chipper calls are expensive, and right now we can only do one call at a time. + + If the model is in use, return a 503 error. The API should scale up and the user can try again + on a different server. + """ + + def __enter__(self): + global is_chipper_processing + if is_chipper_processing: + # Log here so we can track how often it happens + logger.error("Chipper is already is use") + raise HTTPException( + status_code=503, detail="Server is under heavy load. Please try again later." + ) + + is_chipper_processing = True + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ): + global is_chipper_processing + is_chipper_processing = False diff --git a/prepline_general/api/validation.py b/prepline_general/api/validation.py new file mode 100644 index 00000000..df32d578 --- /dev/null +++ b/prepline_general/api/validation.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import mimetypes +import os + +from typing import IO, Optional + +from fastapi import HTTPException, UploadFile +from pypdf import PdfReader +from pypdf.errors import FileNotDecryptedError, PdfReadError +from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES + + +def _check_pdf(file: IO[bytes]): + """Check if the PDF file is encrypted, otherwise assume it is not a valid PDF.""" + try: + pdf = PdfReader(file) + + # This will raise if the file is encrypted + pdf.metadata + return pdf + except FileNotDecryptedError: + raise HTTPException( + status_code=400, + detail="File is encrypted. Please decrypt it with password.", + ) + except PdfReadError: + raise HTTPException(status_code=422, detail="File does not appear to be a valid PDF") + + +def _validate_strategy(strategy: str) -> str: + strategy = strategy.lower() + strategies = ["fast", "hi_res", "auto", "ocr_only"] + if strategy not in strategies: + raise HTTPException( + status_code=400, detail=f"Invalid strategy: {strategy}. Must be one of {strategies}" + ) + return strategy + + +def _validate_hi_res_model_name( + hi_res_model_name: Optional[str], show_coordinates: bool +) -> Optional[str]: + # Make sure chipper aliases to the latest model + if hi_res_model_name and hi_res_model_name == "chipper": + hi_res_model_name = "chipperv2" + + if hi_res_model_name and hi_res_model_name in CHIPPER_MODEL_TYPES and show_coordinates: + raise HTTPException( + status_code=400, + detail=f"coordinates aren't available when using the {hi_res_model_name} model type", + ) + return hi_res_model_name + + +def _validate_chunking_strategy(chunking_strategy: Optional[str]) -> Optional[str]: + """Raise on `chunking_strategy` is not a valid chunking strategy name. + + Also provides case-insensitivity. + """ + if chunking_strategy is None: + return None + + chunking_strategy = chunking_strategy.lower() + available_strategies = ["basic", "by_title"] + + if chunking_strategy not in available_strategies: + raise HTTPException( + status_code=400, + detail=( + f"Invalid chunking strategy: {chunking_strategy}. Must be one of" + f" {available_strategies}" + ), + ) + + return chunking_strategy + + +def get_validated_mimetype(file: UploadFile) -> Optional[str]: + """The MIME-type of `file`. + + The mimetype is computed based on `file.content_type`, or the mimetypes lib if that's too + generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and + return HTTP 400 for an invalid type. + """ + content_type = file.content_type + filename = str(file.filename) # -- "None" when file.filename is None -- + if not content_type or content_type == "application/octet-stream": + content_type = mimetypes.guess_type(filename)[0] + + # Some filetypes missing for this library, just hardcode them for now + if not content_type: + if filename.endswith(".md"): + content_type = "text/markdown" + elif filename.endswith(".msg"): + content_type = "message/rfc822" + + allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES") + if allowed_mimetypes_str is not None: + allowed_mimetypes = allowed_mimetypes_str.split(",") + + if content_type not in allowed_mimetypes: + raise HTTPException( + status_code=400, + detail=(f"File type {content_type} is not supported."), + ) + + return content_type