|  | 
| 17 | 17 | from uuid import UUID | 
| 18 | 18 | 
 | 
| 19 | 19 | from chromadb.api.types import ( | 
|  | 20 | +    EMBEDDING_KEY, | 
| 20 | 21 |     URI, | 
| 21 | 22 |     Schema, | 
| 22 | 23 |     SparseVectorIndexConfig, | 
| @@ -741,3 +742,232 @@ def _sparse_embed( | 
| 741 | 742 |         if is_query: | 
| 742 | 743 |             return sparse_embedding_function.embed_query(input=input) | 
| 743 | 744 |         return sparse_embedding_function(input=input) | 
|  | 745 | + | 
|  | 746 | +    def _embed_knn_string_queries(self, knn: Any) -> Any: | 
|  | 747 | +        """Embed string queries in Knn objects using the appropriate embedding function. | 
|  | 748 | +
 | 
|  | 749 | +        Args: | 
|  | 750 | +            knn: A Knn object that may have a string query | 
|  | 751 | +
 | 
|  | 752 | +        Returns: | 
|  | 753 | +            A Knn object with the string query replaced by an embedding | 
|  | 754 | +
 | 
|  | 755 | +        Raises: | 
|  | 756 | +            ValueError: If the query is a string but no embedding function is available | 
|  | 757 | +        """ | 
|  | 758 | +        from chromadb.execution.expression.operator import Knn | 
|  | 759 | + | 
|  | 760 | +        if not isinstance(knn, Knn): | 
|  | 761 | +            return knn | 
|  | 762 | + | 
|  | 763 | +        # If query is not a string, nothing to do | 
|  | 764 | +        if not isinstance(knn.query, str): | 
|  | 765 | +            return knn | 
|  | 766 | + | 
|  | 767 | +        query_text = knn.query | 
|  | 768 | +        key = knn.key | 
|  | 769 | + | 
|  | 770 | +        # Handle main embedding field | 
|  | 771 | +        if key == EMBEDDING_KEY: | 
|  | 772 | +            # Use the collection's main embedding function | 
|  | 773 | +            embedding = self._embed(input=[query_text], is_query=True) | 
|  | 774 | +            if not embedding or len(embedding) != 1: | 
|  | 775 | +                raise ValueError( | 
|  | 776 | +                    "Embedding function returned unexpected number of embeddings" | 
|  | 777 | +                ) | 
|  | 778 | +            # Return a new Knn with the embedded query | 
|  | 779 | +            return Knn( | 
|  | 780 | +                query=embedding[0], | 
|  | 781 | +                key=knn.key, | 
|  | 782 | +                limit=knn.limit, | 
|  | 783 | +                default=knn.default, | 
|  | 784 | +                return_rank=knn.return_rank, | 
|  | 785 | +            ) | 
|  | 786 | + | 
|  | 787 | +        # Handle metadata field with potential sparse embedding | 
|  | 788 | +        schema = self.schema | 
|  | 789 | +        if schema is None or key not in schema.key_overrides: | 
|  | 790 | +            raise ValueError( | 
|  | 791 | +                f"Cannot embed string query for key '{key}': " | 
|  | 792 | +                f"key not found in schema. Please provide an embedded vector or " | 
|  | 793 | +                f"configure an embedding function for this key in the schema." | 
|  | 794 | +            ) | 
|  | 795 | + | 
|  | 796 | +        value_type = schema.key_overrides[key] | 
|  | 797 | + | 
|  | 798 | +        # Check for sparse vector with embedding function | 
|  | 799 | +        if value_type.sparse_vector is not None: | 
|  | 800 | +            sparse_index = value_type.sparse_vector.sparse_vector_index | 
|  | 801 | +            if sparse_index is not None and sparse_index.enabled: | 
|  | 802 | +                config = sparse_index.config | 
|  | 803 | +                if config.embedding_function is not None: | 
|  | 804 | +                    embedding_func = config.embedding_function | 
|  | 805 | +                    if not isinstance(embedding_func, SparseEmbeddingFunction): | 
|  | 806 | +                        embedding_func = cast( | 
|  | 807 | +                            SparseEmbeddingFunction[Any], embedding_func | 
|  | 808 | +                        ) | 
|  | 809 | +                    validate_sparse_embedding_function(embedding_func) | 
|  | 810 | + | 
|  | 811 | +                    # Embed the query | 
|  | 812 | +                    sparse_embedding = self._sparse_embed( | 
|  | 813 | +                        input=[query_text], | 
|  | 814 | +                        sparse_embedding_function=embedding_func, | 
|  | 815 | +                        is_query=True, | 
|  | 816 | +                    ) | 
|  | 817 | + | 
|  | 818 | +                    if not sparse_embedding or len(sparse_embedding) != 1: | 
|  | 819 | +                        raise ValueError( | 
|  | 820 | +                            "Sparse embedding function returned unexpected number of embeddings" | 
|  | 821 | +                        ) | 
|  | 822 | + | 
|  | 823 | +                    # Return a new Knn with the sparse embedding | 
|  | 824 | +                    return Knn( | 
|  | 825 | +                        query=sparse_embedding[0], | 
|  | 826 | +                        key=knn.key, | 
|  | 827 | +                        limit=knn.limit, | 
|  | 828 | +                        default=knn.default, | 
|  | 829 | +                        return_rank=knn.return_rank, | 
|  | 830 | +                    ) | 
|  | 831 | + | 
|  | 832 | +        # Check for dense vector with embedding function (float_list) | 
|  | 833 | +        if value_type.float_list is not None: | 
|  | 834 | +            vector_index = value_type.float_list.vector_index | 
|  | 835 | +            if vector_index is not None and vector_index.enabled: | 
|  | 836 | +                config = vector_index.config | 
|  | 837 | +                if config.embedding_function is not None: | 
|  | 838 | +                    embedding_func = config.embedding_function | 
|  | 839 | +                    validate_embedding_function(embedding_func) | 
|  | 840 | + | 
|  | 841 | +                    # Embed the query using the schema's embedding function | 
|  | 842 | +                    try: | 
|  | 843 | +                        embeddings = embedding_func.embed_query(input=[query_text]) | 
|  | 844 | +                    except AttributeError: | 
|  | 845 | +                        # Fallback if embed_query doesn't exist | 
|  | 846 | +                        embeddings = embedding_func([query_text]) | 
|  | 847 | + | 
|  | 848 | +                    if not embeddings or len(embeddings) != 1: | 
|  | 849 | +                        raise ValueError( | 
|  | 850 | +                            "Embedding function returned unexpected number of embeddings" | 
|  | 851 | +                        ) | 
|  | 852 | + | 
|  | 853 | +                    # Return a new Knn with the dense embedding | 
|  | 854 | +                    return Knn( | 
|  | 855 | +                        query=embeddings[0], | 
|  | 856 | +                        key=knn.key, | 
|  | 857 | +                        limit=knn.limit, | 
|  | 858 | +                        default=knn.default, | 
|  | 859 | +                        return_rank=knn.return_rank, | 
|  | 860 | +                    ) | 
|  | 861 | + | 
|  | 862 | +        raise ValueError( | 
|  | 863 | +            f"Cannot embed string query for key '{key}': " | 
|  | 864 | +            f"no embedding function configured for this key in the schema. " | 
|  | 865 | +            f"Please provide an embedded vector or configure an embedding function." | 
|  | 866 | +        ) | 
|  | 867 | + | 
|  | 868 | +    def _embed_rank_string_queries(self, rank: Any) -> Any: | 
|  | 869 | +        """Recursively embed string queries in Rank expressions. | 
|  | 870 | +
 | 
|  | 871 | +        Args: | 
|  | 872 | +            rank: A Rank expression that may contain Knn objects with string queries | 
|  | 873 | +
 | 
|  | 874 | +        Returns: | 
|  | 875 | +            A Rank expression with all string queries embedded | 
|  | 876 | +        """ | 
|  | 877 | +        # Import here to avoid circular dependency | 
|  | 878 | +        from chromadb.execution.expression.operator import ( | 
|  | 879 | +            Knn, | 
|  | 880 | +            Abs, | 
|  | 881 | +            Div, | 
|  | 882 | +            Exp, | 
|  | 883 | +            Log, | 
|  | 884 | +            Max, | 
|  | 885 | +            Min, | 
|  | 886 | +            Mul, | 
|  | 887 | +            Sub, | 
|  | 888 | +            Sum, | 
|  | 889 | +            Val, | 
|  | 890 | +            Rrf, | 
|  | 891 | +        ) | 
|  | 892 | + | 
|  | 893 | +        if rank is None: | 
|  | 894 | +            return None | 
|  | 895 | + | 
|  | 896 | +        # Base case: Knn - embed if it has a string query | 
|  | 897 | +        if isinstance(rank, Knn): | 
|  | 898 | +            return self._embed_knn_string_queries(rank) | 
|  | 899 | + | 
|  | 900 | +        # Base case: Val - no embedding needed | 
|  | 901 | +        if isinstance(rank, Val): | 
|  | 902 | +            return rank | 
|  | 903 | + | 
|  | 904 | +        # Recursive cases: walk through child ranks | 
|  | 905 | +        if isinstance(rank, Abs): | 
|  | 906 | +            return Abs(self._embed_rank_string_queries(rank.rank)) | 
|  | 907 | + | 
|  | 908 | +        if isinstance(rank, Div): | 
|  | 909 | +            return Div( | 
|  | 910 | +                self._embed_rank_string_queries(rank.left), | 
|  | 911 | +                self._embed_rank_string_queries(rank.right), | 
|  | 912 | +            ) | 
|  | 913 | + | 
|  | 914 | +        if isinstance(rank, Exp): | 
|  | 915 | +            return Exp(self._embed_rank_string_queries(rank.rank)) | 
|  | 916 | + | 
|  | 917 | +        if isinstance(rank, Log): | 
|  | 918 | +            return Log(self._embed_rank_string_queries(rank.rank)) | 
|  | 919 | + | 
|  | 920 | +        if isinstance(rank, Max): | 
|  | 921 | +            return Max([self._embed_rank_string_queries(r) for r in rank.ranks]) | 
|  | 922 | + | 
|  | 923 | +        if isinstance(rank, Min): | 
|  | 924 | +            return Min([self._embed_rank_string_queries(r) for r in rank.ranks]) | 
|  | 925 | + | 
|  | 926 | +        if isinstance(rank, Mul): | 
|  | 927 | +            return Mul([self._embed_rank_string_queries(r) for r in rank.ranks]) | 
|  | 928 | + | 
|  | 929 | +        if isinstance(rank, Sub): | 
|  | 930 | +            return Sub( | 
|  | 931 | +                self._embed_rank_string_queries(rank.left), | 
|  | 932 | +                self._embed_rank_string_queries(rank.right), | 
|  | 933 | +            ) | 
|  | 934 | + | 
|  | 935 | +        if isinstance(rank, Sum): | 
|  | 936 | +            return Sum([self._embed_rank_string_queries(r) for r in rank.ranks]) | 
|  | 937 | + | 
|  | 938 | +        if isinstance(rank, Rrf): | 
|  | 939 | +            return Rrf( | 
|  | 940 | +                ranks=[self._embed_rank_string_queries(r) for r in rank.ranks], | 
|  | 941 | +                k=rank.k, | 
|  | 942 | +                weights=rank.weights, | 
|  | 943 | +                normalize=rank.normalize, | 
|  | 944 | +            ) | 
|  | 945 | + | 
|  | 946 | +        # Unknown rank type - return as is | 
|  | 947 | +        return rank | 
|  | 948 | + | 
|  | 949 | +    def _embed_search_string_queries(self, search: Any) -> Any: | 
|  | 950 | +        """Embed string queries in a Search object. | 
|  | 951 | +
 | 
|  | 952 | +        Args: | 
|  | 953 | +            search: A Search object that may contain Knn objects with string queries | 
|  | 954 | +
 | 
|  | 955 | +        Returns: | 
|  | 956 | +            A Search object with all string queries embedded | 
|  | 957 | +        """ | 
|  | 958 | +        # Import here to avoid circular dependency | 
|  | 959 | +        from chromadb.execution.expression.plan import Search | 
|  | 960 | + | 
|  | 961 | +        if not isinstance(search, Search): | 
|  | 962 | +            return search | 
|  | 963 | + | 
|  | 964 | +        # Embed the rank expression if it exists | 
|  | 965 | +        embedded_rank = self._embed_rank_string_queries(search._rank) | 
|  | 966 | + | 
|  | 967 | +        # Create a new Search with the embedded rank | 
|  | 968 | +        return Search( | 
|  | 969 | +            where=search._where, | 
|  | 970 | +            rank=embedded_rank, | 
|  | 971 | +            limit=search._limit, | 
|  | 972 | +            select=search._select, | 
|  | 973 | +        ) | 
0 commit comments