Skip to content

Commit

Permalink
Merge pull request #21 from VirtusLab/sanity
Browse files Browse the repository at this point in the history
Sane way of transforming unapply patterns and correct owner setting
  • Loading branch information
KacperFKorban authored Sep 21, 2022
2 parents aa2f7fb + 1e42865 commit c9d2dce
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 37 deletions.
70 changes: 41 additions & 29 deletions avocADO/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package avocado

import scala.annotation.tailrec
import scala.collection.mutable
import scala.quoted.*

private[avocado] object macros {
Expand Down Expand Up @@ -63,7 +64,7 @@ private[avocado] object macros {
.appliedToArgs(List(acc, arg))
case _ =>
val (toZip, rest, newLastBinding) = splitToZip(bindings)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1)), newLastBinding)
val body = go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), newLastBinding)
val arg = funFromZipped(zipped, body, Symbol.spliceOwner)
val tpes = lastBinding.typeArgs.map(_.widen)
ctx.instance
Expand All @@ -73,7 +74,7 @@ private[avocado] object macros {
}

val (toZip, rest, lastMethod) = splitToZip(bindings)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1)), lastMethod)
go(rest, toZip.map(b => b._1.pattern -> b._1.tpe), zipExprs(toZip.map(_._1), Symbol.spliceOwner), lastMethod)
}

private def adaptTpeForMethod(arg: Term, methodName: String): TypeRepr =
Expand All @@ -86,7 +87,7 @@ private[avocado] object macros {
case AppliedType(_, args) => args.last
}

private def zipExprs(toZip: List[Binding])(using Context): Term = {
private def zipExprs(toZip: List[Binding], owner: Symbol)(using Context): Term = {
def doZip(receiver: Term, receiverTpe: TypeRepr, arg: Term)(using Context): Term = {
val receiverTypeSymbol = receiver.tpe.typeSymbol
val argTpe = extractTypeFromApplicative(arg.tpe)
Expand All @@ -99,7 +100,7 @@ private[avocado] object macros {
toZip.init.foldRight(toZip.last.tree) {
case (binding, acc) =>
doZip(binding.tree, binding.tpe, acc)
}
}.changeOwner(owner)
}

private def splitToZip(bindings: List[(Binding, Set[Symbol])]): (List[(Binding, Set[Symbol])], List[(Binding, Set[Symbol])], Binding) = {
Expand Down Expand Up @@ -144,31 +145,42 @@ private[avocado] object macros {

private def funFromZipped(zipped: List[(Tree, TypeRepr)], body: Term, owner: Symbol): Term = {

def makeUnapplies(unaply: Tree, owner: Symbol, binds: Set[String]): (Tree, Map[Symbol, Symbol]) = unaply match {
case valdef: ValDef if valdef.name == "_" || binds.contains(valdef.name) =>
Wildcard() -> Map.empty
case Ident(name) if name == "_" =>
Wildcard() -> Map.empty
case valdef: ValDef =>
val sym = Symbol.newBind(owner, valdef.name, Flags.EmptyFlags, valdef.tpt.tpe)
Bind(sym, Wildcard()) -> Map(valdef.symbol -> sym)
case bind@Bind(name, pattern0) =>
val (pattern, renames) = makeUnapplies(pattern0, owner, binds)
val sym = Symbol.newBind(owner, name, Flags.EmptyFlags, bind.symbol.typeRef.widen)
Bind(sym, pattern) -> (renames + (bind.symbol -> sym))
case Typed(term, _) =>
val (pattern, renames) = makeUnapplies(term, owner, binds)
pattern -> renames
case unaply@Unapply(fun, implicits, patterns0) =>
val (patterns, renames, _) =
patterns0.foldRight((List.empty[Tree], Map.empty[Symbol, Symbol], binds)) {
case (p, (accList, accMap, bindsAcc)) =>
val (pattern, renames) = makeUnapplies(p, owner, bindsAcc)
((pattern :: accList), (accMap ++ renames), bindsAcc ++ renames.keys.map(_.name).toSet)
}
Unapply.copy(unaply)(fun, implicits, patterns) -> renames
case _ =>
unaply.changeOwner(owner) -> Map.empty
def makeUnapplies(unaply: Tree, owner: Symbol, binds0: Set[String]): (Tree, Map[Symbol, Symbol]) = {
val renamesRes = mutable.Map.empty[Symbol, Symbol] // sorry :/
def binds = renamesRes.values.map(_.name).toSet ++ binds0
object mapper extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = tree match {
case Ident(name) if name == "_" =>
Wildcard()
case _ =>
super.transformTerm(tree)(owner)
}
override def transformTree(tree: Tree)(owner: Symbol): Tree = tree match {
case valdef: ValDef if valdef.name == "_" || binds.contains(valdef.name) =>
Wildcard()
case valdef: ValDef =>
val sym = Symbol.newBind(owner, valdef.name, Flags.EmptyFlags, valdef.tpt.tpe)
renamesRes += (valdef.symbol -> sym)
Bind(sym, Wildcard())
case bind@Bind(name, pattern0) =>
val sym = Symbol.newBind(owner, name, Flags.EmptyFlags, bind.symbol.typeRef.widen)
renamesRes += (bind.symbol -> sym)
Bind(sym, transformTree(pattern0)(owner))
case Typed(term, _) =>
transformTerm(term)(owner)
case unaply@Unapply(fun, implicits, patterns0) =>
val patterns =
patterns0.foldRight(List.empty[Tree]) {
case (p, accList) =>
val pattern = transformTree(p)(owner)
pattern :: accList
}
Unapply.copy(unaply)(fun, implicits, patterns).changeOwner(owner)
case _ =>
super.transformTree(tree)(owner).changeOwner(owner)
}
}
mapper.transformTree(unaply)(owner) -> renamesRes.to(Map)
}

def unapplies(zipped: List[(Tree, TypeRepr)], owner: Symbol): (Tree, Map[Symbol, Symbol]) = zipped match {
Expand Down
46 changes: 46 additions & 0 deletions avocADO/src/test/scala/OptionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,50 @@ class OptionTests extends munit.FunSuite {
assertEquals(res, Some(3))
}

test("option comprehension 19") {
val res = ado {
for {
a <- Some(1)
_ <- Some(2)
(_, a) = (3, 4)
} yield a
}
assertEquals(res, Some(4))
}

test("option comprehension 20") {
val res = ado {
for {
a <- Some(1)
_ <- Some(2)
(_, (_, a)) = (3, (4, 5))
} yield a
}
assertEquals(res, Some(5))
}

test("option comprehension 21") {
case class C(i: Int, j: Int)
val res = ado {
for {
a <- Some(1)
_ <- Some(2)
C(_, a) = C(3, 4)
} yield a
}
assertEquals(res, Some(4))
}

test("option comprehension 22") {
case class C(i: Int*)
val res = ado {
for {
a <- Some(1)
_ <- Some(2)
C(_, a) = C(3, 4)
} yield a
}
assertEquals(res, Some(4))
}

}
17 changes: 9 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ val commonSettings = Seq(
url("https://twitter.com/KacperKorban")
)
),
scalacOptions ++= Seq(
"-Xcheck-macros",
"-Ycheck:inlining",
"-explain",
"-deprecation",
"-unchecked",
"-feature"
),
libraryDependencies ++= Seq(
"org.scalameta" %%% "munit" % "0.7.29" % Test
)
Expand All @@ -35,14 +43,7 @@ lazy val avocado = projectMatrix
.in(file("avocADO"))
.settings(commonSettings)
.settings(
name := "avocADO",
scalacOptions ++= Seq(
"-Xcheck-macros",
"-explain",
"-deprecation",
"-unchecked",
"-feature"
)
name := "avocADO"
)
.jvmPlatform(scalaVersions = List(scala3))

Expand Down
45 changes: 45 additions & 0 deletions cats-effect-3/src/test/scala/CatsEffectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,49 @@ class CatsEffectTests extends BaseCatsEffectTest {
}
}(5)

testWithTimeLimit("cats effect comprehension 38", 1900) {
val wait = IO.sleep(500.millis)
case class C(i: Int*)
ado {
for {
a <- wait.map(_ => 1)
b <- wait.map(_ => 2)
c <- wait.map(_ => a + 2)
d <- wait.map(_ => 4)
_ <- wait
C(i, j, k) = C(1, 2, 3)
} yield (c + b, i, j, k)
}
}((5, 1, 2, 3))

testWithTimeLimit("cats effect comprehension 39", 1900) {
val wait = IO.sleep(500.millis)
case class C(i: Int)
ado {
for {
a <- wait.map(_ => 1)
b <- wait.map(_ => 2)
c <- wait.map(_ => 3)
C(i) = C(1)
d <- wait.map(_ => 4)
_ <- wait
} yield (c, i)
}
}((3, 1))

testWithTimeLimit("cats effect comprehension 40", 1900) {
val wait = IO.sleep(500.millis)
case class C(i: Int*)
ado {
for {
a <- wait.map(_ => 1)
b <- wait.map(_ => 2)
c <- wait.map(_ => 3)
C(i, j, k) = C(1, 2, 3)
d <- wait.map(_ => 4)
_ <- wait
} yield (c + b, i, j, k)
}
}((5, 1, 2, 3))

}

0 comments on commit c9d2dce

Please sign in to comment.