Skip to content

Conversation

@mhauskn-dr
Copy link
Contributor

@mhauskn-dr mhauskn-dr commented Jun 14, 2025

This PR adds a caching layer for the retriever: {query+retriever-params} --> retrieved_nodes.

Additionally a Ray cache actor is introduced which performs caching at the ray level, in addition to caching at the local filesystem and Amazon S3 layers.

The PR also contains a variety of other updates, which I've tried to explain in the comments below. The biggest change is that the RAGFlow now inherits from the RetrieverFlow and instead of directly calling query/aquery() methods, it calls retrieve/aretrieve followed by synthesize/asynthesize. This change enables us to do query caching across all flow types.

Given the scope of this change, it deserves adequate testing. I have run retriever-only and normal studies, which don't raise errors. But I have not done more detailed tests - such as verifying the exact set of returned documents is the same.

with local_cache() as cache:
logger.info(f"Storing index to {cache.directory}")
cache.add(cache_key, serialized_obj)
cache.set(cache_key, serialized_obj)
Copy link
Contributor Author

@mhauskn-dr mhauskn-dr Jun 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache.set() will always overwrite the existing cache entry, while cache.add() will not overwrite if the cache_key is already present. For some reason, in testing on my local machine, I would get cache misses for indexes that have valid keys cached, but are returning cache-value of None (not totally sure why). But when a None cache value is present - it will result in a cache miss, and then cache.add() will not update the cache entry (since cache key already exists). So a bad cache value will remain and cause cache misses into the future. Cache.set() resolved these issues for me by always setting the cache value and I'm now getting cache hits with local caching.

if params["reranker_enabled"]:
params.update(**self.reranker.sample(trial))
else:
params["reranker_enabled"] = False
Copy link
Contributor Author

@mhauskn-dr mhauskn-dr Jun 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere - it's general to query defaults for the value of this parameter rather than hard-coding a value.

params.update(**self.lats_rag_agent.sample(trial))
else:
params.update(**self.lats_rag_agent.defaults())
params.update(**self.lats_rag_agent.sample(trial))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line was a mistake that was always sampling the lats agent, even if it should have been using defaults.

"few_shot_enabled", self.few_shot_enabled
)
params["few_shot_enabled"] = few_shot_enabled
params.update(**self.few_shot_retriever.sample(trial))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we were always sampling few_shot_enabled & few_shot_retriever even if it wasn't in the parameters. Updated the code to only sample few_shot if in parameters.

def defaults(self) -> ParamDict:
return {
**self._defaults(),
**self._custom_defaults,
Copy link
Contributor Author

@mhauskn-dr mhauskn-dr Jun 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method of updating defaults with custom_defaults does not work due to the way the sampling of sub-components is implemented. For example let's say our custom_defaults has hyde_enabled=True, hyde_llm=gpt-4o-mini. SearchSpace.sample() would end up picking up hyde_enabled=True from defaults, then calling params.update(**self.hyde.defaults()), which would override our hyde_llm param with whatever the default hyde_llm was.

Therefore I've updated the way custom_defaults are integrated - so that they are now used at the end of the sampling process to override any sampled values (see below).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct. The lines above only mean that custom defaults take precedence over standard defaults. I hyde is set, defaults are not used but values are sampled. After your change, you would override whatever was set, even when sampled, with custom defaults. That's not original idea. Defaults are used when the parameter is not specified. Is is important for block optimization. Maybe some good unit tests could help clarify this.


# Retrieval cache constants and key builder
RETRIEVAL_CACHE_PREFIX = "retrieval_cache"
RETRIEVER_CACHE_VERSION = 1
Copy link
Contributor Author

@mhauskn-dr mhauskn-dr Jun 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of code in this file is a near duplicate of syftr/retrievers/storage.py. Would be good to refactor both of these into a generalized caching mechanism for both embeddings and retrieved documents.

few_shot_examples=examples,
)

def retrieve(self, query: str) -> T.List[NodeWithScore]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retrieve and aretrieve() are now inherited from RetrieverFlow

self, query: str, invocation_id: str
) -> T.Tuple[CompletionResponse, float]:
start_time = time.perf_counter()
response = await self.query_engine.aquery(query)
Copy link
Contributor Author

@mhauskn-dr mhauskn-dr Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Llamaindex code for RetrieverQueryEngine's aquery is:

    @dispatcher.span
    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query."""
        with self.callback_manager.event(
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
        ) as query_event:
            nodes = await self.aretrieve(query_bundle)

            response = await self._response_synthesizer.asynthesize(
                query=query_bundle,
                nodes=nodes,
            )

            query_event.on_end(payload={EventPayload.RESPONSE: response})

        return response

So I believe my refactoring is identical, minus the callbacks.

For the case of LLamaindex TransformQueryEngine (ala Hyde) calls the query_transform first, then the query_engine with an updated querybundle:

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query."""
        query_bundle = self._query_transform.run(
            query_bundle, metadata=self._transform_metadata
        )
        return await self._query_engine.aquery(query_bundle)

The same thing happens for retrieve/aretrieve():

    def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        query_bundle = self._query_transform.run(
            query_bundle, metadata=self._transform_metadata
        )
        return self._query_engine.retrieve(query_bundle)

