Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aryn connectors for reading and writing docsets #1147

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynReader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import os
from dataclasses import dataclass
from typing import Optional, Any

import requests
from requests import Response

from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data import Document, Element
from sycamore.data.element import create_element
from sycamore.utils.aryn_config import ArynConfig


@dataclass
class ArynClientParams(BaseDBReader.ClientParams):
def __init__(self, aryn_url: str, api_key: str, **kwargs):
self.aryn_url = aryn_url
assert self.aryn_url is not None, "Aryn URL is required"
self.api_key = api_key
assert self.api_key is not None, "API key is required"
self.kwargs = kwargs


@dataclass
class ArynQueryParams(BaseDBReader.QueryParams):
def __init__(self, docset_id: str):
self.docset_id = docset_id


class ArynQueryResponse(BaseDBReader.QueryResponse):
def __init__(self, docs: list[dict[str, Any]]):
self.docs = docs

def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some day I feel like we should turn these lists that we're passing around into iterators/generators (all the way to the ray ds construction) but not right now and that applies to all the readers

docs = []
for doc in self.docs:
elements = doc.get("elements", [])
_doc = Document(**doc)
_doc.data["elements"] = [create_element(**element) for element in elements]
docs.append(_doc)

return docs


class ArynClient(BaseDBReader.Client):
def __init__(self, client_params: ArynClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self.kwargs = kwargs

def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryResponse":
assert isinstance(query_params, ArynQueryParams)
headers = {
"Authorization": f"Bearer {self.api_key}"
}
response: Response = requests.post(
f"{self.aryn_url}/docsets/{query_params.docset_id}/read", stream=True,
headers=headers)
assert response.status_code == 200
docs = []
print(f"Reading from docset: {query_params.docset_id}")
for chunk in response.iter_lines():
# print(f"\n{chunk}\n")
doc = json.loads(chunk)
docs.append(doc)

return ArynQueryResponse(docs)

def check_target_presence(self, query_params: "BaseDBReader.QueryParams") -> bool:
return True

@classmethod
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynClient":
assert isinstance(params, ArynClientParams)
return cls(params)


class ArynReader(BaseDBReader):
Client = ArynClient
Record = ArynQueryResponse
ClientParams = ArynClientParams
QueryParams = ArynQueryParams
79 changes: 79 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynWriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
from dataclasses import dataclass
from typing import Optional, Mapping

import requests

from sycamore.connectors.base_writer import BaseDBWriter
from sycamore.data import Document
from sycamore.utils.aryn_config import ArynConfig


@dataclass
class ArynWriterClientParams(BaseDBWriter.ClientParams):
def __init__(self, aryn_url: str, api_key: str, **kwargs):
self.aryn_url = aryn_url
assert self.aryn_url is not None, "Aryn URL is required"
self.api_key = api_key
assert self.api_key is not None, "API key is required"
self.kwargs = kwargs


@dataclass
class ArynWriterTargetParams(BaseDBWriter.TargetParams):
def __init__(self, docset_id: Optional[str] = None):
self.docset_id = docset_id

def compatible_with(self, other: "BaseDBWriter.TargetParams") -> bool:
return True


class ArynWriterRecord(BaseDBWriter.Record):
def __init__(self, doc: Document):
self.doc = doc

@classmethod
def from_doc(cls, document: Document, target_params: "BaseDBWriter.TargetParams") -> "ArynWriterRecord":
return cls(document)


class ArynWriterClient(BaseDBWriter.Client):
def __init__(self, client_params: ArynWriterClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self.kwargs = kwargs

@classmethod
def from_client_params(cls, params: "BaseDBWriter.ClientParams") -> "BaseDBWriter.Client":
assert isinstance(params, ArynWriterClientParams)
return cls(params)

def write_many_records(self, records: list["BaseDBWriter.Record"], target_params: "BaseDBWriter.TargetParams"):
assert isinstance(target_params, ArynWriterTargetParams)
docset_id = target_params.docset_id
name = ""

headers = {
"Authorization": f"Bearer {self.api_key}"
}

for record in records:
assert isinstance(record, ArynWriterRecord)
doc = record.doc
files: Mapping = {"doc": doc.serialize()}
res = requests.post(url=f"{self.aryn_url}/docsets/write",
params={"docset_id": docset_id},
files=files, headers=headers)

def create_target_idempotent(self, target_params: "BaseDBWriter.TargetParams"):
pass

def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams"):
pass


class ArynWriter(BaseDBWriter):
Client = ArynWriterClient
Record = ArynWriterRecord
ClientParams = ArynWriterClientParams
TargetParams = ArynWriterTargetParams
26 changes: 26 additions & 0 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from pyarrow.filesystem import FileSystem

from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.connectors.aryn.ArynReader import ArynClientParams, ArynQueryParams
from sycamore.context import context_params
from sycamore.plan_nodes import Node
from sycamore import Context, DocSet
from sycamore.data import Document
from sycamore.connectors.file import ArrowScan, BinaryScan, DocScan, PandasScan, JsonScan, JsonDocumentScan
from sycamore.connectors.file.file_scan import FileMetadataProvider
from sycamore.utils.aryn_config import ArynConfig
from sycamore.utils.import_utils import requires_modules


Expand Down Expand Up @@ -632,3 +634,27 @@ def qdrant(self, client_params: dict, query_params: dict, **kwargs) -> DocSet:
**kwargs,
)
return DocSet(self._context, wr)

def aryn(self, docset_id: str, aryn_api_key: Optional[str] = None, aryn_url: Optional[str] = None, **kwargs) -> DocSet:
"""
Reads the contents of an Aryn docset into a DocSet.

Args:
docset_id: The ID of the Aryn docset to read from.
aryn_api_key: (Optional) The Aryn API key to use for authentication.
aryn_url: (Optional) The URL of the Aryn instance to read from.
kwargs: Keyword arguments to pass to the underlying execution engine.
"""
from sycamore.connectors.aryn.ArynReader import (
ArynReader,
ArynClientParams,
ArynQueryParams,
)

if aryn_api_key is None:
aryn_api_key = ArynConfig.get_aryn_api_key()
if aryn_url is None:
aryn_url = ArynConfig.get_aryn_url()

dr = ArynReader(client_params=ArynClientParams(aryn_url, aryn_api_key), query_params=ArynQueryParams(docset_id), **kwargs)
return DocSet(self._context, dr)
8 changes: 8 additions & 0 deletions lib/sycamore/sycamore/utils/aryn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def get_aryn_api_key(cls, config_path: str = "") -> str:

return cls._get_aryn_config(config_path).get("aryn_token", "")

@classmethod
def get_aryn_url(cls, config_path: str = "") -> str:
aryn_url = os.environ.get("ARYN_URL")
if aryn_url:
return aryn_url

return cls._get_aryn_config(config_path).get("aryn_url", "")

@classmethod
def _get_aryn_config(cls, config_path: str = "") -> Dict[Any, Any]:
config_path = config_path or os.environ.get("ARYN_CONFIG") or _DEFAULT_PATH
Expand Down
50 changes: 50 additions & 0 deletions lib/sycamore/sycamore/writer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
from typing import Any, Callable, Optional, Union, TYPE_CHECKING

import requests
from pyarrow.fs import FileSystem

from sycamore.connectors.aryn.ArynReader import ArynClientParams
from sycamore.connectors.aryn.ArynWriter import ArynWriterTargetParams
from sycamore.context import Context, ExecMode, context_params
from sycamore.connectors.common import HostAndPort
from sycamore.connectors.file.file_writer import default_doc_to_bytes, default_filename, FileWriter, JsonWriter
from sycamore.data import Document
from sycamore.executor import Execution
from sycamore.plan_nodes import Node
from sycamore.docset import DocSet
from sycamore.utils.aryn_config import ArynConfig
from sycamore.utils.import_utils import requires_modules

from mypy_boto3_s3.client import S3Client
Expand Down Expand Up @@ -800,6 +804,52 @@ def json(

self._maybe_execute(node, True)

def aryn(
self,
docset_id: Optional[str] = None,
name: Optional[str] = None,
aryn_api_key: Optional[str] = None,
aryn_url: Optional[str] = None,
**kwargs,
) -> Optional["DocSet"]:
"""
Writes all documents of a DocSet to Aryn.

Args:
docset_id: The id of the docset to write to. If not provided, a new docset will be created.
create_new_docset: If true, a new docset will be created. If false, the docset with the provided id will be used.
name: The name of the new docset to create. Required if create_new_docset is true.
aryn_api_key: The api key to use for authentication. If not provided, the api key from the config file will be used.
aryn_url: The url of the Aryn instance to write to. If not provided, the url from the config file will be used.
"""

from sycamore.connectors.aryn.ArynWriter import (
ArynWriter,
ArynWriterClientParams,
ArynWriterTargetParams,
)

if aryn_api_key is None:
aryn_api_key = ArynConfig.get_aryn_api_key()
if aryn_url is None:
aryn_url = ArynConfig.get_aryn_url()

if docset_id is None and name is None:
raise ValueError("Either docset_id or name must be provided")

if docset_id is None and name is not None:
headers = {
"Authorization": f"Bearer {aryn_api_key}"
}
res = requests.post(url=f"{aryn_url}/docsets", data={"name": name}, headers=headers)
docset_id = res.json()["docset_id"]

client_params = ArynWriterClientParams(aryn_url, aryn_api_key)
target_params = ArynWriterTargetParams(docset_id)
ds = ArynWriter(self.plan, client_params=client_params, target_params=target_params, **kwargs)

return self._maybe_execute(ds, True)

def _maybe_execute(self, node: Node, execute: bool) -> Optional[DocSet]:
ds = DocSet(self.context, node)
if not execute:
Expand Down
Loading