diff --git a/src/main/scala/com/redis/RedisClient.scala b/src/main/scala/com/redis/RedisClient.scala index 9850a23..d30ca63 100644 --- a/src/main/scala/com/redis/RedisClient.scala +++ b/src/main/scala/com/redis/RedisClient.scala @@ -2,9 +2,10 @@ package com.redis import java.net.SocketException import javax.net.ssl.SSLContext - import com.redis.serialization.Format +import scala.annotation.tailrec + object RedisClient { sealed trait SortOrder case object ASC extends SortOrder @@ -35,7 +36,8 @@ abstract class Redis(batch: Mode) extends IO with Protocol { var handlers: Vector[(String, () => Any)] = Vector.empty val commandBuffer = collection.mutable.ListBuffer.empty[CommandToSend] - def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try { + @tailrec + private def doSend[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try { if (batch == BATCH) { handlers :+= ((command, () => result)) commandBuffer += CommandToSend(command, args.map(format.apply)) @@ -46,14 +48,18 @@ abstract class Redis(batch: Mode) extends IO with Protocol { } } catch { case e: RedisConnectionException => - if (disconnect) send(command, args)(result) + if (disconnect) doSend(command, args)(result) else throw e case e: SocketException => - if (disconnect) send(command, args)(result) + if (disconnect) doSend(command, args)(result) else throw e } - def send[A](command: String)(result: => A): A = try { + def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = + doSend(command, args)(result) + + @tailrec + private def doSend[A](command: String)(result: => A): A = try { if (batch == BATCH) { handlers :+= ((command, () => result)) commandBuffer += CommandToSend(command, Seq.empty[Array[Byte]]) @@ -64,14 +70,17 @@ abstract class Redis(batch: Mode) extends IO with Protocol { } } catch { case e: RedisConnectionException => - if (disconnect) send(command)(result) + if (disconnect) doSend(command)(result) else throw e case e: SocketException => - if (disconnect) send(command)(result) + if (disconnect) doSend(command)(result) else throw e } - def send[A](commands: List[CommandToSend])(result: => A): A = try { + def send[A](command: String)(result: => A): A = doSend(command)(result) + + @tailrec + private def doSend[A](commands: List[CommandToSend])(result: => A): A = try { val cs = commands.map { command => command.command.getBytes("UTF-8") +: command.args } @@ -79,13 +88,15 @@ abstract class Redis(batch: Mode) extends IO with Protocol { result } catch { case e: RedisConnectionException => - if (disconnect) send(commands)(result) + if (disconnect) doSend(commands)(result) else throw e case e: SocketException => - if (disconnect) send(commands)(result) + if (disconnect) doSend(commands)(result) else throw e } + def send[A](commands: List[CommandToSend])(result: => A): A = doSend(commands)(result) + def cmd(args: Seq[Array[Byte]]): Array[Byte] = Commands.multiBulk(args) protected def flattenPairs(in: Iterable[Product2[Any, Any]]): List[Any] = diff --git a/src/test/scala/com/redis/RedisClientSpec.scala b/src/test/scala/com/redis/RedisClientSpec.scala index 37694e6..4cbddb1 100644 --- a/src/test/scala/com/redis/RedisClientSpec.scala +++ b/src/test/scala/com/redis/RedisClientSpec.scala @@ -1,11 +1,13 @@ package com.redis -import java.net.{ServerSocket, URI} +import com.redis.RedisClientSpec.DummyClientWithFaultyConnection +import java.net.{ServerSocket, URI} import com.redis.api.ApiSpec import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers +import java.io.OutputStream import scala.concurrent.Await import scala.concurrent.duration._ @@ -71,7 +73,7 @@ class RedisClientSpec extends AnyFunSpec r.close() }} -// describe("test reconnect") { + describe("test reconnect") { // it("should re-init after server restart") { // val docker = new Docker(DefaultDockerClientConfig.createDefaultConfigBuilder().build()).client // @@ -104,5 +106,46 @@ class RedisClientSpec extends AnyFunSpec // // got shouldBe Some(value) // } -// } + + it("should not trigger a StackOverflowError in send(..) if Redis is down") { + val maxFailures = 10000 // Should be enough to trigger StackOverflowError + val r = new DummyClientWithFaultyConnection(maxFailures) + r.send("PING") { + /* PONG */ + } + r.connected shouldBe true + } + + } +} + +object RedisClientSpec { + + private class DummyClientWithFaultyConnection(maxFailures: Int) extends Redis(RedisClient.SINGLE) { + + private var _connected = false + private var _failures = 0 + + override val host: String = null + override val port: Int = 0 + override val timeout: Int = 0 + + override def onConnect(): Unit = () + + override def connected: Boolean = _connected + + override def disconnect: Boolean = true + + override def write_to_socket(data: Array[Byte])(op: OutputStream => Unit): Unit = () + + override def connect: Boolean = + if (_failures <= maxFailures) { + _failures += 1 + throw RedisConnectionException("fail in order to trigger the reconnect") + } else { + _connected = true + true + } + } + }