|
17 | 17 | from uuid import UUID |
18 | 18 |
|
19 | 19 | from chromadb.api.types import ( |
| 20 | + EMBEDDING_KEY, |
20 | 21 | URI, |
21 | 22 | Schema, |
22 | 23 | SparseVectorIndexConfig, |
@@ -741,3 +742,232 @@ def _sparse_embed( |
741 | 742 | if is_query: |
742 | 743 | return sparse_embedding_function.embed_query(input=input) |
743 | 744 | return sparse_embedding_function(input=input) |
| 745 | + |
| 746 | + def _embed_knn_string_queries(self, knn: Any) -> Any: |
| 747 | + """Embed string queries in Knn objects using the appropriate embedding function. |
| 748 | +
|
| 749 | + Args: |
| 750 | + knn: A Knn object that may have a string query |
| 751 | +
|
| 752 | + Returns: |
| 753 | + A Knn object with the string query replaced by an embedding |
| 754 | +
|
| 755 | + Raises: |
| 756 | + ValueError: If the query is a string but no embedding function is available |
| 757 | + """ |
| 758 | + from chromadb.execution.expression.operator import Knn |
| 759 | + |
| 760 | + if not isinstance(knn, Knn): |
| 761 | + return knn |
| 762 | + |
| 763 | + # If query is not a string, nothing to do |
| 764 | + if not isinstance(knn.query, str): |
| 765 | + return knn |
| 766 | + |
| 767 | + query_text = knn.query |
| 768 | + key = knn.key |
| 769 | + |
| 770 | + # Handle main embedding field |
| 771 | + if key == EMBEDDING_KEY: |
| 772 | + # Use the collection's main embedding function |
| 773 | + embedding = self._embed(input=[query_text], is_query=True) |
| 774 | + if not embedding or len(embedding) != 1: |
| 775 | + raise ValueError( |
| 776 | + "Embedding function returned unexpected number of embeddings" |
| 777 | + ) |
| 778 | + # Return a new Knn with the embedded query |
| 779 | + return Knn( |
| 780 | + query=embedding[0], |
| 781 | + key=knn.key, |
| 782 | + limit=knn.limit, |
| 783 | + default=knn.default, |
| 784 | + return_rank=knn.return_rank, |
| 785 | + ) |
| 786 | + |
| 787 | + # Handle metadata field with potential sparse embedding |
| 788 | + schema = self.schema |
| 789 | + if schema is None or key not in schema.key_overrides: |
| 790 | + raise ValueError( |
| 791 | + f"Cannot embed string query for key '{key}': " |
| 792 | + f"key not found in schema. Please provide an embedded vector or " |
| 793 | + f"configure an embedding function for this key in the schema." |
| 794 | + ) |
| 795 | + |
| 796 | + value_type = schema.key_overrides[key] |
| 797 | + |
| 798 | + # Check for sparse vector with embedding function |
| 799 | + if value_type.sparse_vector is not None: |
| 800 | + sparse_index = value_type.sparse_vector.sparse_vector_index |
| 801 | + if sparse_index is not None and sparse_index.enabled: |
| 802 | + config = sparse_index.config |
| 803 | + if config.embedding_function is not None: |
| 804 | + embedding_func = config.embedding_function |
| 805 | + if not isinstance(embedding_func, SparseEmbeddingFunction): |
| 806 | + embedding_func = cast( |
| 807 | + SparseEmbeddingFunction[Any], embedding_func |
| 808 | + ) |
| 809 | + validate_sparse_embedding_function(embedding_func) |
| 810 | + |
| 811 | + # Embed the query |
| 812 | + sparse_embedding = self._sparse_embed( |
| 813 | + input=[query_text], |
| 814 | + sparse_embedding_function=embedding_func, |
| 815 | + is_query=True, |
| 816 | + ) |
| 817 | + |
| 818 | + if not sparse_embedding or len(sparse_embedding) != 1: |
| 819 | + raise ValueError( |
| 820 | + "Sparse embedding function returned unexpected number of embeddings" |
| 821 | + ) |
| 822 | + |
| 823 | + # Return a new Knn with the sparse embedding |
| 824 | + return Knn( |
| 825 | + query=sparse_embedding[0], |
| 826 | + key=knn.key, |
| 827 | + limit=knn.limit, |
| 828 | + default=knn.default, |
| 829 | + return_rank=knn.return_rank, |
| 830 | + ) |
| 831 | + |
| 832 | + # Check for dense vector with embedding function (float_list) |
| 833 | + if value_type.float_list is not None: |
| 834 | + vector_index = value_type.float_list.vector_index |
| 835 | + if vector_index is not None and vector_index.enabled: |
| 836 | + config = vector_index.config |
| 837 | + if config.embedding_function is not None: |
| 838 | + embedding_func = config.embedding_function |
| 839 | + validate_embedding_function(embedding_func) |
| 840 | + |
| 841 | + # Embed the query using the schema's embedding function |
| 842 | + try: |
| 843 | + embeddings = embedding_func.embed_query(input=[query_text]) |
| 844 | + except AttributeError: |
| 845 | + # Fallback if embed_query doesn't exist |
| 846 | + embeddings = embedding_func([query_text]) |
| 847 | + |
| 848 | + if not embeddings or len(embeddings) != 1: |
| 849 | + raise ValueError( |
| 850 | + "Embedding function returned unexpected number of embeddings" |
| 851 | + ) |
| 852 | + |
| 853 | + # Return a new Knn with the dense embedding |
| 854 | + return Knn( |
| 855 | + query=embeddings[0], |
| 856 | + key=knn.key, |
| 857 | + limit=knn.limit, |
| 858 | + default=knn.default, |
| 859 | + return_rank=knn.return_rank, |
| 860 | + ) |
| 861 | + |
| 862 | + raise ValueError( |
| 863 | + f"Cannot embed string query for key '{key}': " |
| 864 | + f"no embedding function configured for this key in the schema. " |
| 865 | + f"Please provide an embedded vector or configure an embedding function." |
| 866 | + ) |
| 867 | + |
| 868 | + def _embed_rank_string_queries(self, rank: Any) -> Any: |
| 869 | + """Recursively embed string queries in Rank expressions. |
| 870 | +
|
| 871 | + Args: |
| 872 | + rank: A Rank expression that may contain Knn objects with string queries |
| 873 | +
|
| 874 | + Returns: |
| 875 | + A Rank expression with all string queries embedded |
| 876 | + """ |
| 877 | + # Import here to avoid circular dependency |
| 878 | + from chromadb.execution.expression.operator import ( |
| 879 | + Knn, |
| 880 | + Abs, |
| 881 | + Div, |
| 882 | + Exp, |
| 883 | + Log, |
| 884 | + Max, |
| 885 | + Min, |
| 886 | + Mul, |
| 887 | + Sub, |
| 888 | + Sum, |
| 889 | + Val, |
| 890 | + Rrf, |
| 891 | + ) |
| 892 | + |
| 893 | + if rank is None: |
| 894 | + return None |
| 895 | + |
| 896 | + # Base case: Knn - embed if it has a string query |
| 897 | + if isinstance(rank, Knn): |
| 898 | + return self._embed_knn_string_queries(rank) |
| 899 | + |
| 900 | + # Base case: Val - no embedding needed |
| 901 | + if isinstance(rank, Val): |
| 902 | + return rank |
| 903 | + |
| 904 | + # Recursive cases: walk through child ranks |
| 905 | + if isinstance(rank, Abs): |
| 906 | + return Abs(self._embed_rank_string_queries(rank.rank)) |
| 907 | + |
| 908 | + if isinstance(rank, Div): |
| 909 | + return Div( |
| 910 | + self._embed_rank_string_queries(rank.left), |
| 911 | + self._embed_rank_string_queries(rank.right), |
| 912 | + ) |
| 913 | + |
| 914 | + if isinstance(rank, Exp): |
| 915 | + return Exp(self._embed_rank_string_queries(rank.rank)) |
| 916 | + |
| 917 | + if isinstance(rank, Log): |
| 918 | + return Log(self._embed_rank_string_queries(rank.rank)) |
| 919 | + |
| 920 | + if isinstance(rank, Max): |
| 921 | + return Max([self._embed_rank_string_queries(r) for r in rank.ranks]) |
| 922 | + |
| 923 | + if isinstance(rank, Min): |
| 924 | + return Min([self._embed_rank_string_queries(r) for r in rank.ranks]) |
| 925 | + |
| 926 | + if isinstance(rank, Mul): |
| 927 | + return Mul([self._embed_rank_string_queries(r) for r in rank.ranks]) |
| 928 | + |
| 929 | + if isinstance(rank, Sub): |
| 930 | + return Sub( |
| 931 | + self._embed_rank_string_queries(rank.left), |
| 932 | + self._embed_rank_string_queries(rank.right), |
| 933 | + ) |
| 934 | + |
| 935 | + if isinstance(rank, Sum): |
| 936 | + return Sum([self._embed_rank_string_queries(r) for r in rank.ranks]) |
| 937 | + |
| 938 | + if isinstance(rank, Rrf): |
| 939 | + return Rrf( |
| 940 | + ranks=[self._embed_rank_string_queries(r) for r in rank.ranks], |
| 941 | + k=rank.k, |
| 942 | + weights=rank.weights, |
| 943 | + normalize=rank.normalize, |
| 944 | + ) |
| 945 | + |
| 946 | + # Unknown rank type - return as is |
| 947 | + return rank |
| 948 | + |
| 949 | + def _embed_search_string_queries(self, search: Any) -> Any: |
| 950 | + """Embed string queries in a Search object. |
| 951 | +
|
| 952 | + Args: |
| 953 | + search: A Search object that may contain Knn objects with string queries |
| 954 | +
|
| 955 | + Returns: |
| 956 | + A Search object with all string queries embedded |
| 957 | + """ |
| 958 | + # Import here to avoid circular dependency |
| 959 | + from chromadb.execution.expression.plan import Search |
| 960 | + |
| 961 | + if not isinstance(search, Search): |
| 962 | + return search |
| 963 | + |
| 964 | + # Embed the rank expression if it exists |
| 965 | + embedded_rank = self._embed_rank_string_queries(search._rank) |
| 966 | + |
| 967 | + # Create a new Search with the embedded rank |
| 968 | + return Search( |
| 969 | + where=search._where, |
| 970 | + rank=embedded_rank, |
| 971 | + limit=search._limit, |
| 972 | + select=search._select, |
| 973 | + ) |
0 commit comments