Skip to content
This repository was archived by the owner on Mar 1, 2024. It is now read-only.

Commit e77d1a5

Browse files
feat: sync mongo to SimpleMongoReader of llama-index (#624)
1 parent b9d5689 commit e77d1a5

File tree

1 file changed

+58
-26
lines changed

1 file changed

+58
-26
lines changed

llama_hub/mongo/base.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Mongo client."""
22

3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Optional, Union
44

55
from llama_index.readers.base import BaseReader
66
from llama_index.readers.schema.base import Document
@@ -14,59 +14,91 @@ class SimpleMongoReader(BaseReader):
1414
Args:
1515
host (str): Mongo host.
1616
port (int): Mongo port.
17-
max_docs (int): Maximum number of documents to load.
18-
1917
"""
2018

2119
def __init__(
2220
self,
2321
host: Optional[str] = None,
2422
port: Optional[int] = None,
2523
uri: Optional[str] = None,
26-
max_docs: int = 1000,
2724
) -> None:
2825
"""Initialize with parameters."""
2926
try:
30-
import pymongo # noqa: F401
31-
from pymongo import MongoClient # noqa: F401
32-
except ImportError:
27+
from pymongo import MongoClient
28+
except ImportError as err:
3329
raise ImportError(
3430
"`pymongo` package not found, please run `pip install pymongo`"
35-
)
31+
) from err
32+
33+
client: MongoClient
3634
if uri:
37-
if uri is None:
38-
raise ValueError("Either `host` and `port` or `uri` must be provided.")
39-
self.client: MongoClient = MongoClient(uri)
35+
client = MongoClient(uri)
36+
elif host and port:
37+
client = MongoClient(host, port)
4038
else:
41-
if host is None or port is None:
42-
raise ValueError("Either `host` and `port` or `uri` must be provided.")
43-
self.client = MongoClient(host, port)
44-
self.max_docs = max_docs
39+
raise ValueError("Either `host` and `port` or `uri` must be provided.")
40+
41+
self.client = client
42+
43+
def _flatten(self, texts: List[Union[str, List[str]]]) -> List[str]:
44+
result = []
45+
for text in texts:
46+
result += text if isinstance(text, list) else [text]
47+
return result
4548

4649
def load_data(
47-
self, db_name: str, collection_name: str, query_dict: Optional[Dict] = None
50+
self,
51+
db_name: str,
52+
collection_name: str,
53+
field_names: List[str] = ["text"],
54+
separator: str = "",
55+
query_dict: Optional[Dict] = None,
56+
max_docs: int = 0,
57+
metadata_names: Optional[List[str]] = None,
4858
) -> List[Document]:
4959
"""Load data from the input directory.
5060
5161
Args:
5262
db_name (str): name of the database.
5363
collection_name (str): name of the collection.
54-
query_dict (Optional[Dict]): query to filter documents.
64+
field_names(List[str]): names of the fields to be concatenated.
65+
Defaults to ["text"]
66+
separator (str): separator to be used between fields.
67+
Defaults to ""
68+
query_dict (Optional[Dict]): query to filter documents. Read more
69+
at [official docs](https://www.mongodb.com/docs/manual/reference/method/db.collection.find/#std-label-method-find-query)
5570
Defaults to None
71+
max_docs (int): maximum number of documents to load.
72+
Defaults to 0 (no limit)
73+
metadata_names (Optional[List[str]]): names of the fields to be added
74+
to the metadata attribute of the Document. Defaults to None
5675
5776
Returns:
5877
List[Document]: A list of documents.
59-
6078
"""
61-
documents = []
6279
db = self.client[db_name]
63-
if query_dict is None:
64-
cursor = db[collection_name].find()
65-
else:
66-
cursor = db[collection_name].find(query_dict)
80+
cursor = db[collection_name].find(filter=query_dict or {}, limit=max_docs)
6781

82+
documents = []
6883
for item in cursor:
69-
if "text" not in item:
70-
raise ValueError("`text` field not found in Mongo document.")
71-
documents.append(Document(text=item["text"]))
84+
try:
85+
texts = [item[name] for name in field_names]
86+
except KeyError as err:
87+
raise ValueError(
88+
f"{err.args[0]} field not found in Mongo document."
89+
) from err
90+
91+
texts = self._flatten(texts)
92+
text = separator.join(texts)
93+
94+
if metadata_names is None:
95+
documents.append(Document(text=text))
96+
else:
97+
try:
98+
metadata = {name: item[name] for name in metadata_names}
99+
except KeyError as err:
100+
raise ValueError(
101+
f"{err.args[0]} field not found in Mongo document."
102+
) from err
103+
documents.append(Document(text=text, metadata=metadata))
72104
return documents

0 commit comments

Comments
 (0)