From cb959ca37dd5daad0f8c9dba0c819ad7f22a1085 Mon Sep 17 00:00:00 2001 From: Spiros Tzavellas Date: Tue, 24 Jan 2023 00:23:56 +0200 Subject: [PATCH] Make sure Redis.send(..) methods are tail recursive We need to do that cause if we have a faulty Redis server that has issues with connections the reconnect-on-error functionality of those methods can result in `StackOverflowError`. We had to create a private method for each `send(..)` method cause simply adding `@tailrec` would produce the following compiler error: "could not optimize @tailrec annotated method send: it is neither private nor final so can be overridden". We also opted for adding private methods instead of defining them inside each method in order to make fewer changes and better preserve the git history. --- src/main/scala/com/redis/RedisClient.scala | 31 ++++++++---- .../scala/com/redis/RedisClientSpec.scala | 49 +++++++++++++++++-- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/main/scala/com/redis/RedisClient.scala b/src/main/scala/com/redis/RedisClient.scala index 9850a23..d30ca63 100644 --- a/src/main/scala/com/redis/RedisClient.scala +++ b/src/main/scala/com/redis/RedisClient.scala @@ -2,9 +2,10 @@ package com.redis import java.net.SocketException import javax.net.ssl.SSLContext - import com.redis.serialization.Format +import scala.annotation.tailrec + object RedisClient { sealed trait SortOrder case object ASC extends SortOrder @@ -35,7 +36,8 @@ abstract class Redis(batch: Mode) extends IO with Protocol { var handlers: Vector[(String, () => Any)] = Vector.empty val commandBuffer = collection.mutable.ListBuffer.empty[CommandToSend] - def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try { + @tailrec + private def doSend[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try { if (batch == BATCH) { handlers :+= ((command, () => result)) commandBuffer += CommandToSend(command, args.map(format.apply)) @@ -46,14 +48,18 @@ abstract class Redis(batch: Mode) extends IO with Protocol { } } catch { case e: RedisConnectionException => - if (disconnect) send(command, args)(result) + if (disconnect) doSend(command, args)(result) else throw e case e: SocketException => - if (disconnect) send(command, args)(result) + if (disconnect) doSend(command, args)(result) else throw e } - def send[A](command: String)(result: => A): A = try { + def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = + doSend(command, args)(result) + + @tailrec + private def doSend[A](command: String)(result: => A): A = try { if (batch == BATCH) { handlers :+= ((command, () => result)) commandBuffer += CommandToSend(command, Seq.empty[Array[Byte]]) @@ -64,14 +70,17 @@ abstract class Redis(batch: Mode) extends IO with Protocol { } } catch { case e: RedisConnectionException => - if (disconnect) send(command)(result) + if (disconnect) doSend(command)(result) else throw e case e: SocketException => - if (disconnect) send(command)(result) + if (disconnect) doSend(command)(result) else throw e } - def send[A](commands: List[CommandToSend])(result: => A): A = try { + def send[A](command: String)(result: => A): A = doSend(command)(result) + + @tailrec + private def doSend[A](commands: List[CommandToSend])(result: => A): A = try { val cs = commands.map { command => command.command.getBytes("UTF-8") +: command.args } @@ -79,13 +88,15 @@ abstract class Redis(batch: Mode) extends IO with Protocol { result } catch { case e: RedisConnectionException => - if (disconnect) send(commands)(result) + if (disconnect) doSend(commands)(result) else throw e case e: SocketException => - if (disconnect) send(commands)(result) + if (disconnect) doSend(commands)(result) else throw e } + def send[A](commands: List[CommandToSend])(result: => A): A = doSend(commands)(result) + def cmd(args: Seq[Array[Byte]]): Array[Byte] = Commands.multiBulk(args) protected def flattenPairs(in: Iterable[Product2[Any, Any]]): List[Any] = diff --git a/src/test/scala/com/redis/RedisClientSpec.scala b/src/test/scala/com/redis/RedisClientSpec.scala index 37694e6..4cbddb1 100644 --- a/src/test/scala/com/redis/RedisClientSpec.scala +++ b/src/test/scala/com/redis/RedisClientSpec.scala @@ -1,11 +1,13 @@ package com.redis -import java.net.{ServerSocket, URI} +import com.redis.RedisClientSpec.DummyClientWithFaultyConnection +import java.net.{ServerSocket, URI} import com.redis.api.ApiSpec import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers +import java.io.OutputStream import scala.concurrent.Await import scala.concurrent.duration._ @@ -71,7 +73,7 @@ class RedisClientSpec extends AnyFunSpec r.close() }} -// describe("test reconnect") { + describe("test reconnect") { // it("should re-init after server restart") { // val docker = new Docker(DefaultDockerClientConfig.createDefaultConfigBuilder().build()).client // @@ -104,5 +106,46 @@ class RedisClientSpec extends AnyFunSpec // // got shouldBe Some(value) // } -// } + + it("should not trigger a StackOverflowError in send(..) if Redis is down") { + val maxFailures = 10000 // Should be enough to trigger StackOverflowError + val r = new DummyClientWithFaultyConnection(maxFailures) + r.send("PING") { + /* PONG */ + } + r.connected shouldBe true + } + + } +} + +object RedisClientSpec { + + private class DummyClientWithFaultyConnection(maxFailures: Int) extends Redis(RedisClient.SINGLE) { + + private var _connected = false + private var _failures = 0 + + override val host: String = null + override val port: Int = 0 + override val timeout: Int = 0 + + override def onConnect(): Unit = () + + override def connected: Boolean = _connected + + override def disconnect: Boolean = true + + override def write_to_socket(data: Array[Byte])(op: OutputStream => Unit): Unit = () + + override def connect: Boolean = + if (_failures <= maxFailures) { + _failures += 1 + throw RedisConnectionException("fail in order to trigger the reconnect") + } else { + _connected = true + true + } + } + }