Skip to content

Commit

Permalink
Aryn connectors for reading and writing docsets (#1147)
Browse files Browse the repository at this point in the history
* Add Aryn reader and writer

* Add new reader and writer for Aryn

* Remove old tests

* Address reviewer comments

* Fix mypy

* Fix lint
  • Loading branch information
austintlee authored Feb 4, 2025
1 parent 5dd3d7e commit 24dcec6
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 0 deletions.
79 changes: 79 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynReader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from dataclasses import dataclass
from typing import Any

import requests
from requests import Response

from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data import Document
from sycamore.data.element import create_element


@dataclass
class ArynClientParams(BaseDBReader.ClientParams):
def __init__(self, aryn_url: str, api_key: str, **kwargs):
self.aryn_url = aryn_url
assert self.aryn_url is not None, "Aryn URL is required"
self.api_key = api_key
assert self.api_key is not None, "API key is required"
self.kwargs = kwargs


@dataclass
class ArynQueryParams(BaseDBReader.QueryParams):
def __init__(self, docset_id: str):
self.docset_id = docset_id


class ArynQueryResponse(BaseDBReader.QueryResponse):
def __init__(self, docs: list[dict[str, Any]]):
self.docs = docs

def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
docs = []
for doc in self.docs:
elements = doc.get("elements", [])
_doc = Document(**doc)
_doc.data["elements"] = [create_element(**element) for element in elements]
docs.append(_doc)

return docs


class ArynClient(BaseDBReader.Client):
def __init__(self, client_params: ArynClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self.kwargs = kwargs

def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryResponse":
assert isinstance(query_params, ArynQueryParams)
headers = {"Authorization": f"Bearer {self.api_key}"}
response: Response = requests.post(
f"{self.aryn_url}/docsets/{query_params.docset_id}/read", stream=True, headers=headers
)
assert response.status_code == 200
docs = []
print(f"Reading from docset: {query_params.docset_id}")
for chunk in response.iter_lines():
# print(f"\n{chunk}\n")
doc = json.loads(chunk)
docs.append(doc)

return ArynQueryResponse(docs)

def check_target_presence(self, query_params: "BaseDBReader.QueryParams") -> bool:
return True

@classmethod
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynClient":
assert isinstance(params, ArynClientParams)
return cls(params)


class ArynReader(BaseDBReader):
Client = ArynClient
Record = ArynQueryResponse
ClientParams = ArynClientParams
QueryParams = ArynQueryParams
74 changes: 74 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynWriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from dataclasses import dataclass
from typing import Optional, Mapping

import requests

from sycamore.connectors.base_writer import BaseDBWriter
from sycamore.data import Document


@dataclass
class ArynWriterClientParams(BaseDBWriter.ClientParams):
def __init__(self, aryn_url: str, api_key: str, **kwargs):
self.aryn_url = aryn_url
assert self.aryn_url is not None, "Aryn URL is required"
self.api_key = api_key
assert self.api_key is not None, "API key is required"
self.kwargs = kwargs


@dataclass
class ArynWriterTargetParams(BaseDBWriter.TargetParams):
def __init__(self, docset_id: Optional[str] = None):
self.docset_id = docset_id

def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool:
return True


class ArynWriterRecord(BaseDBWriter.Record):
def __init__(self, doc: Document):
self.doc = doc

@classmethod
def from_doc(cls, document: Document, target_params: "BaseDBWriter.TargetParams") -> "ArynWriterRecord":
return cls(document)


class ArynWriterClient(BaseDBWriter.Client):
def __init__(self, client_params: ArynWriterClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self.kwargs = kwargs

@classmethod
def from_client_params(cls, params: "BaseDBWriter.ClientParams") -> "BaseDBWriter.Client":
assert isinstance(params, ArynWriterClientParams)
return cls(params)

def write_many_records(self, records: list["BaseDBWriter.Record"], target_params: "BaseDBWriter.TargetParams"):
assert isinstance(target_params, ArynWriterTargetParams)
docset_id = target_params.docset_id

headers = {"Authorization": f"Bearer {self.api_key}"}

for record in records:
assert isinstance(record, ArynWriterRecord)
doc = record.doc
files: Mapping = {"doc": doc.serialize()}
requests.post(
url=f"{self.aryn_url}/docsets/write", params={"docset_id": docset_id}, files=files, headers=headers
)

def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"):
pass

def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams"):
pass


class ArynWriter(BaseDBWriter):
Client = ArynWriterClient
Record = ArynWriterRecord
ClientParams = ArynWriterClientParams
TargetParams = ArynWriterTargetParams
29 changes: 29 additions & 0 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sycamore.data import Document
from sycamore.connectors.file import ArrowScan, BinaryScan, DocScan, PandasScan, JsonScan, JsonDocumentScan
from sycamore.connectors.file.file_scan import FileMetadataProvider
from sycamore.utils.aryn_config import ArynConfig
from sycamore.utils.import_utils import requires_modules


Expand Down Expand Up @@ -632,3 +633,31 @@ def qdrant(self, client_params: dict, query_params: dict, **kwargs) -> DocSet:
**kwargs,
)
return DocSet(self._context, wr)

def aryn(
self, docset_id: str, aryn_api_key: Optional[str] = None, aryn_url: Optional[str] = None, **kwargs
) -> DocSet:
"""
Reads the contents of an Aryn docset into a DocSet.
Args:
docset_id: The ID of the Aryn docset to read from.
aryn_api_key: (Optional) The Aryn API key to use for authentication.
aryn_url: (Optional) The URL of the Aryn instance to read from.
kwargs: Keyword arguments to pass to the underlying execution engine.
"""
from sycamore.connectors.aryn.ArynReader import (
ArynReader,
ArynClientParams,
ArynQueryParams,
)

if aryn_api_key is None:
aryn_api_key = ArynConfig.get_aryn_api_key()
if aryn_url is None:
aryn_url = ArynConfig.get_aryn_url()

dr = ArynReader(
client_params=ArynClientParams(aryn_url, aryn_api_key), query_params=ArynQueryParams(docset_id), **kwargs
)
return DocSet(self._context, dr)
Empty file.
Empty file.
8 changes: 8 additions & 0 deletions lib/sycamore/sycamore/utils/aryn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def get_aryn_api_key(cls, config_path: str = "") -> str:

return cls._get_aryn_config(config_path).get("aryn_token", "")

@classmethod
def get_aryn_url(cls, config_path: str = "") -> str:
aryn_url = os.environ.get("ARYN_URL")
if aryn_url:
return aryn_url

return cls._get_aryn_config(config_path).get("aryn_url", "")

@classmethod
def _get_aryn_config(cls, config_path: str = "") -> Dict[Any, Any]:
config_path = config_path or os.environ.get("ARYN_CONFIG") or _DEFAULT_PATH
Expand Down
49 changes: 49 additions & 0 deletions lib/sycamore/sycamore/writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Callable, Optional, Union, TYPE_CHECKING

import requests
from pyarrow.fs import FileSystem

from sycamore.context import Context, ExecMode, context_params
Expand All @@ -10,6 +11,7 @@
from sycamore.executor import Execution
from sycamore.plan_nodes import Node
from sycamore.docset import DocSet
from sycamore.utils.aryn_config import ArynConfig
from sycamore.utils.import_utils import requires_modules

from mypy_boto3_s3.client import S3Client
Expand Down Expand Up @@ -800,6 +802,53 @@ def json(

self._maybe_execute(node, True)

def aryn(
self,
docset_id: Optional[str] = None,
name: Optional[str] = None,
aryn_api_key: Optional[str] = None,
aryn_url: Optional[str] = None,
**kwargs,
) -> Optional["DocSet"]:
"""
Writes all documents of a DocSet to Aryn.
Args:
docset_id: The id of the docset to write to. If not provided, a new docset will be created.
create_new_docset: If true, a new docset will be created. If false, the docset with the provided
id will be used.
name: The name of the new docset to create. Required if create_new_docset is true.
aryn_api_key: The api key to use for authentication. If not provided, the api key from the config
file will be used.
aryn_url: The url of the Aryn instance to write to. If not provided, the url from the config file
will be used.
"""

from sycamore.connectors.aryn.ArynWriter import (
ArynWriter,
ArynWriterClientParams,
ArynWriterTargetParams,
)

if aryn_api_key is None:
aryn_api_key = ArynConfig.get_aryn_api_key()
if aryn_url is None:
aryn_url = ArynConfig.get_aryn_url()

if docset_id is None and name is None:
raise ValueError("Either docset_id or name must be provided")

if docset_id is None and name is not None:
headers = {"Authorization": f"Bearer {aryn_api_key}"}
res = requests.post(url=f"{aryn_url}/docsets", data={"name": name}, headers=headers)
docset_id = res.json()["docset_id"]

client_params = ArynWriterClientParams(aryn_url, aryn_api_key)
target_params = ArynWriterTargetParams(docset_id)
ds = ArynWriter(self.plan, client_params=client_params, target_params=target_params, **kwargs)

return self._maybe_execute(ds, True)

def _maybe_execute(self, node: Node, execute: bool) -> Optional[DocSet]:
ds = DocSet(self.context, node)
if not execute:
Expand Down

0 comments on commit 24dcec6

Please sign in to comment.