Therefore in the PRs code, I believe the correct transformations are still being applied when we call aretrieve() followed by asynthesize().

syftr/flows.py Outdated
) -> T.Tuple[CompletionResponse, float]:
start_time = time.perf_counter()
response = self.query_engine.query(query)
nodes, _ = self.retrieve(query)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This retrieve()/aretrieve() call will be fulfilled by the RetrieverFlow's retrieve/aretrieve() method, which includes caching.

@mhauskn-dr mhauskn-dr marked this pull request as ready for review June 16, 2025 17:03
@mhauskn-dr mhauskn-dr requested a review from Copilot June 16, 2025 18:29
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a caching layer for retriever queries by incorporating a unique fingerprint (retriever_cache_fingerprint) into various flows and study configurations, and by integrating caching mechanisms at the local filesystem, Ray, and Amazon S3 levels. Key changes include:

  • Integrating cache fingerprint generation in flow building (e.g. qa_tuner.py and flows.py).
  • Introducing and using a new cached retriever module (cached_retriever.py) along with the Ray cache actor.
  • Updating study defaults and related configuration to support the new caching mechanism.

Reviewed Changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
syftr/tuner/qa_tuner.py Adds integration of retriever cache fingerprint into flow construction.
syftr/studies.py Renames and updates default parameters handling for search spaces.
syftr/retrievers/storage.py Updates caching API calls and incorporates Ray cache updates.
syftr/retrievers/cached_retriever.py Implements new caching logic for retriever responses.
syftr/ray/utils.py Introduces Ray cache actor and associated helper functions.
syftr/optimization.py Replaces the deprecated update_defaults with direct custom_defaults update.
syftr/logger.py Disables logger propagation to avoid duplicate logs with Ray.
syftr/flows.py Refactors RetrieverFlow and RAGFlow to incorporate caching operations.
syftr/configuration.py Adds a new configuration path for retrieval cache.
studies/*.yaml Updates study configurations to leverage the new caching layer.
Comments suppressed due to low confidence (2)

syftr/retrievers/cached_retriever.py:83

  • Replace the curly quotes in the docstring with standard straight quotes to ensure consistency and avoid potential issues in environments that do not support Unicode quotes.
    """Mirror to both diskcache & S3 under “retrieval_cache/{cache_key}.pkl”."""

syftr/optimization.py:302

  • The update of defaults now directly uses 'custom_defaults.update'. Ensure that all consumers of the study configuration are updated accordingly to reflect the removal of the old 'update_defaults' method.
study_config.search_space.custom_defaults.update(defaults)

@mhauskn-dr mhauskn-dr requested review from alex-dr and shackmann June 16, 2025 18:32
assert self.retriever_cache_fingerprint, (
"Retriever fingerprint must be set to use cache."
)
start_time = time.perf_counter()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seem that we use duration to measure retrieval time, do we?
also it will become even less useful, since it will be skewed by the cache hit.
Perhaps, we can remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retriever-only studies do measure and report the retrieval time. Agree that this will be altered by cache hits, but I think it's still a useful number to track.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding the retrieval time to the cache and then reporting the cached duration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting idea - I have concerns around switching compute platforms - e.g. if we are caching to s3 the retriever times on different clusters could be quite different, leading us to pull cached durations that aren't applicable to the current compute environment.

An alternative would be to disable/dissallow the combination of latency-optimization and retriever-caching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented caching of retrieval time and llm_call_data in addition to the retrieval response. This should take care of timing discrepancies related to cache hits.

@shackmann
Copy link
Contributor

How about adding a test for the cache?

Copy link
Contributor

@shackmann shackmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss the changes regarding defaults first.

syftr/flows.py Outdated
else:
logger.debug(f"Retriever cache miss: {query}")
qb = QueryBundle(query)
retrieval_result = self.query_engine.retrieve(qb)
Copy link
Collaborator

@alex-dr alex-dr Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we did self.retrieve instead then we'd cache the timing info along with the retrieval result, and wouldn't need to report the cache retrieval duration, which could be wildly different if the retriever uses query decomposition / HyDE.

And it'd be better for encapsulating the retrieval logic into just the retrieve method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented the strategy we discussed around refactoring and appending llm_call_data etc.

@mhauskn-dr
Copy link
Contributor Author

@shackmann implemented the param_override changes we discussed. Mind taking a look?

@mhauskn-dr mhauskn-dr force-pushed the matt/retriever-cache branch from e557867 to 251ccc2 Compare June 23, 2025 15:28
@mhauskn-dr mhauskn-dr force-pushed the matt/retriever-cache branch from 251ccc2 to 7f6e126 Compare June 23, 2025 16:00
invocation_id = uuid4().hex
self._llm_call_data[invocation_id] = []
response, duration = await self._agenerate(query, invocation_id)
response, duration, retrieval_call_data = await self._agenerate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if it makes sense to introduce some dataclass and used it for resulting object instead of tuples. I would simplify the code and type annotations a bit here and everywhere.
Don't think it is in the scope of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had wondered the exact same thing and but also thought it was likely outside the scope of this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants