diff --git a/docs/modules/file.md b/docs/modules/file.md index cde73bb7f..45f077ca4 100644 --- a/docs/modules/file.md +++ b/docs/modules/file.md @@ -141,7 +141,7 @@ from erniebot_agent.file import GlobalFileManagerHandler async def demo_function(): file_manager = GlobalFileManagerHandler().get() # 通过fileid搜索文件 - file = file_manager.look_up_file_by_id(file_id='your_file_id') + file = await file_manager.look_up_file_by_id(file_id='your_file_id') # 读取file内容(bytes) file_content = await file.read_contents() # 写出到指定位置,your_willing_path需要具体到文件名 diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 2768eaea7..8640d9ffe 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -7,7 +7,6 @@ Final, Iterable, List, - NoReturn, Optional, Sequence, Tuple, @@ -32,7 +31,6 @@ from erniebot_agent.memory.messages import Message, SystemMessage from erniebot_agent.tools.base import BaseTool from erniebot_agent.tools.tool_manager import ToolManager -from erniebot_agent.utils.exceptions import FileError _PLUGINS_WO_FILE_IO: Final[Tuple[str]] = ("eChart",) @@ -106,6 +104,7 @@ def __init__( self._file_manager = file_manager or get_default_file_manager() self._plugins = plugins self._init_file_needs_url() + self._is_running = False @final async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: @@ -119,8 +118,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen Returns: Response from the agent. """ - if files: - await self._ensure_managed_files(files) + if self._is_running: + raise RuntimeError("The agent is already running.") + self._is_running = True await self._callback_manager.on_run_start(agent=self, prompt=prompt) try: agent_resp = await self._run(prompt, files) @@ -129,6 +129,8 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen raise e else: await self._callback_manager.on_run_end(agent=self, response=agent_resp) + finally: + self._is_running = False return agent_resp @final @@ -247,10 +249,10 @@ async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: # XXX: Sniffing is less efficient and probably unnecessary. # Can we make a protocol to statically recognize file inputs and outputs # or can we have the tools introspect about this? - input_files = file_manager.sniff_and_extract_files_from_dict(parsed_tool_args) + input_files = await file_manager.sniff_and_extract_files_from_obj(parsed_tool_args) tool_ret = await tool(**parsed_tool_args) if isinstance(tool_ret, dict): - output_files = file_manager.sniff_and_extract_files_from_dict(tool_ret) + output_files = await file_manager.sniff_and_extract_files_from_obj(tool_ret) else: output_files = [] tool_ret_json = json.dumps(tool_ret, ensure_ascii=False) @@ -275,16 +277,3 @@ def _parse_tool_args(self, tool_args: str) -> Dict[str, Any]: if not isinstance(args_dict, dict): raise ValueError(f"`tool_args` cannot be interpreted as a dict. `tool_args`: {tool_args}") return args_dict - - async def _ensure_managed_files(self, files: Sequence[File]) -> None: - def _raise_exception(file: File) -> NoReturn: - raise FileError(f"{repr(file)} is not managed by the file manager of the agent.") - - file_manager = self.get_file_manager() - for file in files: - try: - managed_file = file_manager.look_up_file_by_id(file.id) - except FileError: - _raise_exception(file) - if file is not managed_file: - _raise_exception(file) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 5e2d7746c..a474a4c0c 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -204,10 +204,10 @@ async def _step( PluginStep( info=output_message.plugin_info, result=output_message.content, - input_files=file_manager.sniff_and_extract_files_from_text( - chat_history[-1].content + input_files=await file_manager.sniff_and_extract_files_from_text( + input_messages[-1].content ), # TODO: make sure this is correct. - output_files=file_manager.sniff_and_extract_files_from_text(output_message.content), + output_files=[], ), new_messages, ) diff --git a/erniebot-agent/src/erniebot_agent/agents/mixins/gradio_mixin.py b/erniebot-agent/src/erniebot_agent/agents/mixins/gradio_mixin.py index e7bc23702..345b891e4 100644 --- a/erniebot-agent/src/erniebot_agent/agents/mixins/gradio_mixin.py +++ b/erniebot-agent/src/erniebot_agent/agents/mixins/gradio_mixin.py @@ -139,7 +139,7 @@ async def _upload(file, history): history = history + [((single_file.name,), None)] size = len(file) - output_lis = file_manager.list_registered_files() + output_lis = await file_manager.list_files() item = "" for i in range(len(output_lis) - size): item += f'
  • {str(output_lis[i]).strip("<>")}
  • ' diff --git a/erniebot-agent/src/erniebot_agent/file/__init__.py b/erniebot-agent/src/erniebot_agent/file/__init__.py index 9c0f505cd..370331396 100644 --- a/erniebot-agent/src/erniebot_agent/file/__init__.py +++ b/erniebot-agent/src/erniebot_agent/file/__init__.py @@ -49,7 +49,7 @@ >>> file_manager = GlobalFileManagerHandler().get() >>> local_file = await file_manager.create_file_from_path(file_path='your_path', file_type='local') - >>> file = file_manager.look_up_file_by_id(file_id='your_file_id') + >>> file = await file_manager.look_up_file_by_id(file_id='your_file_id') >>> file_content = await file.read_contents() # get file content(bytes) >>> await local_file.write_contents_to('your_willing_path') # save to location you want """ diff --git a/erniebot-agent/src/erniebot_agent/file/file_manager.py b/erniebot-agent/src/erniebot_agent/file/file_manager.py index cb744a748..a00a52a88 100644 --- a/erniebot-agent/src/erniebot_agent/file/file_manager.py +++ b/erniebot-agent/src/erniebot_agent/file/file_manager.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import contextlib import contextvars import logging import os import pathlib import tempfile -import uuid from collections import deque from types import TracebackType from typing import ( @@ -28,7 +28,9 @@ Generator, List, Literal, + Mapping, Optional, + Sequence, Type, Union, final, @@ -111,12 +113,22 @@ def __init__( # This can be done lazily, but we need to be careful about race conditions. self._temp_dir = self._create_temp_dir() self._save_dir = pathlib.Path(self._temp_dir.name) + if not prune_on_close: + _logger.warning( + "If `save_dir` is None, the temporary files will be automatically removed" + " even if `prune_on_close` is not True." + ) self._prune_on_close = prune_on_close self._file_registry: FileRegistry[File] = FileRegistry() self._fully_managed_files: Deque[File] = deque() self._closed = False + # XXX: Currently we lock every public method to prevent race conditions. + # However, in some cases, locking may not be necessary, given the + # assumption that each file has a unique ID. We should optimize the + # concurrency of these methods in the future. + self._lock = asyncio.Lock() @property def closed(self): @@ -222,13 +234,15 @@ async def create_local_file_from_path( LocalFile: The created local file. """ - file = await self._create_local_file_from_path( - pathlib.Path(file_path), - file_purpose, - file_metadata or {}, - ) - self._file_registry.register_file(file) - return file + async with self._lock: + self.ensure_not_closed() + file = await self._create_local_file_from_path( + pathlib.Path(file_path), + file_purpose, + file_metadata or {}, + ) + self._file_registry.register_file(file, allow_overwrite=False) + return file async def create_remote_file_from_path( self, @@ -249,14 +263,16 @@ async def create_remote_file_from_path( RemoteFile: The created remote file. """ - file = await self._create_remote_file_from_path( - pathlib.Path(file_path), - file_purpose, - file_metadata, - ) - self._file_registry.register_file(file) - self._fully_managed_files.append(file) - return file + async with self._lock: + self.ensure_not_closed() + file = await self._create_remote_file_from_path( + pathlib.Path(file_path), + file_purpose, + file_metadata, + ) + self._file_registry.register_file(file, allow_overwrite=False) + self._fully_managed_files.append(file) + return file @overload async def create_file_from_bytes( @@ -317,37 +333,39 @@ async def create_file_from_bytes( Union[LocalFile, RemoteFile]: The created file. """ - self.ensure_not_closed() - if file_type is None: - file_type = self._get_default_file_type() - file_path = self._get_unique_file_path( - prefix=pathlib.PurePath(filename).stem, - suffix=pathlib.PurePath(filename).suffix, - ) - async_file_path = anyio.Path(file_path) - await async_file_path.touch() - should_remove_file = True - try: - async with await async_file_path.open("wb") as f: - await f.write(file_contents) - file: File - if file_type == "local": - file = await self._create_local_file_from_path(file_path, file_purpose, file_metadata) - should_remove_file = False - elif file_type == "remote": - file = await self._create_remote_file_from_path( - file_path, - file_purpose, - file_metadata, - ) - else: - raise ValueError(f"Unsupported file type: {file_type}") - finally: - if should_remove_file: - await async_file_path.unlink() - self._file_registry.register_file(file) - self._fully_managed_files.append(file) - return file + async with self._lock: + self.ensure_not_closed() + if file_type is None: + file_type = self._get_default_file_type() + fp, p = tempfile.mkstemp( + prefix=pathlib.PurePath(filename).stem, + suffix=pathlib.PurePath(filename).suffix, + dir=self._save_dir, + ) + os.close(fp) + file_path = pathlib.Path(p) + async_file_path = anyio.Path(file_path) + should_remove_file = True + try: + await async_file_path.write_bytes(file_contents) + file: File + if file_type == "local": + file = await self._create_local_file_from_path(file_path, file_purpose, file_metadata) + should_remove_file = False + elif file_type == "remote": + file = await self._create_remote_file_from_path( + file_path, + file_purpose, + file_metadata, + ) + else: + raise ValueError(f"Unsupported file type: {file_type}") + finally: + if should_remove_file: + await async_file_path.unlink() + self._file_registry.register_file(file, allow_overwrite=False) + self._fully_managed_files.append(file) + return file async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile: """ @@ -360,17 +378,20 @@ async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile: RemoteFile: The retrieved remote file. """ - self.ensure_not_closed() - file = await self._get_remote_file_client().retrieve_file(file_id) - self._file_registry.register_file(file) - return file + async with self._lock: + self.ensure_not_closed() + if self._file_registry.look_up_file(file_id) is not None: + raise FileError(f"File with ID {repr(file_id)} is already managed by the file manager.") + file = await self._get_remote_file_client().retrieve_file(file_id) + self._file_registry.register_file(file, allow_overwrite=False) + return file async def list_remote_files(self) -> List[RemoteFile]: self.ensure_not_closed() files = await self._get_remote_file_client().list_files() return files - def look_up_file_by_id(self, file_id: str) -> File: + async def look_up_file_by_id(self, file_id: str) -> File: """ Look up a file by its ID. @@ -384,16 +405,17 @@ def look_up_file_by_id(self, file_id: str) -> File: FileError: If the file with the specified ID is not found. """ + async with self._lock: + return self.look_up_file_by_id_unsafe(file_id) + + def look_up_file_by_id_unsafe(self, file_id: str) -> File: self.ensure_not_closed() file = self._file_registry.look_up_file(file_id) if file is None: - raise FileError( - f"File with ID {repr(file_id)} not found. " - "Please check if `file_id` is correct and the file is registered." - ) + raise FileError(f"File with ID {repr(file_id)} not found. Please check if `file_id` is correct.") return file - def list_registered_files(self) -> List[File]: + async def list_registered_files(self) -> List[File]: """ List remote files. @@ -401,37 +423,29 @@ def list_registered_files(self) -> List[File]: List[RemoteFile]: The list of remote files. """ - self.ensure_not_closed() - return self._file_registry.list_files() + async with self._lock: + self.ensure_not_closed() + return self._file_registry.list_files() async def prune(self) -> None: - """Clean local cache of file manager.""" - while True: - try: - file = self._fully_managed_files.pop() - except IndexError: - break - if isinstance(file, RemoteFile): - # FIXME: Currently this is not supported. - # await file.delete() - pass - elif isinstance(file, LocalFile): - assert self._save_dir.resolve() in file.path.resolve().parents - await anyio.Path(file.path).unlink() - else: - assert_never() - self._file_registry.unregister_file(file) + async with self._lock: + self.ensure_not_closed() + await self._prune() async def close(self) -> None: """Delete the file manager and clean up its cache""" - if not self._closed: - if self._remote_file_client is not None: - await self._remote_file_client.close() - if self._prune_on_close: - await self.prune() - if self._temp_dir is not None: - self._clean_up_temp_dir(self._temp_dir) - self._closed = True + async with self._lock: + if not self._closed: + # TODO: Suppress errors? + if self._remote_file_client is not None: + await self._remote_file_client.close() + if self._prune_on_close: + await self._prune() + if self._temp_dir is not None: + await asyncio.get_running_loop().run_in_executor( + None, self._clean_up_temp_dir, self._temp_dir + ) + self._closed = True @contextlib.contextmanager def as_default_file_manager(self) -> Generator[None, None, None]: @@ -441,57 +455,15 @@ def as_default_file_manager(self) -> Generator[None, None, None]: finally: _default_file_manager_var.reset(token) - def sniff_and_extract_files_from_list(self, list_: List[Any]) -> List[File]: - files: List[File] = [] - for item in list_: - if not isinstance(item, str): - continue - if protocol.is_file_id(item): - file_id = item - try: - file = self.look_up_file_by_id(file_id) - except FileError as e: - raise FileError(f"An unregistered file with ID {repr(file_id)} was found.") from e - files.append(file) - return files - - def sniff_and_extract_files_from_dict(self, dict_data: Dict[str, Any]) -> List[File]: - files = [] - - def try_get_file(string_value): - if not protocol.is_file_id(string_value): - return - try: - file = self.look_up_file_by_id(string_value) - files.append(file) - except FileError as e: - raise FileError(f"An unregistered file with ID {repr(string_value)} was found.") from e - - for key, value in dict_data.items(): - if isinstance(value, str): - try_get_file(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, str): - try_get_file(item) - elif isinstance(item, dict): - files.extend(self.sniff_and_extract_files_from_dict(item)) - elif isinstance(value, dict): - files.extend(self.sniff_and_extract_files_from_dict(value)) + async def sniff_and_extract_files_from_obj(self, obj: object, *, recursive: bool = True) -> List[File]: + async with self._lock: + self.ensure_not_closed() + return await self._sniff_and_extract_files_from_obj(obj, recursive=recursive) - return files - - def sniff_and_extract_files_from_text(self, text: str) -> List[File]: - file_ids = protocol.extract_file_ids(text) - files: List[File] = [] - for file_id in file_ids: - if protocol.is_file_id(file_id): - try: - file = self.look_up_file_by_id(file_id) - except FileError as e: - raise FileError(f"An unregistered file with ID {repr(file_id)} was found.") from e - files.append(file) - return files + async def sniff_and_extract_files_from_text(self, text: str) -> List[File]: + async with self._lock: + self.ensure_not_closed() + return await self._sniff_and_extract_files_from_text(text) async def _create_local_file_from_path( self, @@ -500,7 +472,7 @@ async def _create_local_file_from_path( file_metadata: Optional[Dict[str, Any]], ) -> LocalFile: return create_local_file_from_path( - pathlib.Path(file_path), + file_path, file_purpose, file_metadata or {}, ) @@ -511,8 +483,7 @@ async def _create_remote_file_from_path( file_purpose: protocol.FilePurpose, file_metadata: Optional[Dict[str, Any]], ) -> RemoteFile: - file = await self._get_remote_file_client().upload_file(file_path, file_purpose, file_metadata or {}) - return file + return await self._get_remote_file_client().upload_file(file_path, file_purpose, file_metadata or {}) def _get_remote_file_client(self) -> RemoteFileClient: if self._remote_file_client is None: @@ -520,19 +491,58 @@ def _get_remote_file_client(self) -> RemoteFileClient: else: return self._remote_file_client + async def _prune(self) -> None: + while True: + try: + file = self._fully_managed_files.popleft() + except IndexError: + break + if isinstance(file, RemoteFile): + # FIXME: Currently this is not supported. + # await file.delete() + pass + elif isinstance(file, LocalFile): + assert self._save_dir.resolve() in file.path.resolve().parents + await anyio.Path(file.path).unlink() + else: + assert_never() + self._file_registry.unregister_file(file) + + async def _sniff_and_extract_files_from_obj(self, obj: object, *, recursive: bool = True) -> List[File]: + files: List[File] = [] + if isinstance(obj, str): + if protocol.is_file_id(obj): + file_id = obj + file = self._file_registry.look_up_file(file_id) + if file is not None: + files.append(file) + else: + if recursive: + if isinstance(obj, Sequence): + for item in obj: + files.extend(await self._sniff_and_extract_files_from_obj(item, recursive=True)) + elif isinstance(obj, Mapping): + for item in obj.values(): + files.extend(await self._sniff_and_extract_files_from_obj(item, recursive=True)) + return files + + async def _sniff_and_extract_files_from_text(self, text: str) -> List[File]: + file_ids = protocol.extract_file_ids(text) + file_ids = list(set(file_ids)) + files: List[File] = [] + for file_id in file_ids: + if protocol.is_file_id(file_id): + file = self._file_registry.look_up_file(file_id) + if file is not None: + files.append(file) + return files + def _get_default_file_type(self) -> Literal["local", "remote"]: if self._remote_file_client is not None: return "remote" else: return "local" - def _get_unique_file_path( - self, prefix: Optional[str] = None, suffix: Optional[str] = None - ) -> pathlib.Path: - filename = f"{prefix or ''}{str(uuid.uuid4())}{suffix or ''}" - file_path = self._save_dir / filename - return file_path - @staticmethod def _create_temp_dir() -> tempfile.TemporaryDirectory: temp_dir = tempfile.TemporaryDirectory() diff --git a/erniebot-agent/src/erniebot_agent/tools/utils.py b/erniebot-agent/src/erniebot_agent/tools/utils.py index 4ea51eae6..c095a7592 100644 --- a/erniebot-agent/src/erniebot_agent/tools/utils.py +++ b/erniebot-agent/src/erniebot_agent/tools/utils.py @@ -212,7 +212,7 @@ async def get_content_by_file_id( file_id: str, format: str, mime_type: str, file_manager: FileManager ) -> bytes: file_id = file_id.replace("", "").replace("", "") - file = file_manager.look_up_file_by_id(file_id) + file = await file_manager.look_up_file_by_id(file_id) byte_str = await file.read_contents() return byte_str diff --git a/erniebot-agent/tests/integration_tests/agents/test_agent_with_plugins.py b/erniebot-agent/tests/integration_tests/agents/test_agent_with_plugins.py index d9935c159..b029d26fc 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_agent_with_plugins.py +++ b/erniebot-agent/tests/integration_tests/agents/test_agent_with_plugins.py @@ -32,7 +32,7 @@ async def __call__(self, input_file_id: str, repeat_times: int) -> Dict[str, Any input_file_id = input_file_id.split("")[0] file_manager = GlobalFileManagerHandler().get() - input_file = file_manager.look_up_file_by_id(input_file_id) + input_file = await file_manager.look_up_file_by_id(input_file_id) if input_file is None: raise RuntimeError("File not found") # text = (await input_file.read_contents())[:10] diff --git a/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py b/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py index f05e8fdf9..24c667b47 100644 --- a/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py +++ b/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py @@ -165,7 +165,7 @@ async def test_plugin_schema(self): self.assertIn("file", result) file_id = result["file"] - file = self.file_manager.look_up_file_by_id(file_id=file_id) + file = await self.file_manager.look_up_file_by_id(file_id=file_id) content = await file.read_contents() self.assertEqual(content.decode("utf-8"), self.content) self.assertIn("prompt", result) diff --git a/erniebot-agent/tests/unit_tests/tools/test_schema.py b/erniebot-agent/tests/unit_tests/tools/test_schema.py index 138c100be..5a0bd97c0 100644 --- a/erniebot-agent/tests/unit_tests/tools/test_schema.py +++ b/erniebot-agent/tests/unit_tests/tools/test_schema.py @@ -351,12 +351,12 @@ async def test_file_v1(self): self.assertEqual(len(result["file"]), 2) file_0, file_1 = result["file"][0], result["file"][1] - file: File = file_manager.look_up_file_by_id(file_0) + file: File = await file_manager.look_up_file_by_id(file_0) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_0) self.assertEqual(file_content, file_content_from_file_manager) - file: File = file_manager.look_up_file_by_id(file_1) + file: File = await file_manager.look_up_file_by_id(file_1) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_1) self.assertEqual(file_content, file_content_from_file_manager) @@ -377,7 +377,7 @@ async def test_file_v2(self): result = await tool() file_manager = GlobalFileManagerHandler().get() - file: File = file_manager.look_up_file_by_id(result["file"]) + file: File = await file_manager.look_up_file_by_id(result["file"]) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content) self.assertEqual(file_content, file_content_from_file_manager) @@ -395,7 +395,7 @@ async def test_file_v3(self): result = await tool() - file: File = file_manager.look_up_file_by_id(result["file"]["file"]) + file: File = await file_manager.look_up_file_by_id(result["file"]["file"]) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content) self.assertEqual(file_content, file_content_from_file_manager) @@ -414,12 +414,12 @@ async def test_file_v4(self): file_0, file_1 = result["file"][0]["file"], result["file"][1]["file"] - file: File = file_manager.look_up_file_by_id(file_0) + file: File = await file_manager.look_up_file_by_id(file_0) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_0) self.assertEqual(file_content, file_content_from_file_manager) - file: File = file_manager.look_up_file_by_id(file_1) + file: File = await file_manager.look_up_file_by_id(file_1) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_1) self.assertEqual(file_content, file_content_from_file_manager) @@ -447,12 +447,12 @@ async def test_file_v5(self): file_0, file_1 = result["file"]["file"][0]["file"], result["file"]["file"][1]["file"] - file: File = file_manager.look_up_file_by_id(file_0) + file: File = await file_manager.look_up_file_by_id(file_0) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_0) self.assertEqual(file_content, file_content_from_file_manager) - file: File = file_manager.look_up_file_by_id(file_1) + file: File = await file_manager.look_up_file_by_id(file_1) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_1) self.assertEqual(file_content, file_content_from_file_manager) @@ -473,12 +473,12 @@ async def test_file_v6(self): result = await tool() - file: File = file_manager.look_up_file_by_id(result["first_file"]) + file: File = await file_manager.look_up_file_by_id(result["first_file"]) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_0) self.assertEqual(file_content, file_content_from_file_manager) - file: File = file_manager.look_up_file_by_id(result["second_file"]) + file: File = await file_manager.look_up_file_by_id(result["second_file"]) file_content_from_file_manager = await file.read_contents() file_content = base64.b64decode(file_content_1) self.assertEqual(file_content, file_content_from_file_manager)