From acc83c2708b9106e4b5c7d26ecb46f2b1b81595a Mon Sep 17 00:00:00 2001 From: Caio Camatta Date: Mon, 26 Feb 2024 17:05:30 -0500 Subject: [PATCH] Add getServingInfo unit tests --- .../scala/ai/chronon/online/FetcherBase.scala | 4 +- .../ai/chronon/online/FetcherBaseTest.scala | 51 +++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/FetcherBase.scala b/online/src/main/scala/ai/chronon/online/FetcherBase.scala index 45f0a52ca..84f37155f 100644 --- a/online/src/main/scala/ai/chronon/online/FetcherBase.scala +++ b/online/src/main/scala/ai/chronon/online/FetcherBase.scala @@ -222,8 +222,8 @@ class FetcherBase(kvStore: KVStore, * @param batchEndTs the new batchEndTs from the latest batch data * @param groupByServingInfo the current GroupByServingInfo */ - private def updateServingInfo(batchEndTs: Long, - groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = { + private[online] def updateServingInfo(batchEndTs: Long, + groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = { val name = groupByServingInfo.groupBy.metaData.name if (batchEndTs > groupByServingInfo.batchEndTsMillis) { logger.info(s"""$name's value's batch timestamp of $batchEndTs is diff --git a/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala b/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala index 51b599ce7..e033c4888 100644 --- a/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala +++ b/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala @@ -16,7 +16,12 @@ package ai.chronon.online +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.MetaData import ai.chronon.online.Fetcher.{ColumnSpec, Request, Response} +import ai.chronon.online.FetcherCache.BatchResponses +import ai.chronon.online.KVStore.TimedValue import org.junit.{Before, Test} import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ @@ -25,12 +30,14 @@ import org.mockito.stubbing.Answer import org.mockito.{Answers, ArgumentCaptor} import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar +import org.junit.Assert.assertSame import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.{Failure, Success} +import scala.util.Try -class FetcherBaseTest extends MockitoSugar with Matchers { +class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { val GroupBy = "relevance.short_term_user_features" val Column = "pdp_view_count_14d" val GuestKey = "guest" @@ -118,7 +125,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers { // Fetch a single query val keyMap = Map(GuestKey -> GuestId) val query = ColumnSpec(GroupBy, Column, None, Some(keyMap)) - + doAnswer(new Answer[Future[Seq[Fetcher.Response]]] { def answer(invocation: InvocationOnMock): Future[Seq[Response]] = { Future.successful(Seq()) @@ -130,7 +137,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers { queryResults.contains(query) shouldBe true queryResults.get(query).map(_.values) match { case Some(Failure(ex: IllegalStateException)) => succeed - case _ => fail() + case _ => fail() } // GroupBy request sent to KV store for the query @@ -141,4 +148,42 @@ class FetcherBaseTest extends MockitoSugar with Matchers { actualRequest.get.name shouldBe query.groupByName + "." + query.columnName actualRequest.get.keys shouldBe query.keyMapping.get } + + @Test + def test_getServingInfo_ShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = { + val baseFetcher = new FetcherBase(mock[KVStore]) + val spiedFetcherBase = spy(baseFetcher) + val oldServingInfo = mock[GroupByServingInfoParsed] + val updatedServingInfo = mock[GroupByServingInfoParsed] + val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L))) + val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess) + doReturn(updatedServingInfo).when(spiedFetcherBase).updateServingInfo(any(), any()) + + // updateServingInfo is called + val result = spiedFetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses) + assertSame(result, updatedServingInfo) + verify(spiedFetcherBase).updateServingInfo(any(), any()) + } + + @Test + def test_getServingInfo_ShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = { + val baseFetcher = new FetcherBase(mock[KVStore]) + val spiedFetcherBase = spy(baseFetcher) + val oldServingInfo = mock[GroupByServingInfoParsed] + val metaData = mock[MetaData] + val groupByOpsMock = mock[GroupByOps] + val cachedBatchResponses = BatchResponses(mock[FinalBatchIr]) + val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]] + doReturn(ttlCache).when(spiedFetcherBase).getGroupByServingInfo + doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String]) + metaData.name = "test" + groupByOpsMock.metaData = metaData + when(oldServingInfo.groupByOps).thenReturn(groupByOpsMock) + + // FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is. + val result = spiedFetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses) + assertSame(result, oldServingInfo) + verify(ttlCache).refresh(any()) + verify(spiedFetcherBase, never()).updateServingInfo(any(), any()) + } }