diff --git a/Makefile b/Makefile
index c91cb305..7357106a 100644
--- a/Makefile
+++ b/Makefile
@@ -13,6 +13,7 @@ lint:
poetry run ruff format --preview src/ tests/
poetry run black --preview --enable-unstable-feature=string_processing src/ tests/
poetry run ruff check --fix src/ tests/
+ poetry run mypy src/avalan
test:
poetry sync --extras test
diff --git a/poetry.lock b/poetry.lock
index fbe53c14..cdaaea54 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]]
name = "a2a-sdk"
@@ -589,6 +589,30 @@ files = [
{file = "blake3-1.0.7-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:d9046bb1e22a8607e1d0d7c3ff47e56e0a197c988502df4bf4d78563f3e9fe2c"},
{file = "blake3-1.0.7-cp313-cp313t-win32.whl", hash = "sha256:bd2f638bcc00fc09ce985ea3c642d45940e1eda198ab1f4b90cfdecbebbc9315"},
{file = "blake3-1.0.7-cp313-cp313t-win_amd64.whl", hash = "sha256:cb3aa1db14231c2ef0ec5acd805505ce128c39ffa510deb3384eed96fe4addcb"},
+ {file = "blake3-1.0.7-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:f7db997205aa420d59fb5639346e40beafb9c09252e2ec6efedca8f230f7520c"},
+ {file = "blake3-1.0.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:19afec6e276f3bc154541248d92b1ecb198af2ee920025f7ce521028f9a69d8b"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:006a11bbba65a95e88ddc069cca751c8812fd144d582715eeea512452fdbe80d"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7febeffdc8412fed105ca517cee641ac521fb9cfb750bf7e27a5cdf3ddf74a08"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c032ce7c52b71015651c0abe9fe599aa2669e6be578aa17d5f993dc93373401"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b81455f7d24b58fe26be037cc3854c28ea6eb3671ceab3b1ec0b1239aeb6fef"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41b0127b0e7c8610054c421959dbe7140a81ac2c88fa9e099994fbaa529af3c1"},
+ {file = "blake3-1.0.7-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4755ca95b4114b629d8f3570bc661916d211d52d47f57ff70e9687377ab39cb9"},
+ {file = "blake3-1.0.7-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:8abe929cfd27b375e02e3dd7a690192fa4efecc52ef510df91ef01651ef08dc7"},
+ {file = "blake3-1.0.7-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:dd607eb5ad5a9b44ff62243759aa0af4085f6f43c9b01f503561a70da63e3b94"},
+ {file = "blake3-1.0.7-cp314-cp314-win32.whl", hash = "sha256:a51684d1f346e7680f7c244c25b0e279e3b297f1938126e4ea8e32425ea269f5"},
+ {file = "blake3-1.0.7-cp314-cp314-win_amd64.whl", hash = "sha256:a6a481719e28e2c61aafd4273d32663365d97613341b72fcdf2f6afbd426319b"},
+ {file = "blake3-1.0.7-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:daa8933cd7db19143bd6b59f7ac4c7c7446767d7b2c3a748a4559aa483275fa2"},
+ {file = "blake3-1.0.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:24074adfffffe0fa7a7dd930cc608d6e965e70306e2c1e14d412e29ec94fa360"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dce6e6f03de2674f9860cf330d8a4fcdb63a60659435e5e31d72d174fc102d8e"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e783f33d53a2de8d2ab845235dd53393d521b5e4a76c23d03e77e472266359d3"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:782784aef18eb61f4ce8bf2b9506b7d90f0d183176b453345b221837a18041b7"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6062122e77f40e3733cac2ef3f25e0fc7f555e352fe6f513f8404ad11dc69974"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c2614bc9d69fd6067571f3bb37b3b07a6b86a56167553ad4784a3c508771f39"},
+ {file = "blake3-1.0.7-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6df2bd56c43bdeb6699d4af0a0dd0d77537d95cb4a5dde4b39ed6e54cc725d6"},
+ {file = "blake3-1.0.7-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:8b635cf4350caf459ecb335b32be622068423245bda457d5bc159106eb20f912"},
+ {file = "blake3-1.0.7-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:f96a685775f87ddf75ff495dc9698703268c66c170caca977347427ef8d52324"},
+ {file = "blake3-1.0.7-cp314-cp314t-win32.whl", hash = "sha256:0633b7d9bad87dc7fce545042353f2e056604d993f71d1dce666a9f5edc13e05"},
+ {file = "blake3-1.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:5e356daa0089968dc1ff1d0d112e7cc1700533441d8f30ae99f835a94dc8b0f3"},
{file = "blake3-1.0.7-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ac2816aa95f27ae5c1888e04468d1091acca5f00c302e8884b600dd344bc80ac"},
{file = "blake3-1.0.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c67b69a513d2a73773bc0a0399a5b4d1602dd157b4960ce8bce9f4fd6327323"},
{file = "blake3-1.0.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:999f4817e651e63811a0ecdafbdbb60d2097f7f46aff7dd600fe3f1780e48c8e"},
@@ -1505,6 +1529,7 @@ files = [
{file = "faiss_cpu-1.11.0.post1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dc12b3f89cf48be3f2a20b37f310c3f1a7a5708fdf705f88d639339a24bb590b"},
{file = "faiss_cpu-1.11.0.post1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:773fa45aa98a210ab4e2c17c1b5fb45f6d7e9acb4979c9a0b320b678984428ac"},
{file = "faiss_cpu-1.11.0.post1-cp39-cp39-win_amd64.whl", hash = "sha256:6240c4b1551eedc07e76813c2e14a1583a1db6c319a92a3934bf212d0e4c7791"},
+ {file = "faiss_cpu-1.11.0.post1.tar.gz", hash = "sha256:06b1ea9ddec9e4d9a41c8ef7478d493b08d770e9a89475056e963081eed757d1"},
]
[package.dependencies]
@@ -2687,10 +2712,9 @@ files = [
name = "joblib"
version = "1.5.2"
description = "Lightweight pipelining with Python functions"
-optional = true
+optional = false
python-versions = ">=3.9"
groups = ["main"]
-markers = "extra == \"memory\" or extra == \"all\""
files = [
{file = "joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241"},
{file = "joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55"},
@@ -2785,6 +2809,93 @@ interegular = ["interegular (>=0.3.1,<0.4.0)"]
nearley = ["js2py"]
regex = ["regex"]
+[[package]]
+name = "librt"
+version = "0.7.7"
+description = "Mypyc runtime library"
+optional = false
+python-versions = ">=3.9"
+groups = ["dev"]
+markers = "platform_python_implementation != \"PyPy\""
+files = [
+ {file = "librt-0.7.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4836c5645f40fbdc275e5670819bde5ab5f2e882290d304e3c6ddab1576a6d0"},
+ {file = "librt-0.7.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae8aec43117a645a31e5f60e9e3a0797492e747823b9bda6972d521b436b4e8"},
+ {file = "librt-0.7.7-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:aea05f701ccd2a76b34f0daf47ca5068176ff553510b614770c90d76ac88df06"},
+ {file = "librt-0.7.7-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7b16ccaeff0ed4355dfb76fe1ea7a5d6d03b5ad27f295f77ee0557bc20a72495"},
+ {file = "librt-0.7.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c48c7e150c095d5e3cea7452347ba26094be905d6099d24f9319a8b475fcd3e0"},
+ {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4dcee2f921a8632636d1c37f1bbdb8841d15666d119aa61e5399c5268e7ce02e"},
+ {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:14ef0f4ac3728ffd85bfc58e2f2f48fb4ef4fa871876f13a73a7381d10a9f77c"},
+ {file = "librt-0.7.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e4ab69fa37f8090f2d971a5d2bc606c7401170dbdae083c393d6cbf439cb45b8"},
+ {file = "librt-0.7.7-cp310-cp310-win32.whl", hash = "sha256:4bf3cc46d553693382d2abf5f5bd493d71bb0f50a7c0beab18aa13a5545c8900"},
+ {file = "librt-0.7.7-cp310-cp310-win_amd64.whl", hash = "sha256:f0c8fe5aeadd8a0e5b0598f8a6ee3533135ca50fd3f20f130f9d72baf5c6ac58"},
+ {file = "librt-0.7.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a487b71fbf8a9edb72a8c7a456dda0184642d99cd007bc819c0b7ab93676a8ee"},
+ {file = "librt-0.7.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4d4efb218264ecf0f8516196c9e2d1a0679d9fb3bb15df1155a35220062eba8"},
+ {file = "librt-0.7.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b8bb331aad734b059c4b450cd0a225652f16889e286b2345af5e2c3c625c3d85"},
+ {file = "librt-0.7.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:467dbd7443bda08338fc8ad701ed38cef48194017554f4c798b0a237904b3f99"},
+ {file = "librt-0.7.7-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50d1d1ee813d2d1a3baf2873634ba506b263032418d16287c92ec1cc9c1a00cb"},
+ {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c7e5070cf3ec92d98f57574da0224f8c73faf1ddd6d8afa0b8c9f6e86997bc74"},
+ {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bdb9f3d865b2dafe7f9ad7f30ef563c80d0ddd2fdc8cc9b8e4f242f475e34d75"},
+ {file = "librt-0.7.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8185c8497d45164e256376f9da5aed2bb26ff636c798c9dabe313b90e9f25b28"},
+ {file = "librt-0.7.7-cp311-cp311-win32.whl", hash = "sha256:44d63ce643f34a903f09ff7ca355aae019a3730c7afd6a3c037d569beeb5d151"},
+ {file = "librt-0.7.7-cp311-cp311-win_amd64.whl", hash = "sha256:7d13cc340b3b82134f8038a2bfe7137093693dcad8ba5773da18f95ad6b77a8a"},
+ {file = "librt-0.7.7-cp311-cp311-win_arm64.whl", hash = "sha256:983de36b5a83fe9222f4f7dcd071f9b1ac6f3f17c0af0238dadfb8229588f890"},
+ {file = "librt-0.7.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a85a1fc4ed11ea0eb0a632459ce004a2d14afc085a50ae3463cd3dfe1ce43fc"},
+ {file = "librt-0.7.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c87654e29a35938baead1c4559858f346f4a2a7588574a14d784f300ffba0efd"},
+ {file = "librt-0.7.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c9faaebb1c6212c20afd8043cd6ed9de0a47d77f91a6b5b48f4e46ed470703fe"},
+ {file = "librt-0.7.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1908c3e5a5ef86b23391448b47759298f87f997c3bd153a770828f58c2bb4630"},
+ {file = "librt-0.7.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbc4900e95a98fc0729523be9d93a8fedebb026f32ed9ffc08acd82e3e181503"},
+ {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a7ea4e1fbd253e5c68ea0fe63d08577f9d288a73f17d82f652ebc61fa48d878d"},
+ {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ef7699b7a5a244b1119f85c5bbc13f152cd38240cbb2baa19b769433bae98e50"},
+ {file = "librt-0.7.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:955c62571de0b181d9e9e0a0303c8bc90d47670a5eff54cf71bf5da61d1899cf"},
+ {file = "librt-0.7.7-cp312-cp312-win32.whl", hash = "sha256:1bcd79be209313b270b0e1a51c67ae1af28adad0e0c7e84c3ad4b5cb57aaa75b"},
+ {file = "librt-0.7.7-cp312-cp312-win_amd64.whl", hash = "sha256:4353ee891a1834567e0302d4bd5e60f531912179578c36f3d0430f8c5e16b456"},
+ {file = "librt-0.7.7-cp312-cp312-win_arm64.whl", hash = "sha256:a76f1d679beccccdf8c1958e732a1dfcd6e749f8821ee59d7bec009ac308c029"},
+ {file = "librt-0.7.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f4a0b0a3c86ba9193a8e23bb18f100d647bf192390ae195d84dfa0a10fb6244"},
+ {file = "librt-0.7.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5335890fea9f9e6c4fdf8683061b9ccdcbe47c6dc03ab8e9b68c10acf78be78d"},
+ {file = "librt-0.7.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b4346b1225be26def3ccc6c965751c74868f0578cbcba293c8ae9168483d811"},
+ {file = "librt-0.7.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a10b8eebdaca6e9fdbaf88b5aefc0e324b763a5f40b1266532590d5afb268a4c"},
+ {file = "librt-0.7.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:067be973d90d9e319e6eb4ee2a9b9307f0ecd648b8a9002fa237289a4a07a9e7"},
+ {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:23d2299ed007812cccc1ecef018db7d922733382561230de1f3954db28433977"},
+ {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6b6f8ea465524aa4c7420c7cc4ca7d46fe00981de8debc67b1cc2e9957bb5b9d"},
+ {file = "librt-0.7.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8df32a99cc46eb0ee90afd9ada113ae2cafe7e8d673686cf03ec53e49635439"},
+ {file = "librt-0.7.7-cp313-cp313-win32.whl", hash = "sha256:86f86b3b785487c7760247bcdac0b11aa8bf13245a13ed05206286135877564b"},
+ {file = "librt-0.7.7-cp313-cp313-win_amd64.whl", hash = "sha256:4862cb2c702b1f905c0503b72d9d4daf65a7fdf5a9e84560e563471e57a56949"},
+ {file = "librt-0.7.7-cp313-cp313-win_arm64.whl", hash = "sha256:0996c83b1cb43c00e8c87835a284f9057bc647abd42b5871e5f941d30010c832"},
+ {file = "librt-0.7.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:23daa1ab0512bafdd677eb1bfc9611d8ffbe2e328895671e64cb34166bc1b8c8"},
+ {file = "librt-0.7.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:558a9e5a6f3cc1e20b3168fb1dc802d0d8fa40731f6e9932dcc52bbcfbd37111"},
+ {file = "librt-0.7.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2567cb48dc03e5b246927ab35cbb343376e24501260a9b5e30b8e255dca0d1d2"},
+ {file = "librt-0.7.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6066c638cdf85ff92fc6f932d2d73c93a0e03492cdfa8778e6d58c489a3d7259"},
+ {file = "librt-0.7.7-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a609849aca463074c17de9cda173c276eb8fee9e441053529e7b9e249dc8b8ee"},
+ {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:add4e0a000858fe9bb39ed55f31085506a5c38363e6eb4a1e5943a10c2bfc3d1"},
+ {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a3bfe73a32bd0bdb9a87d586b05a23c0a1729205d79df66dee65bb2e40d671ba"},
+ {file = "librt-0.7.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:0ecce0544d3db91a40f8b57ae26928c02130a997b540f908cefd4d279d6c5848"},
+ {file = "librt-0.7.7-cp314-cp314-win32.whl", hash = "sha256:8f7a74cf3a80f0c3b0ec75b0c650b2f0a894a2cec57ef75f6f72c1e82cdac61d"},
+ {file = "librt-0.7.7-cp314-cp314-win_amd64.whl", hash = "sha256:3d1fe2e8df3268dd6734dba33ededae72ad5c3a859b9577bc00b715759c5aaab"},
+ {file = "librt-0.7.7-cp314-cp314-win_arm64.whl", hash = "sha256:2987cf827011907d3dfd109f1be0d61e173d68b1270107bb0e89f2fca7f2ed6b"},
+ {file = "librt-0.7.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8e92c8de62b40bfce91d5e12c6e8b15434da268979b1af1a6589463549d491e6"},
+ {file = "librt-0.7.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f683dcd49e2494a7535e30f779aa1ad6e3732a019d80abe1309ea91ccd3230e3"},
+ {file = "librt-0.7.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9b15e5d17812d4d629ff576699954f74e2cc24a02a4fc401882dd94f81daba45"},
+ {file = "librt-0.7.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c084841b879c4d9b9fa34e5d5263994f21aea7fd9c6add29194dbb41a6210536"},
+ {file = "librt-0.7.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c8fb9966f84737115513fecbaf257f9553d067a7dd45a69c2c7e5339e6a8dc"},
+ {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:9b5fb1ecb2c35362eab2dbd354fd1efa5a8440d3e73a68be11921042a0edc0ff"},
+ {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:d1454899909d63cc9199a89fcc4f81bdd9004aef577d4ffc022e600c412d57f3"},
+ {file = "librt-0.7.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7ef28f2e7a016b29792fe0a2dd04dec75725b32a1264e390c366103f834a9c3a"},
+ {file = "librt-0.7.7-cp314-cp314t-win32.whl", hash = "sha256:5e419e0db70991b6ba037b70c1d5bbe92b20ddf82f31ad01d77a347ed9781398"},
+ {file = "librt-0.7.7-cp314-cp314t-win_amd64.whl", hash = "sha256:d6b7d93657332c817b8d674ef6bf1ab7796b4f7ce05e420fd45bd258a72ac804"},
+ {file = "librt-0.7.7-cp314-cp314t-win_arm64.whl", hash = "sha256:142c2cd91794b79fd0ce113bd658993b7ede0fe93057668c2f98a45ca00b7e91"},
+ {file = "librt-0.7.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c8ffe3431d98cc043a14e88b21288b5ec7ee12cb01260e94385887f285ef9389"},
+ {file = "librt-0.7.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e40d20ae1722d6b8ea6acf4597e789604649dcd9c295eb7361a28225bc2e9e12"},
+ {file = "librt-0.7.7-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f2cb63c49bc96847c3bb8dca350970e4dcd19936f391cfdfd057dcb37c4fa97e"},
+ {file = "librt-0.7.7-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8f2f8dcf5ab9f80fb970c6fd780b398efb2f50c1962485eb8d3ab07788595a48"},
+ {file = "librt-0.7.7-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a1f5cc41a570269d1be7a676655875e3a53de4992a9fa38efb7983e97cf73d7c"},
+ {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ff1fb2dfef035549565a4124998fadcb7a3d4957131ddf004a56edeb029626b3"},
+ {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ab2a2a9cd7d044e1a11ca64a86ad3361d318176924bbe5152fbc69f99be20b8c"},
+ {file = "librt-0.7.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ad3fc2d859a709baf9dd9607bb72f599b1cfb8a39eafd41307d0c3c4766763cb"},
+ {file = "librt-0.7.7-cp39-cp39-win32.whl", hash = "sha256:f83c971eb9d2358b6a18da51dc0ae00556ac7c73104dde16e9e14c15aaf685ca"},
+ {file = "librt-0.7.7-cp39-cp39-win_amd64.whl", hash = "sha256:264720fc288c86039c091a4ad63419a5d7cabbf1c1c9933336a957ed2483e570"},
+ {file = "librt-0.7.7.tar.gz", hash = "sha256:81d957b069fed1890953c3b9c3895c7689960f233eea9a1d9607f71ce7f00b2c"},
+]
+
[[package]]
name = "litellm"
version = "1.75.0"
@@ -3304,7 +3415,7 @@ description = "A framework for machine learning on Apple silicon."
optional = true
python-versions = ">=3.9"
groups = ["main"]
-markers = "(platform_machine == \"arm64\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\") and platform_system == \"Darwin\" and (extra == \"vllm\" or extra == \"nvidia\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\")"
+markers = "(extra == \"test\" or extra == \"mlx\" or extra == \"apple\" or platform_machine == \"arm64\") and platform_system == \"Darwin\" and (extra == \"vllm\" or extra == \"nvidia\" or extra == \"test\" or extra == \"mlx\" or extra == \"apple\")"
files = [
{file = "mlx_metal-0.28.0-py3-none-macosx_13_0_arm64.whl", hash = "sha256:ce08d40f1fad4f0b3bc87bfff5d603c7fe7dd141c082ba9ce9328b41e8f8d46b"},
{file = "mlx_metal-0.28.0-py3-none-macosx_14_0_arm64.whl", hash = "sha256:424142ab843e2ac0b14edb58cf88d96723823c565291f46ddeeaa072abcc991e"},
@@ -3640,6 +3751,67 @@ files = [
{file = "multidict-6.7.0.tar.gz", hash = "sha256:c6e99d9a65ca282e578dfea819cfa9c0a62b2499d8677392e09feaf305e9e6f5"},
]
+[[package]]
+name = "mypy"
+version = "1.19.1"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["dev"]
+files = [
+ {file = "mypy-1.19.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f05aa3d375b385734388e844bc01733bd33c644ab48e9684faa54e5389775ec"},
+ {file = "mypy-1.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:022ea7279374af1a5d78dfcab853fe6a536eebfda4b59deab53cd21f6cd9f00b"},
+ {file = "mypy-1.19.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee4c11e460685c3e0c64a4c5de82ae143622410950d6be863303a1c4ba0e36d6"},
+ {file = "mypy-1.19.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de759aafbae8763283b2ee5869c7255391fbc4de3ff171f8f030b5ec48381b74"},
+ {file = "mypy-1.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ab43590f9cd5108f41aacf9fca31841142c786827a74ab7cc8a2eacb634e09a1"},
+ {file = "mypy-1.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:2899753e2f61e571b3971747e302d5f420c3fd09650e1951e99f823bc3089dac"},
+ {file = "mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288"},
+ {file = "mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab"},
+ {file = "mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6"},
+ {file = "mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331"},
+ {file = "mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925"},
+ {file = "mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042"},
+ {file = "mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1"},
+ {file = "mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e"},
+ {file = "mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2"},
+ {file = "mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8"},
+ {file = "mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a"},
+ {file = "mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13"},
+ {file = "mypy-1.19.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3157c7594ff2ef1634ee058aafc56a82db665c9438fd41b390f3bde1ab12250"},
+ {file = "mypy-1.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdb12f69bcc02700c2b47e070238f42cb87f18c0bc1fc4cdb4fb2bc5fd7a3b8b"},
+ {file = "mypy-1.19.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f859fb09d9583a985be9a493d5cfc5515b56b08f7447759a0c5deaf68d80506e"},
+ {file = "mypy-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9a6538e0415310aad77cb94004ca6482330fece18036b5f360b62c45814c4ef"},
+ {file = "mypy-1.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:da4869fc5e7f62a88f3fe0b5c919d1d9f7ea3cef92d3689de2823fd27e40aa75"},
+ {file = "mypy-1.19.1-cp313-cp313-win_amd64.whl", hash = "sha256:016f2246209095e8eda7538944daa1d60e1e8134d98983b9fc1e92c1fc0cb8dd"},
+ {file = "mypy-1.19.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:06e6170bd5836770e8104c8fdd58e5e725cfeb309f0a6c681a811f557e97eac1"},
+ {file = "mypy-1.19.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:804bd67b8054a85447c8954215a906d6eff9cabeabe493fb6334b24f4bfff718"},
+ {file = "mypy-1.19.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:21761006a7f497cb0d4de3d8ef4ca70532256688b0523eee02baf9eec895e27b"},
+ {file = "mypy-1.19.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28902ee51f12e0f19e1e16fbe2f8f06b6637f482c459dd393efddd0ec7f82045"},
+ {file = "mypy-1.19.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:481daf36a4c443332e2ae9c137dfee878fcea781a2e3f895d54bd3002a900957"},
+ {file = "mypy-1.19.1-cp314-cp314-win_amd64.whl", hash = "sha256:8bb5c6f6d043655e055be9b542aa5f3bdd30e4f3589163e85f93f3640060509f"},
+ {file = "mypy-1.19.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7bcfc336a03a1aaa26dfce9fff3e287a3ba99872a157561cbfcebe67c13308e3"},
+ {file = "mypy-1.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b7951a701c07ea584c4fe327834b92a30825514c868b1f69c30445093fdd9d5a"},
+ {file = "mypy-1.19.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b13cfdd6c87fc3efb69ea4ec18ef79c74c3f98b4e5498ca9b85ab3b2c2329a67"},
+ {file = "mypy-1.19.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f28f99c824ecebcdaa2e55d82953e38ff60ee5ec938476796636b86afa3956e"},
+ {file = "mypy-1.19.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c608937067d2fc5a4dd1a5ce92fd9e1398691b8c5d012d66e1ddd430e9244376"},
+ {file = "mypy-1.19.1-cp39-cp39-win_amd64.whl", hash = "sha256:409088884802d511ee52ca067707b90c883426bd95514e8cfda8281dc2effe24"},
+ {file = "mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247"},
+ {file = "mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba"},
+]
+
+[package.dependencies]
+librt = {version = ">=0.6.2", markers = "platform_python_implementation != \"PyPy\""}
+mypy_extensions = ">=1.0.0"
+pathspec = ">=0.9.0"
+typing_extensions = ">=4.6.0"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+faster-cache = ["orjson"]
+install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
+reports = ["lxml"]
+
[[package]]
name = "mypy-extensions"
version = "1.1.0"
@@ -4445,10 +4617,9 @@ numpy = "*"
name = "pillow"
version = "11.3.0"
description = "Python Imaging Library (Fork)"
-optional = true
+optional = false
python-versions = ">=3.9"
groups = ["main"]
-markers = "extra == \"memory\" or extra == \"all\" or extra == \"test\" or extra == \"vision\" or extra == \"vendors\" or extra == \"vllm\" or extra == \"nvidia\""
files = [
{file = "pillow-11.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1b9c17fd4ace828b3003dfd1e30bff24863e0eb59b535e8f80194d9cc7ecf860"},
{file = "pillow-11.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:65dc69160114cdd0ca0f35cb434633c75e8e7fad4cf855177a05bf38678f73ad"},
@@ -4853,10 +5024,9 @@ files = [
name = "psutil"
version = "7.1.0"
description = "Cross-platform lib for process and system monitoring."
-optional = true
+optional = false
python-versions = ">=3.6"
groups = ["main"]
-markers = "extra == \"cpu\" or extra == \"all\" or extra == \"vllm\" or extra == \"nvidia\""
files = [
{file = "psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13"},
{file = "psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5"},
@@ -6763,10 +6933,9 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
name = "scikit-learn"
version = "1.7.2"
description = "A set of python modules for machine learning and data mining"
-optional = true
+optional = false
python-versions = ">=3.10"
groups = ["main"]
-markers = "extra == \"memory\" or extra == \"all\""
files = [
{file = "scikit_learn-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b33579c10a3081d076ab403df4a4190da4f4432d443521674637677dc91e61f"},
{file = "scikit_learn-1.7.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:36749fb62b3d961b1ce4fedf08fa57a1986cd409eff2d783bca5d4b9b5fce51c"},
@@ -6820,10 +6989,9 @@ tests = ["matplotlib (>=3.5.0)", "mypy (>=1.15)", "numpydoc (>=1.2.0)", "pandas
name = "scipy"
version = "1.16.2"
description = "Fundamental algorithms for scientific computing in Python"
-optional = true
+optional = false
python-versions = ">=3.11"
groups = ["main"]
-markers = "extra == \"memory\" or extra == \"all\" or extra == \"vllm\" or extra == \"nvidia\""
files = [
{file = "scipy-1.16.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:6ab88ea43a57da1af33292ebd04b417e8e2eaf9d5aa05700be8d6e1b6501cd92"},
{file = "scipy-1.16.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c95e96c7305c96ede73a7389f46ccd6c659c4da5ef1b2789466baeaed3622b6e"},
@@ -7523,10 +7691,9 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"]
name = "threadpoolctl"
version = "3.6.0"
description = "threadpoolctl"
-optional = true
+optional = false
python-versions = ">=3.9"
groups = ["main"]
-markers = "extra == \"memory\" or extra == \"all\""
files = [
{file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"},
{file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"},
@@ -8052,7 +8219,7 @@ version = "4.15.0"
description = "Backported and Experimental Type Hints for Python 3.9+"
optional = false
python-versions = ">=3.9"
-groups = ["main"]
+groups = ["main", "dev"]
files = [
{file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"},
{file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"},
@@ -8756,4 +8923,4 @@ vllm = ["vllm"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.11,<3.13"
-content-hash = "8631e253579b9b6a7fa5206b39846e0f03b19fd2e160640d724b8e2f0a7d1572"
+content-hash = "e3bb6d7a0d3ad3018734503b81ce0a74228041d5d18542236730ed1dc8defd7f"
diff --git a/pyproject.toml b/pyproject.toml
index ca84a1b7..08e95c7a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -276,6 +276,74 @@ version = "1.3.10"
[tool.poetry.group.dev.dependencies]
ruff = "^0.11.11"
black = "^25.1.0"
+mypy = "^1.14.0"
+
+[tool.mypy]
+python_version = "3.11"
+files = ["src/avalan"]
+plugins = ["pydantic.mypy"]
+strict = false
+warn_return_any = true
+warn_unused_ignores = false
+warn_redundant_casts = true
+warn_unused_configs = true
+show_error_codes = true
+show_column_numbers = true
+pretty = true
+check_untyped_defs = true
+disallow_untyped_defs = false
+disallow_incomplete_defs = false
+disallow_untyped_decorators = false
+
+[[tool.mypy.overrides]]
+module = [
+ "a2a.*",
+ "anthropic.*",
+ "bitsandbytes.*",
+ "boto3.*",
+ "botocore.*",
+ "diffusers.*",
+ "elasticsearch.*",
+ "faiss.*",
+ "google.*",
+ "huggingface_hub.*",
+ "imageio.*",
+ "jinja2.*",
+ "keyring.*",
+ "litellm.*",
+ "markitdown.*",
+ "markdownify.*",
+ "mcp.*",
+ "mlx.*",
+ "mlx_lm.*",
+ "numpy.*",
+ "openai.*",
+ "pandas.*",
+ "pgvector.*",
+ "PIL.*",
+ "playwright.*",
+ "psycopg.*",
+ "psycopg_pool.*",
+ "pydantic.*",
+ "RestrictedPython.*",
+ "rich.*",
+ "soundfile.*",
+ "sqlalchemy.*",
+ "sqlglot.*",
+ "sympy.*",
+ "tiktoken.*",
+ "torch.*",
+ "torchaudio.*",
+ "torchvision.*",
+ "tqdm.*",
+ "transformers.*",
+ "tree_sitter.*",
+ "tree_sitter_python.*",
+ "uvicorn.*",
+ "vllm.*",
+ "youtube_transcript_api.*",
+]
+ignore_missing_imports = true
[tool.poetry-dynamic-versioning]
enable = true
diff --git a/src/avalan/__init__.py b/src/avalan/__init__.py
index dcf9b406..51fb3b94 100644
--- a/src/avalan/__init__.py
+++ b/src/avalan/__init__.py
@@ -5,18 +5,18 @@
from packaging.version import Version, parse
-def _config() -> dict:
+def _config() -> dict[str, str]:
config = metadata("avalan")
package_version = metadata_version("avalan")
return {
- "name": config["Name"],
+ "name": str(config["Name"]),
"version": package_version,
- "license": config["License"],
+ "license": str(config["License"]),
"url": "https://avalan.ai",
}
-config = _config()
+config: dict[str, str] = _config()
def license() -> str:
diff --git a/src/avalan/agent/engine.py b/src/avalan/agent/engine.py
index 0b5fdba5..3d724ba6 100644
--- a/src/avalan/agent/engine.py
+++ b/src/avalan/agent/engine.py
@@ -76,7 +76,8 @@ async def input_token_count(self) -> int | None:
},
)
)
- count = self._model.input_token_count(
+ assert hasattr(self._model, "input_token_count")
+ count: int | None = self._model.input_token_count(
self._last_prompt[0],
system_prompt=self._last_prompt[1],
developer_prompt=self._last_prompt[2],
@@ -158,6 +159,7 @@ async def __call__(
},
)
)
+ assert context.input is not None
output = await self._run(context, context.input, **run_args)
await self._event_manager.trigger(
Event(
@@ -246,11 +248,18 @@ async def _run(
if previous_message:
await self.sync_message(previous_message)
- for current_message in input_value:
+ for current_item in input_value:
+ current_message = (
+ current_item
+ if isinstance(current_item, Message)
+ else Message(role=MessageRole.USER, content=current_item)
+ )
await self.sync_message(current_message)
# Make recent memory the new model input
- input_value = [rm.message for rm in self._memory.recent_messages]
+ recent = self._memory.recent_messages
+ assert recent is not None
+ input_value = [rm.message for rm in recent]
# Have model generate output from input
@@ -289,7 +298,7 @@ async def _run(
tool=self._tool,
context=context,
)
- output = await self._model_manager(model_task)
+ output: TextGenerationResponse = await self._model_manager(model_task)
await self._event_manager.trigger(
Event(
type=EventType.MODEL_EXECUTE_AFTER,
@@ -344,6 +353,7 @@ async def sync_message(self, message: Message) -> None:
},
)
)
+ assert self._model.model_id is not None
await self._memory.append_message(
EngineMessage(
agent_id=self._id,
diff --git a/src/avalan/agent/loader.py b/src/avalan/agent/loader.py
index 704f2136..558df80c 100644
--- a/src/avalan/agent/loader.py
+++ b/src/avalan/agent/loader.py
@@ -63,6 +63,7 @@ class OrchestratorLoader:
)
_OPENAI_RESPONSES_ALIASES = frozenset({"response", "responses"})
+ _event_manager: EventManager
_hub: HuggingfaceHub
_logger: Logger
_participant_id: UUID
@@ -75,12 +76,19 @@ def __init__(
logger: Logger,
participant_id: UUID,
stack: AsyncExitStack,
+ event_manager: EventManager | None = None,
) -> None:
+ self._event_manager = event_manager or EventManager()
self._hub = hub
self._logger = logger
self._participant_id = participant_id
self._stack = stack
+ @property
+ def event_manager(self) -> EventManager:
+ """Return the event manager instance."""
+ return self._event_manager
+
@staticmethod
def parse_permanent_store_value(
value: str,
@@ -372,7 +380,7 @@ async def from_file(
agent_config=agent_config,
uri=uri,
engine_config=engine_config,
- tools=enable_tools,
+ tools=enable_tools or [],
call_options=call_options,
template_vars=template_vars,
memory_permanent_message=memory_permanent_message,
@@ -495,7 +503,7 @@ async def from_settings(
_l("Loading event manager")
- event_manager = EventManager()
+ event_manager = self._event_manager
if settings.log_events:
def _log_event(event: Event) -> None:
@@ -607,6 +615,7 @@ def _log_event(event: Event) -> None:
assert settings.agent_id
+ agent: Orchestrator
if settings.orchestrator_type == "json":
assert settings.json_config is not None
agent = self._load_json_orchestrator(
diff --git a/src/avalan/agent/orchestrator/__init__.py b/src/avalan/agent/orchestrator/__init__.py
index 5e9ed47e..22b28a0e 100644
--- a/src/avalan/agent/orchestrator/__init__.py
+++ b/src/avalan/agent/orchestrator/__init__.py
@@ -6,6 +6,7 @@
Message,
MessageContentText,
MessageRole,
+ TransformerEngineSettings,
)
from ...entities import Modality as Modality
from ...event import Event, EventType
@@ -14,18 +15,24 @@
from ...model.call import ModelCallContext
from ...model.engine import Engine
from ...model.manager import ModelManager
+from ...model.response.text import (
+ TextGenerationResponse as TextGenerationResponse,
+)
from ...tool.manager import ToolManager
from .. import (
AgentOperation,
- EngineEnvironment,
InputType,
NoOperationAvailableException,
Specification,
)
+from .. import (
+ EngineEnvironment as EngineEnvironment,
+)
from ..engine import EngineAgent
from ..renderer import Renderer, TemplateEngineAgent
from .response.orchestrator_response import OrchestratorResponse
+from collections.abc import Callable, Coroutine
from contextlib import ExitStack
from dataclasses import asdict
from json import dumps
@@ -46,7 +53,7 @@ class Orchestrator:
_memory: MemoryManager
_tool: ToolManager
_event_manager: EventManager
- _engine_agents: dict[EngineEnvironment, EngineAgent] = {}
+ _engine_agents: dict[str, EngineAgent] = {}
_engines_stack: ExitStack = ExitStack()
_operation_step: int | None = None
_model_ids: set[str] = set()
@@ -108,7 +115,9 @@ def id(self) -> UUID:
return self._id
@property
- def input_token_count(self) -> int | None:
+ def input_token_count(
+ self,
+ ) -> Callable[[], Coroutine[Any, Any, int | None]] | None:
return (
self._last_engine_agent.input_token_count
if self._last_engine_agent
@@ -243,7 +252,7 @@ async def __call__(self, input: Input, **kwargs) -> OrchestratorResponse:
return OrchestratorResponse(
messages,
- result,
+ result, # type: ignore[arg-type]
engine_agent,
operation,
engine_args,
@@ -264,7 +273,11 @@ async def __aenter__(self):
environment = operation.environment
environment_hash = dumps(asdict(environment))
if environment_hash not in self._engine_agents:
+ assert environment.engine_uri.model_id is not None
model_ids.append(environment.engine_uri.model_id)
+ assert isinstance(
+ environment.settings, TransformerEngineSettings
+ )
engine = self._model_manager.load_engine(
environment.engine_uri,
environment.settings,
@@ -359,18 +372,22 @@ def _input_messages(
else message.content
)
render_vars.update({"input": message_content})
- content = (
- self._renderer(self._user_template, **render_vars)
- if self._user_template
- else self._renderer.from_string(
+ if self._user_template:
+ content = self._renderer(
+ self._user_template, **render_vars
+ )
+ else:
+ assert self._user is not None
+ content = self._renderer.from_string(
self._user, template_vars=render_vars
)
- )
message = Message(role=message.role, content=content)
if isinstance(input, list):
- input[-1] = message
+ input[-1] = message # type: ignore[call-overload]
else:
input = message
- return input
+ # The return type must be Message | list[Message] per the signature
+ assert isinstance(input, (Message, list))
+ return input # type: ignore[return-value]
diff --git a/src/avalan/agent/orchestrator/orchestrators/default.py b/src/avalan/agent/orchestrator/orchestrators/default.py
index 44be78cb..4dca13b7 100644
--- a/src/avalan/agent/orchestrator/orchestrators/default.py
+++ b/src/avalan/agent/orchestrator/orchestrators/default.py
@@ -1,4 +1,10 @@
-from ....agent import AgentOperation, EngineEnvironment, Goal, Specification
+from ....agent import (
+ AgentOperation,
+ EngineEnvironment,
+ Goal,
+ Role,
+ Specification,
+)
from ....agent.orchestrator import Orchestrator
from ....entities import EngineUri, Modality, TransformerEngineSettings
from ....event.manager import EventManager
@@ -47,7 +53,7 @@ def __init__(
)
else:
specification = Specification(
- role=role,
+ role=Role(persona=[role]) if role else None,
goal=(
Goal(task=task, instructions=[instructions])
if task and instructions
@@ -66,7 +72,8 @@ def __init__(
AgentOperation(
specification=specification,
environment=EngineEnvironment(
- engine_uri=engine_uri, settings=settings
+ engine_uri=engine_uri,
+ settings=settings or TransformerEngineSettings(),
),
modality=Modality.TEXT_GENERATION,
),
diff --git a/src/avalan/agent/orchestrator/orchestrators/json.py b/src/avalan/agent/orchestrator/orchestrators/json.py
index ca96c862..42a1c85e 100644
--- a/src/avalan/agent/orchestrator/orchestrators/json.py
+++ b/src/avalan/agent/orchestrator/orchestrators/json.py
@@ -8,7 +8,11 @@
Specification,
)
from ....agent.orchestrator import Orchestrator
-from ....entities import Input, Modality, TransformerEngineSettings
+from ....entities import (
+ Input,
+ Modality,
+ TransformerEngineSettings,
+)
from ....event.manager import EventManager
from ....memory.manager import MemoryManager
from ....model.manager import ModelManager
@@ -17,6 +21,7 @@
from dataclasses import dataclass
from logging import Logger
from typing import Annotated, get_args, get_origin
+from uuid import UUID
@dataclass(frozen=True, kw_only=True, slots=True)
@@ -56,7 +61,9 @@ def __init__(
Property(
name=name,
data_type=data_type,
- description=description.strip(),
+ description=(
+ description.strip() if description else None
+ ),
)
)
else:
@@ -93,6 +100,7 @@ def __init__(
event_manager: EventManager,
output: type | list[Property],
*,
+ id: UUID | None = None,
role: str | None = None,
task: str | None = None,
instructions: str | None = None,
@@ -136,15 +144,20 @@ def __init__(
AgentOperation(
specification=specification,
environment=EngineEnvironment(
- engine_uri=engine_uri, settings=settings
+ engine_uri=engine_uri,
+ settings=settings or TransformerEngineSettings(),
),
modality=Modality.TEXT_GENERATION,
),
call_options=call_options,
+ id=id,
+ name=name,
user=user,
user_template=user_template,
)
- async def __call__(self, input: Input, **kwargs) -> str:
+ async def __call__( # type: ignore[override]
+ self, input: Input, **kwargs
+ ) -> str:
text_response = await super().__call__(input, **kwargs)
return await text_response.to_json()
diff --git a/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py b/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py
index 2e8ea9ae..391ace86 100644
--- a/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py
+++ b/src/avalan/agent/orchestrator/orchestrators/reasoning/cot.py
@@ -43,7 +43,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
def __getattr__(self, name):
return getattr(self._orchestrator, name)
- async def __call__(
+ async def __call__( # type: ignore[override]
self, input: Input, **kwargs
) -> ReasoningOrchestratorResponse:
template_vars = {}
diff --git a/src/avalan/agent/orchestrator/response/orchestrator_response.py b/src/avalan/agent/orchestrator/response/orchestrator_response.py
index 0236f497..16202b38 100644
--- a/src/avalan/agent/orchestrator/response/orchestrator_response.py
+++ b/src/avalan/agent/orchestrator/response/orchestrator_response.py
@@ -27,7 +27,7 @@
from json import dumps
from queue import Queue
from time import perf_counter
-from typing import Any, AsyncIterator, Callable
+from typing import Any, AsyncIterator, Callable, cast
from uuid import UUID
@@ -131,7 +131,10 @@ async def to(self, entity_class: type) -> Any:
def __aiter__(self) -> "OrchestratorResponse":
if self._event_manager:
self._response.add_done_callback(self._on_consumed)
- self._response_iterator = self._response.__aiter__()
+ self._response_iterator = cast(
+ AsyncIterator[Token | TokenDetail | Event],
+ self._response.__aiter__(),
+ )
self._calls = Queue()
self._parser_queue = Queue()
self._tool_context = ToolCallContext(
@@ -149,6 +152,7 @@ def __aiter__(self) -> "OrchestratorResponse":
async def __anext__(self) -> Token | TokenDetail | Event:
assert self._response_iterator
+ assert self._parser_queue is not None
if not self._parser_queue.empty():
return self._parser_queue.get()
@@ -162,9 +166,15 @@ async def __anext__(self) -> Token | TokenDetail | Event:
if not self._tool_call_events.empty():
event = self._tool_call_events.get()
assert event.type == EventType.TOOL_PROCESS
+ assert self._event_manager is not None
await self._event_manager.trigger(event)
- calls: list[ToolCall] = event.payload or []
+ payload = event.payload
+ calls: list[ToolCall] = (
+ payload.get("calls", [])
+ if isinstance(payload, dict)
+ else (payload if isinstance(payload, list) else [])
+ )
if calls:
for call in calls:
assert isinstance(call, ToolCall)
@@ -191,7 +201,8 @@ async def __anext__(self) -> Token | TokenDetail | Event:
if self._event_manager:
await self._event_manager.trigger(execute_event)
- context = ToolCallContext(
+ assert self._tool_context is not None
+ tool_ctx = ToolCallContext(
input=self._tool_context.input,
agent_id=self._agent_id,
participant_id=self._participant_id,
@@ -200,13 +211,13 @@ async def __anext__(self) -> Token | TokenDetail | Event:
)
result = (
- await self._tool_manager(call, context)
+ await self._tool_manager(call, tool_ctx)
if self._tool_manager
else None
)
self._call_history.append(call)
- self._tool_context = context
+ self._tool_context = tool_ctx
end = perf_counter()
result_event = Event(
@@ -236,6 +247,7 @@ async def __anext__(self) -> Token | TokenDetail | Event:
tool_messages = []
for e in result_events:
+ assert e.payload is not None
tool_result = e.payload["result"]
tool_output = (
tool_result.message
@@ -256,8 +268,9 @@ async def __anext__(self) -> Token | TokenDetail | Event:
)
tool_result_output = dumps(
(
- asdict(tool_output)
+ asdict(tool_output) # type: ignore[arg-type]
if is_dataclass(tool_output)
+ and not isinstance(tool_output, type)
else tool_output
),
default=lambda o: (
@@ -293,8 +306,15 @@ async def __anext__(self) -> Token | TokenDetail | Event:
or isinstance(self._input, Message)
)
- messages = list(
- self._input if isinstance(self._input, list) else [self._input]
+ messages: list[Message] = list(
+ cast(
+ list[Message],
+ (
+ self._input
+ if isinstance(self._input, list)
+ else [self._input]
+ ),
+ )
)
messages.extend(tool_messages)
@@ -307,11 +327,12 @@ async def __anext__(self) -> Token | TokenDetail | Event:
"engine_args": self._engine_args,
},
)
+ assert self._event_manager is not None
await self._event_manager.trigger(event_tool_model_run)
context = self._make_child_context(messages)
inner_response = await self._engine_agent(context)
- assert inner_response
+ assert isinstance(inner_response, TextGenerationResponse)
self._response = inner_response
self.__aiter__()
@@ -334,12 +355,13 @@ async def __anext__(self) -> Token | TokenDetail | Event:
if isinstance(token, ToolCallToken) and token.call:
event = Event(
type=EventType.TOOL_PROCESS,
- payload=[token.call],
+ payload={"calls": [token.call]},
started=perf_counter(),
)
self._tool_process_events.put(event)
except StopAsyncIteration:
if self._tool_parser:
+ assert self._parser_queue is not None
for item in await self._tool_parser.flush():
if isinstance(item, Event):
self._tool_process_events.put(item)
@@ -419,7 +441,8 @@ async def _react(
)
self._call_history.append(call)
self._tool_context = context
- results.append(result)
+ if result is not None:
+ results.append(result)
if self._event_manager:
end = perf_counter()
@@ -492,8 +515,15 @@ async def _react_process(
or isinstance(self._input, Message)
)
- messages = list(
- self._input if isinstance(self._input, list) else [self._input]
+ messages: list[Message] = list(
+ cast(
+ list[Message],
+ (
+ self._input
+ if isinstance(self._input, list)
+ else [self._input]
+ ),
+ )
)
messages.extend(tool_messages)
@@ -514,11 +544,12 @@ async def _react_process(
"engine_args": self._engine_args,
},
)
+ assert self._event_manager is not None
await self._event_manager.trigger(event_tool_model_run)
context = self._make_child_context(messages)
response = await self._engine_agent(context)
- assert response
+ assert isinstance(response, TextGenerationResponse)
return response
async def _emit(
@@ -564,6 +595,7 @@ async def _emit(
else:
items = [item]
+ assert self._parser_queue is not None
for it in items:
if isinstance(it, Event):
if it.type == EventType.TOOL_PROCESS:
diff --git a/src/avalan/agent/renderer.py b/src/avalan/agent/renderer.py
index 1a5dc5df..30744dc8 100644
--- a/src/avalan/agent/renderer.py
+++ b/src/avalan/agent/renderer.py
@@ -56,10 +56,9 @@ def from_string(
self,
template: str,
template_vars: dict | None = None,
- encoding: str = "utf-8",
) -> str:
return (
- Template(template).render(**template_vars).encode(encoding)
+ Template(template).render(**template_vars)
if template_vars
else template
)
diff --git a/src/avalan/cli/__init__.py b/src/avalan/cli/__init__.py
index 5d9b9c8e..54be850d 100644
--- a/src/avalan/cli/__init__.py
+++ b/src/avalan/cli/__init__.py
@@ -2,10 +2,11 @@
from collections.abc import Iterator
from contextlib import contextmanager, nullcontext
-from io import UnsupportedOperation
+from io import TextIOWrapper, UnsupportedOperation
from json import dumps
from select import select
from sys import stdin
+from typing import Any
from rich.console import Console
from rich.live import Live
@@ -48,13 +49,8 @@ def confirm_tool_call(
live: Live | None = None,
) -> str:
prompt = "Execute tool call?"
- options = {
- "choices": ["y", "a", "n"],
- "default": "n",
- "console": console,
- "show_choices": True,
- "show_default": True,
- }
+ choices = ["y", "a", "n"]
+ default = "n"
with _pause_live(live) if live else nullcontext():
call_element = Syntax(
@@ -65,22 +61,30 @@ def confirm_tool_call(
if live:
prompt_element = Text.from_markup(prompt, style="prompt")
prompt_element.end = ""
- choices = "/".join(options["choices"])
+ choices_str = "/".join(choices)
prompt_element.append(" ")
- prompt_element.append(f"[{choices}]", "prompt.choices")
+ prompt_element.append(f"[{choices_str}]", "prompt.choices")
- if options["show_default"]:
- default = Text(f"({options['default']})", "prompt.default")
- prompt_element.append(" ")
- prompt_element.append(default)
+ default_text = Text(f"({default})", "prompt.default")
+ prompt_element.append(" ")
+ prompt_element.append(default_text)
console.print(prompt_element)
stdin_is_tty = stdin.isatty()
with open(tty_path) if not stdin_is_tty else nullcontext() as tty:
- if not stdin_is_tty:
- options["stream"] = tty
- return Prompt.ask(prompt, **options)
+ stream: TextIOWrapper | None = (
+ tty if not stdin_is_tty and tty else None
+ )
+ return Prompt.ask(
+ prompt,
+ choices=choices,
+ default=default,
+ console=console,
+ show_choices=True,
+ show_default=True,
+ stream=stream,
+ )
def has_input(console: Console) -> bool:
@@ -123,10 +127,12 @@ def get_input(
if is_input_available and force_prompt
else nullcontext()
) as tty:
- kwargs = {}
- if is_input_available and force_prompt:
- kwargs["stream"] = tty
- input_string = PromptWithoutPrefix.ask(full_prompt, **kwargs)
+ stream: Any = (
+ tty if is_input_available and force_prompt else None
+ )
+ input_string = PromptWithoutPrefix.ask(
+ full_prompt, stream=stream
+ )
if strip_prompt:
input_string = input_string.strip()
except EOFError:
diff --git a/src/avalan/cli/__main__.py b/src/avalan/cli/__main__.py
index 7224dcf1..d6af372f 100644
--- a/src/avalan/cli/__main__.py
+++ b/src/avalan/cli/__main__.py
@@ -23,6 +23,7 @@
model_uninstall,
)
from ..cli.commands.tokenizer import tokenize
+from ..cli.theme import Theme
from ..cli.theme.fancy import FancyTheme
from ..entities import (
AttentionImplementation,
@@ -51,11 +52,16 @@
import gettext
import sys
-from argparse import ArgumentParser, Namespace, _SubParsersAction
+from argparse import (
+ ArgumentParser,
+ Namespace,
+ _ArgumentGroup,
+ _SubParsersAction,
+)
from asyncio import run as run_in_loop
from asyncio.exceptions import CancelledError
from dataclasses import fields
-from gettext import translation
+from gettext import GNUTranslations, NullTranslations, translation
from importlib.util import find_spec
from locale import getlocale
from logging import (
@@ -73,6 +79,7 @@
from pathlib import Path
from subprocess import run
from tomllib import load as toml_load
+from types import ModuleType
from typing import Optional, get_args, get_origin
from typing import get_args as get_type_args
from uuid import uuid4
@@ -81,7 +88,7 @@
from rich.console import Console
from rich.logging import RichHandler
from rich.prompt import Confirm, Prompt
-from rich.theme import Theme
+from rich.theme import Theme as RichTheme
from torch.cuda import device_count, is_available, set_device
from torch.distributed import destroy_process_group
from transformers.utils import (
@@ -112,7 +119,10 @@ def __init__(self, logger: Logger):
)
default_device = TransformerModel.get_default_device()
self._parser = CLI._create_parser(
- default_device, cache_dir, default_locales_path, default_locale
+ default_device,
+ cache_dir,
+ default_locales_path,
+ default_locale or "en_US",
)
@staticmethod
@@ -141,7 +151,7 @@ def _create_parser(
cache_dir: str,
default_locales_path: str,
default_locale: str,
- ):
+ ) -> ArgumentParser:
default_attention = CLI._default_attention(default_device)
global_parser = ArgumentParser(add_help=False)
global_parser.add_argument(
@@ -1637,7 +1647,7 @@ def _create_parser(
@staticmethod
def _get_translator(
app_name: str, locales_path: str, locale: str
- ) -> object:
+ ) -> GNUTranslations | NullTranslations | ModuleType:
"""Return translation object for ``locale`` or ``gettext`` fallback."""
try:
return translation(
@@ -1763,7 +1773,7 @@ def _add_agent_server_arguments(parser: ArgumentParser) -> ArgumentParser:
@staticmethod
def _add_agent_settings_arguments(
parser: ArgumentParser,
- ) -> ArgumentParser:
+ ) -> _ArgumentGroup:
group = parser.add_argument_group("inline agent settings")
group.add_argument("--engine-uri", type=str, help="Agent engine URI")
group.add_argument("--name", type=str, help="Agent name")
@@ -1885,7 +1895,7 @@ def _add_agent_settings_arguments(
@staticmethod
def _add_tool_settings_arguments(
parser: ArgumentParser, *, prefix: str, settings_cls: type
- ) -> ArgumentParser:
+ ) -> _ArgumentGroup:
"""Add dataclass based tool options to ``parser``."""
group = parser.add_argument_group(f"{prefix} tool settings")
@@ -1975,12 +1985,15 @@ async def __call__(self) -> None:
setattr(args, f"run_chat_{key}", value)
translator = CLI._get_translator(self._name, args.locales, args.locale)
+ gettext_fn = getattr(translator, "gettext", gettext.gettext)
+ ngettext_fn = getattr(translator, "ngettext", gettext.ngettext)
assert self._logger is not None and isinstance(self._logger, Logger)
- theme = FancyTheme(translator.gettext, translator.ngettext)
+ theme = FancyTheme(gettext_fn, ngettext_fn)
_ = theme._
console = Console(
- theme=Theme(styles=theme.get_styles()), record=args.record
+ theme=RichTheme(styles=theme.get_styles()),
+ record=args.record,
)
if args.help_full:
@@ -1991,14 +2004,14 @@ async def __call__(self) -> None:
if requires_token:
if not access_token:
- prompt_kwargs = {}
+ stream = None
if has_input(console):
try:
- prompt_kwargs["stream"] = open(args.tty)
+ stream = open(args.tty)
except OSError:
pass
access_token = Prompt.ask(
- theme.ask_access_token(), **prompt_kwargs
+ theme.ask_access_token(), stream=stream
)
assert access_token
else:
@@ -2091,13 +2104,15 @@ def filter(self, record: LogRecord) -> bool:
)
suggest_login = suggest_login and not has_input(console)
+ connecting_spinner = theme.get_spinner("connecting")
if args.login or (
suggest_login
and Confirm.ask(theme.ask_login_to_hub(), default=False)
):
+ assert connecting_spinner is not None
with console.status(
theme.logging_in(hub.domain),
- spinner=(theme.get_spinner("connecting")),
+ spinner=connecting_spinner,
refresh_per_second=self._REFRESH_RATE,
):
hub.login()
@@ -2108,7 +2123,7 @@ def filter(self, record: LogRecord) -> bool:
theme.welcome(
self._site.geturl(),
self._name,
- self._version,
+ str(self._version),
self._license,
user,
)
diff --git a/src/avalan/cli/commands/agent.py b/src/avalan/cli/commands/agent.py
index be1f8303..3190e0ff 100644
--- a/src/avalan/cli/commands/agent.py
+++ b/src/avalan/cli/commands/agent.py
@@ -5,6 +5,7 @@
)
from ...cli import confirm_tool_call, get_input, has_input
from ...cli.commands.model import token_generation
+from ...cli.theme import Theme
from ...entities import (
Backend,
GenerationCacheStrategy,
@@ -23,10 +24,10 @@
from argparse import Namespace
from contextlib import AsyncExitStack
-from dataclasses import fields
+from dataclasses import fields as dataclass_fields
from logging import Logger
from os.path import dirname, getmtime, join
-from typing import Iterable, Mapping
+from typing import Any, Iterable, Mapping, TypeVar, cast
from uuid import UUID, uuid4
from jinja2 import Environment, FileSystemLoader
@@ -34,7 +35,6 @@
from rich.live import Live
from rich.prompt import Confirm, Prompt
from rich.syntax import Syntax
-from rich.theme import Theme
def _parse_permanent_memory_items(
@@ -188,16 +188,19 @@ def get_orchestrator_settings(
)
+T = TypeVar("T")
+
+
def _tool_settings_from_mapping(
mapping: Mapping[str, object] | Namespace,
*,
prefix: str | None = None,
- settings_cls: type,
+ settings_cls: type[T],
open_files: bool = True,
-) -> object:
+) -> T | None:
"""Return tool settings from a mapping using dataclass ``settings_cls``."""
values: dict[str, object] = {}
- for field in fields(settings_cls):
+ for field in dataclass_fields(cast(Any, settings_cls)):
key = f"tool_{prefix}_{field.name}" if prefix else field.name
if isinstance(mapping, Namespace):
if hasattr(mapping, key):
@@ -231,9 +234,10 @@ def get_tool_settings(
args: Namespace,
*,
prefix: str,
- settings_cls: type,
+ settings_cls: type[T],
open_files: bool = True,
-) -> object:
+) -> T | None:
+ """Return tool settings instance from CLI arguments."""
return _tool_settings_from_mapping(
args, prefix=prefix, settings_cls=settings_cls, open_files=open_files
)
@@ -247,7 +251,7 @@ async def agent_message_search(
logger: Logger,
refresh_per_second: int,
) -> None:
- _, _i = theme._, theme.icons
+ _, _i = theme._, theme._icons
specs_path = args.specifications_file
engine_uri = getattr(args, "engine_uri", None)
@@ -267,7 +271,7 @@ async def agent_message_search(
input_string = get_input(
console,
- _i["user_input"] + " ",
+ (_i.get("user_input") or "") + " ",
echo_stdin=not args.no_repl,
is_quiet=args.quiet,
tty_path=tty_path,
@@ -277,6 +281,8 @@ async def agent_message_search(
limit = args.limit
+ spinner = theme.get_spinner("agent_loading")
+ assert spinner is not None
async with AsyncExitStack() as stack:
loader = OrchestratorLoader(
hub=hub,
@@ -286,7 +292,7 @@ async def agent_message_search(
)
with console.status(
_("Loading agent..."),
- spinner=theme.get_spinner("agent_loading"),
+ spinner=spinner,
refresh_per_second=refresh_per_second,
):
if specs_path:
@@ -331,7 +337,8 @@ async def agent_message_search(
)
orchestrator = await stack.enter_async_context(orchestrator)
- assert orchestrator.engine_agent and orchestrator.engine.model_id
+ assert orchestrator.engine_agent
+ assert orchestrator.engine and orchestrator.engine.model_id
can_access = args.skip_hub_access_check or hub.can_access(
orchestrator.engine.model_id
@@ -380,7 +387,7 @@ async def agent_run(
logger: Logger,
refresh_per_second: int,
) -> None:
- _, _i = theme._, theme.icons
+ _, _i = theme._, theme._icons
specs_path = args.specifications_file
engine_uri = getattr(args, "engine_uri", None)
@@ -485,6 +492,7 @@ async def _init_orchestrator() -> Orchestrator:
not orchestrator.tool.is_empty
), "--tools-confirm requires tools"
+ permanent_message = orchestrator.memory.permanent_message
logger.debug(
"Agent loaded from %s, models used: %s, with recent message "
"memory: %s, with permanent message memory: %s",
@@ -492,15 +500,16 @@ async def _init_orchestrator() -> Orchestrator:
orchestrator.model_ids,
"yes" if orchestrator.memory.has_recent_message else "no",
(
- "yes, with session #"
- + str(orchestrator.memory.permanent_message.session_id)
+ "yes, with session #" + str(permanent_message.session_id)
if orchestrator.memory.has_permanent_message
+ and permanent_message
else "no"
),
)
if not args.quiet:
- assert orchestrator.engine_agent and orchestrator.engine.model_id
+ assert orchestrator.engine_agent
+ assert orchestrator.engine and orchestrator.engine.model_id
is_local = not isinstance(
orchestrator.engine, TextGenerationVendorModel
@@ -530,26 +539,30 @@ async def _init_orchestrator() -> Orchestrator:
else:
await orchestrator.memory.start_session()
+ recent_message = orchestrator.memory.recent_message
if (
load_recent_messages
and orchestrator.memory.has_recent_message
- and not orchestrator.memory.recent_message.is_empty
+ and recent_message
+ and not recent_message.is_empty
and not args.quiet
):
console.print(
theme.recent_messages(
participant_id,
orchestrator,
- orchestrator.memory.recent_message.data,
+ recent_message.data,
)
)
return orchestrator
+ spinner = theme.get_spinner("agent_loading")
+ assert spinner is not None
async with AsyncExitStack() as stack:
with console.status(
_("Loading agent..."),
- spinner=theme.get_spinner("agent_loading"),
+ spinner=spinner,
refresh_per_second=refresh_per_second,
):
orchestrator = await _init_orchestrator()
@@ -569,16 +582,21 @@ async def _init_orchestrator() -> Orchestrator:
specs_mtime = new_mtime
in_conversation = False
continue
+ current_recent_message = orchestrator.memory.recent_message
+ recent_size = (
+ str(current_recent_message.size)
+ if orchestrator.memory.has_recent_message
+ and current_recent_message
+ else "0"
+ )
logger.debug(
"Waiting for new message to add to orchestrator's existing "
- + str(orchestrator.memory.recent_message.size)
- if orchestrator.memory
- and orchestrator.memory.has_recent_message
- else "0" + " messages"
+ + recent_size
+ + " messages"
)
input_string = get_input(
console,
- _i["user_input"] + " ",
+ (_i.get("user_input") or "") + " ",
echo_stdin=not args.no_repl,
force_prompt=in_conversation,
is_quiet=args.quiet,
@@ -596,7 +614,7 @@ async def _init_orchestrator() -> Orchestrator:
)
if not args.quiet and not args.stats:
- console.print(_i["agent_output"] + " ", end="")
+ console.print((_i.get("agent_output") or "") + " ", end="")
if args.quiet:
console.print(await output.to_str())
diff --git a/src/avalan/cli/commands/cache.py b/src/avalan/cli/commands/cache.py
index f507f060..ac9fe515 100644
--- a/src/avalan/cli/commands/cache.py
+++ b/src/avalan/cli/commands/cache.py
@@ -1,5 +1,6 @@
from ...cli import confirm
from ...cli.download import create_live_tqdm_class
+from ...cli.theme import Theme
from ...model.hubs import HubAccessDeniedException
from ...model.hubs.huggingface import HuggingfaceHub
@@ -7,7 +8,6 @@
from rich.console import Console
from rich.padding import Padding
-from rich.theme import Theme
def cache_delete(
@@ -27,7 +27,7 @@ def cache_delete(
console.print(theme.cache_delete(cache_deletion, False))
return
- console.print(theme.cache_delete(cache_deletion))
+ console.print(theme.cache_delete(cache_deletion, False))
if not args.delete and not confirm(console, theme.ask_delete_paths()):
return
execute_deletion()
diff --git a/src/avalan/cli/commands/memory.py b/src/avalan/cli/commands/memory.py
index 429ac6ee..a48f4629 100644
--- a/src/avalan/cli/commands/memory.py
+++ b/src/avalan/cli/commands/memory.py
@@ -1,6 +1,7 @@
from ...cli import get_input
from ...cli.commands import get_model_settings
from ...cli.commands.model import model_display
+from ...cli.theme import Theme
from ...entities import DistanceType, Modality, SearchMatch, Similarity
from ...memory.partitioner.code import CodePartitioner
from ...memory.partitioner.text import TextPartition, TextPartitioner
@@ -15,6 +16,7 @@
from io import BytesIO
from logging import Logger
from pathlib import Path
+from typing import Any, cast
from urllib.parse import urlparse
from uuid import UUID
@@ -23,7 +25,6 @@
from numpy import abs, corrcoef, dot, sum, vstack
from numpy.linalg import norm
from rich.console import Console
-from rich.theme import Theme
async def memory_document_index(
@@ -40,7 +41,7 @@ async def memory_document_index(
def transform(html: bytes) -> DocumentConverterResult:
return MarkItDown().convert_stream(BytesIO(html))
- _, _i = theme._, theme.icons
+ _, _i = theme._, cast(dict[str, Any], theme.icons)
model_id = args.model
source = args.source
participant_id = UUID(args.participant)
@@ -139,7 +140,7 @@ def transform(html: bytes) -> DocumentConverterResult:
args.encoding,
args.partition_max_tokens,
)
- partitions: list[TextPartition] = []
+ partitions = []
for cp in code_partitions:
embeddings = await stm(cp.data)
tokens = stm.token_count(cp.data)
@@ -184,7 +185,7 @@ async def memory_embeddings(
logger: Logger,
) -> None:
assert args.model
- _, _i = theme._, theme.icons
+ _, _i = theme._, cast(dict[str, Any], theme.icons)
model_id = args.model
display_partitions = (
args.display_partitions if not args.no_display_partitions else None
@@ -357,7 +358,7 @@ async def memory_embeddings(
index = IndexFlatL2(input_string_embeddings.shape[0])
- if partitioner:
+ if partitioner and knowledge_partitions is not None:
knowledge_stack = vstack(
[kp.embeddings for kp in knowledge_partitions]
).astype("float32", copy=False)
@@ -414,7 +415,7 @@ async def memory_search(
assert args.model and args.dsn and args.participant and args.namespace
assert args.function
- _, _i = theme._, theme.icons
+ _, _i = theme._, cast(dict[str, Any], theme.icons)
model_id = args.model
participant_id = UUID(args.participant)
namespace = args.namespace
diff --git a/src/avalan/cli/commands/model.py b/src/avalan/cli/commands/model.py
index 73e93605..465a0ebf 100644
--- a/src/avalan/cli/commands/model.py
+++ b/src/avalan/cli/commands/model.py
@@ -1,7 +1,11 @@
from ...agent import Specification
from ...agent.orchestrator import Orchestrator
+from ...agent.orchestrator.response.orchestrator_response import (
+ OrchestratorResponse,
+)
from ...cli import confirm, get_input, has_input
from ...cli.commands.cache import cache_delete, cache_download
+from ...cli.theme import Theme
from ...entities import (
GenerationSettings, # noqa: F401
Modality,
@@ -13,6 +17,7 @@
from ...event import TOOL_TYPES, Event, EventStats, EventType
from ...model.call import ModelCall, ModelCallContext
from ...model.criteria import KeywordStoppingCriteria # noqa: F401
+from ...model.engine import Engine
from ...model.hubs.huggingface import HuggingfaceHub
from ...model.manager import ModelManager
from ...model.nlp.sentence import SentenceTransformerModel
@@ -36,13 +41,13 @@
from datetime import datetime, timezone
from logging import Logger
from time import perf_counter
+from typing import Any, cast
from rich.console import Console, Group, RenderableType
from rich.live import Live
from rich.padding import Padding
from rich.prompt import Prompt
from rich.spinner import Spinner
-from rich.theme import Theme
def model_display(
@@ -101,6 +106,9 @@ def model_display(
)
)
elif model:
+ assert (
+ model.tokenizer_config is not None
+ ), "Tokenizer config is required for model display"
console.print(
Padding(
theme.model_display(
@@ -149,7 +157,7 @@ async def model_run(
logger: Logger,
) -> None:
assert args.model and args.device and args.max_new_tokens
- _, _i = theme._, theme.icons
+ _, _i = theme._, cast(dict[str, Any], theme.icons)
with ModelManager(hub, logger) as manager:
engine_uri = manager.parse_uri(args.model)
@@ -236,6 +244,13 @@ async def model_run(
console.print(theme.display_token_labels([output]))
elif operation.modality == Modality.TEXT_GENERATION:
+ assert isinstance(
+ operation.input, str
+ ), "Text generation requires string input"
+ text_params = operation.parameters["text"]
+ assert (
+ text_params is not None
+ ), "Text generation requires text parameters"
await token_generation(
args=args,
console=console,
@@ -247,7 +262,7 @@ async def model_run(
input_string=operation.input,
refresh_per_second=refresh_per_second,
response=output,
- dtokens_pick=operation.parameters["text"].pick_tokens,
+ dtokens_pick=text_params.pick_tokens or 0,
display_tokens=args.display_tokens or 0,
with_stats=not args.quiet,
tool_events_limit=args.display_tools_events,
@@ -291,9 +306,10 @@ async def model_search(
model_access: dict[str, bool] = {}
# Fetch matching models
+ spinner_name = theme.get_spinner("downloading")
with console.status(
_("Loading models..."),
- spinner=theme.get_spinner("downloading"),
+ spinner=spinner_name or "dots",
refresh_per_second=refresh_per_second,
):
models = [
@@ -313,9 +329,11 @@ async def model_search(
]
# Tasks to check model access
+ def _check_access(model_id: str) -> tuple[str, bool]:
+ return (model_id, hub.can_access(model_id))
+
tasks = [
- create_task(to_thread(lambda id=model.id: (id, hub.can_access(id))))
- for model in models
+ create_task(to_thread(_check_access, model.id)) for model in models
]
def _render(
@@ -371,9 +389,9 @@ async def token_generation(
logger: Logger,
orchestrator: Orchestrator | None,
event_stats: EventStats | None,
- lm: TextGenerationModel,
+ lm: TextGenerationModel | Engine | None,
input_string: str,
- response: TextGenerationResponse,
+ response: TextGenerationResponse | OrchestratorResponse,
*,
display_tokens: int,
dtokens_pick: int,
@@ -381,7 +399,7 @@ async def token_generation(
tool_events_limit: int | None,
with_stats: bool = True,
live_container: dict[str, Live | None] | None = None,
-):
+) -> None:
# If no statistics needed, return as early as possible
if not with_stats:
async for token in response:
@@ -513,7 +531,6 @@ async def _event_stream(
events_renderable = theme.events(
event_manager.history,
events_limit=6 if tool_view else 4,
- height=tools_height if tool_view else events_height,
include_tokens=False,
include_tools=tool_view,
include_tool_detect=False,
@@ -543,9 +560,9 @@ async def _token_stream(
logger: Logger,
orchestrator: Orchestrator | None,
event_stats: EventStats | None,
- lm: TextGenerationModel,
+ lm: TextGenerationModel | Engine | None,
input_string: str,
- response: TextGenerationResponse,
+ response: TextGenerationResponse | OrchestratorResponse,
*,
display_tokens: int,
dtokens_pick: int,
@@ -564,7 +581,7 @@ async def _token_stream(
start_thinking = (
args.start_thinking if hasattr(args, "start_thinking") else False
)
- tokens = []
+ tokens: list[Token] = []
answer_text_tokens: list[str] = []
thinking_text_tokens: list[str] = []
tool_text_tokens: list[str] = []
@@ -577,15 +594,21 @@ async def _token_stream(
100 if display_pause > 0 and display_tokens > 0 else 0
)
- input_token_count = (
- response.input_token_count
- if response.input_token_count
- else (
- orchestrator.input_token_count
- if orchestrator
- else lm.input_token_count(input_string)
- )
- )
+ assert lm is not None, "Language model must be provided"
+ input_token_count: int
+ if response.input_token_count:
+ input_token_count = response.input_token_count
+ elif orchestrator and orchestrator.input_token_count is not None:
+ orch_count = orchestrator.input_token_count
+ if callable(orch_count):
+ input_token_count = await orch_count() or 0
+ else:
+ input_token_count = orch_count # type: ignore[assignment]
+ else:
+ assert hasattr(
+ lm, "input_token_count"
+ ), "lm must have input_token_count method"
+ input_token_count = lm.input_token_count(input_string) # type: ignore[union-attr]
ttft: float | None = None
ttnt: float | None = None
last_current_dtoken: Token | None = None
@@ -609,13 +632,14 @@ async def _token_stream(
answer_text_tokens = []
tool_text_tokens = []
thinking_text_tokens = []
- inner_response = event.payload["response"]
- assert isinstance(inner_response, TextGenerationResponse)
- if inner_response.input_token_count:
- input_token_count = inner_response.input_token_count
+ if event.payload is not None:
+ inner_response = event.payload["response"]
+ assert isinstance(inner_response, TextGenerationResponse)
+ if inner_response.input_token_count:
+ input_token_count = inner_response.input_token_count
elif event.type == EventType.TOOL_RESULT:
tool_event_results.append(event)
- if "call" in event.payload:
+ if event.payload is not None and "call" in event.payload:
completed_call_ids.add(event.payload["call"].id)
else:
tool_event_calls.append(event)
@@ -640,16 +664,35 @@ async def _token_stream(
tool_running_spinner = None
if tool_event_calls or tool_event_results:
- tool_calling_names = [
- c.name
- for e in tool_event_calls
- for c in e.payload
- if c.id not in completed_call_ids
- ]
-
+ tool_calling_names: list[str] = []
+ for e in tool_event_calls:
+ if e.payload is None:
+ continue
+ # Handle TOOL_EXECUTE events with {"call": call}
+ if "call" in e.payload:
+ call = e.payload["call"]
+ if (
+ hasattr(call, "id")
+ and hasattr(call, "name")
+ and call.id not in completed_call_ids
+ ):
+ tool_calling_names.append(call.name)
+ # Handle TOOL_PROCESS events with {"calls": [call, ...]}
+ elif "calls" in e.payload:
+ calls = e.payload["calls"]
+ if calls:
+ for c in calls:
+ if (
+ hasattr(c, "id")
+ and hasattr(c, "name")
+ and c.id not in completed_call_ids
+ ):
+ tool_calling_names.append(c.name)
+
+ tool_spinner_name = theme.get_spinner("tool_running")
tool_running_spinner = (
Spinner(
- theme.get_spinner("tool_running"),
+ tool_spinner_name or "dots",
text="[cyan]"
+ theme._n(
"Running tool {tool_names}...",
@@ -687,7 +730,8 @@ async def _token_stream(
)
answer_height = getattr(args, "display_answer_height", 12)
- token_frames_promise = theme.tokens(
+ assert lm.model_id is not None, "Model ID must not be None"
+ token_frames_generator = theme.tokens(
lm.model_id,
lm.tokenizer_config.tokens if lm.tokenizer_config else None,
(
@@ -699,24 +743,32 @@ async def _token_stream(
args.display_probabilities if dtokens_pick > 0 else False,
dtokens_pick,
# Which tokens to mark as interesting
- lambda dtoken: (
+ (
(
- dtoken.probability < args.display_probabilities_maximum
- or len(
- [
- t
- for t in dtoken.tokens
- if t.id != dtoken.id
- and t.probability
- >= args.display_probabilities_sample_minimum
- ]
+ lambda dtoken: (
+ dtoken.probability is not None
+ and dtoken.probability
+ < args.display_probabilities_maximum
+ )
+ or (
+ hasattr(dtoken, "tokens")
+ and dtoken.tokens is not None
+ and len(
+ [
+ t
+ for t in dtoken.tokens
+ if t.id != dtoken.id
+ and t.probability is not None
+ and t.probability
+ >= args.display_probabilities_sample_minimum
+ ]
+ )
+ > 0
)
- > 0
)
if display_tokens
and args.display_probabilities
and args.display_probabilities_maximum > 0
- and args.display_probabilities_maximum > 0
else None
),
thinking_text_tokens,
@@ -729,7 +781,7 @@ async def _token_stream(
tool_event_calls,
tool_event_results,
tool_running_spinner,
- ttft,
+ ttft if ttft is not None else 0.0,
ttnt,
ttsr,
elapsed,
@@ -744,7 +796,7 @@ async def _token_stream(
)
token_frame_list = [
- token_frame async for token_frame in token_frames_promise
+ token_frame async for token_frame in token_frames_generator
]
token_frames = [token_frame_list[0]]
diff --git a/src/avalan/cli/commands/tokenizer.py b/src/avalan/cli/commands/tokenizer.py
index dad21d7c..fcf98a61 100644
--- a/src/avalan/cli/commands/tokenizer.py
+++ b/src/avalan/cli/commands/tokenizer.py
@@ -1,4 +1,5 @@
from ...cli import get_input
+from ...cli.theme import Theme
from ...entities import Token, TransformerEngineSettings
from ...model.hubs.huggingface import HuggingfaceHub
from ...model.nlp.text.generation import TextGenerationModel
@@ -7,7 +8,6 @@
from logging import Logger
from rich.console import Console
-from rich.theme import Theme
async def tokenize(
@@ -19,7 +19,7 @@ async def tokenize(
) -> list[Token] | None:
assert args.tokenizer
- _, _i, _n = theme._, theme.icons, theme._n
+ _, _i, _n = theme._, theme._icons, theme._n
tokenizer_name_or_path = args.tokenizer
with TextGenerationModel(
@@ -57,20 +57,21 @@ async def tokenize(
paths = lm.save_tokenizer(args.save)
total_files = len(paths)
console.print(theme.saved_tokenizer_files(args.save, total_files))
- return
+ return None
tty_path = getattr(args, "tty", "/dev/tty") or "/dev/tty"
+ user_input_icon = _i.get("user_input") or ""
input_string = get_input(
console,
- _i["user_input"] + " ",
+ user_input_icon + " ",
echo_stdin=not args.no_repl,
is_quiet=args.quiet,
tty_path=tty_path,
)
if input_string:
logger.debug("Loaded model %s", lm.config.__repr__())
- tokens = lm.tokenize(input_string)
+ tokens: list[Token] = lm.tokenize(input_string)
panel = theme.tokenizer_tokens(
tokens,
@@ -79,3 +80,5 @@ async def tokenize(
display_details=True,
)
console.print(panel)
+ return tokens
+ return None
diff --git a/src/avalan/cli/download.py b/src/avalan/cli/download.py
index cbe0a14a..55b37079 100644
--- a/src/avalan/cli/download.py
+++ b/src/avalan/cli/download.py
@@ -60,5 +60,5 @@ def display(self, *_, **__):
def reset(self, total=None):
if hasattr(self, "_progress"):
- self._progress.reset(total=total)
+ self._progress.reset(task_id=self._task_id, total=total)
super().reset(total=total)
diff --git a/src/avalan/cli/theme/__init__.py b/src/avalan/cli/theme/__init__.py
index 2ab7bbe3..0466f62c 100644
--- a/src/avalan/cli/theme/__init__.py
+++ b/src/avalan/cli/theme/__init__.py
@@ -1,7 +1,6 @@
from ...agent.orchestrator import Orchestrator
from ...entities import (
EngineMessage,
- EngineMessageScored,
HubCache,
HubCacheDeletion,
ImageEntity,
@@ -14,29 +13,41 @@
TokenizerConfig,
User,
)
+from ...entities import (
+ EngineMessageScored as EngineMessageScored,
+)
from ...event import Event, EventStats
from ...memory.partitioner.text import TextPartition
from ...memory.permanent import Memory as Memory
from ...memory.permanent import PermanentMemoryPartition
from abc import ABC, abstractmethod
+from collections.abc import AsyncGenerator
from dataclasses import fields
from datetime import datetime
from enum import StrEnum
from logging import Logger
-from typing import Callable, Generator, Literal
+from typing import Any, Callable, Literal, get_args
from uuid import UUID
from humanize import intcomma, intword, naturalsize, naturaltime
from numpy import ndarray
from rich.console import RenderableType
+from rich.spinner import Spinner as RichSpinner
Formatter = (
Callable[[datetime], str] | Callable[[float], str] | Callable[[int], str]
)
Formatters = dict[Literal["datetime", "number", "quantity", "size"], Formatter]
-Spinner = Literal["cache_accessing", "connecting", "thinking", "downloading"]
-Data = StrEnum(
+SpinnerName = Literal[
+ "agent_loading",
+ "cache_accessing",
+ "connecting",
+ "downloading",
+ "thinking",
+ "tool_running",
+]
+Data = StrEnum( # type: ignore[misc]
"Data",
{
**{field.name: field.name for field in fields(Model)},
@@ -48,10 +59,10 @@
class Theme(ABC):
- _all_spinners: dict[Spinner, str]
- _all_stylers: Stylers
- _all_styles: dict[Data, str]
- _icons: dict[Data, str]
+ _all_spinners: dict[SpinnerName, str | None]
+ _all_stylers: dict[str, Any]
+ _all_styles: dict[str, str]
+ _icons: dict[str, str | None]
_: Callable[[str], str]
@property
@@ -67,7 +78,7 @@ def quantity_data(self) -> list[str]:
return []
@property
- def spinners(self) -> dict[Spinner, str]:
+ def spinners(self) -> dict[SpinnerName, str]:
return {}
@property
@@ -97,7 +108,7 @@ def agent(
agent: Orchestrator,
*args,
models: list[Model | str],
- cans_access: bool | None,
+ can_access: bool | None,
) -> RenderableType:
raise NotImplementedError()
@@ -252,7 +263,7 @@ def recent_messages(
@abstractmethod
def saved_tokenizer_files(
- directory_path: str, total_files: int
+ self, directory_path: str, total_files: int
) -> RenderableType:
raise NotImplementedError()
@@ -261,8 +272,8 @@ def search_message_matches(
self,
participant_id: UUID,
agent: Orchestrator,
- messages: list[EngineMessageScored],
- ):
+ messages: list[EngineMessage],
+ ) -> RenderableType:
raise NotImplementedError()
@abstractmethod
@@ -284,8 +295,9 @@ def tokenizer_tokens(
dtokens: list[Token],
added_tokens: list[str] | None,
special_tokens: list[str] | None,
+ display_details: bool = False,
current_dtoken: Token | None = None,
- dtokens_selected: list[Token] = [],
+ dtokens_selected: list[Token] | None = None,
) -> RenderableType:
raise NotImplementedError()
@@ -334,6 +346,7 @@ async def tokens(
tool_events: list[Event] | None,
tool_event_calls: list[Event] | None,
tool_event_results: list[Event] | None,
+ tool_running_spinner: RichSpinner | None,
ttft: float,
ttnt: float | None,
ttsr: float | None,
@@ -355,8 +368,9 @@ async def tokens(
limit_tool_height: bool = True,
limit_answer_height: bool = False,
start_thinking: bool = False,
- ) -> Generator[tuple[Token | None, RenderableType], None, None]:
+ ) -> AsyncGenerator[tuple[Token | None, RenderableType], None]:
raise NotImplementedError()
+ yield # pragma: no cover - Makes this an async generator
@abstractmethod
def welcome(
@@ -376,7 +390,7 @@ def __init__(
formatters: Formatters = {},
stylers: Stylers = {},
styles: dict[Data, str] = {},
- spinners: dict[Spinner, str] = {},
+ spinners: dict[SpinnerName, str] = {},
icons: dict[Data, str] = {},
quantity_data: list[str] = [],
):
@@ -392,20 +406,20 @@ def __init__(
**self.formatters,
}
- all_icons = {
+ all_icons: dict[str, str | None] = {
**{data: None for data in data_keys},
- **self.icons,
- **icons,
+ **{str(k): v for k, v in self.icons.items()},
+ **{str(k): v for k, v in icons.items()},
}
data_keys.extend(all_icons)
data_keys.extend(self.styles.keys())
- self._all_spinners = {
- **{spinner: None for spinner in Spinner.__args__},
+ self._all_spinners: dict[SpinnerName, str | None] = {
+ **{spinner: None for spinner in get_args(SpinnerName)},
**self.spinners,
}
- self._all_stylers = {
+ self._all_stylers: dict[str, Any] = {
**{
data: lambda data, value, prefix=None, icon=True: "".join(
[
@@ -420,13 +434,13 @@ def __init__(
),
f"[{data}]",
(
- formatters["quantity"](value)
+ formatters["quantity"](value) # type: ignore[arg-type]
if data in quantity_data
else (
- formatters["datetime"](value)
+ formatters["datetime"](value) # type: ignore[arg-type]
if isinstance(value, datetime)
else (
- formatters["number"](value)
+ formatters["number"](value) # type: ignore[arg-type]
if isinstance(value, int)
or isinstance(value, float)
else value
@@ -438,17 +452,20 @@ def __init__(
)
for data in data_keys
},
- **self.stylers,
+ **{str(k): v for k, v in self.stylers.items()},
+ }
+ self._all_styles: dict[str, str] = {
+ **{data: "" for data in data_keys},
+ **{str(k): v for k, v in self.styles.items()},
}
- self._all_styles = {**{data: "" for data in data_keys}, **self.styles}
self._icons = all_icons
self._ = translator
self._n = translator_plurals
- def get_styles(self) -> dict[Data, str]:
+ def get_styles(self) -> dict[str, str]:
return self._all_styles
- def get_spinner(self, spinner_name: str) -> Spinner:
+ def get_spinner(self, spinner_name: SpinnerName) -> str | None:
return self._all_spinners[spinner_name]
def __call__(self, item: Model | str) -> RenderableType:
@@ -461,8 +478,8 @@ def _f(
prefix: str | None = None,
icon: bool | str = True,
) -> str:
- return (
- self._all_stylers[data](data, value, prefix, icon)
- if data in self._all_stylers and self._all_stylers[data]
- else str(value)
- )
+ styler = self._all_stylers.get(data)
+ if styler is not None:
+ result: str = styler(data, value, prefix, icon) # type: ignore[call-arg]
+ return result
+ return str(value)
diff --git a/src/avalan/cli/theme/fancy.py b/src/avalan/cli/theme/fancy.py
index 9c193c4c..db8dff09 100644
--- a/src/avalan/cli/theme/fancy.py
+++ b/src/avalan/cli/theme/fancy.py
@@ -1,8 +1,7 @@
from ...agent.orchestrator import Orchestrator
-from ...cli.theme import Data, Spinner, Theme
+from ...cli.theme import Data, SpinnerName, Theme
from ...entities import (
EngineMessage,
- EngineMessageScored,
HubCache,
HubCacheDeletion,
ImageEntity,
@@ -12,8 +11,10 @@
SentenceTransformerModelConfig,
Similarity,
Token,
+ TokenDetail,
TokenizerConfig,
ToolCallError,
+ ToolCallResult,
User,
)
from ...event import TOOL_TYPES, Event, EventStats, EventType
@@ -21,13 +22,14 @@
from ...memory.permanent import PermanentMemoryPartition
from ...utils import _j, _lf, to_json
-from datetime import datetime
+from collections.abc import AsyncGenerator
+from datetime import datetime, timedelta
from locale import format_string
from logging import Logger
from math import ceil, inf
from re import sub
from textwrap import wrap
-from typing import Callable, Generator
+from typing import Any, Callable, cast
from uuid import UUID
from humanize import (
@@ -53,94 +55,106 @@
TimeElapsedColumn,
)
from rich.rule import Rule
+from rich.spinner import Spinner as RichSpinner
from rich.table import Column, Table
from rich.text import Text
class FancyTheme(Theme):
+ """Fancy theme implementation with rich formatting and icons."""
+
@property
def icons(self) -> dict[Data, str]:
+ """Return mapping of data keys to emoji icons."""
return {
- "access_token_name": ":lock:",
- "agent_id": ":robot:",
- "agent_output": ":robot:",
- "avalan": ":heavy_large_circle:",
- "author": ":briefcase:",
- "bye": ":vulcan_salute:",
- "can_access": ":white_check_mark:",
- "checking_access": ":mag:",
- "created_at": ":calendar:",
- "disabled": ":cross_mark:",
- "download": ":floppy_disk:",
- "download_access_denied": ":exclamation_mark:",
- "download_finished": ":heavy_check_mark:",
- "downloads": ":floppy_disk:",
- "gated": ":key:",
- "inference": ":brain:",
- "input_token_count": ":laptop_computer:",
- "library_name": ":books:",
- "license": ":balance_scale:",
- "likes": ":orange_heart:",
- "memory": ":brain:",
- "model_id": ":name_badge:",
- "model_type": ":robot_face:",
- "no_access": ":no_entry_sign:",
- "parameters": ":abacus:",
- "pipeline_tag": ":gear:",
- "private": ":closed_lock_with_key:",
- "ranking": ":trophy:",
- "path_blobs": ":file_folder:",
- "path_refs": ":file_folder:",
- "path_repository": ":file_folder:",
- "path_snapshot": ":file_folder:",
- "session": ":card_index_dividers:",
- "task_id": ":robot:",
- "total_tokens": ":abacus:",
- "tokens_rate": ":high_voltage:",
- "events": ":bookmark_tabs:",
- "tool_calls": ":hammer:",
- "tool_call_results": ":package:",
- "ttft": ":seedling:",
- "ttnt": ":alarm_clock:",
- "ttsr": ":thinking_face:",
- "updated_at": ":calendar:",
- "user": ":hugging_face:",
- "user_input": ":speaking_head:",
- "tags": ":label:",
+ cast(Data, "access_token_name"): ":lock:",
+ cast(Data, "agent_id"): ":robot:",
+ cast(Data, "agent_output"): ":robot:",
+ cast(Data, "avalan"): ":heavy_large_circle:",
+ cast(Data, "author"): ":briefcase:",
+ cast(Data, "bye"): ":vulcan_salute:",
+ cast(Data, "can_access"): ":white_check_mark:",
+ cast(Data, "checking_access"): ":mag:",
+ cast(Data, "created_at"): ":calendar:",
+ cast(Data, "disabled"): ":cross_mark:",
+ cast(Data, "download"): ":floppy_disk:",
+ cast(Data, "download_access_denied"): ":exclamation_mark:",
+ cast(Data, "download_finished"): ":heavy_check_mark:",
+ cast(Data, "downloads"): ":floppy_disk:",
+ cast(Data, "gated"): ":key:",
+ cast(Data, "inference"): ":brain:",
+ cast(Data, "input_token_count"): ":laptop_computer:",
+ cast(Data, "library_name"): ":books:",
+ cast(Data, "license"): ":balance_scale:",
+ cast(Data, "likes"): ":orange_heart:",
+ cast(Data, "memory"): ":brain:",
+ cast(Data, "model_id"): ":name_badge:",
+ cast(Data, "model_type"): ":robot_face:",
+ cast(Data, "no_access"): ":no_entry_sign:",
+ cast(Data, "parameters"): ":abacus:",
+ cast(Data, "pipeline_tag"): ":gear:",
+ cast(Data, "private"): ":closed_lock_with_key:",
+ cast(Data, "ranking"): ":trophy:",
+ cast(Data, "path_blobs"): ":file_folder:",
+ cast(Data, "path_refs"): ":file_folder:",
+ cast(Data, "path_repository"): ":file_folder:",
+ cast(Data, "path_snapshot"): ":file_folder:",
+ cast(Data, "session"): ":card_index_dividers:",
+ cast(Data, "task_id"): ":robot:",
+ cast(Data, "total_tokens"): ":abacus:",
+ cast(Data, "tokens_rate"): ":high_voltage:",
+ cast(Data, "events"): ":bookmark_tabs:",
+ cast(Data, "tool_calls"): ":hammer:",
+ cast(Data, "tool_call_results"): ":package:",
+ cast(Data, "ttft"): ":seedling:",
+ cast(Data, "ttnt"): ":alarm_clock:",
+ cast(Data, "ttsr"): ":thinking_face:",
+ cast(Data, "updated_at"): ":calendar:",
+ cast(Data, "user"): ":hugging_face:",
+ cast(Data, "user_input"): ":speaking_head:",
+ cast(Data, "tags"): ":label:",
}
@property
def styles(self) -> dict[Data, str]:
+ """Return mapping of data keys to rich styles."""
return {
- "id": "bold",
- "can_access": "green",
- "checking_access": "bright_black blink",
- "created_at": "magenta",
- "downloads": "bright_black",
- "likes": "bright_black",
- "memory": "magenta",
- "memory_embedding_comparison": "dark_orange3",
- "memory_embedding_comparison_similarity": "dark_orange3",
- "memory_embedding_comparison_similarity_high": (
- "bold dark_olive_green3"
- ),
- "memory_embedding_comparison_similarity_middle": "orange_red1",
- "memory_embedding_comparison_similarity_low": "dark_red",
- "model_id": "cyan",
- "no_access": "bold red",
- "parameters": "bold cyan",
- "participant_id": "bold",
- "ranking": "bright_black",
- "session_id": "dark_orange3",
- "score": "dark_orange3",
- "tags": "gray30",
- "updated_at": "magenta",
- "user": "bold cyan",
- "version": "bold",
+ cast(Data, "id"): "bold",
+ cast(Data, "can_access"): "green",
+ cast(Data, "checking_access"): "bright_black blink",
+ cast(Data, "created_at"): "magenta",
+ cast(Data, "downloads"): "bright_black",
+ cast(Data, "likes"): "bright_black",
+ cast(Data, "memory"): "magenta",
+ cast(Data, "memory_embedding_comparison"): "dark_orange3",
+ cast(
+ Data, "memory_embedding_comparison_similarity"
+ ): "dark_orange3",
+ cast(
+ Data, "memory_embedding_comparison_similarity_high"
+ ): "bold dark_olive_green3",
+ cast(
+ Data, "memory_embedding_comparison_similarity_middle"
+ ): "orange_red1",
+ cast(
+ Data, "memory_embedding_comparison_similarity_low"
+ ): "dark_red",
+ cast(Data, "model_id"): "cyan",
+ cast(Data, "no_access"): "bold red",
+ cast(Data, "parameters"): "bold cyan",
+ cast(Data, "participant_id"): "bold",
+ cast(Data, "ranking"): "bright_black",
+ cast(Data, "session_id"): "dark_orange3",
+ cast(Data, "score"): "dark_orange3",
+ cast(Data, "tags"): "gray30",
+ cast(Data, "updated_at"): "magenta",
+ cast(Data, "user"): "bold cyan",
+ cast(Data, "version"): "bold",
}
@property
- def spinners(self) -> dict[Spinner, str]:
+ def spinners(self) -> dict[SpinnerName, str]:
+ """Return mapping of spinner names to rich spinner types."""
return {
"agent_loading": "dots12",
"cache_accessing": "bouncingBar",
@@ -152,6 +166,7 @@ def spinners(self) -> dict[Spinner, str]:
@property
def quantity_data(self) -> list[str]:
+ """Return list of data keys that should use quantity formatting."""
return ["likes"]
def action(
@@ -164,10 +179,15 @@ def action(
highlight: bool,
finished: bool,
) -> RenderableType:
+ """Render an action panel."""
_i = self._icons
description_color = (
"green" if finished else "white" if highlight else "gray62"
)
+ author_icon = _i.get(cast(Data, "author")) or ""
+ library_icon = _i.get(cast(Data, "library_name")) or ""
+ task_icon = _i.get(cast(Data, "task_id")) or ""
+ model_icon = _i.get(cast(Data, "model_id")) or ""
return Panel.fit(
Padding(
Group(
@@ -178,11 +198,11 @@ def action(
f"{description}[/{description_color}]"
),
(
- _i["author"]
+ author_icon
+ (
f" [bright_black]{author}[/bright_black]"
+ " · "
- + _i["library_name"]
+ + library_icon
+ f" [bright_black]{library_name}"
+ "[/bright_black]"
)
@@ -193,34 +213,39 @@ def action(
)
)
),
- title=_i["task_id"] + f" [cyan]{name}[/cyan]",
- subtitle=_i["model_id"]
- + f" [bright_black]{model_id}[/bright_black]",
+ title=(task_icon + f" [cyan]{name}[/cyan]"),
+ subtitle=(
+ model_icon + f" [bright_black]{model_id}[/bright_black]"
+ ),
box=box.DOUBLE if highlight else box.SQUARE,
)
def agent(
self,
agent: Orchestrator,
- *args,
+ *args: Any,
models: list[Model | str],
can_access: bool | None,
) -> RenderableType:
+ """Render an agent panel with model information."""
_, _f, _i = self._, self._f, self._icons
+ model_id_icon = _i.get(cast(Data, "model_id")) or ""
models_group = Group(
*_lf(
[
- _i["model_id"]
+ model_id_icon
+ " "
+ ", ".join(
[
(
_("{model_id} ({parameters})").format(
model_id=_f(
- "model_id", model.id, icon=False
+ cast(Data, "model_id"),
+ model.id,
+ icon=False,
),
parameters=_f(
- "parameters",
+ cast(Data, "parameters"),
_("{n} params").format(
n=self._parameter_count(
model.parameters
@@ -236,42 +261,52 @@ def agent(
]
),
_f(
- "memory",
+ cast(Data, "memory"),
_j(
", ",
- [
- (
- _("short-term message")
- if agent.memory.has_recent_message
- else None
- ),
- (
- _("long-term message ({driver})").format(
- driver=type(
- agent.memory.permanent_message
- ).__name__
- )
- if agent.memory.has_permanent_message
- else None
- ),
- ],
+ cast(
+ list[str],
+ [
+ (
+ _("short-term message")
+ if agent.memory.has_recent_message
+ else None
+ ),
+ (
+ _(
+ "long-term message ({driver})"
+ ).format(
+ driver=type(
+ agent.memory.permanent_message
+ ).__name__
+ )
+ if agent.memory.has_permanent_message
+ else None
+ ),
+ ],
+ ),
empty=_("stateless"),
),
),
(
_f(
- "session",
+ cast(Data, "session"),
" "
+ _("session: {session_id}").format(
session_id=_f(
- "session_id",
- str(
- agent.memory.permanent_message.session_id
+ cast(Data, "session_id"),
+ (
+ str(
+ agent.memory.permanent_message.session_id
+ )
+ if agent.memory.permanent_message
+ else ""
),
)
),
)
if agent.memory.has_permanent_message
+ and agent.memory.permanent_message
and agent.memory.permanent_message.has_session
else None
),
@@ -280,37 +315,48 @@ def agent(
)
return Panel(
models_group,
- title=_f("agent_id", agent.name if agent.name else str(agent.id)),
+ title=_f(
+ cast(Data, "agent_id"),
+ agent.name if agent.name else str(agent.id),
+ ),
box=box.DOUBLE,
)
def ask_access_token(self) -> str:
+ """Return prompt text for access token input."""
_ = self._
return _("Enter your Huggingface access token")
def ask_delete_paths(self) -> str:
+ """Return prompt text for delete paths confirmation."""
_ = self._
return _("Delete selected paths?")
def ask_login_to_hub(self) -> str:
+ """Return prompt text for hub login confirmation."""
_ = self._
return _("Login to huggingface?")
def ask_secret_password(self, key: str) -> str:
+ """Return prompt text for secret password input."""
_ = self._
return _("Enter secret for {key}").format(key=key)
def ask_override_secret(self, key: str) -> str:
+ """Return prompt text for secret override confirmation."""
_ = self._
return _("Secret {key} exists, override?").format(key=key)
def bye(self) -> RenderableType:
+ """Return goodbye message."""
_, _i = self._, self._icons
- return _i["bye"] + " " + _("bye :)")
+ bye_icon = _i.get(cast(Data, "bye")) or ""
+ return bye_icon + " " + _("bye :)")
def cache_delete(
- self, cache_deletion: HubCacheDeletion | None, deleted=False
+ self, cache_deletion: HubCacheDeletion | None, deleted: bool = False
) -> RenderableType:
+ """Render cache deletion summary."""
_, _f, _n, _i = self._, self._f, self._n, self._icons
if not cache_deletion or (
not cache_deletion.deletable_blobs
@@ -331,7 +377,9 @@ def cache_delete(
"{total_revisions} revisions for {model_id}",
total_revisions,
).format(
- model_id=_f("model_id", cache_deletion.model_id),
+ model_id=_f(
+ cast(Data, "model_id"), cache_deletion.model_id
+ ),
total_revisions=total_revisions,
disk_space=naturalsize(
cache_deletion.deletable_size_on_disk
@@ -360,7 +408,10 @@ def cache_delete(
if deletable_paths:
panel = Panel(
Group(
- *[_f(field_name, path) for path in deletable_paths]
+ *[
+ _f(cast(Data, field_name), path)
+ for path in deletable_paths
+ ]
),
title=title,
)
@@ -374,7 +425,9 @@ def cache_delete(
"{total_revisions} revisions for {model_id}",
total_revisions,
).format(
- model_id=_f("model_id", cache_deletion.model_id),
+ model_id=_f(
+ cast(Data, "model_id"), cache_deletion.model_id
+ ),
total_revisions=total_revisions,
disk_space=naturalsize(
cache_deletion.deletable_size_on_disk
@@ -391,6 +444,7 @@ def cache_list(
display_models: list[str] | None = None,
show_summary: bool = False,
) -> RenderableType:
+ """Render cache list table."""
_ = self._
if display_models and not show_summary:
@@ -450,12 +504,12 @@ def cache_list(
tables.append(Padding(table, pad=(1, 0, 1, 0)))
return Group(*tables)
else:
- display_models = (
+ filtered_models: list[HubCache] = (
[m for m in cached_models if m.model_id in display_models]
if display_models
else cached_models
)
- total_cache_size = sum([m.size_on_disk for m in display_models])
+ total_cache_size = sum([m.size_on_disk for m in filtered_models])
table = Table(
Column(header=_("Model"), justify="left", no_wrap=True),
Column(header=_("Revisions"), justify="left"),
@@ -473,7 +527,7 @@ def cache_list(
footer_style="bold cyan",
)
- for model_cache in display_models:
+ for model_cache in filtered_models:
summarized_revisions = [r[:6] for r in model_cache.revisions]
table.add_row(
model_cache.model_id,
@@ -490,7 +544,9 @@ def cache_list(
def download_access_denied(
self, model_id: str, model_url: str
) -> RenderableType:
+ """Render download access denied message."""
_, _i = self._, self._icons
+ access_denied_icon = _i.get(cast(Data, "download_access_denied")) or ""
return Group(
*_lf(
[
@@ -498,7 +554,7 @@ def download_access_denied(
" ".join(
[
"[bold red]"
- + _i["download_access_denied"]
+ + access_denied_icon
+ "[/bold red]",
"[red]"
+ _(
@@ -521,17 +577,21 @@ def download_access_denied(
)
def download_start(self, model_id: str) -> RenderableType:
+ """Render download start message."""
_, _i = self._, self._icons
+ download_icon = _i.get(cast(Data, "download")) or ""
return Group(
- _i["download"]
+ download_icon
+ " "
+ _("Downloading model {model_id}:").format(model_id=model_id),
"",
)
- def download_progress(self) -> tuple[str | RenderableType]:
- _ = self._
- return (
+ def download_progress(
+ self,
+ ) -> tuple[str | RenderableType]: # type: ignore[override]
+ """Return progress bar components for download."""
+ return ( # type: ignore[return-value]
SpinnerColumn(),
(
"[progress.description]{task.description}"
@@ -546,11 +606,13 @@ def download_progress(self) -> tuple[str | RenderableType]:
)
def download_finished(self, model_id: str, path: str) -> RenderableType:
+ """Render download finished message."""
_, _i = self._, self._icons
+ finished_icon = _i.get(cast(Data, "download_finished")) or ""
return Padding(
" ".join(
[
- "[bold green]" + _i["download_finished"] + "[/bold green]",
+ "[bold green]" + finished_icon + "[/bold green]",
_("Downloaded model {model_id} to {path}").format(
model_id=model_id, path=path
),
@@ -558,7 +620,7 @@ def download_finished(self, model_id: str, path: str) -> RenderableType:
)
)
- def events(
+ def events( # type: ignore[override]
self,
events: list[Event],
*,
@@ -569,7 +631,8 @@ def events(
include_tools: bool = True,
include_non_tools: bool = True,
tool_view: bool = False,
- ) -> RenderableType:
+ ) -> RenderableType | None:
+ """Render events panel."""
_ = self._
event_log = self._events_log(
@@ -580,7 +643,7 @@ def events(
include_tools=include_tools,
include_non_tools=include_non_tools,
)
- panel = (
+ panel: Panel | None = (
Panel(
_j("\n", event_log),
title=_("Tool calls") if tool_view else _("Events"),
@@ -598,6 +661,7 @@ def events(
return panel
def logging_in(self, domain: str) -> str:
+ """Return logging in message."""
_ = self._
return _("Logging in to {domain}...").format(domain=domain)
@@ -605,7 +669,7 @@ def memory_embeddings(
self,
input_string: str,
embeddings: ndarray,
- *args,
+ *args: Any,
total_tokens: int,
minv: float,
maxv: float,
@@ -619,6 +683,7 @@ def memory_embeddings(
partition: int | None = None,
total_partitions: int | None = None,
) -> RenderableType:
+ """Render memory embeddings table."""
_ = self._
assert (
@@ -664,13 +729,13 @@ def memory_embeddings(
):
peek_table.add_column(intcomma(i), justify="center")
- columns = []
- for i, v in enumerate(embeddings[:embedding_peek]):
+ columns: list[str] = []
+ for v in embeddings[:embedding_peek]:
columns.append(clamp(v, format="{:.4g}"))
columns.append("")
- for i, v in enumerate(embeddings[-embedding_peek:]):
+ for v in embeddings[-embedding_peek:]:
columns.append(clamp(v, format="{:.4g}"))
peek_table.add_row(*columns)
@@ -738,6 +803,7 @@ def memory_embeddings(
def memory_embeddings_comparison(
self, similarities: dict[str, Similarity], most_similar: str
) -> RenderableType:
+ """Render memory embeddings comparison table."""
assert similarities and most_similar
_, _f = self._, self._f
table = Table(
@@ -767,32 +833,32 @@ def memory_embeddings_comparison(
table.add_row(
_f(
- field_class,
+ cast(Data, field_class),
compare_string,
- icon=":trophy: " if is_most else None,
+ icon=":trophy: " if is_most else False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(similarity.cosine_distance, format="{:.4g}"),
icon=False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(similarity.l1_distance, format="{:.4g}"),
icon=False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(similarity.l2_distance, format="{:.4g}"),
icon=False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(similarity.inner_product, format="{:.4g}"),
icon=False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(similarity.pearson, format="{:.4g}"),
icon=False,
),
@@ -802,9 +868,10 @@ def memory_embeddings_comparison(
def memory_embeddings_search(
self,
matches: list[SearchMatch],
- *args,
+ *args: Any,
match_preview_length: int = 300,
) -> RenderableType:
+ """Render memory embeddings search results."""
assert matches
_, _f = self._, self._f
table = Table(
@@ -830,12 +897,12 @@ def memory_embeddings_search(
is_most = i == 0
table.add_row(
_f(
- field_class,
+ cast(Data, field_class),
match.query,
- icon=":trophy: " if is_most else None,
+ icon=":trophy: " if is_most else False,
),
_f(
- field_class,
+ cast(Data, field_class),
(
match.match
if len(match.match) <= match_preview_length
@@ -844,7 +911,7 @@ def memory_embeddings_search(
icon=False,
),
_f(
- field_class,
+ cast(Data, field_class),
clamp(match.l2_distance, format="{:.4g}"),
icon=False,
),
@@ -852,9 +919,12 @@ def memory_embeddings_search(
return Align(table, align="center")
def memory_partitions(
- self, partitions: list[TextPartition], *args, display_partitions: int
+ self,
+ partitions: list[TextPartition],
+ *args: Any,
+ display_partitions: int,
) -> RenderableType:
- _ = self._
+ """Render memory partitions."""
total_partitions = len(partitions)
head_count: int = total_partitions
tail_count: int = 0
@@ -884,7 +954,7 @@ def memory_partitions(
if head_count and tail_count:
elements.append(
- Align(Padding(_("..."), pad=(0, 0, 1, 0)), align="center")
+ Align(Padding(self._("..."), pad=(0, 0, 1, 0)), align="center")
)
if tail_count:
@@ -909,11 +979,12 @@ def memory_partitions(
def model(
self,
model: Model,
- *args,
+ *args: Any,
can_access: bool | None = None,
expand: bool = False,
summary: bool = False,
) -> RenderableType:
+ """Render model panel."""
assert (not expand and not summary) or (
expand ^ summary
), "From expand and summary, only one can be set"
@@ -925,114 +996,179 @@ def model(
[
_j(
" · ",
- [
- _j(
- " ",
+ cast(
+ list[str],
+ [
+ _j(
+ " ",
+ cast(
+ list[str],
+ [
+ (
+ _f(
+ cast(
+ Data,
+ "checking_access",
+ ),
+ _("checking access"),
+ )
+ if can_access is None
+ else (
+ _f(
+ cast(
+ Data,
+ "can_access",
+ ),
+ _(
+ "access"
+ " granted"
+ ),
+ )
+ if can_access
+ else _f(
+ cast(
+ Data,
+ "no_access",
+ ),
+ _("access denied"),
+ )
+ )
+ ),
+ _f(
+ cast(Data, "author"),
+ model.author,
+ ),
+ (
+ _f(
+ cast(Data, "license"),
+ model.license,
+ )
+ if expand and model.license
+ else None
+ ),
+ (
+ _f(
+ cast(Data, "gated"),
+ _("gated"),
+ )
+ if model.gated
+ else None
+ ),
+ (
+ _f(
+ cast(Data, "private"),
+ _("private"),
+ )
+ if model.private
+ else None
+ ),
+ (
+ _f(
+ cast(Data, "disabled"),
+ _("disabled"),
+ )
+ if model.disabled
+ else None
+ ),
+ ],
+ ),
+ ),
+ (
+ (
+ (
+ _i.get(
+ cast(Data, "created_at")
+ )
+ or ""
+ )
+ + " "
+ + _j(
+ ", ",
+ cast(
+ list[str],
+ [
+ (
+ _f(
+ cast(
+ Data,
+ "created_at",
+ ),
+ model.created_at,
+ _("created: "),
+ icon=False,
+ )
+ if expand
+ else None
+ ),
+ _f(
+ cast(
+ Data,
+ "updated_at",
+ ),
+ model.updated_at,
+ _("updated: "),
+ icon=False,
+ ),
+ ],
+ ),
+ )
+ )
+ if not summary
+ else None
+ ),
+ ],
+ ),
+ ),
+ (
+ _j(
+ " · ",
+ cast(
+ list[str],
[
(
_f(
- "checking_access",
- _("checking access"),
+ cast(Data, "model_type"),
+ model.model_type,
)
- if can_access is None
- else (
- _f(
- "can_access",
- _("access granted"),
- )
- if can_access
- else _f(
- "no_access",
- _("access denied"),
+ + (
+ " ("
+ + ", ".join(
+ model.architectures
)
+ + ")"
+ if expand
+ and model.architectures
+ else ""
)
- ),
- _f("author", model.author),
- (
- _f("license", model.license)
- if expand and model.license
+ if model.model_type
else None
),
(
- _f("gated", _("gated"))
- if model.gated
+ _f(
+ cast(Data, "library_name"),
+ model.library_name,
+ )
+ if model.library_name
else None
),
(
- _f("private", _("private"))
- if model.private
+ _f(
+ cast(Data, "inference"),
+ model.inference,
+ )
+ if expand and model.inference
else None
),
(
- _f("disabled", _("disabled"))
- if model.disabled
+ _f(
+ cast(Data, "pipeline_tag"),
+ model.pipeline_tag,
+ )
+ if model.pipeline_tag
else None
),
],
),
- (
- (
- _i["created_at"]
- + " "
- + _j(
- ", ",
- [
- (
- _f(
- "created_at",
- model.created_at,
- _("created: "),
- icon=False,
- )
- if expand
- else None
- ),
- _f(
- "updated_at",
- model.updated_at,
- _("updated: "),
- icon=False,
- ),
- ],
- )
- )
- if not summary
- else None
- ),
- ],
- ),
- (
- _j(
- " · ",
- [
- (
- _f("model_type", model.model_type)
- + (
- " ("
- + ", ".join(model.architectures)
- + ")"
- if expand and model.architectures
- else ""
- )
- if model.model_type
- else None
- ),
- (
- _f("library_name", model.library_name)
- if model.library_name
- else None
- ),
- (
- _f("inference", model.inference)
- if expand and model.inference
- else None
- ),
- (
- _f("pipeline_tag", model.pipeline_tag)
- if model.pipeline_tag
- else None
- ),
- ],
)
if not summary
else None
@@ -1043,7 +1179,7 @@ def model(
else None
),
(
- _f("tags", " " + ", ".join(model.tags))
+ _f(cast(Data, "tags"), " " + ", ".join(model.tags))
if expand and model.tags
else None
),
@@ -1052,26 +1188,29 @@ def model(
),
# Model ID
title=(
- _f("model_id", model.id)
+ _f(cast(Data, "model_id"), model.id)
+ (
" "
+ _j(
" ",
- [
- _f(
- "parameters",
- self._parameter_count(model.parameters),
- ),
- (
+ cast(
+ list[str],
+ [
_f(
- "parameter_types",
- ", ".join(model.parameter_types),
- )
- if expand and model.parameter_types
- else None
- ),
- _("parameters") if expand else _("params"),
- ],
+ cast(Data, "parameters"),
+ self._parameter_count(model.parameters),
+ ),
+ (
+ _f(
+ cast(Data, "parameter_types"),
+ ", ".join(model.parameter_types),
+ )
+ if expand and model.parameter_types
+ else None
+ ),
+ _("parameters") if expand else _("params"),
+ ],
+ ),
)
)
if not summary
@@ -1081,19 +1220,26 @@ def model(
subtitle=(
_j(
" ",
- [
- (
- _f("downloads", model.downloads)
- if model.downloads
- else None
- ),
- _f("likes", model.likes) if model.likes else None,
- (
- _f("ranking", model.ranking)
- if model.ranking
- else None
- ),
- ],
+ cast(
+ list[str],
+ [
+ (
+ _f(cast(Data, "downloads"), model.downloads)
+ if model.downloads
+ else None
+ ),
+ (
+ _f(cast(Data, "likes"), model.likes)
+ if model.likes
+ else None
+ ),
+ (
+ _f(cast(Data, "ranking"), model.ranking)
+ if model.ranking
+ else None
+ ),
+ ],
+ ),
)
if expand
else None
@@ -1105,10 +1251,11 @@ def model_display(
self,
model_config: ModelConfig | SentenceTransformerModelConfig | None,
tokenizer_config: TokenizerConfig,
- *args,
+ *args: Any,
is_runnable: bool | None = None,
summary: bool = False,
) -> RenderableType:
+ """Render model display with config and tokenizer info."""
_ = self._
return Group(
*_lf(
@@ -1153,10 +1300,11 @@ def model_display(
def _sentence_transformer_model_config(
self,
config: SentenceTransformerModelConfig,
- *args,
+ *args: Any,
is_runnable: bool | None,
summary: bool,
) -> RenderableType:
+ """Render sentence transformer model config table."""
_ = self._
config_table = Table(
Column(header="", justify="right"),
@@ -1193,10 +1341,11 @@ def _sentence_transformer_model_config(
def _model_config(
self,
config: ModelConfig,
- *args,
+ *args: Any,
is_runnable: bool | None,
summary: bool,
) -> RenderableType:
+ """Render model config table."""
config_table = Table(
Column(header="", justify="right"),
Column(header="", justify="left"),
@@ -1215,10 +1364,11 @@ def _fill_model_config_table(
self,
config: ModelConfig,
config_table: Table,
- *args,
+ *args: Any,
is_runnable: bool | None,
summary: bool,
) -> Table:
+ """Fill model config table with configuration details."""
_ = self._
config_table.add_row(
_("Model type"), f"[bold]{config.model_type}[/bold]"
@@ -1246,24 +1396,26 @@ def _fill_model_config_table(
if not summary:
config_table.add_row(
_("Vocabulary size"),
- f"[magenta]{intcomma(config.vocab_size)}[/magenta]",
+ f"[magenta]{intcomma(config.vocab_size or 0)}[/magenta]",
)
config_table.add_row(
_("Hidden size"),
- f"[magenta]{intcomma(config.hidden_size)}[/magenta]",
+ f"[magenta]{intcomma(config.hidden_size or 0)}[/magenta]",
)
if not summary:
+ num_hidden = config.num_hidden_layers or 0
config_table.add_row(
_("Number of hidden layers"),
- f"[magenta]{intcomma(config.num_hidden_layers)}[/magenta]",
+ f"[magenta]{intcomma(num_hidden)}[/magenta]",
)
config_table.add_row(
_("Number of attention heads"),
- f"[magenta]{intcomma(config.num_attention_heads)}[/magenta]",
+ f"[magenta]{intcomma(config.num_attention_heads or 0)}"
+ "[/magenta]",
)
config_table.add_row(
_("Number of labels in last layer"),
- f"[magenta]{intcomma(config.num_labels)}[/magenta]",
+ f"[magenta]{intcomma(config.num_labels or 0)}[/magenta]",
)
if config.loss_type:
config_table.add_row(
@@ -1332,19 +1484,31 @@ def recent_messages(
participant_id: UUID,
agent: Orchestrator,
messages: list[EngineMessage],
- ):
+ ) -> RenderableType:
+ """Render recent messages panel."""
_, _f, _i = self._, self._f, self._icons
+ agent_output_icon = _i.get(cast(Data, "agent_output")) or ""
+ user_input_icon = _i.get(cast(Data, "user_input")) or ""
group = Group(
*_lf(
[
Panel(
- engine_message.message.content,
+ (
+ str(engine_message.message.content)
+ if engine_message.message.content
+ else ""
+ ),
title=(
- _i["agent_output"] + " " + _f("id", agent.name)
+ agent_output_icon
+ + " "
+ + _f(cast(Data, "id"), agent.name)
if engine_message.is_from_agent
- else _i["user_input"]
+ else user_input_icon
+ " "
- + _f("participant_id", participant_id)
+ + _f(
+ cast(Data, "participant_id"),
+ str(participant_id),
+ )
),
title_align="left",
expand=True,
@@ -1356,9 +1520,10 @@ def recent_messages(
)
return group
- def saved_tokenizer_files(
+ def saved_tokenizer_files( # type: ignore[override]
self, directory_path: str, total_files: int
) -> RenderableType:
+ """Render saved tokenizer files message."""
_n = self._n
return Padding(
_n(
@@ -1373,28 +1538,44 @@ def search_message_matches(
self,
participant_id: UUID,
agent: Orchestrator,
- messages: list[EngineMessageScored],
- ):
+ messages: list[EngineMessage],
+ ) -> RenderableType:
+ """Render search message matches panel."""
_, _f, _i = self._, self._f, self._icons
+ agent_output_icon = _i.get(cast(Data, "agent_output")) or ""
+ user_input_icon = _i.get(cast(Data, "user_input")) or ""
group = Group(
*_lf(
[
Panel(
- engine_message.message.content,
+ (
+ str(engine_message.message.content)
+ if engine_message.message.content
+ else ""
+ ),
title=(
- _i["agent_output"]
+ agent_output_icon
+ " "
- + _f("id", agent.name or str(agent.id))
+ + _f(
+ cast(Data, "id"),
+ agent.name or str(agent.id),
+ )
if engine_message.is_from_agent
- else _i["user_input"]
+ else user_input_icon
+ " "
- + _f("participant_id", participant_id)
+ + _f(
+ cast(Data, "participant_id"),
+ str(participant_id),
+ )
),
title_align="left",
subtitle=_("Matching score: {score}").format(
score=_f(
- "score",
- clamp(engine_message.score, format="{:.8g}"),
+ cast(Data, "score"),
+ clamp(
+ getattr(engine_message, "score", 0.0),
+ format="{:.8g}",
+ ),
)
),
subtitle_align="left",
@@ -1413,16 +1594,18 @@ def memory_search_matches(
namespace: str,
memories: list[PermanentMemoryPartition],
) -> RenderableType:
+ """Render memory search matches panel."""
_, _f, _i = self._, self._f, self._icons
+ memory_icon = _i.get(cast(Data, "memory")) or ""
group = Group(
*_lf(
[
Panel(
memory.data,
title=(
- _i["memory"]
+ memory_icon
+ " "
- + _f("id", str(memory.memory_id))
+ + _f(cast(Data, "id"), str(memory.memory_id))
),
title_align="left",
subtitle=_(
@@ -1430,10 +1613,13 @@ def memory_search_matches(
" Partition: {partition}"
).format(
participant=_f(
- "participant_id", str(participant_id)
+ cast(Data, "participant_id"),
+ str(participant_id),
+ ),
+ ns=_f(cast(Data, "id"), namespace),
+ partition=_f(
+ cast(Data, "number"), memory.partition
),
- ns=_f("id", namespace),
- partition=_f("number", memory.partition),
),
subtitle_align="left",
expand=True,
@@ -1446,8 +1632,9 @@ def memory_search_matches(
return group
def tokenizer_config(
- self, config: TokenizerConfig, *args, summary: bool = False
+ self, config: TokenizerConfig, *args: Any, summary: bool = False
) -> RenderableType:
+ """Render tokenizer config table."""
_ = self._
config_table = Table(
@@ -1469,12 +1656,13 @@ def tokenizer_config(
_("Added tokens"),
", ".join([f"[cyan]{t}[/cyan]" for t in config.tokens]),
)
- config_table.add_row(
- _("Special tokens"),
- ", ".join(
- [f"[cyan]{t}[/cyan]" for t in config.special_tokens]
- ),
- )
+ if config.special_tokens:
+ config_table.add_row(
+ _("Special tokens"),
+ ", ".join(
+ [f"[cyan]{t}[/cyan]" for t in config.special_tokens]
+ ),
+ )
config_table.add_row(
_("Maximum sequence length"),
f"[cyan]{config.tokenizer_model_max_length}[/cyan]",
@@ -1486,15 +1674,18 @@ def tokenizer_config(
return Align(config_table, align="center")
- def tokenizer_tokens(
+ def tokenizer_tokens( # type: ignore[override]
self,
dtokens: list[Token],
added_tokens: list[str] | None,
special_tokens: list[str] | None,
display_details: bool = False,
current_dtoken: Token | None = None,
- dtokens_selected: list[Token] = [],
+ dtokens_selected: list[Token] | None = None,
) -> RenderableType:
+ """Render tokenizer tokens panel."""
+ if dtokens_selected is None:
+ dtokens_selected = []
# Build token panels
compact_dtokens = True # For future configurability
token_panels = [
@@ -1508,7 +1699,7 @@ def tokenizer_tokens(
style=(
"white on dark_green"
if current_dtoken and dtoken == current_dtoken
- else None
+ else ""
),
)
),
@@ -1549,6 +1740,7 @@ def tokenizer_tokens(
def display_image_entities(
self, entities: list[ImageEntity], sort: bool
) -> RenderableType:
+ """Render image entities table."""
_ = self._
table = Table(
Column(header=_("Label"), justify="left"),
@@ -1569,20 +1761,21 @@ def display_image_entities(
for entity in entities:
score = (
- self._f("score", f"{entity.score:.2f}")
+ self._f(cast(Data, "score"), f"{entity.score:.2f}")
if entity.score is not None
else "-"
)
- box = (
+ entity_box = (
", ".join(f"{v:.2f}" for v in entity.box)
if entity.box
else "-"
)
- table.add_row(entity.label, score, box)
+ table.add_row(entity.label, score, entity_box)
return Align(table, align="center")
- def display_image_entity(self, entity: ImageEntity):
+ def display_image_entity(self, entity: ImageEntity) -> RenderableType:
+ """Render single image entity table."""
_ = self._
table = Table(
Column(header=_("Label"), justify="left"),
@@ -1598,6 +1791,7 @@ def display_image_entity(self, entity: ImageEntity):
def display_audio_labels(
self, audio_labels: dict[str, float]
) -> RenderableType:
+ """Render audio labels table."""
_ = self._
table = Table(
Column(header=_("Label"), justify="left"),
@@ -1609,13 +1803,12 @@ def display_audio_labels(
border_style="gray58",
)
for label, score in audio_labels.items():
- score_text = (
- self._f("score", f"{score:.2f}") if score is not None else "-"
- )
+ score_text = self._f(cast(Data, "score"), f"{score:.2f}")
table.add_row(label, score_text)
return Align(table, align="center")
def display_image_labels(self, labels: list[str]) -> RenderableType:
+ """Render image labels table."""
_ = self._
table = Table(
Column(header=_("Label"), justify="left"),
@@ -1632,6 +1825,7 @@ def display_image_labels(self, labels: list[str]) -> RenderableType:
def display_token_labels(
self, token_labels: list[dict[str, str]]
) -> RenderableType:
+ """Render token labels table."""
_ = self._
table = Table(
Column(header=_("Token"), justify="left"),
@@ -1665,7 +1859,7 @@ async def tokens(
tool_events: list[Event] | None,
tool_event_calls: list[Event] | None,
tool_event_results: list[Event] | None,
- tool_running_spinner: Spinner | None,
+ tool_running_spinner: RichSpinner | None,
ttft: float,
ttnt: float | None,
ttsr: float | None,
@@ -1687,7 +1881,8 @@ async def tokens(
limit_tool_height: bool = True,
limit_answer_height: bool = False,
start_thinking: bool = False,
- ) -> Generator[tuple[Token | None, RenderableType], None, None]:
+ ) -> AsyncGenerator[tuple[Token | None, RenderableType], None]:
+ """Generate token display panels asynchronously."""
_, _n, _f, _l = self._, self._n, self._f, logger.debug
pick_first = ceil(pick / 2) if pick > 1 else pick
@@ -1725,16 +1920,18 @@ async def tokens(
"\n".join(wrapped_section).rstrip() if wrapped_section else None
)
- dtokens = (
+ dtokens: list[Token] | None = (
tokens[-display_token_size:]
if display_token_size and tokens
else None
)
- dtokens_selected = (
+ dtokens_selected: list[TokenDetail] | None = (
[
dtoken
for dtoken in dtokens
- if focus_on_token_when and focus_on_token_when(dtoken)
+ if isinstance(dtoken, TokenDetail)
+ and focus_on_token_when
+ and focus_on_token_when(dtoken)
]
if dtokens
else None
@@ -1748,7 +1945,7 @@ async def tokens(
_lf(
[
_f(
- "input_token_count",
+ cast(Data, "input_token_count"),
_n(
"{total_tokens} token in",
"{total_tokens} tokens in",
@@ -1756,7 +1953,7 @@ async def tokens(
).format(total_tokens=input_token_count),
),
_f(
- "total_tokens",
+ cast(Data, "total_tokens"),
_n(
"{total_tokens} token out",
"{total_tokens} tokens out",
@@ -1765,7 +1962,7 @@ async def tokens(
),
(
_f(
- "ttft",
+ cast(Data, "ttft"),
_("ttft: {ttft} s").format(ttft=f"{ttft:.2f}"),
)
if ttft
@@ -1773,7 +1970,7 @@ async def tokens(
),
(
_f(
- "ttnt",
+ cast(Data, "ttnt"),
_("ttnt: {ttnt} s").format(ttnt=f"{ttnt:.1f}"),
)
if ttnt
@@ -1781,21 +1978,21 @@ async def tokens(
),
(
_f(
- "ttsr",
+ cast(Data, "ttsr"),
_("rt: {ttsr} s").format(ttsr=f"{ttsr:.1f}"),
)
if ttsr
else None
),
_f(
- "tokens_rate",
+ cast(Data, "tokens_rate"),
_("{tokens_rate} t/s").format(
tokens_rate=f"{total_tokens / elapsed:.2f}"
),
),
(
_f(
- "events",
+ cast(Data, "events"),
_n(
"{total} event",
"{total} events",
@@ -1807,7 +2004,7 @@ async def tokens(
),
(
_f(
- "tool_calls",
+ cast(Data, "tool_calls"),
_n(
"{total} tool call",
"{total} tool calls",
@@ -1825,7 +2022,7 @@ async def tokens(
),
(
_f(
- "tool_call_results",
+ cast(Data, "tool_call_results"),
_n(
"{total} result",
"{total} results",
@@ -1851,7 +2048,7 @@ async def tokens(
vertical="top",
),
title=_("{model_id} reasoning").format(
- model_id=_f("id", model_id)
+ model_id=_f(cast(Data, "id"), model_id)
),
title_align="left",
subtitle=progress_title if not wrapped_output else None,
@@ -1892,7 +2089,7 @@ async def tokens(
Align(wrapped_output, vertical="top"),
title=(
_("{model_id} response").format(
- model_id=_f("id", model_id)
+ model_id=_f(cast(Data, "id"), model_id)
)
if think_wrapped_output is None
else None
@@ -1915,8 +2112,11 @@ async def tokens(
tool_running_panel: RenderableType | None = None
- if tool_running_spinner and len(tool_event_calls) != len(
- tool_event_results
+ if (
+ tool_running_spinner
+ and tool_event_calls is not None
+ and tool_event_results is not None
+ and len(tool_event_calls) != len(tool_event_results)
):
tool_running_panel = Padding(
tool_running_spinner, pad=(1, 0, 1, 0)
@@ -1951,8 +2151,8 @@ async def tokens(
if display_token_size and tokens:
# Pick current token to highlight
- current_data = None
- current_dtoken: Token | None = None
+ current_data: list[float | None] | None = None
+ current_dtoken: TokenDetail | None = None
if display_probabilities and dtokens_selected:
current_selected_index = (
0
@@ -1986,31 +2186,39 @@ async def tokens(
f'Selected "{current_dtoken.token}" as '
"interesting token, with "
+ clamp(
- current_dtoken.probability, format="{:.4g}"
+ current_dtoken.probability or 0.0,
+ format="{:.4g}",
)
+ f"and {current_dtoken.tokens}"
)
tokens_panel = self.tokenizer_tokens(
- dtokens,
+ dtokens if dtokens else [],
added_tokens,
special_tokens,
display_details=False,
current_dtoken=current_dtoken,
- dtokens_selected=dtokens_selected,
+ dtokens_selected=cast(
+ list[Token] | None, dtokens_selected
+ ),
)
# Build bar chart with token alternative probabilities
chart = None
if display_probabilities:
- current_symmetric_indices = (
- FancyTheme._symmetric_indices(current_data)
+ current_symmetric_indices: list[int] | None = (
+ FancyTheme._symmetric_indices(
+ [v or 0.0 for v in current_data]
+ )
if current_data
else None
)
- current_symmetric_data = (
- [current_data[i] for i in current_symmetric_indices]
- if current_data
+ current_symmetric_data: list[float] | None = (
+ [
+ (current_data[i] or 0.0)
+ for i in current_symmetric_indices
+ ]
+ if current_data and current_symmetric_indices
else None
)
labels = (
@@ -2029,7 +2237,7 @@ async def tokens(
for level in range(chart_height, 0, -1):
chart_row = ""
for value in current_symmetric_data or [
- 0 for i in range(pick)
+ 0.0 for _ in range(pick)
]:
if value * chart_height >= level:
chart_row += "".join(
@@ -2058,7 +2266,8 @@ async def tokens(
if pick > 0 and current_dtoken and current_dtoken.tokens:
dtoken_tokens = current_dtoken.tokens
max_dtoken = max(
- dtoken_tokens, key=lambda dtoken: dtoken.probability
+ dtoken_tokens,
+ key=lambda dt: dt.probability or 0.0,
)
if pick_first is None or len(dtoken_tokens) <= pick_first:
@@ -2170,141 +2379,14 @@ def _events_log(
include_tools: bool,
include_non_tools: bool,
) -> list[str] | None:
+ """Generate event log entries."""
_, _n = self._, self._n
if not events or events_limit == 0:
return None
event_log: list[str] | None = _lf(
[
- (
- _(
- "Executing tool {tool} call #{call_id} with"
- " {total_arguments} arguments: {arguments}."
- ).format(
- tool="[gray78]"
- + event.payload["call"].name
- + "[/gray78]",
- call_id="[gray78]"
- + str(event.payload["call"].id)[:8]
- + "[/gray78]",
- total_arguments=len(
- event.payload["call"].arguments or []
- ),
- arguments="[gray78]"
- + (
- s
- if len(s := str(event.payload["call"].arguments))
- <= 50
- else s[:47] + "..."
- )
- + "[/gray78]",
- )
- if event.type == EventType.TOOL_EXECUTE
- else (
- _n(
- "Running ReACT model {model_id} with"
- " {total_messages} message",
- "Running ReACT model {model_id} with"
- " {total_messages} messages",
- len(event.payload["messages"]),
- ).format(
- model_id=event.payload["model_id"],
- total_messages=len(event.payload["messages"]),
- )
- if event.type == EventType.TOOL_MODEL_RUN
- else (
- _(
- "Got ReACT response from model {model_id}"
- ).format(model_id=event.payload["model_id"])
- if event.type == EventType.TOOL_MODEL_RESPONSE
- else (
- _n(
- "Executing {total_calls} tool: {calls}",
- "Executing {total_calls} tools: {calls}",
- len(event.payload),
- ).format(
- total_calls=len(event.payload),
- calls="[gray78]"
- + "[/gray78], [gray78]".join(
- [call.name for call in event.payload]
- )
- + "[/gray78]",
- )
- if event.type == EventType.TOOL_PROCESS
- else (
- _(
- "Executed tool {tool} call #{call_id}"
- " with {total_arguments} arguments."
- ' Got result "{result}" in'
- " {elapsed_with_unit}."
- ).format(
- tool="[gray78]"
- + event.payload["result"].call.name
- + "[/gray78]",
- elapsed_with_unit="[gray78]"
- + precisedelta(
- event.elapsed,
- minimum_unit="microseconds",
- )
- + "[/gray78]",
- call_id="[gray78]"
- + str(event.payload["result"].call.id)[
- :8
- ]
- + "[/gray78]",
- total_arguments=len(
- event.payload[
- "result"
- ].call.arguments
- or []
- ),
- result=(
- (
- "[red]"
- + event.payload[
- "result"
- ].message
- + "[/red]"
- )
- if isinstance(
- event.payload["result"],
- ToolCallError,
- )
- else (
- "[spring_green3]"
- + to_json(
- event.payload[
- "result"
- ].result
- )
- + "[/spring_green3]"
- )
- ),
- )
- if event.type == EventType.TOOL_RESULT
- and event.payload["result"]
- else (
- f"[{precisedelta(event.elapsed)}]"
- f" <{event.type}>: {event.payload}"
- if event.payload and event.elapsed
- else (
- f"[{datetime.utcfromtimestamp(event.started).isoformat(sep=' ', timespec='seconds')}] <{event.type}>: {event.payload}" # noqa: E501
- if event.payload and event.started
- else (
- f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}] <{event.type}>: {event.payload}" # noqa: E501
- if event.payload
- else (
- f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]" # noqa: E501
- f" <{event.type}>"
- )
- )
- )
- )
- )
- )
- )
- ) # noqa: E501
- )
+ self._format_event(event, _, _n)
for event in events
if (
(
@@ -2332,12 +2414,132 @@ def _events_log(
return event_log
+ def _format_event(
+ self,
+ event: Event,
+ _: Callable[[str], str],
+ _n: Callable[[str, str, int], str],
+ ) -> str | None:
+ """Format a single event for display."""
+ payload = event.payload
+
+ if event.type == EventType.TOOL_EXECUTE and payload:
+ call = payload.get("call")
+ if call:
+ return _(
+ "Executing tool {tool} call #{call_id} with"
+ " {total_arguments} arguments: {arguments}."
+ ).format(
+ tool="[gray78]" + call.name + "[/gray78]",
+ call_id="[gray78]" + str(call.id)[:8] + "[/gray78]",
+ total_arguments=len(call.arguments or []),
+ arguments="[gray78]"
+ + (
+ s
+ if len(s := str(call.arguments)) <= 50
+ else s[:47] + "..."
+ )
+ + "[/gray78]",
+ )
+
+ if event.type == EventType.TOOL_MODEL_RUN and payload:
+ messages = payload.get("messages", [])
+ model_id = payload.get("model_id", "")
+ return _n(
+ "Running ReACT model {model_id} with {total_messages} message",
+ "Running ReACT model {model_id} with {total_messages}"
+ " messages",
+ len(messages),
+ ).format(
+ model_id=model_id,
+ total_messages=len(messages),
+ )
+
+ if event.type == EventType.TOOL_MODEL_RESPONSE and payload:
+ model_id = payload.get("model_id", "")
+ return _("Got ReACT response from model {model_id}").format(
+ model_id=model_id
+ )
+
+ if event.type == EventType.TOOL_PROCESS and payload:
+ calls: list[Any] = payload if isinstance(payload, list) else []
+ return _n(
+ "Executing {total_calls} tool: {calls}",
+ "Executing {total_calls} tools: {calls}",
+ len(calls),
+ ).format(
+ total_calls=len(calls),
+ calls="[gray78]"
+ + "[/gray78], [gray78]".join(
+ [c.name for c in calls if hasattr(c, "name")]
+ )
+ + "[/gray78]",
+ )
+
+ if event.type == EventType.TOOL_RESULT and payload:
+ result = payload.get("result")
+ if result:
+ call = getattr(result, "call", None)
+ if call:
+ elapsed_delta = (
+ timedelta(seconds=event.elapsed)
+ if event.elapsed is not None
+ else timedelta(seconds=0)
+ )
+ result_text: str
+ if isinstance(result, ToolCallError):
+ result_text = "[red]" + result.message + "[/red]"
+ elif isinstance(result, ToolCallResult):
+ result_text = (
+ "[spring_green3]"
+ + to_json(result.result)
+ + "[/spring_green3]"
+ )
+ else:
+ result_text = str(result)
+ return _(
+ "Executed tool {tool} call #{call_id} with"
+ ' {total_arguments} arguments. Got result "{result}"'
+ " in {elapsed_with_unit}."
+ ).format(
+ tool="[gray78]" + call.name + "[/gray78]",
+ elapsed_with_unit="[gray78]"
+ + precisedelta(
+ elapsed_delta,
+ minimum_unit="microseconds",
+ )
+ + "[/gray78]",
+ call_id="[gray78]" + str(call.id)[:8] + "[/gray78]",
+ total_arguments=len(call.arguments or []),
+ result=result_text,
+ )
+
+ # Default format for other event types
+ if payload and event.elapsed is not None:
+ elapsed_delta = timedelta(seconds=event.elapsed)
+ return f"[{precisedelta(elapsed_delta)}] <{event.type}>: {payload}"
+ if payload and event.started:
+ return (
+ f"[{datetime.fromtimestamp(event.started).isoformat(sep=' ', timespec='seconds')}]" # noqa: E501
+ f" <{event.type}>: {payload}"
+ )
+ if payload:
+ return (
+ f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]"
+ f" <{event.type}>: {payload}"
+ )
+ return (
+ f"[{datetime.now().isoformat(sep=' ', timespec='seconds')}]"
+ f" <{event.type}>"
+ )
+
def _tokens_table(
self,
dbatch: list[Token],
current_dtoken: Token | None,
max_dtoken: Token | None,
- ):
+ ) -> Table:
+ """Build token alternatives table."""
_p = self._percentage
dtable_color = "gray58"
@@ -2368,7 +2570,8 @@ def _tokens_table(
table.add_row(
f"[gray50]#{dtoken.id}[/gray50]",
f"[{dtoken_color}]{dtoken.token}[/{dtoken_color}]",
- f"[{dtoken_color}]{_p(dtoken.probability)}[/{dtoken_color}]",
+ f"[{dtoken_color}]{_p(dtoken.probability or 0.0)}"
+ f"[/{dtoken_color}]",
)
return table
@@ -2380,31 +2583,44 @@ def welcome(
license: str,
user: User | None,
) -> RenderableType:
+ """Render welcome panel."""
_, _f, _i = self._, self._f, self._icons
+ avalan_icon = _i.get(cast(Data, "avalan")) or ""
+ license_icon = _i.get(cast(Data, "license")) or ""
license_text = _("{license} license").format(license=license)
return Padding(
Panel(
Padding(
_j(
" - ",
- [
- " ".join(
- [
- _i["avalan"]
- + f" [link={url}]{name}[/link]",
- f"[version]{version}[/version]",
- "[bright_black]"
- + _i["license"]
- + f" {license_text}[/bright_black]",
- ]
- ),
- _f("user", user.name) if user else None,
- (
- _f("access_token_name", user.access_token_name)
- if user
- else None
- ),
- ],
+ cast(
+ list[str],
+ [
+ " ".join(
+ [
+ avalan_icon
+ + f" [link={url}]{name}[/link]",
+ f"[version]{version}[/version]",
+ "[bright_black]"
+ + license_icon
+ + f" {license_text}[/bright_black]",
+ ]
+ ),
+ (
+ _f(cast(Data, "user"), user.name)
+ if user
+ else None
+ ),
+ (
+ _f(
+ cast(Data, "access_token_name"),
+ user.access_token_name,
+ )
+ if user
+ else None
+ ),
+ ],
+ ),
)
),
box=box.SQUARE,
@@ -2413,6 +2629,7 @@ def welcome(
)
def _parameter_count(self, parameters: int | None) -> str:
+ """Format parameter count for display."""
_ = self._
if not parameters:
return _("N/A")
@@ -2423,27 +2640,28 @@ def _parameter_count(self, parameters: int | None) -> str:
)
@staticmethod
- def _symmetric_indices(data: list[float]) -> list[float]:
- """Sorts data desc so that highest values in center lower at edge"""
+ def _symmetric_indices(data: list[float]) -> list[int]:
+ """Sort data desc so that highest values in center lower at edge."""
assert data
sorted_data = sorted(data, reverse=True)
n = len(sorted_data)
- result = [None] * n
+ result: list[int | None] = [None] * n
left = n // 2 - 1
right = n // 2
- for i, value in enumerate(sorted_data):
+ for i, _ in enumerate(sorted_data):
if i % 2 == 0:
result[left] = i
left -= 1
else:
result[right] = i
right += 1
- return result
+ return [r for r in result if r is not None]
@staticmethod
def _percentage(value: float) -> str:
+ """Format value as percentage."""
p = value * 100
return (
format_string("%d%%", p, grouping=True)
@@ -2455,6 +2673,7 @@ def _percentage(value: float) -> str:
def _wrap_lines(
text_tokens: list[str], width: int, skip_blank_lines: bool = False
) -> list[str]:
+ """Wrap text tokens to specified width."""
lines: list[str] = []
output = "".join(text_tokens)
for line in output.splitlines():
diff --git a/src/avalan/compat.py b/src/avalan/compat.py
index 8e8aa82f..8e4b8ba7 100644
--- a/src/avalan/compat.py
+++ b/src/avalan/compat.py
@@ -1,16 +1,8 @@
-from __future__ import annotations
+from collections.abc import Callable
+from typing import TypeVar
-from typing import Callable, TypeVar
-
-try:
- from typing import override as _override
-except ImportError: # Python < 3.12
- T = TypeVar("T", bound=Callable[..., object])
-
- def _override(func: T) -> T: # type: ignore
- return func
+from typing_extensions import override as _override
+T = TypeVar("T", bound=Callable[..., object])
override = _override
-
-__all__ = ["override"]
diff --git a/src/avalan/deploy/aws.py b/src/avalan/deploy/aws.py
index 10222edf..a842d42c 100644
--- a/src/avalan/deploy/aws.py
+++ b/src/avalan/deploy/aws.py
@@ -1,8 +1,8 @@
from asyncio import AbstractEventLoop, get_running_loop
+from collections.abc import Awaitable, Callable
from concurrent.futures import ThreadPoolExecutor
-from typing import Callable
+from typing import Any, cast
-from boto3 import client
from boto3.session import Session
from botocore.exceptions import ClientError
@@ -12,22 +12,27 @@ class DeployError(Exception):
class AsyncClient:
+ """Async wrapper around a boto3 client."""
+
+ exceptions: Any
+
def __init__(
self,
- client: client,
+ client: Any,
loop: AbstractEventLoop | None = None,
executor: ThreadPoolExecutor | None = None,
- ):
+ ) -> None:
self._client = client
self._loop = loop or get_running_loop()
self._executor = executor or ThreadPoolExecutor()
+ self.exceptions = getattr(client, "exceptions", None)
- def __getattr__(self, name: str) -> Callable[..., any]:
+ def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
attr = getattr(self._client, name)
if not callable(attr):
- return attr
+ return cast(Callable[..., Awaitable[Any]], attr)
- async def fn(*args, **kwargs):
+ async def fn(*args: Any, **kwargs: Any) -> Any:
return await self._loop.run_in_executor(
self._executor, lambda: attr(*args, **kwargs)
)
@@ -36,17 +41,21 @@ async def fn(*args, **kwargs):
class Aws:
+ """AWS deployment manager for EC2 and RDS resources."""
+
_ec2: AsyncClient
_rds: AsyncClient
_session: Session
def __init__(
- self, settings: dict | None = None, token_pair: str | None = None
- ):
+ self,
+ settings: dict[str, Any] | None = None,
+ token_pair: str | None = None,
+ ) -> None:
if settings and "token_pair" in settings and not token_pair:
token_pair = settings.pop("token_pair")
- aws_settings = {}
+ aws_settings: dict[str, str] = {}
if token_pair:
access_key, secret_key = token_pair.split(":", 1)
@@ -63,13 +72,14 @@ def __init__(
self._rds = AsyncClient(self._session.client("rds"))
async def get_vpc_id(self, name: str) -> str:
+ """Return the VPC ID for a VPC with the given name."""
response = await self._ec2.describe_vpcs(
Filters=[{"Name": "tag:Name", "Values": [name]}]
)
vpcs = response.get("Vpcs", [])
if not vpcs:
raise DeployError(f"VPC {name!r} not found")
- return vpcs[0]["VpcId"]
+ return str(vpcs[0]["VpcId"])
async def create_vpc_if_missing(self, name: str, cidr: str) -> str:
"""Return an existing VPC id or create a new VPC."""
@@ -77,7 +87,7 @@ async def create_vpc_if_missing(self, name: str, cidr: str) -> str:
return await self.get_vpc_id(name)
except DeployError:
response = await self._ec2.create_vpc(CidrBlock=cidr)
- vpc_id = response["Vpc"]["VpcId"]
+ vpc_id = str(response["Vpc"]["VpcId"])
await self._ec2.create_tags(
Resources=[vpc_id], Tags=[{"Key": "Name", "Value": name}]
)
@@ -86,21 +96,23 @@ async def create_vpc_if_missing(self, name: str, cidr: str) -> str:
return vpc_id
async def get_security_group(self, name: str, vpc_id: str) -> str:
+ """Return the security group ID, creating one if necessary."""
response = await self._ec2.describe_security_groups(
Filters=[{"Name": "group-name", "Values": [name]}]
)
groups = response.get("SecurityGroups", [])
if groups:
- return groups[0]["GroupId"]
+ return str(groups[0]["GroupId"])
response = await self._ec2.create_security_group(
GroupName=name,
Description="avalan deployment",
VpcId=vpc_id,
)
- return response["GroupId"]
+ return str(response["GroupId"])
async def configure_security_group(self, group_id: str, port: int) -> None:
+ """Configure ingress rules for the security group."""
try:
await self._ec2.authorize_security_group_ingress(
GroupId=group_id,
@@ -116,9 +128,10 @@ async def configure_security_group(self, group_id: str, port: int) -> None:
async def create_rds_if_missing(
self, db_id: str, instance_class: str, sg_id: str, storage: int
) -> str:
+ """Create an RDS instance if it does not already exist."""
try:
await self._rds.describe_db_instances(DBInstanceIdentifier=db_id)
- except self._rds.exceptions.DBInstanceNotFoundFault:
+ except self._rds.exceptions.DBInstanceNotFoundFault: # type: ignore[misc]
await self._rds.create_db_instance(
DBInstanceIdentifier=db_id,
DBInstanceClass=instance_class,
@@ -143,13 +156,14 @@ async def create_instance_if_missing(
agent_path: str,
port: int,
) -> str:
+ """Create an EC2 instance if it does not already exist."""
user_data = self._create_user_data(agent_path, port)
response = await self._ec2.describe_instances(
Filters=[{"Name": "tag:Name", "Values": [instance_name]}]
)
reservations = response.get("Reservations", [])
if reservations:
- return reservations[0]["Instances"][0]["InstanceId"]
+ return str(reservations[0]["Instances"][0]["InstanceId"])
response = await self._ec2.describe_subnets(
Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
@@ -172,7 +186,7 @@ async def create_instance_if_missing(
}
],
)
- return response["Instances"][0]["InstanceId"]
+ return str(response["Instances"][0]["InstanceId"])
def _create_user_data(self, agent_path: str, port: int) -> str:
cmd = f"avalan agent serve {agent_path} --host 0.0.0.0 --port {port}\n"
diff --git a/src/avalan/entities.py b/src/avalan/entities.py
index 945445a0..788ed697 100644
--- a/src/avalan/entities.py
+++ b/src/avalan/entities.py
@@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import StrEnum
-from typing import Literal, TypedDict, final
+from typing import Any, Literal, TypedDict, final
from uuid import UUID
from numpy import ndarray
@@ -636,7 +636,7 @@ class ModelConfig:
sep_token: str | None
state_size: int
# Additional keyword arguments to store for the current task
- task_specific_params: dict[str, any] | None
+ task_specific_params: dict[str, Any] | None
# The dtype of the weight. Since the config object is stored in plain
# text, this attribute contains just the floating type string without the
# torch
@@ -759,9 +759,9 @@ class OperationVisionParameters:
class OperationParameters(TypedDict, total=False):
- audio: OperationAudioParameters | None = None
- text: OperationTextParameters | None = None
- vision: OperationVisionParameters | None = None
+ audio: OperationAudioParameters | None
+ text: OperationTextParameters | None
+ vision: OperationVisionParameters | None
@final
@@ -896,7 +896,7 @@ class TransformerEngineSettings(EngineSettings):
low_cpu_mem_usage: bool = False
output_hidden_states: bool = False
special_tokens: list[str] | None = None
- state_dict: dict[str, Tensor] = None
+ state_dict: dict[str, Tensor] | None = None
tokens: list[str] | None = None
diff --git a/src/avalan/event/manager.py b/src/avalan/event/manager.py
index 51c7aa8d..48988720 100644
--- a/src/avalan/event/manager.py
+++ b/src/avalan/event/manager.py
@@ -3,6 +3,7 @@
from asyncio import Event as EventSignal
from asyncio import Queue, TimeoutError, wait_for
from collections import defaultdict, deque
+from collections.abc import AsyncGenerator
from inspect import iscoroutine
from typing import Awaitable, Callable, Iterable
@@ -58,7 +59,7 @@ async def trigger(self, event: Event) -> None:
async def listen(
self, stop_signal: EventSignal | None = None, timeout: float = 0.2
- ):
+ ) -> AsyncGenerator[Event, None]:
while True:
try:
yield await wait_for(self._queue.get(), timeout=timeout)
diff --git a/src/avalan/memory/__init__.py b/src/avalan/memory/__init__.py
index 5ae42d10..ac6a1cbf 100644
--- a/src/avalan/memory/__init__.py
+++ b/src/avalan/memory/__init__.py
@@ -32,7 +32,9 @@ async def search(self, query: str) -> list[T] | None:
class MessageMemory(MemoryStore[EngineMessage], ABC):
- def search(self, query: str) -> list[EngineMessage] | None:
+ def search( # type: ignore[override]
+ self, query: str
+ ) -> list[EngineMessage] | None:
raise NotImplementedError()
@@ -40,17 +42,17 @@ class RecentMessageMemory(MessageMemory):
_lock: Lock
_data: list[EngineMessage]
- def __init__(self, **kwargs) -> None:
+ def __init__(self, **kwargs: object) -> None:
self._lock = Lock()
- self.reset()
+ self._data = []
super().__init__(**kwargs)
@override
- def append(self, data: EngineMessage) -> None:
+ def append(self, data: EngineMessage) -> None: # type: ignore[override]
with self._lock:
self._data.append(data)
- def reset(self) -> None:
+ def reset(self) -> None: # type: ignore[override]
with self._lock:
self._data = []
diff --git a/src/avalan/memory/partitioner/code.py b/src/avalan/memory/partitioner/code.py
index 8fd017d2..3bd16a02 100644
--- a/src/avalan/memory/partitioner/code.py
+++ b/src/avalan/memory/partitioner/code.py
@@ -1,5 +1,4 @@
from ...memory.partitioner import Encoding, PartitionerException
-from ...utils import _j
from dataclasses import dataclass
from logging import Logger
@@ -80,7 +79,9 @@ def partition(
if root_node.children and root_node.children[0].is_error:
error_node = root_node.children[0]
error_name = error_node.grammar_name
- error_message = error_node.text.decode(encoding)
+ error_text = error_node.text
+ assert error_text is not None, "Error node text cannot be None"
+ error_message = error_text.decode(encoding)
error_row, error_column = error_node.start_point
raise PartitionerException(
f'{error_name}: "{error_message}" at'
@@ -133,12 +134,16 @@ def _partition(
for child_node in node.children:
if child_node.type == "class_definition":
class_name_node = child_node.child_by_field_name("name")
- current_class_name = (
- class_name_node.text.decode(encoding)
- if class_name_node
- else None
- )
- class_id = _j(".", [current_namespace, current_class_name])
+ if class_name_node and class_name_node.text:
+ current_class_name = class_name_node.text.decode(encoding)
+ else:
+ current_class_name = None
+ class_parts: list[str] = []
+ if current_namespace:
+ class_parts.append(current_namespace)
+ if current_class_name:
+ class_parts.append(current_class_name)
+ class_id = ".".join(class_parts) if class_parts else ""
child_symbols = current_symbols + [
Symbol(symbol_type="class", id=class_id)
]
@@ -215,11 +220,10 @@ def _get_functions(
results = []
if node.type == "class_definition":
class_name_node = node.child_by_field_name("name")
- class_name = (
- class_name_node.text.decode(encoding)
- if class_name_node
- else None
- )
+ if class_name_node and class_name_node.text:
+ class_name = class_name_node.text.decode(encoding)
+ else:
+ class_name = None
for child in node.children:
functs = cls._get_functions(
current_namespace,
@@ -262,6 +266,9 @@ def _get_function_from_node(
)
params_node = node.child_by_field_name("parameters")
return_type_node = node.child_by_field_name("return_type")
+ return_type: str | None = None
+ if return_type_node and return_type_node.text:
+ return_type = return_type_node.text.decode(encoding)
return Function(
id=function_id,
namespace=current_namespace,
@@ -272,11 +279,7 @@ def _get_function_from_node(
if params_node
else None
),
- return_type=(
- return_type_node.text.decode(encoding)
- if return_type_node
- else None
- ),
+ return_type=return_type,
)
@staticmethod
@@ -291,53 +294,64 @@ def _get_function_id_and_name_from_node(
"async_function_definition",
)
function_name_node = node.child_by_field_name("name")
- assert function_name_node
+ assert function_name_node is not None
+ assert function_name_node.text is not None
function_name = function_name_node.text.decode(encoding)
assert function_name
- function_id = _j(
- ".", [current_namespace, current_class_name, function_name]
- )
+ id_parts: list[str] = []
+ if current_namespace:
+ id_parts.append(current_namespace)
+ if current_class_name:
+ id_parts.append(current_class_name)
+ id_parts.append(function_name)
+ function_id = ".".join(id_parts)
return function_id, function_name
@staticmethod
- def _get_parameters(
- node: Node, encoding: Encoding
- ) -> list[dict[str, str | None]]:
+ def _get_parameters(node: Node, encoding: Encoding) -> list[Parameter]:
assert node
- parameters = []
+ parameters: list[Parameter] = []
for child in node.children:
- if child.type in (
- "default_parameter",
- "typed_default_parameter",
- "typed_parameter",
- ):
- param_type = child.child_by_field_name("type").text.decode(
- encoding
- )
- param_name = None
- if child.named_child_count:
- for parameter_child in child.children:
- if parameter_child.type == "identifier":
- param_name = parameter_child.text.decode(encoding)
- break
- parameters.append(
- Parameter(
- parameter_type=child.type,
- name=param_name,
- type=param_type,
+ match child.type:
+ case (
+ "default_parameter"
+ | "typed_default_parameter"
+ | "typed_parameter"
+ ) as param_kind:
+ type_node = child.child_by_field_name("type")
+ param_type: str | None = None
+ if type_node and type_node.text:
+ param_type = type_node.text.decode(encoding)
+ param_name: str = ""
+ if child.named_child_count:
+ for parameter_child in child.children:
+ if parameter_child.type == "identifier":
+ if parameter_child.text:
+ param_name = parameter_child.text.decode(
+ encoding
+ )
+ break
+ parameters.append(
+ Parameter(
+ parameter_type=param_kind,
+ name=param_name,
+ type=param_type,
+ )
)
- )
- elif child.type in (
- "dictionary_splat_pattern",
- "identifier",
- "keyword_separator",
- ):
- parameters.append(
- Parameter(
- parameter_type=child.type,
- name=child.text.decode(encoding),
- type=None,
+ case (
+ "dictionary_splat_pattern"
+ | "identifier"
+ | "keyword_separator"
+ ) as simple_kind:
+ name: str = ""
+ if child.text:
+ name = child.text.decode(encoding)
+ parameters.append(
+ Parameter(
+ parameter_type=simple_kind,
+ name=name,
+ type=None,
+ )
)
- )
return parameters
diff --git a/src/avalan/memory/partitioner/text.py b/src/avalan/memory/partitioner/text.py
index c2f4997e..803b7d5a 100644
--- a/src/avalan/memory/partitioner/text.py
+++ b/src/avalan/memory/partitioner/text.py
@@ -60,8 +60,8 @@ def configure(
self._window_size = window_size
self._overlap_size = overlap_size
- @override
@property
+ @override
def sentence_model(self) -> Callable:
return self._model
diff --git a/src/avalan/memory/permanent/__init__.py b/src/avalan/memory/permanent/__init__.py
index 0b96cf3d..3642c7f7 100644
--- a/src/avalan/memory/permanent/__init__.py
+++ b/src/avalan/memory/permanent/__init__.py
@@ -175,7 +175,7 @@ def has_session(self) -> bool:
def session_id(self) -> UUID | None:
return self._session_id
- def reset(self) -> None:
+ def reset(self) -> None: # type: ignore[override]
raise NotImplementedError()
async def reset_session(
@@ -198,7 +198,7 @@ async def continue_session(
)
@override
- def append(self, data: EngineMessage) -> None:
+ def append(self, data: EngineMessage) -> None: # type: ignore[override]
raise NotImplementedError()
@abstractmethod
@@ -263,15 +263,14 @@ def _build_message_with_partitions(
if message_id is None:
message_id = uuid4()
content = engine_message.message.content
- data = (
- content.text
- if isinstance(content, MessageContentText)
- else (
- content.image_url
- if isinstance(content, MessageContentImage)
- else str(content)
- )
- )
+ if isinstance(content, MessageContentText):
+ data = content.text
+ elif isinstance(content, MessageContentImage):
+ data = str(content.image_url)
+ elif content is None:
+ data = ""
+ else:
+ data = str(content)
message = PermanentMessage(
id=message_id,
@@ -326,10 +325,10 @@ def __init__(
super().__init__(**kwargs)
@override
- def append(self, data: EngineMessage) -> None:
+ def append(self, data: EngineMessage) -> None: # type: ignore[override]
raise NotImplementedError()
- def reset(self) -> None:
+ def reset(self) -> None: # type: ignore[override]
raise NotImplementedError()
@abstractmethod
diff --git a/src/avalan/memory/permanent/elasticsearch/message.py b/src/avalan/memory/permanent/elasticsearch/message.py
index 00eb8fc0..6991c365 100644
--- a/src/avalan/memory/permanent/elasticsearch/message.py
+++ b/src/avalan/memory/permanent/elasticsearch/message.py
@@ -51,8 +51,8 @@ async def create_instance(
sentence_model=sentence_model,
)
- async def create_session(
- self, *, agent_id: UUID, participant_id: UUID
+ async def create_session( # type: ignore[override]
+ self, agent_id: UUID, participant_id: UUID
) -> UUID:
now_utc = datetime.now(timezone.utc)
session = self._build_session(
@@ -75,7 +75,7 @@ async def create_session(
async def continue_session_and_get_id(
self,
- *,
+ *args: object,
agent_id: UUID,
participant_id: UUID,
session_id: UUID,
@@ -92,7 +92,7 @@ async def continue_session_and_get_id(
async def append_with_partitions(
self,
engine_message: EngineMessage,
- *,
+ *args: object,
partitions: list[TextPartition],
) -> None:
assert engine_message and partitions
@@ -143,7 +143,7 @@ async def get_recent_messages(
self,
session_id: UUID,
participant_id: UUID,
- *,
+ *args: object,
limit: int | None = None,
) -> list[EngineMessage]:
response = await self._call_client(
@@ -173,9 +173,9 @@ async def get_recent_messages(
)
return messages
- async def search_messages(
+ async def search_messages( # type: ignore[override]
self,
- *,
+ *args: object,
agent_id: UUID,
function: VectorFunction,
limit: int | None = None,
diff --git a/src/avalan/memory/permanent/elasticsearch/raw.py b/src/avalan/memory/permanent/elasticsearch/raw.py
index ec84caed..adc92c74 100644
--- a/src/avalan/memory/permanent/elasticsearch/raw.py
+++ b/src/avalan/memory/permanent/elasticsearch/raw.py
@@ -32,7 +32,7 @@ def __init__(
ElasticsearchMemory.__init__(
self, index=index, client=client, logger=logger
)
- PermanentMemory.__init__(self, sentence_model=None)
+ PermanentMemory.__init__(self, sentence_model=None) # type: ignore[arg-type]
@classmethod
async def create_instance(
@@ -47,11 +47,11 @@ async def create_instance(
memory = cls(index=index, client=es_client, logger=logger)
return memory
- async def append_with_partitions(
+ async def append_with_partitions( # type: ignore[override]
self,
namespace: str,
participant_id: UUID,
- *,
+ *args: object,
memory_type: MemoryType,
data: str,
identifier: str,
@@ -113,9 +113,9 @@ async def append_with_partitions(
},
)
- async def search_memories(
+ async def search_memories( # type: ignore[override]
self,
- *,
+ *args: object,
search_partitions: list[TextPartition],
participant_id: UUID,
namespace: str,
@@ -212,7 +212,7 @@ async def list_memories(
)
return memories
- async def search(
+ async def search( # type: ignore[override]
self, query: str
) -> list[PermanentMemoryPartition] | None:
raise NotImplementedError()
diff --git a/src/avalan/memory/permanent/pgsql/__init__.py b/src/avalan/memory/permanent/pgsql/__init__.py
index 74cf2de1..6ae926cf 100644
--- a/src/avalan/memory/permanent/pgsql/__init__.py
+++ b/src/avalan/memory/permanent/pgsql/__init__.py
@@ -9,7 +9,7 @@
from logging import Logger
from time import perf_counter
-from typing import TypeVar
+from typing import Any, TypeVar
from pgvector.psycopg import register_vector_async
from psycopg import AsyncConnection, AsyncCursor
@@ -18,10 +18,11 @@
from psycopg_pool import AsyncConnectionPool
T = TypeVar("T")
+E = TypeVar("E")
class BasePgsqlMemory(MemoryStore[T]):
- _database: AsyncConnection
+ _database: AsyncConnectionPool
_logger: Logger
def __init__(self, database: AsyncConnectionPool, logger: Logger):
@@ -45,8 +46,8 @@ async def _execute(
)
async def _fetch_all(
- self, entity: type[T], query: str, parameters: tuple
- ) -> list[T]:
+ self, entity: type[E], query: str, parameters: tuple
+ ) -> list[E]:
async with self._database.connection() as connection:
async with connection.cursor() as cursor:
await self._execute(cursor, query, parameters)
@@ -59,8 +60,8 @@ async def _fetch_all(
)
async def _fetch_one(
- self, entity: type[T], query: str, parameters: tuple
- ) -> T:
+ self, entity: type[E], query: str, parameters: tuple
+ ) -> E:
result = await self._try_fetch_one(entity, query, parameters)
if result is None:
raise RecordNotFoundException()
@@ -86,8 +87,8 @@ async def _has_one(self, query: str, parameters: tuple) -> bool:
return result is not None
async def _try_fetch_one(
- self, entity: type[T], query: str, parameters: tuple
- ) -> T | None:
+ self, entity: type[E], query: str, parameters: tuple
+ ) -> E | None:
async with self._database.connection() as connection:
async with connection.cursor() as cursor:
await self._execute(cursor, query, parameters)
@@ -96,8 +97,8 @@ async def _try_fetch_one(
return entity(**dict(result)) if result is not None else None
async def _update_and_fetch_one(
- self, entity: type[T], query: str, parameters: tuple
- ) -> T:
+ self, entity: type[E], query: str, parameters: tuple
+ ) -> E:
row = await self._update_and_fetch_row(query, parameters)
return entity(**row)
@@ -105,7 +106,9 @@ async def _update_and_fetch_field(
self, field: str, query: str, parameters: tuple
) -> str:
row = await self._update_and_fetch_row(query, parameters)
- return row[field]
+ value = row[field]
+ assert isinstance(value, str)
+ return value
async def _update_and_fetch_row(
self, query: str, parameters: tuple
@@ -135,33 +138,29 @@ async def create_instance_from_pool(
pool: AsyncConnectionPool,
*,
logger: Logger,
- **kwargs,
- ):
+ **kwargs: Any,
+ ) -> "PgsqlMemory[T]":
memory = cls(dsn=None, pool=pool, logger=logger, **kwargs)
return memory
def __init__(
self,
dsn: str | None,
- *args,
+ *args: object,
pool: AsyncConnectionPool | None = None,
composite_types: list[str] | None = None,
pool_minimum: int | None = None,
pool_maximum: int | None = None,
logger: Logger,
- **kwargs,
+ **kwargs: object,
):
- assert pool or (
- dsn
- and pool_minimum
- and pool_minimum
- and pool_minimum > 0
- and pool_maximum > pool_minimum
- )
-
if pool:
super().__init__(database=pool, logger=logger, **kwargs)
else:
+ assert dsn is not None
+ assert pool_minimum is not None and pool_minimum > 0
+ assert pool_maximum is not None and pool_maximum > pool_minimum
+
self._composite_types = composite_types
if "//" not in dsn:
@@ -176,8 +175,8 @@ def __init__(
)
super().__init__(database=database, logger=logger, **kwargs)
- async def _configure_connection(self, connection: AsyncConnection):
- connection.row_factory = dict_row
+ async def _configure_connection(self, connection: AsyncConnection) -> None:
+ connection.row_factory = dict_row # type: ignore[assignment]
await connection.set_autocommit(True)
if self._composite_types:
for composite_type_name in self._composite_types:
@@ -191,28 +190,33 @@ async def _configure_connection(self, connection: AsyncConnection):
@staticmethod
def _to_engine_messages(
messages: list[PermanentMessage] | list[PermanentMessageScored],
- *args,
+ *args: object,
limit: int | None,
reverse: bool = False,
scored: bool = False,
) -> list[EngineMessage] | list[EngineMessageScored]:
- engine_messages = [
- (
+ engine_messages: list[EngineMessage] | list[EngineMessageScored]
+ if scored:
+ scored_messages: list[EngineMessageScored] = [
EngineMessageScored(
agent_id=m.agent_id,
model_id=m.model_id,
message=Message(role=m.author, content=m.data),
- score=m.score,
+ score=m.score, # type: ignore[attr-defined, union-attr]
)
- if scored
- else EngineMessage(
+ for m in messages
+ ]
+ engine_messages = scored_messages
+ else:
+ unscored_messages: list[EngineMessage] = [
+ EngineMessage(
agent_id=m.agent_id,
model_id=m.model_id,
message=Message(role=m.author, content=m.data),
)
- )
- for m in messages
- ]
+ for m in messages
+ ]
+ engine_messages = unscored_messages
if reverse:
engine_messages.reverse()
if limit and len(engine_messages) > limit:
diff --git a/src/avalan/memory/permanent/pgsql/message.py b/src/avalan/memory/permanent/pgsql/message.py
index d1d6366f..8ea35a3e 100644
--- a/src/avalan/memory/permanent/pgsql/message.py
+++ b/src/avalan/memory/permanent/pgsql/message.py
@@ -15,7 +15,7 @@
from pgvector.psycopg import Vector
-class PgsqlMessageMemory(
+class PgsqlMessageMemory( # type: ignore[misc]
PgsqlMemory[PermanentMessage], PermanentMessageMemory
):
"""PostgreSQL-backed implementation of :class:`PermanentMessageMemory`."""
@@ -44,8 +44,8 @@ async def create_instance(
await memory.open()
return memory
- async def create_session(
- self, *args, agent_id: UUID, participant_id: UUID
+ async def create_session( # type: ignore[override]
+ self, agent_id: UUID, participant_id: UUID
) -> UUID:
"""Create a new session for a participant."""
now_utc = datetime.now(timezone.utc)
@@ -81,13 +81,13 @@ async def create_session(
async def continue_session_and_get_id(
self,
- *args,
+ *args: object,
agent_id: UUID,
participant_id: UUID,
session_id: UUID,
) -> UUID:
"""Continue an existing session if it belongs to the participant."""
- session_id = await self._fetch_field(
+ fetched_session_id = await self._fetch_field(
"id",
"""
SELECT "sessions"."id"
@@ -99,13 +99,15 @@ async def continue_session_and_get_id(
""",
(str(agent_id), str(participant_id), str(session_id)),
)
- assert session_id
- return session_id if isinstance(session_id, UUID) else UUID(session_id)
+ assert fetched_session_id is not None
+ if isinstance(fetched_session_id, UUID):
+ return fetched_session_id
+ return UUID(fetched_session_id)
async def append_with_partitions(
self,
engine_message: EngineMessage,
- *args,
+ *args: object,
partitions: list[TextPartition],
) -> None:
"""Persist a message and its partitions."""
@@ -197,7 +199,7 @@ async def get_recent_messages(
self,
session_id: UUID,
participant_id: UUID,
- *args,
+ *args: object,
limit: int | None = None,
) -> list[EngineMessage]:
"""Retrieve recent messages for a session."""
@@ -227,11 +229,12 @@ async def get_recent_messages(
engine_messages = self._to_engine_messages(
messages, limit=limit_value, reverse=True
)
- return engine_messages
+ assert isinstance(engine_messages, list)
+ return engine_messages # type: ignore[return-value]
- async def search_messages(
+ async def search_messages( # type: ignore[override]
self,
- *args,
+ *args: object,
agent_id: UUID,
function: VectorFunction,
limit: int | None = None,
@@ -297,4 +300,5 @@ async def search_messages(
limit=limit_value,
scored=True,
)
- return engine_messages
+ assert isinstance(engine_messages, list)
+ return engine_messages # type: ignore[return-value]
diff --git a/src/avalan/memory/permanent/pgsql/raw.py b/src/avalan/memory/permanent/pgsql/raw.py
index 635ee4da..4473f03b 100644
--- a/src/avalan/memory/permanent/pgsql/raw.py
+++ b/src/avalan/memory/permanent/pgsql/raw.py
@@ -379,7 +379,10 @@ async def upsert_entity(
),
)
result = await cursor.fetchone()
- entity_id = result["id"] if result else None
+ assert result is not None
+ result_dict = dict(result)
+ entity_id = result_dict.get("id")
+ assert entity_id is not None
await cursor.execute(
"""
@@ -403,4 +406,6 @@ async def upsert_entity(
),
)
await cursor.close()
- return UUID(entity_id) if isinstance(entity_id, str) else entity_id
+ return (
+ UUID(entity_id) if isinstance(entity_id, str) else UUID(entity_id)
+ )
diff --git a/src/avalan/memory/permanent/s3vectors/message.py b/src/avalan/memory/permanent/s3vectors/message.py
index 80f9db00..74f73ee6 100644
--- a/src/avalan/memory/permanent/s3vectors/message.py
+++ b/src/avalan/memory/permanent/s3vectors/message.py
@@ -60,14 +60,14 @@ async def create_instance(
sentence_model=sentence_model,
)
- async def create_session(
- self, *, agent_id: UUID, participant_id: UUID
+ async def create_session( # type: ignore[override]
+ self, agent_id: UUID, participant_id: UUID
) -> UUID:
return uuid4()
async def continue_session_and_get_id(
self,
- *,
+ *args: object,
agent_id: UUID,
participant_id: UUID,
session_id: UUID,
@@ -77,7 +77,7 @@ async def continue_session_and_get_id(
async def append_with_partitions(
self,
engine_message: EngineMessage,
- *,
+ *args: object,
partitions: list[TextPartition],
) -> None:
assert engine_message and partitions
@@ -131,7 +131,7 @@ async def get_recent_messages(
self,
session_id: UUID,
participant_id: UUID,
- *,
+ *args: object,
limit: int | None = None,
) -> list[EngineMessage]:
prefix = f"{self._collection}/{session_id}/"
@@ -157,9 +157,9 @@ async def get_recent_messages(
)
return messages
- async def search_messages(
+ async def search_messages( # type: ignore[override]
self,
- *,
+ *args: object,
agent_id: UUID,
function: VectorFunction,
limit: int | None = None,
diff --git a/src/avalan/memory/permanent/s3vectors/raw.py b/src/avalan/memory/permanent/s3vectors/raw.py
index 1f53d75b..c4fc6c92 100644
--- a/src/avalan/memory/permanent/s3vectors/raw.py
+++ b/src/avalan/memory/permanent/s3vectors/raw.py
@@ -41,7 +41,7 @@ def __init__(
client=client,
logger=logger,
)
- PermanentMemory.__init__(self, sentence_model=None)
+ PermanentMemory.__init__(self, sentence_model=None) # type: ignore[arg-type]
@classmethod
async def create_instance(
@@ -62,11 +62,11 @@ async def create_instance(
)
return memory
- async def append_with_partitions(
+ async def append_with_partitions( # type: ignore[override]
self,
namespace: str,
participant_id: UUID,
- *,
+ *args: object,
memory_type: MemoryType,
data: str,
identifier: str,
@@ -131,9 +131,9 @@ async def append_with_partitions(
},
)
- async def search_memories(
+ async def search_memories( # type: ignore[override]
self,
- *,
+ *args: object,
search_partitions: list[TextPartition],
participant_id: UUID,
namespace: str,
@@ -230,7 +230,7 @@ async def list_memories(
)
return memories
- async def search(
+ async def search( # type: ignore[override]
self, query: str
) -> list[PermanentMemoryPartition] | None:
raise NotImplementedError()
diff --git a/src/avalan/memory/source.py b/src/avalan/memory/source.py
index ab91fdd2..68de9af1 100644
--- a/src/avalan/memory/source.py
+++ b/src/avalan/memory/source.py
@@ -89,11 +89,11 @@ async def _convert_bytes(
if self._is_pdf(url, content_type, data):
metadata = PdfReader(BytesIO(data)).metadata
- metadata_title = (
- metadata["/Title"]
- if metadata and "/Title" in metadata
- else None
- )
+ metadata_title: str | None = None
+ if metadata and "/Title" in metadata:
+ raw_title = metadata["/Title"]
+ if isinstance(raw_title, str):
+ metadata_title = raw_title
title = metadata_title or title or self._markdown_title(markdown)
description = description or self._markdown_description(markdown)
diff --git a/src/avalan/model/audio/__init__.py b/src/avalan/model/audio/__init__.py
index 2510aaab..a934e123 100644
--- a/src/avalan/model/audio/__init__.py
+++ b/src/avalan/model/audio/__init__.py
@@ -1,11 +1,10 @@
from ...model import TokenizerNotSupportedException
from ...model.engine import Engine
-from abc import ABC, abstractmethod
-from typing import Literal
+from abc import ABC
+from typing import Any
-from numpy import ndarray
-from PIL import Image
+from numpy.typing import NDArray
from torch import Tensor
from torchaudio import load
from torchaudio.functional import resample
@@ -16,13 +15,7 @@
class BaseAudioModel(Engine, ABC):
- @abstractmethod
- async def __call__(
- self,
- image_source: str | Image.Image,
- tensor_format: Literal["pt"] = "pt",
- ) -> str:
- raise NotImplementedError()
+ """Base class for audio models."""
def _load_tokenizer(
self, tokenizer_name_or_path: str | None, use_fast: bool = True
@@ -35,25 +28,28 @@ def _load_tokenizer_with_tokens(
raise TokenizerNotSupportedException()
def _resample_mono(self, audio_source: str, sampling_rate: int) -> Tensor:
- wave, wave_sampling_rate = load(audio_source)
+ wave_data: Tensor
+ wave_data, wave_sampling_rate = load(audio_source)
- if wave.shape[0] > 1:
+ if wave_data.shape[0] > 1:
# stereo -> mono
- wave = wave.mean(dim=0)
+ wave_data = wave_data.mean(dim=0) # type: ignore[operator]
else:
# already mono, just drop channel dim (samples,)
- wave = wave.squeeze(0)
+ wave_data = wave_data.squeeze(0) # type: ignore[operator]
if wave_sampling_rate != sampling_rate:
- wave = resample(
- wave.unsqueeze(0), wave_sampling_rate, sampling_rate
+ wave_data = resample(
+ wave_data.unsqueeze(0),
+ wave_sampling_rate,
+ sampling_rate, # type: ignore[operator]
).squeeze(0)
- return wave
+ return wave_data
- def _resample(self, audio_source: str, sampling_rate: int) -> ndarray:
+ def _resample(self, audio_source: str, sampling_rate: int) -> NDArray[Any]:
wave, wave_sampling_rate = load(audio_source)
if wave_sampling_rate != sampling_rate:
wave = resample(wave, wave_sampling_rate, sampling_rate)
- wave = wave.mean(0).numpy()
- return wave
+ result: NDArray[Any] = wave.mean(0).numpy() # type: ignore[union-attr]
+ return result
diff --git a/src/avalan/model/audio/classification.py b/src/avalan/model/audio/classification.py
index f4dc18d9..84bcf59d 100644
--- a/src/avalan/model/audio/classification.py
+++ b/src/avalan/model/audio/classification.py
@@ -3,7 +3,7 @@
from ...model.engine import Engine
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import inference_mode
@@ -15,26 +15,29 @@
class AudioClassificationModel(BaseAudioModel):
- _extractor: AutoFeatureExtractor
+ _extractor: Any # AutoFeatureExtractor with model-specific methods
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id, "model_id is required"
self._extractor = AutoFeatureExtractor.from_pretrained(self._model_id)
- model = AutoModelForAudioClassification.from_pretrained(
- self._model_id,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
- subfolder=self._settings.subfolder or "",
- ).to(self._device)
+ model: PreTrainedModel = (
+ AutoModelForAudioClassification.from_pretrained(
+ self._model_id,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ subfolder=self._settings.subfolder or "",
+ ).to(self._device)
+ )
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
path: str,
*,
@@ -43,6 +46,8 @@ async def __call__(
tensor_format: Literal["pt"] = "pt",
) -> dict[str, float]:
assert path
+ assert self._model is not None, "Model must be loaded"
+ assert isinstance(self._model, PreTrainedModel)
wave = self._resample_mono(path, sampling_rate)
inputs = self._extractor(
@@ -52,11 +57,14 @@ async def __call__(
padding=padding,
).to(self._device)
- id2label = {int(k): v for k, v in self._model.config.id2label.items()}
+ id2label_raw = self._model.config.id2label or {}
+ id2label: dict[int, str] = {
+ int(k): str(v) for k, v in id2label_raw.items()
+ }
labels: dict[str, float] = {}
with inference_mode():
- logits = self._model(**inputs).logits
+ logits = self._model(**inputs).logits # type: ignore[operator]
probs = logits.softmax(dim=-1)[0]
for idx, p in sorted(enumerate(probs), key=lambda x: -x[1]):
labels[id2label[idx]] = p.item()
diff --git a/src/avalan/model/audio/generation.py b/src/avalan/model/audio/generation.py
index ba8c4344..32675fe0 100644
--- a/src/avalan/model/audio/generation.py
+++ b/src/avalan/model/audio/generation.py
@@ -3,7 +3,7 @@
from ...model.engine import Engine
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import from_numpy, inference_mode
@@ -16,25 +16,28 @@
class AudioGenerationModel(BaseAudioModel):
- _processor: AutoProcessor
+ _processor: Any # AutoProcessor with Musicgen-specific methods
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id, "model_id is required"
self._processor = AutoProcessor.from_pretrained(self._model_id)
- model = MusicgenForConditionalGeneration.from_pretrained(
- self._model_id,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
- subfolder=self._settings.subfolder or "",
- ).to(self._device)
+ model: PreTrainedModel = (
+ MusicgenForConditionalGeneration.from_pretrained(
+ self._model_id,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ subfolder=self._settings.subfolder or "",
+ ).to(self._device)
+ )
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
prompt: str,
path: str,
@@ -44,6 +47,8 @@ async def __call__(
tensor_format: Literal["pt"] = "pt",
) -> str:
assert path
+ assert self._model is not None, "Model must be loaded"
+ assert isinstance(self._model, PreTrainedModel)
inputs = self._processor(
text=[prompt], return_tensors=tensor_format, padding=padding
@@ -51,11 +56,11 @@ async def __call__(
inputs.to(self._device)
with inference_mode():
- audio_tokens = self._model.generate(
+ audio_tokens = self._model.generate( # type: ignore[attr-defined,operator]
**inputs, max_new_tokens=max_new_tokens
)
- sampling_rate = self._model.config.audio_encoder.sampling_rate
+ sampling_rate: int = self._model.config.audio_encoder.sampling_rate
waveform = audio_tokens[0, 0].cpu().numpy()
wave_tensor = from_numpy(waveform).unsqueeze(0)
save(path, wave_tensor, sampling_rate)
diff --git a/src/avalan/model/audio/speech.py b/src/avalan/model/audio/speech.py
index ca57e729..f4ea5eea 100644
--- a/src/avalan/model/audio/speech.py
+++ b/src/avalan/model/audio/speech.py
@@ -3,7 +3,7 @@
from ...model.engine import Engine
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import inference_mode
@@ -15,17 +15,18 @@
class TextToSpeechModel(BaseAudioModel):
- _processor: AutoProcessor
+ _processor: Any # AutoProcessor with DiaModel-specific methods
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id, "model_id is required"
self._processor = AutoProcessor.from_pretrained(
self._model_id,
trust_remote_code=self._settings.trust_remote_code,
subfolder=self._settings.tokenizer_subfolder or "",
)
- model = DiaForConditionalGeneration.from_pretrained(
+ model: PreTrainedModel = DiaForConditionalGeneration.from_pretrained(
self._model_id,
trust_remote_code=self._settings.trust_remote_code,
device_map=self._device,
@@ -38,7 +39,7 @@ def _load_model(
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
prompt: str,
path: str,
@@ -50,6 +51,7 @@ async def __call__(
sampling_rate: int = 44_100,
tensor_format: Literal["pt"] = "pt",
) -> str:
+ assert self._model is not None, "Model must be loaded"
assert (not reference_path and not reference_text) or (
reference_path and reference_text
)
@@ -81,7 +83,8 @@ async def __call__(
)
with inference_mode():
- outputs = self._model.generate(
+ assert isinstance(self._model, PreTrainedModel)
+ outputs = self._model.generate( # type: ignore[operator]
**inputs, max_new_tokens=max_new_tokens
)
diff --git a/src/avalan/model/audio/speech_recognition.py b/src/avalan/model/audio/speech_recognition.py
index 3a536558..ff5dd30a 100644
--- a/src/avalan/model/audio/speech_recognition.py
+++ b/src/avalan/model/audio/speech_recognition.py
@@ -3,7 +3,7 @@
from ...model.engine import Engine
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import argmax, inference_mode
@@ -15,11 +15,12 @@
class SpeechRecognitionModel(BaseAudioModel):
- _processor: AutoProcessor
+ _processor: Any # AutoProcessor with Wav2Vec2-specific methods
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id, "model_id is required"
self._processor = AutoProcessor.from_pretrained(
self._model_id,
trust_remote_code=self._settings.trust_remote_code,
@@ -27,10 +28,14 @@ def _load_model(
use_fast=True,
subfolder=self._settings.tokenizer_subfolder or "",
)
- model = AutoModelForCTC.from_pretrained(
+ # AutoProcessor for CTC models has a tokenizer attribute
+ pad_token_id = getattr(
+ getattr(self._processor, "tokenizer", None), "pad_token_id", None
+ )
+ model: PreTrainedModel = AutoModelForCTC.from_pretrained(
self._model_id,
trust_remote_code=self._settings.trust_remote_code,
- pad_token_id=self._processor.tokenizer.pad_token_id,
+ pad_token_id=pad_token_id,
ctc_loss_reduction="mean",
device_map=self._device,
tp_plan=Engine._get_tp_plan(self._settings.parallel),
@@ -43,12 +48,13 @@ def _load_model(
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
path: str,
sampling_rate: int = 16_000,
tensor_format: Literal["pt"] = "pt",
) -> str:
+ assert self._model is not None, "Model must be loaded"
audio = self._resample(path, sampling_rate)
inputs = self._processor(
audio,
@@ -57,7 +63,8 @@ async def __call__(
).to(self._device)
with inference_mode():
# shape (batch, time_steps, vocab_size)
+ assert isinstance(self._model, PreTrainedModel)
logits = self._model(inputs.input_values).logits
predicted_ids = argmax(logits, dim=-1)
- transcription = self._processor.batch_decode(predicted_ids)[0]
+ transcription: str = self._processor.batch_decode(predicted_ids)[0]
return transcription
diff --git a/src/avalan/model/criteria.py b/src/avalan/model/criteria.py
index a6094bfc..4a6b9d0e 100644
--- a/src/avalan/model/criteria.py
+++ b/src/avalan/model/criteria.py
@@ -1,23 +1,27 @@
from io import StringIO
from re import Pattern, compile, escape
+from typing import Any
-from transformers import AutoTokenizer
+from torch import Tensor
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation import StoppingCriteria
class KeywordStoppingCriteria(StoppingCriteria):
+ """Stop generation when specific keywords are detected in output."""
+
_buffer: StringIO
- _tokenizer: AutoTokenizer
- _pattern: Pattern
+ _tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast
+ _pattern: Pattern[str]
_keywords: list[str]
_keyword_count: int
def __init__(
self,
keywords: list[str],
- tokenizer: AutoTokenizer,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
all_must_be_present: bool = False,
- ):
+ ) -> None:
assert keywords
escaped_keywords = [escape(k) for k in keywords]
self._pattern = compile(
@@ -30,7 +34,9 @@ def __init__(
self._keywords = keywords
self._keyword_count = len(keywords)
- def __call__(self, input_ids, scores, **kwargs):
+ def __call__(
+ self, input_ids: Tensor, scores: Tensor, **kwargs: Any
+ ) -> bool:
token_id = input_ids[0][-1]
token = self._tokenizer.decode(token_id, skip_special_tokens=False)
self._buffer.write(token)
diff --git a/src/avalan/model/engine.py b/src/avalan/model/engine.py
index a8a8416f..7728c922 100644
--- a/src/avalan/model/engine.py
+++ b/src/avalan/model/engine.py
@@ -53,12 +53,14 @@
class Engine(ABC):
+ """Base class for model engines."""
+
_device: str
_logger: Logger
_model_id: str | None
_settings: EngineSettings
- _transformers_logging_logger: Logger
- _transformers_logging_level: int
+ _transformers_logging_logger: Logger | None
+ _transformers_logging_level: int | None
_loaded_model: bool = False
_loaded_tokenizer: bool = False
_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None
@@ -117,7 +119,7 @@ def _get_distributed_config(
) -> dict[str, object] | None:
if distributed_config is None:
return None
- config = {"enable_expert_parallel": False}
+ config: dict[str, object] = {"enable_expert_parallel": False}
config.update(distributed_config)
return config
@@ -139,16 +141,16 @@ def __init__(
if self._settings.device
else Engine.get_default_device()
)
- self._transformers_logging_logger = (
- transformers_logging.get_logger()
- if self._settings.change_transformers_logging_level
- else None
- )
- self._transformers_logging_level = (
- self._transformers_logging_logger.level
- if self._settings.change_transformers_logging_level
- else None
- )
+ if self._settings.change_transformers_logging_level:
+ self._transformers_logging_logger = (
+ transformers_logging.get_logger()
+ )
+ self._transformers_logging_level = (
+ self._transformers_logging_logger.level
+ )
+ else:
+ self._transformers_logging_logger = None
+ self._transformers_logging_level = None
auto_load_tokenizer = (
self.uses_tokenizer and self._settings.auto_load_tokenizer
@@ -239,11 +241,9 @@ def is_runnable(self, device: str | None = None) -> bool | None:
def _load_tokenizer_with_tokens(
self, tokenizer_name_or_path: str | None, use_fast: bool = True
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
- raise (
- TokenizerNotSupportedException()
- if not self.uses_tokenizer()
- else NotImplementedError()
- )
+ if not self.uses_tokenizer:
+ raise TokenizerNotSupportedException()
+ raise NotImplementedError()
def __enter__(self):
_l = self._log
@@ -365,12 +365,22 @@ def _load(
self._settings.cache_dir,
)
- if self._settings.enable_eval:
+ if self._settings.enable_eval and isinstance(
+ self._model, PreTrainedModel
+ ):
_l("Setting model %s in eval mode", self._model_id)
self._model.eval()
- if self._tokenizer and (
- self._settings.tokens or self._settings.special_tokens
+ settings = self._settings
+ has_tokens = (
+ hasattr(settings, "tokens") and settings.tokens # type: ignore[union-attr]
+ ) or (
+ hasattr(settings, "special_tokens") and settings.special_tokens # type: ignore[union-attr]
+ )
+ if (
+ self._tokenizer
+ and has_tokens
+ and isinstance(self._model, PreTrainedModel)
):
total_tokens = len(self._tokenizer)
_l(
@@ -407,20 +417,23 @@ def _load(
if self._model and not self._config:
config: ModelConfig | SentenceTransformerModelConfig | None = None
- mc = (
- self._model.config
- if hasattr(self._model, "config")
- else (
- self._model[0].auto_model.config
- if is_sentence_transformer
- else None
- )
- )
+ mc = None
+ if hasattr(self._model, "config"):
+ mc = self._model.config # type: ignore[union-attr]
+ elif is_sentence_transformer and hasattr(
+ self._model, "__getitem__"
+ ):
+ first_module = self._model[0] # type: ignore[index]
+ if hasattr(first_module, "auto_model"):
+ mc = first_module.auto_model.config
if mc:
+ attr_map = getattr(mc, "attribute_map", None)
+ keys_ignore = getattr(mc, "keys_to_ignore_at_inference", None)
+ torch_dt = getattr(mc, "torch_dtype", float32)
config = ModelConfig(
architectures=getattr(mc, "architectures", None),
- attribute_map=getattr(mc, "attribute_map", None),
+ attribute_map=attr_map if attr_map else {},
bos_token_id=getattr(mc, "bos_token_id", None),
bos_token=(
self._tokenizer.decode(mc.bos_token_id)
@@ -448,9 +461,7 @@ def _load(
else None
),
keys_to_ignore_at_inference=(
- mc.keys_to_ignore_at_inference
- if hasattr(mc, "keys_to_ignore_at_inference")
- else None
+ keys_ignore if keys_ignore else []
),
loss_type=(
mc.loss_type if hasattr(mc, "loss_type") else None
@@ -472,9 +483,9 @@ def _load(
else None
),
num_labels=getattr(mc, "num_labels", None),
- output_attentions=getattr(mc, "output_attentions", None),
+ output_attentions=getattr(mc, "output_attentions", False),
output_hidden_states=getattr(
- mc, "output_hidden_states", None
+ mc, "output_hidden_states", False
),
pad_token_id=getattr(mc, "pad_token_id", None),
pad_token=(
@@ -490,7 +501,7 @@ def _load(
else None
),
state_size=(
- len(self._model.state_dict().keys())
+ len(self._model.state_dict().keys()) # type: ignore[union-attr]
if hasattr(self._model, "state_dict")
and self._model.state_dict
else 0
@@ -498,31 +509,41 @@ def _load(
task_specific_params=getattr(
mc, "task_specific_params", None
),
- torch_dtype=(
- str(mc.torch_dtype)
- if hasattr(mc, "torch_dtype")
- else None
- ),
+ torch_dtype=torch_dt,
vocab_size=(
mc.vocab_size if hasattr(mc, "vocab_size") else None
),
tokenizer_class=getattr(mc, "tokenizer_class", None),
)
- if is_sentence_transformer and config:
+ if (
+ is_sentence_transformer
+ and config
+ and isinstance(config, ModelConfig)
+ ):
config = SentenceTransformerModelConfig(
- backend=self._model.backend,
- similarity_function=self._model.similarity_fn_name,
- truncate_dimension=self._model.truncate_dim,
+ backend=getattr(self._model, "backend", "torch"),
+ similarity_function=getattr(
+ self._model, "similarity_fn_name", None
+ ),
+ truncate_dimension=getattr(
+ self._model, "truncate_dim", None
+ ),
transformer_model_config=config,
)
self._config = config
if self._tokenizer and not self._tokenizer_config:
+ settings = self._settings
+ tokens_list = (
+ settings.tokens # type: ignore[union-attr]
+ if hasattr(settings, "tokens")
+ else None
+ )
self._tokenizer_config = TokenizerConfig(
name_or_path=self._tokenizer.name_or_path,
- tokens=self._settings.tokens,
+ tokens=tokens_list,
special_tokens=self._tokenizer.all_special_tokens,
tokenizer_model_max_length=getattr(
self._tokenizer, "model_max_length", 0
@@ -554,11 +575,11 @@ def _get_device_memory(device: str) -> int:
)
return cuda.get_device_properties(index).total_memory
- from psutil import virtual_memory
+ from psutil import virtual_memory # type: ignore[import-untyped]
if device == "mps" and mps.is_available():
- return virtual_memory().total
- return virtual_memory().total
+ return int(virtual_memory().total)
+ return int(virtual_memory().total)
def _log(self, message: str, *args: object) -> None:
self._logger.debug(
diff --git a/src/avalan/model/hubs/huggingface.py b/src/avalan/model/hubs/huggingface.py
index 16885b9f..b4c32ea1 100644
--- a/src/avalan/model/hubs/huggingface.py
+++ b/src/avalan/model/hubs/huggingface.py
@@ -21,6 +21,8 @@
class HuggingfaceHub:
+ """Interface for interacting with the Hugging Face Hub."""
+
DEFAULT_ENDPOINT: str = "https://huggingface.co"
DEFAULT_CACHE_DIR: str = expanduser(
getenv("HF_HUB_CACHE") or "~/.cache/huggingface/hub"
@@ -52,7 +54,7 @@ def __init__(
endpoint=endpoint,
token=access_token,
library_name=name(),
- library_version=version(),
+ library_version=str(version()),
)
self._cache_dir = expanduser(cache_dir)
self._domain = urlparse(endpoint).netloc
@@ -60,7 +62,7 @@ def __init__(
def cache_delete(
self, model_id: str, revisions: list[str] | None = None
- ) -> (HubCacheDeletion | None, Callable[[], None] | None):
+ ) -> tuple[HubCacheDeletion | None, Callable[[], None] | None]:
scan_results = scan_cache_dir(self._cache_dir)
delete_revisions = [
revision.commit_hash
@@ -126,7 +128,9 @@ def cache_scan(
)
for info in scan_results.repos
],
- key=lambda m: m.size_on_disk if sort_models_by_size else m.name,
+ key=lambda m: (
+ m.size_on_disk if sort_models_by_size else m.model_id
+ ),
reverse=sort_models_by_size,
)
return model_caches
@@ -143,7 +147,7 @@ def download(
model_id: str,
*,
workers: int = 8,
- tqdm_class: type[tqdm] | Callable[..., tqdm] | None = None,
+ tqdm_class: type[tqdm] | None = None,
local_dir: str | None = None,
local_dir_use_symlinks: bool | None = None,
) -> str:
@@ -183,8 +187,8 @@ def model_url(self, model_id: str) -> str:
def models(
self,
filter: str | list[str] | None = None,
- name: str | list[str] | None = None,
- search: str | list[str] | None = None,
+ name: str | None = None,
+ search: str | None = None,
*,
library: str | list[str] | None = None,
author: str | None = None,
@@ -226,6 +230,7 @@ def user(self) -> User:
@staticmethod
def _model(model_info: ModelInfo) -> Model:
+ now = datetime.now()
model = Model(
id=model_info.id,
parameters=(
@@ -255,7 +260,7 @@ def _model(model_info: ModelInfo) -> Model:
else None
),
pipeline_tag=model_info.pipeline_tag,
- tags=model_info.tags,
+ tags=model_info.tags or [],
architectures=(
model_info.config["architectures"]
if model_info.config and "architectures" in model_info.config
@@ -279,14 +284,18 @@ def _model(model_info: ModelInfo) -> Model:
else None
),
gated=model_info.gated,
- private=model_info.private,
+ private=model_info.private or False,
disabled=model_info.disabled,
- last_downloads=model_info.downloads,
- downloads=model_info.downloads_all_time or model_info.downloads,
- likes=model_info.likes,
+ last_downloads=model_info.downloads or 0,
+ downloads=model_info.downloads_all_time
+ or model_info.downloads
+ or 0,
+ likes=model_info.likes or 0,
ranking=model_info.trending_score,
- author=model_info.author,
- created_at=model_info.created_at,
- updated_at=model_info.last_modified or model_info.created_at,
+ author=model_info.author or "",
+ created_at=model_info.created_at or now,
+ updated_at=model_info.last_modified
+ or model_info.created_at
+ or now,
)
return model
diff --git a/src/avalan/model/manager.py b/src/avalan/model/manager.py
index 4977d4ac..e29755fa 100644
--- a/src/avalan/model/manager.py
+++ b/src/avalan/model/manager.py
@@ -47,7 +47,7 @@
from logging import Logger
from os import environ
from time import perf_counter
-from typing import TYPE_CHECKING, Any, TypeAlias, get_args
+from typing import TYPE_CHECKING, Any, TypeAlias, cast, get_args
from urllib.parse import parse_qsl, urlparse
if TYPE_CHECKING:
@@ -124,12 +124,13 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: Any | None,
) -> bool:
- return await self._stack.__aexit__(exc_type, exc_value, traceback)
+ result = await self._stack.__aexit__(exc_type, exc_value, traceback)
+ return result if result is not None else False
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_task: ModelCall,
- ):
+ ) -> Any:
modality = model_task.operation.modality
self._logger.info("ModelManager call process started for %s", modality)
@@ -246,7 +247,9 @@ def load(
weight_type: WeightType = "auto",
) -> ModelType:
if "backend" in engine_uri.params:
- backend = Backend(engine_uri.params["backend"])
+ backend_value = engine_uri.params["backend"]
+ assert isinstance(backend_value, str), "backend must be a string"
+ backend = Backend(backend_value)
engine_settings_args = dict(
base_url=base_url,
cache_dir=self._hub.cache_dir,
@@ -294,6 +297,9 @@ def load_engine(
if modality is Modality.EMBEDDING:
from ..model.nlp.sentence import SentenceTransformerModel
+ assert (
+ engine_uri.model_id is not None
+ ), "model_id is required for embedding modality"
model = SentenceTransformerModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -322,9 +328,14 @@ def parse_uri(uri: str) -> EngineUri:
f"Invalid scheme {parsed.scheme!r}, expected 'ai'"
)
- vendor = parsed.hostname
- if not vendor or vendor not in get_args(Vendor) or vendor == "local":
- vendor = None
+ vendor_raw = parsed.hostname
+ vendor: Vendor | None = None
+ if (
+ vendor_raw
+ and vendor_raw in get_args(Vendor)
+ and vendor_raw != "local"
+ ):
+ vendor = cast(Vendor, vendor_raw)
use_host = bool(vendor)
path_prefixed = parsed.path.startswith("/")
params: dict[str, str | int | float | bool] = {}
diff --git a/src/avalan/model/modalities/audio.py b/src/avalan/model/modalities/audio.py
index f626bd8e..84d68a56 100644
--- a/src/avalan/model/modalities/audio.py
+++ b/src/avalan/model/modalities/audio.py
@@ -33,6 +33,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for audio classification"
return AudioClassificationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -90,6 +93,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for speech recognition"
return SpeechRecognitionModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -147,6 +153,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for text to speech"
return TextToSpeechModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -182,19 +189,24 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
+ audio_params = operation.parameters["audio"]
assert (
- operation.parameters["audio"]
- and operation.parameters["audio"].path
- and operation.parameters["audio"].sampling_rate
+ audio_params
+ and audio_params.path
+ and audio_params.sampling_rate
+ and operation.generation_settings
)
+ assert isinstance(operation.input, str), "prompt must be a string"
+ max_tokens = operation.generation_settings.max_new_tokens
+ assert max_tokens is not None, "max_new_tokens is required"
return await model(
- path=operation.parameters["audio"].path,
prompt=operation.input,
- max_new_tokens=operation.generation_settings.max_new_tokens,
- reference_path=operation.parameters["audio"].reference_path,
- reference_text=operation.parameters["audio"].reference_text,
- sampling_rate=operation.parameters["audio"].sampling_rate,
+ path=audio_params.path,
+ max_new_tokens=max_tokens,
+ reference_path=audio_params.reference_path,
+ reference_text=audio_params.reference_text,
+ sampling_rate=audio_params.sampling_rate,
)
@@ -210,6 +222,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for audio generation"
return AudioGenerationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -243,14 +256,19 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
+ audio_params = operation.parameters["audio"]
assert (
operation.input
- and operation.parameters["audio"]
- and operation.parameters["audio"].path
+ and audio_params
+ and audio_params.path
+ and operation.generation_settings
)
+ assert isinstance(operation.input, str), "prompt must be a string"
+ max_tokens = operation.generation_settings.max_new_tokens
+ assert max_tokens is not None, "max_new_tokens is required"
return await model(
operation.input,
- operation.parameters["audio"].path,
- operation.generation_settings.max_new_tokens,
+ audio_params.path,
+ max_tokens,
)
diff --git a/src/avalan/model/modalities/registry.py b/src/avalan/model/modalities/registry.py
index a02d0cf1..fa4944c1 100644
--- a/src/avalan/model/modalities/registry.py
+++ b/src/avalan/model/modalities/registry.py
@@ -6,6 +6,7 @@
Input,
Modality,
Operation,
+ OperationParameters,
ReasoningSettings,
ReasoningTag,
TransformerEngineSettings,
@@ -15,12 +16,13 @@
from argparse import Namespace
from collections.abc import Callable
from contextlib import AsyncExitStack
-from inspect import isclass
from logging import Logger
from typing import Any, Protocol
class ModalityHandler(Protocol):
+ """Protocol for modality handler implementations."""
+
async def __call__(
self,
engine_uri: EngineUri,
@@ -46,16 +48,18 @@ def get_operation_from_arguments(
class ModalityRegistry:
+ """Registry for modality handlers."""
+
_handlers: dict[Modality, ModalityHandler] = {}
@classmethod
def register(
cls, modality: Modality
- ) -> Callable[[ModalityHandler | type], ModalityHandler]:
- def decorator(handler: ModalityHandler | type) -> ModalityHandler:
- cls._handlers[modality] = (
- handler() if isclass(handler) else handler
- )
+ ) -> Callable[[type[ModalityHandler]], type[ModalityHandler]]:
+ def decorator(
+ handler: type[ModalityHandler],
+ ) -> type[ModalityHandler]:
+ cls._handlers[modality] = handler()
return handler
return decorator
@@ -130,11 +134,12 @@ def get_operation_from_arguments(
try:
handler = cls.get(modality)
except NotImplementedError:
+ empty_params: OperationParameters = {}
return Operation(
generation_settings=settings,
input=input_string,
modality=modality,
- parameters=None,
+ parameters=empty_params,
requires_input=False,
)
return handler.get_operation_from_arguments(
diff --git a/src/avalan/model/modalities/text.py b/src/avalan/model/modalities/text.py
index 4f42dc96..434add02 100644
--- a/src/avalan/model/modalities/text.py
+++ b/src/avalan/model/modalities/text.py
@@ -55,6 +55,8 @@ def _get_mlx_model() -> type[TextGenerationModel] | None:
@ModalityRegistry.register(Modality.TEXT_GENERATION)
class TextGenerationModality:
+ """Handler for text generation modality."""
+
def load_engine(
self,
engine_uri: EngineUri,
@@ -62,7 +64,8 @@ def load_engine(
logger: Logger,
exit_stack: AsyncExitStack,
) -> TextGenerationModel:
- model_load_args = dict(
+ assert engine_uri.model_id, "model_id is required for text generation"
+ model_load_args: dict[str, Any] = dict(
model_id=engine_uri.model_id,
settings=engine_settings,
logger=logger,
@@ -80,75 +83,71 @@ def load_engine(
return mlx_loader(**model_load_args)
case Backend.VLLM:
- from ..nlp.text.vllm import VllmModel as Loader
+ from ..nlp.text.vllm import VllmModel
- return Loader(**model_load_args)
+ return VllmModel(**model_load_args)
case _:
return TextGenerationModel(**model_load_args)
match engine_uri.vendor:
case "anthropic":
- from ..nlp.text.vendor.anthropic import (
- AnthropicModel as Loader,
- )
+ from ..nlp.text.vendor.anthropic import AnthropicModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return AnthropicModel(**model_load_args, exit_stack=exit_stack)
case "openai":
- from ..nlp.text.vendor.openai import OpenAIModel as Loader
+ from ..nlp.text.vendor.openai import OpenAIModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return OpenAIModel(**model_load_args, exit_stack=exit_stack)
case "bedrock":
- from ..nlp.text.vendor.bedrock import BedrockModel as Loader
+ from ..nlp.text.vendor.bedrock import BedrockModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return BedrockModel(**model_load_args, exit_stack=exit_stack)
case "openrouter":
- from ..nlp.text.vendor.openrouter import (
- OpenRouterModel as Loader,
- )
+ from ..nlp.text.vendor.openrouter import OpenRouterModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return OpenRouterModel(
+ **model_load_args, exit_stack=exit_stack
+ )
case "anyscale":
- from ..nlp.text.vendor.anyscale import AnyScaleModel as Loader
+ from ..nlp.text.vendor.anyscale import AnyScaleModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return AnyScaleModel(**model_load_args, exit_stack=exit_stack)
case "together":
- from ..nlp.text.vendor.together import TogetherModel as Loader
+ from ..nlp.text.vendor.together import TogetherModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return TogetherModel(**model_load_args, exit_stack=exit_stack)
case "deepseek":
- from ..nlp.text.vendor.deepseek import DeepSeekModel as Loader
+ from ..nlp.text.vendor.deepseek import DeepSeekModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return DeepSeekModel(**model_load_args, exit_stack=exit_stack)
case "deepinfra":
- from ..nlp.text.vendor.deepinfra import (
- DeepInfraModel as Loader,
- )
+ from ..nlp.text.vendor.deepinfra import DeepInfraModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return DeepInfraModel(**model_load_args, exit_stack=exit_stack)
case "groq":
- from ..nlp.text.vendor.groq import GroqModel as Loader
+ from ..nlp.text.vendor.groq import GroqModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return GroqModel(**model_load_args, exit_stack=exit_stack)
case "ollama":
- from ..nlp.text.vendor.ollama import OllamaModel as Loader
+ from ..nlp.text.vendor.ollama import OllamaModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return OllamaModel(**model_load_args, exit_stack=exit_stack)
case "huggingface":
- from ..nlp.text.vendor.huggingface import (
- HuggingfaceModel as Loader,
- )
+ from ..nlp.text.vendor.huggingface import HuggingfaceModel
- return Loader(**model_load_args, exit_stack=exit_stack)
- case "hyperbolic":
- from ..nlp.text.vendor.hyperbolic import (
- HyperbolicModel as Loader,
+ return HuggingfaceModel(
+ **model_load_args, exit_stack=exit_stack
)
+ case "hyperbolic":
+ from ..nlp.text.vendor.hyperbolic import HyperbolicModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return HyperbolicModel(
+ **model_load_args, exit_stack=exit_stack
+ )
case "litellm":
- from ..nlp.text.vendor.litellm import LiteLLMModel as Loader
+ from ..nlp.text.vendor.litellm import LiteLLMModel
- return Loader(**model_load_args, exit_stack=exit_stack)
+ return LiteLLMModel(**model_load_args, exit_stack=exit_stack)
raise NotImplementedError()
def get_operation_from_arguments(
@@ -159,7 +158,7 @@ def get_operation_from_arguments(
) -> Operation:
parameters = OperationParameters(
text=OperationTextParameters(
- manual_sampling=args.display_tokens or 0,
+ manual_sampling=bool(args.display_tokens),
pick_tokens=(
10
if args.display_tokens and args.display_tokens > 0
@@ -186,7 +185,9 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
- assert operation.input and operation.parameters["text"]
+ assert operation.input, "operation.input is required"
+ text_params = operation.parameters["text"]
+ assert text_params, "text parameters are required"
criteria = _stopping_criteria(operation, model)
mlx_model = _get_mlx_model()
@@ -194,21 +195,19 @@ async def __call__(
if engine_uri.is_local and not is_mlx:
return await model(
operation.input,
- system_prompt=operation.parameters["text"].system_prompt,
- developer_prompt=operation.parameters["text"].developer_prompt,
+ system_prompt=text_params.system_prompt,
+ developer_prompt=text_params.developer_prompt,
settings=operation.generation_settings,
stopping_criterias=[criteria] if criteria else None,
- manual_sampling=operation.parameters["text"].manual_sampling,
- pick=operation.parameters["text"].pick_tokens,
- skip_special_tokens=operation.parameters[
- "text"
- ].skip_special_tokens,
+ manual_sampling=text_params.manual_sampling or False,
+ pick=text_params.pick_tokens,
+ skip_special_tokens=text_params.skip_special_tokens or False,
tool=tool,
)
return await model(
operation.input,
- system_prompt=operation.parameters["text"].system_prompt,
- developer_prompt=operation.parameters["text"].developer_prompt,
+ system_prompt=text_params.system_prompt,
+ developer_prompt=text_params.developer_prompt,
settings=operation.generation_settings,
tool=tool,
)
@@ -226,6 +225,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for question answering"
return QuestionAnsweringModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -286,6 +288,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for sequence classification"
return SequenceClassificationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -302,7 +307,7 @@ def get_operation_from_arguments(
generation_settings=settings,
input=input_string,
modality=Modality.TEXT_SEQUENCE_CLASSIFICATION,
- parameters=None,
+ parameters=OperationParameters(),
requires_input=True,
)
@@ -329,6 +334,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for sequence to sequence"
return SequenceToSequenceModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -362,6 +370,7 @@ async def __call__(
tool: ToolManager | None = None,
) -> Any:
assert operation.input and operation.parameters["text"]
+ assert operation.generation_settings, "generation_settings is required"
criteria = _stopping_criteria(operation, model)
return await model(
operation.input,
@@ -382,6 +391,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for token classification"
return TokenClassificationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -437,6 +449,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for translation"
return TranslationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -472,22 +485,25 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
+ text_params = operation.parameters["text"]
assert (
operation.input
- and operation.parameters["text"]
- and operation.parameters["text"].language_source
- and operation.parameters["text"].language_destination
+ and text_params
+ and text_params.language_source
+ and text_params.language_destination
+ and operation.generation_settings
)
criteria = _stopping_criteria(operation, model)
+ skip_special_tokens = (
+ text_params.skip_special_tokens
+ if text_params.skip_special_tokens is not None
+ else True
+ )
return await model(
operation.input,
- source_language=operation.parameters["text"].language_source,
- destination_language=operation.parameters[
- "text"
- ].language_destination,
+ source_language=text_params.language_source,
+ destination_language=text_params.language_destination,
settings=operation.generation_settings,
stopping_criterias=[criteria] if criteria else None,
- skip_special_tokens=operation.parameters[
- "text"
- ].skip_special_tokens,
+ skip_special_tokens=skip_special_tokens,
)
diff --git a/src/avalan/model/modalities/vision.py b/src/avalan/model/modalities/vision.py
index 2f0d30e5..366cb3f6 100644
--- a/src/avalan/model/modalities/vision.py
+++ b/src/avalan/model/modalities/vision.py
@@ -39,6 +39,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for encoder decoder"
return VisionEncoderDecoderModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -72,17 +73,19 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
- assert (
- operation.parameters["vision"]
- and operation.parameters["vision"].path
- )
+ vision_params = operation.parameters["vision"]
+ assert vision_params and vision_params.path
+ prompt = operation.input if isinstance(operation.input, str) else None
+ skip_special_tokens = (
+ vision_params.skip_special_tokens
+ if vision_params.skip_special_tokens is not None
+ else True
+ )
return await model(
- operation.parameters["vision"].path,
- prompt=operation.input,
- skip_special_tokens=operation.parameters[
- "vision"
- ].skip_special_tokens,
+ vision_params.path,
+ prompt=prompt,
+ skip_special_tokens=skip_special_tokens,
)
@@ -98,6 +101,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for image classification"
return ImageClassificationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -150,6 +156,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for image to text"
return ImageToTextModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -183,16 +190,17 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
- assert (
- operation.parameters["vision"]
- and operation.parameters["vision"].path
- )
+ vision_params = operation.parameters["vision"]
+ assert vision_params and vision_params.path
+ skip_special_tokens = (
+ vision_params.skip_special_tokens
+ if vision_params.skip_special_tokens is not None
+ else True
+ )
return await model(
- operation.parameters["vision"].path,
- skip_special_tokens=operation.parameters[
- "vision"
- ].skip_special_tokens,
+ vision_params.path,
+ skip_special_tokens=skip_special_tokens,
)
@@ -208,6 +216,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for image text to text"
return ImageTextToTextModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -247,18 +258,17 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
- assert (
- operation.parameters["vision"]
- and operation.parameters["vision"].path
- )
+ vision_params = operation.parameters["vision"]
+ assert vision_params and vision_params.path
+ assert isinstance(operation.input, str), "prompt must be a string"
return await model(
- operation.parameters["vision"].path,
+ vision_params.path,
operation.input,
- system_prompt=operation.parameters["vision"].system_prompt,
- developer_prompt=operation.parameters["vision"].developer_prompt,
+ system_prompt=vision_params.system_prompt,
+ developer_prompt=vision_params.developer_prompt,
settings=operation.generation_settings,
- width=operation.parameters["vision"].width,
+ width=vision_params.width,
)
@@ -274,6 +284,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for object detection"
return ObjectDetectionModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -335,6 +346,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for text to image"
return TextToImageModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -403,6 +415,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for text to animation"
return TextToAnimationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -471,6 +486,7 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert engine_uri.model_id, "model_id is required for text to video"
return TextToVideoModel(
model_id=engine_uri.model_id,
settings=engine_settings,
@@ -517,13 +533,10 @@ async def __call__(
operation: Operation,
tool: ToolManager | None = None,
) -> Any:
- assert (
- operation.input
- and operation.parameters["vision"]
- and operation.parameters["vision"].path
- )
vision = operation.parameters["vision"]
- kwargs = {
+ assert operation.input and vision and vision.path
+ assert isinstance(operation.input, str), "prompt must be a string"
+ kwargs: dict[str, Any] = {
"reference_path": vision.reference_path,
"negative_prompt": vision.negative_prompt,
"height": vision.height,
@@ -555,6 +568,9 @@ def load_engine(
_ = exit_stack
if not engine_uri.is_local:
raise NotImplementedError()
+ assert (
+ engine_uri.model_id
+ ), "model_id is required for semantic segmentation"
return SemanticSegmentationModel(
model_id=engine_uri.model_id,
settings=engine_settings,
diff --git a/src/avalan/model/nlp/__init__.py b/src/avalan/model/nlp/__init__.py
index 2a382069..74e5958a 100644
--- a/src/avalan/model/nlp/__init__.py
+++ b/src/avalan/model/nlp/__init__.py
@@ -4,6 +4,7 @@
from abc import ABC
from contextlib import nullcontext
+from typing import Any
from torch import (
Tensor,
@@ -21,7 +22,7 @@ def _generate_output(
settings: GenerationSettings,
stopping_criterias: list[StoppingCriteria] | None = None,
streamer: AsyncTextIteratorStreamer | None = None,
- ):
+ ) -> Any:
eos_token_id = (
settings.eos_token_id
if settings.eos_token_id
@@ -31,7 +32,8 @@ def _generate_output(
else None
)
)
- generation_kwargs = {
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+ generation_kwargs: dict[str, Any] = {
"bos_token_id": settings.bos_token_id,
"diversity_penalty": settings.diversity_penalty,
"do_sample": settings.do_sample,
@@ -73,35 +75,42 @@ def _generate_output(
if settings.use_inputs_attention_mask:
attention_mask = (
inputs.get("attention_mask", None)
- if isinstance(inputs, BatchEncoding)
+ if isinstance(inputs, (BatchEncoding, dict))
else getattr(inputs, "attention_mask", None)
)
if attention_mask is not None:
assert isinstance(attention_mask, Tensor)
- assert attention_mask.shape == inputs["input_ids"].shape
+ if isinstance(inputs, (BatchEncoding, dict)):
+ input_ids = inputs.get("input_ids")
+ if input_ids is not None:
+ assert attention_mask.shape == input_ids.shape
generation_kwargs["attention_mask"] = attention_mask
- if (
+ if isinstance(inputs, (BatchEncoding, dict)) and (
not settings.use_inputs_attention_mask
or attention_mask is not None
):
- inputs.pop("attention_mask", None)
+ inputs.pop("attention_mask", None) # type: ignore[union-attr]
if settings.forced_bos_token_id or settings.forced_eos_token_id:
del generation_kwargs["bos_token_id"]
del generation_kwargs["eos_token_id"]
+ assert self._model is not None, "Model must be loaded"
+ assert hasattr(
+ self._model, "generate"
+ ), "Model must support generate()"
with (
inference_mode()
if not settings.enable_gradient_calculation
else nullcontext()
):
outputs = (
- self._model.generate(
+ self._model.generate( # type: ignore[operator]
inputs, tokenizer=self._tokenizer, **generation_kwargs
)
- if isinstance(inputs, Tensor)
- else self._model.generate(
+ if isinstance(inputs, Tensor) # type: ignore[arg-type]
+ else self._model.generate( # type: ignore[operator]
**inputs, tokenizer=self._tokenizer, **generation_kwargs
)
)
diff --git a/src/avalan/model/nlp/question.py b/src/avalan/model/nlp/question.py
index 3e4c4d00..0e7620f8 100644
--- a/src/avalan/model/nlp/question.py
+++ b/src/avalan/model/nlp/question.py
@@ -4,7 +4,7 @@
from ...model.nlp import BaseNLPModel
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import argmax, inference_mode
@@ -24,7 +24,8 @@ def supports_token_streaming(self) -> bool:
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
- model = AutoModelForQuestionAnswering.from_pretrained(
+ assert self._model_id is not None, "Model ID must be set"
+ model: PreTrainedModel = AutoModelForQuestionAnswering.from_pretrained(
self._model_id,
cache_dir=self._settings.cache_dir,
subfolder=self._settings.subfolder or "",
@@ -42,7 +43,7 @@ def _load_model(
)
return model
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -56,14 +57,18 @@ def _tokenize_input(
+ f"{self._model_id} does not support chat "
+ "templates"
)
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+ assert self._model is not None, "Model must be loaded"
_l = self._log
_l(f"Tokenizing input {input}")
- inputs = self._tokenizer(input, context, return_tensors=tensor_format)
- inputs = inputs.to(self._model.device)
+ inputs: BatchEncoding = self._tokenizer(
+ input, context, return_tensors=tensor_format
+ )
+ inputs = inputs.to(self._model.device) # type: ignore[union-attr]
return inputs
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
*,
@@ -71,6 +76,7 @@ async def __call__(
system_prompt: str | None = None,
developer_prompt: str | None = None,
skip_special_tokens: bool = True,
+ **kwargs: Any,
) -> str:
assert self._tokenizer, (
f"Model {self._model} can't be executed "
@@ -87,13 +93,13 @@ async def __call__(
context=context,
)
with inference_mode():
- outputs = self._model(**inputs)
- start_answer_logits = outputs.start_logits
- end_answer_logits = outputs.end_logits
+ outputs = self._model(**inputs) # type: ignore[operator]
+ start_answer_logits = outputs.start_logits # type: ignore[union-attr]
+ end_answer_logits = outputs.end_logits # type: ignore[union-attr]
start = argmax(start_answer_logits)
end = argmax(end_answer_logits)
answer_ids = inputs["input_ids"][0, start : end + 1]
- answer = self._tokenizer.decode(
+ answer: str = self._tokenizer.decode(
answer_ids, skip_special_tokens=skip_special_tokens
)
return answer
diff --git a/src/avalan/model/nlp/sentence.py b/src/avalan/model/nlp/sentence.py
index b31daa59..e0c6d88c 100644
--- a/src/avalan/model/nlp/sentence.py
+++ b/src/avalan/model/nlp/sentence.py
@@ -2,15 +2,12 @@
from ...entities import Input
from ...model.engine import Engine
from ...model.nlp import BaseNLPModel
-from ...model.vendor import TextGenerationVendor
from contextlib import nullcontext
-from typing import Literal
+from typing import Any, Literal
-from diffusers import DiffusionPipeline
from numpy import ndarray
from torch import inference_mode
-from transformers import PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding
@@ -28,12 +25,13 @@ def uses_tokenizer(self) -> bool:
return True
def token_count(self, input: str) -> int:
+ assert self.tokenizer is not None, "Tokenizer must be loaded"
token_ids = self.tokenizer.encode(input, add_special_tokens=False)
return len(token_ids)
def _load_model(
self,
- ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ ) -> Any: # Returns SentenceTransformer which isn't a PreTrainedModel
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
@@ -61,7 +59,7 @@ def _load_model(
)
return model
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -73,9 +71,12 @@ def _tokenize_input(
raise NotImplementedError()
@override
- async def __call__(
- self, input: Input, *args, enable_gradient_calculation: bool = False
- ) -> ndarray:
+ async def __call__( # type: ignore[override]
+ self,
+ input: Input,
+ *args: Any,
+ enable_gradient_calculation: bool = False,
+ ) -> ndarray[Any, Any]:
assert self._model, (
f"Model {self._model} can't be executed, it "
+ "needs to be loaded first"
@@ -86,7 +87,8 @@ async def __call__(
if not enable_gradient_calculation
else nullcontext()
):
- embeddings = self._model.encode(
+ # self._model is SentenceTransformer at runtime
+ embeddings: ndarray[Any, Any] = self._model.encode( # type: ignore[union-attr, operator]
input, convert_to_numpy=True, show_progress_bar=False
)
return embeddings
diff --git a/src/avalan/model/nlp/sequence.py b/src/avalan/model/nlp/sequence.py
index e8e7b1b7..ba33e809 100644
--- a/src/avalan/model/nlp/sequence.py
+++ b/src/avalan/model/nlp/sequence.py
@@ -5,7 +5,7 @@
from ...model.vendor import TextGenerationVendor
from dataclasses import replace
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import Tensor, argmax, inference_mode, softmax
@@ -30,25 +30,28 @@ def supports_token_streaming(self) -> bool:
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
- model = AutoModelForSequenceClassification.from_pretrained(
- self._model_id,
- cache_dir=self._settings.cache_dir,
- subfolder=self._settings.subfolder or "",
- attn_implementation=self._settings.attention,
- trust_remote_code=self._settings.trust_remote_code,
- torch_dtype=Engine.weight(self._settings.weight_type),
- state_dict=self._settings.state_dict,
- local_files_only=self._settings.local_files_only,
- token=self._settings.access_token,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
+ assert self._model_id is not None, "Model ID must be set"
+ model: PreTrainedModel = (
+ AutoModelForSequenceClassification.from_pretrained(
+ self._model_id,
+ cache_dir=self._settings.cache_dir,
+ subfolder=self._settings.subfolder or "",
+ attn_implementation=self._settings.attention,
+ trust_remote_code=self._settings.trust_remote_code,
+ torch_dtype=Engine.weight(self._settings.weight_type),
+ state_dict=self._settings.state_dict,
+ local_files_only=self._settings.local_files_only,
+ token=self._settings.access_token,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ )
)
return model
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -62,14 +65,20 @@ def _tokenize_input(
+ f"{self._model_id} does not support chat "
+ "templates"
)
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+ assert self._model is not None, "Model must be loaded"
_l = self._log
_l(f"Tokenizing input {input}")
- inputs = self._tokenizer(input, return_tensors=tensor_format)
- inputs = inputs.to(self._model.device)
+ inputs: BatchEncoding = self._tokenizer(
+ input, return_tensors=tensor_format
+ )
+ inputs = inputs.to(self._model.device) # type: ignore[union-attr]
return inputs
@override
- async def __call__(self, input: Input) -> str:
+ async def __call__( # type: ignore[override]
+ self, input: Input, **kwargs: Any
+ ) -> str:
assert self._tokenizer, (
f"Model {self._model} can't be executed "
+ "without a tokenizer loaded first"
@@ -80,11 +89,13 @@ async def __call__(self, input: Input) -> str:
)
inputs = self._tokenize_input(input, system_prompt=None, context=None)
with inference_mode():
- outputs = self._model(**inputs)
+ outputs = self._model(**inputs) # type: ignore[operator]
# logits shape (batch_size, num_labels)
- label_probs = softmax(outputs.logits, dim=-1)
- label_id = argmax(label_probs, dim=-1).item()
- label = self._model.config.id2label[label_id]
+ label_probs = softmax(outputs.logits, dim=-1) # type: ignore[union-attr]
+ label_id = int(argmax(label_probs, dim=-1).item())
+ id2label = self._model.config.id2label # type: ignore[union-attr]
+ assert id2label is not None, "Model config must have id2label"
+ label: str = id2label[label_id]
return label
@@ -100,7 +111,8 @@ def supports_token_streaming(self) -> bool:
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
- model = AutoModelForSeq2SeqLM.from_pretrained(
+ assert self._model_id is not None, "Model ID must be set"
+ model: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(
self._model_id,
cache_dir=self._settings.cache_dir,
subfolder=self._settings.subfolder or "",
@@ -118,7 +130,7 @@ def _load_model(
)
return model
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -132,18 +144,24 @@ def _tokenize_input(
+ f"{self._model_id} does not support chat "
+ "templates"
)
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+ assert self._model is not None, "Model must be loaded"
_l = self._log
_l(f"Tokenizing input {input}")
- inputs = self._tokenizer(input, return_tensors=tensor_format)
- inputs = inputs.to(self._model.device)
- return inputs["input_ids"]
+ inputs: BatchEncoding = self._tokenizer(
+ input, return_tensors=tensor_format
+ )
+ inputs = inputs.to(self._model.device) # type: ignore[union-attr]
+ input_ids: Tensor = inputs["input_ids"]
+ return input_ids
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
settings: GenerationSettings,
stopping_criterias: list[StoppingCriteria] | None = None,
+ **kwargs: Any,
) -> str:
assert self._tokenizer, (
f"Model {self._model} can't be executed "
@@ -166,12 +184,12 @@ async def __call__(
settings,
stopping_criterias,
)
- return self._tokenizer.decode(output_ids[0], skip_special_tokens=True)
+ return self._tokenizer.decode(output_ids[0], skip_special_tokens=True) # type: ignore[return-value]
class TranslationModel(SequenceToSequenceModel):
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
source_language: str,
@@ -179,6 +197,7 @@ async def __call__(
settings: GenerationSettings,
stopping_criterias: list[StoppingCriteria] | None = None,
skip_special_tokens: bool = True,
+ **kwargs: Any,
) -> str:
assert self._tokenizer, (
f"Model {self._model} can't be executed "
@@ -198,8 +217,8 @@ async def __call__(
self._tokenizer, "lang_code_to_id"
)
- previous_language = self._tokenizer.src_lang
- self._tokenizer.src_lang = source_language
+ previous_language: str = self._tokenizer.src_lang # type: ignore[attr-defined]
+ self._tokenizer.src_lang = source_language # type: ignore[attr-defined]
inputs = self._tokenize_input(input, system_prompt=None, context=None)
generation_settings = replace(
settings,
@@ -207,7 +226,7 @@ async def __call__(
repetition_penalty=1.0,
use_cache=True,
temperature=None,
- forced_bos_token_id=self._tokenizer.lang_code_to_id[
+ forced_bos_token_id=self._tokenizer.lang_code_to_id[ # type: ignore[attr-defined]
destination_language
],
)
@@ -216,8 +235,8 @@ async def __call__(
generation_settings,
stopping_criterias,
)
- text = self._tokenizer.decode(
+ text: str = self._tokenizer.decode(
output_ids[0], skip_special_tokens=skip_special_tokens
)
- self._tokenizer.src_lang = previous_language
+ self._tokenizer.src_lang = previous_language # type: ignore[attr-defined]
return text
diff --git a/src/avalan/model/nlp/text/generation.py b/src/avalan/model/nlp/text/generation.py
index 181896fd..973c7a1f 100644
--- a/src/avalan/model/nlp/text/generation.py
+++ b/src/avalan/model/nlp/text/generation.py
@@ -25,7 +25,7 @@
from importlib.util import find_spec
from logging import Logger, getLogger
from threading import Thread
-from typing import AsyncGenerator, Literal
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal
from diffusers import DiffusionPipeline
from torch import Tensor, log_softmax, softmax, topk
@@ -41,15 +41,18 @@
from transformers.generation import StoppingCriteria
from transformers.tokenization_utils_base import BatchEncoding
+if TYPE_CHECKING:
+ from types import ModuleType
+
_TOOL_MESSAGE_PARSER = ToolCallParser()
class TextGenerationModel(BaseNLPModel):
_loaders: dict[TextGenerationLoaderClass, type[PreTrainedModel]] = {
- "auto": AutoModelForCausalLM,
- "gemma3": Gemma3ForConditionalGeneration,
- "gpt-oss": GptOssForCausalLM,
- "mistral3": Mistral3ForConditionalGeneration,
+ "auto": AutoModelForCausalLM, # type: ignore[dict-item]
+ "gemma3": Gemma3ForConditionalGeneration, # type: ignore[dict-item]
+ "gpt-oss": GptOssForCausalLM, # type: ignore[dict-item]
+ "mistral3": Mistral3ForConditionalGeneration, # type: ignore[dict-item]
}
def __init__(
@@ -115,11 +118,13 @@ def _load_model(
if model_args["quantization_config"] is None:
model_args.pop("quantization_config", None)
- model = loader.from_pretrained(self._model_id, **model_args)
+ model: PreTrainedModel = loader.from_pretrained(
+ self._model_id, **model_args
+ )
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
system_prompt: str | None = None,
@@ -210,7 +215,7 @@ async def _stream_generator(
_l = self._log
streamer = AsyncTextIteratorStreamer(
- self._tokenizer,
+ self._tokenizer, # type: ignore[arg-type]
skip_prompt=True,
decode_kwargs={"skip_special_tokens": skip_special_tokens},
)
@@ -250,10 +255,15 @@ def _string_output(
settings: GenerationSettings,
stopping_criterias: list[StoppingCriteria] | None,
skip_special_tokens: bool,
- **kwargs,
+ **kwargs: Any,
) -> str:
- input_length = inputs["input_ids"].shape[1]
+ assert isinstance(
+ inputs, dict
+ ), "inputs must be a dict for _string_output"
+ input_ids = inputs["input_ids"]
+ input_length: int = input_ids.shape[1] # type: ignore[union-attr]
outputs = self._generate_output(inputs, settings, stopping_criterias)
+ assert self._tokenizer is not None
return self._tokenizer.decode(
outputs[0][input_length:], skip_special_tokens=skip_special_tokens
)
@@ -274,12 +284,13 @@ async def _token_generator(
_l = self._log
+ entmax: ModuleType | None = None
enable_entmax = find_spec("entmax") and probability_distribution in [
"entmax",
"sparsemax",
]
if enable_entmax:
- import entmax
+ import entmax # type: ignore[import-not-found,no-redef]
_l(
f"Generating up to {settings.max_new_tokens} tokens "
@@ -297,11 +308,16 @@ async def _token_generator(
)
sequences = outputs.sequences[0]
scores = outputs.scores # list of logits for each generated token
- start = inputs["input_ids"].shape[1] # where generation began
+ assert isinstance(
+ inputs, dict
+ ), "inputs must be a dict for _token_generator"
+ input_ids = inputs["input_ids"]
+ start: int = input_ids.shape[1] # type: ignore[union-attr]
generated_sequences = sequences[start:]
_l(f"Generated {len(generated_sequences)} sequences")
+ assert self._tokenizer is not None
total_tokens = 0
for step, token_id in enumerate(generated_sequences):
_l(f"Got step {step} token {token_id}")
@@ -312,42 +328,43 @@ async def _token_generator(
logits = tensor[0] # first element in batch dimension
# apply probabilty distribution over last tensor layer, vocab_size
+ temp = settings.temperature if settings.temperature else 1.0
logits_probs = (
log_softmax(logits, dim=-1)
if probability_distribution == "log_softmax"
else (
- gumbel_softmax(
- logits, tau=settings.temperature, hard=False, dim=-1
- )
+ gumbel_softmax(logits, tau=temp, hard=False, dim=-1)
if probability_distribution == "gumbel_softmax"
else (
entmax.sparsemax(logits, dim=-1)
if enable_entmax
+ and entmax is not None
and probability_distribution == "sparsemax"
else (
entmax.entmax15(logits, dim=-1)
if enable_entmax
+ and entmax is not None
and probability_distribution == "entmax"
- else softmax(logits / settings.temperature, dim=-1)
+ else softmax(logits / temp, dim=-1)
)
)
)
)
tokens: list[Token] | None = None
- if pick > 0:
+ if pick is not None and pick > 0:
picked_logits = topk(logits_probs, pick)
picked_logits_ids = picked_logits.indices.tolist()
picked_logits_probs = picked_logits.values.tolist()
tokens = [
Token(
- id=token_id,
+ id=tid,
token=self._tokenizer.decode(
- token_id, skip_special_tokens=skip_special_tokens
+ tid, skip_special_tokens=skip_special_tokens
),
probability=picked_logits_probs[i],
)
- for i, token_id in enumerate(picked_logits_ids)
+ for i, tid in enumerate(picked_logits_ids)
]
raw_token = TokenDetail(
@@ -371,7 +388,7 @@ async def _token_generator(
await sleep(0) # and just like that, a generator is an async generator
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -383,11 +400,15 @@ def _tokenize_input(
tool: ToolManager | None = None,
) -> dict[str, Tensor] | BatchEncoding | Tensor:
_l = self._log
+ assert self._tokenizer is not None
messages = self._messages(input, system_prompt, developer_prompt, tool)
+ tokenizer = self._tokenizer
def _format_content(
- content: str | MessageContent | list[MessageContent],
+ content: str | MessageContent | list[MessageContent] | None,
) -> str | list[dict[str, object]]:
+ if content is None:
+ return ""
if isinstance(content, str):
return content
@@ -395,14 +416,14 @@ def _format_content(
return content.text
if isinstance(content, MessageContentImage):
- if self._tokenizer.chat_template:
+ if tokenizer.chat_template:
return [
{"type": "image_url", "image_url": content.image_url}
]
return ""
if isinstance(content, list):
- if self._tokenizer.chat_template:
+ if tokenizer.chat_template:
blocks: list[dict[str, object]] = []
for c in content:
if isinstance(c, MessageContentImage):
@@ -426,7 +447,7 @@ def _format_content(
return str(content)
- template_messages = []
+ template_messages: list[dict[str, Any]] = []
for message in messages:
message_dict = asdict(message)
prepared = _TOOL_MESSAGE_PARSER.prepare_message_for_template(
@@ -441,7 +462,7 @@ def _format_content(
}
)
- if not self._tokenizer.chat_template:
+ if not tokenizer.chat_template:
_l("Model does not support template messages, so staying plain")
prompt = f"{system_prompt}\n\n" or ""
@@ -461,27 +482,30 @@ def _format_content(
else ""
)
)
- prompt += template_message["content"].strip() + "\n"
+ content = template_message["content"]
+ content_str = content if isinstance(content, str) else ""
+ prompt += content_str.strip() + "\n"
- inputs = self._tokenizer(
+ inputs: Any = tokenizer(
prompt, add_special_tokens=True, return_tensors=tensor_format
)
else:
_l(f"Got {len(template_messages)} template messages")
_l(f"Applying chat template to {len(template_messages)} messages")
- inputs = self._tokenizer.apply_chat_template(
- template_messages,
+ inputs = tokenizer.apply_chat_template(
+ template_messages, # type: ignore[arg-type]
chat_template=chat_template,
- tools=tool.json_schemas() if tool else None,
- **(chat_template_settings or {}),
+ tools=tool.json_schemas() if tool else None, # type: ignore[arg-type]
+ **(chat_template_settings or {}), # type: ignore[arg-type]
return_tensors=tensor_format,
)
- if hasattr(self._model, "device"):
- _l(f"Translating inputs to {self._model.device}")
- inputs = inputs.to(self._model.device)
- return inputs
+ if self._model is not None and hasattr(self._model, "device"):
+ device = getattr(self._model, "device", None)
+ _l(f"Translating inputs to {device}")
+ inputs = inputs.to(device)
+ return inputs # type: ignore[return-value,no-any-return]
def _messages(
self,
@@ -490,24 +514,27 @@ def _messages(
developer_prompt: str | None = None,
tool: ToolManager | None = None,
) -> list[Message]:
+ messages: list[Message]
if isinstance(input, str):
- input = Message(role=MessageRole.USER, content=input)
+ messages = [Message(role=MessageRole.USER, content=input)]
+ elif isinstance(input, Message):
+ messages = [input]
elif isinstance(input, list):
for m in input:
assert isinstance(m, Message)
- elif not isinstance(input, Message):
+ messages = list(input) # type: ignore[arg-type]
+ else:
raise ValueError(input)
- messages = [input] if not isinstance(input, list) else input
-
if developer_prompt:
messages = [
- Message(role=MessageRole.DEVELOPER, content=developer_prompt)
- ] + messages
+ Message(role=MessageRole.DEVELOPER, content=developer_prompt),
+ *messages,
+ ]
if system_prompt:
messages = [
- Message(role=MessageRole.SYSTEM, content=system_prompt)
- ] + messages
+ Message(role=MessageRole.SYSTEM, content=system_prompt),
+ *messages,
+ ]
- assert isinstance(messages, list)
return messages
diff --git a/src/avalan/model/nlp/text/mlxlm.py b/src/avalan/model/nlp/text/mlxlm.py
index 4b1e2399..f3144cda 100644
--- a/src/avalan/model/nlp/text/mlxlm.py
+++ b/src/avalan/model/nlp/text/mlxlm.py
@@ -5,6 +5,7 @@
TransformerEngineSettings,
)
from ....model.response.text import TextGenerationResponse
+from ....model.vendor import TextGenerationVendor
from ....tool.manager import ToolManager
from ...vendor import TextGenerationVendorStream
from .generation import TextGenerationModel
@@ -12,23 +13,26 @@
from asyncio import to_thread
from dataclasses import asdict, replace
from logging import Logger, getLogger
-from typing import AsyncGenerator, Callable, Literal
+from typing import Any, AsyncGenerator, Callable, Literal
+from diffusers import DiffusionPipeline
from mlx_lm import generate, load, stream_generate
from mlx_lm.sample_utils import make_sampler
from torch import Tensor
+from transformers import PreTrainedModel
class MlxLmStream(TextGenerationVendorStream):
"""Async wrapper around a synchronous token generator."""
- _SENTINEL = object()
+ _SENTINEL: object = object()
+ _iterator: Any
- def __init__(self, generator):
+ def __init__(self, generator: Any) -> None:
super().__init__(generator)
self._iterator = generator
- async def __anext__(self) -> str:
+ async def __anext__(self) -> Any:
sentinel = type(self)._SENTINEL
chunk = await to_thread(next, self._iterator, sentinel)
if chunk is sentinel:
@@ -52,13 +56,16 @@ def __init__(
def supports_sample_generation(self) -> bool:
return False
- def _load_model(self):
+ def _load_model(
+ self,
+ ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None
model, tokenizer = load(self._model_id)
- self._tokenizer = tokenizer
+ self._tokenizer = tokenizer # type: ignore[assignment]
self._loaded_tokenizer = True
- return model
+ return model # type: ignore[return-value]
- async def _stream_generator(
+ async def _stream_generator( # type: ignore[override]
self,
inputs: dict[str, Tensor] | Tensor,
settings: GenerationSettings,
@@ -68,8 +75,8 @@ async def _stream_generator(
inputs, settings, skip_special_tokens
)
iterator = stream_generate(
- self._model,
- self._tokenizer,
+ self._model, # type: ignore[arg-type]
+ self._tokenizer, # type: ignore[arg-type]
prompt,
sampler=sampler,
max_tokens=settings.max_new_tokens,
@@ -78,7 +85,7 @@ async def _stream_generator(
async for chunk in stream:
yield chunk.text
- def _string_output(
+ def _string_output( # type: ignore[override]
self,
inputs: dict[str, Tensor] | Tensor,
settings: GenerationSettings,
@@ -88,15 +95,15 @@ def _string_output(
inputs, settings, skip_special_tokens
)
return generate(
- self._model,
- self._tokenizer,
+ self._model, # type: ignore[arg-type]
+ self._tokenizer, # type: ignore[arg-type]
prompt,
sampler=sampler,
max_tokens=settings.max_new_tokens,
)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
system_prompt: str | None = None,
@@ -118,11 +125,12 @@ async def __call__(
chat_template_settings=asdict(settings.chat_settings),
)
generation_settings = replace(settings, do_sample=False)
- output_fn = (
+ output_fn: Callable[..., Any] = (
self._stream_generator
if settings.use_async_generator
else self._string_output
)
+ assert self._tokenizer is not None
return TextGenerationResponse(
output_fn,
@@ -140,8 +148,8 @@ def _get_sampler_and_prompt(
inputs: dict[str, Tensor] | Tensor,
settings: GenerationSettings,
skip_special_tokens: bool,
- ) -> tuple[Callable, str]:
- sampler_settings = {
+ ) -> tuple[Callable[..., Any], str]:
+ sampler_settings: dict[str, Any] = {
"temp": settings.temperature,
"top_p": settings.top_p,
"top_k": settings.top_k,
@@ -150,7 +158,9 @@ def _get_sampler_and_prompt(
k: v for k, v in sampler_settings.items() if v is not None
}
sampler = make_sampler(**sampler_settings)
- prompt = self._tokenizer.decode(
+ assert self._tokenizer is not None
+ assert isinstance(inputs, dict), "inputs must be a dict"
+ prompt: str = self._tokenizer.decode(
inputs["input_ids"][0],
skip_special_tokens=skip_special_tokens,
)
diff --git a/src/avalan/model/nlp/text/vendor/__init__.py b/src/avalan/model/nlp/text/vendor/__init__.py
index c185ca5e..d4c7881c 100644
--- a/src/avalan/model/nlp/text/vendor/__init__.py
+++ b/src/avalan/model/nlp/text/vendor/__init__.py
@@ -13,7 +13,7 @@
from contextlib import AsyncExitStack
from dataclasses import replace
from logging import Logger, getLogger
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from tiktoken import encoding_for_model, get_encoding
@@ -59,12 +59,12 @@ def supports_token_streaming(self) -> bool:
def uses_tokenizer(self) -> bool:
return False
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
context: str | None = None,
tensor_format: Literal["pt"] = "pt",
- **kwargs,
+ **kwargs: Any,
) -> dict[str, Tensor] | BatchEncoding | Tensor:
raise NotImplementedError()
@@ -74,6 +74,7 @@ def input_token_count(
system_prompt: str | None = None,
developer_prompt: str | None = None,
) -> int:
+ assert self._model_id is not None
try:
encoding = encoding_for_model(self._model_id)
except KeyError:
@@ -85,11 +86,14 @@ def input_token_count(
total_tokens = 0
for message in messages:
- total_tokens += len(encoding.encode(message.content or ""))
+ content = (
+ message.content if isinstance(message.content, str) else ""
+ )
+ total_tokens += len(encoding.encode(content))
return total_tokens
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
system_prompt: str | None = None,
@@ -98,15 +102,17 @@ async def __call__(
*,
tool: ToolManager | None = None,
) -> TextGenerationResponse:
+ gen_settings = settings or GenerationSettings()
messages = self._messages(input, system_prompt, developer_prompt, tool)
- streamer = await self._model(
+ assert self._model is not None
+ assert self._model_id is not None
+ streamer = await self._model( # type: ignore[operator]
self._model_id,
messages,
- settings,
+ gen_settings,
tool=tool,
- use_async_generator=settings.use_async_generator,
+ use_async_generator=gen_settings.use_async_generator,
)
- gen_settings = settings or GenerationSettings()
return TextGenerationResponse(
streamer,
logger=self._logger,
diff --git a/src/avalan/model/nlp/text/vendor/anthropic.py b/src/avalan/model/nlp/text/vendor/anthropic.py
index 837b3100..fc3e517d 100644
--- a/src/avalan/model/nlp/text/vendor/anthropic.py
+++ b/src/avalan/model/nlp/text/vendor/anthropic.py
@@ -18,7 +18,7 @@
from . import TextGenerationVendorModel
from contextlib import AsyncExitStack
-from typing import AsyncIterator
+from typing import Any, AsyncIterator
from anthropic import AsyncAnthropic
from anthropic.types import RawContentBlockDeltaEvent, RawMessageStopEvent
@@ -27,9 +27,9 @@
class AnthropicStream(TextGenerationVendorStream):
- def __init__(self, events: AsyncIterator):
+ def __init__(self, events: AsyncIterator[Any]) -> None:
async def generator() -> AsyncIterator[Token | TokenDetail | str]:
- tool_blocks: dict[int, dict] = {}
+ tool_blocks: dict[int, dict[str, Any]] = {}
async for event in events:
etype = getattr(event, "type", None)
@@ -100,10 +100,10 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]:
):
break
- super().__init__(generator())
+ super().__init__(generator()) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
- return await self._generator.__anext__()
+ return await self._generator.__anext__() # type: ignore[no-any-return]
class AnthropicClient(TextGenerationVendor):
@@ -116,24 +116,24 @@ def __init__(
base_url: str | None = None,
*,
exit_stack: AsyncExitStack,
- ):
+ ) -> None:
self._client = AsyncAnthropic(api_key=api_key, base_url=base_url)
self._exit_stack = exit_stack
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
- model_id: str,
+ model_id: str | None,
messages: list[Message],
settings: GenerationSettings | None = None,
*,
tool: ToolManager | None = None,
use_async_generator: bool = True,
- ) -> AsyncIterator[Token | TokenDetail | str]:
+ ) -> AsyncIterator[Token | TokenDetail | str] | TextGenerationSingleStream:
settings = settings or GenerationSettings()
system_prompt = self._system_prompt(messages)
template_messages = self._template_messages(messages, ["system"])
- kwargs = {
+ kwargs: dict[str, Any] = {
"model": model_id,
"system": system_prompt,
"messages": template_messages,
@@ -152,32 +152,37 @@ async def __call__(
content = self._non_stream_response_content(response)
return TextGenerationSingleStream(content)
- def _template_messages(
+ def _template_messages( # type: ignore[override]
self,
messages: list[Message],
exclude_roles: list[TemplateMessageRole] | None = None,
- ) -> list[TemplateMessage]:
- tool_results = [
- message.tool_call_result or message.tool_call_error
+ ) -> list[TemplateMessage | dict[str, Any]]:
+ tool_results: list[ToolCallResult | ToolCallError] = [
+ message.tool_call_result or message.tool_call_error # type: ignore[misc]
for message in messages
if message.role == MessageRole.TOOL
and (message.tool_call_result or message.tool_call_error)
]
- do_exclude_roles = [*(exclude_roles or []), "tool"]
- messages = super()._template_messages(messages, do_exclude_roles)
- last_message = next(
+ do_exclude_roles: list[TemplateMessageRole] = [
+ *(exclude_roles or []),
+ "tool",
+ ]
+ template_messages: list[TemplateMessage | dict[str, Any]] = list(
+ super()._template_messages(messages, do_exclude_roles)
+ )
+ last_message: TemplateMessage | dict[str, Any] | None = next(
(
m
- for m in reversed(messages)
+ for m in reversed(template_messages)
if m["role"] == str(MessageRole.ASSISTANT)
),
None,
)
- last_message_index = (
- messages.index(last_message) if last_message else None
+ last_message_index: int | None = (
+ template_messages.index(last_message) if last_message else None
)
- if last_message_index:
- messages[last_message_index] = {
+ if last_message_index is not None and last_message is not None:
+ template_messages[last_message_index] = {
"role": last_message["role"],
"content": [
(
@@ -188,11 +193,11 @@ def _template_messages(
*[
{
"type": "tool_use",
- "id": r.call.id,
+ "id": r.call.id if r.call else None,
"name": TextGenerationVendor.encode_tool_name(
- r.call.name
+ r.call.name if r.call else ""
),
- "input": r.call.arguments,
+ "input": r.call.arguments if r.call else {},
}
for r in tool_results
],
@@ -200,7 +205,8 @@ def _template_messages(
}
for result in tool_results:
- content = {
+ assert result.call is not None
+ content: dict[str, Any] = {
"type": "tool_result",
"tool_use_id": result.call.id,
"content": to_json(
@@ -211,49 +217,55 @@ def _template_messages(
}
if isinstance(result, ToolCallError):
content["is_error"] = True
- result_message = TemplateMessage(role="user", content=[content])
- if last_message_index:
- messages.insert(last_message_index + 1, result_message)
+ result_message: dict[str, Any] = {
+ "role": "user",
+ "content": [content],
+ }
+ if last_message_index is not None:
+ template_messages.insert(
+ last_message_index + 1, result_message
+ )
else:
- messages.append(result_message)
+ template_messages.append(result_message)
# @TODO Ensure this doesn't happen from upstream
- if len(messages) > 1 and (messages[0] == messages[-1]):
- messages.pop()
+ if len(template_messages) > 1 and (
+ template_messages[0] == template_messages[-1]
+ ):
+ template_messages.pop()
- return messages
+ return template_messages
@staticmethod
- def _tool_schemas(tool: ToolManager) -> list[dict] | None:
+ def _tool_schemas(tool: ToolManager) -> list[dict[str, Any]] | None:
schemas = tool.json_schemas()
- return (
- [
- {
- "name": TextGenerationVendor.encode_tool_name(
- t["function"]["name"]
- ),
- "description": t["function"]["description"],
- "input_schema": {
- **t["function"]["parameters"],
- "additionalProperties": False,
- },
- }
- for t in tool.json_schemas()
- if t["type"] == "function"
- ]
- if schemas
- else None
- )
+ if not schemas:
+ return None
+ return [
+ {
+ "name": TextGenerationVendor.encode_tool_name(
+ t["function"]["name"]
+ ),
+ "description": t["function"]["description"],
+ "input_schema": {
+ **t["function"]["parameters"],
+ "additionalProperties": False,
+ },
+ }
+ for t in schemas
+ if t["type"] == "function"
+ ]
@staticmethod
def _non_stream_response_content(response: object) -> str:
- def _get(value: object, attribute: str) -> object | None:
+ def _get(value: object, attribute: str) -> Any:
if isinstance(value, dict):
return value.get(attribute)
return getattr(value, attribute, None)
parts: list[str] = []
- for block in _get(response, "content") or []:
+ content_blocks = _get(response, "content") or []
+ for block in content_blocks:
block_type = _get(block, "type")
if block_type == "text":
text = _get(block, "text")
@@ -262,10 +274,13 @@ def _get(value: object, attribute: str) -> object | None:
continue
if block_type == "tool_use":
+ block_id: str | None = _get(block, "id")
+ block_name: str | None = _get(block, "name")
+ block_input: dict[str, Any] | None = _get(block, "input")
token = TextGenerationVendor.build_tool_call_token(
- _get(block, "id"),
- _get(block, "name"),
- _get(block, "input"),
+ block_id,
+ block_name,
+ block_input,
)
parts.append(token.token)
diff --git a/src/avalan/model/nlp/text/vendor/bedrock.py b/src/avalan/model/nlp/text/vendor/bedrock.py
index 4ee84e8d..4e9f6991 100644
--- a/src/avalan/model/nlp/text/vendor/bedrock.py
+++ b/src/avalan/model/nlp/text/vendor/bedrock.py
@@ -24,7 +24,7 @@
from json import dumps
from typing import Any, AsyncIterator
-from aioboto3 import Session as Boto3Session
+from aioboto3 import Session as Boto3Session # type: ignore[import-not-found]
from diffusers import DiffusionPipeline
from transformers import PreTrainedModel
@@ -49,7 +49,7 @@ def _string(value: Any) -> str | None:
class BedrockStream(TextGenerationVendorStream):
- def __init__(self, events: AsyncIterator):
+ def __init__(self, events: AsyncIterator) -> None: # type: ignore[type-arg]
async def generator() -> AsyncIterator[Token | TokenDetail | str]:
tool_blocks: dict[int, dict[str, Any]] = {}
@@ -153,10 +153,12 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]:
if _get(event, "messageStop"):
break
- super().__init__(generator())
+ super().__init__(generator()) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
- return await self._generator.__anext__()
+ result = await self._generator.__anext__()
+ assert isinstance(result, (Token, TokenDetail, str))
+ return result
class BedrockClient(TextGenerationVendor):
@@ -192,7 +194,7 @@ async def _client_instance(self) -> Any:
return self._client
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_id: str,
messages: list[Message],
@@ -278,7 +280,7 @@ def _response_text(self, response: dict[str, Any]) -> str:
parts.append(text_value)
return "".join(parts)
- def _template_messages(
+ def _template_messages( # type: ignore[override]
self,
messages: list[Message],
exclude_roles: list[TemplateMessageRole] | None = None,
diff --git a/src/avalan/model/nlp/text/vendor/google.py b/src/avalan/model/nlp/text/vendor/google.py
index 385840be..93ab7d08 100644
--- a/src/avalan/model/nlp/text/vendor/google.py
+++ b/src/avalan/model/nlp/text/vendor/google.py
@@ -13,12 +13,13 @@
class GoogleStream(TextGenerationVendorStream):
- def __init__(self, stream: AsyncIterator[GenerateContentResponse]):
- super().__init__(stream)
+ def __init__(self, stream: AsyncIterator[GenerateContentResponse]) -> None:
+ super().__init__(stream) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
chunk = await self._generator.__anext__()
- return chunk.text
+ text: str = chunk.text
+ return text
class GoogleClient(TextGenerationVendor):
@@ -28,7 +29,7 @@ def __init__(self, api_key: str):
self._client = Client(api_key=api_key)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_id: str,
messages: list[Message],
@@ -51,7 +52,7 @@ async def __call__(
contents=contents,
)
- async def single_gen():
+ async def single_gen() -> AsyncIterator[Token | TokenDetail | str]:
yield response.text
return single_gen()
diff --git a/src/avalan/model/nlp/text/vendor/huggingface.py b/src/avalan/model/nlp/text/vendor/huggingface.py
index 7be11b57..206ca780 100644
--- a/src/avalan/model/nlp/text/vendor/huggingface.py
+++ b/src/avalan/model/nlp/text/vendor/huggingface.py
@@ -9,7 +9,7 @@
from ....vendor import TextGenerationVendor, TextGenerationVendorStream
from . import TextGenerationVendorModel
-from typing import AsyncIterator
+from typing import Any, AsyncIterator
from diffusers import DiffusionPipeline
from huggingface_hub import AsyncInferenceClient
@@ -17,13 +17,13 @@
class HuggingfaceStream(TextGenerationVendorStream):
- def __init__(self, stream: AsyncIterator):
- super().__init__(stream.__aiter__())
+ def __init__(self, stream: AsyncIterator) -> None: # type: ignore[type-arg]
+ super().__init__(stream.__aiter__()) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
chunk = await self._generator.__anext__()
delta = chunk.choices[0].delta
- text = getattr(delta, "content", None) or ""
+ text: str = getattr(delta, "content", None) or ""
return text
@@ -34,7 +34,7 @@ def __init__(self, api_key: str, base_url: str | None = None):
self._client = AsyncInferenceClient(token=api_key, base_url=base_url)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_id: str,
messages: list[Message],
@@ -44,22 +44,32 @@ async def __call__(
use_async_generator: bool = True,
) -> AsyncIterator[Token | TokenDetail | str]:
settings = settings or GenerationSettings()
- template_messages = self._template_messages(messages)
+ template_messages: list[dict[Any, Any]] = [
+ {"role": m["role"], "content": m["content"]}
+ for m in self._template_messages(messages)
+ ]
+ stop: list[str] | None = None
+ if settings.stop_strings:
+ stop = (
+ [settings.stop_strings]
+ if isinstance(settings.stop_strings, str)
+ else settings.stop_strings
+ )
response = await self._client.chat_completion(
model=model_id,
messages=template_messages,
temperature=settings.temperature,
max_tokens=settings.max_new_tokens,
top_p=settings.top_p,
- stop=settings.stop_strings,
+ stop=stop,
stream=use_async_generator,
)
if use_async_generator:
- return HuggingfaceStream(response)
+ return HuggingfaceStream(response.__aiter__()) # type: ignore[arg-type, union-attr]
else:
- async def single_gen():
- yield response.choices[0].message.content or ""
+ async def single_gen() -> AsyncIterator[Token | TokenDetail | str]:
+ yield response.choices[0].message.content or "" # type: ignore[union-attr]
return single_gen()
diff --git a/src/avalan/model/nlp/text/vendor/litellm.py b/src/avalan/model/nlp/text/vendor/litellm.py
index 801d60bb..b26ac86b 100644
--- a/src/avalan/model/nlp/text/vendor/litellm.py
+++ b/src/avalan/model/nlp/text/vendor/litellm.py
@@ -12,8 +12,8 @@
class LiteLLMStream(TextGenerationVendorStream):
- def __init__(self, stream: AsyncIterator):
- super().__init__(stream.__aiter__())
+ def __init__(self, stream: AsyncIterator) -> None: # type: ignore[type-arg]
+ super().__init__(stream.__aiter__()) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
chunk = await self._generator.__anext__()
@@ -21,7 +21,7 @@ async def __anext__(self) -> Token | TokenDetail | str:
if isinstance(chunk, dict):
choice = chunk.get("choices", [{}])[0]
delta = choice.get("delta", {}) if isinstance(choice, dict) else {}
- text = delta.get("content", "")
+ text: str = delta.get("content", "")
else:
choice = chunk.choices[0]
delta = getattr(choice, "delta", None)
@@ -40,7 +40,7 @@ def __init__(
self._base_url = base_url or "http://localhost:4000"
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_id: str,
messages: list[Message],
@@ -62,7 +62,7 @@ async def __call__(
if use_async_generator:
return LiteLLMStream(result)
- async def single_gen():
+ async def single_gen() -> AsyncIterator[Token | TokenDetail | str]:
if isinstance(result, dict):
yield result["choices"][0]["message"]["content"]
else:
diff --git a/src/avalan/model/nlp/text/vendor/ollama.py b/src/avalan/model/nlp/text/vendor/ollama.py
index 5894b24a..7d62b24a 100644
--- a/src/avalan/model/nlp/text/vendor/ollama.py
+++ b/src/avalan/model/nlp/text/vendor/ollama.py
@@ -6,29 +6,30 @@
TokenDetail,
TransformerEngineSettings,
)
-from .....model.nlp.text.generation import TextGenerationModel
from .....tool.manager import ToolManager
from ....vendor import TextGenerationVendor, TextGenerationVendorStream
from . import TextGenerationVendorModel
+from contextlib import AsyncExitStack
from dataclasses import replace
from logging import Logger, getLogger
from typing import AsyncIterator
try:
- from ollama import AsyncClient
+ from ollama import AsyncClient # type: ignore[import-not-found]
except Exception: # pragma: no cover - ollama may not be installed
- AsyncClient = None
+ AsyncClient = None # type: ignore[misc, assignment]
class OllamaStream(TextGenerationVendorStream):
- def __init__(self, stream: AsyncIterator[dict]):
- super().__init__(stream)
+ def __init__(self, stream: AsyncIterator[dict]) -> None: # type: ignore[type-arg]
+ super().__init__(stream) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
chunk = await self._generator.__anext__()
message = chunk.get("message", {}) if isinstance(chunk, dict) else {}
- return message.get("content", "")
+ content: str = message.get("content", "")
+ return content
class OllamaClient(TextGenerationVendor):
@@ -41,7 +42,7 @@ def __init__(self, base_url: str | None = None):
)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
model_id: str,
messages: list[Message],
@@ -65,7 +66,7 @@ async def __call__(
stream=False,
)
- async def single_gen():
+ async def single_gen() -> AsyncIterator[Token | TokenDetail | str]:
yield response["message"]["content"]
return single_gen()
@@ -77,10 +78,12 @@ def __init__(
model_id: str,
settings: TransformerEngineSettings | None = None,
logger: Logger = getLogger(__name__),
+ *,
+ exit_stack: AsyncExitStack | None = None,
) -> None:
settings = settings or TransformerEngineSettings()
settings = replace(settings, enable_eval=False)
- TextGenerationModel.__init__(self, model_id, settings, logger)
+ super().__init__(model_id, settings, logger, exit_stack=exit_stack)
def _load_model(self):
return OllamaClient(base_url=self._settings.base_url)
diff --git a/src/avalan/model/nlp/text/vendor/openai.py b/src/avalan/model/nlp/text/vendor/openai.py
index f4be59cc..7a129b17 100644
--- a/src/avalan/model/nlp/text/vendor/openai.py
+++ b/src/avalan/model/nlp/text/vendor/openai.py
@@ -7,6 +7,7 @@
ReasoningToken,
Token,
TokenDetail,
+ ToolCallError,
ToolCallResult,
ToolCallToken,
)
@@ -19,7 +20,7 @@
from . import TextGenerationVendorModel
from json import dumps
-from typing import AsyncIterator
+from typing import Any, AsyncIterator
from diffusers import DiffusionPipeline
from openai import AsyncOpenAI
@@ -27,12 +28,15 @@
class OpenAIStream(TextGenerationVendorStream):
- _TEXT_DELTA_EVENTS = {"response.text.delta", "response.output_text.delta"}
- _REASONING_DELTA_EVENTS = {"response.reasoning_text.delta"}
+ _TEXT_DELTA_EVENTS: set[str] = {
+ "response.text.delta",
+ "response.output_text.delta",
+ }
+ _REASONING_DELTA_EVENTS: set[str] = {"response.reasoning_text.delta"}
- def __init__(self, stream: AsyncIterator):
+ def __init__(self, stream: AsyncIterator[Any]) -> None:
async def generator() -> AsyncIterator[Token | TokenDetail | str]:
- tool_calls: dict[str, dict] = {}
+ tool_calls: dict[str, dict[str, Any]] = {}
async for event in stream:
etype = getattr(event, "type", None)
@@ -42,13 +46,14 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]:
if item:
custom = getattr(item, "custom_tool_call", None)
if custom:
- call_id = getattr(
+ call_id: str | None = getattr(
custom, "id", getattr(item, "id", None)
)
- tool_calls[call_id] = {
- "name": getattr(custom, "name", None),
- "args_fragments": [],
- }
+ if call_id is not None:
+ tool_calls[call_id] = {
+ "name": getattr(custom, "name", None),
+ "args_fragments": [],
+ }
continue
if (
@@ -79,11 +84,17 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]:
if etype == "response.output_item.done":
item = getattr(event, "item", None)
- call_id = getattr(item, "id", None) if item else None
- cached = tool_calls.pop(call_id, None)
+ done_call_id: str | None = (
+ getattr(item, "id", None) if item else None
+ )
+ cached = (
+ tool_calls.pop(done_call_id, None)
+ if done_call_id
+ else None
+ )
if cached:
yield TextGenerationVendor.build_tool_call_token(
- call_id,
+ done_call_id,
cached.get("name"),
"".join(cached["args_fragments"]) or None,
)
@@ -104,22 +115,22 @@ async def generator() -> AsyncIterator[Token | TokenDetail | str]:
continue
- super().__init__(generator())
+ super().__init__(generator()) # type: ignore[arg-type]
async def __anext__(self) -> Token | TokenDetail | str:
- return await self._generator.__anext__()
+ return await self._generator.__anext__() # type: ignore[no-any-return]
class OpenAIClient(TextGenerationVendor):
_client: AsyncOpenAI
- def __init__(self, api_key: str, base_url: str | None):
+ def __init__(self, api_key: str | None, base_url: str | None) -> None:
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
- model_id: str,
+ model_id: str | None,
messages: list[Message],
settings: GenerationSettings | None = None,
*,
@@ -128,7 +139,7 @@ async def __call__(
use_async_generator: bool = True,
) -> AsyncIterator[Token | TokenDetail | str] | TextGenerationSingleStream:
template_messages = self._template_messages(messages)
- kwargs: dict = {
+ kwargs: dict[str, Any] = {
"extra_headers": {
"X-Title": "Avalan",
"HTTP-Referer": "https://github.com/avalan-ai/avalan",
@@ -161,21 +172,27 @@ async def __call__(
content = OpenAIClient._non_stream_response_content(client_stream)
return TextGenerationSingleStream(content)
- def _template_messages(
+ def _template_messages( # type: ignore[override]
self,
messages: list[Message],
exclude_roles: list[TemplateMessageRole] | None = None,
- ) -> list[TemplateMessage]:
- tool_results = [
- message.tool_call_result or message.tool_call_error
+ ) -> list[TemplateMessage | dict[str, Any]]:
+ tool_results: list[ToolCallResult | ToolCallError] = [
+ message.tool_call_result or message.tool_call_error # type: ignore[misc]
for message in messages
if message.role == MessageRole.TOOL
and (message.tool_call_result or message.tool_call_error)
]
- do_exclude_roles = [*(exclude_roles or []), "tool"]
- messages = super()._template_messages(messages, do_exclude_roles)
+ do_exclude_roles: list[TemplateMessageRole] = [
+ *(exclude_roles or []),
+ "tool",
+ ]
+ template_messages: list[TemplateMessage | dict[str, Any]] = list(
+ super()._template_messages(messages, do_exclude_roles)
+ )
for result in tool_results:
- call_message = {
+ assert result.call is not None
+ call_message: dict[str, Any] = {
"type": "function_call",
"name": TextGenerationVendor.encode_tool_name(
result.call.name
@@ -183,9 +200,9 @@ def _template_messages(
"call_id": result.call.id,
"arguments": dumps(result.call.arguments),
}
- messages.append(call_message)
+ template_messages.append(call_message)
- result_message = {
+ result_message: dict[str, Any] = {
"type": "function_call_output",
"call_id": result.call.id,
"output": to_json(
@@ -194,39 +211,38 @@ def _template_messages(
else {"error": result.message}
),
}
- messages.append(result_message)
- return messages
+ template_messages.append(result_message)
+ return template_messages
@staticmethod
- def _tool_schemas(tool: ToolManager) -> list[dict] | None:
+ def _tool_schemas(tool: ToolManager) -> list[dict[str, Any]] | None:
schemas = tool.json_schemas()
- return (
- [
- {
- "type": t["type"],
- **t["function"],
- **{
- "name": TextGenerationVendor.encode_tool_name(
- t["function"]["name"]
- )
- },
- }
- for t in tool.json_schemas()
- if t["type"] == "function"
- ]
- if schemas
- else None
- )
+ if not schemas:
+ return None
+ return [
+ {
+ "type": t["type"],
+ **t["function"],
+ **{
+ "name": TextGenerationVendor.encode_tool_name(
+ t["function"]["name"]
+ )
+ },
+ }
+ for t in schemas
+ if t["type"] == "function"
+ ]
@staticmethod
def _non_stream_response_content(response: object) -> str:
- def _get(value: object, attribute: str) -> object | None:
+ def _get(value: object, attribute: str) -> Any:
if isinstance(value, dict):
return value.get(attribute)
return getattr(value, attribute, None)
parts: list[str] = []
- for item in _get(response, "output") or []:
+ output_items = _get(response, "output") or []
+ for item in output_items:
item_type = _get(item, "type")
contents = _get(item, "content") or []
@@ -240,10 +256,15 @@ def _get(value: object, attribute: str) -> object | None:
if item_type in {"tool_call", "function_call"}:
call = _get(item, "call") or item
function = _get(call, "function") or call
+ call_id: str | None = _get(call, "id")
+ func_name: str | None = _get(function, "name")
+ func_args: str | dict[str, Any] | None = _get(
+ function, "arguments"
+ )
token = TextGenerationVendor.build_tool_call_token(
- _get(call, "id"),
- _get(function, "name"),
- _get(function, "arguments"),
+ call_id,
+ func_name,
+ func_args,
)
parts.append(token.token)
@@ -255,9 +276,9 @@ class OpenAINonStreamingResponse(TextGenerationResponse):
def __init__(
self,
- *args,
+ *args: Any,
static_response_text: str | None = None,
- **kwargs,
+ **kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._static_response_text = static_response_text
@@ -289,7 +310,7 @@ def _load_model(
)
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
system_prompt: str | None = None,
@@ -300,7 +321,9 @@ async def __call__(
) -> TextGenerationResponse:
generation_settings = settings or GenerationSettings()
messages = self._messages(input, system_prompt, developer_prompt, tool)
- streamer = await self._model(
+ assert self._model is not None
+ assert self._model_id is not None
+ streamer = await self._model( # type: ignore[operator]
self._model_id,
messages,
generation_settings,
diff --git a/src/avalan/model/nlp/text/vendor/openrouter.py b/src/avalan/model/nlp/text/vendor/openrouter.py
index 750c2b2d..4246355a 100644
--- a/src/avalan/model/nlp/text/vendor/openrouter.py
+++ b/src/avalan/model/nlp/text/vendor/openrouter.py
@@ -6,13 +6,15 @@
class OpenRouterClient(OpenAIClient):
- def __init__(self, api_key: str, base_url: str | None = None):
+ def __init__(
+ self, api_key: str | None, base_url: str | None = None
+ ) -> None:
super().__init__(
api_key=api_key,
base_url=base_url or "https://openrouter.ai/api/v1",
)
# Optional headers recommended by OpenRouter
- self._client.headers.update(
+ self._client.headers.update( # type: ignore[union-attr,attr-defined]
{
"HTTP-Referer": "https://github.com/avalan-ai/avalan",
"X-Title": "avalan",
diff --git a/src/avalan/model/nlp/text/vllm.py b/src/avalan/model/nlp/text/vllm.py
index 508a8d4c..00c95edd 100644
--- a/src/avalan/model/nlp/text/vllm.py
+++ b/src/avalan/model/nlp/text/vllm.py
@@ -5,23 +5,28 @@
TransformerEngineSettings,
)
from ....model.nlp.text.generation import TextGenerationModel
-from ....model.vendor import TextGenerationVendorStream
+from ....model.vendor import TextGenerationVendor, TextGenerationVendorStream
from ....tool.manager import ToolManager
from asyncio import to_thread
from dataclasses import asdict, replace
from logging import Logger, getLogger
-from typing import AsyncGenerator
+from typing import Any, AsyncGenerator
+
+from diffusers import DiffusionPipeline
+from transformers import PreTrainedModel
try:
from vllm import LLM, SamplingParams
except Exception: # pragma: no cover - vllm may not be installed
- LLM = None
- SamplingParams = None
+ LLM = None # type: ignore[assignment,misc]
+ SamplingParams = None # type: ignore[assignment,misc]
class VllmStream(TextGenerationVendorStream):
- def __init__(self, generator):
+ _iterator: Any
+
+ def __init__(self, generator: Any) -> None:
super().__init__(generator)
self._iterator = generator
@@ -48,17 +53,17 @@ def __init__(
def supports_sample_generation(self) -> bool:
return False
- def _load_model(self):
+ def _load_model(
+ self,
+ ) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
assert LLM, "vLLM is not available"
- return LLM(
+ return LLM( # type: ignore[return-value,no-any-return]
model=self._model_id,
tokenizer=self._settings.tokenizer_name_or_path or self._model_id,
trust_remote_code=self._settings.trust_remote_code,
)
- def _build_sampling_params(
- self, settings: GenerationSettings
- ) -> SamplingParams:
+ def _build_sampling_params(self, settings: GenerationSettings) -> Any:
assert SamplingParams, "vLLM is not available"
return SamplingParams(
temperature=settings.temperature,
@@ -85,32 +90,36 @@ def _prompt(
tool=tool,
chat_template_settings=chat_template_settings,
)
+ assert self._tokenizer is not None
+ assert isinstance(inputs, dict), "inputs must be a dict"
return self._tokenizer.decode(
inputs["input_ids"][0], skip_special_tokens=False
)
- async def _stream_generator(
+ async def _stream_generator( # type: ignore[override]
self,
prompt: str,
settings: GenerationSettings,
) -> AsyncGenerator[str, None]:
params = self._build_sampling_params(settings)
- iterator = self._model.generate([prompt], params, stream=True)
+ assert self._model is not None
+ iterator = self._model.generate([prompt], params, stream=True) # type: ignore[union-attr,operator]
stream = VllmStream(iter(iterator))
async for chunk in stream:
yield chunk
- def _string_output(
+ def _string_output( # type: ignore[override]
self,
prompt: str,
settings: GenerationSettings,
) -> str:
params = self._build_sampling_params(settings)
- results = list(self._model.generate([prompt], params))
+ assert self._model is not None
+ results = list(self._model.generate([prompt], params)) # type: ignore[union-attr,operator]
return results[0].outputs[0].text if results else ""
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
system_prompt: str | None = None,
@@ -129,5 +138,5 @@ async def __call__(
)
generation_settings = replace(settings, do_sample=False)
if settings.use_async_generator:
- return await self._stream_generator(prompt, generation_settings)
+ return await self._stream_generator(prompt, generation_settings) # type: ignore[misc,no-any-return]
return self._string_output(prompt, generation_settings)
diff --git a/src/avalan/model/nlp/token.py b/src/avalan/model/nlp/token.py
index 632b0bb2..4792981b 100644
--- a/src/avalan/model/nlp/token.py
+++ b/src/avalan/model/nlp/token.py
@@ -4,7 +4,7 @@
from ...model.nlp import BaseNLPModel
from ...model.vendor import TextGenerationVendor
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import argmax, inference_mode
@@ -26,23 +26,26 @@ def supports_token_streaming(self) -> bool:
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
- model = AutoModelForTokenClassification.from_pretrained(
- self._model_id,
- cache_dir=self._settings.cache_dir,
- subfolder=self._settings.subfolder or "",
- attn_implementation=self._settings.attention,
- trust_remote_code=self._settings.trust_remote_code,
- torch_dtype=Engine.weight(self._settings.weight_type),
- state_dict=self._settings.state_dict,
- local_files_only=self._settings.local_files_only,
- token=self._settings.access_token,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
+ assert self._model_id is not None, "Model ID must be set"
+ model: PreTrainedModel = (
+ AutoModelForTokenClassification.from_pretrained(
+ self._model_id,
+ cache_dir=self._settings.cache_dir,
+ subfolder=self._settings.subfolder or "",
+ attn_implementation=self._settings.attention,
+ trust_remote_code=self._settings.trust_remote_code,
+ torch_dtype=Engine.weight(self._settings.weight_type),
+ state_dict=self._settings.state_dict,
+ local_files_only=self._settings.local_files_only,
+ token=self._settings.access_token,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ )
)
- labels = (
+ labels: dict[int, str] | None = (
getattr(model.config, "id2label", None)
if hasattr(model, "config")
else None
@@ -56,7 +59,7 @@ def _load_model(
)
return model
- def _tokenize_input(
+ def _tokenize_input( # type: ignore[override]
self,
input: Input,
system_prompt: str | None,
@@ -70,20 +73,25 @@ def _tokenize_input(
+ f"{self._model_id} does not support chat "
+ "templates"
)
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+ assert self._model is not None, "Model must be loaded"
_l = self._log
_l(f"Tokenizing input {input}")
- inputs = self._tokenizer(input, return_tensors=tensor_format)
- inputs = inputs.to(self._model.device)
+ inputs: BatchEncoding = self._tokenizer(
+ input, return_tensors=tensor_format
+ )
+ inputs = inputs.to(self._model.device) # type: ignore[union-attr]
return inputs
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
*,
labeled_only: bool = False,
system_prompt: str | None = None,
developer_prompt: str | None = None,
+ **kwargs: Any,
) -> dict[str, str]:
assert self._tokenizer, (
f"Model {self._model} can't be executed "
@@ -100,10 +108,10 @@ async def __call__(
context=None,
)
with inference_mode():
- outputs = self._model(**inputs)
+ outputs = self._model(**inputs) # type: ignore[operator]
# logits shape (1, seq_len, num_labels)
input_ids = inputs["input_ids"][0]
- label_ids = argmax(outputs.logits, dim=2)[0]
+ label_ids = argmax(outputs.logits, dim=2)[0] # type: ignore[union-attr]
if labeled_only and self._default_label_id is not None:
mask = label_ids != self._default_label_id
@@ -112,9 +120,10 @@ async def __call__(
assert input_ids.numel() == label_ids.numel()
tokens = self._tokenizer.convert_ids_to_tokens(input_ids)
- labels = [
- self._model.config.id2label[label_id.item()]
- for label_id in label_ids
+ id2label = self._model.config.id2label # type: ignore[union-attr]
+ assert id2label is not None, "Model config must have id2label"
+ labels_list: list[str] = [
+ id2label[int(label_id.item())] for label_id in label_ids
]
- tokens_to_labels = dict(zip(tokens, labels))
+ tokens_to_labels = dict(zip(tokens, labels_list))
return tokens_to_labels
diff --git a/src/avalan/model/response/parsers/tool.py b/src/avalan/model/response/parsers/tool.py
index aa821a04..6b4ecdd0 100644
--- a/src/avalan/model/response/parsers/tool.py
+++ b/src/avalan/model/response/parsers/tool.py
@@ -90,7 +90,9 @@ async def push(self, token_str: str) -> Iterable[Any]:
return result
event = Event(
- type=EventType.TOOL_PROCESS, payload=calls, started=perf_counter()
+ type=EventType.TOOL_PROCESS,
+ payload={"calls": calls},
+ started=perf_counter(),
)
self._buffer = StringIO()
diff --git a/src/avalan/model/response/text.py b/src/avalan/model/response/text.py
index 96afb0a3..3b7b6149 100644
--- a/src/avalan/model/response/text.py
+++ b/src/avalan/model/response/text.py
@@ -20,8 +20,12 @@
AsyncIterator,
Awaitable,
Callable,
+ Self,
+ TypeVar,
)
+T = TypeVar("T")
+
OutputGenerator = AsyncGenerator[Token | TokenDetail | str, None]
OutputFunction = Callable[..., OutputGenerator | str]
@@ -92,14 +96,17 @@ def _ensure_non_stream_prefetched(self) -> None:
if self._buffer.tell():
return
- result = self._output_fn(*self._args, **self._kwargs)
- if isinstance(result, TextGenerationSingleStream):
- result = result.content
-
- if isinstance(result, (Token, TokenDetail)):
- text = result.token
+ fn_result = self._output_fn(*self._args, **self._kwargs)
+ if isinstance(fn_result, TextGenerationSingleStream):
+ stream_content = fn_result.content
+ if isinstance(stream_content, (Token, TokenDetail)):
+ text = stream_content.token
+ else:
+ text = str(stream_content)
+ elif isinstance(fn_result, (Token, TokenDetail)):
+ text = fn_result.token
else:
- text = str(result)
+ text = str(fn_result)
self._prefetched_text = text
self._buffer = StringIO()
@@ -125,7 +132,11 @@ def can_think(self) -> bool:
@property
def is_thinking(self) -> bool:
- return self.can_think and self._reasoning_parser.is_thinking
+ return (
+ self.can_think
+ and self._reasoning_parser is not None
+ and self._reasoning_parser.is_thinking
+ )
def set_thinking(self, thinking: bool) -> None:
if self._reasoning_parser:
@@ -140,9 +151,16 @@ async def _trigger_consumed(self) -> None:
if iscoroutine(result):
await result
- def __aiter__(self):
+ def __aiter__(self) -> Self:
+ """Return iterator for async iteration over tokens.
+
+ Returns:
+ Self for async iteration.
+ """
# Create a fresh async generator each time we start iterating
- self._output = self._output_fn(*self._args, **self._kwargs)
+ fn_result = self._output_fn(*self._args, **self._kwargs)
+ if not isinstance(fn_result, str):
+ self._output = fn_result
return self
async def __anext__(self) -> Token | TokenDetail | str:
@@ -154,9 +172,10 @@ async def __anext__(self) -> Token | TokenDetail | str:
return self._parser_queue.get()
try:
+ assert self._output is not None
token = await self._output.__anext__()
except StopAsyncIteration:
- if self._reasoning_parser:
+ if self._reasoning_parser and self._parser_queue:
for it in await self._reasoning_parser.flush():
self._parser_queue.put(it)
if not self._parser_queue.empty():
@@ -180,7 +199,11 @@ async def __anext__(self) -> Token | TokenDetail | str:
await self._trigger_consumed()
raise StopAsyncIteration
+ assert self._parser_queue is not None
for it in items:
+ parsed: (
+ Token | TokenDetail | ReasoningToken | ToolCallToken | str
+ )
if isinstance(it, ReasoningToken):
token_id = (
token.id
@@ -188,7 +211,9 @@ async def __anext__(self) -> Token | TokenDetail | str:
else it.id
)
parsed = ReasoningToken(
- token=it.token, id=token_id, probability=it.probability
+ token=it.token,
+ id=token_id if token_id is not None else -1,
+ probability=it.probability,
)
elif isinstance(token, ToolCallToken):
parsed = ToolCallToken(
@@ -227,12 +252,14 @@ async def to_str(self) -> str:
await self._trigger_consumed()
return self._prefetched_text
- # Ensure buffer is filled, wether we were already iterating or not
+ # Ensure buffer is filled, whether we were already iterating or not
if not self._output:
self.__aiter__()
+ assert self._output is not None
async for token in self._output:
- self._buffer.write(token)
+ token_str = token if isinstance(token, str) else token.token
+ self._buffer.write(token_str)
self._output_token_count += 1
await self._trigger_consumed()
@@ -252,7 +279,15 @@ async def to_json(self) -> str:
continue
raise InvalidJsonResponseException(text)
- async def to(self, entity_class: type) -> any:
- json = await self.to_json()
- data = loads(json)
+ async def to(self, entity_class: type[T]) -> T:
+ """Convert JSON response to entity class instance.
+
+ Args:
+ entity_class: The class to instantiate with JSON data.
+
+ Returns:
+ Instance of entity_class populated with JSON data.
+ """
+ json_str = await self.to_json()
+ data = loads(json_str)
return entity_class(**data)
diff --git a/src/avalan/model/transformer.py b/src/avalan/model/transformer.py
index bdf3d26e..32b8ac9a 100644
--- a/src/avalan/model/transformer.py
+++ b/src/avalan/model/transformer.py
@@ -9,7 +9,7 @@
from logging import Logger, getLogger
from typing import Literal
-from tokenizers import AddedToken
+from tokenizers import AddedToken # type: ignore[import-untyped]
from torch import Tensor
from transformers import (
AutoTokenizer,
@@ -20,6 +20,10 @@
class TransformerModel(Engine, ABC):
+ """Base class for transformer-based models."""
+
+ _settings: TransformerEngineSettings
+
@property
def supports_sample_generation(self) -> bool:
return False
@@ -71,12 +75,12 @@ def tokenize(
not hasattr(self, "_loaded_tokenizer")
or not self._loaded_tokenizer
):
- self.load(
- load_model=False,
+ self._load(
load_tokenizer=True,
tokenizer_name_or_path=tokenizer_name_or_path,
)
+ assert self._tokenizer is not None
_l(f'Tokenizing text "{text}"')
token_ids = self._tokenizer.encode(text, add_special_tokens=True)
_l(f'Tokenized text "{text}" into {len(token_ids)} tokens')
@@ -89,7 +93,7 @@ def tokenize(
),
probability=1,
)
- for i, token_id in enumerate(token_ids)
+ for token_id in token_ids
]
def input_token_count(
@@ -98,7 +102,6 @@ def input_token_count(
system_prompt: str | None = None,
developer_prompt: str | None = None,
) -> int:
- _l = self._log
assert self._tokenizer, (
f"Model {self._model} can't be executed "
+ "without a tokenizer loaded first"
@@ -109,20 +112,27 @@ def input_token_count(
developer_prompt=developer_prompt,
context=None,
)
- return (
- len(inputs["input_ids"][0])
- if inputs and "input_ids" in inputs
- else 0
- )
+ if isinstance(inputs, Tensor):
+ return inputs.shape[-1] if inputs.dim() > 0 else 0
+ if isinstance(inputs, dict) and "input_ids" in inputs:
+ input_ids = inputs["input_ids"]
+ if isinstance(input_ids, Tensor):
+ return input_ids.shape[-1] if input_ids.dim() > 0 else 0
+ if isinstance(input_ids, list) and len(input_ids) > 0:
+ return len(input_ids[0])
+ return 0
def _load_tokenizer(
self, tokenizer_name_or_path: str | None, use_fast: bool
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
- return AutoTokenizer.from_pretrained(
- tokenizer_name_or_path or self._model_id,
- use_fast=use_fast,
- subfolder=self._settings.tokenizer_subfolder or "",
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = (
+ AutoTokenizer.from_pretrained(
+ tokenizer_name_or_path or self._model_id,
+ use_fast=use_fast,
+ subfolder=self._settings.tokenizer_subfolder or "",
+ )
)
+ return tokenizer
def _load_tokenizer_with_tokens(
self, tokenizer_name_or_path: str | None, use_fast: bool = True
diff --git a/src/avalan/model/vendor.py b/src/avalan/model/vendor.py
index 874d01ec..65589164 100644
--- a/src/avalan/model/vendor.py
+++ b/src/avalan/model/vendor.py
@@ -8,15 +8,22 @@
ToolCallToken,
)
from ..tool.manager import ToolManager
-from .message import TemplateMessage, TemplateMessageRole
+from .message import (
+ TemplateMessage,
+ TemplateMessageContent,
+ TemplateMessageRole,
+)
from .stream import TextGenerationStream
from abc import ABC
from json import JSONDecodeError, dumps, loads
-from typing import AsyncGenerator
+from typing import AsyncGenerator, cast
+from uuid import uuid4
class TextGenerationVendor(ABC):
+ """Base class for text generation vendor implementations."""
+
async def __call__(
self,
model_id: str,
@@ -29,28 +36,26 @@ async def __call__(
raise NotImplementedError()
def _system_prompt(self, messages: list[Message]) -> str | None:
- return next(
- (
- message.content
- for message in messages
- if message.role == "system"
- ),
- None,
- )
+ for message in messages:
+ if message.role == "system" and isinstance(message.content, str):
+ return message.content
+ return None
def _template_messages(
self,
messages: list[Message],
exclude_roles: list[TemplateMessageRole] | None = None,
) -> list[TemplateMessage]:
- def _block(c: MessageContent) -> dict:
+ def _block(c: MessageContent) -> TemplateMessageContent:
if isinstance(c, MessageContentImage):
- return {"type": "image_url", "image_url": c.image_url}
- return {"type": "text", "text": c.text}
+ return TemplateMessageContent(
+ type="image_url", image_url=c.image_url
+ )
+ return TemplateMessageContent(type="text", text=c.text)
def _wrap(
content: str | MessageContent | list[MessageContent],
- ) -> str | list[dict]:
+ ) -> str | list[TemplateMessageContent]:
if isinstance(content, str):
return content
@@ -70,7 +75,9 @@ def _wrap(
if exclude_roles and msg.role in exclude_roles:
continue
- out.append({"role": str(msg.role), "content": _wrap(msg.content)})
+ content = msg.content if msg.content is not None else ""
+ role = cast(TemplateMessageRole, msg.role)
+ out.append({"role": role, "content": _wrap(content)})
return out
@@ -96,7 +103,7 @@ def build_tool_call_token(
args = {}
else:
args = arguments or {}
- call = ToolCall(id=call_id, name=name, arguments=args)
+ call = ToolCall(id=call_id or str(uuid4()), name=name, arguments=args)
token_json = dumps({"name": name, "arguments": args})
return ToolCallToken(
token=f"{token_json}", call=call
diff --git a/src/avalan/model/vision/classification.py b/src/avalan/model/vision/classification.py
index 2ac506a2..13c0fdb6 100644
--- a/src/avalan/model/vision/classification.py
+++ b/src/avalan/model/vision/classification.py
@@ -4,7 +4,7 @@
from ...model.vendor import TextGenerationVendor
from ...model.vision import BaseVisionModel
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from PIL import Image
@@ -18,36 +18,48 @@
# model predicts one of the 1000 ImageNet classes
class ImageClassificationModel(BaseVisionModel):
+ _processor: AutoImageProcessor
+
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
self._processor = AutoImageProcessor.from_pretrained(
self._model_id,
# default behavior in transformers v4.48
use_fast=True,
)
- model = AutoModelForImageClassification.from_pretrained(
- self._model_id,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
+ model: PreTrainedModel = (
+ AutoModelForImageClassification.from_pretrained(
+ self._model_id,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ )
)
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
tensor_format: Literal["pt"] = "pt",
) -> ImageEntity:
+ assert self._model is not None, "Model must be loaded"
+ assert isinstance(
+ self._model, PreTrainedModel
+ ), "Model must be PreTrainedModel"
image = BaseVisionModel._get_image(image_source)
- inputs = self._processor(image, return_tensors=tensor_format)
+ inputs: Any = self._processor( # type: ignore[operator]
+ image, return_tensors=tensor_format
+ )
inputs.to(self._device)
with inference_mode():
- logits = self._model(**inputs).logits
+ logits = self._model(**inputs).logits # type: ignore[operator]
label_index = logits.argmax(dim=1).item()
- return ImageEntity(label=self._model.config.id2label[label_index])
+ id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment]
+ return ImageEntity(label=id2label[label_index])
diff --git a/src/avalan/model/vision/decoder.py b/src/avalan/model/vision/decoder.py
index 5d7d8d56..674879c9 100644
--- a/src/avalan/model/vision/decoder.py
+++ b/src/avalan/model/vision/decoder.py
@@ -4,7 +4,7 @@
from ...model.vision import BaseVisionModel
from ...model.vision.text import ImageToTextModel
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from PIL import Image
@@ -22,11 +22,12 @@ class VisionEncoderDecoderModel(ImageToTextModel):
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
self._processor = AutoImageProcessor.from_pretrained(
self._model_id,
use_fast=True,
)
- model = VisionEncoderDecoderModelImpl.from_pretrained(
+ model: PreTrainedModel = VisionEncoderDecoderModelImpl.from_pretrained(
self._model_id,
device_map=self._device,
tp_plan=Engine._get_tp_plan(self._settings.parallel),
@@ -37,7 +38,7 @@ def _load_model(
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
prompt: str | None,
@@ -55,8 +56,11 @@ async def __call__(
tensor_format=tensor_format,
)
+ assert self._model is not None, "Model must be loaded"
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
+
image = BaseVisionModel._get_image(image_source)
- pixel_values = self._processor(
+ pixel_values: Any = self._processor( # type: ignore[operator]
image, return_tensors=tensor_format
).pixel_values.to(self._device)
decoder_input_ids = self._tokenizer(
@@ -64,10 +68,11 @@ async def __call__(
).input_ids.to(self._device)
with inference_mode():
- outputs = self._model.generate(
+ # self._model is VisionEncoderDecoderModel at runtime
+ outputs = self._model.generate( # type: ignore[union-attr, operator]
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
- max_length=self._model.decoder.config.max_position_embeddings,
+ max_length=self._model.decoder.config.max_position_embeddings, # type: ignore[union-attr]
early_stopping=early_stopping,
pad_token_id=self._tokenizer.pad_token_id,
eos_token_id=self._tokenizer.eos_token_id,
@@ -77,7 +82,8 @@ async def __call__(
return_dict_in_generate=True,
)
- output = self._tokenizer.batch_decode(
- outputs.sequences, skip_special_tokens=skip_special_tokens
+ output: str = self._tokenizer.batch_decode(
+ outputs.sequences, # type: ignore[union-attr]
+ skip_special_tokens=skip_special_tokens,
)[0]
return output
diff --git a/src/avalan/model/vision/detection.py b/src/avalan/model/vision/detection.py
index d0b110e0..00a04ece 100644
--- a/src/avalan/model/vision/detection.py
+++ b/src/avalan/model/vision/detection.py
@@ -6,7 +6,7 @@
from ...model.vision.classification import ImageClassificationModel
from logging import Logger, getLogger
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from PIL import Image
@@ -32,13 +32,14 @@ def __init__(
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
self._processor = AutoImageProcessor.from_pretrained(
self._model_id,
revision=self._revision,
# default behavior in transformers v4.48
use_fast=True,
)
- model = AutoModelForObjectDetection.from_pretrained(
+ model: PreTrainedModel = AutoModelForObjectDetection.from_pretrained(
self._model_id,
revision=self._revision,
device_map=self._device,
@@ -50,32 +51,38 @@ def _load_model(
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
threshold: float | None = 0.3,
tensor_format: Literal["pt"] = "pt",
) -> list[ImageEntity]:
+ assert self._model is not None, "Model must be loaded"
+ assert isinstance(
+ self._model, PreTrainedModel
+ ), "Model must be PreTrainedModel"
image = BaseVisionModel._get_image(image_source)
- inputs = self._processor(images=image, return_tensors=tensor_format)
+ processor: Any = self._processor
+ inputs: Any = processor(images=image, return_tensors=tensor_format)
inputs.to(self._device)
with inference_mode():
- outputs = self._model(**inputs)
+ outputs = self._model(**inputs) # type: ignore[operator]
target_sizes = tensor([image.size[::-1]])
- results = self._processor.post_process_object_detection(
+ results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
+ id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment]
entities = []
for score, label, box in zip(
results["scores"], results["labels"], results["boxes"]
):
- box = [round(i, 2) for i in box.tolist()]
+ box_coords = [round(i, 2) for i in box.tolist()]
entities.append(
ImageEntity(
- label=self._model.config.id2label[label.item()],
+ label=id2label[label.item()],
score=score.item(),
- box=box,
+ box=box_coords,
)
)
return entities
diff --git a/src/avalan/model/vision/diffusion/animation.py b/src/avalan/model/vision/diffusion/animation.py
index 6cff9652..fd15fc97 100644
--- a/src/avalan/model/vision/diffusion/animation.py
+++ b/src/avalan/model/vision/diffusion/animation.py
@@ -6,6 +6,7 @@
from dataclasses import replace
from logging import Logger, getLogger
+from typing import Any
from diffusers import (
AnimateDiffPipeline,
@@ -40,15 +41,17 @@ def __init__(
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
+ assert self._settings.checkpoint is not None, "Checkpoint is required"
dtype = Engine.weight(self._settings.weight_type)
- adapter = MotionAdapter().to(self._device, dtype)
+ adapter: Any = MotionAdapter().to(self._device, dtype) # type: ignore[attr-defined]
adapter.load_state_dict(
load_file(
hf_hub_download(self._model_id, self._settings.checkpoint),
device=self._device,
)
)
- pipe = AnimateDiffPipeline.from_pretrained(
+ pipe: DiffusionPipeline = AnimateDiffPipeline.from_pretrained(
self._settings.base_model_id,
motion_adapter=adapter,
torch_dtype=dtype,
@@ -57,15 +60,15 @@ def _load_model(
return pipe
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
path: str,
*,
- beta_schedule: BetaSchedule = "linear",
+ beta_schedule: BetaSchedule = BetaSchedule.LINEAR,
guidance_scale: float = 1.0,
steps: int = 4,
- timestep_spacing: TimestepSpacing = "trailing",
+ timestep_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
) -> str:
assert steps and steps in [
1,
@@ -73,21 +76,24 @@ async def __call__(
4,
8,
], f"Invalid number of steps: {steps}, can only be 1, 2, 4, or 8"
+ assert self._model is not None, "Model must be loaded"
+
+ model: Any = self._model
scheduler_settings = (timestep_spacing, beta_schedule)
if scheduler_settings not in self._schedulers:
scheduler = EulerDiscreteScheduler.from_config(
- self._model.scheduler.config,
- timestep_spacing=timestep_spacing,
- beta_schedule=beta_schedule,
+ model.scheduler.config,
+ timestep_spacing=timestep_spacing.value,
+ beta_schedule=beta_schedule.value,
)
self._schedulers[scheduler_settings] = scheduler
else:
scheduler = self._schedulers[scheduler_settings]
- self._model.scheduler = scheduler
+ model.scheduler = scheduler
with inference_mode():
- output = self._model(
+ output = model(
prompt=input if isinstance(input, str) else str(input),
guidance_scale=guidance_scale,
num_inference_steps=steps,
diff --git a/src/avalan/model/vision/diffusion/image.py b/src/avalan/model/vision/diffusion/image.py
index cfcb13d2..946bdac6 100644
--- a/src/avalan/model/vision/diffusion/image.py
+++ b/src/avalan/model/vision/diffusion/image.py
@@ -11,7 +11,7 @@
from dataclasses import replace
from logging import Logger, getLogger
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from torch import inference_mode
@@ -19,7 +19,7 @@
class TextToImageModel(BaseVisionModel):
- _base: DiffusionPipeline
+ _base: Any
def __init__(
self,
@@ -38,7 +38,7 @@ def _load_model(
dtype = Engine.weight(self._settings.weight_type)
dtype_variant = self._settings.weight_type
- base = DiffusionPipeline.from_pretrained(
+ base: Any = DiffusionPipeline.from_pretrained(
self._model_id,
torch_dtype=dtype,
variant=dtype_variant,
@@ -46,7 +46,7 @@ def _load_model(
)
base.to(self._device)
- refiner = DiffusionPipeline.from_pretrained(
+ refiner: Any = DiffusionPipeline.from_pretrained(
self._settings.refiner_model_id,
text_encoder_2=base.text_encoder_2,
vae=base.vae,
@@ -58,10 +58,10 @@ def _load_model(
self._base = base
- return refiner
+ return refiner # type: ignore[no-any-return]
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
path: str,
@@ -81,7 +81,9 @@ async def __call__(
and n_steps
and output_type
)
+ assert self._model is not None, "Model must be loaded"
+ model: Any = self._model
with inference_mode():
image = self._base(
prompt=input if isinstance(input, str) else str(input),
@@ -89,7 +91,7 @@ async def __call__(
denoising_end=high_noise_frac,
output_type=output_type,
).images
- image = self._model(
+ image = model(
prompt=input if isinstance(input, str) else str(input),
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
diff --git a/src/avalan/model/vision/diffusion/video.py b/src/avalan/model/vision/diffusion/video.py
index 18411481..cb4ddcba 100644
--- a/src/avalan/model/vision/diffusion/video.py
+++ b/src/avalan/model/vision/diffusion/video.py
@@ -6,6 +6,7 @@
from dataclasses import replace
from logging import Logger, getLogger
+from typing import Any
from diffusers import DiffusionPipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
@@ -32,7 +33,7 @@ def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
dtype = Engine.weight(self._settings.weight_type)
- base_pipe = DiffusionPipeline.from_pretrained(
+ base_pipe: Any = DiffusionPipeline.from_pretrained(
self._model_id,
torch_dtype=dtype,
).to(self._device)
@@ -42,10 +43,10 @@ def _load_model(
torch_dtype=dtype,
).to(self._device)
base_pipe.vae.enable_tiling()
- return base_pipe
+ return base_pipe # type: ignore[no-any-return]
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
input: Input,
negative_prompt: str,
@@ -63,22 +64,24 @@ async def __call__(
width: int = 832,
steps: int = 30,
) -> str:
+ assert self._model is not None, "Model must be loaded"
image = load_image(reference_path)
video = load_video(export_to_video([image]))
condition = LTXVideoCondition(video=video, frame_index=0)
+ model: Any = self._model
down_h = int(height * downscale)
down_w = int(width * downscale)
down_h, down_w = (
TextToVideoModel._round_to_nearest_resolution_acceptable_by_vae(
down_h,
down_w,
- ratio=self._model.vae_spatial_compression_ratio,
+ ratio=model.vae_spatial_compression_ratio,
)
)
with inference_mode():
- latents = self._model(
+ latents = model(
conditions=[condition],
prompt=input if isinstance(input, str) else str(input),
negative_prompt=negative_prompt,
@@ -91,11 +94,12 @@ async def __call__(
).frames
upscaled_h, upscaled_w = down_h * 2, down_w * 2
- upscaled_latents = self._upsampler_pipe(
+ upsampler: Any = self._upsampler_pipe
+ upscaled_latents = upsampler(
latents=latents, output_type="latent"
).frames
- video = self._model(
+ video = model(
conditions=[condition],
prompt=input if isinstance(input, str) else str(input),
negative_prompt=negative_prompt,
diff --git a/src/avalan/model/vision/segmentation.py b/src/avalan/model/vision/segmentation.py
index 3cdd0652..710b84e7 100644
--- a/src/avalan/model/vision/segmentation.py
+++ b/src/avalan/model/vision/segmentation.py
@@ -3,7 +3,7 @@
from ...model.vendor import TextGenerationVendor
from ...model.vision import BaseVisionModel
-from typing import Literal
+from typing import Any, Literal
from diffusers import DiffusionPipeline
from PIL import Image
@@ -16,40 +16,50 @@
class SemanticSegmentationModel(BaseVisionModel):
+ _processor: AutoImageProcessor
+
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
self._processor = AutoImageProcessor.from_pretrained(
self._model_id,
# default behavior in transformers v4.48
use_fast=True,
)
- model = AutoModelForSemanticSegmentation.from_pretrained(
- self._model_id,
- device_map=self._device,
- tp_plan=Engine._get_tp_plan(self._settings.parallel),
- distributed_config=Engine._get_distributed_config(
- self._settings.distributed_config
- ),
+ model: PreTrainedModel = (
+ AutoModelForSemanticSegmentation.from_pretrained(
+ self._model_id,
+ device_map=self._device,
+ tp_plan=Engine._get_tp_plan(self._settings.parallel),
+ distributed_config=Engine._get_distributed_config(
+ self._settings.distributed_config
+ ),
+ )
)
model.eval()
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
tensor_format: Literal["pt"] = "pt",
) -> list[str]:
+ assert self._model is not None, "Model must be loaded"
+ assert isinstance(
+ self._model, PreTrainedModel
+ ), "Model must be PreTrainedModel"
image = BaseVisionModel._get_image(image_source)
- inputs = self._processor(images=image, return_tensors=tensor_format)
+ inputs: Any = self._processor( # type: ignore[operator]
+ images=image, return_tensors=tensor_format
+ )
inputs.to(self._device)
with inference_mode():
- logits = self._model(**inputs).logits
+ logits = self._model(**inputs).logits # type: ignore[operator]
# shape (height, width) with class indices
mask = logits.argmax(dim=1)[0]
labels_tensor = unique(mask)
- labels = [
- self._model.config.id2label[idx.item()] for idx in labels_tensor
- ]
+ id2label: dict[int, str] = self._model.config.id2label # type: ignore[assignment]
+ labels = [id2label[idx.item()] for idx in labels_tensor]
return labels
diff --git a/src/avalan/model/vision/text.py b/src/avalan/model/vision/text.py
index 0fc28876..1c6b3144 100644
--- a/src/avalan/model/vision/text.py
+++ b/src/avalan/model/vision/text.py
@@ -1,7 +1,6 @@
from ...compat import override
from ...entities import (
GenerationSettings,
- ImageTextGenerationLoaderClass,
Input,
MessageRole,
)
@@ -10,7 +9,7 @@
from ...model.vendor import TextGenerationVendor
from ...model.vision import BaseVisionModel
-from typing import Literal
+from typing import Any, Literal, cast
from diffusers import DiffusionPipeline
from PIL import Image
@@ -33,12 +32,13 @@ class ImageToTextModel(TransformerModel):
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
self._processor = AutoImageProcessor.from_pretrained(
self._model_id,
# default behavior in transformers v4.48
use_fast=True,
)
- model = AutoModelForVision2Seq.from_pretrained(
+ model: PreTrainedModel = AutoModelForVision2Seq.from_pretrained(
self._model_id,
device_map=self._device,
tp_plan=Engine._get_tp_plan(self._settings.parallel),
@@ -58,48 +58,55 @@ def _tokenize_input(
raise NotImplementedError()
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
*,
skip_special_tokens: bool = True,
tensor_format: Literal["pt"] = "pt",
) -> str:
+ assert self._model is not None, "Model must be loaded"
+ assert self._tokenizer is not None, "Tokenizer must be loaded"
image = BaseVisionModel._get_image(image_source)
- inputs = self._processor(images=image, return_tensors=tensor_format)
+ inputs: Any = self._processor( # type: ignore[operator]
+ images=image, return_tensors=tensor_format
+ )
inputs.to(self._device)
with inference_mode():
- output_ids = self._model.generate(**inputs)
+ model = cast(PreTrainedModel, self._model)
+ output_ids = model.generate(**inputs) # type: ignore[operator]
- output = self._tokenizer.decode(
+ output: str = self._tokenizer.decode(
output_ids[0], skip_special_tokens=skip_special_tokens
)
return output
class ImageTextToTextModel(ImageToTextModel):
- _loaders: dict[ImageTextGenerationLoaderClass, type[PreTrainedModel]] = {
- "auto": AutoModelForImageTextToText,
- "qwen2": Qwen2VLForConditionalGeneration,
- "gemma3": Gemma3ForConditionalGeneration,
+ _loaders: dict[str, type[PreTrainedModel]] = {
+ "auto": AutoModelForImageTextToText, # type: ignore[dict-item]
+ "qwen2": Qwen2VLForConditionalGeneration, # type: ignore[dict-item]
+ "gemma3": Gemma3ForConditionalGeneration, # type: ignore[dict-item]
}
def _load_model(
self,
) -> PreTrainedModel | TextGenerationVendor | DiffusionPipeline:
+ assert self._model_id is not None, "Model ID is required"
+ loader_class = self._settings.loader_class or "auto"
assert (
- self._settings.loader_class in self._loaders
- ), f"Unrecognized loader {self._settings.loader_class}"
+ loader_class in self._loaders
+ ), f"Unrecognized loader {loader_class}"
self._processor = AutoProcessor.from_pretrained(
self._model_id,
use_fast=True,
)
- loader = self._loaders[self._settings.loader_class]
- model = loader.from_pretrained(
+ loader = self._loaders[loader_class]
+ model: PreTrainedModel = loader.from_pretrained(
self._model_id,
torch_dtype=Engine.weight(self._settings.weight_type),
device_map=self._device,
@@ -111,7 +118,7 @@ def _load_model(
return model
@override
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
image_source: str | Image.Image,
prompt: str,
@@ -123,6 +130,8 @@ async def __call__(
skip_special_tokens: bool = True,
tensor_format: Literal["pt"] = "pt",
) -> str:
+ assert self._model is not None, "Model must be loaded"
+ assert settings is not None, "Generation settings must be provided"
image = BaseVisionModel._get_image(image_source).convert("RGB")
assert image.width
@@ -131,7 +140,7 @@ async def __call__(
height = int(ratio * image.height)
image = image.resize((width, height), Image.Resampling.LANCZOS)
- messages = []
+ messages: list[dict[str, Any]] = []
if system_prompt:
messages.append(
{
@@ -156,12 +165,13 @@ async def __call__(
}
)
- text = self._processor.apply_chat_template(
+ processor: Any = self._processor
+ text: str = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=settings.chat_settings.add_generation_prompt,
)
- inputs = self._processor(
+ inputs: Any = processor(
text=[text],
images=image,
videos=None,
@@ -173,14 +183,15 @@ async def __call__(
inputs.to(self._device)
with inference_mode():
- generated_ids = self._model.generate(
+ model = cast(PreTrainedModel, self._model)
+ generated_ids = model.generate( # type: ignore[operator]
**inputs, max_new_tokens=settings.max_new_tokens
)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
- output_text = self._processor.batch_decode(
+ output_text: list[str] = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=False,
diff --git a/src/avalan/secrets/aws.py b/src/avalan/secrets/aws.py
index a71d5fac..24562df3 100644
--- a/src/avalan/secrets/aws.py
+++ b/src/avalan/secrets/aws.py
@@ -1,22 +1,29 @@
from . import Secrets
+from typing import Any
+
from boto3 import client
class AwsSecrets(Secrets):
+ """Secrets backend using AWS Secrets Manager."""
+
_SERVICE = "secretsmanager"
- def __init__(self, aws_client: object | None = None):
- self._client = aws_client or client(self._SERVICE)
+ def __init__(self, aws_client: Any | None = None) -> None:
+ self._client: Any = aws_client or client(self._SERVICE)
def read(self, key: str) -> str | None:
- response = self._client.get_secret_value(SecretId=key)
+ """Return secret stored under *key*."""
+ response: dict[str, Any] = self._client.get_secret_value(SecretId=key)
return response.get("SecretString")
def write(self, key: str, secret: str) -> None:
+ """Store *secret* under *key*."""
self._client.put_secret_value(SecretId=key, SecretString=secret)
def delete(self, key: str) -> None:
+ """Remove secret associated with *key*."""
self._client.delete_secret(
SecretId=key,
ForceDeleteWithoutRecovery=True,
diff --git a/src/avalan/secrets/keyring.py b/src/avalan/secrets/keyring.py
index 1e52cc44..a1f584f9 100644
--- a/src/avalan/secrets/keyring.py
+++ b/src/avalan/secrets/keyring.py
@@ -4,8 +4,8 @@
from keyring import get_keyring
from keyring.backend import KeyringBackend
except Exception: # pragma: no cover - optional dependency
- get_keyring = None # type: ignore[assignment]
- KeyringBackend = object # type: ignore[assignment]
+ get_keyring = None # type: ignore[assignment, misc]
+ KeyringBackend = object # type: ignore[assignment, misc]
class KeyringSecrets(Secrets):
@@ -14,7 +14,7 @@ class KeyringSecrets(Secrets):
_SERVICE = "avalan"
def __init__(self, ring: KeyringBackend | None = None) -> None:
- if ring is None and get_keyring:
+ if ring is None and get_keyring is not None:
ring = get_keyring()
self._ring = ring
diff --git a/src/avalan/server/__init__.py b/src/avalan/server/__init__.py
index 113d4510..c46d6158 100644
--- a/src/avalan/server/__init__.py
+++ b/src/avalan/server/__init__.py
@@ -7,12 +7,12 @@
from .entities import OrchestratorContext
from .routers import mcp as mcp_router
-from collections.abc import AsyncIterator, Callable
+from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from importlib import import_module
from importlib.util import find_spec
from logging import Logger
-from typing import TYPE_CHECKING, Mapping
+from typing import TYPE_CHECKING, AsyncContextManager, Mapping, cast
from uuid import UUID, uuid4
from fastapi import FastAPI, Request
@@ -91,9 +91,9 @@ def _create_lifespan(
selected_protocols: Mapping[str, set[str]],
agent_id: UUID | None,
participant_id: UUID | None,
-) -> Callable[[FastAPI], AsyncIterator[None]]:
+) -> Callable[[FastAPI], AsyncContextManager[None]]:
@asynccontextmanager
- async def lifespan(app: FastAPI):
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("Initializing app lifespan")
from os import environ
@@ -199,7 +199,7 @@ def _include_protocol_routers(
def _attach_lifespan(
- app: FastAPI, lifespan: Callable[[FastAPI], AsyncIterator[None]]
+ app: FastAPI, lifespan: Callable[[FastAPI], AsyncContextManager[None]]
) -> None:
existing = app.router.lifespan_context
@@ -208,7 +208,7 @@ def _attach_lifespan(
return
@asynccontextmanager
- async def combined(app_: FastAPI):
+ async def combined(app_: FastAPI) -> AsyncGenerator[None, None]:
async with existing(app_):
async with lifespan(app_):
yield
@@ -406,4 +406,4 @@ async def di_get_orchestrator(request: Request) -> Orchestrator:
request.app.state.agent_id = orchestrator.id
orchestrator = request.app.state.orchestrator
assert orchestrator is not None
- return orchestrator
+ return cast(Orchestrator, orchestrator)
diff --git a/src/avalan/server/a2a/router.py b/src/avalan/server/a2a/router.py
index 777f4ef8..f2e6a6f6 100644
--- a/src/avalan/server/a2a/router.py
+++ b/src/avalan/server/a2a/router.py
@@ -24,7 +24,7 @@
from logging import Logger
from re import compile
from time import time
-from typing import TYPE_CHECKING, Any, Final, Iterable
+from typing import TYPE_CHECKING, Any, Final, Iterable, cast
from uuid import uuid4
if TYPE_CHECKING:
@@ -643,7 +643,7 @@ async def _ensure_answer_artifact(self) -> list[dict[str, Any]]:
async def _handle_tool_process(self, event: Event) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
- payload = event.payload or []
+ payload: dict[str, Any] | list[Any] = event.payload or []
if isinstance(payload, dict):
calls: Iterable[ToolCall] = payload.get("calls", []) # type: ignore[assignment]
else:
@@ -908,11 +908,9 @@ async def _response_id(self) -> str | int | None:
overview = await self._store.get_task_overview(self._task_id)
metadata = overview.get("metadata") or {}
self._cached_response_id = metadata.get("jsonrpc_id")
- return (
- None
- if self._cached_response_id is _STREAM_RESPONSE_ID_UNSET
- else self._cached_response_id
- )
+ if self._cached_response_id is _STREAM_RESPONSE_ID_UNSET:
+ return None
+ return cast(str | int | None, self._cached_response_id)
async def _task_result(self, event: dict[str, Any]) -> a2a_types.Task:
overview = await self._store.get_task_overview(self._task_id)
@@ -1315,7 +1313,7 @@ def _call_identifier(item: Token | TokenDetail | Event | str) -> str | None:
return str(item.call.id)
if isinstance(item, Event):
if item.type is EventType.TOOL_PROCESS:
- payload = item.payload or []
+ payload: dict[str, Any] | list[Any] = item.payload or []
if isinstance(payload, dict):
candidates = payload.get("calls", [])
else:
diff --git a/src/avalan/server/routers/__init__.py b/src/avalan/server/routers/__init__.py
index 7688a4ce..4e55c145 100644
--- a/src/avalan/server/routers/__init__.py
+++ b/src/avalan/server/routers/__init__.py
@@ -1,4 +1,7 @@
from ...agent.orchestrator import Orchestrator
+from ...agent.orchestrator.response.orchestrator_response import (
+ OrchestratorResponse,
+)
from ...entities import (
GenerationSettings,
Message,
@@ -15,7 +18,7 @@
from logging import Logger
from time import time
-from uuid import uuid4
+from uuid import UUID, uuid4
from fastapi import HTTPException
@@ -24,7 +27,7 @@ async def orchestrate(
request: ChatCompletionRequest,
logger: Logger,
orchestrator: Orchestrator,
-):
+) -> tuple[OrchestratorResponse, UUID, int]:
messages = [
Message(role=req.role, content=to_message_content(req.content))
for req in request.messages
@@ -40,7 +43,7 @@ async def orchestrate(
timestamp = int(time())
settings = GenerationSettings(
- use_async_generator=request.stream,
+ use_async_generator=request.stream or False,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
stop_strings=request.stop,
@@ -59,13 +62,10 @@ async def orchestrate(
return response, response_id, timestamp
-def to_message_content(item):
- if isinstance(item, list):
- return [
- to_message_content(i)
- for i in item
- if isinstance(i, (ContentImage, ContentText, str))
- ]
+def _convert_single_content(
+ item: str | ContentText | ContentImage,
+) -> MessageContentText | MessageContentImage:
+ """Convert a single content item to message content type."""
if isinstance(item, ContentImage):
return MessageContentImage(type=item.type, image_url=item.image_url)
if isinstance(item, ContentText):
@@ -73,3 +73,22 @@ def to_message_content(item):
if isinstance(item, str):
return MessageContentText(type="text", text=item)
raise TypeError(f"Unsupported content type: {type(item).__name__}")
+
+
+def to_message_content(
+ item: str | list[ContentText | ContentImage],
+) -> (
+ MessageContentText
+ | MessageContentImage
+ | list[MessageContentText | MessageContentImage]
+):
+ """Convert request content to message content types."""
+ if isinstance(item, list):
+ return [
+ _convert_single_content(i)
+ for i in item
+ if isinstance(i, (ContentImage, ContentText, str))
+ ]
+ if not isinstance(item, (str, ContentText, ContentImage)):
+ raise TypeError(f"Unsupported content type: {type(item).__name__}")
+ return _convert_single_content(item)
diff --git a/src/avalan/server/routers/chat.py b/src/avalan/server/routers/chat.py
index 5610f727..1285782e 100644
--- a/src/avalan/server/routers/chat.py
+++ b/src/avalan/server/routers/chat.py
@@ -1,5 +1,11 @@
from ...agent.orchestrator import Orchestrator
-from ...entities import MessageRole, ReasoningToken, ToolCallToken
+from ...entities import (
+ MessageRole,
+ ReasoningToken,
+ Token,
+ TokenDetail,
+ ToolCallToken,
+)
from ...event import Event
from ...server.entities import (
ChatCompletionChoice,
@@ -31,7 +37,7 @@ async def create_chat_completion(
request: ChatCompletionRequest,
logger: Logger = Depends(di_get_logger),
orchestrator: Orchestrator = Depends(di_get_orchestrator),
-):
+) -> ChatCompletionResponse | StreamingResponse:
assert orchestrator and isinstance(orchestrator, Orchestrator)
assert logger and isinstance(logger, Logger)
assert request and request.messages
@@ -60,11 +66,16 @@ async def generate_chunks():
async for token in response:
if isinstance(token, Event):
continue
- elif isinstance(token, (ReasoningToken, ToolCallToken)):
- token = token.token
+ token_text: str
+ if isinstance(token, (ReasoningToken, ToolCallToken)):
+ token_text = token.token
+ elif isinstance(token, (Token, TokenDetail)):
+ token_text = token.token
+ else:
+ token_text = str(token)
choice = ChatCompletionChunkChoice(
- delta=ChatCompletionChunkChoiceDelta(content=token)
+ delta=ChatCompletionChunkChoiceDelta(content=token_text)
)
chunk = ChatCompletionChunk(
id=str(response_id),
@@ -92,13 +103,13 @@ async def generate_chunks():
choices = [
ChatCompletionChoice(
index=i,
- message=ChatMessage(role=str(MessageRole.ASSISTANT), content=text),
+ message=ChatMessage(role=MessageRole.ASSISTANT, content=text),
finish_reason="stop",
)
for i in range(request.n or 1)
]
usage = ChatCompletionUsage()
- response = ChatCompletionResponse(
+ chat_response = ChatCompletionResponse(
id=str(response_id),
created=timestamp,
model=request.model,
@@ -106,9 +117,9 @@ async def generate_chunks():
usage=usage,
)
logger.debug(
- "Generated chat completion response #%s %r", response_id, response
+ "Generated chat completion response #%s %r", response_id, chat_response
)
await orchestrator.sync_messages()
- return response
+ return chat_response
diff --git a/src/avalan/server/routers/mcp.py b/src/avalan/server/routers/mcp.py
index a4d053a7..d64fd464 100644
--- a/src/avalan/server/routers/mcp.py
+++ b/src/avalan/server/routers/mcp.py
@@ -1,5 +1,6 @@
from ...agent.orchestrator import Orchestrator
from ...entities import (
+ MessageRole,
ReasoningToken,
Token,
TokenDetail,
@@ -345,7 +346,7 @@ async def _expect_jsonrpc_message(
return message, messages
-def _server_info(request: Request) -> dict[str, str]:
+def _server_info(request: Request) -> dict[str, JSONValue]:
app = request.app
name = getattr(app, "title", None) or "avalan"
version = getattr(app, "version", None)
@@ -397,7 +398,11 @@ def _build_chat_request(
model_id = _default_model_id(orchestrator)
return ChatCompletionRequest(
model=model_id,
- messages=[ChatMessage(role="user", content=tool_request.input_string)],
+ messages=[
+ ChatMessage(
+ role=MessageRole.USER, content=tool_request.input_string
+ )
+ ],
stream=True,
)
@@ -409,7 +414,7 @@ async def _start_tool_streaming_response(
request_id: str | int,
tool_request: MCPToolRequest,
progress_token: str,
-) -> StreamingResponse:
+) -> StreamingResponse | JSONResponse:
chat_request = _build_chat_request(tool_request, orchestrator)
response, response_uuid, timestamp = await orchestrate(
chat_request, logger, orchestrator
@@ -580,7 +585,7 @@ def _handle_list_tools_message(
return JSONResponse(payload)
-def _collect_tool_descriptions(request: Request) -> list[dict[str, JSONValue]]:
+def _collect_tool_descriptions(request: Request) -> list[JSONValue]:
name = cast(str, getattr(request.app.state, "mcp_tool_name", "run"))
description = cast(
str,
@@ -613,7 +618,7 @@ def _extract_call_arguments(
raise HTTPException(
status_code=400, detail="Invalid tool arguments"
)
- return cast(dict[str, JSONValue], arguments)
+ return arguments
raise HTTPException(
status_code=400, detail=f'Unsupported MCP method "{method}"'
@@ -655,7 +660,9 @@ async def _stream_mcp_response(
resources: dict[str, MCPResource] = {}
finished_normally = False
- def emit(message: JSONObject) -> Iterator[bytes]:
+ def emit(
+ message: JSONObject | JSONRPCResult | JSONRPCError,
+ ) -> Iterator[bytes]:
encoded = dumps(message, separators=(",", ":")) + "\n"
yield encoded.encode("utf-8")
@@ -697,18 +704,18 @@ def emit(message: JSONObject) -> Iterator[bytes]:
continue
if isinstance(item, ToolCallToken):
- notification = _tool_call_token_notification(item)
- if notification is not None:
- for payload in emit(notification):
+ tool_notification = _tool_call_token_notification(item)
+ if tool_notification is not None:
+ for payload in emit(tool_notification):
yield payload
continue
text = _token_text(item)
if text:
+ answer_chunks.append(text)
if isinstance(item, Token):
- answer_chunks.append(text)
- notification: JSONObject = {
+ answer_notification: JSONObject = {
"jsonrpc": "2.0",
"method": "notifications/message",
"params": {
@@ -720,8 +727,7 @@ def emit(message: JSONObject) -> Iterator[bytes]:
},
}
else:
- answer_chunks.append(text)
- notification: JSONObject = {
+ answer_notification = {
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {
@@ -732,7 +738,7 @@ def emit(message: JSONObject) -> Iterator[bytes]:
},
},
}
- for payload in emit(notification):
+ for payload in emit(answer_notification):
yield payload
finished_normally = not cancel_event.is_set()
@@ -757,15 +763,15 @@ def emit(message: JSONObject) -> Iterator[bytes]:
await _close_response_iterator(response)
for resource in resources.values():
closed = await resource_store.close(resource.id)
- notification = _resource_notification(closed)
- for payload in emit(notification):
+ resource_notification = _resource_notification(closed)
+ for payload in emit(resource_notification):
yield payload
- error_message: JSONRPCError = {
+ cancel_error: JSONRPCError = {
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32000, "message": "Request cancelled"},
}
- for payload in emit(error_message):
+ for payload in emit(cancel_error):
yield payload
await orchestrator.sync_messages()
return
@@ -830,11 +836,13 @@ def _token_text(item: ResponseItem) -> str:
async def _close_response_iterator(response: StreamResponse) -> None:
iterator = getattr(response, "_response_iterator", None)
- if iterator and hasattr(iterator, "aclose"):
- try:
- await cast(AsyncIterator[object], iterator).aclose()
- except Exception: # pragma: no cover - best effort cleanup
- pass
+ if iterator is not None:
+ aclose = getattr(iterator, "aclose", None)
+ if aclose is not None:
+ try:
+ await aclose()
+ except Exception: # pragma: no cover - best effort cleanup
+ pass
def _tool_call_token_notification(
@@ -858,7 +866,7 @@ def _tool_call_token_notification(
delta: dict[str, JSONValue] = {
"id": str(token.call.id),
"name": token.call.name,
- "arguments": token.call.arguments,
+ "arguments": cast(JSONValue, token.call.arguments),
}
return {
"jsonrpc": "2.0",
@@ -886,7 +894,7 @@ async def _tool_event_notifications(
if item is None:
return
- tool_call_id = item["id"]
+ tool_call_id = cast(str, item["id"])
if event.type is EventType.TOOL_PROCESS:
tool_summaries[tool_call_id] = {
@@ -910,7 +918,7 @@ async def _tool_event_notifications(
}
return
- tool_summary = tool_summaries.setdefault(
+ tool_summary: dict[str, JSONValue] = tool_summaries.setdefault(
tool_call_id,
{
"id": tool_call_id,
@@ -938,7 +946,8 @@ async def _tool_event_notifications(
},
}
- message = cast(dict[str, JSONValue], payload["params"]["message"])
+ params_dict = cast(dict[str, JSONValue], payload["params"])
+ message = cast(dict[str, JSONValue], params_dict["message"])
if "error" in item:
message["error"] = item["error"]
@@ -946,8 +955,9 @@ async def _tool_event_notifications(
elif "result" in item:
message["resultDelta"] = item["result"]
tool_summary["result"] = item["result"]
+ result_value = item["result"]
for resource_key, payload2 in _extract_append_streams(
- tool_call_id, item["result"]
+ tool_call_id, result_value
).items():
name, text = payload2
resource = resources.get(resource_key)
@@ -959,7 +969,11 @@ async def _tool_event_notifications(
resource = await resource_store.append(resource.id, text)
resources[resource_key] = resource
yield _resource_notification(resource)
- tool_summary.setdefault("resources", []).append(
+ resources_list = cast(
+ list[dict[str, JSONValue]],
+ tool_summary.setdefault("resources", []),
+ )
+ resources_list.append(
{
"uri": resource.uri,
"name": name,
@@ -1021,11 +1035,11 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None:
return {
"id": str(tool_result.call.id),
"name": tool_result.name,
- "arguments": tool_result.arguments,
+ "arguments": cast(JSONValue, tool_result.arguments),
"error": tool_result.message,
}
if isinstance(tool_result, ToolCallResult):
- result: JSONValue = (
+ result_value: JSONValue = (
tool_result.result
if isinstance(
tool_result.result, (dict, list, str, int, float, bool)
@@ -1035,8 +1049,8 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None:
return {
"id": str(tool_result.call.id),
"name": tool_result.name,
- "arguments": tool_result.arguments,
- "result": result,
+ "arguments": cast(JSONValue, tool_result.arguments),
+ "result": result_value,
}
if isinstance(event.payload, list) and event.payload:
call = event.payload[0]
@@ -1051,7 +1065,7 @@ def _tool_call_event_item(event: Event) -> dict[str, JSONValue] | None:
return {
"id": str(call.id),
"name": call.name,
- "arguments": call.arguments,
+ "arguments": cast(JSONValue, call.arguments),
}
diff --git a/src/avalan/server/routers/responses.py b/src/avalan/server/routers/responses.py
index f14fba2a..606f57a9 100644
--- a/src/avalan/server/routers/responses.py
+++ b/src/avalan/server/routers/responses.py
@@ -8,7 +8,7 @@
ToolCallToken,
)
from ...event import Event, EventType
-from ...server.entities import ResponsesRequest
+from ...server.entities import ChatCompletionRequest, ResponsesRequest
from ...utils import to_json
from .. import di_get_logger, di_get_orchestrator
from ..sse import sse_headers, sse_message
@@ -16,6 +16,7 @@
from enum import Enum, auto
from logging import Logger
+from typing import Any
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
@@ -30,18 +31,31 @@ class ResponseState(Enum):
router = APIRouter(tags=["responses"])
-@router.post("/responses")
+@router.post("/responses", response_model=None)
async def create_response(
request: ResponsesRequest,
logger: Logger = Depends(di_get_logger),
orchestrator: Orchestrator = Depends(di_get_orchestrator),
-):
+) -> StreamingResponse | dict[str, Any]:
+ """Create a response using the OpenAI Responses API format."""
assert orchestrator and isinstance(orchestrator, Orchestrator)
assert logger and isinstance(logger, Logger)
assert request and request.messages
+ chat_request = ChatCompletionRequest(
+ model=request.model,
+ messages=request.messages,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ n=request.n,
+ stream=request.stream,
+ stop=request.stop,
+ max_tokens=request.max_tokens,
+ response_format=request.response_format,
+ )
+
response, response_id, timestamp = await orchestrate(
- request, logger, orchestrator
+ chat_request, logger, orchestrator
)
if request.stream:
@@ -69,19 +83,20 @@ async def generate():
tool_call_id: str | None = None
async for token in response:
- is_event = isinstance(token, Event)
- if is_event and token.type not in (
- EventType.TOOL_PROCESS,
- EventType.TOOL_RESULT,
- ):
- continue
+ if isinstance(token, Event):
+ if token.type not in (
+ EventType.TOOL_PROCESS,
+ EventType.TOOL_RESULT,
+ ):
+ continue
call_id: str | None = None
- if is_event and token.type in (
+ if isinstance(token, Event) and token.type in (
EventType.TOOL_PROCESS,
EventType.TOOL_RESULT,
):
- call_id = _tool_call_event_item(token)["id"]
+ event_item = _tool_call_event_item(token)
+ call_id = str(event_item["id"])
elif (
isinstance(token, ToolCallToken) and token.call is not None
):
@@ -259,8 +274,6 @@ def _switch_state(
current_tool_call_id: str | None,
new_tool_call_id: str | None,
) -> list[str]:
- new_state: ResponseState | None
-
events: list[str] = []
changed = state is not new_state or (
state is ResponseState.TOOL_CALLING
@@ -295,7 +308,15 @@ def _switch_state(
def _new_state(
- token: ReasoningToken | ToolCallToken | Token | TokenDetail | str | None,
+ token: (
+ ReasoningToken
+ | ToolCallToken
+ | Token
+ | TokenDetail
+ | Event
+ | str
+ | None
+ ),
) -> ResponseState | None:
if isinstance(token, ReasoningToken):
new_state = ResponseState.REASONING
@@ -407,16 +428,36 @@ def _content_part_done(id: str | None = None) -> str:
return sse_message(to_json(data), event="response.content_part.done")
-def _tool_call_event_item(event: Event) -> dict:
- tool_result = (
- event.payload["result"]
- if event.type == EventType.TOOL_RESULT and "result" in event.payload
- else None
- )
- tool_call = (
- tool_result.call if tool_result is not None else event.payload[0]
- )
- item = {
+def _tool_call_event_item(event: Event) -> dict[str, Any]:
+ """Extract tool call item from an event payload."""
+ payload = event.payload
+ assert payload is not None, "Event payload must be provided"
+
+ tool_result: Any = None
+ if (
+ event.type == EventType.TOOL_RESULT
+ and isinstance(payload, dict)
+ and "result" in payload
+ ):
+ tool_result = payload["result"]
+
+ tool_call: Any
+ if tool_result is not None and hasattr(tool_result, "call"):
+ tool_call = tool_result.call
+ elif isinstance(payload, list) and payload:
+ tool_call = payload[0]
+ elif isinstance(payload, dict):
+ calls = payload.get("calls")
+ if isinstance(calls, list) and calls:
+ tool_call = calls[0]
+ else:
+ call = payload.get("call")
+ assert call is not None, "Event must have calls or call field"
+ tool_call = call
+ else:
+ raise ValueError("Invalid event payload type")
+
+ item: dict[str, Any] = {
"type": "function_call",
"id": str(tool_call.id),
"name": tool_call.name,
diff --git a/src/avalan/tool/__init__.py b/src/avalan/tool/__init__.py
index d4a0f2d3..832374c0 100644
--- a/src/avalan/tool/__init__.py
+++ b/src/avalan/tool/__init__.py
@@ -1,11 +1,14 @@
-from __future__ import annotations
-
from abc import ABC
from collections.abc import Callable, Sequence
-from contextlib import AsyncExitStack, ContextDecorator
+from contextlib import (
+ AbstractAsyncContextManager,
+ AbstractContextManager,
+ AsyncExitStack,
+ ContextDecorator,
+)
from inspect import Signature, isfunction, signature
-from types import FunctionType
-from typing import get_type_hints
+from types import FunctionType, TracebackType
+from typing import Any, Union, cast, get_type_hints
from transformers.utils import get_json_schema
@@ -44,19 +47,20 @@ def _get_signature(
return_annotation=function_signature.return_annotation,
)
- async def __aenter__(self) -> "ToolSet":
+ async def __aenter__(self) -> "Tool":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- traceback: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
if self._exit_stack:
- return await self._exit_stack.__aexit__(
+ result = await self._exit_stack.__aexit__(
exc_type, exc_value, traceback
)
+ return result if result is not None else False
return True
@@ -65,14 +69,14 @@ class ToolSet(ContextDecorator):
_namespace: str | None
_exit_stack: AsyncExitStack
- _tools: Sequence[Callable]
+ _tools: list[Callable[..., Any]]
@property
def namespace(self) -> str | None:
return self._namespace
@property
- def tools(self) -> Sequence[Callable]:
+ def tools(self) -> list[Callable[..., Any]]:
return self._tools
def __init__(
@@ -80,15 +84,15 @@ def __init__(
*,
exit_stack: AsyncExitStack | None = None,
namespace: str | None = None,
- tools: Sequence[Callable | "ToolSet"],
+ tools: Sequence[Union[Callable[..., Any], "ToolSet"]],
):
self._namespace = namespace
self._exit_stack = exit_stack or AsyncExitStack()
- self._tools = tools
+ self._tools = list(tools)
exclude_type_names = ["self", "context"]
- for i, tool in enumerate(self.tools):
+ for i, tool in enumerate(self._tools):
if (
not isfunction(tool)
and callable(tool)
@@ -102,12 +106,16 @@ def __init__(
if type_name not in exclude_type_names
}
tool.__annotations__ = type_hints
- tool.__signature__ = Tool._get_signature(
- tool.__call__, exclude_type_names
+ setattr(
+ tool,
+ "__signature__",
+ Tool._get_signature(
+ cast(FunctionType, tool.__call__), exclude_type_names
+ ),
)
if not tool.__doc__ and tool.__call__.__doc__:
tool.__doc__ = tool.__call__.__doc__
- self.tools[i] = tool
+ self._tools[i] = tool
def with_enabled_tools(self, enable_tools: list[str]) -> "ToolSet":
prefix = f"{self.namespace}." if self.namespace else ""
@@ -128,18 +136,25 @@ def with_enabled_tools(self, enable_tools: list[str]) -> "ToolSet":
async def __aenter__(self) -> "ToolSet":
for tool in self.tools:
if hasattr(tool, "__aenter__"):
- await self._exit_stack.enter_async_context(tool)
+ await self._exit_stack.enter_async_context(
+ cast(AbstractAsyncContextManager[Any, bool | None], tool)
+ )
elif hasattr(tool, "__enter__"):
- self._exit_stack.enter_context(tool)
+ self._exit_stack.enter_context(
+ cast(AbstractContextManager[Any, bool | None], tool)
+ )
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- traceback: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
- return await self._exit_stack.__aexit__(exc_type, exc_value, traceback)
+ result = await self._exit_stack.__aexit__(
+ exc_type, exc_value, traceback
+ )
+ return result if result is not None else False
def json_schemas(self, prefix: str | None = None) -> list[dict] | None:
schemas = []
diff --git a/src/avalan/tool/browser.py b/src/avalan/tool/browser.py
index 2f20555c..55dd2f81 100644
--- a/src/avalan/tool/browser.py
+++ b/src/avalan/tool/browser.py
@@ -7,7 +7,8 @@
from dataclasses import dataclass
from email.message import EmailMessage
from io import BytesIO, TextIOBase
-from typing import TYPE_CHECKING, Literal, final
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, Literal, final
if TYPE_CHECKING:
from faiss import IndexFlatL2
@@ -117,7 +118,9 @@ def __init__(
self._partitioner = partitioner
self.__name__ = "open"
- async def __call__(self, url: str, *, context: ToolCallContext) -> str:
+ async def __call__( # type: ignore[override]
+ self, url: str, *, context: ToolCallContext
+ ) -> str:
content = await self._read(url)
if (
@@ -185,9 +188,9 @@ async def __call__(self, url: str, *, context: ToolCallContext) -> str:
else knowledge_chunk
)
- content = knowledge_match
+ content = str(knowledge_match) if knowledge_match else ""
- return content
+ return str(content)
async def _read(self, url: str) -> str:
if (
@@ -201,14 +204,15 @@ async def _read(self, url: str) -> str:
return content
if not self._browser:
+ client: Any = self._client
browser_type = (
- self._client.chromium
+ client.chromium
if self._settings.engine == "chromium"
else (
- self._client.firefox
+ client.firefox
if self._settings.engine == "firefox"
else (
- self._client.webkit
+ client.webkit
if self._settings.engine == "webkit"
else None
)
@@ -254,6 +258,7 @@ async def _read(self, url: str) -> str:
)
response = await self._page.goto(url)
+ assert response is not None, f"Failed to load {url}"
contents: str = await self._page.content()
content_type_header = response.headers.get("content-type", None)
assert content_type_header
@@ -262,12 +267,18 @@ async def _read(self, url: str) -> str:
m["content-type"] = content_type_header
maintype = m.get_content_maintype() or "text"
assert maintype == "text"
- encoding = (m.get_param("charset") or "utf-8").lower()
+ charset_param = m.get_param("charset")
+ encoding = (
+ str(charset_param).lower()
+ if charset_param and not isinstance(charset_param, tuple)
+ else "utf-8"
+ )
mime_type = m.get_content_type()
byte_stream = BytesIO(contents.encode(encoding))
- result = self._md.convert_stream(byte_stream, mime_type=mime_type)
- content = result.text_content
- return content
+ assert self._md is not None, "MarkItDown must be initialized"
+ md_result = self._md.convert_stream(byte_stream, mime_type=mime_type)
+ page_content: str = md_result.text_content or ""
+ return page_content
def with_client(self, client: "PlaywrightContextManager") -> "BrowserTool":
self._client = client
@@ -278,7 +289,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- traceback: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
if self._page:
await self._page.close()
@@ -313,13 +324,20 @@ def __init__(
self._client = async_playwright()
tools = [BrowserTool(settings, self._client, partitioner=partitioner)]
- return super().__init__(
+ super().__init__(
exit_stack=exit_stack, namespace=namespace, tools=tools
)
@override
async def __aenter__(self) -> "BrowserToolSet":
- self._client = await self._exit_stack.enter_async_context(self._client)
+ assert (
+ self._client is not None
+ ), "Playwright client must be initialized"
+ playwright_instance: Any = await self._exit_stack.enter_async_context(
+ self._client
+ )
for i, tool in enumerate(self._tools):
- self._tools[i] = tool.with_client(self._client)
- return await super().__aenter__()
+ if hasattr(tool, "with_client"):
+ self._tools[i] = tool.with_client(playwright_instance)
+ await super().__aenter__()
+ return self
diff --git a/src/avalan/tool/code.py b/src/avalan/tool/code.py
index 4c6f9dbb..bfc633ab 100644
--- a/src/avalan/tool/code.py
+++ b/src/avalan/tool/code.py
@@ -5,6 +5,7 @@
from asyncio import create_subprocess_exec
from asyncio.subprocess import PIPE
from contextlib import AsyncExitStack
+from typing import Any
try:
from RestrictedPython import (
@@ -37,10 +38,10 @@ def __init__(self) -> None:
super().__init__()
self.__name__ = "run"
- async def __call__(
- self, code: str, *args: any, context: ToolCallContext, **kwargs: any
+ async def __call__( # type: ignore[override]
+ self, code: str, *args: Any, context: ToolCallContext, **kwargs: Any
) -> str:
- locals_dict = {}
+ locals_dict: dict[str, Any] = {}
byte_code = compile_restricted(
code,
filename="",
@@ -90,7 +91,7 @@ def __init__(self) -> None:
super().__init__()
self.__name__ = "search.ast.grep"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
pattern: str,
*,
diff --git a/src/avalan/tool/database/__init__.py b/src/avalan/tool/database/__init__.py
index 673773fb..2f94bcd0 100644
--- a/src/avalan/tool/database/__init__.py
+++ b/src/avalan/tool/database/__init__.py
@@ -6,6 +6,7 @@
from asyncio import sleep
from dataclasses import dataclass
from re import compile as regex_compile
+from types import TracebackType
from typing import Any, Literal, final
try:
@@ -22,22 +23,22 @@
_SQLALCHEMY_AVAILABLE = True
except ImportError:
_SQLALCHEMY_AVAILABLE = False
- MetaData = None # type: ignore[assignment]
- event = None # type: ignore[assignment]
- func = None # type: ignore[assignment]
- select = None # type: ignore[assignment]
- text = None # type: ignore[assignment]
- SATable = None # type: ignore[assignment]
- sqlalchemy_inspect = None # type: ignore[assignment]
- Connection = None # type: ignore[assignment]
- Inspector = None # type: ignore[assignment]
- NoSuchTableError = None # type: ignore[assignment]
- SQLAlchemyError = None # type: ignore[assignment]
- AsyncEngine = None # type: ignore[assignment]
- create_async_engine = None # type: ignore[assignment]
- Select = None # type: ignore[assignment]
- ColumnElement = None # type: ignore[assignment]
- TextClause = None # type: ignore[assignment]
+ MetaData = None # type: ignore[misc, assignment]
+ event = None # type: ignore[misc, assignment]
+ func = None # type: ignore[misc, assignment]
+ select = None # type: ignore[misc, assignment]
+ text = None # type: ignore[misc, assignment]
+ SATable = None # type: ignore[misc, assignment]
+ sqlalchemy_inspect = None # type: ignore[misc, assignment]
+ Connection = None # type: ignore[misc, assignment]
+ Inspector = None # type: ignore[misc, assignment]
+ NoSuchTableError = None # type: ignore[misc, assignment]
+ SQLAlchemyError = None # type: ignore[misc, assignment]
+ AsyncEngine = None # type: ignore[misc, assignment]
+ create_async_engine = None # type: ignore[misc, assignment]
+ Select = None # type: ignore[misc, assignment]
+ ColumnElement = None # type: ignore[misc, assignment]
+ TextClause = None # type: ignore[misc, assignment]
try:
from sqlglot import exp, parse, parse_one
@@ -45,9 +46,9 @@
_SQLGLOT_AVAILABLE = True
except ImportError:
_SQLGLOT_AVAILABLE = False
- exp = None # type: ignore[assignment]
- parse = None # type: ignore[assignment]
- parse_one = None # type: ignore[assignment]
+ exp = None # type: ignore[misc, assignment]
+ parse = None # type: ignore[misc, assignment]
+ parse_one = None # type: ignore[misc, assignment]
@final
@@ -257,6 +258,8 @@ def _apply_identifier_case(self, connection: Connection, sql: str) -> str:
if self._normalizer is None:
return sql
+ normalizer = self._normalizer
+
self._ensure_sqlglot_available()
inspector = inspect(connection)
_, schemas = self._schemas(connection, inspector)
@@ -296,7 +299,7 @@ def normalize_table(node: exp.Expression) -> exp.Expression:
if isinstance(schema_ident, exp.Identifier)
else None
)
- key = self._normalizer.normalize(name)
+ key = normalizer.normalize(name)
lookup = f"{schema}.{key}" if schema else key
actual = replacements.get(lookup) or replacements.get(key)
if actual:
@@ -387,7 +390,7 @@ def _schemas(
if dialect == "postgresql":
sys = {"information_schema", "pg_catalog"}
- schemas = [
+ schemas: list[str | None] = [
s
for s in inspector.get_schema_names()
if s not in sys and not (s or "").startswith("pg_")
@@ -396,9 +399,14 @@ def _schemas(
schemas.append(default_schema)
return default_schema, schemas
- all_schemas = inspector.get_schema_names() or (
- [default_schema] if default_schema is not None else [None]
- )
+ raw_schemas = inspector.get_schema_names()
+ all_schemas: list[str | None]
+ if raw_schemas:
+ all_schemas = list(raw_schemas)
+ elif default_schema is not None:
+ all_schemas = [default_schema]
+ else:
+ all_schemas = [None]
sys_filters = {
"mysql": {
@@ -421,9 +429,10 @@ def _schemas(
schemas = [s for s in all_schemas if s not in sys]
if not schemas:
- schemas = (
- [default_schema] if default_schema is not None else [None]
- )
+ if default_schema is not None:
+ schemas = [default_schema]
+ else:
+ schemas = [None]
seen: set[str | None] = set()
uniq: list[str | None] = []
@@ -442,7 +451,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- traceback: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
return await super().__aexit__(exc_type, exc_value, traceback)
diff --git a/src/avalan/tool/database/count.py b/src/avalan/tool/database/count.py
index 5e4783e1..47e61e52 100644
--- a/src/avalan/tool/database/count.py
+++ b/src/avalan/tool/database/count.py
@@ -44,7 +44,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]:
return (sch or None), tbl
return None, qualified
- async def __call__(
+ async def __call__( # type: ignore[override]
self, table_name: str, *, context: ToolCallContext
) -> int:
assert table_name, "table_name must not be empty"
diff --git a/src/avalan/tool/database/inspect.py b/src/avalan/tool/database/inspect.py
index 9fac3d4f..2f707bce 100644
--- a/src/avalan/tool/database/inspect.py
+++ b/src/avalan/tool/database/inspect.py
@@ -38,7 +38,7 @@ def __init__(
)
self.__name__ = "inspect"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
table_names: list[str],
schema: str | None = None,
diff --git a/src/avalan/tool/database/keys.py b/src/avalan/tool/database/keys.py
index 2b471f7b..21001cf1 100644
--- a/src/avalan/tool/database/keys.py
+++ b/src/avalan/tool/database/keys.py
@@ -8,6 +8,8 @@
TableKey,
)
+from typing import Any, cast
+
class DatabaseKeysTool(DatabaseTool):
"""List primary and unique keys defined on a table.
@@ -37,7 +39,7 @@ def __init__(
)
self.__name__ = "keys"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
table_name: str,
schema: str | None = None,
@@ -71,36 +73,38 @@ def _collect(
keys: list[TableKey] = []
- pk = (
- inspector.get_pk_constraint(actual_table, schema=resolved_schema)
- or {}
- )
- pk_columns = tuple(
- pk.get("constrained_columns") or pk.get("column_names") or []
+ pk = cast(
+ dict[str, Any] | None,
+ inspector.get_pk_constraint(actual_table, schema=resolved_schema),
)
- if pk_columns:
- keys.append(
- TableKey(
- type="primary",
- name=pk.get("name"),
- columns=pk_columns,
+ if pk:
+ pk_col_raw = pk.get("constrained_columns") or []
+ pk_columns: tuple[str, ...] = tuple(str(c) for c in pk_col_raw)
+ if pk_columns:
+ keys.append(
+ TableKey(
+ type="primary",
+ name=pk.get("name"),
+ columns=pk_columns,
+ )
)
- )
- unique_constraints = (
+ unique_constraints = cast(
+ list[dict[str, Any]],
inspector.get_unique_constraints(
actual_table,
schema=resolved_schema,
)
- or []
+ or [],
)
for constraint in unique_constraints:
- columns = tuple(
+ col_raw = (
constraint.get("column_names")
or constraint.get("constrained_columns")
or []
)
+ columns: tuple[str, ...] = tuple(str(c) for c in col_raw)
if not columns:
continue
keys.append(
diff --git a/src/avalan/tool/database/kill.py b/src/avalan/tool/database/kill.py
index b9637212..3179cf9c 100644
--- a/src/avalan/tool/database/kill.py
+++ b/src/avalan/tool/database/kill.py
@@ -35,7 +35,7 @@ def __init__(
)
self.__name__ = "kill"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
task_id: str,
*,
diff --git a/src/avalan/tool/database/locks.py b/src/avalan/tool/database/locks.py
index 37690916..0313e166 100644
--- a/src/avalan/tool/database/locks.py
+++ b/src/avalan/tool/database/locks.py
@@ -39,7 +39,7 @@ def __init__(
)
self.__name__ = "locks"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
*,
context: ToolCallContext,
diff --git a/src/avalan/tool/database/plan.py b/src/avalan/tool/database/plan.py
index acda2e87..2c5667f7 100644
--- a/src/avalan/tool/database/plan.py
+++ b/src/avalan/tool/database/plan.py
@@ -37,7 +37,7 @@ def __init__(
)
self.__name__ = "plan"
- async def __call__(
+ async def __call__( # type: ignore[override]
self, sql: str, *, context: ToolCallContext
) -> QueryPlan:
await self._sleep_if_configured()
diff --git a/src/avalan/tool/database/relationships.py b/src/avalan/tool/database/relationships.py
index 3dc1315e..d4800b8f 100644
--- a/src/avalan/tool/database/relationships.py
+++ b/src/avalan/tool/database/relationships.py
@@ -38,7 +38,7 @@ def __init__(
)
self.__name__ = "relationships"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
table_name: str,
*,
diff --git a/src/avalan/tool/database/run.py b/src/avalan/tool/database/run.py
index ac1985e1..41a97b0e 100644
--- a/src/avalan/tool/database/run.py
+++ b/src/avalan/tool/database/run.py
@@ -35,7 +35,7 @@ def __init__(
)
self.__name__ = "run"
- async def __call__(
+ async def __call__( # type: ignore[override]
self, sql: str, *, context: ToolCallContext
) -> list[dict[str, Any]]:
await self._sleep_if_configured()
diff --git a/src/avalan/tool/database/sample.py b/src/avalan/tool/database/sample.py
index 2c2f7422..b642e9bc 100644
--- a/src/avalan/tool/database/sample.py
+++ b/src/avalan/tool/database/sample.py
@@ -54,7 +54,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]:
return (schema or None), table
return None, qualified
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
table_name: str,
*,
diff --git a/src/avalan/tool/database/size.py b/src/avalan/tool/database/size.py
index 3dd92759..26e8480f 100644
--- a/src/avalan/tool/database/size.py
+++ b/src/avalan/tool/database/size.py
@@ -10,7 +10,7 @@
text,
)
-from typing import Any
+from typing import Any, Literal
class DatabaseSizeTool(DatabaseTool):
@@ -46,7 +46,7 @@ def _split_schema_and_table(qualified: str) -> tuple[str | None, str]:
return (schema or None), table
return None, qualified
- async def __call__(
+ async def __call__( # type: ignore[override]
self, table_name: str, *, context: ToolCallContext
) -> TableSize:
assert table_name, "table_name must not be empty"
@@ -220,10 +220,10 @@ def _collect_sqlite(
table_row.get("size") if table_row else None
)
- index_rows: list[dict[str, Any]] = []
+ index_rows: list[Any] = []
try:
pragma = table_name.replace("'", "''")
- index_rows = (
+ index_rows = list(
connection.execute(text(f"PRAGMA index_list('{pragma}')"))
.mappings()
.all()
@@ -420,7 +420,11 @@ def _collect_mssql(
metrics.append(self._metric("total", total_bytes))
return metrics
- def _metric(self, category: str, value: int | None) -> TableSizeMetric:
+ def _metric(
+ self,
+ category: Literal["data", "indexes", "total", "toast", "lob", "free"],
+ value: int | None,
+ ) -> TableSizeMetric:
return TableSizeMetric(
category=category,
bytes=value,
diff --git a/src/avalan/tool/database/tables.py b/src/avalan/tool/database/tables.py
index 59c8c1e5..fd0d8936 100644
--- a/src/avalan/tool/database/tables.py
+++ b/src/avalan/tool/database/tables.py
@@ -34,7 +34,7 @@ def __init__(
)
self.__name__ = "tables"
- async def __call__(
+ async def __call__( # type: ignore[override]
self, *, context: ToolCallContext
) -> dict[str | None, list[str]]:
await self._sleep_if_configured()
diff --git a/src/avalan/tool/database/tasks.py b/src/avalan/tool/database/tasks.py
index 84f91204..515f6d0e 100644
--- a/src/avalan/tool/database/tasks.py
+++ b/src/avalan/tool/database/tasks.py
@@ -40,7 +40,7 @@ def __init__(
)
self.__name__ = "tasks"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
*,
running_for: int | None = None,
diff --git a/src/avalan/tool/database/toolset.py b/src/avalan/tool/database/toolset.py
index 08156b2e..55a80325 100644
--- a/src/avalan/tool/database/toolset.py
+++ b/src/avalan/tool/database/toolset.py
@@ -20,6 +20,7 @@
from .settings import DatabaseToolSettings
from contextlib import AsyncExitStack
+from types import TracebackType
class DatabaseToolSet(ToolSet):
@@ -131,11 +132,11 @@ def __init__(
async def __aexit__(
self,
exc_type: type[BaseException] | None,
- exc: BaseException | None,
- tb: BaseException | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
try:
if self._engine is not None:
await self._engine.dispose()
finally:
- return await super().__aexit__(exc_type, exc, tb)
+ return await super().__aexit__(exc_type, exc_value, traceback)
diff --git a/src/avalan/tool/manager.py b/src/avalan/tool/manager.py
index 2ce29e62..b11a749f 100644
--- a/src/avalan/tool/manager.py
+++ b/src/avalan/tool/manager.py
@@ -13,6 +13,8 @@
from collections.abc import Callable, Sequence
from contextlib import AsyncExitStack, ContextDecorator
+from types import TracebackType
+from typing import Any
from uuid import uuid4
@@ -60,10 +62,13 @@ def tool_format(self) -> ToolFormat | None:
"""Return the tool format configured for this manager."""
return self._parser.tool_format
- def json_schemas(self) -> list[dict] | None:
- schemas = []
- for toolset in self._toolsets:
- schemas.extend(toolset.json_schemas())
+ def json_schemas(self) -> list[dict[str, Any]] | None:
+ schemas: list[dict[str, Any]] = []
+ if self._toolsets:
+ for toolset in self._toolsets:
+ toolset_schemas = toolset.json_schemas()
+ if toolset_schemas:
+ schemas.extend(toolset_schemas)
return schemas
def __init__(
@@ -110,24 +115,33 @@ def tool_call_status(
return self._parser.tool_call_status(buffer)
def get_calls(self, text: str) -> list[ToolCall] | None:
- return self._parser(text)
+ result = self._parser(text)
+ if result is None:
+ return None
+ if isinstance(result, list):
+ return result
+ name, arguments = result
+ return [ToolCall(id=uuid4(), name=name, arguments=arguments)]
async def __aenter__(self) -> "ToolManager":
if self._toolsets:
- for i, toolset in enumerate(self._toolsets):
- toolset = await self._stack.enter_async_context(toolset)
- self._toolsets[i] = toolset
- return self
+ entered_toolsets: list[ToolSet] = []
+ for toolset in self._toolsets:
+ entered = await self._stack.enter_async_context(toolset)
+ entered_toolsets.append(entered)
+ self._toolsets = entered_toolsets
+ return self # type: ignore[return-value]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- traceback: BaseException | None,
+ traceback: TracebackType | None,
) -> bool:
- return await self._stack.__aexit__(exc_type, exc_value, traceback)
+ result = await self._stack.__aexit__(exc_type, exc_value, traceback)
+ return result if result is not None else False
- async def __call__(
+ async def __call__( # type: ignore[override]
self, call: ToolCall, context: ToolCallContext
) -> ToolCallResult | ToolCallError | None:
"""Execute a single tool call and return the result."""
@@ -153,10 +167,10 @@ async def __call__(
if self._settings.filters:
for f in self._settings.filters:
- namespace = None
- func = f
+ namespace: str | None = None
+ func: Callable[[ToolCall, ToolCallContext], Any] = f # type: ignore[assignment]
if isinstance(f, ToolFilter):
- func = f.func
+ func = f.func # type: ignore[assignment]
namespace = f.namespace
if not self._matches_namespace(call.name, namespace):
continue
@@ -184,14 +198,14 @@ async def __call__(
if self._settings.transformers:
for t in self._settings.transformers:
- namespace = None
- func = t
+ t_namespace: str | None = None
+ t_func: Callable[..., Any] = t # type: ignore[assignment]
if isinstance(t, ToolTransformer):
- func = t.func
- namespace = t.namespace
- if not self._matches_namespace(call.name, namespace):
+ t_func = t.func # type: ignore[assignment]
+ t_namespace = t.namespace
+ if not self._matches_namespace(call.name, t_namespace):
continue
- transformed = func(call, context, result)
+ transformed = t_func(call, context, result)
if transformed is not None:
result = transformed
diff --git a/src/avalan/tool/math.py b/src/avalan/tool/math.py
index f48afbd2..98089d4e 100644
--- a/src/avalan/tool/math.py
+++ b/src/avalan/tool/math.py
@@ -21,7 +21,9 @@ def __init__(self) -> None:
super().__init__()
self.__name__ = "calculator"
- async def __call__(self, expression: str, context: ToolCallContext) -> str:
+ async def __call__( # type: ignore[override]
+ self, expression: str, context: ToolCallContext
+ ) -> str:
result = sympify(expression, evaluate=True)
return str(result)
diff --git a/src/avalan/tool/mcp.py b/src/avalan/tool/mcp.py
index 569f8588..4bcb9fb7 100644
--- a/src/avalan/tool/mcp.py
+++ b/src/avalan/tool/mcp.py
@@ -31,7 +31,7 @@ def __init__(
self._client_params = client_params or {}
self._call_params = call_params or {}
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
uri: str,
name: str,
@@ -39,15 +39,16 @@ async def __call__(
*,
context: ToolCallContext,
) -> list[object]:
- from mcp import Client
+ from mcp import Client # type: ignore[attr-defined]
assert uri
assert name
async with Client(uri, **self._client_params) as client:
- return await client.call_tool(
+ result: list[object] = await client.call_tool(
name, arguments or {}, **self._call_params
)
+ return result
class McpToolSet(ToolSet):
diff --git a/src/avalan/tool/memory.py b/src/avalan/tool/memory.py
index c083b3a2..9b155682 100644
--- a/src/avalan/tool/memory.py
+++ b/src/avalan/tool/memory.py
@@ -3,7 +3,6 @@
from ..memory.manager import MemoryManager
from ..memory.permanent import (
Memory,
- PermanentMemoryPartition,
PermanentMemoryStore,
VectorFunction,
)
@@ -30,7 +29,9 @@ def __init__(self, memory_manager: MemoryManager) -> None:
self._memory_manager = memory_manager
self.__name__ = "message.read"
- async def __call__(self, search: str, context: ToolCallContext) -> str:
+ async def __call__( # type: ignore[override]
+ self, search: str, context: ToolCallContext
+ ) -> str:
if (
not context.agent_id
or not context.session_id
@@ -48,7 +49,10 @@ async def __call__(self, search: str, context: ToolCallContext) -> str:
limit=1,
)
if results and results[0].message:
- return results[0].message.content
+ content = results[0].message.content
+ if isinstance(content, str):
+ return content
+ return MessageReadTool._NOT_FOUND
return MessageReadTool._NOT_FOUND
@@ -78,13 +82,13 @@ def __init__(
self._function = function
self.__name__ = "read"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
namespace: str,
search: str,
*,
context: ToolCallContext,
- ) -> list[PermanentMemoryPartition]:
+ ) -> list[str]:
"""Return memory partitions that match the search query."""
if (
not namespace
@@ -126,7 +130,7 @@ def __init__(self, memory_manager: MemoryManager) -> None:
self._memory_manager = memory_manager
self.__name__ = "list"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
namespace: str,
*,
@@ -160,7 +164,7 @@ def __init__(self, memory_manager: MemoryManager) -> None:
self._memory_manager = memory_manager
self.__name__ = "stores"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
*,
context: ToolCallContext,
diff --git a/src/avalan/tool/parser.py b/src/avalan/tool/parser.py
index 1b80eece..cd2491d8 100644
--- a/src/avalan/tool/parser.py
+++ b/src/avalan/tool/parser.py
@@ -39,8 +39,10 @@ def tool_format(self) -> ToolFormat | None:
"""Return the tool format used by the parser."""
return self._tool_format
- def __call__(self, text: str) -> list[ToolCall] | None:
- calls = (
+ def __call__(
+ self, text: str
+ ) -> list[ToolCall] | tuple[str, dict[str, Any]] | None:
+ calls: list[ToolCall] | tuple[str, dict[str, Any]] | None = (
self._parse_json(text)
if self._tool_format is ToolFormat.JSON
else (
@@ -152,7 +154,7 @@ def extract_structured_message(
def message_tool_calls(self, text: str) -> list[dict[str, object]]:
"""Return tool calls extracted from ``text`` in message format."""
- parsed = None
+ parsed: list[ToolCall] | tuple[str, dict[str, Any]] | None = None
if "<|call|>" in text and "<|channel|>" in text:
parsed = self._parse_harmony(text)
elif self._tool_format:
@@ -272,7 +274,7 @@ def _merge_thinking(
if existing_thinking:
combined = "\n\n".join(
- part for part in (existing_thinking, thinking) if part
+ str(part) for part in (existing_thinking, thinking) if part
)
else:
combined = thinking
@@ -379,7 +381,7 @@ def _parse_harmony(self, text: str) -> list[ToolCall] | None:
return tool_calls
return None
- def _parse_tag(self, text: str) -> tuple[str, dict[str, Any]] | None:
+ def _parse_tag(self, text: str) -> list[ToolCall] | None:
tool_calls: list[ToolCall] = []
if self._eos_token:
@@ -389,6 +391,8 @@ def _parse_tag(self, text: str) -> tuple[str, dict[str, Any]] | None:
for element in root.findall(".//tool_call"):
tool_call = None
try:
+ if element.text is None:
+ continue
json_text = element.text.strip()
try:
diff --git a/src/avalan/tool/search_engine.py b/src/avalan/tool/search_engine.py
index 2ec37c1a..160be6a0 100644
--- a/src/avalan/tool/search_engine.py
+++ b/src/avalan/tool/search_engine.py
@@ -15,7 +15,9 @@ class SearchEngineTool(Tool):
def __init__(self) -> None:
self.__name__ = "search"
- async def __call__(self, query: str, engine: str) -> str:
+ async def __call__( # type: ignore[override]
+ self, query: str, engine: str
+ ) -> str:
return (
"The weather is nice and warm, with 23 degrees celsius, clear"
" skies, and winds under 11 kmh."
diff --git a/src/avalan/tool/youtube.py b/src/avalan/tool/youtube.py
index 65629da4..484eaf04 100644
--- a/src/avalan/tool/youtube.py
+++ b/src/avalan/tool/youtube.py
@@ -33,7 +33,7 @@ def __init__(
self._proxy = proxy
self.__name__ = "transcript"
- async def __call__(
+ async def __call__( # type: ignore[override]
self,
video_id: str,
*,
diff --git a/src/avalan/utils.py b/src/avalan/utils.py
index 71597bc9..95de3658 100644
--- a/src/avalan/utils.py
+++ b/src/avalan/utils.py
@@ -28,8 +28,8 @@ def logger_replace(logger: Logger, logger_names: list[str]) -> None:
def to_json(item: Any) -> str:
- def _default(o):
- if is_dataclass(o):
+ def _default(o: Any) -> Any:
+ if is_dataclass(o) and not isinstance(o, type):
return asdict(o)
elif isinstance(o, (Decimal, UUID)):
return str(o)
diff --git a/tests/agent/default_orchestrator_test.py b/tests/agent/default_orchestrator_test.py
index 1be08b00..961437a4 100644
--- a/tests/agent/default_orchestrator_test.py
+++ b/tests/agent/default_orchestrator_test.py
@@ -6,7 +6,7 @@
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
-from avalan.agent import Goal, InputType, OutputType, Specification
+from avalan.agent import Goal, InputType, OutputType, Role, Specification
from avalan.agent.orchestrator.orchestrators.default import DefaultOrchestrator
from avalan.agent.orchestrator.response.orchestrator_response import (
OrchestratorResponse,
@@ -69,7 +69,7 @@ def test_initialization(self):
op = orch.operations[0]
self.assertIs(op.environment.engine_uri, engine_uri)
self.assertIs(op.environment.settings, settings)
- self.assertEqual(op.specification.role, "assistant")
+ self.assertEqual(op.specification.role, Role(persona=["assistant"]))
self.assertEqual(op.specification.goal.task, "do")
self.assertEqual(op.specification.goal.instructions, ["something"])
self.assertEqual(op.specification.rules, ["a", "b"])
@@ -206,7 +206,7 @@ def output_fn(*args, **kwargs):
self.assertEqual(
context.specification,
Specification(
- role="assistant",
+ role=Role(persona=["assistant"]),
goal=Goal(task="do", instructions=["something"]),
rules=None,
input_type=InputType.TEXT,
@@ -309,7 +309,7 @@ async def test_user_string_rendering(self):
context = agent_mock.await_args.args[0]
self.assertIsInstance(context.input, Message)
- self.assertEqual(context.input.content, b"hello world Bob")
+ self.assertEqual(context.input.content, "hello world Bob")
async def test_user_template_rendering(self):
with TemporaryDirectory() as tmp:
diff --git a/tests/agent/orchestrator_response_additional_test.py b/tests/agent/orchestrator_response_additional_test.py
index 0b62a1e3..e2ffeea5 100644
--- a/tests/agent/orchestrator_response_additional_test.py
+++ b/tests/agent/orchestrator_response_additional_test.py
@@ -139,7 +139,7 @@ async def output_gen():
second = await iterator.__anext__()
self.assertEqual(second.type, EventType.TOOL_PROCESS)
- self.assertEqual(second.payload, [call])
+ self.assertEqual(second.payload, {"calls": [call]})
third = await iterator.__anext__()
self.assertEqual(third.type, EventType.TOOL_RESULT)
diff --git a/tests/agent/orchestrator_test.py b/tests/agent/orchestrator_test.py
index a04e9314..430b0545 100644
--- a/tests/agent/orchestrator_test.py
+++ b/tests/agent/orchestrator_test.py
@@ -230,7 +230,7 @@ def _message(self, text: str) -> Message:
def test_user_string_transformation(self):
orchestrator, specification = self._create_orchestrator()
message = orchestrator._input_messages(specification, "world")
- self.assertEqual(message.content, b"hello world Bob")
+ self.assertEqual(message.content, "hello world Bob")
def test_user_list_strings_transformation(self):
orchestrator, specification = self._create_orchestrator()
@@ -241,13 +241,13 @@ def test_user_message_transformation(self):
orchestrator, specification = self._create_orchestrator()
msg = self._message("earth")
message = orchestrator._input_messages(specification, msg)
- self.assertEqual(message.content, b"hello earth Bob")
+ self.assertEqual(message.content, "hello earth Bob")
def test_user_list_messages_transformation(self):
orchestrator, specification = self._create_orchestrator()
msg = self._message("moon")
messages = orchestrator._input_messages(specification, [msg])
- self.assertEqual(messages[0].content, b"hello moon Bob")
+ self.assertEqual(messages[0].content, "hello moon Bob")
class OrchestratorUserTemplateTransformationOptionsTestCase(unittest.TestCase):
@@ -348,7 +348,7 @@ def _create_orchestrator(self):
def test_user_string_transformation(self):
orchestrator, specification = self._create_orchestrator()
message = orchestrator._input_messages(specification, "world")
- self.assertEqual(message.content, b"hello world Bob")
+ self.assertEqual(message.content, "hello world Bob")
class OrchestratorSettingsTemplateVarsUserTemplateTestCase(unittest.TestCase):
diff --git a/tests/agent/renderer_test.py b/tests/agent/renderer_test.py
index d0ca13af..9d1a86c0 100644
--- a/tests/agent/renderer_test.py
+++ b/tests/agent/renderer_test.py
@@ -59,9 +59,7 @@ def test_custom_template_path(self):
def test_from_string(self):
renderer = self.Renderer()
tmpl = "Hi {{name}}"
- self.assertEqual(
- renderer.from_string(tmpl, {"name": "Ada"}), b"Hi Ada"
- )
+ self.assertEqual(renderer.from_string(tmpl, {"name": "Ada"}), "Hi Ada")
self.assertEqual(renderer.from_string(tmpl), tmpl)
def test_template_caching(self):
diff --git a/tests/agent/template_engine_agent_test.py b/tests/agent/template_engine_agent_test.py
index 34a5d9b4..6fe840cc 100644
--- a/tests/agent/template_engine_agent_test.py
+++ b/tests/agent/template_engine_agent_test.py
@@ -89,9 +89,9 @@ def test_prepare_call_no_template_vars(self):
"agent.md",
name="Bob",
roles=["assistant"],
- task=b"do",
- instructions=[b"instr"],
- rules=[b"rule"],
+ task="do",
+ instructions=["instr"],
+ rules=["rule"],
)
self.assertEqual(result["settings"], spec.settings)
self.assertEqual(result["system_prompt"], expected_prompt)
@@ -108,10 +108,10 @@ def test_prepare_call_with_template_vars(self):
expected_prompt = self.renderer(
"agent.md",
name="Bob",
- roles=[b"role run"],
- task=b"do run",
- instructions=[b"inst run"],
- rules=[b"rule run"],
+ roles=["role run"],
+ task="do run",
+ instructions=["inst run"],
+ rules=["rule run"],
)
self.assertEqual(result["system_prompt"], expected_prompt)
@@ -127,10 +127,10 @@ def test_prepare_call_with_settings_template_vars(self):
expected_prompt = self.renderer(
"agent.md",
name="Bob",
- roles=[b"role run"],
- task=b"do run",
- instructions=[b"inst run"],
- rules=[b"rule run"],
+ roles=["role run"],
+ task="do run",
+ instructions=["inst run"],
+ rules=["rule run"],
)
self.assertEqual(result["system_prompt"], expected_prompt)
@@ -228,9 +228,9 @@ async def test_call_invokes_run_with_prepared_arguments(self):
"agent.md",
name="Bob",
roles=["assistant"],
- task=b"do",
- instructions=[b"ins"],
- rules=[b"r"],
+ task="do",
+ instructions=["ins"],
+ rules=["r"],
)
agent._run = AsyncMock(return_value="out")
diff --git a/tests/cli/agent_test.py b/tests/cli/agent_test.py
index 84b47f0a..865beca0 100644
--- a/tests/cli/agent_test.py
+++ b/tests/cli/agent_test.py
@@ -72,7 +72,7 @@ def setUp(self):
self.console.status.return_value = status_cm
self.theme = MagicMock()
self.theme._ = lambda s: s
- self.theme.icons = {"user_input": ">"}
+ self.theme._icons = {"user_input": ">", "agent_output": "<"}
self.theme.get_spinner.return_value = "sp"
self.theme.agent.return_value = "agent_panel"
self.theme.search_message_matches.return_value = "matches_panel"
@@ -896,7 +896,7 @@ def setUp(self):
self.console.status.return_value = status_cm
self.theme = MagicMock()
self.theme._ = lambda s: s
- self.theme.icons = {"user_input": ">", "agent_output": "<"}
+ self.theme._icons = {"user_input": ">", "agent_output": "<"}
self.theme.get_spinner.return_value = "sp"
self.theme.agent.return_value = "agent_panel"
self.theme.recent_messages.return_value = "recent_panel"
@@ -2467,7 +2467,7 @@ def setUp(self):
self.console.status.return_value = status_cm
self.theme = MagicMock()
self.theme._ = lambda s: s
- self.theme.icons = {"user_input": ">", "agent_output": "<"}
+ self.theme._icons = {"user_input": ">", "agent_output": "<"}
self.theme.get_spinner.return_value = "sp"
self.theme.agent.return_value = "agent_panel"
self.theme.recent_messages.return_value = "recent_panel"
diff --git a/tests/cli/cache_test.py b/tests/cli/cache_test.py
index 548a7bed..8d8f09b8 100644
--- a/tests/cli/cache_test.py
+++ b/tests/cli/cache_test.py
@@ -40,7 +40,7 @@ def test_execute_deletion_confirmed(self):
)
self.hub.cache_delete.assert_called_once_with("m", "r")
self.assertEqual(
- self.theme.cache_delete.call_args_list[0].args, (cache_del,)
+ self.theme.cache_delete.call_args_list[0].args, (cache_del, False)
)
self.assertEqual(
self.theme.cache_delete.call_args_list[1].args, (cache_del, True)
@@ -65,7 +65,9 @@ def test_deletion_cancelled(self):
execute.assert_not_called()
confirm.assert_called_once()
self.assertEqual(len(self.theme.cache_delete.call_args_list), 1)
- self.assertEqual(self.theme.cache_delete.call_args.args, (cache_del,))
+ self.assertEqual(
+ self.theme.cache_delete.call_args.args, (cache_del, False)
+ )
self.assertEqual(len(self.console.print.call_args_list), 1)
diff --git a/tests/cli/input_test.py b/tests/cli/input_test.py
index ba81c030..bdbb48b6 100644
--- a/tests/cli/input_test.py
+++ b/tests/cli/input_test.py
@@ -43,7 +43,7 @@ def test_prompt_when_no_stdin(self):
):
result = get_input(self.console, "prompt")
- ask.assert_called_once_with("prompt ")
+ ask.assert_called_once_with("prompt ", stream=None)
self.assertEqual(result, "value")
self.assertEqual(
self.console.print.call_args_list, [call(""), call("")]
@@ -123,9 +123,10 @@ def test_confirm_tool_call(self):
"Execute tool call?",
choices=["y", "a", "n"],
default="n",
+ console=console,
show_choices=True,
show_default=True,
- console=console,
+ stream=None,
)
self.assertEqual(result, "y")
@@ -206,9 +207,10 @@ def ask_side_effect(*args, **kwargs):
"Execute tool call?",
choices=["y", "a", "n"],
default="n",
+ console=console,
show_choices=True,
show_default=True,
- console=console,
+ stream=None,
)
self.assertEqual(live.auto_refresh, True)
self.assertEqual(result, "y")
diff --git a/tests/cli/model_search_additional_test.py b/tests/cli/model_search_additional_test.py
index b12fc3c3..0e1094f8 100644
--- a/tests/cli/model_search_additional_test.py
+++ b/tests/cli/model_search_additional_test.py
@@ -73,7 +73,7 @@ async def test_model_search_updates_access(self):
live.__exit__.return_value = False
async def fake_to_thread(fn, *a, **kw):
- return fn()
+ return fn(*a, **kw)
with (
patch.object(model_cmds, "Live", return_value=live),
diff --git a/tests/cli/model_test.py b/tests/cli/model_test.py
index ba21918d..77ff4f37 100644
--- a/tests/cli/model_test.py
+++ b/tests/cli/model_test.py
@@ -824,7 +824,7 @@ async def inner_gen():
events = [
model_cmds.Event(
type=model_cmds.EventType.TOOL_EXECUTE,
- payload=[call],
+ payload={"call": call},
),
token_a,
model_cmds.Event(
@@ -4242,7 +4242,7 @@ def fake_group(*items):
return items
async def to_thread_stub(fn, *a, **kw):
- return fn()
+ return fn(*a, **kw)
with (
patch.object(model_cmds, "Live", return_value=live),
@@ -4317,7 +4317,9 @@ async def test_model_search_open(self):
patch.object(model_cmds, "Live", return_value=live),
patch.object(model_cmds, "Group", side_effect=lambda *i: i),
patch.object(
- model_cmds, "to_thread", side_effect=lambda f, *a, **k: f()
+ model_cmds,
+ "to_thread",
+ side_effect=lambda f, *a, **k: f(*a, **k),
),
):
await model_cmds.model_search(args, console, theme, hub, 5)
diff --git a/tests/cli/theme_test.py b/tests/cli/theme_test.py
index 90443e43..4157c5fb 100644
--- a/tests/cli/theme_test.py
+++ b/tests/cli/theme_test.py
@@ -319,7 +319,7 @@ def test_base_methods_raise(self):
self.theme, "n", "d", "a", "id", "lib", True, False
),
lambda: Theme.agent(
- self.theme, SimpleNamespace(), models=[], cans_access=None
+ self.theme, SimpleNamespace(), models=[], can_access=None
),
lambda: Theme.events(self.theme, []),
lambda: Theme.ask_access_token(self.theme),
@@ -356,7 +356,7 @@ def test_base_methods_raise(self):
lambda: Theme.recent_messages(
self.theme, "id", SimpleNamespace(), []
),
- lambda: Theme.saved_tokenizer_files("/d", 0),
+ lambda: Theme.saved_tokenizer_files(self.theme, "/d", 0),
lambda: Theme.search_message_matches(
self.theme, "id", SimpleNamespace(), []
),
@@ -378,7 +378,7 @@ def test_base_methods_raise(self):
call()
async def run_tokens():
- await Theme.tokens(
+ async for _ in Theme.tokens(
self.theme,
model_id="m",
added_tokens=None,
@@ -396,13 +396,15 @@ async def run_tokens():
tool_events=None,
tool_event_calls=None,
tool_event_results=None,
+ tool_running_spinner=None,
ttft=0.0,
ttnt=0.0,
ttsr=0.0,
elapsed=0.0,
console_width=80,
logger=SimpleNamespace(),
- )
+ ):
+ pass
with self.assertRaises(NotImplementedError):
import asyncio
diff --git a/tests/cli/tokenizer_test.py b/tests/cli/tokenizer_test.py
index 5df733a4..23c40dd0 100644
--- a/tests/cli/tokenizer_test.py
+++ b/tests/cli/tokenizer_test.py
@@ -1,7 +1,6 @@
import importlib
import sys
from argparse import Namespace
-from dataclasses import dataclass
from types import ModuleType, SimpleNamespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, call, patch
@@ -17,39 +16,15 @@ def setUpClass(cls):
cls._saved_modules = {
name: sys.modules.get(name)
for name in [
- "avalan.entities",
"avalan.model.hubs.huggingface",
"avalan.model.nlp.text.generation",
]
}
- # Stub avalan.entities
- entities = ModuleType("avalan.entities")
-
- @dataclass(frozen=True, kw_only=True)
- class Token:
- id: int
- token: str
- probability: float | None = None
-
- class TransformerEngineSettings:
- def __init__(
- self,
- device=None,
- cache_dir=None,
- **kwargs,
- ) -> None:
- self.device = device
- self.cache_dir = cache_dir
- self.distributed_config = None
- for key, value in kwargs.items():
- setattr(self, key, value)
-
- backend = "transformers"
+ # Import real entities for TransformerEngineSettings
+ from avalan.entities import TransformerEngineSettings
- entities.Token = Token
- entities.TransformerEngineSettings = TransformerEngineSettings
- sys.modules["avalan.entities"] = entities
+ cls.TransformerEngineSettings = TransformerEngineSettings
# Stub avalan.model.hubs.huggingface
hubs = ModuleType("avalan.model.hubs.huggingface")
@@ -93,7 +68,6 @@ def save_tokenizer(self, path):
cls.tokenizer_mod = importlib.import_module(
"avalan.cli.commands.tokenizer"
)
- cls.TransformerEngineSettings = TransformerEngineSettings
@classmethod
def tearDownClass(cls):
@@ -124,7 +98,7 @@ def setUp(self):
self.theme = MagicMock()
self.theme._ = lambda s: s
self.theme._n = lambda s, p, n: s if n == 1 else p
- self.theme.icons = {"user_input": ">"}
+ self.theme._icons = {"user_input": ">"}
self.theme.tokenizer_config.return_value = "cfg_panel"
self.theme.saved_tokenizer_files.return_value = "save_panel"
self.theme.tokenizer_tokens.return_value = "tokens_panel"
@@ -199,7 +173,7 @@ async def test_tokenize_input(self):
dummy_model.tokenize.assert_called_once_with("hello")
get_input.assert_called_once_with(
self.console,
- self.theme.icons["user_input"] + " ",
+ self.theme._icons["user_input"] + " ",
echo_stdin=not args.no_repl,
is_quiet=args.quiet,
tty_path="/dev/tty",
diff --git a/tests/compat_test.py b/tests/compat_test.py
index 02c030e5..13a0bd3d 100644
--- a/tests/compat_test.py
+++ b/tests/compat_test.py
@@ -1,6 +1,5 @@
import importlib
import sys
-import typing
from unittest import TestCase
@@ -10,26 +9,16 @@ def _reload_module(self):
del sys.modules["avalan.compat"]
return importlib.import_module("avalan.compat")
- def test_override_fallback_when_missing(self):
- if hasattr(typing, "override"):
- delattr(typing, "override")
+ def test_override_from_typing_extensions(self):
+ """Test that override is imported from typing_extensions."""
compat = self._reload_module()
- self.assertEqual(compat.override.__module__, "avalan.compat")
+ self.assertIn(
+ compat.override.__module__, {"typing", "typing_extensions"}
+ )
def func():
return 1
- self.assertIs(compat.override(func), func)
-
- def test_override_uses_typing_override_when_available(self):
- def sentinel(func):
- func.sentinel = True
- return func
-
- typing.override = sentinel
- try:
- compat = self._reload_module()
- self.assertIs(compat.override, sentinel)
- finally:
- delattr(typing, "override")
- self._reload_module()
+ # The override decorator should return the function unchanged
+ decorated = compat.override(func)
+ self.assertEqual(decorated(), 1)
diff --git a/tests/model/audio/audio_classification_test.py b/tests/model/audio/audio_classification_test.py
index 2e11bcc5..5c8c0382 100644
--- a/tests/model/audio/audio_classification_test.py
+++ b/tests/model/audio/audio_classification_test.py
@@ -92,6 +92,10 @@ def item(self):
"avalan.model.audio.classification.inference_mode",
return_value=nullcontext(),
) as inf_mock,
+ patch(
+ "avalan.model.audio.classification.isinstance",
+ return_value=True,
+ ),
):
extractor_instance = MagicMock()
extractor_call = MagicMock(return_value="inputs")
diff --git a/tests/model/audio/base_audio_model_test.py b/tests/model/audio/base_audio_model_test.py
index 7e837bc5..cade371e 100644
--- a/tests/model/audio/base_audio_model_test.py
+++ b/tests/model/audio/base_audio_model_test.py
@@ -1,5 +1,4 @@
from logging import Logger
-from typing import Literal
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch
@@ -12,21 +11,17 @@ class DummyAudioModel(BaseAudioModel):
def _load_model(self):
return MagicMock()
- async def __call__(
- self, image_source: str | object, tensor_format: Literal["pt"] = "pt"
- ) -> str:
- return await super().__call__(image_source, tensor_format)
+ async def __call__(self, path: str) -> str:
+ return path
class BaseAudioModelTestCase(IsolatedAsyncioTestCase):
- async def test_base_methods_raise(self):
+ async def test_tokenizer_methods_raise(self):
model = DummyAudioModel(
None,
EngineSettings(auto_load_model=False),
logger=MagicMock(spec=Logger),
)
- with self.assertRaises(NotImplementedError):
- await model("img")
with self.assertRaises(TokenizerNotSupportedException):
model._load_tokenizer(None)
with self.assertRaises(TokenizerNotSupportedException):
diff --git a/tests/model/engine_additional_test.py b/tests/model/engine_additional_test.py
index 9e9a6f24..9f8bf036 100644
--- a/tests/model/engine_additional_test.py
+++ b/tests/model/engine_additional_test.py
@@ -6,6 +6,7 @@
class DummyEngine(Engine):
+ @property
def uses_tokenizer(self) -> bool:
return False
diff --git a/tests/model/model_manager_operation_test.py b/tests/model/model_manager_operation_test.py
index 23bab4e9..889237f6 100644
--- a/tests/model/model_manager_operation_test.py
+++ b/tests/model/model_manager_operation_test.py
@@ -61,12 +61,12 @@ def test_all_modalities(self):
Modality.AUDIO_TEXT_TO_SPEECH: (self._check_audio, True),
Modality.AUDIO_GENERATION: (self._check_audio, True),
Modality.EMBEDDING: (
- lambda op: self.assertIsNone(op.parameters),
+ lambda op: self.assertEqual(op.parameters, {}),
False,
),
Modality.TEXT_QUESTION_ANSWERING: (self._check_text, True),
Modality.TEXT_SEQUENCE_CLASSIFICATION: (
- lambda op: self.assertIsNone(op.parameters),
+ lambda op: self.assertEqual(op.parameters, {}),
True,
),
Modality.TEXT_SEQUENCE_TO_SEQUENCE: (self._check_text, True),
@@ -95,7 +95,7 @@ def test_unknown_modality(self):
FakeModality.UNKNOWN, self.args, None
)
self.assertEqual(op.modality, FakeModality.UNKNOWN)
- self.assertIsNone(op.parameters)
+ self.assertEqual(op.parameters, {})
self.assertFalse(op.requires_input)
def test_text_to_video_with_steps(self):
diff --git a/tests/model/nlp/text_test.py b/tests/model/nlp/text_test.py
index e346cef3..0aea14cf 100644
--- a/tests/model/nlp/text_test.py
+++ b/tests/model/nlp/text_test.py
@@ -149,12 +149,11 @@ def test_tokenize(self):
lambda i, skip_special_tokens=False: f"t{i}"
)
self.model._loaded_tokenizer = False
- self.model.load = MagicMock()
+ self.model._load = MagicMock()
result = self.model.tokenize("hi")
- self.model.load.assert_called_once_with(
- load_model=False,
+ self.model._load.assert_called_once_with(
load_tokenizer=True,
tokenizer_name_or_path=None,
)
diff --git a/tests/model/response_parsers_additional_test.py b/tests/model/response_parsers_additional_test.py
index 4a8b1c75..c612060e 100644
--- a/tests/model/response_parsers_additional_test.py
+++ b/tests/model/response_parsers_additional_test.py
@@ -168,7 +168,9 @@ def status_side_effect(text: str):
)
event = next(item for item in output if isinstance(item, Event))
self.assertEqual(event.type, EventType.TOOL_PROCESS)
- self.assertEqual(event.payload, [SimpleNamespace(name="call")])
+ self.assertEqual(
+ event.payload, {"calls": [SimpleNamespace(name="call")]}
+ )
trigger_event = event_manager.trigger.await_args_list[0].args[0]
self.assertEqual(trigger_event.type, EventType.TOOL_DETECT)
diff --git a/tests/model/text_generation_modality_vendors_test.py b/tests/model/text_generation_modality_vendors_test.py
index b54dec99..f4ff1640 100644
--- a/tests/model/text_generation_modality_vendors_test.py
+++ b/tests/model/text_generation_modality_vendors_test.py
@@ -9,6 +9,9 @@
from avalan.entities import EngineUri, TransformerEngineSettings
from avalan.model.modalities.text import TextGenerationModality
+# Vendors that don't accept exit_stack parameter
+VENDORS_WITHOUT_EXIT_STACK: set[str] = set()
+
@pytest.mark.parametrize(
"vendor,class_name",
@@ -50,10 +53,17 @@ def test_load_engine_per_vendor(vendor: str, class_name: str) -> None:
result = TextGenerationModality().load_engine(
engine_uri, settings, logger, exit_stack
)
- loader.assert_called_once_with(
- model_id="model",
- settings=settings,
- logger=logger,
- exit_stack=exit_stack,
- )
+ if vendor in VENDORS_WITHOUT_EXIT_STACK:
+ loader.assert_called_once_with(
+ model_id="model",
+ settings=settings,
+ logger=logger,
+ )
+ else:
+ loader.assert_called_once_with(
+ model_id="model",
+ settings=settings,
+ logger=logger,
+ exit_stack=exit_stack,
+ )
assert result is loader.return_value
diff --git a/tests/model/text_modalities_full_test.py b/tests/model/text_modalities_full_test.py
index 6c2ab639..3f4c5852 100644
--- a/tests/model/text_modalities_full_test.py
+++ b/tests/model/text_modalities_full_test.py
@@ -254,7 +254,7 @@ def test_text_generation_get_operation_from_arguments() -> None:
GenerationSettings(),
)
text_params = operation.parameters["text"]
- assert text_params.manual_sampling == 5
+ assert text_params.manual_sampling is True
assert text_params.pick_tokens == 10
assert text_params.skip_special_tokens is True
assert text_params.stop_on_keywords == ["STOP"]
@@ -447,7 +447,7 @@ def test_text_sequence_classification_get_operation_from_arguments() -> None:
GenerationSettings(),
)
)
- assert operation.parameters is None
+ assert operation.parameters == {}
assert operation.input == "text"
diff --git a/tests/model/vendor_tool_call_token_test.py b/tests/model/vendor_tool_call_token_test.py
index fa4f97fb..6b8021cc 100644
--- a/tests/model/vendor_tool_call_token_test.py
+++ b/tests/model/vendor_tool_call_token_test.py
@@ -1,4 +1,5 @@
from unittest import TestCase
+from uuid import UUID
from avalan.entities import ToolCall, ToolCallToken
from avalan.model.vendor import TextGenerationVendor
@@ -26,11 +27,16 @@ def test_build_tool_call_token_handles_invalid_json_str(self) -> None:
tool_name="tool",
arguments='{"a": }',
)
- expected = ToolCallToken(
- token='{"name": "tool", "arguments": {}}',
- call=ToolCall(id=None, name="tool", arguments={}),
+ # When call_id is None, a UUID is generated
+ self.assertEqual(
+ token.token,
+ '{"name": "tool", "arguments": {}}',
)
- self.assertEqual(token, expected)
+ assert token.call is not None
+ self.assertEqual(token.call.name, "tool")
+ self.assertEqual(token.call.arguments, {})
+ # Verify a valid UUID was generated
+ UUID(token.call.id) # This will raise if not a valid UUID
def test_build_tool_call_token_from_dict(self) -> None:
token = TextGenerationVendor.build_tool_call_token(
diff --git a/tests/utils_test.py b/tests/utils_test.py
index 407f6aa7..b1d2728a 100644
--- a/tests/utils_test.py
+++ b/tests/utils_test.py
@@ -135,5 +135,5 @@ def reset(self, *_, **__):
super_close.assert_called_once()
bar.reset(total=2)
- bar._progress.reset.assert_called_once_with(total=2)
+ bar._progress.reset.assert_called_once_with(task_id=1, total=2)
super_reset.assert_called_once_with(total=2)