Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/graphrag-cache/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
92 changes: 92 additions & 0 deletions packages/graphrag-cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# GraphRAG Cache

## Basic

```python
import asyncio
from graphrag_storage import StorageConfig, create_storage, StorageType
from graphrag_cache import CacheConfig, create_cache, CacheType

async def run():
cache = create_cache(
CacheConfig(
type=CacheType.Json
storage=StorageConfig(
type=StorageType.File
base_dir="cache"
)
),
)

await cache.set("my_key", {"some": "object to cache"})
print(await cache.get("my_key"))

if __name__ == "__main__":
asyncio.run(run())
```

## Custom Cache

```python
import asyncio
from typing import Any
from graphrag_storage import Storage
from graphrag_cache import Cache, CacheConfig, create_cache, register_cache

class MyCache(Cache):
def __init__(self, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
# Validate settings and initialize
# View the JsonCache implementation to see how to create a cache that relies on a Storage provider.
...

#Implement rest of interface
...

register_cache("MyCache", MyCache)

async def run():
cache = create_cache(
CacheConfig(
type="MyCache"
some_setting="My Setting"
)
)

# Or use the factory directly to instantiate with a dict instead of using
# CacheConfig + create_factory
# from graphrag_cache.cache_factory import cache_factory
# cache = cache_factory.create(strategy="MyCache", init_args={"some_setting": "My Setting"})

await cache.set("my_key", {"some": "object to cache"})
print(await cache.get("my_key"))

if __name__ == "__main__":
asyncio.run(run())
```

### Details

By default, the `create_cache` comes with the following cache providers registered that correspond to the entries in the `CacheType` enum.

- `JsonCache`
- `MemoryCache`
- `NoopCache`

The preregistration happens dynamically, e.g., `JsonCache` is only imported and registered if you request a `JsonCache` with `create_cache(CacheType.Json, ...)`. There is no need to manually import and register builtin cache providers when using `create_cache`.

If you want a clean factory with no preregistered cache providers then directly import `cache_factory` and bypass using `create_cache`. The downside is that `cache_factory.create` uses a dict for init args instead of the strongly typed `CacheConfig` used with `create_cache`.

```python
from graphrag_cache.cache_factory import cache_factory
from graphrag_cache.json_cache import JsonCache

# cache_factory has no preregistered providers so you must register any
# providers you plan on using.
# May also register a custom implementation, see above for example.
cache_factory.register("my_cache_impl", JsonCache)

cache = cache_factory.create(strategy="my_cache_impl", init_args={"some_setting": "..."})

...

```
17 changes: 17 additions & 0 deletions packages/graphrag-cache/graphrag_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The GraphRAG Cache package."""

from graphrag_cache.cache import Cache
from graphrag_cache.cache_config import CacheConfig
from graphrag_cache.cache_factory import create_cache, register_cache
from graphrag_cache.cache_type import CacheType

__all__ = [
"Cache",
"CacheConfig",
"CacheType",
"create_cache",
"register_cache",
]
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing 'PipelineCache' model."""
"""Abstract base class for cache."""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from typing import Any


class PipelineCache(metaclass=ABCMeta):
class Cache(ABC):
"""Provide a cache interface for the pipeline."""

@abstractmethod
def __init__(self, **kwargs: Any) -> None:
"""Create a cache instance."""

@abstractmethod
async def get(self, key: str) -> Any:
"""Get the value for the given key.
Expand Down Expand Up @@ -59,7 +63,7 @@ async def clear(self) -> None:
"""Clear the cache."""

@abstractmethod
def child(self, name: str) -> PipelineCache:
def child(self, name: str) -> Cache:
"""Create a child cache with the given name.

