Skip to content

Commit

Permalink
feat: Add loom Support.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 11, 2025
1 parent 78031f7 commit a2a40cd
Show file tree
Hide file tree
Showing 9 changed files with 458 additions and 35 deletions.
15 changes: 12 additions & 3 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ jobs:
- name: Run tests
run: |
set -eux
./mill -ikj1 --disable-ticker __.testLocal
if [ "${{ matrix.java }}" == "21" ]; then
./mill -ikj1 --disable-ticker __.testLocal -Dcask.virtual-thread.enabled=true --add-opens java.base/java.lang=ALL-UNNAMED
else
./mill -ikj1 --disable-ticker __.testLocal
fi
test-examples:
runs-on: ubuntu-latest
Expand All @@ -45,8 +49,13 @@ jobs:
- name: Run tests
run: |
set -eux
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples
if [ "${{ matrix.java }}" == "21" ]; then
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples -Dcask.virtual-thread.enabled=true --add-opens java.base/java.lang=ALL-UNNAMED
else
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples
fi
publish-sonatype:
if: github.repository == 'com-lihaoyi/cask' && contains(github.ref, 'refs/tags/')
Expand Down
105 changes: 104 additions & 1 deletion build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,114 @@ object cask extends Cross[CaskMainModule](scalaVersions) {
}
}

trait BenchmarkModule extends CrossScalaModule {
def moduleDeps = Seq(cask(crossScalaVersion))
def ivyDeps = Agg[Dep](
)
}

object benchmark extends Cross[BenchmarkModule](build.scalaVersions) with RunModule {

def waitForServer(url: String, maxAttempts: Int = 120): Boolean = {
(1 to maxAttempts).exists { attempt =>
try {
Thread.sleep(3000)
println("Checking server... Attempt " + attempt)
os.proc("curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", url)
.call(check = false)
.exitCode == 0
} catch {
case _: Throwable =>
Thread.sleep(3000)
false
}
}
}

def runBenchmark() = T.command {
if (os.proc("which", "wrk").call(check = false).exitCode != 0) {
println("Error: wrk is not installed. Please install wrk first.")
sys.exit(1)
}

val duration = "30s"
val threads = "4"
val connections = "100"
val url = "http://localhost:8080/"

println("Testing with regular threads...")

val projectRoot = T.workspace
println("projectRoot: " + projectRoot)

def runMillBackground(example: String, vt:Boolean) = {
println(s"Running $example with vt: $vt")

os.proc(
"mill",
s"$example.app[$scala213].run")
.spawn(
cwd = projectRoot,
env = Map("CASK_VIRTUAL_THREAD" -> vt.toString),
stdout = os.Inherit,
stderr = os.Inherit)
}

val regularApp = runMillBackground("example.minimalApplicationWithLoom", vt = false)

println("Waiting for regular server to start...")
if (!waitForServer(url)) {
regularApp.destroy()
println("Failed to start regular server")
sys.exit(1)
}

println("target server started, starting run benchmark with wrk")
val regularResults = os.proc("wrk",
"-t", threads,
"-c", connections,
"-d", duration,
url
).call(stderr = os.Pipe)
regularApp.destroy()

println("\nRegular Threads Results:")
println(regularResults.out.text())

Thread.sleep(1000)
println("\nTesting with virtual threads, please use Java 21+...")
val virtualApp = runMillBackground("example.minimalApplicationWithLoom", vt = true)

println("Waiting for virtual server to start...")
if (!waitForServer(url)) {
virtualApp.destroy()
println("Failed to start virtual server")
sys.exit(1)
}

println("target server started, starting run benchmark with wrk")
val virtualResults = os.proc("wrk",
"-t", threads,
"-c", connections,
"-d", duration,
url
).call(stderr = os.Pipe)
virtualApp.destroy()

println("\nVirtual Threads Results:")
println(virtualResults.out.text())
}
}

trait LocalModule extends CrossScalaModule{
override def millSourcePath = super.millSourcePath / "app"
def moduleDeps = Seq(cask(crossScalaVersion))
}


