-
Notifications
You must be signed in to change notification settings - Fork 24
Retriever cache #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Retriever cache #129
Conversation
| with local_cache() as cache: | ||
| logger.info(f"Storing index to {cache.directory}") | ||
| cache.add(cache_key, serialized_obj) | ||
| cache.set(cache_key, serialized_obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache.set() will always overwrite the existing cache entry, while cache.add() will not overwrite if the cache_key is already present. For some reason, in testing on my local machine, I would get cache misses for indexes that have valid keys cached, but are returning cache-value of None (not totally sure why). But when a None cache value is present - it will result in a cache miss, and then cache.add() will not update the cache entry (since cache key already exists). So a bad cache value will remain and cause cache misses into the future. Cache.set() resolved these issues for me by always setting the cache value and I'm now getting cache hits with local caching.
| if params["reranker_enabled"]: | ||
| params.update(**self.reranker.sample(trial)) | ||
| else: | ||
| params["reranker_enabled"] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and elsewhere - it's general to query defaults for the value of this parameter rather than hard-coding a value.
| params.update(**self.lats_rag_agent.sample(trial)) | ||
| else: | ||
| params.update(**self.lats_rag_agent.defaults()) | ||
| params.update(**self.lats_rag_agent.sample(trial)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line was a mistake that was always sampling the lats agent, even if it should have been using defaults.
| "few_shot_enabled", self.few_shot_enabled | ||
| ) | ||
| params["few_shot_enabled"] = few_shot_enabled | ||
| params.update(**self.few_shot_retriever.sample(trial)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we were always sampling few_shot_enabled & few_shot_retriever even if it wasn't in the parameters. Updated the code to only sample few_shot if in parameters.
| def defaults(self) -> ParamDict: | ||
| return { | ||
| **self._defaults(), | ||
| **self._custom_defaults, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method of updating defaults with custom_defaults does not work due to the way the sampling of sub-components is implemented. For example let's say our custom_defaults has hyde_enabled=True, hyde_llm=gpt-4o-mini. SearchSpace.sample() would end up picking up hyde_enabled=True from defaults, then calling params.update(**self.hyde.defaults()), which would override our hyde_llm param with whatever the default hyde_llm was.
Therefore I've updated the way custom_defaults are integrated - so that they are now used at the end of the sampling process to override any sampled values (see below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not correct. The lines above only mean that custom defaults take precedence over standard defaults. I hyde is set, defaults are not used but values are sampled. After your change, you would override whatever was set, even when sampled, with custom defaults. That's not original idea. Defaults are used when the parameter is not specified. Is is important for block optimization. Maybe some good unit tests could help clarify this.
|
|
||
| # Retrieval cache constants and key builder | ||
| RETRIEVAL_CACHE_PREFIX = "retrieval_cache" | ||
| RETRIEVER_CACHE_VERSION = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of code in this file is a near duplicate of syftr/retrievers/storage.py. Would be good to refactor both of these into a generalized caching mechanism for both embeddings and retrieved documents.
| few_shot_examples=examples, | ||
| ) | ||
|
|
||
| def retrieve(self, query: str) -> T.List[NodeWithScore]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retrieve and aretrieve() are now inherited from RetrieverFlow
| self, query: str, invocation_id: str | ||
| ) -> T.Tuple[CompletionResponse, float]: | ||
| start_time = time.perf_counter() | ||
| response = await self.query_engine.aquery(query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Llamaindex code for RetrieverQueryEngine's aquery is:
@dispatcher.span
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = await self.aretrieve(query_bundle)
response = await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
So I believe my refactoring is identical, minus the callbacks.
For the case of LLamaindex TransformQueryEngine (ala Hyde) calls the query_transform first, then the query_engine with an updated querybundle:
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return await self._query_engine.aquery(query_bundle)
The same thing happens for retrieve/aretrieve():
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return self._query_engine.retrieve(query_bundle)
Therefore in the PRs code, I believe the correct transformations are still being applied when we call aretrieve() followed by asynthesize().
syftr/flows.py
Outdated
| ) -> T.Tuple[CompletionResponse, float]: | ||
| start_time = time.perf_counter() | ||
| response = self.query_engine.query(query) | ||
| nodes, _ = self.retrieve(query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This retrieve()/aretrieve() call will be fulfilled by the RetrieverFlow's retrieve/aretrieve() method, which includes caching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a caching layer for retriever queries by incorporating a unique fingerprint (retriever_cache_fingerprint) into various flows and study configurations, and by integrating caching mechanisms at the local filesystem, Ray, and Amazon S3 levels. Key changes include:
- Integrating cache fingerprint generation in flow building (e.g. qa_tuner.py and flows.py).
- Introducing and using a new cached retriever module (cached_retriever.py) along with the Ray cache actor.
- Updating study defaults and related configuration to support the new caching mechanism.
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| syftr/tuner/qa_tuner.py | Adds integration of retriever cache fingerprint into flow construction. |
| syftr/studies.py | Renames and updates default parameters handling for search spaces. |
| syftr/retrievers/storage.py | Updates caching API calls and incorporates Ray cache updates. |
| syftr/retrievers/cached_retriever.py | Implements new caching logic for retriever responses. |
| syftr/ray/utils.py | Introduces Ray cache actor and associated helper functions. |
| syftr/optimization.py | Replaces the deprecated update_defaults with direct custom_defaults update. |
| syftr/logger.py | Disables logger propagation to avoid duplicate logs with Ray. |
| syftr/flows.py | Refactors RetrieverFlow and RAGFlow to incorporate caching operations. |
| syftr/configuration.py | Adds a new configuration path for retrieval cache. |
| studies/*.yaml | Updates study configurations to leverage the new caching layer. |
Comments suppressed due to low confidence (2)
syftr/retrievers/cached_retriever.py:83
- Replace the curly quotes in the docstring with standard straight quotes to ensure consistency and avoid potential issues in environments that do not support Unicode quotes.
"""Mirror to both diskcache & S3 under “retrieval_cache/{cache_key}.pkl”."""
syftr/optimization.py:302
- The update of defaults now directly uses 'custom_defaults.update'. Ensure that all consumers of the study configuration are updated accordingly to reflect the removal of the old 'update_defaults' method.
study_config.search_space.custom_defaults.update(defaults)
| assert self.retriever_cache_fingerprint, ( | ||
| "Retriever fingerprint must be set to use cache." | ||
| ) | ||
| start_time = time.perf_counter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not seem that we use duration to measure retrieval time, do we?
also it will become even less useful, since it will be skewed by the cache hit.
Perhaps, we can remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retriever-only studies do measure and report the retrieval time. Agree that this will be altered by cache hits, but I think it's still a useful number to track.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding the retrieval time to the cache and then reporting the cached duration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting idea - I have concerns around switching compute platforms - e.g. if we are caching to s3 the retriever times on different clusters could be quite different, leading us to pull cached durations that aren't applicable to the current compute environment.
An alternative would be to disable/dissallow the combination of latency-optimization and retriever-caching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented caching of retrieval time and llm_call_data in addition to the retrieval response. This should take care of timing discrepancies related to cache hits.
|
How about adding a test for the cache? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss the changes regarding defaults first.
syftr/flows.py
Outdated
| else: | ||
| logger.debug(f"Retriever cache miss: {query}") | ||
| qb = QueryBundle(query) | ||
| retrieval_result = self.query_engine.retrieve(qb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we did self.retrieve instead then we'd cache the timing info along with the retrieval result, and wouldn't need to report the cache retrieval duration, which could be wildly different if the retriever uses query decomposition / HyDE.
And it'd be better for encapsulating the retrieval logic into just the retrieve method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented the strategy we discussed around refactoring and appending llm_call_data etc.
|
@shackmann implemented the |
e557867 to
251ccc2
Compare
251ccc2 to
7f6e126
Compare
| invocation_id = uuid4().hex | ||
| self._llm_call_data[invocation_id] = [] | ||
| response, duration = await self._agenerate(query, invocation_id) | ||
| response, duration, retrieval_call_data = await self._agenerate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if it makes sense to introduce some dataclass and used it for resulting object instead of tuples. I would simplify the code and type annotations a bit here and everywhere.
Don't think it is in the scope of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had wondered the exact same thing and but also thought it was likely outside the scope of this PR.
This PR adds a caching layer for the retriever: {query+retriever-params} --> retrieved_nodes.
Additionally a Ray cache actor is introduced which performs caching at the ray level, in addition to caching at the local filesystem and Amazon S3 layers.
The PR also contains a variety of other updates, which I've tried to explain in the comments below. The biggest change is that the RAGFlow now inherits from the RetrieverFlow and instead of directly calling query/aquery() methods, it calls retrieve/aretrieve followed by synthesize/asynthesize. This change enables us to do query caching across all flow types.
Given the scope of this change, it deserves adequate testing. I have run retriever-only and normal studies, which don't raise errors. But I have not done more detailed tests - such as verifying the exact set of returned documents is the same.