Args:
Expand Down
26 changes: 26 additions & 0 deletions packages/graphrag-cache/graphrag_cache/cache_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Cache configuration model."""

from graphrag_storage import StorageConfig
from pydantic import BaseModel, ConfigDict, Field

from graphrag_cache.cache_type import CacheType


class CacheConfig(BaseModel):
"""The configuration section for cache."""

model_config = ConfigDict(extra="allow")
"""Allow extra fields to support custom cache implementations."""

type: str = Field(
description="The cache type to use. Builtin types include 'Json', 'Memory', and 'Noop'.",
default=CacheType.Json,
)

storage: StorageConfig | None = Field(
description="The storage configuration to use for file-based caches such as 'Json'.",
default=None,
)
83 changes: 83 additions & 0 deletions packages/graphrag-cache/graphrag_cache/cache_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License


"""Cache factory implementation."""

from collections.abc import Callable

from graphrag_common.factory import Factory, ServiceScope
from graphrag_storage import Storage

from graphrag_cache.cache import Cache
from graphrag_cache.cache_config import CacheConfig
from graphrag_cache.cache_type import CacheType


class CacheFactory(Factory[Cache]):
"""A factory class for cache implementations."""


cache_factory = CacheFactory()


def register_cache(
cache_type: str,
cache_initializer: Callable[..., Cache],
scope: ServiceScope = "transient",
) -> None:
"""Register a custom storage implementation.

Args
----
- storage_type: str
The storage id to register.
- storage_initializer: Callable[..., Storage]
The storage initializer to register.
"""
cache_factory.register(cache_type, cache_initializer, scope)


def create_cache(config: CacheConfig, storage: Storage | None = None) -> Cache:
"""Create a cache implementation based on the given configuration.

Args
----
- config: CacheConfig
The cache configuration to use.
- storage: Storage | None
The storage implementation to use for file-based caches such as 'Json'.

Returns
-------
Cache
The created cache implementation.
"""
config_model = config.model_dump()
cache_strategy = config.type

if cache_strategy not in cache_factory:
match cache_strategy:
case "json":
from graphrag_cache.json_cache import JsonCache

register_cache(CacheType.Json, JsonCache)

case "memory":
from graphrag_cache.memory_cache import MemoryCache

register_cache(CacheType.Memory, MemoryCache)

case "noop":
from graphrag_cache.noop_cache import NoopCache

register_cache(CacheType.Noop, NoopCache)

case _:
msg = f"CacheConfig.type '{cache_strategy}' is not registered in the CacheFactory. Registered types: {', '.join(cache_factory.keys())}."
raise ValueError(msg)

if storage:
config_model["storage"] = storage

return cache_factory.create(strategy=cache_strategy, init_args=config_model)
15 changes: 15 additions & 0 deletions packages/graphrag-cache/graphrag_cache/cache_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License


"""Builtin cache implementation types."""

from enum import StrEnum


class CacheType(StrEnum):
"""Enum for cache types."""

Json = "json"
Memory = "memory"
Noop = "none"
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,35 @@
import json
from typing import Any

from graphrag_storage import Storage
from graphrag_storage import Storage, StorageConfig, create_storage

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache.cache import Cache


class JsonPipelineCache(PipelineCache):
class JsonCache(Cache):
"""File pipeline cache class definition."""

_storage: Storage
_encoding: str

def __init__(self, storage: Storage, encoding="utf-8"):
def __init__(
self,
storage: Storage | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Init method definition."""
self._storage = storage
self._encoding = encoding
if storage is None:
msg = "JsonCache requires either a Storage instance to be provided or a StorageConfig to create one."
raise ValueError(msg)
if isinstance(storage, Storage):
self._storage = storage
else:
self._storage = create_storage(StorageConfig(**storage))

async def get(self, key: str) -> str | None:
async def get(self, key: str) -> Any | None:
"""Get method definition."""
if await self.has(key):
try:
data = await self._storage.get(key, encoding=self._encoding)
data = await self._storage.get(key)
data = json.loads(data)
except UnicodeDecodeError:
await self._storage.delete(key)
Expand All @@ -44,9 +52,7 @@ async def set(self, key: str, value: Any, debug_data: dict | None = None) -> Non
if value is None:
return
data = {"result": value, **(debug_data or {})}
await self._storage.set(
key, json.dumps(data, ensure_ascii=False), encoding=self._encoding
)
await self._storage.set(key, json.dumps(data, ensure_ascii=False))

async def has(self, key: str) -> bool:
"""Has method definition."""
Expand All @@ -61,6 +67,6 @@ async def clear(self) -> None:
"""Clear method definition."""
await self._storage.clear()

def child(self, name: str) -> "JsonPipelineCache":
def child(self, name: str) -> "Cache":
"""Child method definition."""
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)
return JsonCache(storage=self._storage.child(name))
Loading
Loading