object ZincWorkerJava11Latest extends ZincWorkerModule with CoursierModule {
def jvmId = "temurin:23.0.1"
def jvmIndexVersion = "latest.release"
}

def zippedExamples = T {
val vcsState = VcsVersion.vcsState()
Expand All @@ -111,6 +213,7 @@ def zippedExamples = T {
build.example.httpMethods.millSourcePath,
build.example.minimalApplication.millSourcePath,
build.example.minimalApplication2.millSourcePath,
build.example.minimalApplicationWithLoom.millSourcePath,
build.example.redirectAbort.millSourcePath,
build.example.scalatags.millSourcePath,
build.example.staticFiles.millSourcePath,
Expand Down
17 changes: 17 additions & 0 deletions cask/src/cask/internal/ThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

import java.util.concurrent.Executor

/**
* A handler that dispatches the request to the given handler using the given executor.
* */
final class ThreadBlockingHandler(executor: Executor, handler: HttpHandler) extends HttpHandler {
require(executor ne null, "Executor should not be null")

def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(executor, handler)
}
}
140 changes: 120 additions & 20 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,121 @@
package cask.internal

import java.io.{InputStream, PrintWriter, StringWriter}

import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ForkJoinPool, ThreadFactory}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup()

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[Executor] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

/**
* Create a default cask virtual thread executor if possible.
* */
def createDefaultCaskVirtualThreadExecutor: Option[Executor] = {
for {
scheduler <- getDefaultVirtualThreadScheduler
executor <- createVirtualThreadExecutor(scheduler)
} yield executor
}

/**
* Try to get the default virtual thread scheduler, or null if not supported.
* */
def getDefaultVirtualThreadScheduler: Option[ForkJoinPool] = {
try {
val virtualThreadClass = Class.forName("java.lang.VirtualThread")
val privateLookup = MethodHandles.privateLookupIn(virtualThreadClass, lookup)
val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, "DEFAULT_SCHEDULER", classOf[ForkJoinPool])
Option(defaultSchedulerField.get().asInstanceOf[ForkJoinPool])
} catch {
case NonFatal(e) =>
//--add-opens java.base/java.lang=ALL-UNNAMED
globalLogger.exception(e)
None
}
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
p.future
}

/**
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
def literalize(s: IndexedSeq[Char], unicode: Boolean = true) = {
val sb = new StringBuilder
sb.append('"')
Expand Down Expand Up @@ -47,29 +144,30 @@ object Util {
def transferTo(in: InputStream, out: OutputStream) = {
val buffer = new Array[Byte](8192)

while ({
in.read(buffer) match{
while ( {
in.read(buffer) match {
case -1 => false
case n =>
out.write(buffer, 0, n)
true
}
}) ()
}

def pluralize(s: String, n: Int) = {
if (n == 1) s else s + "s"
}

/**
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
def splitPath(p: String): collection.IndexedSeq[String] = {
val pLength = p.length
var i = 0
while(i < pLength && p(i) == '/') i += 1
while (i < pLength && p(i) == '/') i += 1
var segmentStart = i
val out = mutable.ArrayBuffer.empty[String]

Expand All @@ -81,7 +179,7 @@ object Util {
segmentStart = i + 1
}

while(i < pLength){
while (i < pLength) {
if (p(i) == '/') complete()
i += 1
}
Expand All @@ -96,33 +194,35 @@ object Util {
pw.flush()
trace.toString
}

def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
val oneLine = s.linesIterator.mkString(" ").split(' ')

lazy val indent = " " * leftOffset

val output = new StringBuilder(oneLine.head)
var currentLineWidth = oneLine.head.length
for(chunk <- oneLine.tail){
for (chunk <- oneLine.tail) {
val addedWidth = currentLineWidth + chunk.length + 1
if (addedWidth > maxWidth){
if (addedWidth > maxWidth) {
output.append("\n" + indent)
output.append(chunk)
currentLineWidth = chunk.length
} else{
} else {
currentLineWidth = addedWidth
output.append(' ')
output.append(chunk)
}
}
output.mkString
}

def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])(
implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = {
in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) {
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
.map(_.result())
}
}
Loading

0 comments on commit a2a40cd

Please sign in to comment.