From 01f96d5c9a22d62d52449b9e37657bf38a669c94 Mon Sep 17 00:00:00 2001 From: Mariano Iglesias Date: Fri, 9 Jan 2026 16:20:32 -0300 Subject: [PATCH 1/2] Adding mypy --- Makefile | 1 + poetry.lock | 199 ++- pyproject.toml | 68 + src/avalan/__init__.py | 8 +- src/avalan/agent/engine.py | 18 +- src/avalan/agent/loader.py | 13 +- src/avalan/agent/orchestrator/__init__.py | 39 +- .../orchestrator/orchestrators/default.py | 13 +- .../agent/orchestrator/orchestrators/json.py | 21 +- .../orchestrators/reasoning/cot.py | 2 +- .../response/orchestrator_response.py | 62 +- src/avalan/agent/renderer.py | 3 +- src/avalan/cli/__init__.py | 48 +- src/avalan/cli/__main__.py | 45 +- src/avalan/cli/commands/agent.py | 68 +- src/avalan/cli/commands/cache.py | 4 +- src/avalan/cli/commands/memory.py | 13 +- src/avalan/cli/commands/model.py | 152 ++- src/avalan/cli/commands/tokenizer.py | 13 +- src/avalan/cli/download.py | 2 +- src/avalan/cli/theme/__init__.py | 85 +- src/avalan/cli/theme/fancy.py | 1203 ++++++++++------- src/avalan/compat.py | 16 +- src/avalan/deploy/aws.py | 48 +- src/avalan/entities.py | 12 +- src/avalan/event/manager.py | 3 +- src/avalan/memory/__init__.py | 12 +- src/avalan/memory/partitioner/code.py | 128 +- src/avalan/memory/partitioner/text.py | 2 +- src/avalan/memory/permanent/__init__.py | 25 +- .../memory/permanent/elasticsearch/message.py | 14 +- .../memory/permanent/elasticsearch/raw.py | 12 +- src/avalan/memory/permanent/pgsql/__init__.py | 72 +- src/avalan/memory/permanent/pgsql/message.py | 30 +- src/avalan/memory/permanent/pgsql/raw.py | 9 +- .../memory/permanent/s3vectors/message.py | 14 +- src/avalan/memory/permanent/s3vectors/raw.py | 12 +- src/avalan/memory/source.py | 10 +- src/avalan/model/audio/__init__.py | 38 +- src/avalan/model/audio/classification.py | 36 +- src/avalan/model/audio/generation.py | 33 +- src/avalan/model/audio/speech.py | 13 +- src/avalan/model/audio/speech_recognition.py | 19 +- src/avalan/model/criteria.py | 18 +- src/avalan/model/engine.py | 121 +- src/avalan/model/hubs/huggingface.py | 37 +- src/avalan/model/manager.py | 27 +- src/avalan/model/modalities/audio.py | 42 +- src/avalan/model/modalities/registry.py | 19 +- src/avalan/model/modalities/text.py | 138 +- src/avalan/model/modalities/vision.py | 78 +- src/avalan/model/nlp/__init__.py | 27 +- src/avalan/model/nlp/question.py | 26 +- src/avalan/model/nlp/sentence.py | 22 +- src/avalan/model/nlp/sequence.py | 93 +- src/avalan/model/nlp/text/generation.py | 117 +- src/avalan/model/nlp/text/mlxlm.py | 46 +- src/avalan/model/nlp/text/vendor/__init__.py | 24 +- src/avalan/model/nlp/text/vendor/anthropic.py | 129 +- src/avalan/model/nlp/text/vendor/bedrock.py | 14 +- src/avalan/model/nlp/text/vendor/google.py | 11 +- .../model/nlp/text/vendor/huggingface.py | 30 +- src/avalan/model/nlp/text/vendor/litellm.py | 10 +- src/avalan/model/nlp/text/vendor/ollama.py | 21 +- src/avalan/model/nlp/text/vendor/openai.py | 137 +- .../model/nlp/text/vendor/openrouter.py | 6 +- src/avalan/model/nlp/text/vllm.py | 41 +- src/avalan/model/nlp/token.py | 63 +- src/avalan/model/response/parsers/tool.py | 4 +- src/avalan/model/response/text.py | 69 +- src/avalan/model/transformer.py | 38 +- src/avalan/model/vendor.py | 39 +- src/avalan/model/vision/classification.py | 36 +- src/avalan/model/vision/decoder.py | 22 +- src/avalan/model/vision/detection.py | 25 +- .../model/vision/diffusion/animation.py | 26 +- src/avalan/model/vision/diffusion/image.py | 16 +- src/avalan/model/vision/diffusion/video.py | 18 +- src/avalan/model/vision/segmentation.py | 38 +- src/avalan/model/vision/text.py | 53 +- src/avalan/secrets/aws.py | 13 +- src/avalan/secrets/keyring.py | 6 +- src/avalan/server/__init__.py | 14 +- src/avalan/server/a2a/router.py | 14 +- src/avalan/server/routers/__init__.py | 39 +- src/avalan/server/routers/chat.py | 29 +- src/avalan/server/routers/mcp.py | 82 +- src/avalan/server/routers/responses.py | 91 +- src/avalan/tool/__init__.py | 55 +- src/avalan/tool/browser.py | 50 +- src/avalan/tool/code.py | 9 +- src/avalan/tool/database/__init__.py | 65 +- src/avalan/tool/database/count.py | 2 +- src/avalan/tool/database/inspect.py | 2 +- src/avalan/tool/database/keys.py | 38 +- src/avalan/tool/database/kill.py | 2 +- src/avalan/tool/database/locks.py | 2 +- src/avalan/tool/database/plan.py | 2 +- src/avalan/tool/database/relationships.py | 2 +- src/avalan/tool/database/run.py | 2 +- src/avalan/tool/database/sample.py | 2 +- src/avalan/tool/database/size.py | 14 +- src/avalan/tool/database/tables.py | 2 +- src/avalan/tool/database/tasks.py | 2 +- src/avalan/tool/database/toolset.py | 7 +- src/avalan/tool/manager.py | 56 +- src/avalan/tool/math.py | 4 +- src/avalan/tool/mcp.py | 7 +- src/avalan/tool/memory.py | 18 +- src/avalan/tool/parser.py | 14 +- src/avalan/tool/search_engine.py | 4 +- src/avalan/tool/youtube.py | 2 +- src/avalan/utils.py | 4 +- tests/agent/default_orchestrator_test.py | 8 +- .../orchestrator_response_additional_test.py | 2 +- tests/agent/orchestrator_test.py | 8 +- tests/agent/renderer_test.py | 4 +- tests/agent/template_engine_agent_test.py | 28 +- tests/cli/agent_test.py | 6 +- tests/cli/cache_test.py | 6 +- tests/cli/input_test.py | 8 +- tests/cli/model_search_additional_test.py | 2 +- tests/cli/model_test.py | 8 +- tests/cli/theme_test.py | 10 +- tests/cli/tokenizer_test.py | 36 +- tests/compat_test.py | 25 +- .../model/audio/audio_classification_test.py | 4 + tests/model/audio/base_audio_model_test.py | 11 +- tests/model/engine_additional_test.py | 1 + tests/model/model_manager_operation_test.py | 6 +- tests/model/nlp/text_test.py | 5 +- .../model/response_parsers_additional_test.py | 4 +- .../text_generation_modality_vendors_test.py | 22 +- tests/model/text_modalities_full_test.py | 4 +- tests/model/vendor_tool_call_token_test.py | 14 +- tests/utils_test.py | 2 +- 136 files changed, 3213 insertions(+), 1985 deletions(-) diff --git a/Makefile b/Makefile index c91cb305..7357106a 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ lint: poetry run ruff format --preview src/ tests/ poetry run black --preview --enable-unstable-feature=string_processing src/ tests/ poetry run ruff check --fix src/ tests/ + poetry run mypy src/avalan test: poetry sync --extras test diff --git a/poetry.lock b/poetry.lock index fbe53c14..cdaaea54 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "a2a-sdk" @@ -589,6 +589,30 @@ files = [ {file = "blake3-1.0.7-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:d9046bb1e22a8607e1d0d7c3ff47e56e0a197c988502df4bf4d78563f3e9fe2c"}, {file = "blake3-1.0.7-cp313-cp313t-win32.whl", hash = "sha256:bd2f638bcc00fc09ce985ea3c642d45940e1eda198ab1f4b90cfdecbebbc9315"}, {file = "blake3-1.0.7-cp313-cp313t-win_amd64.whl", hash = "sha256:cb3aa1db14231c2ef0ec5acd805505ce128c39ffa510deb3384eed96fe4addcb"}, + {file = "blake3-1.0.7-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:f7db997205aa420d59fb5639346e40beafb9c09252e2ec6efedca8f230f7520c"}, + {file = "blake3-1.0.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:19afec6e276f3bc154541248d92b1ecb198af2ee920025f7ce521028f9a69d8b"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:006a11bbba65a95e88ddc069cca751c8812fd144d582715eeea512452fdbe80d"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7febeffdc8412fed105ca517cee641ac521fb9cfb750bf7e27a5cdf3ddf74a08"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c032ce7c52b71015651c0abe9fe599aa2669e6be578aa17d5f993dc93373401"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b81455f7d24b58fe26be037cc3854c28ea6eb3671ceab3b1ec0b1239aeb6fef"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41b0127b0e7c8610054c421959dbe7140a81ac2c88fa9e099994fbaa529af3c1"}, + {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4755ca95b4114b629d8f3570bc661916d211d52d47f57ff70e9687377ab39cb9"}, + {file = "blake3-1.0.7-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:8abe929cfd27b375e02e3dd7a690192fa4efecc52ef510df91ef01651ef08dc7"}, + {file = "blake3-1.0.7-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:dd607eb5ad5a9b44ff62243759aa0af4085f6f43c9b01f503561a70da63e3b94"}, + {file = "blake3-1.0.7-cp314-cp314-win32.whl", hash = "sha256:a51684d1f346e7680f7c244c25b0e279e3b297f1938126e4ea8e32425ea269f5"}, + {file = "blake3-1.0.7-cp314-cp314-win_amd64.whl", hash = "sha256:a6a481719e28e2c61aafd4273d32663365d97613341b72fcdf2f6afbd426319b"}, + {file = "blake3-1.0.7-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:daa8933cd7db19143bd6b59f7ac4c7c7446767d7b2c3a748a4559aa483275fa2"}, + {file = "blake3-1.0.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:24074adfffffe0fa7a7dd930cc608d6e965e70306e2c1e14d412e29ec94fa360"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dce6e6f03de2674f9860cf330d8a4fcdb63a60659435e5e31d72d174fc102d8e"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e783f33d53a2de8d2ab845235dd53393d521b5e4a76c23d03e77e472266359d3"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:782784aef18eb61f4ce8bf2b9506b7d90f0d183176b453345b221837a18041b7"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6062122e77f40e3733cac2ef3f25e0fc7f555e352fe6f513f8404ad11dc69974"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c2614bc9d69fd6067571f3bb37b3b07a6b86a56167553ad4784a3c508771f39"}, + {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6df2bd56c43bdeb6699d4af0a0dd0d77537d95cb4a5dde4b39ed6e54cc725d6"}, + {file = "blake3-1.0.7-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:8b635cf4350caf459ecb335b32be622068423245bda457d5bc159106eb20f912"}, + {file = "blake3-1.0.7-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:f96a685775f87ddf75ff495dc9698703268c66c170caca977347427ef8d52324"}, + {file = "blake3-1.0.7-cp314-cp314t-win32.whl", hash = "sha256:0633b7d9bad87dc7fce545042353f2e056604d993f71d1dce666a9f5edc13e05"}, + {file = "blake3-1.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:5e356daa0089968dc1ff1d0d112e7cc1700533441d8f30ae99f835a94dc8b0f3"}, {file = "blake3-1.0.7-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ac2816aa95f27ae5c1888e04468d1091acca5f00c302e8884b600dd344bc80ac"}, {file = "blake3-1.0.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c67b69a513d2a73773bc0a0399a5b4d1602dd157b4960ce8bce9f4fd6327323"}, {file = "blake3-1.0.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:999f4817e651e63811a0ecdafbdbb60d2097f7f46aff7dd600fe3f1780e48c8e"}, @@ -1505,6 +1529,7 @@ files = [ {file = "faiss_cpu-1.11.0.post1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dc12b3f89cf48be3f2a20b37f310c3f1a7a5708fdf705f88d639339a24bb590b"}, {file = "faiss_cpu-1.11.0.post1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:773fa45aa98a210ab4e2c17c1b5fb45f6d7e9acb4979c9a0b320b678984428ac"}, {file = "faiss_cpu-1.11.0.post1-cp39-cp39-win_amd64.whl", hash = "sha256:6240c4b1551eedc07e76813c2e14a1583a1db6c319a92a3934bf212d0e4c7791"}, + {file = "faiss_cpu-1.11.0.post1.tar.gz", hash = "sha256:06b1ea9ddec9e4d9a41c8ef7478d493b08d770e9a89475056e963081eed757d1"}, ] [package.dependencies] @@ -2687,10 +2712,9 @@ files = [ name = "joblib" version = "1.5.2" description = "Lightweight pipelining with Python functions" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"memory\" or extra == \"all\"" files = [ {file = "joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241"}, {file = "joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55"}, @@ -2785,6 +2809,93 @@ interegular = ["interegular (>=0.3.1,<0.4.0)"] nearley = ["js2py"] regex = ["regex"] +[[package]] +name = "librt" +version = "0.7.7" +description = "Mypyc runtime library" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "platform_python_implementation != \"PyPy\"" +files = [ + {file = "librt-0.7.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4836c5645f40fbdc275e5670819bde5ab5f2e882290d304e3c6ddab1576a6d0"}, + {file = "librt-0.7.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae8aec43117a645a31e5f60e9e3a0797492e747823b9bda6972d521b436b4e8"}, + {file = "librt-0.7.7-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:aea05f701ccd2a76b34f0daf47ca5068176ff553510b614770c90d76ac88df06"}, + {file = "librt-0.7.7-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7b16ccaeff0ed4355dfb76fe1ea7a5d6d03b5ad27f295f77ee0557bc20a72495"}, + {file = "librt-0.7.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c48c7e150c095d5e3cea7452347ba26094be905d6099d24f9319a8b475fcd3e0"}, + {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4dcee2f921a8632636d1c37f1bbdb8841d15666d119aa61e5399c5268e7ce02e"}, + {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:14ef0f4ac3728ffd85bfc58e2f2f48fb4ef4fa871876f13a73a7381d10a9f77c"}, + {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e4ab69fa37f8090f2d971a5d2bc606c7401170dbdae083c393d6cbf439cb45b8"}, + {file = "librt-0.7.7-cp310-cp310-win32.whl", hash = "sha256:4bf3cc46d553693382d2abf5f5bd493d71bb0f50a7c0beab18aa13a5545c8900"}, + {file = "librt-0.7.7-cp310-cp310-win_amd64.whl", hash = "sha256:f0c8fe5aeadd8a0e5b0598f8a6ee3533135ca50fd3f20f130f9d72baf5c6ac58"}, + {file = "librt-0.7.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a487b71fbf8a9edb72a8c7a456dda0184642d99cd007bc819c0b7ab93676a8ee"}, + {file = "librt-0.7.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4d4efb218264ecf0f8516196c9e2d1a0679d9fb3bb15df1155a35220062eba8"}, + {file = "librt-0.7.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b8bb331aad734b059c4b450cd0a225652f16889e286b2345af5e2c3c625c3d85"}, + {file = "librt-0.7.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:467dbd7443bda08338fc8ad701ed38cef48194017554f4c798b0a237904b3f99"}, + {file = "librt-0.7.7-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50d1d1ee813d2d1a3baf2873634ba506b263032418d16287c92ec1cc9c1a00cb"}, + {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c7e5070cf3ec92d98f57574da0224f8c73faf1ddd6d8afa0b8c9f6e86997bc74"}, + {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bdb9f3d865b2dafe7f9ad7f30ef563c80d0ddd2fdc8cc9b8e4f242f475e34d75"}, + {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8185c8497d45164e256376f9da5aed2bb26ff636c798c9dabe313b90e9f25b28"}, + {file = "librt-0.7.7-cp311-cp311-win32.whl", hash = "sha256:44d63ce643f34a903f09ff7ca355aae019a3730c7afd6a3c037d569beeb5d151"}, + {file = "librt-0.7.7-cp311-cp311-win_amd64.whl", hash = "sha256:7d13cc340b3b82134f8038a2bfe7137093693dcad8ba5773da18f95ad6b77a8a"}, + {file = "librt-0.7.7-cp311-cp311-win_arm64.whl", hash = "sha256:983de36b5a83fe9222f4f7dcd071f9b1ac6f3f17c0af0238dadfb8229588f890"}, + {file = "librt-0.7.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a85a1fc4ed11ea0eb0a632459ce004a2d14afc085a50ae3463cd3dfe1ce43fc"}, + {file = "librt-0.7.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c87654e29a35938baead1c4559858f346f4a2a7588574a14d784f300ffba0efd"}, + {file = "librt-0.7.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c9faaebb1c6212c20afd8043cd6ed9de0a47d77f91a6b5b48f4e46ed470703fe"}, + {file = "librt-0.7.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1908c3e5a5ef86b23391448b47759298f87f997c3bd153a770828f58c2bb4630"}, + {file = "librt-0.7.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbc4900e95a98fc0729523be9d93a8fedebb026f32ed9ffc08acd82e3e181503"}, + {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a7ea4e1fbd253e5c68ea0fe63d08577f9d288a73f17d82f652ebc61fa48d878d"}, + {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ef7699b7a5a244b1119f85c5bbc13f152cd38240cbb2baa19b769433bae98e50"}, + {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:955c62571de0b181d9e9e0a0303c8bc90d47670a5eff54cf71bf5da61d1899cf"}, + {file = "librt-0.7.7-cp312-cp312-win32.whl", hash = "sha256:1bcd79be209313b270b0e1a51c67ae1af28adad0e0c7e84c3ad4b5cb57aaa75b"}, + {file = "librt-0.7.7-cp312-cp312-win_amd64.whl", hash = "sha256:4353ee891a1834567e0302d4bd5e60f531912179578c36f3d0430f8c5e16b456"}, + {file = "librt-0.7.7-cp312-cp312-win_arm64.whl", hash = "sha256:a76f1d679beccccdf8c1958e732a1dfcd6e749f8821ee59d7bec009ac308c029"}, + {file = "librt-0.7.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f4a0b0a3c86ba9193a8e23bb18f100d647bf192390ae195d84dfa0a10fb6244"}, + {file = "librt-0.7.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5335890fea9f9e6c4fdf8683061b9ccdcbe47c6dc03ab8e9b68c10acf78be78d"}, + {file = "librt-0.7.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b4346b1225be26def3ccc6c965751c74868f0578cbcba293c8ae9168483d811"}, + {file = "librt-0.7.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a10b8eebdaca6e9fdbaf88b5aefc0e324b763a5f40b1266532590d5afb268a4c"}, + {file = "librt-0.7.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:067be973d90d9e319e6eb4ee2a9b9307f0ecd648b8a9002fa237289a4a07a9e7"}, + {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:23d2299ed007812cccc1ecef018db7d922733382561230de1f3954db28433977"}, + {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6b6f8ea465524aa4c7420c7cc4ca7d46fe00981de8debc67b1cc2e9957bb5b9d"}, + {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8df32a99cc46eb0ee90afd9ada113ae2cafe7e8d673686cf03ec53e49635439"}, + {file = "librt-0.7.7-cp313-cp313-win32.whl", hash = "sha256:86f86b3b785487c7760247bcdac0b11aa8bf13245a13ed05206286135877564b"}, + {file = "librt-0.7.7-cp313-cp313-win_amd64.whl", hash = "sha256:4862cb2c702b1f905c0503b72d9d4daf65a7fdf5a9e84560e563471e57a56949"}, + {file = "librt-0.7.7-cp313-cp313-win_arm64.whl", hash = "sha256:0996c83b1cb43c00e8c87835a284f9057bc647abd42b5871e5f941d30010c832"}, + {file = "librt-0.7.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:23daa1ab0512bafdd677eb1bfc9611d8ffbe2e328895671e64cb34166bc1b8c8"}, + {file = "librt-0.7.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:558a9e5a6f3cc1e20b3168fb1dc802d0d8fa40731f6e9932dcc52bbcfbd37111"}, + {file = "librt-0.7.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2567cb48dc03e5b246927ab35cbb343376e24501260a9b5e30b8e255dca0d1d2"}, + {file = "librt-0.7.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6066c638cdf85ff92fc6f932d2d73c93a0e03492cdfa8778e6d58c489a3d7259"}, + {file = "librt-0.7.7-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a609849aca463074c17de9cda173c276eb8fee9e441053529e7b9e249dc8b8ee"}, + {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:add4e0a000858fe9bb39ed55f31085506a5c38363e6eb4a1e5943a10c2bfc3d1"}, + {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a3bfe73a32bd0bdb9a87d586b05a23c0a1729205d79df66dee65bb2e40d671ba"}, + {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:0ecce0544d3db91a40f8b57ae26928c02130a997b540f908cefd4d279d6c5848"}, + {file = "librt-0.7.7-cp314-cp314-win32.whl", hash = "sha256:8f7a74cf3a80f0c3b0ec75b0c650b2f0a894a2cec57ef75f6f72c1e82cdac61d"}, + {file = "librt-0.7.7-cp314-cp314-win_amd64.whl", hash = "sha256:3d1fe2e8df3268dd6734dba33ededae72ad5c3a859b9577bc00b715759c5aaab"}, + {file = "librt-0.7.7-cp314-cp314-win_arm64.whl", hash = "sha256:2987cf827011907d3dfd109f1be0d61e173d68b1270107bb0e89f2fca7f2ed6b"}, + {file = "librt-0.7.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8e92c8de62b40bfce91d5e12c6e8b15434da268979b1af1a6589463549d491e6"}, + {file = "librt-0.7.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f683dcd49e2494a7535e30f779aa1ad6e3732a019d80abe1309ea91ccd3230e3"}, + {file = "librt-0.7.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b15e5d17812d4d629ff576699954f74e2cc24a02a4fc401882dd94f81daba45"}, + {file = "librt-0.7.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c084841b879c4d9b9fa34e5d5263994f21aea7fd9c6add29194dbb41a6210536"}, + {file = "librt-0.7.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c8fb9966f84737115513fecbaf257f9553d067a7dd45a69c2c7e5339e6a8dc"}, + {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:9b5fb1ecb2c35362eab2dbd354fd1efa5a8440d3e73a68be11921042a0edc0ff"}, + {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:d1454899909d63cc9199a89fcc4f81bdd9004aef577d4ffc022e600c412d57f3"}, + {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7ef28f2e7a016b29792fe0a2dd04dec75725b32a1264e390c366103f834a9c3a"}, + {file = "librt-0.7.7-cp314-cp314t-win32.whl", hash = "sha256:5e419e0db70991b6ba037b70c1d5bbe92b20ddf82f31ad01d77a347ed9781398"}, + {file = "librt-0.7.7-cp314-cp314t-win_amd64.whl", hash = "sha256:d6b7d93657332c817b8d674ef6bf1ab7796b4f7ce05e420fd45bd258a72ac804"}, + {file = "librt-0.7.7-cp314-cp314t-win_arm64.whl", hash = "sha256:142c2cd91794b79fd0ce113bd658993b7ede0fe93057668c2f98a45ca00b7e91"}, + {file = "librt-0.7.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c8ffe3431d98cc043a14e88b21288b5ec7ee12cb01260e94385887f285ef9389"}, + {file = "librt-0.7.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e40d20ae1722d6b8ea6acf4597e789604649dcd9c295eb7361a28225bc2e9e12"}, + {file = "librt-0.7.7-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f2cb63c49bc96847c3bb8dca350970e4dcd19936f391cfdfd057dcb37c4fa97e"}, + {file = "librt-0.7.7-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8f2f8dcf5ab9f80fb970c6fd780b398efb2f50c1962485eb8d3ab07788595a48"}, + {file = "librt-0.7.7-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a1f5cc41a570269d1be7a676655875e3a53de4992a9fa38efb7983e97cf73d7c"}, + {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ff1fb2dfef035549565a4124998fadcb7a3d4957131ddf004a56edeb029626b3"}, + {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ab2a2a9cd7d044e1a11ca64a86ad3361d318176924bbe5152fbc69f99be20b8c"}, + {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ad3fc2d859a709baf9dd9607bb72f599b1cfb8a39eafd41307d0c3c4766763cb"}, + {file = "librt-0.7.7-cp39-cp39-win32.whl", hash = "sha256:f83c971eb9d2358b6a18da51dc0ae00556ac7c73104dde16e9e14c15aaf685ca"}, + {file = "librt-0.7.7-cp39-cp39-win_amd64.whl", hash = "sha256:264720fc288c86039c091a4ad63419a5d7cabbf1c1c9933336a957ed2483e570"}, + {file = "librt-0.7.7.tar.gz", hash = "sha256:81d957b069fed1890953c3b9c3895c7689960f233eea9a1d9607f71ce7f00b2c"}, +] + [[package]] name = "litellm" version = "1.75.0" @@ -3304,7 +3415,7 @@ description = "A framework for machine learning on Apple silicon." optional = true python-versions = ">=3.9" groups = ["main"] -markers = "(platform_machine == \"arm64\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\") and platform_system == \"Darwin\" and (extra == \"vllm\" or extra == \"nvidia\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\")" +markers = "(extra == \"test\" or extra == \"mlx\" or extra == \"apple\" or platform_machine == \"arm64\") and platform_system == \"Darwin\" and (extra == \"vllm\" or extra == \"nvidia\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\")" files = [ {file = "mlx_metal-0.28.0-py3-none-macosx_13_0_arm64.whl", hash = "sha256:ce08d40f1fad4f0b3bc87bfff5d603c7fe7dd141c082ba9ce9328b41e8f8d46b"}, {file = "mlx_metal-0.28.0-py3-none-macosx_14_0_arm64.whl", hash = "sha256:424142ab843e2ac0b14edb58cf88d96723823c565291f46ddeeaa072abcc991e"}, @@ -3640,6 +3751,67 @@ files = [ {file = "multidict-6.7.0.tar.gz", hash = "sha256:c6e99d9a65ca282e578dfea819cfa9c0a62b2499d8677392e09feaf305e9e6f5"}, ] +[[package]] +name = "mypy" +version = "1.19.1" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "mypy-1.19.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f05aa3d375b385734388e844bc01733bd33c644ab48e9684faa54e5389775ec"}, + {file = "mypy-1.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:022ea7279374af1a5d78dfcab853fe6a536eebfda4b59deab53cd21f6cd9f00b"}, + {file = "mypy-1.19.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee4c11e460685c3e0c64a4c5de82ae143622410950d6be863303a1c4ba0e36d6"}, + {file = "mypy-1.19.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de759aafbae8763283b2ee5869c7255391fbc4de3ff171f8f030b5ec48381b74"}, + {file = "mypy-1.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ab43590f9cd5108f41aacf9fca31841142c786827a74ab7cc8a2eacb634e09a1"}, + {file = "mypy-1.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:2899753e2f61e571b3971747e302d5f420c3fd09650e1951e99f823bc3089dac"}, + {file = "mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288"}, + {file = "mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab"}, + {file = "mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6"}, + {file = "mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331"}, + {file = "mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925"}, + {file = "mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042"}, + {file = "mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1"}, + {file = "mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e"}, + {file = "mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2"}, + {file = "mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8"}, + {file = "mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a"}, + {file = "mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13"}, + {file = "mypy-1.19.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3157c7594ff2ef1634ee058aafc56a82db665c9438fd41b390f3bde1ab12250"}, + {file = "mypy-1.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdb12f69bcc02700c2b47e070238f42cb87f18c0bc1fc4cdb4fb2bc5fd7a3b8b"}, + {file = "mypy-1.19.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f859fb09d9583a985be9a493d5cfc5515b56b08f7447759a0c5deaf68d80506e"}, + {file = "mypy-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9a6538e0415310aad77cb94004ca6482330fece18036b5f360b62c45814c4ef"}, + {file = "mypy-1.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:da4869fc5e7f62a88f3fe0b5c919d1d9f7ea3cef92d3689de2823fd27e40aa75"}, + {file = "mypy-1.19.1-cp313-cp313-win_amd64.whl", hash = "sha256:016f2246209095e8eda7538944daa1d60e1e8134d98983b9fc1e92c1fc0cb8dd"}, + {file = "mypy-1.19.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:06e6170bd5836770e8104c8fdd58e5e725cfeb309f0a6c681a811f557e97eac1"}, + {file = "mypy-1.19.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:804bd67b8054a85447c8954215a906d6eff9cabeabe493fb6334b24f4bfff718"}, + {file = "mypy-1.19.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:21761006a7f497cb0d4de3d8ef4ca70532256688b0523eee02baf9eec895e27b"}, + {file = "mypy-1.19.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28902ee51f12e0f19e1e16fbe2f8f06b6637f482c459dd393efddd0ec7f82045"}, + {file = "mypy-1.19.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:481daf36a4c443332e2ae9c137dfee878fcea781a2e3f895d54bd3002a900957"}, + {file = "mypy-1.19.1-cp314-cp314-win_amd64.whl", hash = "sha256:8bb5c6f6d043655e055be9b542aa5f3bdd30e4f3589163e85f93f3640060509f"}, + {file = "mypy-1.19.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7bcfc336a03a1aaa26dfce9fff3e287a3ba99872a157561cbfcebe67c13308e3"}, + {file = "mypy-1.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b7951a701c07ea584c4fe327834b92a30825514c868b1f69c30445093fdd9d5a"}, + {file = "mypy-1.19.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b13cfdd6c87fc3efb69ea4ec18ef79c74c3f98b4e5498ca9b85ab3b2c2329a67"}, + {file = "mypy-1.19.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f28f99c824ecebcdaa2e55d82953e38ff60ee5ec938476796636b86afa3956e"}, + {file = "mypy-1.19.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c608937067d2fc5a4dd1a5ce92fd9e1398691b8c5d012d66e1ddd430e9244376"}, + {file = "mypy-1.19.1-cp39-cp39-win_amd64.whl", hash = "sha256:409088884802d511ee52ca067707b90c883426bd95514e8cfda8281dc2effe24"}, + {file = "mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247"}, + {file = "mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba"}, +] + +[package.dependencies] +librt = {version = ">=0.6.2", markers = "platform_python_implementation != \"PyPy\""} +mypy_extensions = ">=1.0.0" +pathspec = ">=0.9.0" +typing_extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -4445,10 +4617,9 @@ numpy = "*" name = "pillow" version = "11.3.0" description = "Python Imaging Library (Fork)" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"memory\" or extra == \"all\" or extra == \"test\" or extra == \"vision\" or extra == \"vendors\" or extra == \"vllm\" or extra == \"nvidia\"" files = [ {file = "pillow-11.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1b9c17fd4ace828b3003dfd1e30bff24863e0eb59b535e8f80194d9cc7ecf860"}, {file = "pillow-11.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:65dc69160114cdd0ca0f35cb434633c75e8e7fad4cf855177a05bf38678f73ad"}, @@ -4853,10 +5024,9 @@ files = [ name = "psutil" version = "7.1.0" description = "Cross-platform lib for process and system monitoring." -optional = true +optional = false python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"cpu\" or extra == \"all\" or extra == \"vllm\" or extra == \"nvidia\"" files = [ {file = "psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13"}, {file = "psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5"}, @@ -6763,10 +6933,9 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] name = "scikit-learn" version = "1.7.2" description = "A set of python modules for machine learning and data mining" -optional = true +optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"memory\" or extra == \"all\"" files = [ {file = "scikit_learn-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b33579c10a3081d076ab403df4a4190da4f4432d443521674637677dc91e61f"}, {file = "scikit_learn-1.7.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:36749fb62b3d961b1ce4fedf08fa57a1986cd409eff2d783bca5d4b9b5fce51c"}, @@ -6820,10 +6989,9 @@ tests = ["matplotlib (>=3.5.0)", "mypy (>=1.15)", "numpydoc (>=1.2.0)", "pandas name = "scipy" version = "1.16.2" description = "Fundamental algorithms for scientific computing in Python" -optional = true +optional = false python-versions = ">=3.11" groups = ["main"] -markers = "extra == \"memory\" or extra == \"all\" or extra == \"vllm\" or extra == \"nvidia\"" files = [ {file = "scipy-1.16.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6ab88ea43a57da1af33292ebd04b417e8e2eaf9d5aa05700be8d6e1b6501cd92"}, {file = "scipy-1.16.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c95e96c7305c96ede73a7389f46ccd6c659c4da5ef1b2789466baeaed3622b6e"}, @@ -7523,10 +7691,9 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] name = "threadpoolctl" version = "3.6.0" description = "threadpoolctl" -optional = true +optional = false python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"memory\" or extra == \"all\"" files = [ {file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"}, {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"}, @@ -8052,7 +8219,7 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, @@ -8756,4 +8923,4 @@ vllm = ["vllm"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.13" -content-hash = "8631e253579b9b6a7fa5206b39846e0f03b19fd2e160640d724b8e2f0a7d1572" +content-hash = "e3bb6d7a0d3ad3018734503b81ce0a74228041d5d18542236730ed1dc8defd7f" diff --git a/pyproject.toml b/pyproject.toml index ca84a1b7..08e95c7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -276,6 +276,74 @@ version = "1.3.10" [tool.poetry.group.dev.dependencies] ruff = "^0.11.11" black = "^25.1.0" +mypy = "^1.14.0" + +[tool.mypy] +python_version = "3.11" +files = ["src/avalan"] +plugins = ["pydantic.mypy"] +strict = false +warn_return_any = true +warn_unused_ignores = false +warn_redundant_casts = true +warn_unused_configs = true +show_error_codes = true +show_column_numbers = true +pretty = true +check_untyped_defs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +disallow_untyped_decorators = false + +[[tool.mypy.overrides]] +module = [ + "a2a.*", + "anthropic.*", + "bitsandbytes.*", + "boto3.*", + "botocore.*", + "diffusers.*", + "elasticsearch.*", + "faiss.*", + "google.*", + "huggingface_hub.*", + "imageio.*", + "jinja2.*", + "keyring.*", + "litellm.*", + "markitdown.*", + "markdownify.*", + "mcp.*", + "mlx.*", + "mlx_lm.*", + "numpy.*", + "openai.*", + "pandas.*", + "pgvector.*", + "PIL.*", + "playwright.*", + "psycopg.*", + "psycopg_pool.*", + "pydantic.*", + "RestrictedPython.*", + "rich.*", + "soundfile.*", + "sqlalchemy.*", + "sqlglot.*", + "sympy.*", + "tiktoken.*", + "torch.*", + "torchaudio.*", + "torchvision.*", + "tqdm.*", + "transformers.*", + "tree_sitter.*", + "tree_sitter_python.*", + "uvicorn.*", + "vllm.*", + "youtube_transcript_api.*", +] +ignore_missing_imports = true [tool.poetry-dynamic-versioning] enable = true diff --git a/src/avalan/__init__.py b/src/avalan/__init__.py index dcf9b406..51fb3b94 100644 --- a/src/avalan/__init__.py +++ b/src/avalan/__init__.py @@ -5,18 +5,18 @@ from packaging.version import Version, parse -def _config() -> dict: +def _config() -> dict[str, str]: config = metadata("avalan") package_version = metadata_version("avalan") return { - "name": config["Name"], + "name": str(config["Name"]), "version": package_version, - "license": config["License"], + "license": str(config["License"]), "url": "https://avalan.ai", } -config = _config() +config: dict[str, str] = _config() def license() -> str: diff --git a/src/avalan/agent/engine.py b/src/avalan/agent/engine.py index 0b5fdba5..3d724ba6 100644 --- a/src/avalan/agent/engine.py +++ b/src/avalan/agent/engine.py @@ -76,7 +76,8 @@ async def input_token_count(self) -> int | None: }, ) ) - count = self._model.input_token_count( + assert hasattr(self._model, "input_token_count") + count: int | None = self._model.input_token_count( self._last_prompt[0], system_prompt=self._last_prompt[1], developer_prompt=self._last_prompt[2], @@ -158,6 +159,7 @@ async def __call__( }, ) ) + assert context.input is not None output = await self._run(context, context.input, **run_args) await self._event_manager.trigger( Event( @@ -246,11 +248,18 @@ async def _run( if previous_message: await self.sync_message(previous_message) - for current_message in input_value: + for current_item in input_value: + current_message = ( + current_item + if isinstance(current_item, Message) + else Message(role=MessageRole.USER, content=current_item) + ) await self.sync_message(current_message) # Make recent memory the new model input - input_value = [rm.message for rm in self._memory.recent_messages] + recent = self._memory.recent_messages + assert recent is not None + input_value = [rm.message for rm in recent] # Have model generate output from input @@ -289,7 +298,7 @@ async def _run( tool=self._tool, context=context, ) - output = await self._model_manager(model_task) + output: TextGenerationResponse = await self._model_manager(model_task) await self._event_manager.trigger( Event( type=EventType.MODEL_EXECUTE_AFTER, @@ -344,6 +353,7 @@ async def sync_message(self, message: Message) -> None: }, ) ) + assert self._model.model_id is not None await self._memory.append_message( EngineMessage( agent_id=self._id, diff --git a/src/avalan/agent/loader.py b/src/avalan/agent/loader.py index 704f2136..558df80c 100644 --- a/src/avalan/agent/loader.py +++ b/src/avalan/agent/loader.py @@ -63,6 +63,7 @@ class OrchestratorLoader: ) _OPENAI_RESPONSES_ALIASES = frozenset({"response", "responses"}) + _event_manager: EventManager _hub: HuggingfaceHub _logger: Logger _participant_id: UUID @@ -75,12 +76,19 @@ def __init__( logger: Logger, participant_id: UUID, stack: AsyncExitStack, + event_manager: EventManager | None = None, ) -> None: + self._event_manager = event_manager or EventManager() self._hub = hub self._logger = logger self._participant_id = participant_id self._stack = stack + @property + def event_manager(self) -> EventManager: + """Return the event manager instance.""" + return self._event_manager + @staticmethod def parse_permanent_store_value( value: str, @@ -372,7 +380,7 @@ async def from_file( agent_config=agent_config, uri=uri, engine_config=engine_config, - tools=enable_tools, + tools=enable_tools or [], call_options=call_options, template_vars=template_vars, memory_permanent_message=memory_permanent_message, @@ -495,7 +503,7 @@ async def from_settings( _l("Loading event manager") - event_manager = EventManager() + event_manager = self._event_manager if settings.log_events: def _log_event(event: Event) -> None: @@ -607,6 +615,7 @@ def _log_event(event: Event) -> None: assert settings.agent_id + agent: Orchestrator if settings.orchestrator_type == "json": assert settings.json_config is not None agent = self._load_json_orchestrator( diff --git a/src/avalan/agent/orchestrator/__init__.py b/src/avalan/agent/orchestrator/__init__.py index 5e9ed47e..22b28a0e 100644 --- a/src/avalan/agent/orchestrator/__init__.py +++ b/src/avalan/agent/orchestrator/__init__.py @@ -6,6 +6,7 @@ Message, MessageContentText, MessageRole, + TransformerEngineSettings, ) from ...entities import Modality as Modality from ...event import Event, EventType @@ -14,18 +15,24 @@ from ...model.call import ModelCallContext from ...model.engine import Engine from ...model.manager import ModelManager +from ...model.response.text import ( + TextGenerationResponse as TextGenerationResponse, +) from ...tool.manager import ToolManager from .. import ( AgentOperation, - EngineEnvironment, InputType, NoOperationAvailableException, Specification, ) +from .. import ( + EngineEnvironment as EngineEnvironment, +) from ..engine import EngineAgent from ..renderer import Renderer, TemplateEngineAgent from .response.orchestrator_response import OrchestratorResponse +from collections.abc import Callable, Coroutine from contextlib import ExitStack from dataclasses import asdict from json import dumps @@ -46,7 +53,7 @@ class Orchestrator: _memory: MemoryManager _tool: ToolManager _event_manager: EventManager - _engine_agents: dict[EngineEnvironment, EngineAgent] = {} + _engine_agents: dict[str, EngineAgent] = {} _engines_stack: ExitStack = ExitStack() _operation_step: int | None = None _model_ids: set[str] = set() @@ -108,7 +115,9 @@ def id(self) -> UUID: return self._id @property - def input_token_count(self) -> int | None: + def input_token_count( + self, + ) -> Callable[[], Coroutine[Any, Any, int | None]] | None: return ( self._last_engine_agent.input_token_count if self._last_engine_agent @@ -243,7 +252,7 @@ async def __call__(self, input: Input, **kwargs) -> OrchestratorResponse: return OrchestratorResponse( messages, - result, + result, # type: ignore[arg-type] engine_agent, operation, engine_args, @@ -264,7 +273,11 @@ async def __aenter__(self): environment = operation.environment environment_hash = dumps(asdict(environment)) if environment_hash not in self._engine_agents: + assert environment.engine_uri.model_id is not None model_ids.append(environment.engine_uri.model_id) + assert isinstance( + environment.settings, TransformerEngineSettings + ) engine = self._model_manager.load_engine( environment.engine_uri, environment.settings, @@ -359,18 +372,22 @@ def _input_messages( else message.content ) render_vars.update({"input": message_content}) - content = ( - self._renderer(self._user_template, **render_vars) - if self._user_template - else self._renderer.from_string( + if self._user_template: + content = self._renderer( + self._user_template, **render_vars + ) + else: + assert self._user is not None + content = self._renderer.from_string( self._user, template_vars=render_vars ) - ) message = Message(role=message.role, content=content) if isinstance(input, list): - input[-1] = message + input[-1] = message # type: ignore[call-overload] else: input = message - return input + # The return type must be Message | list[Message] per the signature + assert isinstance(input, (Message, list)) + return input # type: ignore[return-value] diff --git a/src/avalan/agent/orchestrator/orchestrators/default.py b/src/avalan/agent/orchestrator/orchestrators/default.py index 44be78cb..4dca13b7 100644 --- a/src/avalan/agent/orchestrator/orchestrators/default.py +++ b/src/avalan/agent/orchestrator/orchestrators/default.py @@ -1,4 +1,10 @@ -from ....agent import AgentOperation, EngineEnvironment, Goal, Specification +from ....agent import ( + AgentOperation, + EngineEnvironment, + Goal, + Role, + Specification, +) from ....agent.orchestrator import Orchestrator from ....entities import EngineUri, Modality, TransformerEngineSettings from ....event.manager import EventManager @@ -47,7 +53,7 @@ def __init__( ) else: specification = Specification( - role=role, + role=Role(persona=[role]) if role else None, goal=( Goal(task=task, instructions=[instructions]) if task and instructions @@ -66,7 +72,8 @@ def __init__( AgentOperation( specification=specification, environment=EngineEnvironment( - engine_uri=engine_uri, settings=settings + engine_uri=engine_uri, + settings=settings or TransformerEngineSettings(), ), modality=Modality.TEXT_GENERATION, ), diff --git a/src/avalan/agent/orchestrator/orchestrators/json.py b/src/avalan/agent/orchestrator/orchestrators/json.py index ca96c862..42a1c85e 100644 --- a/src/avalan/agent/orchestrator/orchestrators/json.py +++ b/src/avalan/agent/orchestrator/orchestrators/json.py @@ -8,7 +8,11 @@ Specification, ) from ....agent.orchestrator import Orchestrator -from ....entities import Input, Modality, TransformerEngineSettings +from ....entities import ( + Input, + Modality, + TransformerEngineSettings, +) from ....event.manager import EventManager from ....memory.manager import MemoryManager from ....model.manager import ModelManager @@ -17,6 +21,7 @@ from dataclasses import dataclass from logging import Logger from typing import Annotated, get_args, get_origin +from uuid import UUID @dataclass(frozen=True, kw_only=True, slots=True) @@ -56,7 +61,9 @@ def __init__( Property( name=name, data_type=data_type, - description=description.strip(), + description=( + description.strip() if description else None + ), ) ) else: @@ -93,6 +100,7 @@ def __init__( event_manager: EventManager, output: type | list[Property], *, + id: UUID | None = None, role: str | None = None, task: str | None = None, instructions: str | None = None, @@ -136,15 +144,20 @@ def __init__( AgentOperation( specification=specification, environment=EngineEnvironment( - engine_uri=engine_uri, settings=settings + engine_uri=engine_uri, + settings=settings or TransformerEngineSettings(), ), modality=Modality.TEXT_GENERATION, ), call_options=call_options, + id=id, + name=name, user=user, user_template=user_template, ) - async def __call__(self, input: Input, **kwargs) -> str: + async def __call__( # type: ignore[override] + self, input: Input, **kwargs + ) -> str: text_response = await super().__call__(input, **kwargs) return await text_response.to_json() diff --git a/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py b/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py index 2e8ea9ae..391ace86 100644 --- a/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py +++ b/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py @@ -43,7 +43,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __getattr__(self, name): return getattr(self._orchestrator, name) - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, **kwargs ) -> ReasoningOrchestratorResponse: template_vars = {} diff --git a/src/avalan/agent/orchestrator/response/orchestrator_response.py b/src/avalan/agent/orchestrator/response/orchestrator_response.py index 0236f497..16202b38 100644 --- a/src/avalan/agent/orchestrator/response/orchestrator_response.py +++ b/src/avalan/agent/orchestrator/response/orchestrator_response.py @@ -27,7 +27,7 @@ from json import dumps from queue import Queue from time import perf_counter -from typing import Any, AsyncIterator, Callable +from typing import Any, AsyncIterator, Callable, cast from uuid import UUID @@ -131,7 +131,10 @@ async def to(self, entity_class: type) -> Any: def __aiter__(self) -> "OrchestratorResponse": if self._event_manager: self._response.add_done_callback(self._on_consumed) - self._response_iterator = self._response.__aiter__() + self._response_iterator = cast( + AsyncIterator[Token | TokenDetail | Event], + self._response.__aiter__(), + ) self._calls = Queue() self._parser_queue = Queue() self._tool_context = ToolCallContext( @@ -149,6 +152,7 @@ def __aiter__(self) -> "OrchestratorResponse": async def __anext__(self) -> Token | TokenDetail | Event: assert self._response_iterator + assert self._parser_queue is not None if not self._parser_queue.empty(): return self._parser_queue.get() @@ -162,9 +166,15 @@ async def __anext__(self) -> Token | TokenDetail | Event: if not self._tool_call_events.empty(): event = self._tool_call_events.get() assert event.type == EventType.TOOL_PROCESS + assert self._event_manager is not None await self._event_manager.trigger(event) - calls: list[ToolCall] = event.payload or [] + payload = event.payload + calls: list[ToolCall] = ( + payload.get("calls", []) + if isinstance(payload, dict) + else (payload if isinstance(payload, list) else []) + ) if calls: for call in calls: assert isinstance(call, ToolCall) @@ -191,7 +201,8 @@ async def __anext__(self) -> Token | TokenDetail | Event: if self._event_manager: await self._event_manager.trigger(execute_event) - context = ToolCallContext( + assert self._tool_context is not None + tool_ctx = ToolCallContext( input=self._tool_context.input, agent_id=self._agent_id, participant_id=self._participant_id, @@ -200,13 +211,13 @@ async def __anext__(self) -> Token | TokenDetail | Event: ) result = ( - await self._tool_manager(call, context) + await self._tool_manager(call, tool_ctx) if self._tool_manager else None ) self._call_history.append(call) - self._tool_context = context + self._tool_context = tool_ctx end = perf_counter() result_event = Event( @@ -236,6 +247,7 @@ async def __anext__(self) -> Token | TokenDetail | Event: tool_messages = [] for e in result_events: + assert e.payload is not None tool_result = e.payload["result"] tool_output = ( tool_result.message @@ -256,8 +268,9 @@ async def __anext__(self) -> Token | TokenDetail | Event: ) tool_result_output = dumps( ( - asdict(tool_output) + asdict(tool_output) # type: ignore[arg-type] if is_dataclass(tool_output) + and not isinstance(tool_output, type) else tool_output ), default=lambda o: ( @@ -293,8 +306,15 @@ async def __anext__(self) -> Token | TokenDetail | Event: or isinstance(self._input, Message) ) - messages = list( - self._input if isinstance(self._input, list) else [self._input] + messages: list[Message] = list( + cast( + list[Message], + ( + self._input + if isinstance(self._input, list) + else [self._input] + ), + ) ) messages.extend(tool_messages) @@ -307,11 +327,12 @@ async def __anext__(self) -> Token | TokenDetail | Event: "engine_args": self._engine_args, }, ) + assert self._event_manager is not None await self._event_manager.trigger(event_tool_model_run) context = self._make_child_context(messages) inner_response = await self._engine_agent(context) - assert inner_response + assert isinstance(inner_response, TextGenerationResponse) self._response = inner_response self.__aiter__() @@ -334,12 +355,13 @@ async def __anext__(self) -> Token | TokenDetail | Event: if isinstance(token, ToolCallToken) and token.call: event = Event( type=EventType.TOOL_PROCESS, - payload=[token.call], + payload={"calls": [token.call]}, started=perf_counter(), ) self._tool_process_events.put(event) except StopAsyncIteration: if self._tool_parser: + assert self._parser_queue is not None for item in await self._tool_parser.flush(): if isinstance(item, Event): self._tool_process_events.put(item) @@ -419,7 +441,8 @@ async def _react( ) self._call_history.append(call) self._tool_context = context - results.append(result) + if result is not None: + results.append(result) if self._event_manager: end = perf_counter() @@ -492,8 +515,15 @@ async def _react_process( or isinstance(self._input, Message) ) - messages = list( - self._input if isinstance(self._input, list) else [self._input] + messages: list[Message] = list( + cast( + list[Message], + ( + self._input + if isinstance(self._input, list) + else [self._input] + ), + ) ) messages.extend(tool_messages) @@ -514,11 +544,12 @@ async def _react_process( "engine_args": self._engine_args, }, ) + assert self._event_manager is not None await self._event_manager.trigger(event_tool_model_run) context = self._make_child_context(messages) response = await self._engine_agent(context) - assert response + assert isinstance(response, TextGenerationResponse) return response async def _emit( @@ -564,6 +595,7 @@ async def _emit( else: items = [item] + assert self._parser_queue is not None for it in items: if isinstance(it, Event): if it.type == EventType.TOOL_PROCESS: diff --git a/src/avalan/agent/renderer.py b/src/avalan/agent/renderer.py index 1a5dc5df..30744dc8 100644 --- a/src/avalan/agent/renderer.py +++ b/src/avalan/agent/renderer.py @@ -56,10 +56,9 @@ def from_string( self, template: str, template_vars: dict | None = None, - encoding: str = "utf-8", ) -> str: return ( - Template(template).render(**template_vars).encode(encoding) + Template(template).render(**template_vars) if template_vars else template ) diff --git a/src/avalan/cli/__init__.py b/src/avalan/cli/__init__.py index 5d9b9c8e..54be850d 100644 --- a/src/avalan/cli/__init__.py +++ b/src/avalan/cli/__init__.py @@ -2,10 +2,11 @@ from collections.abc import Iterator from contextlib import contextmanager, nullcontext -from io import UnsupportedOperation +from io import TextIOWrapper, UnsupportedOperation from json import dumps from select import select from sys import stdin +from typing import Any from rich.console import Console from rich.live import Live @@ -48,13 +49,8 @@ def confirm_tool_call( live: Live | None = None, ) -> str: prompt = "Execute tool call?" - options = { - "choices": ["y", "a", "n"], - "default": "n", - "console": console, - "show_choices": True, - "show_default": True, - } + choices = ["y", "a", "n"] + default = "n" with _pause_live(live) if live else nullcontext(): call_element = Syntax( @@ -65,22 +61,30 @@ def confirm_tool_call( if live: prompt_element = Text.from_markup(prompt, style="prompt") prompt_element.end = "" - choices = "/".join(options["choices"]) + choices_str = "/".join(choices) prompt_element.append(" ") - prompt_element.append(f"[{choices}]", "prompt.choices") + prompt_element.append(f"[{choices_str}]", "prompt.choices") - if options["show_default"]: - default = Text(f"({options['default']})", "prompt.default") - prompt_element.append(" ") - prompt_element.append(default) + default_text = Text(f"({default})", "prompt.default") + prompt_element.append(" ") + prompt_element.append(default_text) console.print(prompt_element) stdin_is_tty = stdin.isatty() with open(tty_path) if not stdin_is_tty else nullcontext() as tty: - if not stdin_is_tty: - options["stream"] = tty - return Prompt.ask(prompt, **options) + stream: TextIOWrapper | None = ( + tty if not stdin_is_tty and tty else None + ) + return Prompt.ask( + prompt, + choices=choices, + default=default, + console=console, + show_choices=True, + show_default=True, + stream=stream, + ) def has_input(console: Console) -> bool: @@ -123,10 +127,12 @@ def get_input( if is_input_available and force_prompt else nullcontext() ) as tty: - kwargs = {} - if is_input_available and force_prompt: - kwargs["stream"] = tty - input_string = PromptWithoutPrefix.ask(full_prompt, **kwargs) + stream: Any = ( + tty if is_input_available and force_prompt else None + ) + input_string = PromptWithoutPrefix.ask( + full_prompt, stream=stream + ) if strip_prompt: input_string = input_string.strip() except EOFError: diff --git a/src/avalan/cli/__main__.py b/src/avalan/cli/__main__.py index 7224dcf1..d6af372f 100644 --- a/src/avalan/cli/__main__.py +++ b/src/avalan/cli/__main__.py @@ -23,6 +23,7 @@ model_uninstall, ) from ..cli.commands.tokenizer import tokenize +from ..cli.theme import Theme from ..cli.theme.fancy import FancyTheme from ..entities import ( AttentionImplementation, @@ -51,11 +52,16 @@ import gettext import sys -from argparse import ArgumentParser, Namespace, _SubParsersAction +from argparse import ( + ArgumentParser, + Namespace, + _ArgumentGroup, + _SubParsersAction, +) from asyncio import run as run_in_loop from asyncio.exceptions import CancelledError from dataclasses import fields -from gettext import translation +from gettext import GNUTranslations, NullTranslations, translation from importlib.util import find_spec from locale import getlocale from logging import ( @@ -73,6 +79,7 @@ from pathlib import Path from subprocess import run from tomllib import load as toml_load +from types import ModuleType from typing import Optional, get_args, get_origin from typing import get_args as get_type_args from uuid import uuid4 @@ -81,7 +88,7 @@ from rich.console import Console from rich.logging import RichHandler from rich.prompt import Confirm, Prompt -from rich.theme import Theme +from rich.theme import Theme as RichTheme from torch.cuda import device_count, is_available, set_device from torch.distributed import destroy_process_group from transformers.utils import ( @@ -112,7 +119,10 @@ def __init__(self, logger: Logger): ) default_device = TransformerModel.get_default_device() self._parser = CLI._create_parser( - default_device, cache_dir, default_locales_path, default_locale + default_device, + cache_dir, + default_locales_path, + default_locale or "en_US", ) @staticmethod @@ -141,7 +151,7 @@ def _create_parser( cache_dir: str, default_locales_path: str, default_locale: str, - ): + ) -> ArgumentParser: default_attention = CLI._default_attention(default_device) global_parser = ArgumentParser(add_help=False) global_parser.add_argument( @@ -1637,7 +1647,7 @@ def _create_parser( @staticmethod def _get_translator( app_name: str, locales_path: str, locale: str - ) -> object: + ) -> GNUTranslations | NullTranslations | ModuleType: """Return translation object for ``locale`` or ``gettext`` fallback.""" try: return translation( @@ -1763,7 +1773,7 @@ def _add_agent_server_arguments(parser: ArgumentParser) -> ArgumentParser: @staticmethod def _add_agent_settings_arguments( parser: ArgumentParser, - ) -> ArgumentParser: + ) -> _ArgumentGroup: group = parser.add_argument_group("inline agent settings") group.add_argument("--engine-uri", type=str, help="Agent engine URI") group.add_argument("--name", type=str, help="Agent name") @@ -1885,7 +1895,7 @@ def _add_agent_settings_arguments( @staticmethod def _add_tool_settings_arguments( parser: ArgumentParser, *, prefix: str, settings_cls: type - ) -> ArgumentParser: + ) -> _ArgumentGroup: """Add dataclass based tool options to ``parser``.""" group = parser.add_argument_group(f"{prefix} tool settings") @@ -1975,12 +1985,15 @@ async def __call__(self) -> None: setattr(args, f"run_chat_{key}", value) translator = CLI._get_translator(self._name, args.locales, args.locale) + gettext_fn = getattr(translator, "gettext", gettext.gettext) + ngettext_fn = getattr(translator, "ngettext", gettext.ngettext) assert self._logger is not None and isinstance(self._logger, Logger) - theme = FancyTheme(translator.gettext, translator.ngettext) + theme = FancyTheme(gettext_fn, ngettext_fn) _ = theme._ console = Console( - theme=Theme(styles=theme.get_styles()), record=args.record + theme=RichTheme(styles=theme.get_styles()), + record=args.record, ) if args.help_full: @@ -1991,14 +2004,14 @@ async def __call__(self) -> None: if requires_token: if not access_token: - prompt_kwargs = {} + stream = None if has_input(console): try: - prompt_kwargs["stream"] = open(args.tty) + stream = open(args.tty) except OSError: pass access_token = Prompt.ask( - theme.ask_access_token(), **prompt_kwargs + theme.ask_access_token(), stream=stream ) assert access_token else: @@ -2091,13 +2104,15 @@ def filter(self, record: LogRecord) -> bool: ) suggest_login = suggest_login and not has_input(console) + connecting_spinner = theme.get_spinner("connecting") if args.login or ( suggest_login and Confirm.ask(theme.ask_login_to_hub(), default=False) ): + assert connecting_spinner is not None with console.status( theme.logging_in(hub.domain), - spinner=(theme.get_spinner("connecting")), + spinner=connecting_spinner, refresh_per_second=self._REFRESH_RATE, ): hub.login() @@ -2108,7 +2123,7 @@ def filter(self, record: LogRecord) -> bool: theme.welcome( self._site.geturl(), self._name, - self._version, + str(self._version), self._license, user, ) diff --git a/src/avalan/cli/commands/agent.py b/src/avalan/cli/commands/agent.py index be1f8303..3190e0ff 100644 --- a/src/avalan/cli/commands/agent.py +++ b/src/avalan/cli/commands/agent.py @@ -5,6 +5,7 @@ ) from ...cli import confirm_tool_call, get_input, has_input from ...cli.commands.model import token_generation +from ...cli.theme import Theme from ...entities import ( Backend, GenerationCacheStrategy, @@ -23,10 +24,10 @@ from argparse import Namespace from contextlib import AsyncExitStack -from dataclasses import fields +from dataclasses import fields as dataclass_fields from logging import Logger from os.path import dirname, getmtime, join -from typing import Iterable, Mapping +from typing import Any, Iterable, Mapping, TypeVar, cast from uuid import UUID, uuid4 from jinja2 import Environment, FileSystemLoader @@ -34,7 +35,6 @@ from rich.live import Live from rich.prompt import Confirm, Prompt from rich.syntax import Syntax -from rich.theme import Theme def _parse_permanent_memory_items( @@ -188,16 +188,19 @@ def get_orchestrator_settings( ) +T = TypeVar("T") + + def _tool_settings_from_mapping( mapping: Mapping[str, object] | Namespace, *, prefix: str | None = None, - settings_cls: type, + settings_cls: type[T], open_files: bool = True, -) -> object: +) -> T | None: """Return tool settings from a mapping using dataclass ``settings_cls``.""" values: dict[str, object] = {} - for field in fields(settings_cls): + for field in dataclass_fields(cast(Any, settings_cls)): key = f"tool_{prefix}_{field.name}" if prefix else field.name if isinstance(mapping, Namespace): if hasattr(mapping, key): @@ -231,9 +234,10 @@ def get_tool_settings( args: Namespace, *, prefix: str, - settings_cls: type, + settings_cls: type[T], open_files: bool = True, -) -> object: +) -> T | None: + """Return tool settings instance from CLI arguments.""" return _tool_settings_from_mapping( args, prefix=prefix, settings_cls=settings_cls, open_files=open_files ) @@ -247,7 +251,7 @@ async def agent_message_search( logger: Logger, refresh_per_second: int, ) -> None: - _, _i = theme._, theme.icons + _, _i = theme._, theme._icons specs_path = args.specifications_file engine_uri = getattr(args, "engine_uri", None) @@ -267,7 +271,7 @@ async def agent_message_search( input_string = get_input( console, - _i["user_input"] + " ", + (_i.get("user_input") or "") + " ", echo_stdin=not args.no_repl, is_quiet=args.quiet, tty_path=tty_path, @@ -277,6 +281,8 @@ async def agent_message_search( limit = args.limit + spinner = theme.get_spinner("agent_loading") + assert spinner is not None async with AsyncExitStack() as stack: loader = OrchestratorLoader( hub=hub, @@ -286,7 +292,7 @@ async def agent_message_search( ) with console.status( _("Loading agent..."), - spinner=theme.get_spinner("agent_loading"), + spinner=spinner, refresh_per_second=refresh_per_second, ): if specs_path: @@ -331,7 +337,8 @@ async def agent_message_search( ) orchestrator = await stack.enter_async_context(orchestrator) - assert orchestrator.engine_agent and orchestrator.engine.model_id + assert orchestrator.engine_agent + assert orchestrator.engine and orchestrator.engine.model_id can_access = args.skip_hub_access_check or hub.can_access( orchestrator.engine.model_id @@ -380,7 +387,7 @@ async def agent_run( logger: Logger, refresh_per_second: int, ) -> None: - _, _i = theme._, theme.icons + _, _i = theme._, theme._icons specs_path = args.specifications_file engine_uri = getattr(args, "engine_uri", None) @@ -485,6 +492,7 @@ async def _init_orchestrator() -> Orchestrator: not orchestrator.tool.is_empty ), "--tools-confirm requires tools" + permanent_message = orchestrator.memory.permanent_message logger.debug( "Agent loaded from %s, models used: %s, with recent message " "memory: %s, with permanent message memory: %s", @@ -492,15 +500,16 @@ async def _init_orchestrator() -> Orchestrator: orchestrator.model_ids, "yes" if orchestrator.memory.has_recent_message else "no", ( - "yes, with session #" - + str(orchestrator.memory.permanent_message.session_id) + "yes, with session #" + str(permanent_message.session_id) if orchestrator.memory.has_permanent_message + and permanent_message else "no" ), ) if not args.quiet: - assert orchestrator.engine_agent and orchestrator.engine.model_id + assert orchestrator.engine_agent + assert orchestrator.engine and orchestrator.engine.model_id is_local = not isinstance( orchestrator.engine, TextGenerationVendorModel @@ -530,26 +539,30 @@ async def _init_orchestrator() -> Orchestrator: else: await orchestrator.memory.start_session() + recent_message = orchestrator.memory.recent_message if ( load_recent_messages and orchestrator.memory.has_recent_message - and not orchestrator.memory.recent_message.is_empty + and recent_message + and not recent_message.is_empty and not args.quiet ): console.print( theme.recent_messages( participant_id, orchestrator, - orchestrator.memory.recent_message.data, + recent_message.data, ) ) return orchestrator + spinner = theme.get_spinner("agent_loading") + assert spinner is not None async with AsyncExitStack() as stack: with console.status( _("Loading agent..."), - spinner=theme.get_spinner("agent_loading"), + spinner=spinner, refresh_per_second=refresh_per_second, ): orchestrator = await _init_orchestrator() @@ -569,16 +582,21 @@ async def _init_orchestrator() -> Orchestrator: specs_mtime = new_mtime in_conversation = False continue + current_recent_message = orchestrator.memory.recent_message + recent_size = ( + str(current_recent_message.size) + if orchestrator.memory.has_recent_message + and current_recent_message + else "0" + ) logger.debug( "Waiting for new message to add to orchestrator's existing " - + str(orchestrator.memory.recent_message.size) - if orchestrator.memory - and orchestrator.memory.has_recent_message - else "0" + " messages" + + recent_size + + " messages" ) input_string = get_input( console, - _i["user_input"] + " ", + (_i.get("user_input") or "") + " ", echo_stdin=not args.no_repl, force_prompt=in_conversation, is_quiet=args.quiet, @@ -596,7 +614,7 @@ async def _init_orchestrator() -> Orchestrator: ) if not args.quiet and not args.stats: - console.print(_i["agent_output"] + " ", end="") + console.print((_i.get("agent_output") or "") + " ", end="") if args.quiet: console.print(await output.to_str()) diff --git a/src/avalan/cli/commands/cache.py b/src/avalan/cli/commands/cache.py index f507f060..ac9fe515 100644 --- a/src/avalan/cli/commands/cache.py +++ b/src/avalan/cli/commands/cache.py @@ -1,5 +1,6 @@ from ...cli import confirm from ...cli.download import create_live_tqdm_class +from ...cli.theme import Theme from ...model.hubs import HubAccessDeniedException from ...model.hubs.huggingface import HuggingfaceHub @@ -7,7 +8,6 @@ from rich.console import Console from rich.padding import Padding -from rich.theme import Theme def cache_delete( @@ -27,7 +27,7 @@ def cache_delete( console.print(theme.cache_delete(cache_deletion, False)) return - console.print(theme.cache_delete(cache_deletion)) + console.print(theme.cache_delete(cache_deletion, False)) if not args.delete and not confirm(console, theme.ask_delete_paths()): return execute_deletion() diff --git a/src/avalan/cli/commands/memory.py b/src/avalan/cli/commands/memory.py index 429ac6ee..a48f4629 100644 --- a/src/avalan/cli/commands/memory.py +++ b/src/avalan/cli/commands/memory.py @@ -1,6 +1,7 @@ from ...cli import get_input from ...cli.commands import get_model_settings from ...cli.commands.model import model_display +from ...cli.theme import Theme from ...entities import DistanceType, Modality, SearchMatch, Similarity from ...memory.partitioner.code import CodePartitioner from ...memory.partitioner.text import TextPartition, TextPartitioner @@ -15,6 +16,7 @@ from io import BytesIO from logging import Logger from pathlib import Path +from typing import Any, cast from urllib.parse import urlparse from uuid import UUID @@ -23,7 +25,6 @@ from numpy import abs, corrcoef, dot, sum, vstack from numpy.linalg import norm from rich.console import Console -from rich.theme import Theme async def memory_document_index( @@ -40,7 +41,7 @@ async def memory_document_index( def transform(html: bytes) -> DocumentConverterResult: return MarkItDown().convert_stream(BytesIO(html)) - _, _i = theme._, theme.icons + _, _i = theme._, cast(dict[str, Any], theme.icons) model_id = args.model source = args.source participant_id = UUID(args.participant) @@ -139,7 +140,7 @@ def transform(html: bytes) -> DocumentConverterResult: args.encoding, args.partition_max_tokens, ) - partitions: list[TextPartition] = [] + partitions = [] for cp in code_partitions: embeddings = await stm(cp.data) tokens = stm.token_count(cp.data) @@ -184,7 +185,7 @@ async def memory_embeddings( logger: Logger, ) -> None: assert args.model - _, _i = theme._, theme.icons + _, _i = theme._, cast(dict[str, Any], theme.icons) model_id = args.model display_partitions = ( args.display_partitions if not args.no_display_partitions else None @@ -357,7 +358,7 @@ async def memory_embeddings( index = IndexFlatL2(input_string_embeddings.shape[0]) - if partitioner: + if partitioner and knowledge_partitions is not None: knowledge_stack = vstack( [kp.embeddings for kp in knowledge_partitions] ).astype("float32", copy=False) @@ -414,7 +415,7 @@ async def memory_search( assert args.model and args.dsn and args.participant and args.namespace assert args.function - _, _i = theme._, theme.icons + _, _i = theme._, cast(dict[str, Any], theme.icons) model_id = args.model participant_id = UUID(args.participant) namespace = args.namespace diff --git a/src/avalan/cli/commands/model.py b/src/avalan/cli/commands/model.py index 73e93605..465a0ebf 100644 --- a/src/avalan/cli/commands/model.py +++ b/src/avalan/cli/commands/model.py @@ -1,7 +1,11 @@ from ...agent import Specification from ...agent.orchestrator import Orchestrator +from ...agent.orchestrator.response.orchestrator_response import ( + OrchestratorResponse, +) from ...cli import confirm, get_input, has_input from ...cli.commands.cache import cache_delete, cache_download +from ...cli.theme import Theme from ...entities import ( GenerationSettings, # noqa: F401 Modality, @@ -13,6 +17,7 @@ from ...event import TOOL_TYPES, Event, EventStats, EventType from ...model.call import ModelCall, ModelCallContext from ...model.criteria import KeywordStoppingCriteria # noqa: F401 +from ...model.engine import Engine from ...model.hubs.huggingface import HuggingfaceHub from ...model.manager import ModelManager from ...model.nlp.sentence import SentenceTransformerModel @@ -36,13 +41,13 @@ from datetime import datetime, timezone from logging import Logger from time import perf_counter +from typing import Any, cast from rich.console import Console, Group, RenderableType from rich.live import Live from rich.padding import Padding from rich.prompt import Prompt from rich.spinner import Spinner -from rich.theme import Theme def model_display( @@ -101,6 +106,9 @@ def model_display( ) ) elif model: + assert ( + model.tokenizer_config is not None + ), "Tokenizer config is required for model display" console.print( Padding( theme.model_display( @@ -149,7 +157,7 @@ async def model_run( logger: Logger, ) -> None: assert args.model and args.device and args.max_new_tokens - _, _i = theme._, theme.icons + _, _i = theme._, cast(dict[str, Any], theme.icons) with ModelManager(hub, logger) as manager: engine_uri = manager.parse_uri(args.model) @@ -236,6 +244,13 @@ async def model_run( console.print(theme.display_token_labels([output])) elif operation.modality == Modality.TEXT_GENERATION: + assert isinstance( + operation.input, str + ), "Text generation requires string input" + text_params = operation.parameters["text"] + assert ( + text_params is not None + ), "Text generation requires text parameters" await token_generation( args=args, console=console, @@ -247,7 +262,7 @@ async def model_run( input_string=operation.input, refresh_per_second=refresh_per_second, response=output, - dtokens_pick=operation.parameters["text"].pick_tokens, + dtokens_pick=text_params.pick_tokens or 0, display_tokens=args.display_tokens or 0, with_stats=not args.quiet, tool_events_limit=args.display_tools_events, @@ -291,9 +306,10 @@ async def model_search( model_access: dict[str, bool] = {} # Fetch matching models + spinner_name = theme.get_spinner("downloading") with console.status( _("Loading models..."), - spinner=theme.get_spinner("downloading"), + spinner=spinner_name or "dots", refresh_per_second=refresh_per_second, ): models = [ @@ -313,9 +329,11 @@ async def model_search( ] # Tasks to check model access + def _check_access(model_id: str) -> tuple[str, bool]: + return (model_id, hub.can_access(model_id)) + tasks = [ - create_task(to_thread(lambda id=model.id: (id, hub.can_access(id)))) - for model in models + create_task(to_thread(_check_access, model.id)) for model in models ] def _render( @@ -371,9 +389,9 @@ async def token_generation( logger: Logger, orchestrator: Orchestrator | None, event_stats: EventStats | None, - lm: TextGenerationModel, + lm: TextGenerationModel | Engine | None, input_string: str, - response: TextGenerationResponse, + response: TextGenerationResponse | OrchestratorResponse, *, display_tokens: int, dtokens_pick: int, @@ -381,7 +399,7 @@ async def token_generation( tool_events_limit: int | None, with_stats: bool = True, live_container: dict[str, Live | None] | None = None, -): +) -> None: # If no statistics needed, return as early as possible if not with_stats: async for token in response: @@ -513,7 +531,6 @@ async def _event_stream( events_renderable = theme.events( event_manager.history, events_limit=6 if tool_view else 4, - height=tools_height if tool_view else events_height, include_tokens=False, include_tools=tool_view, include_tool_detect=False, @@ -543,9 +560,9 @@ async def _token_stream( logger: Logger, orchestrator: Orchestrator | None, event_stats: EventStats | None, - lm: TextGenerationModel, + lm: TextGenerationModel | Engine | None, input_string: str, - response: TextGenerationResponse, + response: TextGenerationResponse | OrchestratorResponse, *, display_tokens: int, dtokens_pick: int, @@ -564,7 +581,7 @@ async def _token_stream( start_thinking = ( args.start_thinking if hasattr(args, "start_thinking") else False ) - tokens = [] + tokens: list[Token] = [] answer_text_tokens: list[str] = [] thinking_text_tokens: list[str] = [] tool_text_tokens: list[str] = [] @@ -577,15 +594,21 @@ async def _token_stream( 100 if display_pause > 0 and display_tokens > 0 else 0 ) - input_token_count = ( - response.input_token_count - if response.input_token_count - else ( - orchestrator.input_token_count - if orchestrator - else lm.input_token_count(input_string) - ) - ) + assert lm is not None, "Language model must be provided" + input_token_count: int + if response.input_token_count: + input_token_count = response.input_token_count + elif orchestrator and orchestrator.input_token_count is not None: + orch_count = orchestrator.input_token_count + if callable(orch_count): + input_token_count = await orch_count() or 0 + else: + input_token_count = orch_count # type: ignore[assignment] + else: + assert hasattr( + lm, "input_token_count" + ), "lm must have input_token_count method" + input_token_count = lm.input_token_count(input_string) # type: ignore[union-attr] ttft: float | None = None ttnt: float | None = None last_current_dtoken: Token | None = None @@ -609,13 +632,14 @@ async def _token_stream( answer_text_tokens = [] tool_text_tokens = [] thinking_text_tokens = [] - inner_response = event.payload["response"] - assert isinstance(inner_response, TextGenerationResponse) - if inner_response.input_token_count: - input_token_count = inner_response.input_token_count + if event.payload is not None: + inner_response = event.payload["response"] + assert isinstance(inner_response, TextGenerationResponse) + if inner_response.input_token_count: + input_token_count = inner_response.input_token_count elif event.type == EventType.TOOL_RESULT: tool_event_results.append(event) - if "call" in event.payload: + if event.payload is not None and "call" in event.payload: completed_call_ids.add(event.payload["call"].id) else: tool_event_calls.append(event) @@ -640,16 +664,35 @@ async def _token_stream( tool_running_spinner = None if tool_event_calls or tool_event_results: - tool_calling_names = [ - c.name - for e in tool_event_calls - for c in e.payload - if c.id not in completed_call_ids - ] - + tool_calling_names: list[str] = [] + for e in tool_event_calls: + if e.payload is None: + continue + # Handle TOOL_EXECUTE events with {"call": call} + if "call" in e.payload: + call = e.payload["call"] + if ( + hasattr(call, "id") + and hasattr(call, "name") + and call.id not in completed_call_ids + ): + tool_calling_names.append(call.name) + # Handle TOOL_PROCESS events with {"calls": [call, ...]} + elif "calls" in e.payload: + calls = e.payload["calls"] + if calls: + for c in calls: + if ( + hasattr(c, "id") + and hasattr(c, "name") + and c.id not in completed_call_ids + ): + tool_calling_names.append(c.name) + + tool_spinner_name = theme.get_spinner("tool_running") tool_running_spinner = ( Spinner( - theme.get_spinner("tool_running"), + tool_spinner_name or "dots", text="[cyan]" + theme._n( "Running tool {tool_names}...", @@ -687,7 +730,8 @@ async def _token_stream( ) answer_height = getattr(args, "display_answer_height", 12) - token_frames_promise = theme.tokens( + assert lm.model_id is not None, "Model ID must not be None" + token_frames_generator = theme.tokens( lm.model_id, lm.tokenizer_config.tokens if lm.tokenizer_config else None, ( @@ -699,24 +743,32 @@ async def _token_stream( args.display_probabilities if dtokens_pick > 0 else False, dtokens_pick, # Which tokens to mark as interesting - lambda dtoken: ( + ( ( - dtoken.probability < args.display_probabilities_maximum - or len( - [ - t - for t in dtoken.tokens - if t.id != dtoken.id - and t.probability - >= args.display_probabilities_sample_minimum - ] + lambda dtoken: ( + dtoken.probability is not None + and dtoken.probability + < args.display_probabilities_maximum + ) + or ( + hasattr(dtoken, "tokens") + and dtoken.tokens is not None + and len( + [ + t + for t in dtoken.tokens + if t.id != dtoken.id + and t.probability is not None + and t.probability + >= args.display_probabilities_sample_minimum + ] + ) + > 0 ) - > 0 ) if display_tokens and args.display_probabilities and args.display_probabilities_maximum > 0 - and args.display_probabilities_maximum > 0 else None ), thinking_text_tokens, @@ -729,7 +781,7 @@ async def _token_stream( tool_event_calls, tool_event_results, tool_running_spinner, - ttft, + ttft if ttft is not None else 0.0, ttnt, ttsr, elapsed, @@ -744,7 +796,7 @@ async def _token_stream( ) token_frame_list = [ - token_frame async for token_frame in token_frames_promise + token_frame async for token_frame in token_frames_generator ] token_frames = [token_frame_list[0]] diff --git a/src/avalan/cli/commands/tokenizer.py b/src/avalan/cli/commands/tokenizer.py index dad21d7c..fcf98a61 100644 --- a/src/avalan/cli/commands/tokenizer.py +++ b/src/avalan/cli/commands/tokenizer.py @@ -1,4 +1,5 @@ from ...cli import get_input +from ...cli.theme import Theme from ...entities import Token, TransformerEngineSettings from ...model.hubs.huggingface import HuggingfaceHub from ...model.nlp.text.generation import TextGenerationModel @@ -7,7 +8,6 @@ from logging import Logger from rich.console import Console -from rich.theme import Theme async def tokenize( @@ -19,7 +19,7 @@ async def tokenize( ) -> list[Token] | None: assert args.tokenizer - _, _i, _n = theme._, theme.icons, theme._n + _, _i, _n = theme._, theme._icons, theme._n tokenizer_name_or_path = args.tokenizer with TextGenerationModel( @@ -57,20 +57,21 @@ async def tokenize( paths = lm.save_tokenizer(args.save) total_files = len(paths) console.print(theme.saved_tokenizer_files(args.save, total_files)) - return + return None tty_path = getattr(args, "tty", "/dev/tty") or "/dev/tty" + user_input_icon = _i.get("user_input") or "" input_string = get_input( console, - _i["user_input"] + " ", + user_input_icon + " ", echo_stdin=not args.no_repl, is_quiet=args.quiet, tty_path=tty_path, ) if input_string: logger.debug("Loaded model %s", lm.config.__repr__()) - tokens = lm.tokenize(input_string) + tokens: list[Token] = lm.tokenize(input_string) panel = theme.tokenizer_tokens( tokens, @@ -79,3 +80,5 @@ async def tokenize( display_details=True, ) console.print(panel) + return tokens + return None diff --git a/src/avalan/cli/download.py b/src/avalan/cli/download.py index cbe0a14a..55b37079 100644 --- a/src/avalan/cli/download.py +++ b/src/avalan/cli/download.py @@ -60,5 +60,5 @@ def display(self, *_, **__): def reset(self, total=None): if hasattr(self, "_progress"): - self._progress.reset(total=total) + self._progress.reset(task_id=self._task_id, total=total) super().reset(total=total) diff --git a/src/avalan/cli/theme/__init__.py b/src/avalan/cli/theme/__init__.py index 2ab7bbe3..0466f62c 100644 --- a/src/avalan/cli/theme/__init__.py +++ b/src/avalan/cli/theme/__init__.py @@ -1,7 +1,6 @@ from ...agent.orchestrator import Orchestrator from ...entities import ( EngineMessage, - EngineMessageScored, HubCache, HubCacheDeletion, ImageEntity, @@ -14,29 +13,41 @@ TokenizerConfig, User, ) +from ...entities import ( + EngineMessageScored as EngineMessageScored, +) from ...event import Event, EventStats from ...memory.partitioner.text import TextPartition from ...memory.permanent import Memory as Memory from ...memory.permanent import PermanentMemoryPartition from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from dataclasses import fields from datetime import datetime from enum import StrEnum from logging import Logger -from typing import Callable, Generator, Literal +from typing import Any, Callable, Literal, get_args from uuid import UUID from humanize import intcomma, intword, naturalsize, naturaltime from numpy import ndarray from rich.console import RenderableType +from rich.spinner import Spinner as RichSpinner Formatter = ( Callable[[datetime], str] | Callable[[float], str] | Callable[[int], str] ) Formatters = dict[Literal["datetime", "number", "quantity", "size"], Formatter] -Spinner = Literal["cache_accessing", "connecting", "thinking", "downloading"] -Data = StrEnum( +SpinnerName = Literal[ + "agent_loading", + "cache_accessing", + "connecting", + "downloading", + "thinking", + "tool_running", +] +Data = StrEnum( # type: ignore[misc] "Data", { **{field.name: field.name for field in fields(Model)}, @@ -48,10 +59,10 @@ class Theme(ABC): - _all_spinners: dict[Spinner, str] - _all_stylers: Stylers - _all_styles: dict[Data, str] - _icons: dict[Data, str] + _all_spinners: dict[SpinnerName, str | None] + _all_stylers: dict[str, Any] + _all_styles: dict[str, str] + _icons: dict[str, str | None] _: Callable[[str], str] @property @@ -67,7 +78,7 @@ def quantity_data(self) -> list[str]: return [] @property - def spinners(self) -> dict[Spinner, str]: + def spinners(self) -> dict[SpinnerName, str]: return {} @property @@ -97,7 +108,7 @@ def agent( agent: Orchestrator, *args, models: list[Model | str], - cans_access: bool | None, + can_access: bool | None, ) -> RenderableType: raise NotImplementedError() @@ -252,7 +263,7 @@ def recent_messages( @abstractmethod def saved_tokenizer_files( - directory_path: str, total_files: int + self, directory_path: str, total_files: int ) -> RenderableType: raise NotImplementedError() @@ -261,8 +272,8 @@ def search_message_matches( self, participant_id: UUID, agent: Orchestrator, - messages: list[EngineMessageScored], - ): + messages: list[EngineMessage], + ) -> RenderableType: raise NotImplementedError() @abstractmethod @@ -284,8 +295,9 @@ def tokenizer_tokens( dtokens: list[Token], added_tokens: list[str] | None, special_tokens: list[str] | None, + display_details: bool = False, current_dtoken: Token | None = None, - dtokens_selected: list[Token] = [], + dtokens_selected: list[Token] | None = None, ) -> RenderableType: raise NotImplementedError() @@ -334,6 +346,7 @@ async def tokens( tool_events: list[Event] | None, tool_event_calls: list[Event] | None, tool_event_results: list[Event] | None, + tool_running_spinner: RichSpinner | None, ttft: float, ttnt: float | None, ttsr: float | None, @@ -355,8 +368,9 @@ async def tokens( limit_tool_height: bool = True, limit_answer_height: bool = False, start_thinking: bool = False, - ) -> Generator[tuple[Token | None, RenderableType], None, None]: + ) -> AsyncGenerator[tuple[Token | None, RenderableType], None]: raise NotImplementedError() + yield # pragma: no cover - Makes this an async generator @abstractmethod def welcome( @@ -376,7 +390,7 @@ def __init__( formatters: Formatters = {}, stylers: Stylers = {}, styles: dict[Data, str] = {}, - spinners: dict[Spinner, str] = {}, + spinners: dict[SpinnerName, str] = {}, icons: dict[Data, str] = {}, quantity_data: list[str] = [], ): @@ -392,20 +406,20 @@ def __init__( **self.formatters, } - all_icons = { + all_icons: dict[str, str | None] = { **{data: None for data in data_keys}, - **self.icons, - **icons, + **{str(k): v for k, v in self.icons.items()}, + **{str(k): v for k, v in icons.items()}, } data_keys.extend(all_icons) data_keys.extend(self.styles.keys()) - self._all_spinners = { - **{spinner: None for spinner in Spinner.__args__}, + self._all_spinners: dict[SpinnerName, str | None] = { + **{spinner: None for spinner in get_args(SpinnerName)}, **self.spinners, } - self._all_stylers = { + self._all_stylers: dict[str, Any] = { **{ data: lambda data, value, prefix=None, icon=True: "".join( [ @@ -420,13 +434,13 @@ def __init__( ), f"[{data}]", ( - formatters["quantity"](value) + formatters["quantity"](value) # type: ignore[arg-type] if data in quantity_data else ( - formatters["datetime"](value) + formatters["datetime"](value) # type: ignore[arg-type] if isinstance(value, datetime) else ( - formatters["number"](value) + formatters["number"](value) # type: ignore[arg-type] if isinstance(value, int) or isinstance(value, float) else value @@ -438,17 +452,20 @@ def __init__( ) for data in data_keys }, - **self.stylers, + **{str(k): v for k, v in self.stylers.items()}, + } + self._all_styles: dict[str, str] = { + **{data: "" for data in data_keys}, + **{str(k): v for k, v in self.styles.items()}, } - self._all_styles = {**{data: "" for data in data_keys}, **self.styles} self._icons = all_icons self._ = translator self._n = translator_plurals - def get_styles(self) -> dict[Data, str]: + def get_styles(self) -> dict[str, str]: return self._all_styles - def get_spinner(self, spinner_name: str) -> Spinner: + def get_spinner(self, spinner_name: SpinnerName) -> str | None: return self._all_spinners[spinner_name] def __call__(self, item: Model | str) -> RenderableType: @@ -461,8 +478,8 @@ def _f( prefix: str | None = None, icon: bool | str = True, ) -> str: - return ( - self._all_stylers[data](data, value, prefix, icon) - if data in self._all_stylers and self._all_stylers[data] - else str(value) - ) + styler = self._all_stylers.get(data) + if styler is not None: + result: str = styler(data, value, prefix, icon) # type: ignore[call-arg] + return result + return str(value) diff --git a/src/avalan/cli/theme/fancy.py b/src/avalan/cli/theme/fancy.py index 9c193c4c..db8dff09 100644 --- a/src/avalan/cli/theme/fancy.py +++ b/src/avalan/cli/theme/fancy.py @@ -1,8 +1,7 @@ from ...agent.orchestrator import Orchestrator -from ...cli.theme import Data, Spinner, Theme +from ...cli.theme import Data, SpinnerName, Theme from ...entities import ( EngineMessage, - EngineMessageScored, HubCache, HubCacheDeletion, ImageEntity, @@ -12,8 +11,10 @@ SentenceTransformerModelConfig, Similarity, Token, + TokenDetail, TokenizerConfig, ToolCallError, + ToolCallResult, User, ) from ...event import TOOL_TYPES, Event, EventStats, EventType @@ -21,13 +22,14 @@ from ...memory.permanent import PermanentMemoryPartition from ...utils import _j, _lf, to_json -from datetime import datetime +from collections.abc import AsyncGenerator +from datetime import datetime, timedelta from locale import format_string from logging import Logger from math import ceil, inf from re import sub from textwrap import wrap -from typing import Callable, Generator +from typing import Any, Callable, cast from uuid import UUID from humanize import ( @@ -53,94 +55,106 @@ TimeElapsedColumn, ) from rich.rule import Rule +from rich.spinner import Spinner as RichSpinner from rich.table import Column, Table from rich.text import Text class FancyTheme(Theme): + """Fancy theme implementation with rich formatting and icons.""" + @property def icons(self) -> dict[Data, str]: + """Return mapping of data keys to emoji icons.""" return { - "access_token_name": ":lock:", - "agent_id": ":robot:", - "agent_output": ":robot:", - "avalan": ":heavy_large_circle:", - "author": ":briefcase:", - "bye": ":vulcan_salute:", - "can_access": ":white_check_mark:", - "checking_access": ":mag:", - "created_at": ":calendar:", - "disabled": ":cross_mark:", - "download": ":floppy_disk:", - "download_access_denied": ":exclamation_mark:", - "download_finished": ":heavy_check_mark:", - "downloads": ":floppy_disk:", - "gated": ":key:", - "inference": ":brain:", - "input_token_count": ":laptop_computer:", - "library_name": ":books:", - "license": ":balance_scale:", - "likes": ":orange_heart:", - "memory": ":brain:", - "model_id": ":name_badge:", - "model_type": ":robot_face:", - "no_access": ":no_entry_sign:", - "parameters": ":abacus:", - "pipeline_tag": ":gear:", - "private": ":closed_lock_with_key:", - "ranking": ":trophy:", - "path_blobs": ":file_folder:", - "path_refs": ":file_folder:", - "path_repository": ":file_folder:", - "path_snapshot": ":file_folder:", - "session": ":card_index_dividers:", - "task_id": ":robot:", - "total_tokens": ":abacus:", - "tokens_rate": ":high_voltage:", - "events": ":bookmark_tabs:", - "tool_calls": ":hammer:", - "tool_call_results": ":package:", - "ttft": ":seedling:", - "ttnt": ":alarm_clock:", - "ttsr": ":thinking_face:", - "updated_at": ":calendar:", - "user": ":hugging_face:", - "user_input": ":speaking_head:", - "tags": ":label:", + cast(Data, "access_token_name"): ":lock:", + cast(Data, "agent_id"): ":robot:", + cast(Data, "agent_output"): ":robot:", + cast(Data, "avalan"): ":heavy_large_circle:", + cast(Data, "author"): ":briefcase:", + cast(Data, "bye"): ":vulcan_salute:", + cast(Data, "can_access"): ":white_check_mark:", + cast(Data, "checking_access"): ":mag:", + cast(Data, "created_at"): ":calendar:", + cast(Data, "disabled"): ":cross_mark:", + cast(Data, "download"): ":floppy_disk:", + cast(Data, "download_access_denied"): ":exclamation_mark:", + cast(Data, "download_finished"): ":heavy_check_mark:", + cast(Data, "downloads"): ":floppy_disk:", + cast(Data, "gated"): ":key:", + cast(Data, "inference"): ":brain:", + cast(Data, "input_token_count"): ":laptop_computer:", + cast(Data, "library_name"): ":books:", + cast(Data, "license"): ":balance_scale:", + cast(Data, "likes"): ":orange_heart:", + cast(Data, "memory"): ":brain:", + cast(Data, "model_id"): ":name_badge:", + cast(Data, "model_type"): ":robot_face:", + cast(Data, "no_access"): ":no_entry_sign:", + cast(Data, "parameters"): ":abacus:", + cast(Data, "pipeline_tag"): ":gear:", + cast(Data, "private"): ":closed_lock_with_key:", + cast(Data, "ranking"): ":trophy:", + cast(Data, "path_blobs"): ":file_folder:", + cast(Data, "path_refs"): ":file_folder:", + cast(Data, "path_repository"): ":file_folder:", + cast(Data, "path_snapshot"): ":file_folder:", + cast(Data, "session"): ":card_index_dividers:", + cast(Data, "task_id"): ":robot:", + cast(Data, "total_tokens"): ":abacus:", + cast(Data, "tokens_rate"): ":high_voltage:", + cast(Data, "events"): ":bookmark_tabs:", + cast(Data, "tool_calls"): ":hammer:", + cast(Data, "tool_call_results"): ":package:", + cast(Data, "ttft"): ":seedling:", + cast(Data, "ttnt"): ":alarm_clock:", + cast(Data, "ttsr"): ":thinking_face:", + cast(Data, "updated_at"): ":calendar:", + cast(Data, "user"): ":hugging_face:", + cast(Data, "user_input"): ":speaking_head:", + cast(Data, "tags"): ":label:", } @property def styles(self) -> dict[Data, str]: + """Return mapping of data keys to rich styles.""" return { - "id": "bold", - "can_access": "green", - "checking_access": "bright_black blink", - "created_at": "magenta", - "downloads": "bright_black", - "likes": "bright_black", - "memory": "magenta", - "memory_embedding_comparison": "dark_orange3", - "memory_embedding_comparison_similarity": "dark_orange3", - "memory_embedding_comparison_similarity_high": ( - "bold dark_olive_green3" - ), - "memory_embedding_comparison_similarity_middle": "orange_red1", - "memory_embedding_comparison_similarity_low": "dark_red", - "model_id": "cyan", - "no_access": "bold red", - "parameters": "bold cyan", - "participant_id": "bold", - "ranking": "bright_black", - "session_id": "dark_orange3", - "score": "dark_orange3", - "tags": "gray30", - "updated_at": "magenta", - "user": "bold cyan", - "version": "bold", + cast(Data, "id"): "bold", + cast(Data, "can_access"): "green", + cast(Data, "checking_access"): "bright_black blink", + cast(Data, "created_at"): "magenta", + cast(Data, "downloads"): "bright_black", + cast(Data, "likes"): "bright_black", + cast(Data, "memory"): "magenta", + cast(Data, "memory_embedding_comparison"): "dark_orange3", + cast( + Data, "memory_embedding_comparison_similarity" + ): "dark_orange3", + cast( + Data, "memory_embedding_comparison_similarity_high" + ): "bold dark_olive_green3", + cast( + Data, "memory_embedding_comparison_similarity_middle" + ): "orange_red1", + cast( + Data, "memory_embedding_comparison_similarity_low" + ): "dark_red", + cast(Data, "model_id"): "cyan", + cast(Data, "no_access"): "bold red", + cast(Data, "parameters"): "bold cyan", + cast(Data, "participant_id"): "bold", + cast(Data, "ranking"): "bright_black", + cast(Data, "session_id"): "dark_orange3", + cast(Data, "score"): "dark_orange3", + cast(Data, "tags"): "gray30", + cast(Data, "updated_at"): "magenta", + cast(Data, "user"): "bold cyan", + cast(Data, "version"): "bold", } @property - def spinners(self) -> dict[Spinner, str]: + def spinners(self) -> dict[SpinnerName, str]: + """Return mapping of spinner names to rich spinner types.""" return { "agent_loading": "dots12", "cache_accessing": "bouncingBar", @@ -152,6 +166,7 @@ def spinners(self) -> dict[Spinner, str]: @property def quantity_data(self) -> list[str]: + """Return list of data keys that should use quantity formatting.""" return ["likes"] def action( @@ -164,10 +179,15 @@ def action( highlight: bool, finished: bool, ) -> RenderableType: + """Render an action panel.""" _i = self._icons description_color = ( "green" if finished else "white" if highlight else "gray62" ) + author_icon = _i.get(cast(Data, "author")) or "" + library_icon = _i.get(cast(Data, "library_name")) or "" + task_icon = _i.get(cast(Data, "task_id")) or "" + model_icon = _i.get(cast(Data, "model_id")) or "" return Panel.fit( Padding( Group( @@ -178,11 +198,11 @@ def action( f"{description}[/{description_color}]" ), ( - _i["author"] + author_icon + ( f" [bright_black]{author}[/bright_black]" + " · " - + _i["library_name"] + + library_icon + f" [bright_black]{library_name}" + "[/bright_black]" ) @@ -193,34 +213,39 @@ def action( ) ) ), - title=_i["task_id"] + f" [cyan]{name}[/cyan]", - subtitle=_i["model_id"] - + f" [bright_black]{model_id}[/bright_black]", + title=(task_icon + f" [cyan]{name}[/cyan]"), + subtitle=( + model_icon + f" [bright_black]{model_id}[/bright_black]" + ), box=box.DOUBLE if highlight else box.SQUARE, ) def agent( self, agent: Orchestrator, - *args, + *args: Any, models: list[Model | str], can_access: bool | None, ) -> RenderableType: + """Render an agent panel with model information.""" _, _f, _i = self._, self._f, self._icons + model_id_icon = _i.get(cast(Data, "model_id")) or "" models_group = Group( *_lf( [ - _i["model_id"] + model_id_icon + " " + ", ".join( [ ( _("{model_id} ({parameters})").format( model_id=_f( - "model_id", model.id, icon=False + cast(Data, "model_id"), + model.id, + icon=False, ), parameters=_f( - "parameters", + cast(Data, "parameters"), _("{n} params").format( n=self._parameter_count( model.parameters @@ -236,42 +261,52 @@ def agent( ] ), _f( - "memory", + cast(Data, "memory"), _j( ", ", - [ - ( - _("short-term message") - if agent.memory.has_recent_message - else None - ), - ( - _("long-term message ({driver})").format( - driver=type( - agent.memory.permanent_message - ).__name__ - ) - if agent.memory.has_permanent_message - else None - ), - ], + cast( + list[str], + [ + ( + _("short-term message") + if agent.memory.has_recent_message + else None + ), + ( + _( + "long-term message ({driver})" + ).format( + driver=type( + agent.memory.permanent_message + ).__name__ + ) + if agent.memory.has_permanent_message + else None + ), + ], + ), empty=_("stateless"), ), ), ( _f( - "session", + cast(Data, "session"), " " + _("session: {session_id}").format( session_id=_f( - "session_id", - str( - agent.memory.permanent_message.session_id + cast(Data, "session_id"), + ( + str( + agent.memory.permanent_message.session_id + ) + if agent.memory.permanent_message + else "" ), ) ), ) if agent.memory.has_permanent_message + and agent.memory.permanent_message and agent.memory.permanent_message.has_session else None ), @@ -280,37 +315,48 @@ def agent( ) return Panel( models_group, - title=_f("agent_id", agent.name if agent.name else str(agent.id)), + title=_f( + cast(Data, "agent_id"), + agent.name if agent.name else str(agent.id), + ), box=box.DOUBLE, ) def ask_access_token(self) -> str: + """Return prompt text for access token input.""" _ = self._ return _("Enter your Huggingface access token") def ask_delete_paths(self) -> str: + """Return prompt text for delete paths confirmation.""" _ = self._ return _("Delete selected paths?") def ask_login_to_hub(self) -> str: + """Return prompt text for hub login confirmation.""" _ = self._ return _("Login to huggingface?") def ask_secret_password(self, key: str) -> str: + """Return prompt text for secret password input.""" _ = self._ return _("Enter secret for {key}").format(key=key) def ask_override_secret(self, key: str) -> str: + """Return prompt text for secret override confirmation.""" _ = self._ return _("Secret {key} exists, override?").format(key=key) def bye(self) -> RenderableType: + """Return goodbye message.""" _, _i = self._, self._icons - return _i["bye"] + " " + _("bye :)") + bye_icon = _i.get(cast(Data, "bye")) or "" + return bye_icon + " " + _("bye :)") def cache_delete( - self, cache_deletion: HubCacheDeletion | None, deleted=False + self, cache_deletion: HubCacheDeletion | None, deleted: bool = False ) -> RenderableType: + """Render cache deletion summary.""" _, _f, _n, _i = self._, self._f, self._n, self._icons if not cache_deletion or ( not cache_deletion.deletable_blobs @@ -331,7 +377,9 @@ def cache_delete( "{total_revisions} revisions for {model_id}", total_revisions, ).format( - model_id=_f("model_id", cache_deletion.model_id), + model_id=_f( + cast(Data, "model_id"), cache_deletion.model_id + ), total_revisions=total_revisions, disk_space=naturalsize( cache_deletion.deletable_size_on_disk @@ -360,7 +408,10 @@ def cache_delete( if deletable_paths: panel = Panel( Group( - *[_f(field_name, path) for path in deletable_paths] + *[ + _f(cast(Data, field_name), path) + for path in deletable_paths + ] ), title=title, ) @@ -374,7 +425,9 @@ def cache_delete( "{total_revisions} revisions for {model_id}", total_revisions, ).format( - model_id=_f("model_id", cache_deletion.model_id), + model_id=_f( + cast(Data, "model_id"), cache_deletion.model_id + ), total_revisions=total_revisions, disk_space=naturalsize( cache_deletion.deletable_size_on_disk @@ -391,6 +444,7 @@ def cache_list( display_models: list[str] | None = None, show_summary: bool = False, ) -> RenderableType: + """Render cache list table.""" _ = self._ if display_models and not show_summary: @@ -450,12 +504,12 @@ def cache_list( tables.append(Padding(table, pad=(1, 0, 1, 0))) return Group(*tables) else: - display_models = ( + filtered_models: list[HubCache] = ( [m for m in cached_models if m.model_id in display_models] if display_models else cached_models ) - total_cache_size = sum([m.size_on_disk for m in display_models]) + total_cache_size = sum([m.size_on_disk for m in filtered_models]) table = Table( Column(header=_("Model"), justify="left", no_wrap=True), Column(header=_("Revisions"), justify="left"), @@ -473,7 +527,7 @@ def cache_list( footer_style="bold cyan", ) - for model_cache in display_models: + for model_cache in filtered_models: summarized_revisions = [r[:6] for r in model_cache.revisions] table.add_row( model_cache.model_id, @@ -490,7 +544,9 @@ def cache_list( def download_access_denied( self, model_id: str, model_url: str ) -> RenderableType: + """Render download access denied message.""" _, _i = self._, self._icons + access_denied_icon = _i.get(cast(Data, "download_access_denied")) or "" return Group( *_lf( [ @@ -498,7 +554,7 @@ def download_access_denied( " ".join( [ "[bold red]" - + _i["download_access_denied"] + + access_denied_icon + "[/bold red]", "[red]" + _( @@ -521,17 +577,21 @@ def download_access_denied( ) def download_start(self, model_id: str) -> RenderableType: + """Render download start message.""" _, _i = self._, self._icons + download_icon = _i.get(cast(Data, "download")) or "" return Group( - _i["download"] + download_icon + " " + _("Downloading model {model_id}:").format(model_id=model_id), "", ) - def download_progress(self) -> tuple[str | RenderableType]: - _ = self._ - return ( + def download_progress( + self, + ) -> tuple[str | RenderableType]: # type: ignore[override] + """Return progress bar components for download.""" + return ( # type: ignore[return-value] SpinnerColumn(), ( "[progress.description]{task.description}" @@ -546,11 +606,13 @@ def download_progress(self) -> tuple[str | RenderableType]: ) def download_finished(self, model_id: str, path: str) -> RenderableType: + """Render download finished message.""" _, _i = self._, self._icons + finished_icon = _i.get(cast(Data, "download_finished")) or "" return Padding( " ".join( [ - "[bold green]" + _i["download_finished"] + "[/bold green]", + "[bold green]" + finished_icon + "[/bold green]", _("Downloaded model {model_id} to {path}").format( model_id=model_id, path=path ), @@ -558,7 +620,7 @@ def download_finished(self, model_id: str, path: str) -> RenderableType: ) ) - def events( + def events( # type: ignore[override] self, events: list[Event], *, @@ -569,7 +631,8 @@ def events( include_tools: bool = True, include_non_tools: bool = True, tool_view: bool = False, - ) -> RenderableType: + ) -> RenderableType | None: + """Render events panel.""" _ = self._ event_log = self._events_log( @@ -580,7 +643,7 @@ def events( include_tools=include_tools, include_non_tools=include_non_tools, ) - panel = ( + panel: Panel | None = ( Panel( _j("\n", event_log), title=_("Tool calls") if tool_view else _("Events"), @@ -598,6 +661,7 @@ def events( return panel def logging_in(self, domain: str) -> str: + """Return logging in message.""" _ = self._ return _("Logging in to {domain}...").format(domain=domain) @@ -605,7 +669,7 @@ def memory_embeddings( self, input_string: str, embeddings: ndarray, - *args, + *args: Any, total_tokens: int, minv: float, maxv: float, @@ -619,6 +683,7 @@ def memory_embeddings( partition: int | None = None, total_partitions: int | None = None, ) -> RenderableType: + """Render memory embeddings table.""" _ = self._ assert ( @@ -664,13 +729,13 @@ def memory_embeddings( ): peek_table.add_column(intcomma(i), justify="center") - columns = [] - for i, v in enumerate(embeddings[:embedding_peek]): + columns: list[str] = [] + for v in embeddings[:embedding_peek]: columns.append(clamp(v, format="{:.4g}")) columns.append("") - for i, v in enumerate(embeddings[-embedding_peek:]): + for v in embeddings[-embedding_peek:]: columns.append(clamp(v, format="{:.4g}")) peek_table.add_row(*columns) @@ -738,6 +803,7 @@ def memory_embeddings( def memory_embeddings_comparison( self, similarities: dict[str, Similarity], most_similar: str ) -> RenderableType: + """Render memory embeddings comparison table.""" assert similarities and most_similar _, _f = self._, self._f table = Table( @@ -767,32 +833,32 @@ def memory_embeddings_comparison( table.add_row( _f( - field_class, + cast(Data, field_class), compare_string, - icon=":trophy: " if is_most else None, + icon=":trophy: " if is_most else False, ), _f( - field_class, + cast(Data, field_class), clamp(similarity.cosine_distance, format="{:.4g}"), icon=False, ), _f( - field_class, + cast(Data, field_class), clamp(similarity.l1_distance, format="{:.4g}"), icon=False, ), _f( - field_class, + cast(Data, field_class), clamp(similarity.l2_distance, format="{:.4g}"), icon=False, ), _f( - field_class, + cast(Data, field_class), clamp(similarity.inner_product, format="{:.4g}"), icon=False, ), _f( - field_class, + cast(Data, field_class), clamp(similarity.pearson, format="{:.4g}"), icon=False, ), @@ -802,9 +868,10 @@ def memory_embeddings_comparison( def memory_embeddings_search( self, matches: list[SearchMatch], - *args, + *args: Any, match_preview_length: int = 300, ) -> RenderableType: + """Render memory embeddings search results.""" assert matches _, _f = self._, self._f table = Table( @@ -830,12 +897,12 @@ def memory_embeddings_search( is_most = i == 0 table.add_row( _f( - field_class, + cast(Data, field_class), match.query, - icon=":trophy: " if is_most else None, + icon=":trophy: " if is_most else False, ), _f( - field_class, + cast(Data, field_class), ( match.match if len(match.match) <= match_preview_length @@ -844,7 +911,7 @@ def memory_embeddings_search( icon=False, ), _f( - field_class, + cast(Data, field_class), clamp(match.l2_distance, format="{:.4g}"), icon=False, ), @@ -852,9 +919,12 @@ def memory_embeddings_search( return Align(table, align="center") def memory_partitions( - self, partitions: list[TextPartition], *args, display_partitions: int + self, + partitions: list[TextPartition], + *args: Any, + display_partitions: int, ) -> RenderableType: - _ = self._ + """Render memory partitions.""" total_partitions = len(partitions) head_count: int = total_partitions tail_count: int = 0 @@ -884,7 +954,7 @@ def memory_partitions( if head_count and tail_count: elements.append( - Align(Padding(_("..."), pad=(0, 0, 1, 0)), align="center") + Align(Padding(self._("..."), pad=(0, 0, 1, 0)), align="center") ) if tail_count: @@ -909,11 +979,12 @@ def memory_partitions( def model( self, model: Model, - *args, + *args: Any, can_access: bool | None = None, expand: bool = False, summary: bool = False, ) -> RenderableType: + """Render model panel.""" assert (not expand and not summary) or ( expand ^ summary ), "From expand and summary, only one can be set" @@ -925,114 +996,179 @@ def model( [ _j( " · ", - [ - _j( - " ", + cast( + list[str], + [ + _j( + " ", + cast( + list[str], + [ + ( + _f( + cast( + Data, + "checking_access", + ), + _("checking access"), + ) + if can_access is None + else ( + _f( + cast( + Data, + "can_access", + ), + _( + "access" + " granted" + ), + ) + if can_access + else _f( + cast( + Data, + "no_access", + ), + _("access denied"), + ) + ) + ), + _f( + cast(Data, "author"), + model.author, + ), + ( + _f( + cast(Data, "license"), + model.license, + ) + if expand and model.license + else None + ), + ( + _f( + cast(Data, "gated"), + _("gated"), + ) + if model.gated + else None + ), + ( + _f( + cast(Data, "private"), + _("private"), + ) + if model.private + else None + ), + ( + _f( + cast(Data, "disabled"), + _("disabled"), + ) + if model.disabled + else None + ), + ], + ), + ), + ( + ( + ( + _i.get( + cast(Data, "created_at") + ) + or "" + ) + + " " + + _j( + ", ", + cast( + list[str], + [ + ( + _f( + cast( + Data, + "created_at", + ), + model.created_at, + _("created: "), + icon=False, + ) + if expand + else None + ), + _f( + cast( + Data, + "updated_at", + ), + model.updated_at, + _("updated: "), + icon=False, + ), + ], + ), + ) + ) + if not summary + else None + ), + ], + ), + ), + ( + _j( + " · ", + cast( + list[str], [ ( _f( - "checking_access", - _("checking access"), + cast(Data, "model_type"), + model.model_type, ) - if can_access is None - else ( - _f( - "can_access", - _("access granted"), - ) - if can_access - else _f( - "no_access", - _("access denied"), + + ( + " (" + + ", ".join( + model.architectures ) + + ")" + if expand + and model.architectures + else "" ) - ), - _f("author", model.author), - ( - _f("license", model.license) - if expand and model.license + if model.model_type else None ), ( - _f("gated", _("gated")) - if model.gated + _f( + cast(Data, "library_name"), + model.library_name, + ) + if model.library_name else None ), ( - _f("private", _("private")) - if model.private + _f( + cast(Data, "inference"), + model.inference, + ) + if expand and model.inference else None ), ( - _f("disabled", _("disabled")) - if model.disabled + _f( + cast(Data, "pipeline_tag"), + model.pipeline_tag, + ) + if model.pipeline_tag else None ), ], ), - ( - ( - _i["created_at"] - + " " - + _j( - ", ", - [ - ( - _f( - "created_at", - model.created_at, - _("created: "), - icon=False, - ) - if expand - else None - ), - _f( - "updated_at", - model.updated_at, - _("updated: "), - icon=False, - ), - ], - ) - ) - if not summary - else None - ), - ], - ), - ( - _j( - " · ", - [ - ( - _f("model_type", model.model_type) - + ( - " (" - + ", ".join(model.architectures) - + ")" - if expand and model.architectures - else "" - ) - if model.model_type - else None - ), - ( - _f("library_name", model.library_name) - if model.library_name - else None - ), - ( - _f("inference", model.inference) - if expand and model.inference - else None - ), - ( - _f("pipeline_tag", model.pipeline_tag) - if model.pipeline_tag - else None - ), - ], ) if not summary else None @@ -1043,7 +1179,7 @@ def model( else None ), ( - _f("tags", " " + ", ".join(model.tags)) + _f(cast(Data, "tags"), " " + ", ".join(model.tags)) if expand and model.tags else None ), @@ -1052,26 +1188,29 @@ def model( ), # Model ID title=( - _f("model_id", model.id) + _f(cast(Data, "model_id"), model.id) + ( " " + _j( " ", - [ - _f( - "parameters", - self._parameter_count(model.parameters), - ), - ( + cast( + list[str], + [ _f( - "parameter_types", - ", ".join(model.parameter_types), - ) - if expand and model.parameter_types - else None - ), - _("parameters") if expand else _("params"), - ], + cast(Data, "parameters"), + self._parameter_count(model.parameters), + ), + ( + _f( + cast(Data, "parameter_types"), + ", ".join(model.parameter_types), + ) + if expand and model.parameter_types + else None + ), + _("parameters") if expand else _("params"), + ], + ), ) ) if not summary @@ -1081,19 +1220,26 @@ def model( subtitle=( _j( " ", - [ - ( - _f("downloads", model.downloads) - if model.downloads - else None - ), - _f("likes", model.likes) if model.likes else None, - ( - _f("ranking", model.ranking) - if model.ranking - else None - ), - ], + cast( + list[str], + [ + ( + _f(cast(Data, "downloads"), model.downloads) + if model.downloads + else None + ), + ( + _f(cast(Data, "likes"), model.likes) + if model.likes + else None + ), + ( + _f(cast(Data, "ranking"), model.ranking) + if model.ranking + else None + ), + ], + ), ) if expand else None @@ -1105,10 +1251,11 @@ def model_display( self, model_config: ModelConfig | SentenceTransformerModelConfig | None, tokenizer_config: TokenizerConfig, - *args, + *args: Any, is_runnable: bool | None = None, summary: bool = False, ) -> RenderableType: + """Render model display with config and tokenizer info.""" _ = self._ return Group( *_lf( @@ -1153,10 +1300,11 @@ def model_display( def _sentence_transformer_model_config( self, config: SentenceTransformerModelConfig, - *args, + *args: Any, is_runnable: bool | None, summary: bool, ) -> RenderableType: + """Render sentence transformer model config table.""" _ = self._ config_table = Table( Column(header="", justify="right"), @@ -1193,10 +1341,11 @@ def _sentence_transformer_model_config( def _model_config( self, config: ModelConfig, - *args, + *args: Any, is_runnable: bool | None, summary: bool, ) -> RenderableType: + """Render model config table.""" config_table = Table( Column(header="", justify="right"), Column(header="", justify="left"), @@ -1215,10 +1364,11 @@ def _fill_model_config_table( self, config: ModelConfig, config_table: Table, - *args, + *args: Any, is_runnable: bool | None, summary: bool, ) -> Table: + """Fill model config table with configuration details.""" _ = self._ config_table.add_row( _("Model type"), f"[bold]{config.model_type}[/bold]" @@ -1246,24 +1396,26 @@ def _fill_model_config_table( if not summary: config_table.add_row( _("Vocabulary size"), - f"[magenta]{intcomma(config.vocab_size)}[/magenta]", + f"[magenta]{intcomma(config.vocab_size or 0)}[/magenta]", ) config_table.add_row( _("Hidden size"), - f"[magenta]{intcomma(config.hidden_size)}[/magenta]", + f"[magenta]{intcomma(config.hidden_size or 0)}[/magenta]", ) if not summary: + num_hidden = config.num_hidden_layers or 0 config_table.add_row( _("Number of hidden layers"), - f"[magenta]{intcomma(config.num_hidden_layers)}[/magenta]", + f"[magenta]{intcomma(num_hidden)}[/magenta]", ) config_table.add_row( _("Number of attention heads"), - f"[magenta]{intcomma(config.num_attention_heads)}[/magenta]", + f"[magenta]{intcomma(config.num_attention_heads or 0)}" + "[/magenta]", ) config_table.add_row( _("Number of labels in last layer"), - f"[magenta]{intcomma(config.num_labels)}[/magenta]", + f"[magenta]{intcomma(config.num_labels or 0)}[/magenta]", ) if config.loss_type: config_table.add_row( @@ -1332,19 +1484,31 @@ def recent_messages( participant_id: UUID, agent: Orchestrator, messages: list[EngineMessage], - ): + ) -> RenderableType: + """Render recent messages panel.""" _, _f, _i = self._, self._f, self._icons + agent_output_icon = _i.get(cast(Data, "agent_output")) or "" + user_input_icon = _i.get(cast(Data, "user_input")) or "" group = Group( *_lf( [ Panel( - engine_message.message.content, + ( + str(engine_message.message.content) + if engine_message.message.content + else "" + ), title=( - _i["agent_output"] + " " + _f("id", agent.name) + agent_output_icon + + " " + + _f(cast(Data, "id"), agent.name) if engine_message.is_from_agent - else _i["user_input"] + else user_input_icon + " " - + _f("participant_id", participant_id) + + _f( + cast(Data, "participant_id"), + str(participant_id), + ) ), title_align="left", expand=True, @@ -1356,9 +1520,10 @@ def recent_messages( ) return group - def saved_tokenizer_files( + def saved_tokenizer_files( # type: ignore[override] self, directory_path: str, total_files: int ) -> RenderableType: + """Render saved tokenizer files message.""" _n = self._n return Padding( _n( @@ -1373,28 +1538,44 @@ def search_message_matches( self, participant_id: UUID, agent: Orchestrator, - messages: list[EngineMessageScored], - ): + messages: list[EngineMessage], + ) -> RenderableType: + """Render search message matches panel.""" _, _f, _i = self._, self._f, self._icons + agent_output_icon = _i.get(cast(Data, "agent_output")) or "" + user_input_icon = _i.get(cast(Data, "user_input")) or "" group = Group( *_lf( [ Panel( - engine_message.message.content, + ( + str(engine_message.message.content) + if engine_message.message.content + else "" + ), title=( - _i["agent_output"] + agent_output_icon + " " - + _f("id", agent.name or str(agent.id)) + + _f( + cast(Data, "id"), + agent.name or str(agent.id), + ) if engine_message.is_from_agent - else _i["user_input"] + else user_input_icon + " " - + _f("participant_id", participant_id) + + _f( + cast(Data, "participant_id"), + str(participant_id), + ) ), title_align="left", subtitle=_("Matching score: {score}").format( score=_f( - "score", - clamp(engine_message.score, format="{:.8g}"), + cast(Data, "score"), + clamp( + getattr(engine_message, "score", 0.0), + format="{:.8g}", + ), ) ), subtitle_align="left", @@ -1413,16 +1594,18 @@ def memory_search_matches( namespace: str, memories: list[PermanentMemoryPartition], ) -> RenderableType: + """Render memory search matches panel.""" _, _f, _i = self._, self._f, self._icons + memory_icon = _i.get(cast(Data, "memory")) or "" group = Group( *_lf( [ Panel( memory.data, title=( - _i["memory"] + memory_icon + " " - + _f("id", str(memory.memory_id)) + + _f(cast(Data, "id"), str(memory.memory_id)) ), title_align="left", subtitle=_( @@ -1430,10 +1613,13 @@ def memory_search_matches( " Partition: {partition}" ).format( participant=_f( - "participant_id", str(participant_id) + cast(Data, "participant_id"), + str(participant_id), + ), + ns=_f(cast(Data, "id"), namespace), + partition=_f( + cast(Data, "number"), memory.partition ), - ns=_f("id", namespace), - partition=_f("number", memory.partition), ), subtitle_align="left", expand=True, @@ -1446,8 +1632,9 @@ def memory_search_matches( return group def tokenizer_config( - self, config: TokenizerConfig, *args, summary: bool = False + self, config: TokenizerConfig, *args: Any, summary: bool = False ) -> RenderableType: + """Render tokenizer config table.""" _ = self._ config_table = Table( @@ -1469,12 +1656,13 @@ def tokenizer_config( _("Added tokens"), ", ".join([f"[cyan]{t}[/cyan]" for t in config.tokens]), ) - config_table.add_row( - _("Special tokens"), - ", ".join( - [f"[cyan]{t}[/cyan]" for t in config.special_tokens] - ), - ) + if config.special_tokens: + config_table.add_row( + _("Special tokens"), + ", ".join( + [f"[cyan]{t}[/cyan]" for t in config.special_tokens] + ), + ) config_table.add_row( _("Maximum sequence length"), f"[cyan]{config.tokenizer_model_max_length}[/cyan]", @@ -1486,15 +1674,18 @@ def tokenizer_config( return Align(config_table, align="center") - def tokenizer_tokens( + def tokenizer_tokens( # type: ignore[override] self, dtokens: list[Token], added_tokens: list[str] | None, special_tokens: list[str] | None, display_details: bool = False, current_dtoken: Token | None = None, - dtokens_selected: list[Token] = [], + dtokens_selected: list[Token] | None = None, ) -> RenderableType: + """Render tokenizer tokens panel.""" + if dtokens_selected is None: + dtokens_selected = [] # Build token panels compact_dtokens = True # For future configurability token_panels = [ @@ -1508,7 +1699,7 @@ def tokenizer_tokens( style=( "white on dark_green" if current_dtoken and dtoken == current_dtoken - else None + else "" ), ) ), @@ -1549,6 +1740,7 @@ def tokenizer_tokens( def display_image_entities( self, entities: list[ImageEntity], sort: bool ) -> RenderableType: + """Render image entities table.""" _ = self._ table = Table( Column(header=_("Label"), justify="left"), @@ -1569,20 +1761,21 @@ def display_image_entities( for entity in entities: score = ( - self._f("score", f"{entity.score:.2f}") + self._f(cast(Data, "score"), f"{entity.score:.2f}") if entity.score is not None else "-" ) - box = ( + entity_box = ( ", ".join(f"{v:.2f}" for v in entity.box) if entity.box else "-" ) - table.add_row(entity.label, score, box) + table.add_row(entity.label, score, entity_box) return Align(table, align="center") - def display_image_entity(self, entity: ImageEntity): + def display_image_entity(self, entity: ImageEntity) -> RenderableType: + """Render single image entity table.""" _ = self._ table = Table( Column(header=_("Label"), justify="left"), @@ -1598,6 +1791,7 @@ def display_image_entity(self, entity: ImageEntity): def display_audio_labels( self, audio_labels: dict[str, float] ) -> RenderableType: + """Render audio labels table.""" _ = self._ table = Table( Column(header=_("Label"), justify="left"), @@ -1609,13 +1803,12 @@ def display_audio_labels( border_style="gray58", ) for label, score in audio_labels.items(): - score_text = ( - self._f("score", f"{score:.2f}") if score is not None else "-" - ) + score_text = self._f(cast(Data, "score"), f"{score:.2f}") table.add_row(label, score_text) return Align(table, align="center") def display_image_labels(self, labels: list[str]) -> RenderableType: + """Render image labels table.""" _ = self._ table = Table( Column(header=_("Label"), justify="left"), @@ -1632,6 +1825,7 @@ def display_image_labels(self, labels: list[str]) -> RenderableType: def display_token_labels( self, token_labels: list[dict[str, str]] ) -> RenderableType: + """Render token labels table.""" _ = self._ table = Table( Column(header=_("Token"), justify="left"), @@ -1665,7 +1859,7 @@ async def tokens( tool_events: list[Event] | None, tool_event_calls: list[Event] | None, tool_event_results: list[Event] | None, - tool_running_spinner: Spinner | None, + tool_running_spinner: RichSpinner | None, ttft: float, ttnt: float | None, ttsr: float | None, @@ -1687,7 +1881,8 @@ async def tokens( limit_tool_height: bool = True, limit_answer_height: bool = False, start_thinking: bool = False, - ) -> Generator[tuple[Token | None, RenderableType], None, None]: + ) -> AsyncGenerator[tuple[Token | None, RenderableType], None]: + """Generate token display panels asynchronously.""" _, _n, _f, _l = self._, self._n, self._f, logger.debug pick_first = ceil(pick / 2) if pick > 1 else pick @@ -1725,16 +1920,18 @@ async def tokens( "\n".join(wrapped_section).rstrip() if wrapped_section else None ) - dtokens = ( + dtokens: list[Token] | None = ( tokens[-display_token_size:] if display_token_size and tokens else None ) - dtokens_selected = ( + dtokens_selected: list[TokenDetail] | None = ( [ dtoken for dtoken in dtokens - if focus_on_token_when and focus_on_token_when(dtoken) + if isinstance(dtoken, TokenDetail) + and focus_on_token_when + and focus_on_token_when(dtoken) ] if dtokens else None @@ -1748,7 +1945,7 @@ async def tokens( _lf( [ _f( - "input_token_count", + cast(Data, "input_token_count"), _n( "{total_tokens} token in", "{total_tokens} tokens in", @@ -1756,7 +1953,7 @@ async def tokens( ).format(total_tokens=input_token_count), ), _f( - "total_tokens", + cast(Data, "total_tokens"), _n( "{total_tokens} token out", "{total_tokens} tokens out", @@ -1765,7 +1962,7 @@ async def tokens( ), ( _f( - "ttft", + cast(Data, "ttft"), _("ttft: {ttft} s").format(ttft=f"{ttft:.2f}"), ) if ttft @@ -1773,7 +1970,7 @@ async def tokens( ), ( _f( - "ttnt", + cast(Data, "ttnt"), _("ttnt: {ttnt} s").format(ttnt=f"{ttnt:.1f}"), ) if ttnt @@ -1781,21 +1978,21 @@ async def tokens( ), ( _f( - "ttsr", + cast(Data, "ttsr"), _("rt: {ttsr} s").format(ttsr=f"{ttsr:.1f}"), ) if ttsr else None ), _f( - "tokens_rate", + cast(Data, "tokens_rate"), _("{tokens_rate} t/s").format( tokens_rate=f"{total_tokens / elapsed:.2f}" ), ), ( _f( - "events", + cast(Data, "events"), _n( "{total} event", "{total} events", @@ -1807,7 +2004,7 @@ async def tokens( ), ( _f( - "tool_calls", + cast(Data, "tool_calls"), _n( "{total} tool call", "{total} tool calls", @@ -1825,7 +2022,7 @@ async def tokens( ), ( _f( - "tool_call_results", + cast(Data, "tool_call_results"), _n( "{total} result", "{total} results", @@ -1851,7 +2048,7 @@ async def tokens( vertical="top", ), title=_("{model_id} reasoning").format( - model_id=_f("id", model_id) + model_id=_f(cast(Data, "id"), model_id) ), title_align="left", subtitle=progress_title if not wrapped_output else None, @@ -1892,7 +2089,7 @@ async def tokens( Align(wrapped_output, vertical="top"), title=( _("{model_id} response").format( - model_id=_f("id", model_id) + model_id=_f(cast(Data, "id"), model_id) ) if think_wrapped_output is None else None @@ -1915,8 +2112,11 @@ async def tokens( tool_running_panel: RenderableType | None = None - if tool_running_spinner and len(tool_event_calls) != len( - tool_event_results + if ( + tool_running_spinner + and tool_event_calls is not None + and tool_event_results is not None + and len(tool_event_calls) != len(tool_event_results) ): tool_running_panel = Padding( tool_running_spinner, pad=(1, 0, 1, 0) @@ -1951,8 +2151,8 @@ async def tokens( if display_token_size and tokens: # Pick current token to highlight - current_data = None - current_dtoken: Token | None = None + current_data: list[float | None] | None = None + current_dtoken: TokenDetail | None = None if display_probabilities and dtokens_selected: current_selected_index = ( 0 @@ -1986,31 +2186,39 @@ async def tokens( f'Selected "{current_dtoken.token}" as ' "interesting token, with " + clamp( - current_dtoken.probability, format="{:.4g}" + current_dtoken.probability or 0.0, + format="{:.4g}", ) + f"and {current_dtoken.tokens}" ) tokens_panel = self.tokenizer_tokens( - dtokens, + dtokens if dtokens else [], added_tokens, special_tokens, display_details=False, current_dtoken=current_dtoken, - dtokens_selected=dtokens_selected, + dtokens_selected=cast( + list[Token] | None, dtokens_selected + ), ) # Build bar chart with token alternative probabilities chart = None if display_probabilities: - current_symmetric_indices = ( - FancyTheme._symmetric_indices(current_data) + current_symmetric_indices: list[int] | None = ( + FancyTheme._symmetric_indices( + [v or 0.0 for v in current_data] + ) if current_data else None ) - current_symmetric_data = ( - [current_data[i] for i in current_symmetric_indices] - if current_data + current_symmetric_data: list[float] | None = ( + [ + (current_data[i] or 0.0) + for i in current_symmetric_indices + ] + if current_data and current_symmetric_indices else None ) labels = ( @@ -2029,7 +2237,7 @@ async def tokens( for level in range(chart_height, 0, -1): chart_row = "" for value in current_symmetric_data or [ - 0 for i in range(pick) + 0.0 for _ in range(pick) ]: if value * chart_height >= level: chart_row += "".join( @@ -2058,7 +2266,8 @@ async def tokens( if pick > 0 and current_dtoken and current_dtoken.tokens: dtoken_tokens = current_dtoken.tokens max_dtoken = max( - dtoken_tokens, key=lambda dtoken: dtoken.probability + dtoken_tokens, + key=lambda dt: dt.probability or 0.0, ) if pick_first is None or len(dtoken_tokens) <= pick_first: @@ -2170,141 +2379,14 @@ def _events_log( include_tools: bool, include_non_tools: bool, ) -> list[str] | None: + """Generate event log entries.""" _, _n = self._, self._n if not events or events_limit == 0: return None event_log: list[str] | None = _lf( [ - ( - _( - "Executing tool {tool} call #{call_id} with" - " {total_arguments} arguments: {arguments}." - ).format( - tool="[gray78]" - + event.payload["call"].name - + "[/gray78]", - call_id="[gray78]" - + str(event.payload["call"].id)[:8] - + "[/gray78]", - total_arguments=len( - event.payload["call"].arguments or [] - ), - arguments="[gray78]" - + ( - s - if len(s := str(event.payload["call"].arguments)) - <= 50 - else s[:47] + "..." - ) - + "[/gray78]", - ) - if event.type == EventType.TOOL_EXECUTE - else ( - _n( - "Running ReACT model {model_id} with" - " {total_messages} message", - "Running ReACT model {model_id} with" - " {total_messages} messages", - len(event.payload["messages"]), - ).format( - model_id=event.payload["model_id"], - total_messages=len(event.payload["messages"]), - ) - if event.type == EventType.TOOL_MODEL_RUN - else ( - _( - "Got ReACT response from model {model_id}" - ).format(model_id=event.payload["model_id"]) - if event.type == EventType.TOOL_MODEL_RESPONSE - else ( - _n( - "Executing {total_calls} tool: {calls}", - "Executing {total_calls} tools: {calls}", - len(event.payload), - ).format( - total_calls=len(event.payload), - calls="[gray78]" - + "[/gray78], [gray78]".join( - [call.name for call in event.payload] - ) - + "[/gray78]", - ) - if event.type == EventType.TOOL_PROCESS - else ( - _( - "Executed tool {tool} call #{call_id}" - " with {total_arguments} arguments." - ' Got result "{result}" in' - " {elapsed_with_unit}." - ).format( - tool="[gray78]" - + event.payload["result"].call.name - + "[/gray78]", - elapsed_with_unit="[gray78]" - + precisedelta( - event.elapsed, - minimum_unit="microseconds", - ) - + "[/gray78]", - call_id="[gray78]" - + str(event.payload["result"].call.id)[ - :8 - ] - + "[/gray78]", - total_arguments=len( - event.payload[ - "result" - ].call.arguments - or [] - ), - result=( - ( - "[red]" - + event.payload[ - "result" - ].message - + "[/red]" - ) - if isinstance( - event.payload["result"], - ToolCallError, - ) - else ( - "[spring_green3]" - + to_json( - event.payload[ - "result" - ].result - ) - + "[/spring_green3]" - ) - ), - ) - if event.type == EventType.TOOL_RESULT - and event.payload["result"] - else ( - f"[{precisedelta(event.elapsed)}]" - f" <{event.type}>: {event.payload}" - if event.payload and event.elapsed - else ( - f"[{datetime.utcfromtimestamp(event.started).isoformat(sep=' ', timespec='seconds')}] <{event.type}>: {event.payload}" # noqa: E501 - if event.payload and event.started - else ( - f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}] <{event.type}>: {event.payload}" # noqa: E501 - if event.payload - else ( - f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]" # noqa: E501 - f" <{event.type}>" - ) - ) - ) - ) - ) - ) - ) - ) # noqa: E501 - ) + self._format_event(event, _, _n) for event in events if ( ( @@ -2332,12 +2414,132 @@ def _events_log( return event_log + def _format_event( + self, + event: Event, + _: Callable[[str], str], + _n: Callable[[str, str, int], str], + ) -> str | None: + """Format a single event for display.""" + payload = event.payload + + if event.type == EventType.TOOL_EXECUTE and payload: + call = payload.get("call") + if call: + return _( + "Executing tool {tool} call #{call_id} with" + " {total_arguments} arguments: {arguments}." + ).format( + tool="[gray78]" + call.name + "[/gray78]", + call_id="[gray78]" + str(call.id)[:8] + "[/gray78]", + total_arguments=len(call.arguments or []), + arguments="[gray78]" + + ( + s + if len(s := str(call.arguments)) <= 50 + else s[:47] + "..." + ) + + "[/gray78]", + ) + + if event.type == EventType.TOOL_MODEL_RUN and payload: + messages = payload.get("messages", []) + model_id = payload.get("model_id", "") + return _n( + "Running ReACT model {model_id} with {total_messages} message", + "Running ReACT model {model_id} with {total_messages}" + " messages", + len(messages), + ).format( + model_id=model_id, + total_messages=len(messages), + ) + + if event.type == EventType.TOOL_MODEL_RESPONSE and payload: + model_id = payload.get("model_id", "") + return _("Got ReACT response from model {model_id}").format( + model_id=model_id + ) + + if event.type == EventType.TOOL_PROCESS and payload: + calls: list[Any] = payload if isinstance(payload, list) else [] + return _n( + "Executing {total_calls} tool: {calls}", + "Executing {total_calls} tools: {calls}", + len(calls), + ).format( + total_calls=len(calls), + calls="[gray78]" + + "[/gray78], [gray78]".join( + [c.name for c in calls if hasattr(c, "name")] + ) + + "[/gray78]", + ) + + if event.type == EventType.TOOL_RESULT and payload: + result = payload.get("result") + if result: + call = getattr(result, "call", None) + if call: + elapsed_delta = ( + timedelta(seconds=event.elapsed) + if event.elapsed is not None + else timedelta(seconds=0) + ) + result_text: str + if isinstance(result, ToolCallError): + result_text = "[red]" + result.message + "[/red]" + elif isinstance(result, ToolCallResult): + result_text = ( + "[spring_green3]" + + to_json(result.result) + + "[/spring_green3]" + ) + else: + result_text = str(result) + return _( + "Executed tool {tool} call #{call_id} with" + ' {total_arguments} arguments. Got result "{result}"' + " in {elapsed_with_unit}." + ).format( + tool="[gray78]" + call.name + "[/gray78]", + elapsed_with_unit="[gray78]" + + precisedelta( + elapsed_delta, + minimum_unit="microseconds", + ) + + "[/gray78]", + call_id="[gray78]" + str(call.id)[:8] + "[/gray78]", + total_arguments=len(call.arguments or []), + result=result_text, + ) + + # Default format for other event types + if payload and event.elapsed is not None: + elapsed_delta = timedelta(seconds=event.elapsed) + return f"[{precisedelta(elapsed_delta)}] <{event.type}>: {payload}" + if payload and event.started: + return ( + f"[{datetime.fromtimestamp(event.started).isoformat(sep=' ', timespec='seconds')}]" # noqa: E501 + f" <{event.type}>: {payload}" + ) + if payload: + return ( + f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]" + f" <{event.type}>: {payload}" + ) + return ( + f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]" + f" <{event.type}>" + ) + def _tokens_table( self, dbatch: list[Token], current_dtoken: Token | None, max_dtoken: Token | None, - ): + ) -> Table: + """Build token alternatives table.""" _p = self._percentage dtable_color = "gray58" @@ -2368,7 +2570,8 @@ def _tokens_table( table.add_row( f"[gray50]#{dtoken.id}[/gray50]", f"[{dtoken_color}]{dtoken.token}[/{dtoken_color}]", - f"[{dtoken_color}]{_p(dtoken.probability)}[/{dtoken_color}]", + f"[{dtoken_color}]{_p(dtoken.probability or 0.0)}" + f"[/{dtoken_color}]", ) return table @@ -2380,31 +2583,44 @@ def welcome( license: str, user: User | None, ) -> RenderableType: + """Render welcome panel.""" _, _f, _i = self._, self._f, self._icons + avalan_icon = _i.get(cast(Data, "avalan")) or "" + license_icon = _i.get(cast(Data, "license")) or "" license_text = _("{license} license").format(license=license) return Padding( Panel( Padding( _j( " - ", - [ - " ".join( - [ - _i["avalan"] - + f" [link={url}]{name}[/link]", - f"[version]{version}[/version]", - "[bright_black]" - + _i["license"] - + f" {license_text}[/bright_black]", - ] - ), - _f("user", user.name) if user else None, - ( - _f("access_token_name", user.access_token_name) - if user - else None - ), - ], + cast( + list[str], + [ + " ".join( + [ + avalan_icon + + f" [link={url}]{name}[/link]", + f"[version]{version}[/version]", + "[bright_black]" + + license_icon + + f" {license_text}[/bright_black]", + ] + ), + ( + _f(cast(Data, "user"), user.name) + if user + else None + ), + ( + _f( + cast(Data, "access_token_name"), + user.access_token_name, + ) + if user + else None + ), + ], + ), ) ), box=box.SQUARE, @@ -2413,6 +2629,7 @@ def welcome( ) def _parameter_count(self, parameters: int | None) -> str: + """Format parameter count for display.""" _ = self._ if not parameters: return _("N/A") @@ -2423,27 +2640,28 @@ def _parameter_count(self, parameters: int | None) -> str: ) @staticmethod - def _symmetric_indices(data: list[float]) -> list[float]: - """Sorts data desc so that highest values in center lower at edge""" + def _symmetric_indices(data: list[float]) -> list[int]: + """Sort data desc so that highest values in center lower at edge.""" assert data sorted_data = sorted(data, reverse=True) n = len(sorted_data) - result = [None] * n + result: list[int | None] = [None] * n left = n // 2 - 1 right = n // 2 - for i, value in enumerate(sorted_data): + for i, _ in enumerate(sorted_data): if i % 2 == 0: result[left] = i left -= 1 else: result[right] = i right += 1 - return result + return [r for r in result if r is not None] @staticmethod def _percentage(value: float) -> str: + """Format value as percentage.""" p = value * 100 return ( format_string("%d%%", p, grouping=True) @@ -2455,6 +2673,7 @@ def _percentage(value: float) -> str: def _wrap_lines( text_tokens: list[str], width: int, skip_blank_lines: bool = False ) -> list[str]: + """Wrap text tokens to specified width.""" lines: list[str] = [] output = "".join(text_tokens) for line in output.splitlines(): diff --git a/src/avalan/compat.py b/src/avalan/compat.py index 8e8aa82f..8e4b8ba7 100644 --- a/src/avalan/compat.py +++ b/src/avalan/compat.py @@ -1,16 +1,8 @@ -from __future__ import annotations +from collections.abc import Callable +from typing import TypeVar -from typing import Callable, TypeVar - -try: - from typing import override as _override -except ImportError: # Python < 3.12 - T = TypeVar("T", bound=Callable[..., object]) - - def _override(func: T) -> T: # type: ignore - return func +from typing_extensions import override as _override +T = TypeVar("T", bound=Callable[..., object]) override = _override - -__all__ = ["override"] diff --git a/src/avalan/deploy/aws.py b/src/avalan/deploy/aws.py index 10222edf..a842d42c 100644 --- a/src/avalan/deploy/aws.py +++ b/src/avalan/deploy/aws.py @@ -1,8 +1,8 @@ from asyncio import AbstractEventLoop, get_running_loop +from collections.abc import Awaitable, Callable from concurrent.futures import ThreadPoolExecutor -from typing import Callable +from typing import Any, cast -from boto3 import client from boto3.session import Session from botocore.exceptions import ClientError @@ -12,22 +12,27 @@ class DeployError(Exception): class AsyncClient: + """Async wrapper around a boto3 client.""" + + exceptions: Any + def __init__( self, - client: client, + client: Any, loop: AbstractEventLoop | None = None, executor: ThreadPoolExecutor | None = None, - ): + ) -> None: self._client = client self._loop = loop or get_running_loop() self._executor = executor or ThreadPoolExecutor() + self.exceptions = getattr(client, "exceptions", None) - def __getattr__(self, name: str) -> Callable[..., any]: + def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]: attr = getattr(self._client, name) if not callable(attr): - return attr + return cast(Callable[..., Awaitable[Any]], attr) - async def fn(*args, **kwargs): + async def fn(*args: Any, **kwargs: Any) -> Any: return await self._loop.run_in_executor( self._executor, lambda: attr(*args, **kwargs) ) @@ -36,17 +41,21 @@ async def fn(*args, **kwargs): class Aws: + """AWS deployment manager for EC2 and RDS resources.""" + _ec2: AsyncClient _rds: AsyncClient _session: Session def __init__( - self, settings: dict | None = None, token_pair: str | None = None - ): + self, + settings: dict[str, Any] | None = None, + token_pair: str | None = None, + ) -> None: if settings and "token_pair" in settings and not token_pair: token_pair = settings.pop("token_pair") - aws_settings = {} + aws_settings: dict[str, str] = {} if token_pair: access_key, secret_key = token_pair.split(":", 1) @@ -63,13 +72,14 @@ def __init__( self._rds = AsyncClient(self._session.client("rds")) async def get_vpc_id(self, name: str) -> str: + """Return the VPC ID for a VPC with the given name.""" response = await self._ec2.describe_vpcs( Filters=[{"Name": "tag:Name", "Values": [name]}] ) vpcs = response.get("Vpcs", []) if not vpcs: raise DeployError(f"VPC {name!r} not found") - return vpcs[0]["VpcId"] + return str(vpcs[0]["VpcId"]) async def create_vpc_if_missing(self, name: str, cidr: str) -> str: """Return an existing VPC id or create a new VPC.""" @@ -77,7 +87,7 @@ async def create_vpc_if_missing(self, name: str, cidr: str) -> str: return await self.get_vpc_id(name) except DeployError: response = await self._ec2.create_vpc(CidrBlock=cidr) - vpc_id = response["Vpc"]["VpcId"] + vpc_id = str(response["Vpc"]["VpcId"]) await self._ec2.create_tags( Resources=[vpc_id], Tags=[{"Key": "Name", "Value": name}] ) @@ -86,21 +96,23 @@ async def create_vpc_if_missing(self, name: str, cidr: str) -> str: return vpc_id async def get_security_group(self, name: str, vpc_id: str) -> str: + """Return the security group ID, creating one if necessary.""" response = await self._ec2.describe_security_groups( Filters=[{"Name": "group-name", "Values": [name]}] ) groups = response.get("SecurityGroups", []) if groups: - return groups[0]["GroupId"] + return str(groups[0]["GroupId"]) response = await self._ec2.create_security_group( GroupName=name, Description="avalan deployment", VpcId=vpc_id, ) - return response["GroupId"] + return str(response["GroupId"]) async def configure_security_group(self, group_id: str, port: int) -> None: + """Configure ingress rules for the security group.""" try: await self._ec2.authorize_security_group_ingress( GroupId=group_id, @@ -116,9 +128,10 @@ async def configure_security_group(self, group_id: str, port: int) -> None: async def create_rds_if_missing( self, db_id: str, instance_class: str, sg_id: str, storage: int ) -> str: + """Create an RDS instance if it does not already exist.""" try: await self._rds.describe_db_instances(DBInstanceIdentifier=db_id) - except self._rds.exceptions.DBInstanceNotFoundFault: + except self._rds.exceptions.DBInstanceNotFoundFault: # type: ignore[misc] await self._rds.create_db_instance( DBInstanceIdentifier=db_id, DBInstanceClass=instance_class, @@ -143,13 +156,14 @@ async def create_instance_if_missing( agent_path: str, port: int, ) -> str: + """Create an EC2 instance if it does not already exist.""" user_data = self._create_user_data(agent_path, port) response = await self._ec2.describe_instances( Filters=[{"Name": "tag:Name", "Values": [instance_name]}] ) reservations = response.get("Reservations", []) if reservations: - return reservations[0]["Instances"][0]["InstanceId"] + return str(reservations[0]["Instances"][0]["InstanceId"]) response = await self._ec2.describe_subnets( Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] @@ -172,7 +186,7 @@ async def create_instance_if_missing( } ], ) - return response["Instances"][0]["InstanceId"] + return str(response["Instances"][0]["InstanceId"]) def _create_user_data(self, agent_path: str, port: int) -> str: cmd = f"avalan agent serve {agent_path} --host 0.0.0.0 --port {port}\n" diff --git a/src/avalan/entities.py b/src/avalan/entities.py index 945445a0..788ed697 100644 --- a/src/avalan/entities.py +++ b/src/avalan/entities.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import StrEnum -from typing import Literal, TypedDict, final +from typing import Any, Literal, TypedDict, final from uuid import UUID from numpy import ndarray @@ -636,7 +636,7 @@ class ModelConfig: sep_token: str | None state_size: int # Additional keyword arguments to store for the current task - task_specific_params: dict[str, any] | None + task_specific_params: dict[str, Any] | None # The dtype of the weight. Since the config object is stored in plain # text, this attribute contains just the floating type string without the # torch @@ -759,9 +759,9 @@ class OperationVisionParameters: class OperationParameters(TypedDict, total=False): - audio: OperationAudioParameters | None = None - text: OperationTextParameters | None = None - vision: OperationVisionParameters | None = None + audio: OperationAudioParameters | None + text: OperationTextParameters | None + vision: OperationVisionParameters | None @final @@ -896,7 +896,7 @@ class TransformerEngineSettings(EngineSettings): low_cpu_mem_usage: bool = False output_hidden_states: bool = False special_tokens: list[str] | None = None - state_dict: dict[str, Tensor] = None + state_dict: dict[str, Tensor] | None = None tokens: list[str] | None = None diff --git a/src/avalan/event/manager.py b/src/avalan/event/manager.py index 51c7aa8d..48988720 100644 --- a/src/avalan/event/manager.py +++ b/src/avalan/event/manager.py @@ -3,6 +3,7 @@ from asyncio import Event as EventSignal from asyncio import Queue, TimeoutError, wait_for from collections import defaultdict, deque +from collections.abc import AsyncGenerator from inspect import iscoroutine from typing import Awaitable, Callable, Iterable @@ -58,7 +59,7 @@ async def trigger(self, event: Event) -> None: async def listen( self, stop_signal: EventSignal | None = None, timeout: float = 0.2 - ): + ) -> AsyncGenerator[Event, None]: while True: try: yield await wait_for(self._queue.get(), timeout=timeout) diff --git a/src/avalan/memory/__init__.py b/src/avalan/memory/__init__.py index 5ae42d10..ac6a1cbf 100644 --- a/src/avalan/memory/__init__.py +++ b/src/avalan/memory/__init__.py @@ -32,7 +32,9 @@ async def search(self, query: str) -> list[T] | None: class MessageMemory(MemoryStore[EngineMessage], ABC): - def search(self, query: str) -> list[EngineMessage] | None: + def search( # type: ignore[override] + self, query: str + ) -> list[EngineMessage] | None: raise NotImplementedError() @@ -40,17 +42,17 @@ class RecentMessageMemory(MessageMemory): _lock: Lock _data: list[EngineMessage] - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: object) -> None: self._lock = Lock() - self.reset() + self._data = [] super().__init__(**kwargs) @override - def append(self, data: EngineMessage) -> None: + def append(self, data: EngineMessage) -> None: # type: ignore[override] with self._lock: self._data.append(data) - def reset(self) -> None: + def reset(self) -> None: # type: ignore[override] with self._lock: self._data = [] diff --git a/src/avalan/memory/partitioner/code.py b/src/avalan/memory/partitioner/code.py index 8fd017d2..3bd16a02 100644 --- a/src/avalan/memory/partitioner/code.py +++ b/src/avalan/memory/partitioner/code.py @@ -1,5 +1,4 @@ from ...memory.partitioner import Encoding, PartitionerException -from ...utils import _j from dataclasses import dataclass from logging import Logger @@ -80,7 +79,9 @@ def partition( if root_node.children and root_node.children[0].is_error: error_node = root_node.children[0] error_name = error_node.grammar_name - error_message = error_node.text.decode(encoding) + error_text = error_node.text + assert error_text is not None, "Error node text cannot be None" + error_message = error_text.decode(encoding) error_row, error_column = error_node.start_point raise PartitionerException( f'{error_name}: "{error_message}" at' @@ -133,12 +134,16 @@ def _partition( for child_node in node.children: if child_node.type == "class_definition": class_name_node = child_node.child_by_field_name("name") - current_class_name = ( - class_name_node.text.decode(encoding) - if class_name_node - else None - ) - class_id = _j(".", [current_namespace, current_class_name]) + if class_name_node and class_name_node.text: + current_class_name = class_name_node.text.decode(encoding) + else: + current_class_name = None + class_parts: list[str] = [] + if current_namespace: + class_parts.append(current_namespace) + if current_class_name: + class_parts.append(current_class_name) + class_id = ".".join(class_parts) if class_parts else "" child_symbols = current_symbols + [ Symbol(symbol_type="class", id=class_id) ] @@ -215,11 +220,10 @@ def _get_functions( results = [] if node.type == "class_definition": class_name_node = node.child_by_field_name("name") - class_name = ( - class_name_node.text.decode(encoding) - if class_name_node - else None - ) + if class_name_node and class_name_node.text: + class_name = class_name_node.text.decode(encoding) + else: + class_name = None for child in node.children: functs = cls._get_functions( current_namespace, @@ -262,6 +266,9 @@ def _get_function_from_node( ) params_node = node.child_by_field_name("parameters") return_type_node = node.child_by_field_name("return_type") + return_type: str | None = None + if return_type_node and return_type_node.text: + return_type = return_type_node.text.decode(encoding) return Function( id=function_id, namespace=current_namespace, @@ -272,11 +279,7 @@ def _get_function_from_node( if params_node else None ), - return_type=( - return_type_node.text.decode(encoding) - if return_type_node - else None - ), + return_type=return_type, ) @staticmethod @@ -291,53 +294,64 @@ def _get_function_id_and_name_from_node( "async_function_definition", ) function_name_node = node.child_by_field_name("name") - assert function_name_node + assert function_name_node is not None + assert function_name_node.text is not None function_name = function_name_node.text.decode(encoding) assert function_name - function_id = _j( - ".", [current_namespace, current_class_name, function_name] - ) + id_parts: list[str] = [] + if current_namespace: + id_parts.append(current_namespace) + if current_class_name: + id_parts.append(current_class_name) + id_parts.append(function_name) + function_id = ".".join(id_parts) return function_id, function_name @staticmethod - def _get_parameters( - node: Node, encoding: Encoding - ) -> list[dict[str, str | None]]: + def _get_parameters(node: Node, encoding: Encoding) -> list[Parameter]: assert node - parameters = [] + parameters: list[Parameter] = [] for child in node.children: - if child.type in ( - "default_parameter", - "typed_default_parameter", - "typed_parameter", - ): - param_type = child.child_by_field_name("type").text.decode( - encoding - ) - param_name = None - if child.named_child_count: - for parameter_child in child.children: - if parameter_child.type == "identifier": - param_name = parameter_child.text.decode(encoding) - break - parameters.append( - Parameter( - parameter_type=child.type, - name=param_name, - type=param_type, + match child.type: + case ( + "default_parameter" + | "typed_default_parameter" + | "typed_parameter" + ) as param_kind: + type_node = child.child_by_field_name("type") + param_type: str | None = None + if type_node and type_node.text: + param_type = type_node.text.decode(encoding) + param_name: str = "" + if child.named_child_count: + for parameter_child in child.children: + if parameter_child.type == "identifier": + if parameter_child.text: + param_name = parameter_child.text.decode( + encoding + ) + break + parameters.append( + Parameter( + parameter_type=param_kind, + name=param_name, + type=param_type, + ) ) - ) - elif child.type in ( - "dictionary_splat_pattern", - "identifier", - "keyword_separator", - ): - parameters.append( - Parameter( - parameter_type=child.type, - name=child.text.decode(encoding), - type=None, + case ( + "dictionary_splat_pattern" + | "identifier" + | "keyword_separator" + ) as simple_kind: + name: str = "" + if child.text: + name = child.text.decode(encoding) + parameters.append( + Parameter( + parameter_type=simple_kind, + name=name, + type=None, + ) ) - ) return parameters diff --git a/src/avalan/memory/partitioner/text.py b/src/avalan/memory/partitioner/text.py index c2f4997e..803b7d5a 100644 --- a/src/avalan/memory/partitioner/text.py +++ b/src/avalan/memory/partitioner/text.py @@ -60,8 +60,8 @@ def configure( self._window_size = window_size self._overlap_size = overlap_size - @override @property + @override def sentence_model(self) -> Callable: return self._model diff --git a/src/avalan/memory/permanent/__init__.py b/src/avalan/memory/permanent/__init__.py index 0b96cf3d..3642c7f7 100644 --- a/src/avalan/memory/permanent/__init__.py +++ b/src/avalan/memory/permanent/__init__.py @@ -175,7 +175,7 @@ def has_session(self) -> bool: def session_id(self) -> UUID | None: return self._session_id - def reset(self) -> None: + def reset(self) -> None: # type: ignore[override] raise NotImplementedError() async def reset_session( @@ -198,7 +198,7 @@ async def continue_session( ) @override - def append(self, data: EngineMessage) -> None: + def append(self, data: EngineMessage) -> None: # type: ignore[override] raise NotImplementedError() @abstractmethod @@ -263,15 +263,14 @@ def _build_message_with_partitions( if message_id is None: message_id = uuid4() content = engine_message.message.content - data = ( - content.text - if isinstance(content, MessageContentText) - else ( - content.image_url - if isinstance(content, MessageContentImage) - else str(content) - ) - ) + if isinstance(content, MessageContentText): + data = content.text + elif isinstance(content, MessageContentImage): + data = str(content.image_url) + elif content is None: + data = "" + else: + data = str(content) message = PermanentMessage( id=message_id, @@ -326,10 +325,10 @@ def __init__( super().__init__(**kwargs) @override - def append(self, data: EngineMessage) -> None: + def append(self, data: EngineMessage) -> None: # type: ignore[override] raise NotImplementedError() - def reset(self) -> None: + def reset(self) -> None: # type: ignore[override] raise NotImplementedError() @abstractmethod diff --git a/src/avalan/memory/permanent/elasticsearch/message.py b/src/avalan/memory/permanent/elasticsearch/message.py index 00eb8fc0..6991c365 100644 --- a/src/avalan/memory/permanent/elasticsearch/message.py +++ b/src/avalan/memory/permanent/elasticsearch/message.py @@ -51,8 +51,8 @@ async def create_instance( sentence_model=sentence_model, ) - async def create_session( - self, *, agent_id: UUID, participant_id: UUID + async def create_session( # type: ignore[override] + self, agent_id: UUID, participant_id: UUID ) -> UUID: now_utc = datetime.now(timezone.utc) session = self._build_session( @@ -75,7 +75,7 @@ async def create_session( async def continue_session_and_get_id( self, - *, + *args: object, agent_id: UUID, participant_id: UUID, session_id: UUID, @@ -92,7 +92,7 @@ async def continue_session_and_get_id( async def append_with_partitions( self, engine_message: EngineMessage, - *, + *args: object, partitions: list[TextPartition], ) -> None: assert engine_message and partitions @@ -143,7 +143,7 @@ async def get_recent_messages( self, session_id: UUID, participant_id: UUID, - *, + *args: object, limit: int | None = None, ) -> list[EngineMessage]: response = await self._call_client( @@ -173,9 +173,9 @@ async def get_recent_messages( ) return messages - async def search_messages( + async def search_messages( # type: ignore[override] self, - *, + *args: object, agent_id: UUID, function: VectorFunction, limit: int | None = None, diff --git a/src/avalan/memory/permanent/elasticsearch/raw.py b/src/avalan/memory/permanent/elasticsearch/raw.py index ec84caed..adc92c74 100644 --- a/src/avalan/memory/permanent/elasticsearch/raw.py +++ b/src/avalan/memory/permanent/elasticsearch/raw.py @@ -32,7 +32,7 @@ def __init__( ElasticsearchMemory.__init__( self, index=index, client=client, logger=logger ) - PermanentMemory.__init__(self, sentence_model=None) + PermanentMemory.__init__(self, sentence_model=None) # type: ignore[arg-type] @classmethod async def create_instance( @@ -47,11 +47,11 @@ async def create_instance( memory = cls(index=index, client=es_client, logger=logger) return memory - async def append_with_partitions( + async def append_with_partitions( # type: ignore[override] self, namespace: str, participant_id: UUID, - *, + *args: object, memory_type: MemoryType, data: str, identifier: str, @@ -113,9 +113,9 @@ async def append_with_partitions( }, ) - async def search_memories( + async def search_memories( # type: ignore[override] self, - *, + *args: object, search_partitions: list[TextPartition], participant_id: UUID, namespace: str, @@ -212,7 +212,7 @@ async def list_memories( ) return memories - async def search( + async def search( # type: ignore[override] self, query: str ) -> list[PermanentMemoryPartition] | None: raise NotImplementedError() diff --git a/src/avalan/memory/permanent/pgsql/__init__.py b/src/avalan/memory/permanent/pgsql/__init__.py index 74cf2de1..6ae926cf 100644 --- a/src/avalan/memory/permanent/pgsql/__init__.py +++ b/src/avalan/memory/permanent/pgsql/__init__.py @@ -9,7 +9,7 @@ from logging import Logger from time import perf_counter -from typing import TypeVar +from typing import Any, TypeVar from pgvector.psycopg import register_vector_async from psycopg import AsyncConnection, AsyncCursor @@ -18,10 +18,11 @@ from psycopg_pool import AsyncConnectionPool T = TypeVar("T") +E = TypeVar("E") class BasePgsqlMemory(MemoryStore[T]): - _database: AsyncConnection + _database: AsyncConnectionPool _logger: Logger def __init__(self, database: AsyncConnectionPool, logger: Logger): @@ -45,8 +46,8 @@ async def _execute( ) async def _fetch_all( - self, entity: type[T], query: str, parameters: tuple - ) -> list[T]: + self, entity: type[E], query: str, parameters: tuple + ) -> list[E]: async with self._database.connection() as connection: async with connection.cursor() as cursor: await self._execute(cursor, query, parameters) @@ -59,8 +60,8 @@ async def _fetch_all( ) async def _fetch_one( - self, entity: type[T], query: str, parameters: tuple - ) -> T: + self, entity: type[E], query: str, parameters: tuple + ) -> E: result = await self._try_fetch_one(entity, query, parameters) if result is None: raise RecordNotFoundException() @@ -86,8 +87,8 @@ async def _has_one(self, query: str, parameters: tuple) -> bool: return result is not None async def _try_fetch_one( - self, entity: type[T], query: str, parameters: tuple - ) -> T | None: + self, entity: type[E], query: str, parameters: tuple + ) -> E | None: async with self._database.connection() as connection: async with connection.cursor() as cursor: await self._execute(cursor, query, parameters) @@ -96,8 +97,8 @@ async def _try_fetch_one( return entity(**dict(result)) if result is not None else None async def _update_and_fetch_one( - self, entity: type[T], query: str, parameters: tuple - ) -> T: + self, entity: type[E], query: str, parameters: tuple + ) -> E: row = await self._update_and_fetch_row(query, parameters) return entity(**row) @@ -105,7 +106,9 @@ async def _update_and_fetch_field( self, field: str, query: str, parameters: tuple ) -> str: row = await self._update_and_fetch_row(query, parameters) - return row[field] + value = row[field] + assert isinstance(value, str) + return value async def _update_and_fetch_row( self, query: str, parameters: tuple @@ -135,33 +138,29 @@ async def create_instance_from_pool( pool: AsyncConnectionPool, *, logger: Logger, - **kwargs, - ): + **kwargs: Any, + ) -> "PgsqlMemory[T]": memory = cls(dsn=None, pool=pool, logger=logger, **kwargs) return memory def __init__( self, dsn: str | None, - *args, + *args: object, pool: AsyncConnectionPool | None = None, composite_types: list[str] | None = None, pool_minimum: int | None = None, pool_maximum: int | None = None, logger: Logger, - **kwargs, + **kwargs: object, ): - assert pool or ( - dsn - and pool_minimum - and pool_minimum - and pool_minimum > 0 - and pool_maximum > pool_minimum - ) - if pool: super().__init__(database=pool, logger=logger, **kwargs) else: + assert dsn is not None + assert pool_minimum is not None and pool_minimum > 0 + assert pool_maximum is not None and pool_maximum > pool_minimum + self._composite_types = composite_types if "//" not in dsn: @@ -176,8 +175,8 @@ def __init__( ) super().__init__(database=database, logger=logger, **kwargs) - async def _configure_connection(self, connection: AsyncConnection): - connection.row_factory = dict_row + async def _configure_connection(self, connection: AsyncConnection) -> None: + connection.row_factory = dict_row # type: ignore[assignment] await connection.set_autocommit(True) if self._composite_types: for composite_type_name in self._composite_types: @@ -191,28 +190,33 @@ async def _configure_connection(self, connection: AsyncConnection): @staticmethod def _to_engine_messages( messages: list[PermanentMessage] | list[PermanentMessageScored], - *args, + *args: object, limit: int | None, reverse: bool = False, scored: bool = False, ) -> list[EngineMessage] | list[EngineMessageScored]: - engine_messages = [ - ( + engine_messages: list[EngineMessage] | list[EngineMessageScored] + if scored: + scored_messages: list[EngineMessageScored] = [ EngineMessageScored( agent_id=m.agent_id, model_id=m.model_id, message=Message(role=m.author, content=m.data), - score=m.score, + score=m.score, # type: ignore[attr-defined, union-attr] ) - if scored - else EngineMessage( + for m in messages + ] + engine_messages = scored_messages + else: + unscored_messages: list[EngineMessage] = [ + EngineMessage( agent_id=m.agent_id, model_id=m.model_id, message=Message(role=m.author, content=m.data), ) - ) - for m in messages - ] + for m in messages + ] + engine_messages = unscored_messages if reverse: engine_messages.reverse() if limit and len(engine_messages) > limit: diff --git a/src/avalan/memory/permanent/pgsql/message.py b/src/avalan/memory/permanent/pgsql/message.py index d1d6366f..8ea35a3e 100644 --- a/src/avalan/memory/permanent/pgsql/message.py +++ b/src/avalan/memory/permanent/pgsql/message.py @@ -15,7 +15,7 @@ from pgvector.psycopg import Vector -class PgsqlMessageMemory( +class PgsqlMessageMemory( # type: ignore[misc] PgsqlMemory[PermanentMessage], PermanentMessageMemory ): """PostgreSQL-backed implementation of :class:`PermanentMessageMemory`.""" @@ -44,8 +44,8 @@ async def create_instance( await memory.open() return memory - async def create_session( - self, *args, agent_id: UUID, participant_id: UUID + async def create_session( # type: ignore[override] + self, agent_id: UUID, participant_id: UUID ) -> UUID: """Create a new session for a participant.""" now_utc = datetime.now(timezone.utc) @@ -81,13 +81,13 @@ async def create_session( async def continue_session_and_get_id( self, - *args, + *args: object, agent_id: UUID, participant_id: UUID, session_id: UUID, ) -> UUID: """Continue an existing session if it belongs to the participant.""" - session_id = await self._fetch_field( + fetched_session_id = await self._fetch_field( "id", """ SELECT "sessions"."id" @@ -99,13 +99,15 @@ async def continue_session_and_get_id( """, (str(agent_id), str(participant_id), str(session_id)), ) - assert session_id - return session_id if isinstance(session_id, UUID) else UUID(session_id) + assert fetched_session_id is not None + if isinstance(fetched_session_id, UUID): + return fetched_session_id + return UUID(fetched_session_id) async def append_with_partitions( self, engine_message: EngineMessage, - *args, + *args: object, partitions: list[TextPartition], ) -> None: """Persist a message and its partitions.""" @@ -197,7 +199,7 @@ async def get_recent_messages( self, session_id: UUID, participant_id: UUID, - *args, + *args: object, limit: int | None = None, ) -> list[EngineMessage]: """Retrieve recent messages for a session.""" @@ -227,11 +229,12 @@ async def get_recent_messages( engine_messages = self._to_engine_messages( messages, limit=limit_value, reverse=True ) - return engine_messages + assert isinstance(engine_messages, list) + return engine_messages # type: ignore[return-value] - async def search_messages( + async def search_messages( # type: ignore[override] self, - *args, + *args: object, agent_id: UUID, function: VectorFunction, limit: int | None = None, @@ -297,4 +300,5 @@ async def search_messages( limit=limit_value, scored=True, ) - return engine_messages + assert isinstance(engine_messages, list) + return engine_messages # type: ignore[return-value] diff --git a/src/avalan/memory/permanent/pgsql/raw.py b/src/avalan/memory/permanent/pgsql/raw.py index 635ee4da..4473f03b 100644 --- a/src/avalan/memory/permanent/pgsql/raw.py +++ b/src/avalan/memory/permanent/pgsql/raw.py @@ -379,7 +379,10 @@ async def upsert_entity( ), ) result = await cursor.fetchone() - entity_id = result["id"] if result else None + assert result is not None + result_dict = dict(result) + entity_id = result_dict.get("id") + assert entity_id is not None await cursor.execute( """ @@ -403,4 +406,6 @@ async def upsert_entity( ), ) await cursor.close() - return UUID(entity_id) if isinstance(entity_id, str) else entity_id + return ( + UUID(entity_id) if isinstance(entity_id, str) else UUID(entity_id) + ) diff --git a/src/avalan/memory/permanent/s3vectors/message.py b/src/avalan/memory/permanent/s3vectors/message.py index 80f9db00..74f73ee6 100644 --- a/src/avalan/memory/permanent/s3vectors/message.py +++ b/src/avalan/memory/permanent/s3vectors/message.py @@ -60,14 +60,14 @@ async def create_instance( sentence_model=sentence_model, ) - async def create_session( - self, *, agent_id: UUID, participant_id: UUID + async def create_session( # type: ignore[override] + self, agent_id: UUID, participant_id: UUID ) -> UUID: return uuid4() async def continue_session_and_get_id( self, - *, + *args: object, agent_id: UUID, participant_id: UUID, session_id: UUID, @@ -77,7 +77,7 @@ async def continue_session_and_get_id( async def append_with_partitions( self, engine_message: EngineMessage, - *, + *args: object, partitions: list[TextPartition], ) -> None: assert engine_message and partitions @@ -131,7 +131,7 @@ async def get_recent_messages( self, session_id: UUID, participant_id: UUID, - *, + *args: object, limit: int | None = None, ) -> list[EngineMessage]: prefix = f"{self._collection}/{session_id}/" @@ -157,9 +157,9 @@ async def get_recent_messages( ) return messages - async def search_messages( + async def search_messages( # type: ignore[override] self, - *, + *args: object, agent_id: UUID, function: VectorFunction, limit: int | None = None, diff --git a/src/avalan/memory/permanent/s3vectors/raw.py b/src/avalan/memory/permanent/s3vectors/raw.py index 1f53d75b..c4fc6c92 100644 --- a/src/avalan/memory/permanent/s3vectors/raw.py +++ b/src/avalan/memory/permanent/s3vectors/raw.py @@ -41,7 +41,7 @@ def __init__( client=client, logger=logger, ) - PermanentMemory.__init__(self, sentence_model=None) + PermanentMemory.__init__(self, sentence_model=None) # type: ignore[arg-type] @classmethod async def create_instance( @@ -62,11 +62,11 @@ async def create_instance( ) return memory - async def append_with_partitions( + async def append_with_partitions( # type: ignore[override] self, namespace: str, participant_id: UUID, - *, + *args: object, memory_type: MemoryType, data: str, identifier: str, @@ -131,9 +131,9 @@ async def append_with_partitions( }, ) - async def search_memories( + async def search_memories( # type: ignore[override] self, - *, + *args: object, search_partitions: list[TextPartition], participant_id: UUID, namespace: str, @@ -230,7 +230,7 @@ async def list_memories( ) return memories - async def search( + async def search( # type: ignore[override] self, query: str ) -> list[PermanentMemoryPartition] | None: raise NotImplementedError() diff --git a/src/avalan/memory/source.py b/src/avalan/memory/source.py index ab91fdd2..68de9af1 100644 --- a/src/avalan/memory/source.py +++ b/src/avalan/memory/source.py @@ -89,11 +89,11 @@ async def _convert_bytes( if self._is_pdf(url, content_type, data): metadata = PdfReader(BytesIO(data)).metadata - metadata_title = ( - metadata["/Title"] - if metadata and "/Title" in metadata - else None - ) + metadata_title: str | None = None + if metadata and "/Title" in metadata: + raw_title = metadata["/Title"] + if isinstance(raw_title, str): + metadata_title = raw_title title = metadata_title or title or self._markdown_title(markdown) description = description or self._markdown_description(markdown) diff --git a/src/avalan/model/audio/__init__.py b/src/avalan/model/audio/__init__.py index 2510aaab..a934e123 100644 --- a/src/avalan/model/audio/__init__.py +++ b/src/avalan/model/audio/__init__.py @@ -1,11 +1,10 @@ from ...model import TokenizerNotSupportedException from ...model.engine import Engine -from abc import ABC, abstractmethod -from typing import Literal +from abc import ABC +from typing import Any -from numpy import ndarray -from PIL import Image +from numpy.typing import NDArray from torch import Tensor from torchaudio import load from torchaudio.functional import resample @@ -16,13 +15,7 @@ class BaseAudioModel(Engine, ABC): - @abstractmethod - async def __call__( - self, - image_source: str | Image.Image, - tensor_format: Literal["pt"] = "pt", - ) -> str: - raise NotImplementedError() + """Base class for audio models.""" def _load_tokenizer( self, tokenizer_name_or_path: str | None, use_fast: bool = True @@ -35,25 +28,28 @@ def _load_tokenizer_with_tokens( raise TokenizerNotSupportedException() def _resample_mono(self, audio_source: str, sampling_rate: int) -> Tensor: - wave, wave_sampling_rate = load(audio_source) + wave_data: Tensor + wave_data, wave_sampling_rate = load(audio_source) - if wave.shape[0] > 1: + if wave_data.shape[0] > 1: # stereo -> mono - wave = wave.mean(dim=0) + wave_data = wave_data.mean(dim=0) # type: ignore[operator] else: # already mono, just drop channel dim (samples,) - wave = wave.squeeze(0) + wave_data = wave_data.squeeze(0) # type: ignore[operator] if wave_sampling_rate != sampling_rate: - wave = resample( - wave.unsqueeze(0), wave_sampling_rate, sampling_rate + wave_data = resample( + wave_data.unsqueeze(0), + wave_sampling_rate, + sampling_rate, # type: ignore[operator] ).squeeze(0) - return wave + return wave_data - def _resample(self, audio_source: str, sampling_rate: int) -> ndarray: + def _resample(self, audio_source: str, sampling_rate: int) -> NDArray[Any]: wave, wave_sampling_rate = load(audio_source) if wave_sampling_rate != sampling_rate: wave = resample(wave, wave_sampling_rate, sampling_rate) - wave = wave.mean(0).numpy() - return wave + result: NDArray[Any] = wave.mean(0).numpy() # type: ignore[union-attr] + return result diff --git a/src/avalan/model/audio/classification.py b/src/avalan/model/audio/classification.py index f4dc18d9..84bcf59d 100644 --- a/src/avalan/model/audio/classification.py +++ b/src/avalan/model/audio/classification.py @@ -3,7 +3,7 @@ from ...model.engine import Engine from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import inference_mode @@ -15,26 +15,29 @@ class AudioClassificationModel(BaseAudioModel): - _extractor: AutoFeatureExtractor + _extractor: Any # AutoFeatureExtractor with model-specific methods def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id, "model_id is required" self._extractor = AutoFeatureExtractor.from_pretrained(self._model_id) - model = AutoModelForAudioClassification.from_pretrained( - self._model_id, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), - subfolder=self._settings.subfolder or "", - ).to(self._device) + model: PreTrainedModel = ( + AutoModelForAudioClassification.from_pretrained( + self._model_id, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + subfolder=self._settings.subfolder or "", + ).to(self._device) + ) return model @override - async def __call__( + async def __call__( # type: ignore[override] self, path: str, *, @@ -43,6 +46,8 @@ async def __call__( tensor_format: Literal["pt"] = "pt", ) -> dict[str, float]: assert path + assert self._model is not None, "Model must be loaded" + assert isinstance(self._model, PreTrainedModel) wave = self._resample_mono(path, sampling_rate) inputs = self._extractor( @@ -52,11 +57,14 @@ async def __call__( padding=padding, ).to(self._device) - id2label = {int(k): v for k, v in self._model.config.id2label.items()} + id2label_raw = self._model.config.id2label or {} + id2label: dict[int, str] = { + int(k): str(v) for k, v in id2label_raw.items() + } labels: dict[str, float] = {} with inference_mode(): - logits = self._model(**inputs).logits + logits = self._model(**inputs).logits # type: ignore[operator] probs = logits.softmax(dim=-1)[0] for idx, p in sorted(enumerate(probs), key=lambda x: -x[1]): labels[id2label[idx]] = p.item() diff --git a/src/avalan/model/audio/generation.py b/src/avalan/model/audio/generation.py index ba8c4344..32675fe0 100644 --- a/src/avalan/model/audio/generation.py +++ b/src/avalan/model/audio/generation.py @@ -3,7 +3,7 @@ from ...model.engine import Engine from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import from_numpy, inference_mode @@ -16,25 +16,28 @@ class AudioGenerationModel(BaseAudioModel): - _processor: AutoProcessor + _processor: Any # AutoProcessor with Musicgen-specific methods def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id, "model_id is required" self._processor = AutoProcessor.from_pretrained(self._model_id) - model = MusicgenForConditionalGeneration.from_pretrained( - self._model_id, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), - subfolder=self._settings.subfolder or "", - ).to(self._device) + model: PreTrainedModel = ( + MusicgenForConditionalGeneration.from_pretrained( + self._model_id, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + subfolder=self._settings.subfolder or "", + ).to(self._device) + ) return model @override - async def __call__( + async def __call__( # type: ignore[override] self, prompt: str, path: str, @@ -44,6 +47,8 @@ async def __call__( tensor_format: Literal["pt"] = "pt", ) -> str: assert path + assert self._model is not None, "Model must be loaded" + assert isinstance(self._model, PreTrainedModel) inputs = self._processor( text=[prompt], return_tensors=tensor_format, padding=padding @@ -51,11 +56,11 @@ async def __call__( inputs.to(self._device) with inference_mode(): - audio_tokens = self._model.generate( + audio_tokens = self._model.generate( # type: ignore[attr-defined,operator] **inputs, max_new_tokens=max_new_tokens ) - sampling_rate = self._model.config.audio_encoder.sampling_rate + sampling_rate: int = self._model.config.audio_encoder.sampling_rate waveform = audio_tokens[0, 0].cpu().numpy() wave_tensor = from_numpy(waveform).unsqueeze(0) save(path, wave_tensor, sampling_rate) diff --git a/src/avalan/model/audio/speech.py b/src/avalan/model/audio/speech.py index ca57e729..f4ea5eea 100644 --- a/src/avalan/model/audio/speech.py +++ b/src/avalan/model/audio/speech.py @@ -3,7 +3,7 @@ from ...model.engine import Engine from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import inference_mode @@ -15,17 +15,18 @@ class TextToSpeechModel(BaseAudioModel): - _processor: AutoProcessor + _processor: Any # AutoProcessor with DiaModel-specific methods def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id, "model_id is required" self._processor = AutoProcessor.from_pretrained( self._model_id, trust_remote_code=self._settings.trust_remote_code, subfolder=self._settings.tokenizer_subfolder or "", ) - model = DiaForConditionalGeneration.from_pretrained( + model: PreTrainedModel = DiaForConditionalGeneration.from_pretrained( self._model_id, trust_remote_code=self._settings.trust_remote_code, device_map=self._device, @@ -38,7 +39,7 @@ def _load_model( return model @override - async def __call__( + async def __call__( # type: ignore[override] self, prompt: str, path: str, @@ -50,6 +51,7 @@ async def __call__( sampling_rate: int = 44_100, tensor_format: Literal["pt"] = "pt", ) -> str: + assert self._model is not None, "Model must be loaded" assert (not reference_path and not reference_text) or ( reference_path and reference_text ) @@ -81,7 +83,8 @@ async def __call__( ) with inference_mode(): - outputs = self._model.generate( + assert isinstance(self._model, PreTrainedModel) + outputs = self._model.generate( # type: ignore[operator] **inputs, max_new_tokens=max_new_tokens ) diff --git a/src/avalan/model/audio/speech_recognition.py b/src/avalan/model/audio/speech_recognition.py index 3a536558..ff5dd30a 100644 --- a/src/avalan/model/audio/speech_recognition.py +++ b/src/avalan/model/audio/speech_recognition.py @@ -3,7 +3,7 @@ from ...model.engine import Engine from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import argmax, inference_mode @@ -15,11 +15,12 @@ class SpeechRecognitionModel(BaseAudioModel): - _processor: AutoProcessor + _processor: Any # AutoProcessor with Wav2Vec2-specific methods def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id, "model_id is required" self._processor = AutoProcessor.from_pretrained( self._model_id, trust_remote_code=self._settings.trust_remote_code, @@ -27,10 +28,14 @@ def _load_model( use_fast=True, subfolder=self._settings.tokenizer_subfolder or "", ) - model = AutoModelForCTC.from_pretrained( + # AutoProcessor for CTC models has a tokenizer attribute + pad_token_id = getattr( + getattr(self._processor, "tokenizer", None), "pad_token_id", None + ) + model: PreTrainedModel = AutoModelForCTC.from_pretrained( self._model_id, trust_remote_code=self._settings.trust_remote_code, - pad_token_id=self._processor.tokenizer.pad_token_id, + pad_token_id=pad_token_id, ctc_loss_reduction="mean", device_map=self._device, tp_plan=Engine._get_tp_plan(self._settings.parallel), @@ -43,12 +48,13 @@ def _load_model( return model @override - async def __call__( + async def __call__( # type: ignore[override] self, path: str, sampling_rate: int = 16_000, tensor_format: Literal["pt"] = "pt", ) -> str: + assert self._model is not None, "Model must be loaded" audio = self._resample(path, sampling_rate) inputs = self._processor( audio, @@ -57,7 +63,8 @@ async def __call__( ).to(self._device) with inference_mode(): # shape (batch, time_steps, vocab_size) + assert isinstance(self._model, PreTrainedModel) logits = self._model(inputs.input_values).logits predicted_ids = argmax(logits, dim=-1) - transcription = self._processor.batch_decode(predicted_ids)[0] + transcription: str = self._processor.batch_decode(predicted_ids)[0] return transcription diff --git a/src/avalan/model/criteria.py b/src/avalan/model/criteria.py index a6094bfc..4a6b9d0e 100644 --- a/src/avalan/model/criteria.py +++ b/src/avalan/model/criteria.py @@ -1,23 +1,27 @@ from io import StringIO from re import Pattern, compile, escape +from typing import Any -from transformers import AutoTokenizer +from torch import Tensor +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.generation import StoppingCriteria class KeywordStoppingCriteria(StoppingCriteria): + """Stop generation when specific keywords are detected in output.""" + _buffer: StringIO - _tokenizer: AutoTokenizer - _pattern: Pattern + _tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast + _pattern: Pattern[str] _keywords: list[str] _keyword_count: int def __init__( self, keywords: list[str], - tokenizer: AutoTokenizer, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, all_must_be_present: bool = False, - ): + ) -> None: assert keywords escaped_keywords = [escape(k) for k in keywords] self._pattern = compile( @@ -30,7 +34,9 @@ def __init__( self._keywords = keywords self._keyword_count = len(keywords) - def __call__(self, input_ids, scores, **kwargs): + def __call__( + self, input_ids: Tensor, scores: Tensor, **kwargs: Any + ) -> bool: token_id = input_ids[0][-1] token = self._tokenizer.decode(token_id, skip_special_tokens=False) self._buffer.write(token) diff --git a/src/avalan/model/engine.py b/src/avalan/model/engine.py index a8a8416f..7728c922 100644 --- a/src/avalan/model/engine.py +++ b/src/avalan/model/engine.py @@ -53,12 +53,14 @@ class Engine(ABC): + """Base class for model engines.""" + _device: str _logger: Logger _model_id: str | None _settings: EngineSettings - _transformers_logging_logger: Logger - _transformers_logging_level: int + _transformers_logging_logger: Logger | None + _transformers_logging_level: int | None _loaded_model: bool = False _loaded_tokenizer: bool = False _tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None @@ -117,7 +119,7 @@ def _get_distributed_config( ) -> dict[str, object] | None: if distributed_config is None: return None - config = {"enable_expert_parallel": False} + config: dict[str, object] = {"enable_expert_parallel": False} config.update(distributed_config) return config @@ -139,16 +141,16 @@ def __init__( if self._settings.device else Engine.get_default_device() ) - self._transformers_logging_logger = ( - transformers_logging.get_logger() - if self._settings.change_transformers_logging_level - else None - ) - self._transformers_logging_level = ( - self._transformers_logging_logger.level - if self._settings.change_transformers_logging_level - else None - ) + if self._settings.change_transformers_logging_level: + self._transformers_logging_logger = ( + transformers_logging.get_logger() + ) + self._transformers_logging_level = ( + self._transformers_logging_logger.level + ) + else: + self._transformers_logging_logger = None + self._transformers_logging_level = None auto_load_tokenizer = ( self.uses_tokenizer and self._settings.auto_load_tokenizer @@ -239,11 +241,9 @@ def is_runnable(self, device: str | None = None) -> bool | None: def _load_tokenizer_with_tokens( self, tokenizer_name_or_path: str | None, use_fast: bool = True ) -> PreTrainedTokenizer | PreTrainedTokenizerFast: - raise ( - TokenizerNotSupportedException() - if not self.uses_tokenizer() - else NotImplementedError() - ) + if not self.uses_tokenizer: + raise TokenizerNotSupportedException() + raise NotImplementedError() def __enter__(self): _l = self._log @@ -365,12 +365,22 @@ def _load( self._settings.cache_dir, ) - if self._settings.enable_eval: + if self._settings.enable_eval and isinstance( + self._model, PreTrainedModel + ): _l("Setting model %s in eval mode", self._model_id) self._model.eval() - if self._tokenizer and ( - self._settings.tokens or self._settings.special_tokens + settings = self._settings + has_tokens = ( + hasattr(settings, "tokens") and settings.tokens # type: ignore[union-attr] + ) or ( + hasattr(settings, "special_tokens") and settings.special_tokens # type: ignore[union-attr] + ) + if ( + self._tokenizer + and has_tokens + and isinstance(self._model, PreTrainedModel) ): total_tokens = len(self._tokenizer) _l( @@ -407,20 +417,23 @@ def _load( if self._model and not self._config: config: ModelConfig | SentenceTransformerModelConfig | None = None - mc = ( - self._model.config - if hasattr(self._model, "config") - else ( - self._model[0].auto_model.config - if is_sentence_transformer - else None - ) - ) + mc = None + if hasattr(self._model, "config"): + mc = self._model.config # type: ignore[union-attr] + elif is_sentence_transformer and hasattr( + self._model, "__getitem__" + ): + first_module = self._model[0] # type: ignore[index] + if hasattr(first_module, "auto_model"): + mc = first_module.auto_model.config if mc: + attr_map = getattr(mc, "attribute_map", None) + keys_ignore = getattr(mc, "keys_to_ignore_at_inference", None) + torch_dt = getattr(mc, "torch_dtype", float32) config = ModelConfig( architectures=getattr(mc, "architectures", None), - attribute_map=getattr(mc, "attribute_map", None), + attribute_map=attr_map if attr_map else {}, bos_token_id=getattr(mc, "bos_token_id", None), bos_token=( self._tokenizer.decode(mc.bos_token_id) @@ -448,9 +461,7 @@ def _load( else None ), keys_to_ignore_at_inference=( - mc.keys_to_ignore_at_inference - if hasattr(mc, "keys_to_ignore_at_inference") - else None + keys_ignore if keys_ignore else [] ), loss_type=( mc.loss_type if hasattr(mc, "loss_type") else None @@ -472,9 +483,9 @@ def _load( else None ), num_labels=getattr(mc, "num_labels", None), - output_attentions=getattr(mc, "output_attentions", None), + output_attentions=getattr(mc, "output_attentions", False), output_hidden_states=getattr( - mc, "output_hidden_states", None + mc, "output_hidden_states", False ), pad_token_id=getattr(mc, "pad_token_id", None), pad_token=( @@ -490,7 +501,7 @@ def _load( else None ), state_size=( - len(self._model.state_dict().keys()) + len(self._model.state_dict().keys()) # type: ignore[union-attr] if hasattr(self._model, "state_dict") and self._model.state_dict else 0 @@ -498,31 +509,41 @@ def _load( task_specific_params=getattr( mc, "task_specific_params", None ), - torch_dtype=( - str(mc.torch_dtype) - if hasattr(mc, "torch_dtype") - else None - ), + torch_dtype=torch_dt, vocab_size=( mc.vocab_size if hasattr(mc, "vocab_size") else None ), tokenizer_class=getattr(mc, "tokenizer_class", None), ) - if is_sentence_transformer and config: + if ( + is_sentence_transformer + and config + and isinstance(config, ModelConfig) + ): config = SentenceTransformerModelConfig( - backend=self._model.backend, - similarity_function=self._model.similarity_fn_name, - truncate_dimension=self._model.truncate_dim, + backend=getattr(self._model, "backend", "torch"), + similarity_function=getattr( + self._model, "similarity_fn_name", None + ), + truncate_dimension=getattr( + self._model, "truncate_dim", None + ), transformer_model_config=config, ) self._config = config if self._tokenizer and not self._tokenizer_config: + settings = self._settings + tokens_list = ( + settings.tokens # type: ignore[union-attr] + if hasattr(settings, "tokens") + else None + ) self._tokenizer_config = TokenizerConfig( name_or_path=self._tokenizer.name_or_path, - tokens=self._settings.tokens, + tokens=tokens_list, special_tokens=self._tokenizer.all_special_tokens, tokenizer_model_max_length=getattr( self._tokenizer, "model_max_length", 0 @@ -554,11 +575,11 @@ def _get_device_memory(device: str) -> int: ) return cuda.get_device_properties(index).total_memory - from psutil import virtual_memory + from psutil import virtual_memory # type: ignore[import-untyped] if device == "mps" and mps.is_available(): - return virtual_memory().total - return virtual_memory().total + return int(virtual_memory().total) + return int(virtual_memory().total) def _log(self, message: str, *args: object) -> None: self._logger.debug( diff --git a/src/avalan/model/hubs/huggingface.py b/src/avalan/model/hubs/huggingface.py index 16885b9f..b4c32ea1 100644 --- a/src/avalan/model/hubs/huggingface.py +++ b/src/avalan/model/hubs/huggingface.py @@ -21,6 +21,8 @@ class HuggingfaceHub: + """Interface for interacting with the Hugging Face Hub.""" + DEFAULT_ENDPOINT: str = "https://huggingface.co" DEFAULT_CACHE_DIR: str = expanduser( getenv("HF_HUB_CACHE") or "~/.cache/huggingface/hub" @@ -52,7 +54,7 @@ def __init__( endpoint=endpoint, token=access_token, library_name=name(), - library_version=version(), + library_version=str(version()), ) self._cache_dir = expanduser(cache_dir) self._domain = urlparse(endpoint).netloc @@ -60,7 +62,7 @@ def __init__( def cache_delete( self, model_id: str, revisions: list[str] | None = None - ) -> (HubCacheDeletion | None, Callable[[], None] | None): + ) -> tuple[HubCacheDeletion | None, Callable[[], None] | None]: scan_results = scan_cache_dir(self._cache_dir) delete_revisions = [ revision.commit_hash @@ -126,7 +128,9 @@ def cache_scan( ) for info in scan_results.repos ], - key=lambda m: m.size_on_disk if sort_models_by_size else m.name, + key=lambda m: ( + m.size_on_disk if sort_models_by_size else m.model_id + ), reverse=sort_models_by_size, ) return model_caches @@ -143,7 +147,7 @@ def download( model_id: str, *, workers: int = 8, - tqdm_class: type[tqdm] | Callable[..., tqdm] | None = None, + tqdm_class: type[tqdm] | None = None, local_dir: str | None = None, local_dir_use_symlinks: bool | None = None, ) -> str: @@ -183,8 +187,8 @@ def model_url(self, model_id: str) -> str: def models( self, filter: str | list[str] | None = None, - name: str | list[str] | None = None, - search: str | list[str] | None = None, + name: str | None = None, + search: str | None = None, *, library: str | list[str] | None = None, author: str | None = None, @@ -226,6 +230,7 @@ def user(self) -> User: @staticmethod def _model(model_info: ModelInfo) -> Model: + now = datetime.now() model = Model( id=model_info.id, parameters=( @@ -255,7 +260,7 @@ def _model(model_info: ModelInfo) -> Model: else None ), pipeline_tag=model_info.pipeline_tag, - tags=model_info.tags, + tags=model_info.tags or [], architectures=( model_info.config["architectures"] if model_info.config and "architectures" in model_info.config @@ -279,14 +284,18 @@ def _model(model_info: ModelInfo) -> Model: else None ), gated=model_info.gated, - private=model_info.private, + private=model_info.private or False, disabled=model_info.disabled, - last_downloads=model_info.downloads, - downloads=model_info.downloads_all_time or model_info.downloads, - likes=model_info.likes, + last_downloads=model_info.downloads or 0, + downloads=model_info.downloads_all_time + or model_info.downloads + or 0, + likes=model_info.likes or 0, ranking=model_info.trending_score, - author=model_info.author, - created_at=model_info.created_at, - updated_at=model_info.last_modified or model_info.created_at, + author=model_info.author or "", + created_at=model_info.created_at or now, + updated_at=model_info.last_modified + or model_info.created_at + or now, ) return model diff --git a/src/avalan/model/manager.py b/src/avalan/model/manager.py index 4977d4ac..e29755fa 100644 --- a/src/avalan/model/manager.py +++ b/src/avalan/model/manager.py @@ -47,7 +47,7 @@ from logging import Logger from os import environ from time import perf_counter -from typing import TYPE_CHECKING, Any, TypeAlias, get_args +from typing import TYPE_CHECKING, Any, TypeAlias, cast, get_args from urllib.parse import parse_qsl, urlparse if TYPE_CHECKING: @@ -124,12 +124,13 @@ async def __aexit__( exc_value: BaseException | None, traceback: Any | None, ) -> bool: - return await self._stack.__aexit__(exc_type, exc_value, traceback) + result = await self._stack.__aexit__(exc_type, exc_value, traceback) + return result if result is not None else False - async def __call__( + async def __call__( # type: ignore[override] self, model_task: ModelCall, - ): + ) -> Any: modality = model_task.operation.modality self._logger.info("ModelManager call process started for %s", modality) @@ -246,7 +247,9 @@ def load( weight_type: WeightType = "auto", ) -> ModelType: if "backend" in engine_uri.params: - backend = Backend(engine_uri.params["backend"]) + backend_value = engine_uri.params["backend"] + assert isinstance(backend_value, str), "backend must be a string" + backend = Backend(backend_value) engine_settings_args = dict( base_url=base_url, cache_dir=self._hub.cache_dir, @@ -294,6 +297,9 @@ def load_engine( if modality is Modality.EMBEDDING: from ..model.nlp.sentence import SentenceTransformerModel + assert ( + engine_uri.model_id is not None + ), "model_id is required for embedding modality" model = SentenceTransformerModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -322,9 +328,14 @@ def parse_uri(uri: str) -> EngineUri: f"Invalid scheme {parsed.scheme!r}, expected 'ai'" ) - vendor = parsed.hostname - if not vendor or vendor not in get_args(Vendor) or vendor == "local": - vendor = None + vendor_raw = parsed.hostname + vendor: Vendor | None = None + if ( + vendor_raw + and vendor_raw in get_args(Vendor) + and vendor_raw != "local" + ): + vendor = cast(Vendor, vendor_raw) use_host = bool(vendor) path_prefixed = parsed.path.startswith("/") params: dict[str, str | int | float | bool] = {} diff --git a/src/avalan/model/modalities/audio.py b/src/avalan/model/modalities/audio.py index f626bd8e..84d68a56 100644 --- a/src/avalan/model/modalities/audio.py +++ b/src/avalan/model/modalities/audio.py @@ -33,6 +33,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for audio classification" return AudioClassificationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -90,6 +93,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for speech recognition" return SpeechRecognitionModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -147,6 +153,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for text to speech" return TextToSpeechModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -182,19 +189,24 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: + audio_params = operation.parameters["audio"] assert ( - operation.parameters["audio"] - and operation.parameters["audio"].path - and operation.parameters["audio"].sampling_rate + audio_params + and audio_params.path + and audio_params.sampling_rate + and operation.generation_settings ) + assert isinstance(operation.input, str), "prompt must be a string" + max_tokens = operation.generation_settings.max_new_tokens + assert max_tokens is not None, "max_new_tokens is required" return await model( - path=operation.parameters["audio"].path, prompt=operation.input, - max_new_tokens=operation.generation_settings.max_new_tokens, - reference_path=operation.parameters["audio"].reference_path, - reference_text=operation.parameters["audio"].reference_text, - sampling_rate=operation.parameters["audio"].sampling_rate, + path=audio_params.path, + max_new_tokens=max_tokens, + reference_path=audio_params.reference_path, + reference_text=audio_params.reference_text, + sampling_rate=audio_params.sampling_rate, ) @@ -210,6 +222,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for audio generation" return AudioGenerationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -243,14 +256,19 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: + audio_params = operation.parameters["audio"] assert ( operation.input - and operation.parameters["audio"] - and operation.parameters["audio"].path + and audio_params + and audio_params.path + and operation.generation_settings ) + assert isinstance(operation.input, str), "prompt must be a string" + max_tokens = operation.generation_settings.max_new_tokens + assert max_tokens is not None, "max_new_tokens is required" return await model( operation.input, - operation.parameters["audio"].path, - operation.generation_settings.max_new_tokens, + audio_params.path, + max_tokens, ) diff --git a/src/avalan/model/modalities/registry.py b/src/avalan/model/modalities/registry.py index a02d0cf1..fa4944c1 100644 --- a/src/avalan/model/modalities/registry.py +++ b/src/avalan/model/modalities/registry.py @@ -6,6 +6,7 @@ Input, Modality, Operation, + OperationParameters, ReasoningSettings, ReasoningTag, TransformerEngineSettings, @@ -15,12 +16,13 @@ from argparse import Namespace from collections.abc import Callable from contextlib import AsyncExitStack -from inspect import isclass from logging import Logger from typing import Any, Protocol class ModalityHandler(Protocol): + """Protocol for modality handler implementations.""" + async def __call__( self, engine_uri: EngineUri, @@ -46,16 +48,18 @@ def get_operation_from_arguments( class ModalityRegistry: + """Registry for modality handlers.""" + _handlers: dict[Modality, ModalityHandler] = {} @classmethod def register( cls, modality: Modality - ) -> Callable[[ModalityHandler | type], ModalityHandler]: - def decorator(handler: ModalityHandler | type) -> ModalityHandler: - cls._handlers[modality] = ( - handler() if isclass(handler) else handler - ) + ) -> Callable[[type[ModalityHandler]], type[ModalityHandler]]: + def decorator( + handler: type[ModalityHandler], + ) -> type[ModalityHandler]: + cls._handlers[modality] = handler() return handler return decorator @@ -130,11 +134,12 @@ def get_operation_from_arguments( try: handler = cls.get(modality) except NotImplementedError: + empty_params: OperationParameters = {} return Operation( generation_settings=settings, input=input_string, modality=modality, - parameters=None, + parameters=empty_params, requires_input=False, ) return handler.get_operation_from_arguments( diff --git a/src/avalan/model/modalities/text.py b/src/avalan/model/modalities/text.py index 4f42dc96..434add02 100644 --- a/src/avalan/model/modalities/text.py +++ b/src/avalan/model/modalities/text.py @@ -55,6 +55,8 @@ def _get_mlx_model() -> type[TextGenerationModel] | None: @ModalityRegistry.register(Modality.TEXT_GENERATION) class TextGenerationModality: + """Handler for text generation modality.""" + def load_engine( self, engine_uri: EngineUri, @@ -62,7 +64,8 @@ def load_engine( logger: Logger, exit_stack: AsyncExitStack, ) -> TextGenerationModel: - model_load_args = dict( + assert engine_uri.model_id, "model_id is required for text generation" + model_load_args: dict[str, Any] = dict( model_id=engine_uri.model_id, settings=engine_settings, logger=logger, @@ -80,75 +83,71 @@ def load_engine( return mlx_loader(**model_load_args) case Backend.VLLM: - from ..nlp.text.vllm import VllmModel as Loader + from ..nlp.text.vllm import VllmModel - return Loader(**model_load_args) + return VllmModel(**model_load_args) case _: return TextGenerationModel(**model_load_args) match engine_uri.vendor: case "anthropic": - from ..nlp.text.vendor.anthropic import ( - AnthropicModel as Loader, - ) + from ..nlp.text.vendor.anthropic import AnthropicModel - return Loader(**model_load_args, exit_stack=exit_stack) + return AnthropicModel(**model_load_args, exit_stack=exit_stack) case "openai": - from ..nlp.text.vendor.openai import OpenAIModel as Loader + from ..nlp.text.vendor.openai import OpenAIModel - return Loader(**model_load_args, exit_stack=exit_stack) + return OpenAIModel(**model_load_args, exit_stack=exit_stack) case "bedrock": - from ..nlp.text.vendor.bedrock import BedrockModel as Loader + from ..nlp.text.vendor.bedrock import BedrockModel - return Loader(**model_load_args, exit_stack=exit_stack) + return BedrockModel(**model_load_args, exit_stack=exit_stack) case "openrouter": - from ..nlp.text.vendor.openrouter import ( - OpenRouterModel as Loader, - ) + from ..nlp.text.vendor.openrouter import OpenRouterModel - return Loader(**model_load_args, exit_stack=exit_stack) + return OpenRouterModel( + **model_load_args, exit_stack=exit_stack + ) case "anyscale": - from ..nlp.text.vendor.anyscale import AnyScaleModel as Loader + from ..nlp.text.vendor.anyscale import AnyScaleModel - return Loader(**model_load_args, exit_stack=exit_stack) + return AnyScaleModel(**model_load_args, exit_stack=exit_stack) case "together": - from ..nlp.text.vendor.together import TogetherModel as Loader + from ..nlp.text.vendor.together import TogetherModel - return Loader(**model_load_args, exit_stack=exit_stack) + return TogetherModel(**model_load_args, exit_stack=exit_stack) case "deepseek": - from ..nlp.text.vendor.deepseek import DeepSeekModel as Loader + from ..nlp.text.vendor.deepseek import DeepSeekModel - return Loader(**model_load_args, exit_stack=exit_stack) + return DeepSeekModel(**model_load_args, exit_stack=exit_stack) case "deepinfra": - from ..nlp.text.vendor.deepinfra import ( - DeepInfraModel as Loader, - ) + from ..nlp.text.vendor.deepinfra import DeepInfraModel - return Loader(**model_load_args, exit_stack=exit_stack) + return DeepInfraModel(**model_load_args, exit_stack=exit_stack) case "groq": - from ..nlp.text.vendor.groq import GroqModel as Loader + from ..nlp.text.vendor.groq import GroqModel - return Loader(**model_load_args, exit_stack=exit_stack) + return GroqModel(**model_load_args, exit_stack=exit_stack) case "ollama": - from ..nlp.text.vendor.ollama import OllamaModel as Loader + from ..nlp.text.vendor.ollama import OllamaModel - return Loader(**model_load_args, exit_stack=exit_stack) + return OllamaModel(**model_load_args, exit_stack=exit_stack) case "huggingface": - from ..nlp.text.vendor.huggingface import ( - HuggingfaceModel as Loader, - ) + from ..nlp.text.vendor.huggingface import HuggingfaceModel - return Loader(**model_load_args, exit_stack=exit_stack) - case "hyperbolic": - from ..nlp.text.vendor.hyperbolic import ( - HyperbolicModel as Loader, + return HuggingfaceModel( + **model_load_args, exit_stack=exit_stack ) + case "hyperbolic": + from ..nlp.text.vendor.hyperbolic import HyperbolicModel - return Loader(**model_load_args, exit_stack=exit_stack) + return HyperbolicModel( + **model_load_args, exit_stack=exit_stack + ) case "litellm": - from ..nlp.text.vendor.litellm import LiteLLMModel as Loader + from ..nlp.text.vendor.litellm import LiteLLMModel - return Loader(**model_load_args, exit_stack=exit_stack) + return LiteLLMModel(**model_load_args, exit_stack=exit_stack) raise NotImplementedError() def get_operation_from_arguments( @@ -159,7 +158,7 @@ def get_operation_from_arguments( ) -> Operation: parameters = OperationParameters( text=OperationTextParameters( - manual_sampling=args.display_tokens or 0, + manual_sampling=bool(args.display_tokens), pick_tokens=( 10 if args.display_tokens and args.display_tokens > 0 @@ -186,7 +185,9 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: - assert operation.input and operation.parameters["text"] + assert operation.input, "operation.input is required" + text_params = operation.parameters["text"] + assert text_params, "text parameters are required" criteria = _stopping_criteria(operation, model) mlx_model = _get_mlx_model() @@ -194,21 +195,19 @@ async def __call__( if engine_uri.is_local and not is_mlx: return await model( operation.input, - system_prompt=operation.parameters["text"].system_prompt, - developer_prompt=operation.parameters["text"].developer_prompt, + system_prompt=text_params.system_prompt, + developer_prompt=text_params.developer_prompt, settings=operation.generation_settings, stopping_criterias=[criteria] if criteria else None, - manual_sampling=operation.parameters["text"].manual_sampling, - pick=operation.parameters["text"].pick_tokens, - skip_special_tokens=operation.parameters[ - "text" - ].skip_special_tokens, + manual_sampling=text_params.manual_sampling or False, + pick=text_params.pick_tokens, + skip_special_tokens=text_params.skip_special_tokens or False, tool=tool, ) return await model( operation.input, - system_prompt=operation.parameters["text"].system_prompt, - developer_prompt=operation.parameters["text"].developer_prompt, + system_prompt=text_params.system_prompt, + developer_prompt=text_params.developer_prompt, settings=operation.generation_settings, tool=tool, ) @@ -226,6 +225,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for question answering" return QuestionAnsweringModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -286,6 +288,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for sequence classification" return SequenceClassificationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -302,7 +307,7 @@ def get_operation_from_arguments( generation_settings=settings, input=input_string, modality=Modality.TEXT_SEQUENCE_CLASSIFICATION, - parameters=None, + parameters=OperationParameters(), requires_input=True, ) @@ -329,6 +334,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for sequence to sequence" return SequenceToSequenceModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -362,6 +370,7 @@ async def __call__( tool: ToolManager | None = None, ) -> Any: assert operation.input and operation.parameters["text"] + assert operation.generation_settings, "generation_settings is required" criteria = _stopping_criteria(operation, model) return await model( operation.input, @@ -382,6 +391,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for token classification" return TokenClassificationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -437,6 +449,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for translation" return TranslationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -472,22 +485,25 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: + text_params = operation.parameters["text"] assert ( operation.input - and operation.parameters["text"] - and operation.parameters["text"].language_source - and operation.parameters["text"].language_destination + and text_params + and text_params.language_source + and text_params.language_destination + and operation.generation_settings ) criteria = _stopping_criteria(operation, model) + skip_special_tokens = ( + text_params.skip_special_tokens + if text_params.skip_special_tokens is not None + else True + ) return await model( operation.input, - source_language=operation.parameters["text"].language_source, - destination_language=operation.parameters[ - "text" - ].language_destination, + source_language=text_params.language_source, + destination_language=text_params.language_destination, settings=operation.generation_settings, stopping_criterias=[criteria] if criteria else None, - skip_special_tokens=operation.parameters[ - "text" - ].skip_special_tokens, + skip_special_tokens=skip_special_tokens, ) diff --git a/src/avalan/model/modalities/vision.py b/src/avalan/model/modalities/vision.py index 2f0d30e5..366cb3f6 100644 --- a/src/avalan/model/modalities/vision.py +++ b/src/avalan/model/modalities/vision.py @@ -39,6 +39,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for encoder decoder" return VisionEncoderDecoderModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -72,17 +73,19 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: - assert ( - operation.parameters["vision"] - and operation.parameters["vision"].path - ) + vision_params = operation.parameters["vision"] + assert vision_params and vision_params.path + prompt = operation.input if isinstance(operation.input, str) else None + skip_special_tokens = ( + vision_params.skip_special_tokens + if vision_params.skip_special_tokens is not None + else True + ) return await model( - operation.parameters["vision"].path, - prompt=operation.input, - skip_special_tokens=operation.parameters[ - "vision" - ].skip_special_tokens, + vision_params.path, + prompt=prompt, + skip_special_tokens=skip_special_tokens, ) @@ -98,6 +101,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for image classification" return ImageClassificationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -150,6 +156,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for image to text" return ImageToTextModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -183,16 +190,17 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: - assert ( - operation.parameters["vision"] - and operation.parameters["vision"].path - ) + vision_params = operation.parameters["vision"] + assert vision_params and vision_params.path + skip_special_tokens = ( + vision_params.skip_special_tokens + if vision_params.skip_special_tokens is not None + else True + ) return await model( - operation.parameters["vision"].path, - skip_special_tokens=operation.parameters[ - "vision" - ].skip_special_tokens, + vision_params.path, + skip_special_tokens=skip_special_tokens, ) @@ -208,6 +216,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for image text to text" return ImageTextToTextModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -247,18 +258,17 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: - assert ( - operation.parameters["vision"] - and operation.parameters["vision"].path - ) + vision_params = operation.parameters["vision"] + assert vision_params and vision_params.path + assert isinstance(operation.input, str), "prompt must be a string" return await model( - operation.parameters["vision"].path, + vision_params.path, operation.input, - system_prompt=operation.parameters["vision"].system_prompt, - developer_prompt=operation.parameters["vision"].developer_prompt, + system_prompt=vision_params.system_prompt, + developer_prompt=vision_params.developer_prompt, settings=operation.generation_settings, - width=operation.parameters["vision"].width, + width=vision_params.width, ) @@ -274,6 +284,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for object detection" return ObjectDetectionModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -335,6 +346,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for text to image" return TextToImageModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -403,6 +415,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for text to animation" return TextToAnimationModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -471,6 +486,7 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert engine_uri.model_id, "model_id is required for text to video" return TextToVideoModel( model_id=engine_uri.model_id, settings=engine_settings, @@ -517,13 +533,10 @@ async def __call__( operation: Operation, tool: ToolManager | None = None, ) -> Any: - assert ( - operation.input - and operation.parameters["vision"] - and operation.parameters["vision"].path - ) vision = operation.parameters["vision"] - kwargs = { + assert operation.input and vision and vision.path + assert isinstance(operation.input, str), "prompt must be a string" + kwargs: dict[str, Any] = { "reference_path": vision.reference_path, "negative_prompt": vision.negative_prompt, "height": vision.height, @@ -555,6 +568,9 @@ def load_engine( _ = exit_stack if not engine_uri.is_local: raise NotImplementedError() + assert ( + engine_uri.model_id + ), "model_id is required for semantic segmentation" return SemanticSegmentationModel( model_id=engine_uri.model_id, settings=engine_settings, diff --git a/src/avalan/model/nlp/__init__.py b/src/avalan/model/nlp/__init__.py index 2a382069..74e5958a 100644 --- a/src/avalan/model/nlp/__init__.py +++ b/src/avalan/model/nlp/__init__.py @@ -4,6 +4,7 @@ from abc import ABC from contextlib import nullcontext +from typing import Any from torch import ( Tensor, @@ -21,7 +22,7 @@ def _generate_output( settings: GenerationSettings, stopping_criterias: list[StoppingCriteria] | None = None, streamer: AsyncTextIteratorStreamer | None = None, - ): + ) -> Any: eos_token_id = ( settings.eos_token_id if settings.eos_token_id @@ -31,7 +32,8 @@ def _generate_output( else None ) ) - generation_kwargs = { + assert self._tokenizer is not None, "Tokenizer must be loaded" + generation_kwargs: dict[str, Any] = { "bos_token_id": settings.bos_token_id, "diversity_penalty": settings.diversity_penalty, "do_sample": settings.do_sample, @@ -73,35 +75,42 @@ def _generate_output( if settings.use_inputs_attention_mask: attention_mask = ( inputs.get("attention_mask", None) - if isinstance(inputs, BatchEncoding) + if isinstance(inputs, (BatchEncoding, dict)) else getattr(inputs, "attention_mask", None) ) if attention_mask is not None: assert isinstance(attention_mask, Tensor) - assert attention_mask.shape == inputs["input_ids"].shape + if isinstance(inputs, (BatchEncoding, dict)): + input_ids = inputs.get("input_ids") + if input_ids is not None: + assert attention_mask.shape == input_ids.shape generation_kwargs["attention_mask"] = attention_mask - if ( + if isinstance(inputs, (BatchEncoding, dict)) and ( not settings.use_inputs_attention_mask or attention_mask is not None ): - inputs.pop("attention_mask", None) + inputs.pop("attention_mask", None) # type: ignore[union-attr] if settings.forced_bos_token_id or settings.forced_eos_token_id: del generation_kwargs["bos_token_id"] del generation_kwargs["eos_token_id"] + assert self._model is not None, "Model must be loaded" + assert hasattr( + self._model, "generate" + ), "Model must support generate()" with ( inference_mode() if not settings.enable_gradient_calculation else nullcontext() ): outputs = ( - self._model.generate( + self._model.generate( # type: ignore[operator] inputs, tokenizer=self._tokenizer, **generation_kwargs ) - if isinstance(inputs, Tensor) - else self._model.generate( + if isinstance(inputs, Tensor) # type: ignore[arg-type] + else self._model.generate( # type: ignore[operator] **inputs, tokenizer=self._tokenizer, **generation_kwargs ) ) diff --git a/src/avalan/model/nlp/question.py b/src/avalan/model/nlp/question.py index 3e4c4d00..0e7620f8 100644 --- a/src/avalan/model/nlp/question.py +++ b/src/avalan/model/nlp/question.py @@ -4,7 +4,7 @@ from ...model.nlp import BaseNLPModel from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import argmax, inference_mode @@ -24,7 +24,8 @@ def supports_token_streaming(self) -> bool: def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: - model = AutoModelForQuestionAnswering.from_pretrained( + assert self._model_id is not None, "Model ID must be set" + model: PreTrainedModel = AutoModelForQuestionAnswering.from_pretrained( self._model_id, cache_dir=self._settings.cache_dir, subfolder=self._settings.subfolder or "", @@ -42,7 +43,7 @@ def _load_model( ) return model - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -56,14 +57,18 @@ def _tokenize_input( + f"{self._model_id} does not support chat " + "templates" ) + assert self._tokenizer is not None, "Tokenizer must be loaded" + assert self._model is not None, "Model must be loaded" _l = self._log _l(f"Tokenizing input {input}") - inputs = self._tokenizer(input, context, return_tensors=tensor_format) - inputs = inputs.to(self._model.device) + inputs: BatchEncoding = self._tokenizer( + input, context, return_tensors=tensor_format + ) + inputs = inputs.to(self._model.device) # type: ignore[union-attr] return inputs @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, *, @@ -71,6 +76,7 @@ async def __call__( system_prompt: str | None = None, developer_prompt: str | None = None, skip_special_tokens: bool = True, + **kwargs: Any, ) -> str: assert self._tokenizer, ( f"Model {self._model} can't be executed " @@ -87,13 +93,13 @@ async def __call__( context=context, ) with inference_mode(): - outputs = self._model(**inputs) - start_answer_logits = outputs.start_logits - end_answer_logits = outputs.end_logits + outputs = self._model(**inputs) # type: ignore[operator] + start_answer_logits = outputs.start_logits # type: ignore[union-attr] + end_answer_logits = outputs.end_logits # type: ignore[union-attr] start = argmax(start_answer_logits) end = argmax(end_answer_logits) answer_ids = inputs["input_ids"][0, start : end + 1] - answer = self._tokenizer.decode( + answer: str = self._tokenizer.decode( answer_ids, skip_special_tokens=skip_special_tokens ) return answer diff --git a/src/avalan/model/nlp/sentence.py b/src/avalan/model/nlp/sentence.py index b31daa59..e0c6d88c 100644 --- a/src/avalan/model/nlp/sentence.py +++ b/src/avalan/model/nlp/sentence.py @@ -2,15 +2,12 @@ from ...entities import Input from ...model.engine import Engine from ...model.nlp import BaseNLPModel -from ...model.vendor import TextGenerationVendor from contextlib import nullcontext -from typing import Literal +from typing import Any, Literal -from diffusers import DiffusionPipeline from numpy import ndarray from torch import inference_mode -from transformers import PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding @@ -28,12 +25,13 @@ def uses_tokenizer(self) -> bool: return True def token_count(self, input: str) -> int: + assert self.tokenizer is not None, "Tokenizer must be loaded" token_ids = self.tokenizer.encode(input, add_special_tokens=False) return len(token_ids) def _load_model( self, - ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + ) -> Any: # Returns SentenceTransformer which isn't a PreTrainedModel from sentence_transformers import SentenceTransformer model = SentenceTransformer( @@ -61,7 +59,7 @@ def _load_model( ) return model - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -73,9 +71,12 @@ def _tokenize_input( raise NotImplementedError() @override - async def __call__( - self, input: Input, *args, enable_gradient_calculation: bool = False - ) -> ndarray: + async def __call__( # type: ignore[override] + self, + input: Input, + *args: Any, + enable_gradient_calculation: bool = False, + ) -> ndarray[Any, Any]: assert self._model, ( f"Model {self._model} can't be executed, it " + "needs to be loaded first" @@ -86,7 +87,8 @@ async def __call__( if not enable_gradient_calculation else nullcontext() ): - embeddings = self._model.encode( + # self._model is SentenceTransformer at runtime + embeddings: ndarray[Any, Any] = self._model.encode( # type: ignore[union-attr, operator] input, convert_to_numpy=True, show_progress_bar=False ) return embeddings diff --git a/src/avalan/model/nlp/sequence.py b/src/avalan/model/nlp/sequence.py index e8e7b1b7..ba33e809 100644 --- a/src/avalan/model/nlp/sequence.py +++ b/src/avalan/model/nlp/sequence.py @@ -5,7 +5,7 @@ from ...model.vendor import TextGenerationVendor from dataclasses import replace -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import Tensor, argmax, inference_mode, softmax @@ -30,25 +30,28 @@ def supports_token_streaming(self) -> bool: def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: - model = AutoModelForSequenceClassification.from_pretrained( - self._model_id, - cache_dir=self._settings.cache_dir, - subfolder=self._settings.subfolder or "", - attn_implementation=self._settings.attention, - trust_remote_code=self._settings.trust_remote_code, - torch_dtype=Engine.weight(self._settings.weight_type), - state_dict=self._settings.state_dict, - local_files_only=self._settings.local_files_only, - token=self._settings.access_token, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), + assert self._model_id is not None, "Model ID must be set" + model: PreTrainedModel = ( + AutoModelForSequenceClassification.from_pretrained( + self._model_id, + cache_dir=self._settings.cache_dir, + subfolder=self._settings.subfolder or "", + attn_implementation=self._settings.attention, + trust_remote_code=self._settings.trust_remote_code, + torch_dtype=Engine.weight(self._settings.weight_type), + state_dict=self._settings.state_dict, + local_files_only=self._settings.local_files_only, + token=self._settings.access_token, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + ) ) return model - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -62,14 +65,20 @@ def _tokenize_input( + f"{self._model_id} does not support chat " + "templates" ) + assert self._tokenizer is not None, "Tokenizer must be loaded" + assert self._model is not None, "Model must be loaded" _l = self._log _l(f"Tokenizing input {input}") - inputs = self._tokenizer(input, return_tensors=tensor_format) - inputs = inputs.to(self._model.device) + inputs: BatchEncoding = self._tokenizer( + input, return_tensors=tensor_format + ) + inputs = inputs.to(self._model.device) # type: ignore[union-attr] return inputs @override - async def __call__(self, input: Input) -> str: + async def __call__( # type: ignore[override] + self, input: Input, **kwargs: Any + ) -> str: assert self._tokenizer, ( f"Model {self._model} can't be executed " + "without a tokenizer loaded first" @@ -80,11 +89,13 @@ async def __call__(self, input: Input) -> str: ) inputs = self._tokenize_input(input, system_prompt=None, context=None) with inference_mode(): - outputs = self._model(**inputs) + outputs = self._model(**inputs) # type: ignore[operator] # logits shape (batch_size, num_labels) - label_probs = softmax(outputs.logits, dim=-1) - label_id = argmax(label_probs, dim=-1).item() - label = self._model.config.id2label[label_id] + label_probs = softmax(outputs.logits, dim=-1) # type: ignore[union-attr] + label_id = int(argmax(label_probs, dim=-1).item()) + id2label = self._model.config.id2label # type: ignore[union-attr] + assert id2label is not None, "Model config must have id2label" + label: str = id2label[label_id] return label @@ -100,7 +111,8 @@ def supports_token_streaming(self) -> bool: def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: - model = AutoModelForSeq2SeqLM.from_pretrained( + assert self._model_id is not None, "Model ID must be set" + model: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained( self._model_id, cache_dir=self._settings.cache_dir, subfolder=self._settings.subfolder or "", @@ -118,7 +130,7 @@ def _load_model( ) return model - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -132,18 +144,24 @@ def _tokenize_input( + f"{self._model_id} does not support chat " + "templates" ) + assert self._tokenizer is not None, "Tokenizer must be loaded" + assert self._model is not None, "Model must be loaded" _l = self._log _l(f"Tokenizing input {input}") - inputs = self._tokenizer(input, return_tensors=tensor_format) - inputs = inputs.to(self._model.device) - return inputs["input_ids"] + inputs: BatchEncoding = self._tokenizer( + input, return_tensors=tensor_format + ) + inputs = inputs.to(self._model.device) # type: ignore[union-attr] + input_ids: Tensor = inputs["input_ids"] + return input_ids @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, settings: GenerationSettings, stopping_criterias: list[StoppingCriteria] | None = None, + **kwargs: Any, ) -> str: assert self._tokenizer, ( f"Model {self._model} can't be executed " @@ -166,12 +184,12 @@ async def __call__( settings, stopping_criterias, ) - return self._tokenizer.decode(output_ids[0], skip_special_tokens=True) + return self._tokenizer.decode(output_ids[0], skip_special_tokens=True) # type: ignore[return-value] class TranslationModel(SequenceToSequenceModel): @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, source_language: str, @@ -179,6 +197,7 @@ async def __call__( settings: GenerationSettings, stopping_criterias: list[StoppingCriteria] | None = None, skip_special_tokens: bool = True, + **kwargs: Any, ) -> str: assert self._tokenizer, ( f"Model {self._model} can't be executed " @@ -198,8 +217,8 @@ async def __call__( self._tokenizer, "lang_code_to_id" ) - previous_language = self._tokenizer.src_lang - self._tokenizer.src_lang = source_language + previous_language: str = self._tokenizer.src_lang # type: ignore[attr-defined] + self._tokenizer.src_lang = source_language # type: ignore[attr-defined] inputs = self._tokenize_input(input, system_prompt=None, context=None) generation_settings = replace( settings, @@ -207,7 +226,7 @@ async def __call__( repetition_penalty=1.0, use_cache=True, temperature=None, - forced_bos_token_id=self._tokenizer.lang_code_to_id[ + forced_bos_token_id=self._tokenizer.lang_code_to_id[ # type: ignore[attr-defined] destination_language ], ) @@ -216,8 +235,8 @@ async def __call__( generation_settings, stopping_criterias, ) - text = self._tokenizer.decode( + text: str = self._tokenizer.decode( output_ids[0], skip_special_tokens=skip_special_tokens ) - self._tokenizer.src_lang = previous_language + self._tokenizer.src_lang = previous_language # type: ignore[attr-defined] return text diff --git a/src/avalan/model/nlp/text/generation.py b/src/avalan/model/nlp/text/generation.py index 181896fd..973c7a1f 100644 --- a/src/avalan/model/nlp/text/generation.py +++ b/src/avalan/model/nlp/text/generation.py @@ -25,7 +25,7 @@ from importlib.util import find_spec from logging import Logger, getLogger from threading import Thread -from typing import AsyncGenerator, Literal +from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal from diffusers import DiffusionPipeline from torch import Tensor, log_softmax, softmax, topk @@ -41,15 +41,18 @@ from transformers.generation import StoppingCriteria from transformers.tokenization_utils_base import BatchEncoding +if TYPE_CHECKING: + from types import ModuleType + _TOOL_MESSAGE_PARSER = ToolCallParser() class TextGenerationModel(BaseNLPModel): _loaders: dict[TextGenerationLoaderClass, type[PreTrainedModel]] = { - "auto": AutoModelForCausalLM, - "gemma3": Gemma3ForConditionalGeneration, - "gpt-oss": GptOssForCausalLM, - "mistral3": Mistral3ForConditionalGeneration, + "auto": AutoModelForCausalLM, # type: ignore[dict-item] + "gemma3": Gemma3ForConditionalGeneration, # type: ignore[dict-item] + "gpt-oss": GptOssForCausalLM, # type: ignore[dict-item] + "mistral3": Mistral3ForConditionalGeneration, # type: ignore[dict-item] } def __init__( @@ -115,11 +118,13 @@ def _load_model( if model_args["quantization_config"] is None: model_args.pop("quantization_config", None) - model = loader.from_pretrained(self._model_id, **model_args) + model: PreTrainedModel = loader.from_pretrained( + self._model_id, **model_args + ) return model @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, system_prompt: str | None = None, @@ -210,7 +215,7 @@ async def _stream_generator( _l = self._log streamer = AsyncTextIteratorStreamer( - self._tokenizer, + self._tokenizer, # type: ignore[arg-type] skip_prompt=True, decode_kwargs={"skip_special_tokens": skip_special_tokens}, ) @@ -250,10 +255,15 @@ def _string_output( settings: GenerationSettings, stopping_criterias: list[StoppingCriteria] | None, skip_special_tokens: bool, - **kwargs, + **kwargs: Any, ) -> str: - input_length = inputs["input_ids"].shape[1] + assert isinstance( + inputs, dict + ), "inputs must be a dict for _string_output" + input_ids = inputs["input_ids"] + input_length: int = input_ids.shape[1] # type: ignore[union-attr] outputs = self._generate_output(inputs, settings, stopping_criterias) + assert self._tokenizer is not None return self._tokenizer.decode( outputs[0][input_length:], skip_special_tokens=skip_special_tokens ) @@ -274,12 +284,13 @@ async def _token_generator( _l = self._log + entmax: ModuleType | None = None enable_entmax = find_spec("entmax") and probability_distribution in [ "entmax", "sparsemax", ] if enable_entmax: - import entmax + import entmax # type: ignore[import-not-found,no-redef] _l( f"Generating up to {settings.max_new_tokens} tokens " @@ -297,11 +308,16 @@ async def _token_generator( ) sequences = outputs.sequences[0] scores = outputs.scores # list of logits for each generated token - start = inputs["input_ids"].shape[1] # where generation began + assert isinstance( + inputs, dict + ), "inputs must be a dict for _token_generator" + input_ids = inputs["input_ids"] + start: int = input_ids.shape[1] # type: ignore[union-attr] generated_sequences = sequences[start:] _l(f"Generated {len(generated_sequences)} sequences") + assert self._tokenizer is not None total_tokens = 0 for step, token_id in enumerate(generated_sequences): _l(f"Got step {step} token {token_id}") @@ -312,42 +328,43 @@ async def _token_generator( logits = tensor[0] # first element in batch dimension # apply probabilty distribution over last tensor layer, vocab_size + temp = settings.temperature if settings.temperature else 1.0 logits_probs = ( log_softmax(logits, dim=-1) if probability_distribution == "log_softmax" else ( - gumbel_softmax( - logits, tau=settings.temperature, hard=False, dim=-1 - ) + gumbel_softmax(logits, tau=temp, hard=False, dim=-1) if probability_distribution == "gumbel_softmax" else ( entmax.sparsemax(logits, dim=-1) if enable_entmax + and entmax is not None and probability_distribution == "sparsemax" else ( entmax.entmax15(logits, dim=-1) if enable_entmax + and entmax is not None and probability_distribution == "entmax" - else softmax(logits / settings.temperature, dim=-1) + else softmax(logits / temp, dim=-1) ) ) ) ) tokens: list[Token] | None = None - if pick > 0: + if pick is not None and pick > 0: picked_logits = topk(logits_probs, pick) picked_logits_ids = picked_logits.indices.tolist() picked_logits_probs = picked_logits.values.tolist() tokens = [ Token( - id=token_id, + id=tid, token=self._tokenizer.decode( - token_id, skip_special_tokens=skip_special_tokens + tid, skip_special_tokens=skip_special_tokens ), probability=picked_logits_probs[i], ) - for i, token_id in enumerate(picked_logits_ids) + for i, tid in enumerate(picked_logits_ids) ] raw_token = TokenDetail( @@ -371,7 +388,7 @@ async def _token_generator( await sleep(0) # and just like that, a generator is an async generator - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -383,11 +400,15 @@ def _tokenize_input( tool: ToolManager | None = None, ) -> dict[str, Tensor] | BatchEncoding | Tensor: _l = self._log + assert self._tokenizer is not None messages = self._messages(input, system_prompt, developer_prompt, tool) + tokenizer = self._tokenizer def _format_content( - content: str | MessageContent | list[MessageContent], + content: str | MessageContent | list[MessageContent] | None, ) -> str | list[dict[str, object]]: + if content is None: + return "" if isinstance(content, str): return content @@ -395,14 +416,14 @@ def _format_content( return content.text if isinstance(content, MessageContentImage): - if self._tokenizer.chat_template: + if tokenizer.chat_template: return [ {"type": "image_url", "image_url": content.image_url} ] return "" if isinstance(content, list): - if self._tokenizer.chat_template: + if tokenizer.chat_template: blocks: list[dict[str, object]] = [] for c in content: if isinstance(c, MessageContentImage): @@ -426,7 +447,7 @@ def _format_content( return str(content) - template_messages = [] + template_messages: list[dict[str, Any]] = [] for message in messages: message_dict = asdict(message) prepared = _TOOL_MESSAGE_PARSER.prepare_message_for_template( @@ -441,7 +462,7 @@ def _format_content( } ) - if not self._tokenizer.chat_template: + if not tokenizer.chat_template: _l("Model does not support template messages, so staying plain") prompt = f"{system_prompt}\n\n" or "" @@ -461,27 +482,30 @@ def _format_content( else "" ) ) - prompt += template_message["content"].strip() + "\n" + content = template_message["content"] + content_str = content if isinstance(content, str) else "" + prompt += content_str.strip() + "\n" - inputs = self._tokenizer( + inputs: Any = tokenizer( prompt, add_special_tokens=True, return_tensors=tensor_format ) else: _l(f"Got {len(template_messages)} template messages") _l(f"Applying chat template to {len(template_messages)} messages") - inputs = self._tokenizer.apply_chat_template( - template_messages, + inputs = tokenizer.apply_chat_template( + template_messages, # type: ignore[arg-type] chat_template=chat_template, - tools=tool.json_schemas() if tool else None, - **(chat_template_settings or {}), + tools=tool.json_schemas() if tool else None, # type: ignore[arg-type] + **(chat_template_settings or {}), # type: ignore[arg-type] return_tensors=tensor_format, ) - if hasattr(self._model, "device"): - _l(f"Translating inputs to {self._model.device}") - inputs = inputs.to(self._model.device) - return inputs + if self._model is not None and hasattr(self._model, "device"): + device = getattr(self._model, "device", None) + _l(f"Translating inputs to {device}") + inputs = inputs.to(device) + return inputs # type: ignore[return-value,no-any-return] def _messages( self, @@ -490,24 +514,27 @@ def _messages( developer_prompt: str | None = None, tool: ToolManager | None = None, ) -> list[Message]: + messages: list[Message] if isinstance(input, str): - input = Message(role=MessageRole.USER, content=input) + messages = [Message(role=MessageRole.USER, content=input)] + elif isinstance(input, Message): + messages = [input] elif isinstance(input, list): for m in input: assert isinstance(m, Message) - elif not isinstance(input, Message): + messages = list(input) # type: ignore[arg-type] + else: raise ValueError(input) - messages = [input] if not isinstance(input, list) else input - if developer_prompt: messages = [ - Message(role=MessageRole.DEVELOPER, content=developer_prompt) - ] + messages + Message(role=MessageRole.DEVELOPER, content=developer_prompt), + *messages, + ] if system_prompt: messages = [ - Message(role=MessageRole.SYSTEM, content=system_prompt) - ] + messages + Message(role=MessageRole.SYSTEM, content=system_prompt), + *messages, + ] - assert isinstance(messages, list) return messages diff --git a/src/avalan/model/nlp/text/mlxlm.py b/src/avalan/model/nlp/text/mlxlm.py index 4b1e2399..f3144cda 100644 --- a/src/avalan/model/nlp/text/mlxlm.py +++ b/src/avalan/model/nlp/text/mlxlm.py @@ -5,6 +5,7 @@ TransformerEngineSettings, ) from ....model.response.text import TextGenerationResponse +from ....model.vendor import TextGenerationVendor from ....tool.manager import ToolManager from ...vendor import TextGenerationVendorStream from .generation import TextGenerationModel @@ -12,23 +13,26 @@ from asyncio import to_thread from dataclasses import asdict, replace from logging import Logger, getLogger -from typing import AsyncGenerator, Callable, Literal +from typing import Any, AsyncGenerator, Callable, Literal +from diffusers import DiffusionPipeline from mlx_lm import generate, load, stream_generate from mlx_lm.sample_utils import make_sampler from torch import Tensor +from transformers import PreTrainedModel class MlxLmStream(TextGenerationVendorStream): """Async wrapper around a synchronous token generator.""" - _SENTINEL = object() + _SENTINEL: object = object() + _iterator: Any - def __init__(self, generator): + def __init__(self, generator: Any) -> None: super().__init__(generator) self._iterator = generator - async def __anext__(self) -> str: + async def __anext__(self) -> Any: sentinel = type(self)._SENTINEL chunk = await to_thread(next, self._iterator, sentinel) if chunk is sentinel: @@ -52,13 +56,16 @@ def __init__( def supports_sample_generation(self) -> bool: return False - def _load_model(self): + def _load_model( + self, + ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None model, tokenizer = load(self._model_id) - self._tokenizer = tokenizer + self._tokenizer = tokenizer # type: ignore[assignment] self._loaded_tokenizer = True - return model + return model # type: ignore[return-value] - async def _stream_generator( + async def _stream_generator( # type: ignore[override] self, inputs: dict[str, Tensor] | Tensor, settings: GenerationSettings, @@ -68,8 +75,8 @@ async def _stream_generator( inputs, settings, skip_special_tokens ) iterator = stream_generate( - self._model, - self._tokenizer, + self._model, # type: ignore[arg-type] + self._tokenizer, # type: ignore[arg-type] prompt, sampler=sampler, max_tokens=settings.max_new_tokens, @@ -78,7 +85,7 @@ async def _stream_generator( async for chunk in stream: yield chunk.text - def _string_output( + def _string_output( # type: ignore[override] self, inputs: dict[str, Tensor] | Tensor, settings: GenerationSettings, @@ -88,15 +95,15 @@ def _string_output( inputs, settings, skip_special_tokens ) return generate( - self._model, - self._tokenizer, + self._model, # type: ignore[arg-type] + self._tokenizer, # type: ignore[arg-type] prompt, sampler=sampler, max_tokens=settings.max_new_tokens, ) @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, system_prompt: str | None = None, @@ -118,11 +125,12 @@ async def __call__( chat_template_settings=asdict(settings.chat_settings), ) generation_settings = replace(settings, do_sample=False) - output_fn = ( + output_fn: Callable[..., Any] = ( self._stream_generator if settings.use_async_generator else self._string_output ) + assert self._tokenizer is not None return TextGenerationResponse( output_fn, @@ -140,8 +148,8 @@ def _get_sampler_and_prompt( inputs: dict[str, Tensor] | Tensor, settings: GenerationSettings, skip_special_tokens: bool, - ) -> tuple[Callable, str]: - sampler_settings = { + ) -> tuple[Callable[..., Any], str]: + sampler_settings: dict[str, Any] = { "temp": settings.temperature, "top_p": settings.top_p, "top_k": settings.top_k, @@ -150,7 +158,9 @@ def _get_sampler_and_prompt( k: v for k, v in sampler_settings.items() if v is not None } sampler = make_sampler(**sampler_settings) - prompt = self._tokenizer.decode( + assert self._tokenizer is not None + assert isinstance(inputs, dict), "inputs must be a dict" + prompt: str = self._tokenizer.decode( inputs["input_ids"][0], skip_special_tokens=skip_special_tokens, ) diff --git a/src/avalan/model/nlp/text/vendor/__init__.py b/src/avalan/model/nlp/text/vendor/__init__.py index c185ca5e..d4c7881c 100644 --- a/src/avalan/model/nlp/text/vendor/__init__.py +++ b/src/avalan/model/nlp/text/vendor/__init__.py @@ -13,7 +13,7 @@ from contextlib import AsyncExitStack from dataclasses import replace from logging import Logger, getLogger -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from tiktoken import encoding_for_model, get_encoding @@ -59,12 +59,12 @@ def supports_token_streaming(self) -> bool: def uses_tokenizer(self) -> bool: return False - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, context: str | None = None, tensor_format: Literal["pt"] = "pt", - **kwargs, + **kwargs: Any, ) -> dict[str, Tensor] | BatchEncoding | Tensor: raise NotImplementedError() @@ -74,6 +74,7 @@ def input_token_count( system_prompt: str | None = None, developer_prompt: str | None = None, ) -> int: + assert self._model_id is not None try: encoding = encoding_for_model(self._model_id) except KeyError: @@ -85,11 +86,14 @@ def input_token_count( total_tokens = 0 for message in messages: - total_tokens += len(encoding.encode(message.content or "")) + content = ( + message.content if isinstance(message.content, str) else "" + ) + total_tokens += len(encoding.encode(content)) return total_tokens @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, system_prompt: str | None = None, @@ -98,15 +102,17 @@ async def __call__( *, tool: ToolManager | None = None, ) -> TextGenerationResponse: + gen_settings = settings or GenerationSettings() messages = self._messages(input, system_prompt, developer_prompt, tool) - streamer = await self._model( + assert self._model is not None + assert self._model_id is not None + streamer = await self._model( # type: ignore[operator] self._model_id, messages, - settings, + gen_settings, tool=tool, - use_async_generator=settings.use_async_generator, + use_async_generator=gen_settings.use_async_generator, ) - gen_settings = settings or GenerationSettings() return TextGenerationResponse( streamer, logger=self._logger, diff --git a/src/avalan/model/nlp/text/vendor/anthropic.py b/src/avalan/model/nlp/text/vendor/anthropic.py index 837b3100..fc3e517d 100644 --- a/src/avalan/model/nlp/text/vendor/anthropic.py +++ b/src/avalan/model/nlp/text/vendor/anthropic.py @@ -18,7 +18,7 @@ from . import TextGenerationVendorModel from contextlib import AsyncExitStack -from typing import AsyncIterator +from typing import Any, AsyncIterator from anthropic import AsyncAnthropic from anthropic.types import RawContentBlockDeltaEvent, RawMessageStopEvent @@ -27,9 +27,9 @@ class AnthropicStream(TextGenerationVendorStream): - def __init__(self, events: AsyncIterator): + def __init__(self, events: AsyncIterator[Any]) -> None: async def generator() -> AsyncIterator[Token | TokenDetail | str]: - tool_blocks: dict[int, dict] = {} + tool_blocks: dict[int, dict[str, Any]] = {} async for event in events: etype = getattr(event, "type", None) @@ -100,10 +100,10 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]: ): break - super().__init__(generator()) + super().__init__(generator()) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: - return await self._generator.__anext__() + return await self._generator.__anext__() # type: ignore[no-any-return] class AnthropicClient(TextGenerationVendor): @@ -116,24 +116,24 @@ def __init__( base_url: str | None = None, *, exit_stack: AsyncExitStack, - ): + ) -> None: self._client = AsyncAnthropic(api_key=api_key, base_url=base_url) self._exit_stack = exit_stack @override - async def __call__( + async def __call__( # type: ignore[override] self, - model_id: str, + model_id: str | None, messages: list[Message], settings: GenerationSettings | None = None, *, tool: ToolManager | None = None, use_async_generator: bool = True, - ) -> AsyncIterator[Token | TokenDetail | str]: + ) -> AsyncIterator[Token | TokenDetail | str] | TextGenerationSingleStream: settings = settings or GenerationSettings() system_prompt = self._system_prompt(messages) template_messages = self._template_messages(messages, ["system"]) - kwargs = { + kwargs: dict[str, Any] = { "model": model_id, "system": system_prompt, "messages": template_messages, @@ -152,32 +152,37 @@ async def __call__( content = self._non_stream_response_content(response) return TextGenerationSingleStream(content) - def _template_messages( + def _template_messages( # type: ignore[override] self, messages: list[Message], exclude_roles: list[TemplateMessageRole] | None = None, - ) -> list[TemplateMessage]: - tool_results = [ - message.tool_call_result or message.tool_call_error + ) -> list[TemplateMessage | dict[str, Any]]: + tool_results: list[ToolCallResult | ToolCallError] = [ + message.tool_call_result or message.tool_call_error # type: ignore[misc] for message in messages if message.role == MessageRole.TOOL and (message.tool_call_result or message.tool_call_error) ] - do_exclude_roles = [*(exclude_roles or []), "tool"] - messages = super()._template_messages(messages, do_exclude_roles) - last_message = next( + do_exclude_roles: list[TemplateMessageRole] = [ + *(exclude_roles or []), + "tool", + ] + template_messages: list[TemplateMessage | dict[str, Any]] = list( + super()._template_messages(messages, do_exclude_roles) + ) + last_message: TemplateMessage | dict[str, Any] | None = next( ( m - for m in reversed(messages) + for m in reversed(template_messages) if m["role"] == str(MessageRole.ASSISTANT) ), None, ) - last_message_index = ( - messages.index(last_message) if last_message else None + last_message_index: int | None = ( + template_messages.index(last_message) if last_message else None ) - if last_message_index: - messages[last_message_index] = { + if last_message_index is not None and last_message is not None: + template_messages[last_message_index] = { "role": last_message["role"], "content": [ ( @@ -188,11 +193,11 @@ def _template_messages( *[ { "type": "tool_use", - "id": r.call.id, + "id": r.call.id if r.call else None, "name": TextGenerationVendor.encode_tool_name( - r.call.name + r.call.name if r.call else "" ), - "input": r.call.arguments, + "input": r.call.arguments if r.call else {}, } for r in tool_results ], @@ -200,7 +205,8 @@ def _template_messages( } for result in tool_results: - content = { + assert result.call is not None + content: dict[str, Any] = { "type": "tool_result", "tool_use_id": result.call.id, "content": to_json( @@ -211,49 +217,55 @@ def _template_messages( } if isinstance(result, ToolCallError): content["is_error"] = True - result_message = TemplateMessage(role="user", content=[content]) - if last_message_index: - messages.insert(last_message_index + 1, result_message) + result_message: dict[str, Any] = { + "role": "user", + "content": [content], + } + if last_message_index is not None: + template_messages.insert( + last_message_index + 1, result_message + ) else: - messages.append(result_message) + template_messages.append(result_message) # @TODO Ensure this doesn't happen from upstream - if len(messages) > 1 and (messages[0] == messages[-1]): - messages.pop() + if len(template_messages) > 1 and ( + template_messages[0] == template_messages[-1] + ): + template_messages.pop() - return messages + return template_messages @staticmethod - def _tool_schemas(tool: ToolManager) -> list[dict] | None: + def _tool_schemas(tool: ToolManager) -> list[dict[str, Any]] | None: schemas = tool.json_schemas() - return ( - [ - { - "name": TextGenerationVendor.encode_tool_name( - t["function"]["name"] - ), - "description": t["function"]["description"], - "input_schema": { - **t["function"]["parameters"], - "additionalProperties": False, - }, - } - for t in tool.json_schemas() - if t["type"] == "function" - ] - if schemas - else None - ) + if not schemas: + return None + return [ + { + "name": TextGenerationVendor.encode_tool_name( + t["function"]["name"] + ), + "description": t["function"]["description"], + "input_schema": { + **t["function"]["parameters"], + "additionalProperties": False, + }, + } + for t in schemas + if t["type"] == "function" + ] @staticmethod def _non_stream_response_content(response: object) -> str: - def _get(value: object, attribute: str) -> object | None: + def _get(value: object, attribute: str) -> Any: if isinstance(value, dict): return value.get(attribute) return getattr(value, attribute, None) parts: list[str] = [] - for block in _get(response, "content") or []: + content_blocks = _get(response, "content") or [] + for block in content_blocks: block_type = _get(block, "type") if block_type == "text": text = _get(block, "text") @@ -262,10 +274,13 @@ def _get(value: object, attribute: str) -> object | None: continue if block_type == "tool_use": + block_id: str | None = _get(block, "id") + block_name: str | None = _get(block, "name") + block_input: dict[str, Any] | None = _get(block, "input") token = TextGenerationVendor.build_tool_call_token( - _get(block, "id"), - _get(block, "name"), - _get(block, "input"), + block_id, + block_name, + block_input, ) parts.append(token.token) diff --git a/src/avalan/model/nlp/text/vendor/bedrock.py b/src/avalan/model/nlp/text/vendor/bedrock.py index 4ee84e8d..4e9f6991 100644 --- a/src/avalan/model/nlp/text/vendor/bedrock.py +++ b/src/avalan/model/nlp/text/vendor/bedrock.py @@ -24,7 +24,7 @@ from json import dumps from typing import Any, AsyncIterator -from aioboto3 import Session as Boto3Session +from aioboto3 import Session as Boto3Session # type: ignore[import-not-found] from diffusers import DiffusionPipeline from transformers import PreTrainedModel @@ -49,7 +49,7 @@ def _string(value: Any) -> str | None: class BedrockStream(TextGenerationVendorStream): - def __init__(self, events: AsyncIterator): + def __init__(self, events: AsyncIterator) -> None: # type: ignore[type-arg] async def generator() -> AsyncIterator[Token | TokenDetail | str]: tool_blocks: dict[int, dict[str, Any]] = {} @@ -153,10 +153,12 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]: if _get(event, "messageStop"): break - super().__init__(generator()) + super().__init__(generator()) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: - return await self._generator.__anext__() + result = await self._generator.__anext__() + assert isinstance(result, (Token, TokenDetail, str)) + return result class BedrockClient(TextGenerationVendor): @@ -192,7 +194,7 @@ async def _client_instance(self) -> Any: return self._client @override - async def __call__( + async def __call__( # type: ignore[override] self, model_id: str, messages: list[Message], @@ -278,7 +280,7 @@ def _response_text(self, response: dict[str, Any]) -> str: parts.append(text_value) return "".join(parts) - def _template_messages( + def _template_messages( # type: ignore[override] self, messages: list[Message], exclude_roles: list[TemplateMessageRole] | None = None, diff --git a/src/avalan/model/nlp/text/vendor/google.py b/src/avalan/model/nlp/text/vendor/google.py index 385840be..93ab7d08 100644 --- a/src/avalan/model/nlp/text/vendor/google.py +++ b/src/avalan/model/nlp/text/vendor/google.py @@ -13,12 +13,13 @@ class GoogleStream(TextGenerationVendorStream): - def __init__(self, stream: AsyncIterator[GenerateContentResponse]): - super().__init__(stream) + def __init__(self, stream: AsyncIterator[GenerateContentResponse]) -> None: + super().__init__(stream) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: chunk = await self._generator.__anext__() - return chunk.text + text: str = chunk.text + return text class GoogleClient(TextGenerationVendor): @@ -28,7 +29,7 @@ def __init__(self, api_key: str): self._client = Client(api_key=api_key) @override - async def __call__( + async def __call__( # type: ignore[override] self, model_id: str, messages: list[Message], @@ -51,7 +52,7 @@ async def __call__( contents=contents, ) - async def single_gen(): + async def single_gen() -> AsyncIterator[Token | TokenDetail | str]: yield response.text return single_gen() diff --git a/src/avalan/model/nlp/text/vendor/huggingface.py b/src/avalan/model/nlp/text/vendor/huggingface.py index 7be11b57..206ca780 100644 --- a/src/avalan/model/nlp/text/vendor/huggingface.py +++ b/src/avalan/model/nlp/text/vendor/huggingface.py @@ -9,7 +9,7 @@ from ....vendor import TextGenerationVendor, TextGenerationVendorStream from . import TextGenerationVendorModel -from typing import AsyncIterator +from typing import Any, AsyncIterator from diffusers import DiffusionPipeline from huggingface_hub import AsyncInferenceClient @@ -17,13 +17,13 @@ class HuggingfaceStream(TextGenerationVendorStream): - def __init__(self, stream: AsyncIterator): - super().__init__(stream.__aiter__()) + def __init__(self, stream: AsyncIterator) -> None: # type: ignore[type-arg] + super().__init__(stream.__aiter__()) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: chunk = await self._generator.__anext__() delta = chunk.choices[0].delta - text = getattr(delta, "content", None) or "" + text: str = getattr(delta, "content", None) or "" return text @@ -34,7 +34,7 @@ def __init__(self, api_key: str, base_url: str | None = None): self._client = AsyncInferenceClient(token=api_key, base_url=base_url) @override - async def __call__( + async def __call__( # type: ignore[override] self, model_id: str, messages: list[Message], @@ -44,22 +44,32 @@ async def __call__( use_async_generator: bool = True, ) -> AsyncIterator[Token | TokenDetail | str]: settings = settings or GenerationSettings() - template_messages = self._template_messages(messages) + template_messages: list[dict[Any, Any]] = [ + {"role": m["role"], "content": m["content"]} + for m in self._template_messages(messages) + ] + stop: list[str] | None = None + if settings.stop_strings: + stop = ( + [settings.stop_strings] + if isinstance(settings.stop_strings, str) + else settings.stop_strings + ) response = await self._client.chat_completion( model=model_id, messages=template_messages, temperature=settings.temperature, max_tokens=settings.max_new_tokens, top_p=settings.top_p, - stop=settings.stop_strings, + stop=stop, stream=use_async_generator, ) if use_async_generator: - return HuggingfaceStream(response) + return HuggingfaceStream(response.__aiter__()) # type: ignore[arg-type, union-attr] else: - async def single_gen(): - yield response.choices[0].message.content or "" + async def single_gen() -> AsyncIterator[Token | TokenDetail | str]: + yield response.choices[0].message.content or "" # type: ignore[union-attr] return single_gen() diff --git a/src/avalan/model/nlp/text/vendor/litellm.py b/src/avalan/model/nlp/text/vendor/litellm.py index 801d60bb..b26ac86b 100644 --- a/src/avalan/model/nlp/text/vendor/litellm.py +++ b/src/avalan/model/nlp/text/vendor/litellm.py @@ -12,8 +12,8 @@ class LiteLLMStream(TextGenerationVendorStream): - def __init__(self, stream: AsyncIterator): - super().__init__(stream.__aiter__()) + def __init__(self, stream: AsyncIterator) -> None: # type: ignore[type-arg] + super().__init__(stream.__aiter__()) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: chunk = await self._generator.__anext__() @@ -21,7 +21,7 @@ async def __anext__(self) -> Token | TokenDetail | str: if isinstance(chunk, dict): choice = chunk.get("choices", [{}])[0] delta = choice.get("delta", {}) if isinstance(choice, dict) else {} - text = delta.get("content", "") + text: str = delta.get("content", "") else: choice = chunk.choices[0] delta = getattr(choice, "delta", None) @@ -40,7 +40,7 @@ def __init__( self._base_url = base_url or "http://localhost:4000" @override - async def __call__( + async def __call__( # type: ignore[override] self, model_id: str, messages: list[Message], @@ -62,7 +62,7 @@ async def __call__( if use_async_generator: return LiteLLMStream(result) - async def single_gen(): + async def single_gen() -> AsyncIterator[Token | TokenDetail | str]: if isinstance(result, dict): yield result["choices"][0]["message"]["content"] else: diff --git a/src/avalan/model/nlp/text/vendor/ollama.py b/src/avalan/model/nlp/text/vendor/ollama.py index 5894b24a..7d62b24a 100644 --- a/src/avalan/model/nlp/text/vendor/ollama.py +++ b/src/avalan/model/nlp/text/vendor/ollama.py @@ -6,29 +6,30 @@ TokenDetail, TransformerEngineSettings, ) -from .....model.nlp.text.generation import TextGenerationModel from .....tool.manager import ToolManager from ....vendor import TextGenerationVendor, TextGenerationVendorStream from . import TextGenerationVendorModel +from contextlib import AsyncExitStack from dataclasses import replace from logging import Logger, getLogger from typing import AsyncIterator try: - from ollama import AsyncClient + from ollama import AsyncClient # type: ignore[import-not-found] except Exception: # pragma: no cover - ollama may not be installed - AsyncClient = None + AsyncClient = None # type: ignore[misc, assignment] class OllamaStream(TextGenerationVendorStream): - def __init__(self, stream: AsyncIterator[dict]): - super().__init__(stream) + def __init__(self, stream: AsyncIterator[dict]) -> None: # type: ignore[type-arg] + super().__init__(stream) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: chunk = await self._generator.__anext__() message = chunk.get("message", {}) if isinstance(chunk, dict) else {} - return message.get("content", "") + content: str = message.get("content", "") + return content class OllamaClient(TextGenerationVendor): @@ -41,7 +42,7 @@ def __init__(self, base_url: str | None = None): ) @override - async def __call__( + async def __call__( # type: ignore[override] self, model_id: str, messages: list[Message], @@ -65,7 +66,7 @@ async def __call__( stream=False, ) - async def single_gen(): + async def single_gen() -> AsyncIterator[Token | TokenDetail | str]: yield response["message"]["content"] return single_gen() @@ -77,10 +78,12 @@ def __init__( model_id: str, settings: TransformerEngineSettings | None = None, logger: Logger = getLogger(__name__), + *, + exit_stack: AsyncExitStack | None = None, ) -> None: settings = settings or TransformerEngineSettings() settings = replace(settings, enable_eval=False) - TextGenerationModel.__init__(self, model_id, settings, logger) + super().__init__(model_id, settings, logger, exit_stack=exit_stack) def _load_model(self): return OllamaClient(base_url=self._settings.base_url) diff --git a/src/avalan/model/nlp/text/vendor/openai.py b/src/avalan/model/nlp/text/vendor/openai.py index f4be59cc..7a129b17 100644 --- a/src/avalan/model/nlp/text/vendor/openai.py +++ b/src/avalan/model/nlp/text/vendor/openai.py @@ -7,6 +7,7 @@ ReasoningToken, Token, TokenDetail, + ToolCallError, ToolCallResult, ToolCallToken, ) @@ -19,7 +20,7 @@ from . import TextGenerationVendorModel from json import dumps -from typing import AsyncIterator +from typing import Any, AsyncIterator from diffusers import DiffusionPipeline from openai import AsyncOpenAI @@ -27,12 +28,15 @@ class OpenAIStream(TextGenerationVendorStream): - _TEXT_DELTA_EVENTS = {"response.text.delta", "response.output_text.delta"} - _REASONING_DELTA_EVENTS = {"response.reasoning_text.delta"} + _TEXT_DELTA_EVENTS: set[str] = { + "response.text.delta", + "response.output_text.delta", + } + _REASONING_DELTA_EVENTS: set[str] = {"response.reasoning_text.delta"} - def __init__(self, stream: AsyncIterator): + def __init__(self, stream: AsyncIterator[Any]) -> None: async def generator() -> AsyncIterator[Token | TokenDetail | str]: - tool_calls: dict[str, dict] = {} + tool_calls: dict[str, dict[str, Any]] = {} async for event in stream: etype = getattr(event, "type", None) @@ -42,13 +46,14 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]: if item: custom = getattr(item, "custom_tool_call", None) if custom: - call_id = getattr( + call_id: str | None = getattr( custom, "id", getattr(item, "id", None) ) - tool_calls[call_id] = { - "name": getattr(custom, "name", None), - "args_fragments": [], - } + if call_id is not None: + tool_calls[call_id] = { + "name": getattr(custom, "name", None), + "args_fragments": [], + } continue if ( @@ -79,11 +84,17 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]: if etype == "response.output_item.done": item = getattr(event, "item", None) - call_id = getattr(item, "id", None) if item else None - cached = tool_calls.pop(call_id, None) + done_call_id: str | None = ( + getattr(item, "id", None) if item else None + ) + cached = ( + tool_calls.pop(done_call_id, None) + if done_call_id + else None + ) if cached: yield TextGenerationVendor.build_tool_call_token( - call_id, + done_call_id, cached.get("name"), "".join(cached["args_fragments"]) or None, ) @@ -104,22 +115,22 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]: continue - super().__init__(generator()) + super().__init__(generator()) # type: ignore[arg-type] async def __anext__(self) -> Token | TokenDetail | str: - return await self._generator.__anext__() + return await self._generator.__anext__() # type: ignore[no-any-return] class OpenAIClient(TextGenerationVendor): _client: AsyncOpenAI - def __init__(self, api_key: str, base_url: str | None): + def __init__(self, api_key: str | None, base_url: str | None) -> None: self._client = AsyncOpenAI(base_url=base_url, api_key=api_key) @override - async def __call__( + async def __call__( # type: ignore[override] self, - model_id: str, + model_id: str | None, messages: list[Message], settings: GenerationSettings | None = None, *, @@ -128,7 +139,7 @@ async def __call__( use_async_generator: bool = True, ) -> AsyncIterator[Token | TokenDetail | str] | TextGenerationSingleStream: template_messages = self._template_messages(messages) - kwargs: dict = { + kwargs: dict[str, Any] = { "extra_headers": { "X-Title": "Avalan", "HTTP-Referer": "https://github.com/avalan-ai/avalan", @@ -161,21 +172,27 @@ async def __call__( content = OpenAIClient._non_stream_response_content(client_stream) return TextGenerationSingleStream(content) - def _template_messages( + def _template_messages( # type: ignore[override] self, messages: list[Message], exclude_roles: list[TemplateMessageRole] | None = None, - ) -> list[TemplateMessage]: - tool_results = [ - message.tool_call_result or message.tool_call_error + ) -> list[TemplateMessage | dict[str, Any]]: + tool_results: list[ToolCallResult | ToolCallError] = [ + message.tool_call_result or message.tool_call_error # type: ignore[misc] for message in messages if message.role == MessageRole.TOOL and (message.tool_call_result or message.tool_call_error) ] - do_exclude_roles = [*(exclude_roles or []), "tool"] - messages = super()._template_messages(messages, do_exclude_roles) + do_exclude_roles: list[TemplateMessageRole] = [ + *(exclude_roles or []), + "tool", + ] + template_messages: list[TemplateMessage | dict[str, Any]] = list( + super()._template_messages(messages, do_exclude_roles) + ) for result in tool_results: - call_message = { + assert result.call is not None + call_message: dict[str, Any] = { "type": "function_call", "name": TextGenerationVendor.encode_tool_name( result.call.name @@ -183,9 +200,9 @@ def _template_messages( "call_id": result.call.id, "arguments": dumps(result.call.arguments), } - messages.append(call_message) + template_messages.append(call_message) - result_message = { + result_message: dict[str, Any] = { "type": "function_call_output", "call_id": result.call.id, "output": to_json( @@ -194,39 +211,38 @@ def _template_messages( else {"error": result.message} ), } - messages.append(result_message) - return messages + template_messages.append(result_message) + return template_messages @staticmethod - def _tool_schemas(tool: ToolManager) -> list[dict] | None: + def _tool_schemas(tool: ToolManager) -> list[dict[str, Any]] | None: schemas = tool.json_schemas() - return ( - [ - { - "type": t["type"], - **t["function"], - **{ - "name": TextGenerationVendor.encode_tool_name( - t["function"]["name"] - ) - }, - } - for t in tool.json_schemas() - if t["type"] == "function" - ] - if schemas - else None - ) + if not schemas: + return None + return [ + { + "type": t["type"], + **t["function"], + **{ + "name": TextGenerationVendor.encode_tool_name( + t["function"]["name"] + ) + }, + } + for t in schemas + if t["type"] == "function" + ] @staticmethod def _non_stream_response_content(response: object) -> str: - def _get(value: object, attribute: str) -> object | None: + def _get(value: object, attribute: str) -> Any: if isinstance(value, dict): return value.get(attribute) return getattr(value, attribute, None) parts: list[str] = [] - for item in _get(response, "output") or []: + output_items = _get(response, "output") or [] + for item in output_items: item_type = _get(item, "type") contents = _get(item, "content") or [] @@ -240,10 +256,15 @@ def _get(value: object, attribute: str) -> object | None: if item_type in {"tool_call", "function_call"}: call = _get(item, "call") or item function = _get(call, "function") or call + call_id: str | None = _get(call, "id") + func_name: str | None = _get(function, "name") + func_args: str | dict[str, Any] | None = _get( + function, "arguments" + ) token = TextGenerationVendor.build_tool_call_token( - _get(call, "id"), - _get(function, "name"), - _get(function, "arguments"), + call_id, + func_name, + func_args, ) parts.append(token.token) @@ -255,9 +276,9 @@ class OpenAINonStreamingResponse(TextGenerationResponse): def __init__( self, - *args, + *args: Any, static_response_text: str | None = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self._static_response_text = static_response_text @@ -289,7 +310,7 @@ def _load_model( ) @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, system_prompt: str | None = None, @@ -300,7 +321,9 @@ async def __call__( ) -> TextGenerationResponse: generation_settings = settings or GenerationSettings() messages = self._messages(input, system_prompt, developer_prompt, tool) - streamer = await self._model( + assert self._model is not None + assert self._model_id is not None + streamer = await self._model( # type: ignore[operator] self._model_id, messages, generation_settings, diff --git a/src/avalan/model/nlp/text/vendor/openrouter.py b/src/avalan/model/nlp/text/vendor/openrouter.py index 750c2b2d..4246355a 100644 --- a/src/avalan/model/nlp/text/vendor/openrouter.py +++ b/src/avalan/model/nlp/text/vendor/openrouter.py @@ -6,13 +6,15 @@ class OpenRouterClient(OpenAIClient): - def __init__(self, api_key: str, base_url: str | None = None): + def __init__( + self, api_key: str | None, base_url: str | None = None + ) -> None: super().__init__( api_key=api_key, base_url=base_url or "https://openrouter.ai/api/v1", ) # Optional headers recommended by OpenRouter - self._client.headers.update( + self._client.headers.update( # type: ignore[union-attr,attr-defined] { "HTTP-Referer": "https://github.com/avalan-ai/avalan", "X-Title": "avalan", diff --git a/src/avalan/model/nlp/text/vllm.py b/src/avalan/model/nlp/text/vllm.py index 508a8d4c..00c95edd 100644 --- a/src/avalan/model/nlp/text/vllm.py +++ b/src/avalan/model/nlp/text/vllm.py @@ -5,23 +5,28 @@ TransformerEngineSettings, ) from ....model.nlp.text.generation import TextGenerationModel -from ....model.vendor import TextGenerationVendorStream +from ....model.vendor import TextGenerationVendor, TextGenerationVendorStream from ....tool.manager import ToolManager from asyncio import to_thread from dataclasses import asdict, replace from logging import Logger, getLogger -from typing import AsyncGenerator +from typing import Any, AsyncGenerator + +from diffusers import DiffusionPipeline +from transformers import PreTrainedModel try: from vllm import LLM, SamplingParams except Exception: # pragma: no cover - vllm may not be installed - LLM = None - SamplingParams = None + LLM = None # type: ignore[assignment,misc] + SamplingParams = None # type: ignore[assignment,misc] class VllmStream(TextGenerationVendorStream): - def __init__(self, generator): + _iterator: Any + + def __init__(self, generator: Any) -> None: super().__init__(generator) self._iterator = generator @@ -48,17 +53,17 @@ def __init__( def supports_sample_generation(self) -> bool: return False - def _load_model(self): + def _load_model( + self, + ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: assert LLM, "vLLM is not available" - return LLM( + return LLM( # type: ignore[return-value,no-any-return] model=self._model_id, tokenizer=self._settings.tokenizer_name_or_path or self._model_id, trust_remote_code=self._settings.trust_remote_code, ) - def _build_sampling_params( - self, settings: GenerationSettings - ) -> SamplingParams: + def _build_sampling_params(self, settings: GenerationSettings) -> Any: assert SamplingParams, "vLLM is not available" return SamplingParams( temperature=settings.temperature, @@ -85,32 +90,36 @@ def _prompt( tool=tool, chat_template_settings=chat_template_settings, ) + assert self._tokenizer is not None + assert isinstance(inputs, dict), "inputs must be a dict" return self._tokenizer.decode( inputs["input_ids"][0], skip_special_tokens=False ) - async def _stream_generator( + async def _stream_generator( # type: ignore[override] self, prompt: str, settings: GenerationSettings, ) -> AsyncGenerator[str, None]: params = self._build_sampling_params(settings) - iterator = self._model.generate([prompt], params, stream=True) + assert self._model is not None + iterator = self._model.generate([prompt], params, stream=True) # type: ignore[union-attr,operator] stream = VllmStream(iter(iterator)) async for chunk in stream: yield chunk - def _string_output( + def _string_output( # type: ignore[override] self, prompt: str, settings: GenerationSettings, ) -> str: params = self._build_sampling_params(settings) - results = list(self._model.generate([prompt], params)) + assert self._model is not None + results = list(self._model.generate([prompt], params)) # type: ignore[union-attr,operator] return results[0].outputs[0].text if results else "" @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, system_prompt: str | None = None, @@ -129,5 +138,5 @@ async def __call__( ) generation_settings = replace(settings, do_sample=False) if settings.use_async_generator: - return await self._stream_generator(prompt, generation_settings) + return await self._stream_generator(prompt, generation_settings) # type: ignore[misc,no-any-return] return self._string_output(prompt, generation_settings) diff --git a/src/avalan/model/nlp/token.py b/src/avalan/model/nlp/token.py index 632b0bb2..4792981b 100644 --- a/src/avalan/model/nlp/token.py +++ b/src/avalan/model/nlp/token.py @@ -4,7 +4,7 @@ from ...model.nlp import BaseNLPModel from ...model.vendor import TextGenerationVendor -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import argmax, inference_mode @@ -26,23 +26,26 @@ def supports_token_streaming(self) -> bool: def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: - model = AutoModelForTokenClassification.from_pretrained( - self._model_id, - cache_dir=self._settings.cache_dir, - subfolder=self._settings.subfolder or "", - attn_implementation=self._settings.attention, - trust_remote_code=self._settings.trust_remote_code, - torch_dtype=Engine.weight(self._settings.weight_type), - state_dict=self._settings.state_dict, - local_files_only=self._settings.local_files_only, - token=self._settings.access_token, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), + assert self._model_id is not None, "Model ID must be set" + model: PreTrainedModel = ( + AutoModelForTokenClassification.from_pretrained( + self._model_id, + cache_dir=self._settings.cache_dir, + subfolder=self._settings.subfolder or "", + attn_implementation=self._settings.attention, + trust_remote_code=self._settings.trust_remote_code, + torch_dtype=Engine.weight(self._settings.weight_type), + state_dict=self._settings.state_dict, + local_files_only=self._settings.local_files_only, + token=self._settings.access_token, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + ) ) - labels = ( + labels: dict[int, str] | None = ( getattr(model.config, "id2label", None) if hasattr(model, "config") else None @@ -56,7 +59,7 @@ def _load_model( ) return model - def _tokenize_input( + def _tokenize_input( # type: ignore[override] self, input: Input, system_prompt: str | None, @@ -70,20 +73,25 @@ def _tokenize_input( + f"{self._model_id} does not support chat " + "templates" ) + assert self._tokenizer is not None, "Tokenizer must be loaded" + assert self._model is not None, "Model must be loaded" _l = self._log _l(f"Tokenizing input {input}") - inputs = self._tokenizer(input, return_tensors=tensor_format) - inputs = inputs.to(self._model.device) + inputs: BatchEncoding = self._tokenizer( + input, return_tensors=tensor_format + ) + inputs = inputs.to(self._model.device) # type: ignore[union-attr] return inputs @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, *, labeled_only: bool = False, system_prompt: str | None = None, developer_prompt: str | None = None, + **kwargs: Any, ) -> dict[str, str]: assert self._tokenizer, ( f"Model {self._model} can't be executed " @@ -100,10 +108,10 @@ async def __call__( context=None, ) with inference_mode(): - outputs = self._model(**inputs) + outputs = self._model(**inputs) # type: ignore[operator] # logits shape (1, seq_len, num_labels) input_ids = inputs["input_ids"][0] - label_ids = argmax(outputs.logits, dim=2)[0] + label_ids = argmax(outputs.logits, dim=2)[0] # type: ignore[union-attr] if labeled_only and self._default_label_id is not None: mask = label_ids != self._default_label_id @@ -112,9 +120,10 @@ async def __call__( assert input_ids.numel() == label_ids.numel() tokens = self._tokenizer.convert_ids_to_tokens(input_ids) - labels = [ - self._model.config.id2label[label_id.item()] - for label_id in label_ids + id2label = self._model.config.id2label # type: ignore[union-attr] + assert id2label is not None, "Model config must have id2label" + labels_list: list[str] = [ + id2label[int(label_id.item())] for label_id in label_ids ] - tokens_to_labels = dict(zip(tokens, labels)) + tokens_to_labels = dict(zip(tokens, labels_list)) return tokens_to_labels diff --git a/src/avalan/model/response/parsers/tool.py b/src/avalan/model/response/parsers/tool.py index aa821a04..6b4ecdd0 100644 --- a/src/avalan/model/response/parsers/tool.py +++ b/src/avalan/model/response/parsers/tool.py @@ -90,7 +90,9 @@ async def push(self, token_str: str) -> Iterable[Any]: return result event = Event( - type=EventType.TOOL_PROCESS, payload=calls, started=perf_counter() + type=EventType.TOOL_PROCESS, + payload={"calls": calls}, + started=perf_counter(), ) self._buffer = StringIO() diff --git a/src/avalan/model/response/text.py b/src/avalan/model/response/text.py index 96afb0a3..3b7b6149 100644 --- a/src/avalan/model/response/text.py +++ b/src/avalan/model/response/text.py @@ -20,8 +20,12 @@ AsyncIterator, Awaitable, Callable, + Self, + TypeVar, ) +T = TypeVar("T") + OutputGenerator = AsyncGenerator[Token | TokenDetail | str, None] OutputFunction = Callable[..., OutputGenerator | str] @@ -92,14 +96,17 @@ def _ensure_non_stream_prefetched(self) -> None: if self._buffer.tell(): return - result = self._output_fn(*self._args, **self._kwargs) - if isinstance(result, TextGenerationSingleStream): - result = result.content - - if isinstance(result, (Token, TokenDetail)): - text = result.token + fn_result = self._output_fn(*self._args, **self._kwargs) + if isinstance(fn_result, TextGenerationSingleStream): + stream_content = fn_result.content + if isinstance(stream_content, (Token, TokenDetail)): + text = stream_content.token + else: + text = str(stream_content) + elif isinstance(fn_result, (Token, TokenDetail)): + text = fn_result.token else: - text = str(result) + text = str(fn_result) self._prefetched_text = text self._buffer = StringIO() @@ -125,7 +132,11 @@ def can_think(self) -> bool: @property def is_thinking(self) -> bool: - return self.can_think and self._reasoning_parser.is_thinking + return ( + self.can_think + and self._reasoning_parser is not None + and self._reasoning_parser.is_thinking + ) def set_thinking(self, thinking: bool) -> None: if self._reasoning_parser: @@ -140,9 +151,16 @@ async def _trigger_consumed(self) -> None: if iscoroutine(result): await result - def __aiter__(self): + def __aiter__(self) -> Self: + """Return iterator for async iteration over tokens. + + Returns: + Self for async iteration. + """ # Create a fresh async generator each time we start iterating - self._output = self._output_fn(*self._args, **self._kwargs) + fn_result = self._output_fn(*self._args, **self._kwargs) + if not isinstance(fn_result, str): + self._output = fn_result return self async def __anext__(self) -> Token | TokenDetail | str: @@ -154,9 +172,10 @@ async def __anext__(self) -> Token | TokenDetail | str: return self._parser_queue.get() try: + assert self._output is not None token = await self._output.__anext__() except StopAsyncIteration: - if self._reasoning_parser: + if self._reasoning_parser and self._parser_queue: for it in await self._reasoning_parser.flush(): self._parser_queue.put(it) if not self._parser_queue.empty(): @@ -180,7 +199,11 @@ async def __anext__(self) -> Token | TokenDetail | str: await self._trigger_consumed() raise StopAsyncIteration + assert self._parser_queue is not None for it in items: + parsed: ( + Token | TokenDetail | ReasoningToken | ToolCallToken | str + ) if isinstance(it, ReasoningToken): token_id = ( token.id @@ -188,7 +211,9 @@ async def __anext__(self) -> Token | TokenDetail | str: else it.id ) parsed = ReasoningToken( - token=it.token, id=token_id, probability=it.probability + token=it.token, + id=token_id if token_id is not None else -1, + probability=it.probability, ) elif isinstance(token, ToolCallToken): parsed = ToolCallToken( @@ -227,12 +252,14 @@ async def to_str(self) -> str: await self._trigger_consumed() return self._prefetched_text - # Ensure buffer is filled, wether we were already iterating or not + # Ensure buffer is filled, whether we were already iterating or not if not self._output: self.__aiter__() + assert self._output is not None async for token in self._output: - self._buffer.write(token) + token_str = token if isinstance(token, str) else token.token + self._buffer.write(token_str) self._output_token_count += 1 await self._trigger_consumed() @@ -252,7 +279,15 @@ async def to_json(self) -> str: continue raise InvalidJsonResponseException(text) - async def to(self, entity_class: type) -> any: - json = await self.to_json() - data = loads(json) + async def to(self, entity_class: type[T]) -> T: + """Convert JSON response to entity class instance. + + Args: + entity_class: The class to instantiate with JSON data. + + Returns: + Instance of entity_class populated with JSON data. + """ + json_str = await self.to_json() + data = loads(json_str) return entity_class(**data) diff --git a/src/avalan/model/transformer.py b/src/avalan/model/transformer.py index bdf3d26e..32b8ac9a 100644 --- a/src/avalan/model/transformer.py +++ b/src/avalan/model/transformer.py @@ -9,7 +9,7 @@ from logging import Logger, getLogger from typing import Literal -from tokenizers import AddedToken +from tokenizers import AddedToken # type: ignore[import-untyped] from torch import Tensor from transformers import ( AutoTokenizer, @@ -20,6 +20,10 @@ class TransformerModel(Engine, ABC): + """Base class for transformer-based models.""" + + _settings: TransformerEngineSettings + @property def supports_sample_generation(self) -> bool: return False @@ -71,12 +75,12 @@ def tokenize( not hasattr(self, "_loaded_tokenizer") or not self._loaded_tokenizer ): - self.load( - load_model=False, + self._load( load_tokenizer=True, tokenizer_name_or_path=tokenizer_name_or_path, ) + assert self._tokenizer is not None _l(f'Tokenizing text "{text}"') token_ids = self._tokenizer.encode(text, add_special_tokens=True) _l(f'Tokenized text "{text}" into {len(token_ids)} tokens') @@ -89,7 +93,7 @@ def tokenize( ), probability=1, ) - for i, token_id in enumerate(token_ids) + for token_id in token_ids ] def input_token_count( @@ -98,7 +102,6 @@ def input_token_count( system_prompt: str | None = None, developer_prompt: str | None = None, ) -> int: - _l = self._log assert self._tokenizer, ( f"Model {self._model} can't be executed " + "without a tokenizer loaded first" @@ -109,20 +112,27 @@ def input_token_count( developer_prompt=developer_prompt, context=None, ) - return ( - len(inputs["input_ids"][0]) - if inputs and "input_ids" in inputs - else 0 - ) + if isinstance(inputs, Tensor): + return inputs.shape[-1] if inputs.dim() > 0 else 0 + if isinstance(inputs, dict) and "input_ids" in inputs: + input_ids = inputs["input_ids"] + if isinstance(input_ids, Tensor): + return input_ids.shape[-1] if input_ids.dim() > 0 else 0 + if isinstance(input_ids, list) and len(input_ids) > 0: + return len(input_ids[0]) + return 0 def _load_tokenizer( self, tokenizer_name_or_path: str | None, use_fast: bool ) -> PreTrainedTokenizer | PreTrainedTokenizerFast: - return AutoTokenizer.from_pretrained( - tokenizer_name_or_path or self._model_id, - use_fast=use_fast, - subfolder=self._settings.tokenizer_subfolder or "", + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = ( + AutoTokenizer.from_pretrained( + tokenizer_name_or_path or self._model_id, + use_fast=use_fast, + subfolder=self._settings.tokenizer_subfolder or "", + ) ) + return tokenizer def _load_tokenizer_with_tokens( self, tokenizer_name_or_path: str | None, use_fast: bool = True diff --git a/src/avalan/model/vendor.py b/src/avalan/model/vendor.py index 874d01ec..65589164 100644 --- a/src/avalan/model/vendor.py +++ b/src/avalan/model/vendor.py @@ -8,15 +8,22 @@ ToolCallToken, ) from ..tool.manager import ToolManager -from .message import TemplateMessage, TemplateMessageRole +from .message import ( + TemplateMessage, + TemplateMessageContent, + TemplateMessageRole, +) from .stream import TextGenerationStream from abc import ABC from json import JSONDecodeError, dumps, loads -from typing import AsyncGenerator +from typing import AsyncGenerator, cast +from uuid import uuid4 class TextGenerationVendor(ABC): + """Base class for text generation vendor implementations.""" + async def __call__( self, model_id: str, @@ -29,28 +36,26 @@ async def __call__( raise NotImplementedError() def _system_prompt(self, messages: list[Message]) -> str | None: - return next( - ( - message.content - for message in messages - if message.role == "system" - ), - None, - ) + for message in messages: + if message.role == "system" and isinstance(message.content, str): + return message.content + return None def _template_messages( self, messages: list[Message], exclude_roles: list[TemplateMessageRole] | None = None, ) -> list[TemplateMessage]: - def _block(c: MessageContent) -> dict: + def _block(c: MessageContent) -> TemplateMessageContent: if isinstance(c, MessageContentImage): - return {"type": "image_url", "image_url": c.image_url} - return {"type": "text", "text": c.text} + return TemplateMessageContent( + type="image_url", image_url=c.image_url + ) + return TemplateMessageContent(type="text", text=c.text) def _wrap( content: str | MessageContent | list[MessageContent], - ) -> str | list[dict]: + ) -> str | list[TemplateMessageContent]: if isinstance(content, str): return content @@ -70,7 +75,9 @@ def _wrap( if exclude_roles and msg.role in exclude_roles: continue - out.append({"role": str(msg.role), "content": _wrap(msg.content)}) + content = msg.content if msg.content is not None else "" + role = cast(TemplateMessageRole, msg.role) + out.append({"role": role, "content": _wrap(content)}) return out @@ -96,7 +103,7 @@ def build_tool_call_token( args = {} else: args = arguments or {} - call = ToolCall(id=call_id, name=name, arguments=args) + call = ToolCall(id=call_id or str(uuid4()), name=name, arguments=args) token_json = dumps({"name": name, "arguments": args}) return ToolCallToken( token=f"{token_json}", call=call diff --git a/src/avalan/model/vision/classification.py b/src/avalan/model/vision/classification.py index 2ac506a2..13c0fdb6 100644 --- a/src/avalan/model/vision/classification.py +++ b/src/avalan/model/vision/classification.py @@ -4,7 +4,7 @@ from ...model.vendor import TextGenerationVendor from ...model.vision import BaseVisionModel -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from PIL import Image @@ -18,36 +18,48 @@ # model predicts one of the 1000 ImageNet classes class ImageClassificationModel(BaseVisionModel): + _processor: AutoImageProcessor + def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" self._processor = AutoImageProcessor.from_pretrained( self._model_id, # default behavior in transformers v4.48 use_fast=True, ) - model = AutoModelForImageClassification.from_pretrained( - self._model_id, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), + model: PreTrainedModel = ( + AutoModelForImageClassification.from_pretrained( + self._model_id, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + ) ) return model @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, tensor_format: Literal["pt"] = "pt", ) -> ImageEntity: + assert self._model is not None, "Model must be loaded" + assert isinstance( + self._model, PreTrainedModel + ), "Model must be PreTrainedModel" image = BaseVisionModel._get_image(image_source) - inputs = self._processor(image, return_tensors=tensor_format) + inputs: Any = self._processor( # type: ignore[operator] + image, return_tensors=tensor_format + ) inputs.to(self._device) with inference_mode(): - logits = self._model(**inputs).logits + logits = self._model(**inputs).logits # type: ignore[operator] label_index = logits.argmax(dim=1).item() - return ImageEntity(label=self._model.config.id2label[label_index]) + id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment] + return ImageEntity(label=id2label[label_index]) diff --git a/src/avalan/model/vision/decoder.py b/src/avalan/model/vision/decoder.py index 5d7d8d56..674879c9 100644 --- a/src/avalan/model/vision/decoder.py +++ b/src/avalan/model/vision/decoder.py @@ -4,7 +4,7 @@ from ...model.vision import BaseVisionModel from ...model.vision.text import ImageToTextModel -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from PIL import Image @@ -22,11 +22,12 @@ class VisionEncoderDecoderModel(ImageToTextModel): def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" self._processor = AutoImageProcessor.from_pretrained( self._model_id, use_fast=True, ) - model = VisionEncoderDecoderModelImpl.from_pretrained( + model: PreTrainedModel = VisionEncoderDecoderModelImpl.from_pretrained( self._model_id, device_map=self._device, tp_plan=Engine._get_tp_plan(self._settings.parallel), @@ -37,7 +38,7 @@ def _load_model( return model @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, prompt: str | None, @@ -55,8 +56,11 @@ async def __call__( tensor_format=tensor_format, ) + assert self._model is not None, "Model must be loaded" + assert self._tokenizer is not None, "Tokenizer must be loaded" + image = BaseVisionModel._get_image(image_source) - pixel_values = self._processor( + pixel_values: Any = self._processor( # type: ignore[operator] image, return_tensors=tensor_format ).pixel_values.to(self._device) decoder_input_ids = self._tokenizer( @@ -64,10 +68,11 @@ async def __call__( ).input_ids.to(self._device) with inference_mode(): - outputs = self._model.generate( + # self._model is VisionEncoderDecoderModel at runtime + outputs = self._model.generate( # type: ignore[union-attr, operator] pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, - max_length=self._model.decoder.config.max_position_embeddings, + max_length=self._model.decoder.config.max_position_embeddings, # type: ignore[union-attr] early_stopping=early_stopping, pad_token_id=self._tokenizer.pad_token_id, eos_token_id=self._tokenizer.eos_token_id, @@ -77,7 +82,8 @@ async def __call__( return_dict_in_generate=True, ) - output = self._tokenizer.batch_decode( - outputs.sequences, skip_special_tokens=skip_special_tokens + output: str = self._tokenizer.batch_decode( + outputs.sequences, # type: ignore[union-attr] + skip_special_tokens=skip_special_tokens, )[0] return output diff --git a/src/avalan/model/vision/detection.py b/src/avalan/model/vision/detection.py index d0b110e0..00a04ece 100644 --- a/src/avalan/model/vision/detection.py +++ b/src/avalan/model/vision/detection.py @@ -6,7 +6,7 @@ from ...model.vision.classification import ImageClassificationModel from logging import Logger, getLogger -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from PIL import Image @@ -32,13 +32,14 @@ def __init__( def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" self._processor = AutoImageProcessor.from_pretrained( self._model_id, revision=self._revision, # default behavior in transformers v4.48 use_fast=True, ) - model = AutoModelForObjectDetection.from_pretrained( + model: PreTrainedModel = AutoModelForObjectDetection.from_pretrained( self._model_id, revision=self._revision, device_map=self._device, @@ -50,32 +51,38 @@ def _load_model( return model @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, threshold: float | None = 0.3, tensor_format: Literal["pt"] = "pt", ) -> list[ImageEntity]: + assert self._model is not None, "Model must be loaded" + assert isinstance( + self._model, PreTrainedModel + ), "Model must be PreTrainedModel" image = BaseVisionModel._get_image(image_source) - inputs = self._processor(images=image, return_tensors=tensor_format) + processor: Any = self._processor + inputs: Any = processor(images=image, return_tensors=tensor_format) inputs.to(self._device) with inference_mode(): - outputs = self._model(**inputs) + outputs = self._model(**inputs) # type: ignore[operator] target_sizes = tensor([image.size[::-1]]) - results = self._processor.post_process_object_detection( + results = processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=threshold )[0] + id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment] entities = [] for score, label, box in zip( results["scores"], results["labels"], results["boxes"] ): - box = [round(i, 2) for i in box.tolist()] + box_coords = [round(i, 2) for i in box.tolist()] entities.append( ImageEntity( - label=self._model.config.id2label[label.item()], + label=id2label[label.item()], score=score.item(), - box=box, + box=box_coords, ) ) return entities diff --git a/src/avalan/model/vision/diffusion/animation.py b/src/avalan/model/vision/diffusion/animation.py index 6cff9652..fd15fc97 100644 --- a/src/avalan/model/vision/diffusion/animation.py +++ b/src/avalan/model/vision/diffusion/animation.py @@ -6,6 +6,7 @@ from dataclasses import replace from logging import Logger, getLogger +from typing import Any from diffusers import ( AnimateDiffPipeline, @@ -40,15 +41,17 @@ def __init__( def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" + assert self._settings.checkpoint is not None, "Checkpoint is required" dtype = Engine.weight(self._settings.weight_type) - adapter = MotionAdapter().to(self._device, dtype) + adapter: Any = MotionAdapter().to(self._device, dtype) # type: ignore[attr-defined] adapter.load_state_dict( load_file( hf_hub_download(self._model_id, self._settings.checkpoint), device=self._device, ) ) - pipe = AnimateDiffPipeline.from_pretrained( + pipe: DiffusionPipeline = AnimateDiffPipeline.from_pretrained( self._settings.base_model_id, motion_adapter=adapter, torch_dtype=dtype, @@ -57,15 +60,15 @@ def _load_model( return pipe @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, path: str, *, - beta_schedule: BetaSchedule = "linear", + beta_schedule: BetaSchedule = BetaSchedule.LINEAR, guidance_scale: float = 1.0, steps: int = 4, - timestep_spacing: TimestepSpacing = "trailing", + timestep_spacing: TimestepSpacing = TimestepSpacing.TRAILING, ) -> str: assert steps and steps in [ 1, @@ -73,21 +76,24 @@ async def __call__( 4, 8, ], f"Invalid number of steps: {steps}, can only be 1, 2, 4, or 8" + assert self._model is not None, "Model must be loaded" + + model: Any = self._model scheduler_settings = (timestep_spacing, beta_schedule) if scheduler_settings not in self._schedulers: scheduler = EulerDiscreteScheduler.from_config( - self._model.scheduler.config, - timestep_spacing=timestep_spacing, - beta_schedule=beta_schedule, + model.scheduler.config, + timestep_spacing=timestep_spacing.value, + beta_schedule=beta_schedule.value, ) self._schedulers[scheduler_settings] = scheduler else: scheduler = self._schedulers[scheduler_settings] - self._model.scheduler = scheduler + model.scheduler = scheduler with inference_mode(): - output = self._model( + output = model( prompt=input if isinstance(input, str) else str(input), guidance_scale=guidance_scale, num_inference_steps=steps, diff --git a/src/avalan/model/vision/diffusion/image.py b/src/avalan/model/vision/diffusion/image.py index cfcb13d2..946bdac6 100644 --- a/src/avalan/model/vision/diffusion/image.py +++ b/src/avalan/model/vision/diffusion/image.py @@ -11,7 +11,7 @@ from dataclasses import replace from logging import Logger, getLogger -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from torch import inference_mode @@ -19,7 +19,7 @@ class TextToImageModel(BaseVisionModel): - _base: DiffusionPipeline + _base: Any def __init__( self, @@ -38,7 +38,7 @@ def _load_model( dtype = Engine.weight(self._settings.weight_type) dtype_variant = self._settings.weight_type - base = DiffusionPipeline.from_pretrained( + base: Any = DiffusionPipeline.from_pretrained( self._model_id, torch_dtype=dtype, variant=dtype_variant, @@ -46,7 +46,7 @@ def _load_model( ) base.to(self._device) - refiner = DiffusionPipeline.from_pretrained( + refiner: Any = DiffusionPipeline.from_pretrained( self._settings.refiner_model_id, text_encoder_2=base.text_encoder_2, vae=base.vae, @@ -58,10 +58,10 @@ def _load_model( self._base = base - return refiner + return refiner # type: ignore[no-any-return] @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, path: str, @@ -81,7 +81,9 @@ async def __call__( and n_steps and output_type ) + assert self._model is not None, "Model must be loaded" + model: Any = self._model with inference_mode(): image = self._base( prompt=input if isinstance(input, str) else str(input), @@ -89,7 +91,7 @@ async def __call__( denoising_end=high_noise_frac, output_type=output_type, ).images - image = self._model( + image = model( prompt=input if isinstance(input, str) else str(input), num_inference_steps=n_steps, denoising_start=high_noise_frac, diff --git a/src/avalan/model/vision/diffusion/video.py b/src/avalan/model/vision/diffusion/video.py index 18411481..cb4ddcba 100644 --- a/src/avalan/model/vision/diffusion/video.py +++ b/src/avalan/model/vision/diffusion/video.py @@ -6,6 +6,7 @@ from dataclasses import replace from logging import Logger, getLogger +from typing import Any from diffusers import DiffusionPipeline from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition @@ -32,7 +33,7 @@ def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: dtype = Engine.weight(self._settings.weight_type) - base_pipe = DiffusionPipeline.from_pretrained( + base_pipe: Any = DiffusionPipeline.from_pretrained( self._model_id, torch_dtype=dtype, ).to(self._device) @@ -42,10 +43,10 @@ def _load_model( torch_dtype=dtype, ).to(self._device) base_pipe.vae.enable_tiling() - return base_pipe + return base_pipe # type: ignore[no-any-return] @override - async def __call__( + async def __call__( # type: ignore[override] self, input: Input, negative_prompt: str, @@ -63,22 +64,24 @@ async def __call__( width: int = 832, steps: int = 30, ) -> str: + assert self._model is not None, "Model must be loaded" image = load_image(reference_path) video = load_video(export_to_video([image])) condition = LTXVideoCondition(video=video, frame_index=0) + model: Any = self._model down_h = int(height * downscale) down_w = int(width * downscale) down_h, down_w = ( TextToVideoModel._round_to_nearest_resolution_acceptable_by_vae( down_h, down_w, - ratio=self._model.vae_spatial_compression_ratio, + ratio=model.vae_spatial_compression_ratio, ) ) with inference_mode(): - latents = self._model( + latents = model( conditions=[condition], prompt=input if isinstance(input, str) else str(input), negative_prompt=negative_prompt, @@ -91,11 +94,12 @@ async def __call__( ).frames upscaled_h, upscaled_w = down_h * 2, down_w * 2 - upscaled_latents = self._upsampler_pipe( + upsampler: Any = self._upsampler_pipe + upscaled_latents = upsampler( latents=latents, output_type="latent" ).frames - video = self._model( + video = model( conditions=[condition], prompt=input if isinstance(input, str) else str(input), negative_prompt=negative_prompt, diff --git a/src/avalan/model/vision/segmentation.py b/src/avalan/model/vision/segmentation.py index 3cdd0652..710b84e7 100644 --- a/src/avalan/model/vision/segmentation.py +++ b/src/avalan/model/vision/segmentation.py @@ -3,7 +3,7 @@ from ...model.vendor import TextGenerationVendor from ...model.vision import BaseVisionModel -from typing import Literal +from typing import Any, Literal from diffusers import DiffusionPipeline from PIL import Image @@ -16,40 +16,50 @@ class SemanticSegmentationModel(BaseVisionModel): + _processor: AutoImageProcessor + def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" self._processor = AutoImageProcessor.from_pretrained( self._model_id, # default behavior in transformers v4.48 use_fast=True, ) - model = AutoModelForSemanticSegmentation.from_pretrained( - self._model_id, - device_map=self._device, - tp_plan=Engine._get_tp_plan(self._settings.parallel), - distributed_config=Engine._get_distributed_config( - self._settings.distributed_config - ), + model: PreTrainedModel = ( + AutoModelForSemanticSegmentation.from_pretrained( + self._model_id, + device_map=self._device, + tp_plan=Engine._get_tp_plan(self._settings.parallel), + distributed_config=Engine._get_distributed_config( + self._settings.distributed_config + ), + ) ) model.eval() return model @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, tensor_format: Literal["pt"] = "pt", ) -> list[str]: + assert self._model is not None, "Model must be loaded" + assert isinstance( + self._model, PreTrainedModel + ), "Model must be PreTrainedModel" image = BaseVisionModel._get_image(image_source) - inputs = self._processor(images=image, return_tensors=tensor_format) + inputs: Any = self._processor( # type: ignore[operator] + images=image, return_tensors=tensor_format + ) inputs.to(self._device) with inference_mode(): - logits = self._model(**inputs).logits + logits = self._model(**inputs).logits # type: ignore[operator] # shape (height, width) with class indices mask = logits.argmax(dim=1)[0] labels_tensor = unique(mask) - labels = [ - self._model.config.id2label[idx.item()] for idx in labels_tensor - ] + id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment] + labels = [id2label[idx.item()] for idx in labels_tensor] return labels diff --git a/src/avalan/model/vision/text.py b/src/avalan/model/vision/text.py index 0fc28876..1c6b3144 100644 --- a/src/avalan/model/vision/text.py +++ b/src/avalan/model/vision/text.py @@ -1,7 +1,6 @@ from ...compat import override from ...entities import ( GenerationSettings, - ImageTextGenerationLoaderClass, Input, MessageRole, ) @@ -10,7 +9,7 @@ from ...model.vendor import TextGenerationVendor from ...model.vision import BaseVisionModel -from typing import Literal +from typing import Any, Literal, cast from diffusers import DiffusionPipeline from PIL import Image @@ -33,12 +32,13 @@ class ImageToTextModel(TransformerModel): def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" self._processor = AutoImageProcessor.from_pretrained( self._model_id, # default behavior in transformers v4.48 use_fast=True, ) - model = AutoModelForVision2Seq.from_pretrained( + model: PreTrainedModel = AutoModelForVision2Seq.from_pretrained( self._model_id, device_map=self._device, tp_plan=Engine._get_tp_plan(self._settings.parallel), @@ -58,48 +58,55 @@ def _tokenize_input( raise NotImplementedError() @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, *, skip_special_tokens: bool = True, tensor_format: Literal["pt"] = "pt", ) -> str: + assert self._model is not None, "Model must be loaded" + assert self._tokenizer is not None, "Tokenizer must be loaded" image = BaseVisionModel._get_image(image_source) - inputs = self._processor(images=image, return_tensors=tensor_format) + inputs: Any = self._processor( # type: ignore[operator] + images=image, return_tensors=tensor_format + ) inputs.to(self._device) with inference_mode(): - output_ids = self._model.generate(**inputs) + model = cast(PreTrainedModel, self._model) + output_ids = model.generate(**inputs) # type: ignore[operator] - output = self._tokenizer.decode( + output: str = self._tokenizer.decode( output_ids[0], skip_special_tokens=skip_special_tokens ) return output class ImageTextToTextModel(ImageToTextModel): - _loaders: dict[ImageTextGenerationLoaderClass, type[PreTrainedModel]] = { - "auto": AutoModelForImageTextToText, - "qwen2": Qwen2VLForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, + _loaders: dict[str, type[PreTrainedModel]] = { + "auto": AutoModelForImageTextToText, # type: ignore[dict-item] + "qwen2": Qwen2VLForConditionalGeneration, # type: ignore[dict-item] + "gemma3": Gemma3ForConditionalGeneration, # type: ignore[dict-item] } def _load_model( self, ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline: + assert self._model_id is not None, "Model ID is required" + loader_class = self._settings.loader_class or "auto" assert ( - self._settings.loader_class in self._loaders - ), f"Unrecognized loader {self._settings.loader_class}" + loader_class in self._loaders + ), f"Unrecognized loader {loader_class}" self._processor = AutoProcessor.from_pretrained( self._model_id, use_fast=True, ) - loader = self._loaders[self._settings.loader_class] - model = loader.from_pretrained( + loader = self._loaders[loader_class] + model: PreTrainedModel = loader.from_pretrained( self._model_id, torch_dtype=Engine.weight(self._settings.weight_type), device_map=self._device, @@ -111,7 +118,7 @@ def _load_model( return model @override - async def __call__( + async def __call__( # type: ignore[override] self, image_source: str | Image.Image, prompt: str, @@ -123,6 +130,8 @@ async def __call__( skip_special_tokens: bool = True, tensor_format: Literal["pt"] = "pt", ) -> str: + assert self._model is not None, "Model must be loaded" + assert settings is not None, "Generation settings must be provided" image = BaseVisionModel._get_image(image_source).convert("RGB") assert image.width @@ -131,7 +140,7 @@ async def __call__( height = int(ratio * image.height) image = image.resize((width, height), Image.Resampling.LANCZOS) - messages = [] + messages: list[dict[str, Any]] = [] if system_prompt: messages.append( { @@ -156,12 +165,13 @@ async def __call__( } ) - text = self._processor.apply_chat_template( + processor: Any = self._processor + text: str = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=settings.chat_settings.add_generation_prompt, ) - inputs = self._processor( + inputs: Any = processor( text=[text], images=image, videos=None, @@ -173,14 +183,15 @@ async def __call__( inputs.to(self._device) with inference_mode(): - generated_ids = self._model.generate( + model = cast(PreTrainedModel, self._model) + generated_ids = model.generate( # type: ignore[operator] **inputs, max_new_tokens=settings.max_new_tokens ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - output_text = self._processor.batch_decode( + output_text: list[str] = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False, diff --git a/src/avalan/secrets/aws.py b/src/avalan/secrets/aws.py index a71d5fac..24562df3 100644 --- a/src/avalan/secrets/aws.py +++ b/src/avalan/secrets/aws.py @@ -1,22 +1,29 @@ from . import Secrets +from typing import Any + from boto3 import client class AwsSecrets(Secrets): + """Secrets backend using AWS Secrets Manager.""" + _SERVICE = "secretsmanager" - def __init__(self, aws_client: object | None = None): - self._client = aws_client or client(self._SERVICE) + def __init__(self, aws_client: Any | None = None) -> None: + self._client: Any = aws_client or client(self._SERVICE) def read(self, key: str) -> str | None: - response = self._client.get_secret_value(SecretId=key) + """Return secret stored under *key*.""" + response: dict[str, Any] = self._client.get_secret_value(SecretId=key) return response.get("SecretString") def write(self, key: str, secret: str) -> None: + """Store *secret* under *key*.""" self._client.put_secret_value(SecretId=key, SecretString=secret) def delete(self, key: str) -> None: + """Remove secret associated with *key*.""" self._client.delete_secret( SecretId=key, ForceDeleteWithoutRecovery=True, diff --git a/src/avalan/secrets/keyring.py b/src/avalan/secrets/keyring.py index 1e52cc44..a1f584f9 100644 --- a/src/avalan/secrets/keyring.py +++ b/src/avalan/secrets/keyring.py @@ -4,8 +4,8 @@ from keyring import get_keyring from keyring.backend import KeyringBackend except Exception: # pragma: no cover - optional dependency - get_keyring = None # type: ignore[assignment] - KeyringBackend = object # type: ignore[assignment] + get_keyring = None # type: ignore[assignment, misc] + KeyringBackend = object # type: ignore[assignment, misc] class KeyringSecrets(Secrets): @@ -14,7 +14,7 @@ class KeyringSecrets(Secrets): _SERVICE = "avalan" def __init__(self, ring: KeyringBackend | None = None) -> None: - if ring is None and get_keyring: + if ring is None and get_keyring is not None: ring = get_keyring() self._ring = ring diff --git a/src/avalan/server/__init__.py b/src/avalan/server/__init__.py index 113d4510..c46d6158 100644 --- a/src/avalan/server/__init__.py +++ b/src/avalan/server/__init__.py @@ -7,12 +7,12 @@ from .entities import OrchestratorContext from .routers import mcp as mcp_router -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager from importlib import import_module from importlib.util import find_spec from logging import Logger -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING, AsyncContextManager, Mapping, cast from uuid import UUID, uuid4 from fastapi import FastAPI, Request @@ -91,9 +91,9 @@ def _create_lifespan( selected_protocols: Mapping[str, set[str]], agent_id: UUID | None, participant_id: UUID | None, -) -> Callable[[FastAPI], AsyncIterator[None]]: +) -> Callable[[FastAPI], AsyncContextManager[None]]: @asynccontextmanager - async def lifespan(app: FastAPI): + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: logger.info("Initializing app lifespan") from os import environ @@ -199,7 +199,7 @@ def _include_protocol_routers( def _attach_lifespan( - app: FastAPI, lifespan: Callable[[FastAPI], AsyncIterator[None]] + app: FastAPI, lifespan: Callable[[FastAPI], AsyncContextManager[None]] ) -> None: existing = app.router.lifespan_context @@ -208,7 +208,7 @@ def _attach_lifespan( return @asynccontextmanager - async def combined(app_: FastAPI): + async def combined(app_: FastAPI) -> AsyncGenerator[None, None]: async with existing(app_): async with lifespan(app_): yield @@ -406,4 +406,4 @@ async def di_get_orchestrator(request: Request) -> Orchestrator: request.app.state.agent_id = orchestrator.id orchestrator = request.app.state.orchestrator assert orchestrator is not None - return orchestrator + return cast(Orchestrator, orchestrator) diff --git a/src/avalan/server/a2a/router.py b/src/avalan/server/a2a/router.py index 777f4ef8..f2e6a6f6 100644 --- a/src/avalan/server/a2a/router.py +++ b/src/avalan/server/a2a/router.py @@ -24,7 +24,7 @@ from logging import Logger from re import compile from time import time -from typing import TYPE_CHECKING, Any, Final, Iterable +from typing import TYPE_CHECKING, Any, Final, Iterable, cast from uuid import uuid4 if TYPE_CHECKING: @@ -643,7 +643,7 @@ async def _ensure_answer_artifact(self) -> list[dict[str, Any]]: async def _handle_tool_process(self, event: Event) -> list[dict[str, Any]]: events: list[dict[str, Any]] = [] - payload = event.payload or [] + payload: dict[str, Any] | list[Any] = event.payload or [] if isinstance(payload, dict): calls: Iterable[ToolCall] = payload.get("calls", []) # type: ignore[assignment] else: @@ -908,11 +908,9 @@ async def _response_id(self) -> str | int | None: overview = await self._store.get_task_overview(self._task_id) metadata = overview.get("metadata") or {} self._cached_response_id = metadata.get("jsonrpc_id") - return ( - None - if self._cached_response_id is _STREAM_RESPONSE_ID_UNSET - else self._cached_response_id - ) + if self._cached_response_id is _STREAM_RESPONSE_ID_UNSET: + return None + return cast(str | int | None, self._cached_response_id) async def _task_result(self, event: dict[str, Any]) -> a2a_types.Task: overview = await self._store.get_task_overview(self._task_id) @@ -1315,7 +1313,7 @@ def _call_identifier(item: Token | TokenDetail | Event | str) -> str | None: return str(item.call.id) if isinstance(item, Event): if item.type is EventType.TOOL_PROCESS: - payload = item.payload or [] + payload: dict[str, Any] | list[Any] = item.payload or [] if isinstance(payload, dict): candidates = payload.get("calls", []) else: diff --git a/src/avalan/server/routers/__init__.py b/src/avalan/server/routers/__init__.py index 7688a4ce..4e55c145 100644 --- a/src/avalan/server/routers/__init__.py +++ b/src/avalan/server/routers/__init__.py @@ -1,4 +1,7 @@ from ...agent.orchestrator import Orchestrator +from ...agent.orchestrator.response.orchestrator_response import ( + OrchestratorResponse, +) from ...entities import ( GenerationSettings, Message, @@ -15,7 +18,7 @@ from logging import Logger from time import time -from uuid import uuid4 +from uuid import UUID, uuid4 from fastapi import HTTPException @@ -24,7 +27,7 @@ async def orchestrate( request: ChatCompletionRequest, logger: Logger, orchestrator: Orchestrator, -): +) -> tuple[OrchestratorResponse, UUID, int]: messages = [ Message(role=req.role, content=to_message_content(req.content)) for req in request.messages @@ -40,7 +43,7 @@ async def orchestrate( timestamp = int(time()) settings = GenerationSettings( - use_async_generator=request.stream, + use_async_generator=request.stream or False, temperature=request.temperature, max_new_tokens=request.max_tokens, stop_strings=request.stop, @@ -59,13 +62,10 @@ async def orchestrate( return response, response_id, timestamp -def to_message_content(item): - if isinstance(item, list): - return [ - to_message_content(i) - for i in item - if isinstance(i, (ContentImage, ContentText, str)) - ] +def _convert_single_content( + item: str | ContentText | ContentImage, +) -> MessageContentText | MessageContentImage: + """Convert a single content item to message content type.""" if isinstance(item, ContentImage): return MessageContentImage(type=item.type, image_url=item.image_url) if isinstance(item, ContentText): @@ -73,3 +73,22 @@ def to_message_content(item): if isinstance(item, str): return MessageContentText(type="text", text=item) raise TypeError(f"Unsupported content type: {type(item).__name__}") + + +def to_message_content( + item: str | list[ContentText | ContentImage], +) -> ( + MessageContentText + | MessageContentImage + | list[MessageContentText | MessageContentImage] +): + """Convert request content to message content types.""" + if isinstance(item, list): + return [ + _convert_single_content(i) + for i in item + if isinstance(i, (ContentImage, ContentText, str)) + ] + if not isinstance(item, (str, ContentText, ContentImage)): + raise TypeError(f"Unsupported content type: {type(item).__name__}") + return _convert_single_content(item) diff --git a/src/avalan/server/routers/chat.py b/src/avalan/server/routers/chat.py index 5610f727..1285782e 100644 --- a/src/avalan/server/routers/chat.py +++ b/src/avalan/server/routers/chat.py @@ -1,5 +1,11 @@ from ...agent.orchestrator import Orchestrator -from ...entities import MessageRole, ReasoningToken, ToolCallToken +from ...entities import ( + MessageRole, + ReasoningToken, + Token, + TokenDetail, + ToolCallToken, +) from ...event import Event from ...server.entities import ( ChatCompletionChoice, @@ -31,7 +37,7 @@ async def create_chat_completion( request: ChatCompletionRequest, logger: Logger = Depends(di_get_logger), orchestrator: Orchestrator = Depends(di_get_orchestrator), -): +) -> ChatCompletionResponse | StreamingResponse: assert orchestrator and isinstance(orchestrator, Orchestrator) assert logger and isinstance(logger, Logger) assert request and request.messages @@ -60,11 +66,16 @@ async def generate_chunks(): async for token in response: if isinstance(token, Event): continue - elif isinstance(token, (ReasoningToken, ToolCallToken)): - token = token.token + token_text: str + if isinstance(token, (ReasoningToken, ToolCallToken)): + token_text = token.token + elif isinstance(token, (Token, TokenDetail)): + token_text = token.token + else: + token_text = str(token) choice = ChatCompletionChunkChoice( - delta=ChatCompletionChunkChoiceDelta(content=token) + delta=ChatCompletionChunkChoiceDelta(content=token_text) ) chunk = ChatCompletionChunk( id=str(response_id), @@ -92,13 +103,13 @@ async def generate_chunks(): choices = [ ChatCompletionChoice( index=i, - message=ChatMessage(role=str(MessageRole.ASSISTANT), content=text), + message=ChatMessage(role=MessageRole.ASSISTANT, content=text), finish_reason="stop", ) for i in range(request.n or 1) ] usage = ChatCompletionUsage() - response = ChatCompletionResponse( + chat_response = ChatCompletionResponse( id=str(response_id), created=timestamp, model=request.model, @@ -106,9 +117,9 @@ async def generate_chunks(): usage=usage, ) logger.debug( - "Generated chat completion response #%s %r", response_id, response + "Generated chat completion response #%s %r", response_id, chat_response ) await orchestrator.sync_messages() - return response + return chat_response diff --git a/src/avalan/server/routers/mcp.py b/src/avalan/server/routers/mcp.py index a4d053a7..d64fd464 100644 --- a/src/avalan/server/routers/mcp.py +++ b/src/avalan/server/routers/mcp.py @@ -1,5 +1,6 @@ from ...agent.orchestrator import Orchestrator from ...entities import ( + MessageRole, ReasoningToken, Token, TokenDetail, @@ -345,7 +346,7 @@ async def _expect_jsonrpc_message( return message, messages -def _server_info(request: Request) -> dict[str, str]: +def _server_info(request: Request) -> dict[str, JSONValue]: app = request.app name = getattr(app, "title", None) or "avalan" version = getattr(app, "version", None) @@ -397,7 +398,11 @@ def _build_chat_request( model_id = _default_model_id(orchestrator) return ChatCompletionRequest( model=model_id, - messages=[ChatMessage(role="user", content=tool_request.input_string)], + messages=[ + ChatMessage( + role=MessageRole.USER, content=tool_request.input_string + ) + ], stream=True, ) @@ -409,7 +414,7 @@ async def _start_tool_streaming_response( request_id: str | int, tool_request: MCPToolRequest, progress_token: str, -) -> StreamingResponse: +) -> StreamingResponse | JSONResponse: chat_request = _build_chat_request(tool_request, orchestrator) response, response_uuid, timestamp = await orchestrate( chat_request, logger, orchestrator @@ -580,7 +585,7 @@ def _handle_list_tools_message( return JSONResponse(payload) -def _collect_tool_descriptions(request: Request) -> list[dict[str, JSONValue]]: +def _collect_tool_descriptions(request: Request) -> list[JSONValue]: name = cast(str, getattr(request.app.state, "mcp_tool_name", "run")) description = cast( str, @@ -613,7 +618,7 @@ def _extract_call_arguments( raise HTTPException( status_code=400, detail="Invalid tool arguments" ) - return cast(dict[str, JSONValue], arguments) + return arguments raise HTTPException( status_code=400, detail=f'Unsupported MCP method "{method}"' @@ -655,7 +660,9 @@ async def _stream_mcp_response( resources: dict[str, MCPResource] = {} finished_normally = False - def emit(message: JSONObject) -> Iterator[bytes]: + def emit( + message: JSONObject | JSONRPCResult | JSONRPCError, + ) -> Iterator[bytes]: encoded = dumps(message, separators=(",", ":")) + "\n" yield encoded.encode("utf-8") @@ -697,18 +704,18 @@ def emit(message: JSONObject) -> Iterator[bytes]: continue if isinstance(item, ToolCallToken): - notification = _tool_call_token_notification(item) - if notification is not None: - for payload in emit(notification): + tool_notification = _tool_call_token_notification(item) + if tool_notification is not None: + for payload in emit(tool_notification): yield payload continue text = _token_text(item) if text: + answer_chunks.append(text) if isinstance(item, Token): - answer_chunks.append(text) - notification: JSONObject = { + answer_notification: JSONObject = { "jsonrpc": "2.0", "method": "notifications/message", "params": { @@ -720,8 +727,7 @@ def emit(message: JSONObject) -> Iterator[bytes]: }, } else: - answer_chunks.append(text) - notification: JSONObject = { + answer_notification = { "jsonrpc": "2.0", "method": "notifications/progress", "params": { @@ -732,7 +738,7 @@ def emit(message: JSONObject) -> Iterator[bytes]: }, }, } - for payload in emit(notification): + for payload in emit(answer_notification): yield payload finished_normally = not cancel_event.is_set() @@ -757,15 +763,15 @@ def emit(message: JSONObject) -> Iterator[bytes]: await _close_response_iterator(response) for resource in resources.values(): closed = await resource_store.close(resource.id) - notification = _resource_notification(closed) - for payload in emit(notification): + resource_notification = _resource_notification(closed) + for payload in emit(resource_notification): yield payload - error_message: JSONRPCError = { + cancel_error: JSONRPCError = { "jsonrpc": "2.0", "id": request_id, "error": {"code": -32000, "message": "Request cancelled"}, } - for payload in emit(error_message): + for payload in emit(cancel_error): yield payload await orchestrator.sync_messages() return @@ -830,11 +836,13 @@ def _token_text(item: ResponseItem) -> str: async def _close_response_iterator(response: StreamResponse) -> None: iterator = getattr(response, "_response_iterator", None) - if iterator and hasattr(iterator, "aclose"): - try: - await cast(AsyncIterator[object], iterator).aclose() - except Exception: # pragma: no cover - best effort cleanup - pass + if iterator is not None: + aclose = getattr(iterator, "aclose", None) + if aclose is not None: + try: + await aclose() + except Exception: # pragma: no cover - best effort cleanup + pass def _tool_call_token_notification( @@ -858,7 +866,7 @@ def _tool_call_token_notification( delta: dict[str, JSONValue] = { "id": str(token.call.id), "name": token.call.name, - "arguments": token.call.arguments, + "arguments": cast(JSONValue, token.call.arguments), } return { "jsonrpc": "2.0", @@ -886,7 +894,7 @@ async def _tool_event_notifications( if item is None: return - tool_call_id = item["id"] + tool_call_id = cast(str, item["id"]) if event.type is EventType.TOOL_PROCESS: tool_summaries[tool_call_id] = { @@ -910,7 +918,7 @@ async def _tool_event_notifications( } return - tool_summary = tool_summaries.setdefault( + tool_summary: dict[str, JSONValue] = tool_summaries.setdefault( tool_call_id, { "id": tool_call_id, @@ -938,7 +946,8 @@ async def _tool_event_notifications( }, } - message = cast(dict[str, JSONValue], payload["params"]["message"]) + params_dict = cast(dict[str, JSONValue], payload["params"]) + message = cast(dict[str, JSONValue], params_dict["message"]) if "error" in item: message["error"] = item["error"] @@ -946,8 +955,9 @@ async def _tool_event_notifications( elif "result" in item: message["resultDelta"] = item["result"] tool_summary["result"] = item["result"] + result_value = item["result"] for resource_key, payload2 in _extract_append_streams( - tool_call_id, item["result"] + tool_call_id, result_value ).items(): name, text = payload2 resource = resources.get(resource_key) @@ -959,7 +969,11 @@ async def _tool_event_notifications( resource = await resource_store.append(resource.id, text) resources[resource_key] = resource yield _resource_notification(resource) - tool_summary.setdefault("resources", []).append( + resources_list = cast( + list[dict[str, JSONValue]], + tool_summary.setdefault("resources", []), + ) + resources_list.append( { "uri": resource.uri, "name": name, @@ -1021,11 +1035,11 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None: return { "id": str(tool_result.call.id), "name": tool_result.name, - "arguments": tool_result.arguments, + "arguments": cast(JSONValue, tool_result.arguments), "error": tool_result.message, } if isinstance(tool_result, ToolCallResult): - result: JSONValue = ( + result_value: JSONValue = ( tool_result.result if isinstance( tool_result.result, (dict, list, str, int, float, bool) @@ -1035,8 +1049,8 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None: return { "id": str(tool_result.call.id), "name": tool_result.name, - "arguments": tool_result.arguments, - "result": result, + "arguments": cast(JSONValue, tool_result.arguments), + "result": result_value, } if isinstance(event.payload, list) and event.payload: call = event.payload[0] @@ -1051,7 +1065,7 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None: return { "id": str(call.id), "name": call.name, - "arguments": call.arguments, + "arguments": cast(JSONValue, call.arguments), } diff --git a/src/avalan/server/routers/responses.py b/src/avalan/server/routers/responses.py index f14fba2a..606f57a9 100644 --- a/src/avalan/server/routers/responses.py +++ b/src/avalan/server/routers/responses.py @@ -8,7 +8,7 @@ ToolCallToken, ) from ...event import Event, EventType -from ...server.entities import ResponsesRequest +from ...server.entities import ChatCompletionRequest, ResponsesRequest from ...utils import to_json from .. import di_get_logger, di_get_orchestrator from ..sse import sse_headers, sse_message @@ -16,6 +16,7 @@ from enum import Enum, auto from logging import Logger +from typing import Any from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse @@ -30,18 +31,31 @@ class ResponseState(Enum): router = APIRouter(tags=["responses"]) -@router.post("/responses") +@router.post("/responses", response_model=None) async def create_response( request: ResponsesRequest, logger: Logger = Depends(di_get_logger), orchestrator: Orchestrator = Depends(di_get_orchestrator), -): +) -> StreamingResponse | dict[str, Any]: + """Create a response using the OpenAI Responses API format.""" assert orchestrator and isinstance(orchestrator, Orchestrator) assert logger and isinstance(logger, Logger) assert request and request.messages + chat_request = ChatCompletionRequest( + model=request.model, + messages=request.messages, + temperature=request.temperature, + top_p=request.top_p, + n=request.n, + stream=request.stream, + stop=request.stop, + max_tokens=request.max_tokens, + response_format=request.response_format, + ) + response, response_id, timestamp = await orchestrate( - request, logger, orchestrator + chat_request, logger, orchestrator ) if request.stream: @@ -69,19 +83,20 @@ async def generate(): tool_call_id: str | None = None async for token in response: - is_event = isinstance(token, Event) - if is_event and token.type not in ( - EventType.TOOL_PROCESS, - EventType.TOOL_RESULT, - ): - continue + if isinstance(token, Event): + if token.type not in ( + EventType.TOOL_PROCESS, + EventType.TOOL_RESULT, + ): + continue call_id: str | None = None - if is_event and token.type in ( + if isinstance(token, Event) and token.type in ( EventType.TOOL_PROCESS, EventType.TOOL_RESULT, ): - call_id = _tool_call_event_item(token)["id"] + event_item = _tool_call_event_item(token) + call_id = str(event_item["id"]) elif ( isinstance(token, ToolCallToken) and token.call is not None ): @@ -259,8 +274,6 @@ def _switch_state( current_tool_call_id: str | None, new_tool_call_id: str | None, ) -> list[str]: - new_state: ResponseState | None - events: list[str] = [] changed = state is not new_state or ( state is ResponseState.TOOL_CALLING @@ -295,7 +308,15 @@ def _switch_state( def _new_state( - token: ReasoningToken | ToolCallToken | Token | TokenDetail | str | None, + token: ( + ReasoningToken + | ToolCallToken + | Token + | TokenDetail + | Event + | str + | None + ), ) -> ResponseState | None: if isinstance(token, ReasoningToken): new_state = ResponseState.REASONING @@ -407,16 +428,36 @@ def _content_part_done(id: str | None = None) -> str: return sse_message(to_json(data), event="response.content_part.done") -def _tool_call_event_item(event: Event) -> dict: - tool_result = ( - event.payload["result"] - if event.type == EventType.TOOL_RESULT and "result" in event.payload - else None - ) - tool_call = ( - tool_result.call if tool_result is not None else event.payload[0] - ) - item = { +def _tool_call_event_item(event: Event) -> dict[str, Any]: + """Extract tool call item from an event payload.""" + payload = event.payload + assert payload is not None, "Event payload must be provided" + + tool_result: Any = None + if ( + event.type == EventType.TOOL_RESULT + and isinstance(payload, dict) + and "result" in payload + ): + tool_result = payload["result"] + + tool_call: Any + if tool_result is not None and hasattr(tool_result, "call"): + tool_call = tool_result.call + elif isinstance(payload, list) and payload: + tool_call = payload[0] + elif isinstance(payload, dict): + calls = payload.get("calls") + if isinstance(calls, list) and calls: + tool_call = calls[0] + else: + call = payload.get("call") + assert call is not None, "Event must have calls or call field" + tool_call = call + else: + raise ValueError("Invalid event payload type") + + item: dict[str, Any] = { "type": "function_call", "id": str(tool_call.id), "name": tool_call.name, diff --git a/src/avalan/tool/__init__.py b/src/avalan/tool/__init__.py index d4a0f2d3..832374c0 100644 --- a/src/avalan/tool/__init__.py +++ b/src/avalan/tool/__init__.py @@ -1,11 +1,14 @@ -from __future__ import annotations - from abc import ABC from collections.abc import Callable, Sequence -from contextlib import AsyncExitStack, ContextDecorator +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + AsyncExitStack, + ContextDecorator, +) from inspect import Signature, isfunction, signature -from types import FunctionType -from typing import get_type_hints +from types import FunctionType, TracebackType +from typing import Any, Union, cast, get_type_hints from transformers.utils import get_json_schema @@ -44,19 +47,20 @@ def _get_signature( return_annotation=function_signature.return_annotation, ) - async def __aenter__(self) -> "ToolSet": + async def __aenter__(self) -> "Tool": return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: BaseException | None, + traceback: TracebackType | None, ) -> bool: if self._exit_stack: - return await self._exit_stack.__aexit__( + result = await self._exit_stack.__aexit__( exc_type, exc_value, traceback ) + return result if result is not None else False return True @@ -65,14 +69,14 @@ class ToolSet(ContextDecorator): _namespace: str | None _exit_stack: AsyncExitStack - _tools: Sequence[Callable] + _tools: list[Callable[..., Any]] @property def namespace(self) -> str | None: return self._namespace @property - def tools(self) -> Sequence[Callable]: + def tools(self) -> list[Callable[..., Any]]: return self._tools def __init__( @@ -80,15 +84,15 @@ def __init__( *, exit_stack: AsyncExitStack | None = None, namespace: str | None = None, - tools: Sequence[Callable | "ToolSet"], + tools: Sequence[Union[Callable[..., Any], "ToolSet"]], ): self._namespace = namespace self._exit_stack = exit_stack or AsyncExitStack() - self._tools = tools + self._tools = list(tools) exclude_type_names = ["self", "context"] - for i, tool in enumerate(self.tools): + for i, tool in enumerate(self._tools): if ( not isfunction(tool) and callable(tool) @@ -102,12 +106,16 @@ def __init__( if type_name not in exclude_type_names } tool.__annotations__ = type_hints - tool.__signature__ = Tool._get_signature( - tool.__call__, exclude_type_names + setattr( + tool, + "__signature__", + Tool._get_signature( + cast(FunctionType, tool.__call__), exclude_type_names + ), ) if not tool.__doc__ and tool.__call__.__doc__: tool.__doc__ = tool.__call__.__doc__ - self.tools[i] = tool + self._tools[i] = tool def with_enabled_tools(self, enable_tools: list[str]) -> "ToolSet": prefix = f"{self.namespace}." if self.namespace else "" @@ -128,18 +136,25 @@ def with_enabled_tools(self, enable_tools: list[str]) -> "ToolSet": async def __aenter__(self) -> "ToolSet": for tool in self.tools: if hasattr(tool, "__aenter__"): - await self._exit_stack.enter_async_context(tool) + await self._exit_stack.enter_async_context( + cast(AbstractAsyncContextManager[Any, bool | None], tool) + ) elif hasattr(tool, "__enter__"): - self._exit_stack.enter_context(tool) + self._exit_stack.enter_context( + cast(AbstractContextManager[Any, bool | None], tool) + ) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: BaseException | None, + traceback: TracebackType | None, ) -> bool: - return await self._exit_stack.__aexit__(exc_type, exc_value, traceback) + result = await self._exit_stack.__aexit__( + exc_type, exc_value, traceback + ) + return result if result is not None else False def json_schemas(self, prefix: str | None = None) -> list[dict] | None: schemas = [] diff --git a/src/avalan/tool/browser.py b/src/avalan/tool/browser.py index 2f20555c..55dd2f81 100644 --- a/src/avalan/tool/browser.py +++ b/src/avalan/tool/browser.py @@ -7,7 +7,8 @@ from dataclasses import dataclass from email.message import EmailMessage from io import BytesIO, TextIOBase -from typing import TYPE_CHECKING, Literal, final +from types import TracebackType +from typing import TYPE_CHECKING, Any, Literal, final if TYPE_CHECKING: from faiss import IndexFlatL2 @@ -117,7 +118,9 @@ def __init__( self._partitioner = partitioner self.__name__ = "open" - async def __call__(self, url: str, *, context: ToolCallContext) -> str: + async def __call__( # type: ignore[override] + self, url: str, *, context: ToolCallContext + ) -> str: content = await self._read(url) if ( @@ -185,9 +188,9 @@ async def __call__(self, url: str, *, context: ToolCallContext) -> str: else knowledge_chunk ) - content = knowledge_match + content = str(knowledge_match) if knowledge_match else "" - return content + return str(content) async def _read(self, url: str) -> str: if ( @@ -201,14 +204,15 @@ async def _read(self, url: str) -> str: return content if not self._browser: + client: Any = self._client browser_type = ( - self._client.chromium + client.chromium if self._settings.engine == "chromium" else ( - self._client.firefox + client.firefox if self._settings.engine == "firefox" else ( - self._client.webkit + client.webkit if self._settings.engine == "webkit" else None ) @@ -254,6 +258,7 @@ async def _read(self, url: str) -> str: ) response = await self._page.goto(url) + assert response is not None, f"Failed to load {url}" contents: str = await self._page.content() content_type_header = response.headers.get("content-type", None) assert content_type_header @@ -262,12 +267,18 @@ async def _read(self, url: str) -> str: m["content-type"] = content_type_header maintype = m.get_content_maintype() or "text" assert maintype == "text" - encoding = (m.get_param("charset") or "utf-8").lower() + charset_param = m.get_param("charset") + encoding = ( + str(charset_param).lower() + if charset_param and not isinstance(charset_param, tuple) + else "utf-8" + ) mime_type = m.get_content_type() byte_stream = BytesIO(contents.encode(encoding)) - result = self._md.convert_stream(byte_stream, mime_type=mime_type) - content = result.text_content - return content + assert self._md is not None, "MarkItDown must be initialized" + md_result = self._md.convert_stream(byte_stream, mime_type=mime_type) + page_content: str = md_result.text_content or "" + return page_content def with_client(self, client: "PlaywrightContextManager") -> "BrowserTool": self._client = client @@ -278,7 +289,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: BaseException | None, + traceback: TracebackType | None, ) -> bool: if self._page: await self._page.close() @@ -313,13 +324,20 @@ def __init__( self._client = async_playwright() tools = [BrowserTool(settings, self._client, partitioner=partitioner)] - return super().__init__( + super().__init__( exit_stack=exit_stack, namespace=namespace, tools=tools ) @override async def __aenter__(self) -> "BrowserToolSet": - self._client = await self._exit_stack.enter_async_context(self._client) + assert ( + self._client is not None + ), "Playwright client must be initialized" + playwright_instance: Any = await self._exit_stack.enter_async_context( + self._client + ) for i, tool in enumerate(self._tools): - self._tools[i] = tool.with_client(self._client) - return await super().__aenter__() + if hasattr(tool, "with_client"): + self._tools[i] = tool.with_client(playwright_instance) + await super().__aenter__() + return self diff --git a/src/avalan/tool/code.py b/src/avalan/tool/code.py index 4c6f9dbb..bfc633ab 100644 --- a/src/avalan/tool/code.py +++ b/src/avalan/tool/code.py @@ -5,6 +5,7 @@ from asyncio import create_subprocess_exec from asyncio.subprocess import PIPE from contextlib import AsyncExitStack +from typing import Any try: from RestrictedPython import ( @@ -37,10 +38,10 @@ def __init__(self) -> None: super().__init__() self.__name__ = "run" - async def __call__( - self, code: str, *args: any, context: ToolCallContext, **kwargs: any + async def __call__( # type: ignore[override] + self, code: str, *args: Any, context: ToolCallContext, **kwargs: Any ) -> str: - locals_dict = {} + locals_dict: dict[str, Any] = {} byte_code = compile_restricted( code, filename="", @@ -90,7 +91,7 @@ def __init__(self) -> None: super().__init__() self.__name__ = "search.ast.grep" - async def __call__( + async def __call__( # type: ignore[override] self, pattern: str, *, diff --git a/src/avalan/tool/database/__init__.py b/src/avalan/tool/database/__init__.py index 673773fb..2f94bcd0 100644 --- a/src/avalan/tool/database/__init__.py +++ b/src/avalan/tool/database/__init__.py @@ -6,6 +6,7 @@ from asyncio import sleep from dataclasses import dataclass from re import compile as regex_compile +from types import TracebackType from typing import Any, Literal, final try: @@ -22,22 +23,22 @@ _SQLALCHEMY_AVAILABLE = True except ImportError: _SQLALCHEMY_AVAILABLE = False - MetaData = None # type: ignore[assignment] - event = None # type: ignore[assignment] - func = None # type: ignore[assignment] - select = None # type: ignore[assignment] - text = None # type: ignore[assignment] - SATable = None # type: ignore[assignment] - sqlalchemy_inspect = None # type: ignore[assignment] - Connection = None # type: ignore[assignment] - Inspector = None # type: ignore[assignment] - NoSuchTableError = None # type: ignore[assignment] - SQLAlchemyError = None # type: ignore[assignment] - AsyncEngine = None # type: ignore[assignment] - create_async_engine = None # type: ignore[assignment] - Select = None # type: ignore[assignment] - ColumnElement = None # type: ignore[assignment] - TextClause = None # type: ignore[assignment] + MetaData = None # type: ignore[misc, assignment] + event = None # type: ignore[misc, assignment] + func = None # type: ignore[misc, assignment] + select = None # type: ignore[misc, assignment] + text = None # type: ignore[misc, assignment] + SATable = None # type: ignore[misc, assignment] + sqlalchemy_inspect = None # type: ignore[misc, assignment] + Connection = None # type: ignore[misc, assignment] + Inspector = None # type: ignore[misc, assignment] + NoSuchTableError = None # type: ignore[misc, assignment] + SQLAlchemyError = None # type: ignore[misc, assignment] + AsyncEngine = None # type: ignore[misc, assignment] + create_async_engine = None # type: ignore[misc, assignment] + Select = None # type: ignore[misc, assignment] + ColumnElement = None # type: ignore[misc, assignment] + TextClause = None # type: ignore[misc, assignment] try: from sqlglot import exp, parse, parse_one @@ -45,9 +46,9 @@ _SQLGLOT_AVAILABLE = True except ImportError: _SQLGLOT_AVAILABLE = False - exp = None # type: ignore[assignment] - parse = None # type: ignore[assignment] - parse_one = None # type: ignore[assignment] + exp = None # type: ignore[misc, assignment] + parse = None # type: ignore[misc, assignment] + parse_one = None # type: ignore[misc, assignment] @final @@ -257,6 +258,8 @@ def _apply_identifier_case(self, connection: Connection, sql: str) -> str: if self._normalizer is None: return sql + normalizer = self._normalizer + self._ensure_sqlglot_available() inspector = inspect(connection) _, schemas = self._schemas(connection, inspector) @@ -296,7 +299,7 @@ def normalize_table(node: exp.Expression) -> exp.Expression: if isinstance(schema_ident, exp.Identifier) else None ) - key = self._normalizer.normalize(name) + key = normalizer.normalize(name) lookup = f"{schema}.{key}" if schema else key actual = replacements.get(lookup) or replacements.get(key) if actual: @@ -387,7 +390,7 @@ def _schemas( if dialect == "postgresql": sys = {"information_schema", "pg_catalog"} - schemas = [ + schemas: list[str | None] = [ s for s in inspector.get_schema_names() if s not in sys and not (s or "").startswith("pg_") @@ -396,9 +399,14 @@ def _schemas( schemas.append(default_schema) return default_schema, schemas - all_schemas = inspector.get_schema_names() or ( - [default_schema] if default_schema is not None else [None] - ) + raw_schemas = inspector.get_schema_names() + all_schemas: list[str | None] + if raw_schemas: + all_schemas = list(raw_schemas) + elif default_schema is not None: + all_schemas = [default_schema] + else: + all_schemas = [None] sys_filters = { "mysql": { @@ -421,9 +429,10 @@ def _schemas( schemas = [s for s in all_schemas if s not in sys] if not schemas: - schemas = ( - [default_schema] if default_schema is not None else [None] - ) + if default_schema is not None: + schemas = [default_schema] + else: + schemas = [None] seen: set[str | None] = set() uniq: list[str | None] = [] @@ -442,7 +451,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: BaseException | None, + traceback: TracebackType | None, ) -> bool: return await super().__aexit__(exc_type, exc_value, traceback) diff --git a/src/avalan/tool/database/count.py b/src/avalan/tool/database/count.py index 5e4783e1..47e61e52 100644 --- a/src/avalan/tool/database/count.py +++ b/src/avalan/tool/database/count.py @@ -44,7 +44,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]: return (sch or None), tbl return None, qualified - async def __call__( + async def __call__( # type: ignore[override] self, table_name: str, *, context: ToolCallContext ) -> int: assert table_name, "table_name must not be empty" diff --git a/src/avalan/tool/database/inspect.py b/src/avalan/tool/database/inspect.py index 9fac3d4f..2f707bce 100644 --- a/src/avalan/tool/database/inspect.py +++ b/src/avalan/tool/database/inspect.py @@ -38,7 +38,7 @@ def __init__( ) self.__name__ = "inspect" - async def __call__( + async def __call__( # type: ignore[override] self, table_names: list[str], schema: str | None = None, diff --git a/src/avalan/tool/database/keys.py b/src/avalan/tool/database/keys.py index 2b471f7b..21001cf1 100644 --- a/src/avalan/tool/database/keys.py +++ b/src/avalan/tool/database/keys.py @@ -8,6 +8,8 @@ TableKey, ) +from typing import Any, cast + class DatabaseKeysTool(DatabaseTool): """List primary and unique keys defined on a table. @@ -37,7 +39,7 @@ def __init__( ) self.__name__ = "keys" - async def __call__( + async def __call__( # type: ignore[override] self, table_name: str, schema: str | None = None, @@ -71,36 +73,38 @@ def _collect( keys: list[TableKey] = [] - pk = ( - inspector.get_pk_constraint(actual_table, schema=resolved_schema) - or {} - ) - pk_columns = tuple( - pk.get("constrained_columns") or pk.get("column_names") or [] + pk = cast( + dict[str, Any] | None, + inspector.get_pk_constraint(actual_table, schema=resolved_schema), ) - if pk_columns: - keys.append( - TableKey( - type="primary", - name=pk.get("name"), - columns=pk_columns, + if pk: + pk_col_raw = pk.get("constrained_columns") or [] + pk_columns: tuple[str, ...] = tuple(str(c) for c in pk_col_raw) + if pk_columns: + keys.append( + TableKey( + type="primary", + name=pk.get("name"), + columns=pk_columns, + ) ) - ) - unique_constraints = ( + unique_constraints = cast( + list[dict[str, Any]], inspector.get_unique_constraints( actual_table, schema=resolved_schema, ) - or [] + or [], ) for constraint in unique_constraints: - columns = tuple( + col_raw = ( constraint.get("column_names") or constraint.get("constrained_columns") or [] ) + columns: tuple[str, ...] = tuple(str(c) for c in col_raw) if not columns: continue keys.append( diff --git a/src/avalan/tool/database/kill.py b/src/avalan/tool/database/kill.py index b9637212..3179cf9c 100644 --- a/src/avalan/tool/database/kill.py +++ b/src/avalan/tool/database/kill.py @@ -35,7 +35,7 @@ def __init__( ) self.__name__ = "kill" - async def __call__( + async def __call__( # type: ignore[override] self, task_id: str, *, diff --git a/src/avalan/tool/database/locks.py b/src/avalan/tool/database/locks.py index 37690916..0313e166 100644 --- a/src/avalan/tool/database/locks.py +++ b/src/avalan/tool/database/locks.py @@ -39,7 +39,7 @@ def __init__( ) self.__name__ = "locks" - async def __call__( + async def __call__( # type: ignore[override] self, *, context: ToolCallContext, diff --git a/src/avalan/tool/database/plan.py b/src/avalan/tool/database/plan.py index acda2e87..2c5667f7 100644 --- a/src/avalan/tool/database/plan.py +++ b/src/avalan/tool/database/plan.py @@ -37,7 +37,7 @@ def __init__( ) self.__name__ = "plan" - async def __call__( + async def __call__( # type: ignore[override] self, sql: str, *, context: ToolCallContext ) -> QueryPlan: await self._sleep_if_configured() diff --git a/src/avalan/tool/database/relationships.py b/src/avalan/tool/database/relationships.py index 3dc1315e..d4800b8f 100644 --- a/src/avalan/tool/database/relationships.py +++ b/src/avalan/tool/database/relationships.py @@ -38,7 +38,7 @@ def __init__( ) self.__name__ = "relationships" - async def __call__( + async def __call__( # type: ignore[override] self, table_name: str, *, diff --git a/src/avalan/tool/database/run.py b/src/avalan/tool/database/run.py index ac1985e1..41a97b0e 100644 --- a/src/avalan/tool/database/run.py +++ b/src/avalan/tool/database/run.py @@ -35,7 +35,7 @@ def __init__( ) self.__name__ = "run" - async def __call__( + async def __call__( # type: ignore[override] self, sql: str, *, context: ToolCallContext ) -> list[dict[str, Any]]: await self._sleep_if_configured() diff --git a/src/avalan/tool/database/sample.py b/src/avalan/tool/database/sample.py index 2c2f7422..b642e9bc 100644 --- a/src/avalan/tool/database/sample.py +++ b/src/avalan/tool/database/sample.py @@ -54,7 +54,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]: return (schema or None), table return None, qualified - async def __call__( + async def __call__( # type: ignore[override] self, table_name: str, *, diff --git a/src/avalan/tool/database/size.py b/src/avalan/tool/database/size.py index 3dd92759..26e8480f 100644 --- a/src/avalan/tool/database/size.py +++ b/src/avalan/tool/database/size.py @@ -10,7 +10,7 @@ text, ) -from typing import Any +from typing import Any, Literal class DatabaseSizeTool(DatabaseTool): @@ -46,7 +46,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]: return (schema or None), table return None, qualified - async def __call__( + async def __call__( # type: ignore[override] self, table_name: str, *, context: ToolCallContext ) -> TableSize: assert table_name, "table_name must not be empty" @@ -220,10 +220,10 @@ def _collect_sqlite( table_row.get("size") if table_row else None ) - index_rows: list[dict[str, Any]] = [] + index_rows: list[Any] = [] try: pragma = table_name.replace("'", "''") - index_rows = ( + index_rows = list( connection.execute(text(f"PRAGMA index_list('{pragma}')")) .mappings() .all() @@ -420,7 +420,11 @@ def _collect_mssql( metrics.append(self._metric("total", total_bytes)) return metrics - def _metric(self, category: str, value: int | None) -> TableSizeMetric: + def _metric( + self, + category: Literal["data", "indexes", "total", "toast", "lob", "free"], + value: int | None, + ) -> TableSizeMetric: return TableSizeMetric( category=category, bytes=value, diff --git a/src/avalan/tool/database/tables.py b/src/avalan/tool/database/tables.py index 59c8c1e5..fd0d8936 100644 --- a/src/avalan/tool/database/tables.py +++ b/src/avalan/tool/database/tables.py @@ -34,7 +34,7 @@ def __init__( ) self.__name__ = "tables" - async def __call__( + async def __call__( # type: ignore[override] self, *, context: ToolCallContext ) -> dict[str | None, list[str]]: await self._sleep_if_configured() diff --git a/src/avalan/tool/database/tasks.py b/src/avalan/tool/database/tasks.py index 84f91204..515f6d0e 100644 --- a/src/avalan/tool/database/tasks.py +++ b/src/avalan/tool/database/tasks.py @@ -40,7 +40,7 @@ def __init__( ) self.__name__ = "tasks" - async def __call__( + async def __call__( # type: ignore[override] self, *, running_for: int | None = None, diff --git a/src/avalan/tool/database/toolset.py b/src/avalan/tool/database/toolset.py index 08156b2e..55a80325 100644 --- a/src/avalan/tool/database/toolset.py +++ b/src/avalan/tool/database/toolset.py @@ -20,6 +20,7 @@ from .settings import DatabaseToolSettings from contextlib import AsyncExitStack +from types import TracebackType class DatabaseToolSet(ToolSet): @@ -131,11 +132,11 @@ def __init__( async def __aexit__( self, exc_type: type[BaseException] | None, - exc: BaseException | None, - tb: BaseException | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> bool: try: if self._engine is not None: await self._engine.dispose() finally: - return await super().__aexit__(exc_type, exc, tb) + return await super().__aexit__(exc_type, exc_value, traceback) diff --git a/src/avalan/tool/manager.py b/src/avalan/tool/manager.py index 2ce29e62..b11a749f 100644 --- a/src/avalan/tool/manager.py +++ b/src/avalan/tool/manager.py @@ -13,6 +13,8 @@ from collections.abc import Callable, Sequence from contextlib import AsyncExitStack, ContextDecorator +from types import TracebackType +from typing import Any from uuid import uuid4 @@ -60,10 +62,13 @@ def tool_format(self) -> ToolFormat | None: """Return the tool format configured for this manager.""" return self._parser.tool_format - def json_schemas(self) -> list[dict] | None: - schemas = [] - for toolset in self._toolsets: - schemas.extend(toolset.json_schemas()) + def json_schemas(self) -> list[dict[str, Any]] | None: + schemas: list[dict[str, Any]] = [] + if self._toolsets: + for toolset in self._toolsets: + toolset_schemas = toolset.json_schemas() + if toolset_schemas: + schemas.extend(toolset_schemas) return schemas def __init__( @@ -110,24 +115,33 @@ def tool_call_status( return self._parser.tool_call_status(buffer) def get_calls(self, text: str) -> list[ToolCall] | None: - return self._parser(text) + result = self._parser(text) + if result is None: + return None + if isinstance(result, list): + return result + name, arguments = result + return [ToolCall(id=uuid4(), name=name, arguments=arguments)] async def __aenter__(self) -> "ToolManager": if self._toolsets: - for i, toolset in enumerate(self._toolsets): - toolset = await self._stack.enter_async_context(toolset) - self._toolsets[i] = toolset - return self + entered_toolsets: list[ToolSet] = [] + for toolset in self._toolsets: + entered = await self._stack.enter_async_context(toolset) + entered_toolsets.append(entered) + self._toolsets = entered_toolsets + return self # type: ignore[return-value] async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - traceback: BaseException | None, + traceback: TracebackType | None, ) -> bool: - return await self._stack.__aexit__(exc_type, exc_value, traceback) + result = await self._stack.__aexit__(exc_type, exc_value, traceback) + return result if result is not None else False - async def __call__( + async def __call__( # type: ignore[override] self, call: ToolCall, context: ToolCallContext ) -> ToolCallResult | ToolCallError | None: """Execute a single tool call and return the result.""" @@ -153,10 +167,10 @@ async def __call__( if self._settings.filters: for f in self._settings.filters: - namespace = None - func = f + namespace: str | None = None + func: Callable[[ToolCall, ToolCallContext], Any] = f # type: ignore[assignment] if isinstance(f, ToolFilter): - func = f.func + func = f.func # type: ignore[assignment] namespace = f.namespace if not self._matches_namespace(call.name, namespace): continue @@ -184,14 +198,14 @@ async def __call__( if self._settings.transformers: for t in self._settings.transformers: - namespace = None - func = t + t_namespace: str | None = None + t_func: Callable[..., Any] = t # type: ignore[assignment] if isinstance(t, ToolTransformer): - func = t.func - namespace = t.namespace - if not self._matches_namespace(call.name, namespace): + t_func = t.func # type: ignore[assignment] + t_namespace = t.namespace + if not self._matches_namespace(call.name, t_namespace): continue - transformed = func(call, context, result) + transformed = t_func(call, context, result) if transformed is not None: result = transformed diff --git a/src/avalan/tool/math.py b/src/avalan/tool/math.py index f48afbd2..98089d4e 100644 --- a/src/avalan/tool/math.py +++ b/src/avalan/tool/math.py @@ -21,7 +21,9 @@ def __init__(self) -> None: super().__init__() self.__name__ = "calculator" - async def __call__(self, expression: str, context: ToolCallContext) -> str: + async def __call__( # type: ignore[override] + self, expression: str, context: ToolCallContext + ) -> str: result = sympify(expression, evaluate=True) return str(result) diff --git a/src/avalan/tool/mcp.py b/src/avalan/tool/mcp.py index 569f8588..4bcb9fb7 100644 --- a/src/avalan/tool/mcp.py +++ b/src/avalan/tool/mcp.py @@ -31,7 +31,7 @@ def __init__( self._client_params = client_params or {} self._call_params = call_params or {} - async def __call__( + async def __call__( # type: ignore[override] self, uri: str, name: str, @@ -39,15 +39,16 @@ async def __call__( *, context: ToolCallContext, ) -> list[object]: - from mcp import Client + from mcp import Client # type: ignore[attr-defined] assert uri assert name async with Client(uri, **self._client_params) as client: - return await client.call_tool( + result: list[object] = await client.call_tool( name, arguments or {}, **self._call_params ) + return result class McpToolSet(ToolSet): diff --git a/src/avalan/tool/memory.py b/src/avalan/tool/memory.py index c083b3a2..9b155682 100644 --- a/src/avalan/tool/memory.py +++ b/src/avalan/tool/memory.py @@ -3,7 +3,6 @@ from ..memory.manager import MemoryManager from ..memory.permanent import ( Memory, - PermanentMemoryPartition, PermanentMemoryStore, VectorFunction, ) @@ -30,7 +29,9 @@ def __init__(self, memory_manager: MemoryManager) -> None: self._memory_manager = memory_manager self.__name__ = "message.read" - async def __call__(self, search: str, context: ToolCallContext) -> str: + async def __call__( # type: ignore[override] + self, search: str, context: ToolCallContext + ) -> str: if ( not context.agent_id or not context.session_id @@ -48,7 +49,10 @@ async def __call__(self, search: str, context: ToolCallContext) -> str: limit=1, ) if results and results[0].message: - return results[0].message.content + content = results[0].message.content + if isinstance(content, str): + return content + return MessageReadTool._NOT_FOUND return MessageReadTool._NOT_FOUND @@ -78,13 +82,13 @@ def __init__( self._function = function self.__name__ = "read" - async def __call__( + async def __call__( # type: ignore[override] self, namespace: str, search: str, *, context: ToolCallContext, - ) -> list[PermanentMemoryPartition]: + ) -> list[str]: """Return memory partitions that match the search query.""" if ( not namespace @@ -126,7 +130,7 @@ def __init__(self, memory_manager: MemoryManager) -> None: self._memory_manager = memory_manager self.__name__ = "list" - async def __call__( + async def __call__( # type: ignore[override] self, namespace: str, *, @@ -160,7 +164,7 @@ def __init__(self, memory_manager: MemoryManager) -> None: self._memory_manager = memory_manager self.__name__ = "stores" - async def __call__( + async def __call__( # type: ignore[override] self, *, context: ToolCallContext, diff --git a/src/avalan/tool/parser.py b/src/avalan/tool/parser.py index 1b80eece..cd2491d8 100644 --- a/src/avalan/tool/parser.py +++ b/src/avalan/tool/parser.py @@ -39,8 +39,10 @@ def tool_format(self) -> ToolFormat | None: """Return the tool format used by the parser.""" return self._tool_format - def __call__(self, text: str) -> list[ToolCall] | None: - calls = ( + def __call__( + self, text: str + ) -> list[ToolCall] | tuple[str, dict[str, Any]] | None: + calls: list[ToolCall] | tuple[str, dict[str, Any]] | None = ( self._parse_json(text) if self._tool_format is ToolFormat.JSON else ( @@ -152,7 +154,7 @@ def extract_structured_message( def message_tool_calls(self, text: str) -> list[dict[str, object]]: """Return tool calls extracted from ``text`` in message format.""" - parsed = None + parsed: list[ToolCall] | tuple[str, dict[str, Any]] | None = None if "<|call|>" in text and "<|channel|>" in text: parsed = self._parse_harmony(text) elif self._tool_format: @@ -272,7 +274,7 @@ def _merge_thinking( if existing_thinking: combined = "\n\n".join( - part for part in (existing_thinking, thinking) if part + str(part) for part in (existing_thinking, thinking) if part ) else: combined = thinking @@ -379,7 +381,7 @@ def _parse_harmony(self, text: str) -> list[ToolCall] | None: return tool_calls return None - def _parse_tag(self, text: str) -> tuple[str, dict[str, Any]] | None: + def _parse_tag(self, text: str) -> list[ToolCall] | None: tool_calls: list[ToolCall] = [] if self._eos_token: @@ -389,6 +391,8 @@ def _parse_tag(self, text: str) -> tuple[str, dict[str, Any]] | None: for element in root.findall(".//tool_call"): tool_call = None try: + if element.text is None: + continue json_text = element.text.strip() try: diff --git a/src/avalan/tool/search_engine.py b/src/avalan/tool/search_engine.py index 2ec37c1a..160be6a0 100644 --- a/src/avalan/tool/search_engine.py +++ b/src/avalan/tool/search_engine.py @@ -15,7 +15,9 @@ class SearchEngineTool(Tool): def __init__(self) -> None: self.__name__ = "search" - async def __call__(self, query: str, engine: str) -> str: + async def __call__( # type: ignore[override] + self, query: str, engine: str + ) -> str: return ( "The weather is nice and warm, with 23 degrees celsius, clear" " skies, and winds under 11 kmh." diff --git a/src/avalan/tool/youtube.py b/src/avalan/tool/youtube.py index 65629da4..484eaf04 100644 --- a/src/avalan/tool/youtube.py +++ b/src/avalan/tool/youtube.py @@ -33,7 +33,7 @@ def __init__( self._proxy = proxy self.__name__ = "transcript" - async def __call__( + async def __call__( # type: ignore[override] self, video_id: str, *, diff --git a/src/avalan/utils.py b/src/avalan/utils.py index 71597bc9..95de3658 100644 --- a/src/avalan/utils.py +++ b/src/avalan/utils.py @@ -28,8 +28,8 @@ def logger_replace(logger: Logger, logger_names: list[str]) -> None: def to_json(item: Any) -> str: - def _default(o): - if is_dataclass(o): + def _default(o: Any) -> Any: + if is_dataclass(o) and not isinstance(o, type): return asdict(o) elif isinstance(o, (Decimal, UUID)): return str(o) diff --git a/tests/agent/default_orchestrator_test.py b/tests/agent/default_orchestrator_test.py index 1be08b00..961437a4 100644 --- a/tests/agent/default_orchestrator_test.py +++ b/tests/agent/default_orchestrator_test.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from avalan.agent import Goal, InputType, OutputType, Specification +from avalan.agent import Goal, InputType, OutputType, Role, Specification from avalan.agent.orchestrator.orchestrators.default import DefaultOrchestrator from avalan.agent.orchestrator.response.orchestrator_response import ( OrchestratorResponse, @@ -69,7 +69,7 @@ def test_initialization(self): op = orch.operations[0] self.assertIs(op.environment.engine_uri, engine_uri) self.assertIs(op.environment.settings, settings) - self.assertEqual(op.specification.role, "assistant") + self.assertEqual(op.specification.role, Role(persona=["assistant"])) self.assertEqual(op.specification.goal.task, "do") self.assertEqual(op.specification.goal.instructions, ["something"]) self.assertEqual(op.specification.rules, ["a", "b"]) @@ -206,7 +206,7 @@ def output_fn(*args, **kwargs): self.assertEqual( context.specification, Specification( - role="assistant", + role=Role(persona=["assistant"]), goal=Goal(task="do", instructions=["something"]), rules=None, input_type=InputType.TEXT, @@ -309,7 +309,7 @@ async def test_user_string_rendering(self): context = agent_mock.await_args.args[0] self.assertIsInstance(context.input, Message) - self.assertEqual(context.input.content, b"hello world Bob") + self.assertEqual(context.input.content, "hello world Bob") async def test_user_template_rendering(self): with TemporaryDirectory() as tmp: diff --git a/tests/agent/orchestrator_response_additional_test.py b/tests/agent/orchestrator_response_additional_test.py index 0b62a1e3..e2ffeea5 100644 --- a/tests/agent/orchestrator_response_additional_test.py +++ b/tests/agent/orchestrator_response_additional_test.py @@ -139,7 +139,7 @@ async def output_gen(): second = await iterator.__anext__() self.assertEqual(second.type, EventType.TOOL_PROCESS) - self.assertEqual(second.payload, [call]) + self.assertEqual(second.payload, {"calls": [call]}) third = await iterator.__anext__() self.assertEqual(third.type, EventType.TOOL_RESULT) diff --git a/tests/agent/orchestrator_test.py b/tests/agent/orchestrator_test.py index a04e9314..430b0545 100644 --- a/tests/agent/orchestrator_test.py +++ b/tests/agent/orchestrator_test.py @@ -230,7 +230,7 @@ def _message(self, text: str) -> Message: def test_user_string_transformation(self): orchestrator, specification = self._create_orchestrator() message = orchestrator._input_messages(specification, "world") - self.assertEqual(message.content, b"hello world Bob") + self.assertEqual(message.content, "hello world Bob") def test_user_list_strings_transformation(self): orchestrator, specification = self._create_orchestrator() @@ -241,13 +241,13 @@ def test_user_message_transformation(self): orchestrator, specification = self._create_orchestrator() msg = self._message("earth") message = orchestrator._input_messages(specification, msg) - self.assertEqual(message.content, b"hello earth Bob") + self.assertEqual(message.content, "hello earth Bob") def test_user_list_messages_transformation(self): orchestrator, specification = self._create_orchestrator() msg = self._message("moon") messages = orchestrator._input_messages(specification, [msg]) - self.assertEqual(messages[0].content, b"hello moon Bob") + self.assertEqual(messages[0].content, "hello moon Bob") class OrchestratorUserTemplateTransformationOptionsTestCase(unittest.TestCase): @@ -348,7 +348,7 @@ def _create_orchestrator(self): def test_user_string_transformation(self): orchestrator, specification = self._create_orchestrator() message = orchestrator._input_messages(specification, "world") - self.assertEqual(message.content, b"hello world Bob") + self.assertEqual(message.content, "hello world Bob") class OrchestratorSettingsTemplateVarsUserTemplateTestCase(unittest.TestCase): diff --git a/tests/agent/renderer_test.py b/tests/agent/renderer_test.py index d0ca13af..9d1a86c0 100644 --- a/tests/agent/renderer_test.py +++ b/tests/agent/renderer_test.py @@ -59,9 +59,7 @@ def test_custom_template_path(self): def test_from_string(self): renderer = self.Renderer() tmpl = "Hi {{name}}" - self.assertEqual( - renderer.from_string(tmpl, {"name": "Ada"}), b"Hi Ada" - ) + self.assertEqual(renderer.from_string(tmpl, {"name": "Ada"}), "Hi Ada") self.assertEqual(renderer.from_string(tmpl), tmpl) def test_template_caching(self): diff --git a/tests/agent/template_engine_agent_test.py b/tests/agent/template_engine_agent_test.py index 34a5d9b4..6fe840cc 100644 --- a/tests/agent/template_engine_agent_test.py +++ b/tests/agent/template_engine_agent_test.py @@ -89,9 +89,9 @@ def test_prepare_call_no_template_vars(self): "agent.md", name="Bob", roles=["assistant"], - task=b"do", - instructions=[b"instr"], - rules=[b"rule"], + task="do", + instructions=["instr"], + rules=["rule"], ) self.assertEqual(result["settings"], spec.settings) self.assertEqual(result["system_prompt"], expected_prompt) @@ -108,10 +108,10 @@ def test_prepare_call_with_template_vars(self): expected_prompt = self.renderer( "agent.md", name="Bob", - roles=[b"role run"], - task=b"do run", - instructions=[b"inst run"], - rules=[b"rule run"], + roles=["role run"], + task="do run", + instructions=["inst run"], + rules=["rule run"], ) self.assertEqual(result["system_prompt"], expected_prompt) @@ -127,10 +127,10 @@ def test_prepare_call_with_settings_template_vars(self): expected_prompt = self.renderer( "agent.md", name="Bob", - roles=[b"role run"], - task=b"do run", - instructions=[b"inst run"], - rules=[b"rule run"], + roles=["role run"], + task="do run", + instructions=["inst run"], + rules=["rule run"], ) self.assertEqual(result["system_prompt"], expected_prompt) @@ -228,9 +228,9 @@ async def test_call_invokes_run_with_prepared_arguments(self): "agent.md", name="Bob", roles=["assistant"], - task=b"do", - instructions=[b"ins"], - rules=[b"r"], + task="do", + instructions=["ins"], + rules=["r"], ) agent._run = AsyncMock(return_value="out") diff --git a/tests/cli/agent_test.py b/tests/cli/agent_test.py index 84b47f0a..865beca0 100644 --- a/tests/cli/agent_test.py +++ b/tests/cli/agent_test.py @@ -72,7 +72,7 @@ def setUp(self): self.console.status.return_value = status_cm self.theme = MagicMock() self.theme._ = lambda s: s - self.theme.icons = {"user_input": ">"} + self.theme._icons = {"user_input": ">", "agent_output": "<"} self.theme.get_spinner.return_value = "sp" self.theme.agent.return_value = "agent_panel" self.theme.search_message_matches.return_value = "matches_panel" @@ -896,7 +896,7 @@ def setUp(self): self.console.status.return_value = status_cm self.theme = MagicMock() self.theme._ = lambda s: s - self.theme.icons = {"user_input": ">", "agent_output": "<"} + self.theme._icons = {"user_input": ">", "agent_output": "<"} self.theme.get_spinner.return_value = "sp" self.theme.agent.return_value = "agent_panel" self.theme.recent_messages.return_value = "recent_panel" @@ -2467,7 +2467,7 @@ def setUp(self): self.console.status.return_value = status_cm self.theme = MagicMock() self.theme._ = lambda s: s - self.theme.icons = {"user_input": ">", "agent_output": "<"} + self.theme._icons = {"user_input": ">", "agent_output": "<"} self.theme.get_spinner.return_value = "sp" self.theme.agent.return_value = "agent_panel" self.theme.recent_messages.return_value = "recent_panel" diff --git a/tests/cli/cache_test.py b/tests/cli/cache_test.py index 548a7bed..8d8f09b8 100644 --- a/tests/cli/cache_test.py +++ b/tests/cli/cache_test.py @@ -40,7 +40,7 @@ def test_execute_deletion_confirmed(self): ) self.hub.cache_delete.assert_called_once_with("m", "r") self.assertEqual( - self.theme.cache_delete.call_args_list[0].args, (cache_del,) + self.theme.cache_delete.call_args_list[0].args, (cache_del, False) ) self.assertEqual( self.theme.cache_delete.call_args_list[1].args, (cache_del, True) @@ -65,7 +65,9 @@ def test_deletion_cancelled(self): execute.assert_not_called() confirm.assert_called_once() self.assertEqual(len(self.theme.cache_delete.call_args_list), 1) - self.assertEqual(self.theme.cache_delete.call_args.args, (cache_del,)) + self.assertEqual( + self.theme.cache_delete.call_args.args, (cache_del, False) + ) self.assertEqual(len(self.console.print.call_args_list), 1) diff --git a/tests/cli/input_test.py b/tests/cli/input_test.py index ba81c030..bdbb48b6 100644 --- a/tests/cli/input_test.py +++ b/tests/cli/input_test.py @@ -43,7 +43,7 @@ def test_prompt_when_no_stdin(self): ): result = get_input(self.console, "prompt") - ask.assert_called_once_with("prompt ") + ask.assert_called_once_with("prompt ", stream=None) self.assertEqual(result, "value") self.assertEqual( self.console.print.call_args_list, [call(""), call("")] @@ -123,9 +123,10 @@ def test_confirm_tool_call(self): "Execute tool call?", choices=["y", "a", "n"], default="n", + console=console, show_choices=True, show_default=True, - console=console, + stream=None, ) self.assertEqual(result, "y") @@ -206,9 +207,10 @@ def ask_side_effect(*args, **kwargs): "Execute tool call?", choices=["y", "a", "n"], default="n", + console=console, show_choices=True, show_default=True, - console=console, + stream=None, ) self.assertEqual(live.auto_refresh, True) self.assertEqual(result, "y") diff --git a/tests/cli/model_search_additional_test.py b/tests/cli/model_search_additional_test.py index b12fc3c3..0e1094f8 100644 --- a/tests/cli/model_search_additional_test.py +++ b/tests/cli/model_search_additional_test.py @@ -73,7 +73,7 @@ async def test_model_search_updates_access(self): live.__exit__.return_value = False async def fake_to_thread(fn, *a, **kw): - return fn() + return fn(*a, **kw) with ( patch.object(model_cmds, "Live", return_value=live), diff --git a/tests/cli/model_test.py b/tests/cli/model_test.py index ba21918d..77ff4f37 100644 --- a/tests/cli/model_test.py +++ b/tests/cli/model_test.py @@ -824,7 +824,7 @@ async def inner_gen(): events = [ model_cmds.Event( type=model_cmds.EventType.TOOL_EXECUTE, - payload=[call], + payload={"call": call}, ), token_a, model_cmds.Event( @@ -4242,7 +4242,7 @@ def fake_group(*items): return items async def to_thread_stub(fn, *a, **kw): - return fn() + return fn(*a, **kw) with ( patch.object(model_cmds, "Live", return_value=live), @@ -4317,7 +4317,9 @@ async def test_model_search_open(self): patch.object(model_cmds, "Live", return_value=live), patch.object(model_cmds, "Group", side_effect=lambda *i: i), patch.object( - model_cmds, "to_thread", side_effect=lambda f, *a, **k: f() + model_cmds, + "to_thread", + side_effect=lambda f, *a, **k: f(*a, **k), ), ): await model_cmds.model_search(args, console, theme, hub, 5) diff --git a/tests/cli/theme_test.py b/tests/cli/theme_test.py index 90443e43..4157c5fb 100644 --- a/tests/cli/theme_test.py +++ b/tests/cli/theme_test.py @@ -319,7 +319,7 @@ def test_base_methods_raise(self): self.theme, "n", "d", "a", "id", "lib", True, False ), lambda: Theme.agent( - self.theme, SimpleNamespace(), models=[], cans_access=None + self.theme, SimpleNamespace(), models=[], can_access=None ), lambda: Theme.events(self.theme, []), lambda: Theme.ask_access_token(self.theme), @@ -356,7 +356,7 @@ def test_base_methods_raise(self): lambda: Theme.recent_messages( self.theme, "id", SimpleNamespace(), [] ), - lambda: Theme.saved_tokenizer_files("/d", 0), + lambda: Theme.saved_tokenizer_files(self.theme, "/d", 0), lambda: Theme.search_message_matches( self.theme, "id", SimpleNamespace(), [] ), @@ -378,7 +378,7 @@ def test_base_methods_raise(self): call() async def run_tokens(): - await Theme.tokens( + async for _ in Theme.tokens( self.theme, model_id="m", added_tokens=None, @@ -396,13 +396,15 @@ async def run_tokens(): tool_events=None, tool_event_calls=None, tool_event_results=None, + tool_running_spinner=None, ttft=0.0, ttnt=0.0, ttsr=0.0, elapsed=0.0, console_width=80, logger=SimpleNamespace(), - ) + ): + pass with self.assertRaises(NotImplementedError): import asyncio diff --git a/tests/cli/tokenizer_test.py b/tests/cli/tokenizer_test.py index 5df733a4..23c40dd0 100644 --- a/tests/cli/tokenizer_test.py +++ b/tests/cli/tokenizer_test.py @@ -1,7 +1,6 @@ import importlib import sys from argparse import Namespace -from dataclasses import dataclass from types import ModuleType, SimpleNamespace from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, call, patch @@ -17,39 +16,15 @@ def setUpClass(cls): cls._saved_modules = { name: sys.modules.get(name) for name in [ - "avalan.entities", "avalan.model.hubs.huggingface", "avalan.model.nlp.text.generation", ] } - # Stub avalan.entities - entities = ModuleType("avalan.entities") - - @dataclass(frozen=True, kw_only=True) - class Token: - id: int - token: str - probability: float | None = None - - class TransformerEngineSettings: - def __init__( - self, - device=None, - cache_dir=None, - **kwargs, - ) -> None: - self.device = device - self.cache_dir = cache_dir - self.distributed_config = None - for key, value in kwargs.items(): - setattr(self, key, value) - - backend = "transformers" + # Import real entities for TransformerEngineSettings + from avalan.entities import TransformerEngineSettings - entities.Token = Token - entities.TransformerEngineSettings = TransformerEngineSettings - sys.modules["avalan.entities"] = entities + cls.TransformerEngineSettings = TransformerEngineSettings # Stub avalan.model.hubs.huggingface hubs = ModuleType("avalan.model.hubs.huggingface") @@ -93,7 +68,6 @@ def save_tokenizer(self, path): cls.tokenizer_mod = importlib.import_module( "avalan.cli.commands.tokenizer" ) - cls.TransformerEngineSettings = TransformerEngineSettings @classmethod def tearDownClass(cls): @@ -124,7 +98,7 @@ def setUp(self): self.theme = MagicMock() self.theme._ = lambda s: s self.theme._n = lambda s, p, n: s if n == 1 else p - self.theme.icons = {"user_input": ">"} + self.theme._icons = {"user_input": ">"} self.theme.tokenizer_config.return_value = "cfg_panel" self.theme.saved_tokenizer_files.return_value = "save_panel" self.theme.tokenizer_tokens.return_value = "tokens_panel" @@ -199,7 +173,7 @@ async def test_tokenize_input(self): dummy_model.tokenize.assert_called_once_with("hello") get_input.assert_called_once_with( self.console, - self.theme.icons["user_input"] + " ", + self.theme._icons["user_input"] + " ", echo_stdin=not args.no_repl, is_quiet=args.quiet, tty_path="/dev/tty", diff --git a/tests/compat_test.py b/tests/compat_test.py index 02c030e5..b31a21ff 100644 --- a/tests/compat_test.py +++ b/tests/compat_test.py @@ -1,6 +1,5 @@ import importlib import sys -import typing from unittest import TestCase @@ -10,26 +9,14 @@ def _reload_module(self): del sys.modules["avalan.compat"] return importlib.import_module("avalan.compat") - def test_override_fallback_when_missing(self): - if hasattr(typing, "override"): - delattr(typing, "override") + def test_override_from_typing_extensions(self): + """Test that override is imported from typing_extensions.""" compat = self._reload_module() - self.assertEqual(compat.override.__module__, "avalan.compat") + self.assertEqual(compat.override.__module__, "typing_extensions") def func(): return 1 - self.assertIs(compat.override(func), func) - - def test_override_uses_typing_override_when_available(self): - def sentinel(func): - func.sentinel = True - return func - - typing.override = sentinel - try: - compat = self._reload_module() - self.assertIs(compat.override, sentinel) - finally: - delattr(typing, "override") - self._reload_module() + # The override decorator should return the function unchanged + decorated = compat.override(func) + self.assertEqual(decorated(), 1) diff --git a/tests/model/audio/audio_classification_test.py b/tests/model/audio/audio_classification_test.py index 2e11bcc5..5c8c0382 100644 --- a/tests/model/audio/audio_classification_test.py +++ b/tests/model/audio/audio_classification_test.py @@ -92,6 +92,10 @@ def item(self): "avalan.model.audio.classification.inference_mode", return_value=nullcontext(), ) as inf_mock, + patch( + "avalan.model.audio.classification.isinstance", + return_value=True, + ), ): extractor_instance = MagicMock() extractor_call = MagicMock(return_value="inputs") diff --git a/tests/model/audio/base_audio_model_test.py b/tests/model/audio/base_audio_model_test.py index 7e837bc5..cade371e 100644 --- a/tests/model/audio/base_audio_model_test.py +++ b/tests/model/audio/base_audio_model_test.py @@ -1,5 +1,4 @@ from logging import Logger -from typing import Literal from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch @@ -12,21 +11,17 @@ class DummyAudioModel(BaseAudioModel): def _load_model(self): return MagicMock() - async def __call__( - self, image_source: str | object, tensor_format: Literal["pt"] = "pt" - ) -> str: - return await super().__call__(image_source, tensor_format) + async def __call__(self, path: str) -> str: + return path class BaseAudioModelTestCase(IsolatedAsyncioTestCase): - async def test_base_methods_raise(self): + async def test_tokenizer_methods_raise(self): model = DummyAudioModel( None, EngineSettings(auto_load_model=False), logger=MagicMock(spec=Logger), ) - with self.assertRaises(NotImplementedError): - await model("img") with self.assertRaises(TokenizerNotSupportedException): model._load_tokenizer(None) with self.assertRaises(TokenizerNotSupportedException): diff --git a/tests/model/engine_additional_test.py b/tests/model/engine_additional_test.py index 9e9a6f24..9f8bf036 100644 --- a/tests/model/engine_additional_test.py +++ b/tests/model/engine_additional_test.py @@ -6,6 +6,7 @@ class DummyEngine(Engine): + @property def uses_tokenizer(self) -> bool: return False diff --git a/tests/model/model_manager_operation_test.py b/tests/model/model_manager_operation_test.py index 23bab4e9..889237f6 100644 --- a/tests/model/model_manager_operation_test.py +++ b/tests/model/model_manager_operation_test.py @@ -61,12 +61,12 @@ def test_all_modalities(self): Modality.AUDIO_TEXT_TO_SPEECH: (self._check_audio, True), Modality.AUDIO_GENERATION: (self._check_audio, True), Modality.EMBEDDING: ( - lambda op: self.assertIsNone(op.parameters), + lambda op: self.assertEqual(op.parameters, {}), False, ), Modality.TEXT_QUESTION_ANSWERING: (self._check_text, True), Modality.TEXT_SEQUENCE_CLASSIFICATION: ( - lambda op: self.assertIsNone(op.parameters), + lambda op: self.assertEqual(op.parameters, {}), True, ), Modality.TEXT_SEQUENCE_TO_SEQUENCE: (self._check_text, True), @@ -95,7 +95,7 @@ def test_unknown_modality(self): FakeModality.UNKNOWN, self.args, None ) self.assertEqual(op.modality, FakeModality.UNKNOWN) - self.assertIsNone(op.parameters) + self.assertEqual(op.parameters, {}) self.assertFalse(op.requires_input) def test_text_to_video_with_steps(self): diff --git a/tests/model/nlp/text_test.py b/tests/model/nlp/text_test.py index e346cef3..0aea14cf 100644 --- a/tests/model/nlp/text_test.py +++ b/tests/model/nlp/text_test.py @@ -149,12 +149,11 @@ def test_tokenize(self): lambda i, skip_special_tokens=False: f"t{i}" ) self.model._loaded_tokenizer = False - self.model.load = MagicMock() + self.model._load = MagicMock() result = self.model.tokenize("hi") - self.model.load.assert_called_once_with( - load_model=False, + self.model._load.assert_called_once_with( load_tokenizer=True, tokenizer_name_or_path=None, ) diff --git a/tests/model/response_parsers_additional_test.py b/tests/model/response_parsers_additional_test.py index 4a8b1c75..c612060e 100644 --- a/tests/model/response_parsers_additional_test.py +++ b/tests/model/response_parsers_additional_test.py @@ -168,7 +168,9 @@ def status_side_effect(text: str): ) event = next(item for item in output if isinstance(item, Event)) self.assertEqual(event.type, EventType.TOOL_PROCESS) - self.assertEqual(event.payload, [SimpleNamespace(name="call")]) + self.assertEqual( + event.payload, {"calls": [SimpleNamespace(name="call")]} + ) trigger_event = event_manager.trigger.await_args_list[0].args[0] self.assertEqual(trigger_event.type, EventType.TOOL_DETECT) diff --git a/tests/model/text_generation_modality_vendors_test.py b/tests/model/text_generation_modality_vendors_test.py index b54dec99..f4ff1640 100644 --- a/tests/model/text_generation_modality_vendors_test.py +++ b/tests/model/text_generation_modality_vendors_test.py @@ -9,6 +9,9 @@ from avalan.entities import EngineUri, TransformerEngineSettings from avalan.model.modalities.text import TextGenerationModality +# Vendors that don't accept exit_stack parameter +VENDORS_WITHOUT_EXIT_STACK: set[str] = set() + @pytest.mark.parametrize( "vendor,class_name", @@ -50,10 +53,17 @@ def test_load_engine_per_vendor(vendor: str, class_name: str) -> None: result = TextGenerationModality().load_engine( engine_uri, settings, logger, exit_stack ) - loader.assert_called_once_with( - model_id="model", - settings=settings, - logger=logger, - exit_stack=exit_stack, - ) + if vendor in VENDORS_WITHOUT_EXIT_STACK: + loader.assert_called_once_with( + model_id="model", + settings=settings, + logger=logger, + ) + else: + loader.assert_called_once_with( + model_id="model", + settings=settings, + logger=logger, + exit_stack=exit_stack, + ) assert result is loader.return_value diff --git a/tests/model/text_modalities_full_test.py b/tests/model/text_modalities_full_test.py index 6c2ab639..3f4c5852 100644 --- a/tests/model/text_modalities_full_test.py +++ b/tests/model/text_modalities_full_test.py @@ -254,7 +254,7 @@ def test_text_generation_get_operation_from_arguments() -> None: GenerationSettings(), ) text_params = operation.parameters["text"] - assert text_params.manual_sampling == 5 + assert text_params.manual_sampling is True assert text_params.pick_tokens == 10 assert text_params.skip_special_tokens is True assert text_params.stop_on_keywords == ["STOP"] @@ -447,7 +447,7 @@ def test_text_sequence_classification_get_operation_from_arguments() -> None: GenerationSettings(), ) ) - assert operation.parameters is None + assert operation.parameters == {} assert operation.input == "text" diff --git a/tests/model/vendor_tool_call_token_test.py b/tests/model/vendor_tool_call_token_test.py index fa4f97fb..6b8021cc 100644 --- a/tests/model/vendor_tool_call_token_test.py +++ b/tests/model/vendor_tool_call_token_test.py @@ -1,4 +1,5 @@ from unittest import TestCase +from uuid import UUID from avalan.entities import ToolCall, ToolCallToken from avalan.model.vendor import TextGenerationVendor @@ -26,11 +27,16 @@ def test_build_tool_call_token_handles_invalid_json_str(self) -> None: tool_name="tool", arguments='{"a": }', ) - expected = ToolCallToken( - token='{"name": "tool", "arguments": {}}', - call=ToolCall(id=None, name="tool", arguments={}), + # When call_id is None, a UUID is generated + self.assertEqual( + token.token, + '{"name": "tool", "arguments": {}}', ) - self.assertEqual(token, expected) + assert token.call is not None + self.assertEqual(token.call.name, "tool") + self.assertEqual(token.call.arguments, {}) + # Verify a valid UUID was generated + UUID(token.call.id) # This will raise if not a valid UUID def test_build_tool_call_token_from_dict(self) -> None: token = TextGenerationVendor.build_tool_call_token( diff --git a/tests/utils_test.py b/tests/utils_test.py index 407f6aa7..b1d2728a 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -135,5 +135,5 @@ def reset(self, *_, **__): super_close.assert_called_once() bar.reset(total=2) - bar._progress.reset.assert_called_once_with(total=2) + bar._progress.reset.assert_called_once_with(task_id=1, total=2) super_reset.assert_called_once_with(total=2) From 4bccc059dcd0056077be8cd1a4f129613a98797f Mon Sep 17 00:00:00 2001 From: Mariano Iglesias Date: Sat, 10 Jan 2026 10:38:45 -0300 Subject: [PATCH 2/2] Fix compat override test (#887) --- tests/compat_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compat_test.py b/tests/compat_test.py index b31a21ff..13a0bd3d 100644 --- a/tests/compat_test.py +++ b/tests/compat_test.py @@ -12,7 +12,9 @@ def _reload_module(self): def test_override_from_typing_extensions(self): """Test that override is imported from typing_extensions.""" compat = self._reload_module() - self.assertEqual(compat.override.__module__, "typing_extensions") + self.assertIn( + compat.override.__module__, {"typing", "typing_extensions"} + ) def func(): return 1