Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,98 @@ class CleanupRetains(using Context) extends TypeMap:
else this(parent)
case _ => mapOver(tp)

/** Marks a poly-fn literal's tpt as explicit so Setup runs
* `transformExplicitType` on it (preserving user-written `@retains`)
* instead of `mapInferred` (which strips them). Required for soundness:
* without it, `[C^] => (x: File^{C}) => x` would have stored result
* type `File^{}`, and applying with an impure arg would yield a
* value falsely typed as pure.
*
* Restricted to poly-fn literals; ordinary Function1 lambdas need
* `mapInferred`'s capture-set-variable inference.
*/
object Explicify {

/** True if `rhs` is — or transitively wraps — a poly-fn literal.
* Wrappers are non-poly closures whose body is itself a closure,
* e.g. `(i: Int) => { [C^] => ... }` or `[A] => zs => [C^] => ...`.
*/
def isPolyFunLiteralRhs(rhs: Tree)(using Context): Boolean = rhs match
case closureDef(dd) =>
dd.symbol.info.isInstanceOf[PolyType] || isPolyFunLiteralRhs(dd.rhs)
case _ => false

/** Walk `tp`'s lambda chain, leaving params and binder bounds alone
* and sanitizing only the innermost result.
*/
def explicify(tp: Type)(using Context): Type = tp match
case defn.PolyFunctionOf(mt: MethodOrPoly) =>
val mt1 = explicify(mt).asInstanceOf[MethodOrPoly]
if mt1 eq mt then tp else defn.PolyFunctionOf(mt1)
case mt: MethodOrPoly =>
mt.derivedLambdaType(resType = explicify(mt.resType))
case tp @ AppliedType(tycon, args) if defn.isFunctionType(tp) =>
val res1 = explicify(args.last)
if res1 eq args.last then tp else AppliedType(tycon, args.init :+ res1)
case _ =>
sanitizeLeaf(tp)

/** Drop typer placeholder arguments from `@retains` annotations (e.g.
* `retain[TypeBounds(...)]` for unresolved inference) while keeping
* user-written capability references. Mixed args like
* `retain[x | TypeBounds(...)]` reduce to `retain[x]`.
*/
private def sanitizeLeaf(tp: Type)(using Context): Type =
val tm = new TypeMap:
def apply(tp: Type): Type = tp match
case AnnotatedType(parent, ann: RetainingAnnotation) =>
val args = ann.argumentTypes
val args1 = args.mapConserve(filterValidRetainArg)
if args1 eq args then mapOver(tp)
else AnnotatedType(this(parent), RetainingAnnotation(ann.symbol.asClass, args1*))
case _ => mapOver(tp)
tm(tp)

/** Replace non-capability sub-parts with `Nothing` (the empty-capture
* marker); `OrType` collapses to its valid branches.
*/
private def filterValidRetainArg(tp: Type)(using Context): Type = tp match
case _: (TermRef | TypeRef | TypeParamRef | ThisType | SkolemType) => tp
case tp @ AnnotatedType(parent, ann) =>
tp.derivedAnnotatedType(filterValidRetainArg(parent), ann)
case tp: OrType =>
tp.derivedOrType(filterValidRetainArg(tp.tp1), filterValidRetainArg(tp.tp2))
case _ => defn.NothingType

/** Flip an inferred TypeTree to non-inferred, with the explicified type. */
def explicifyTpt(tpt: Tree)(using Context): Tree = tpt match
case tpt: TypeTree if tpt.isInferred =>
tpd.TypeTree(explicify(tpt.tpe), inferred = false).withSpan(tpt.span)
case _ => tpt

/** Apply `explicifyTpt` to each `$anonfun` DefDef's result tpt along
* the closure chain so curried inner-layer params (e.g.
* `(ys: List[File^{C}])`) survive Setup.
*/
def explicifyClosureChain(rhs: Tree)(using Context): Tree = rhs match
case Block((dd: DefDef) :: Nil, closure: Closure) if dd.symbol == closure.meth.symbol =>
cpy.Block(rhs)(
cpy.DefDef(dd)(tpt = explicifyTpt(dd.tpt), rhs = explicifyClosureChain(dd.rhs)) :: Nil,
closure)
case Block(Nil, expr) =>
cpy.Block(rhs)(Nil, explicifyClosureChain(expr))
case _ => rhs

/** PostTyper entry point: explicify the val/def's tpt and its closure
* chain when `rhs` is a poly-fn literal.
*/
def maybeExplicifyChain(tpt: Tree, rhs: Tree)(using Context): (Tree, Tree) = tpt match
case tpt: TypeTree if Feature.ccEnabled && tpt.isInferred && isPolyFunLiteralRhs(rhs) =>
(explicifyTpt(tpt), explicifyClosureChain(rhs))
case _ => (tpt, rhs)

}

/** A base class for extractors that match annotated types with a specific
* Capability annotation.
*/
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -582,15 +582,17 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
annotateExperimentalCompanion(tree.symbol)
registerIfHasMacroAnnotations(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt))
val (tpt1, rhs1) = Explicify.maybeExplicifyChain(makeOverrideTypeDeclared(tree.symbol, tree.tpt), tree.rhs)
val tree1 = cpy.ValDef(tree)(tpt = tpt1, rhs = rhs1)
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
case tree: DefDef =>
registerIfHasMacroAnnotations(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt))
val (tpt1, rhs1) = Explicify.maybeExplicifyChain(makeOverrideTypeDeclared(tree.symbol, tree.tpt), tree.rhs)
val tree1 = cpy.DefDef(tree)(tpt = tpt1, rhs = rhs1)
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
case tree: TypeDef =>
registerIfHasMacroAnnotations(tree)
Expand Down
24 changes: 24 additions & 0 deletions tests/neg-custom-args/captures/i25830-nicolas-lambda.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import language.experimental.captureChecking
import caps.*

// Soundness regression: for nicolas1-shape poly-fn lambdas with capset
// binder `B^`, applying with `B := {seed}` must propagate `^{seed}` into
// the result. Assigning the result to a strict pure-bound (`Rand -> Int`)
// must be rejected.

trait Rand extends SharedCapability:
def range(min: Int, max: Int): Int

val pickFirst =
[A, B^] => (head: Rand ->{B} A, tail: Rand ->{B} A) => head

val oneOf =
[A, B^] => (head: Rand ->{B} A, tail: Seq[Rand ->{B} A]) =>
val all: Seq[Rand ->{B} A] = head +: tail
all.head

def check =
val seed: Rand = ???
val f: Rand ->{seed} Int = (r: Rand) => r.range(0, 10)
val r2: Rand -> Int = pickFirst[Int, {seed}](f, f) // error
val r4: Rand -> Int = oneOf[Int, {seed}](f, Seq(f, f)) // error
16 changes: 16 additions & 0 deletions tests/neg-custom-args/captures/i25830-soundness.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import language.experimental.captureChecking
import caps.*

class File extends SharedCapability

// Soundness regression test: applying an identity poly-fn must not
// erase the captures of its argument. If the lambda's stored result
// type were scrubbed to `^{}`, an impure value could be claimed pure
// and leaked past a strict capture-set bound.

object Test:
val id = [C^] => (x: File^{C}) => x

def check(): Unit =
val a = File()
val r: File^{} = id[{a}](a) // error
18 changes: 18 additions & 0 deletions tests/pos-custom-args/captures/i25830-apply-workaround.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import language.experimental.captureChecking
import caps.*

class File extends SharedCapability

def test() =
val external = File()
class Convert:
def apply[C^, D^ <: {C}, E^ >: {C} <: {C, external}](
xs: List[File^{C, external}],
ys: List[File^{D, external}])(
zs: List[File^{E, external}]): List[File^{E, external}] = zs
val x = File()
val files1: List[File^{x, external}] = List(x)
val files2: List[File^{x, external}] = List(x)
val files3: List[File^{x, external}] = List(x)
val _ : List[File^{x, external}] =
Convert()[{x}, {x}, {x, external}](files1, files2)(files3)
27 changes: 27 additions & 0 deletions tests/pos-custom-args/captures/i25830-bounded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import language.experimental.captureChecking
import caps.*

class File extends SharedCapability

def testFlat() =
val f = { [C^, D^ <: {C}] => (xs: List[File^{D}]) => xs }
val a = File()
val _ : List[File^{a}] = f[{a}, {a}](List[File^{a}](a))

def testLowerBound() =
val f = { [C^, D^ >: {C}] => (xs: List[File^{D}]) => xs }
val a = File()
val _ : List[File^{a}] = f[{a}, {a}](List[File^{a}](a))

def testCurriedBounded() =
val f =
{ [C^, D^ <: {C}, E^ >: {C} <: {C, D}] =>
(xs: List[File^{D}], ys: List[File^{C}]) =>
(zs: List[File^{E}], ws: List[File^{C, D}]) => ()
}
val a = File()
val xs: List[File^{a}] = List(a)
val ys: List[File^{a}] = List(a)
val zs: List[File^{a}] = List(a)
val ws: List[File^{a}] = List(a)
val _ : Unit = f[{a}, {a}, {a}](xs, ys)(zs, ws)
58 changes: 58 additions & 0 deletions tests/pos-custom-args/captures/i25830-external.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import language.experimental.captureChecking
import caps.*

class File extends SharedCapability

// Capture-polymorphic lambdas whose retains mention enclosing capabilities
// in addition to (or instead of) the lambda's own capset binders. Because
// PostTyper now keeps the user-written parameter types and binder bounds
// verbatim ("explicify"), Setup processes them through transformExplicitType
// which preserves @retains. Only the inferred result type is cleaned up.

def mixedExternal() =
val external = File()
val f =
{ [C^] => (xs: List[File^{C, external}]) => xs }

def externalOnly() =
val external = File()
val f =
{ [C^] => (xs: List[File^{external}]) => xs }

def mixedExternalWithLowerBoundedParam() =
val external = File()
val f =
{ [C^, D^ >: {C}] => (xs: List[File^{D, external}]) => xs }

def mixedExternalInLaterParamList() =
val external = File()
val f =
{ [C^] => (xs: List[File^{C}]) => (ys: List[File^{C, external}]) => ys }

def enclosingParam(external: File^) =
val f =
{ [C^] => (xs: List[File^{C}]) => (ys: List[File^{external}]) => ys }

def supportedDef() =
val external = File()
def f =
{ [C^] => (xs: List[File^{C, external}]) => xs }

def insideAnonymousFunction() =
List(File()).map: external =>
val f =
{ [C^] => (xs: List[File^{C}]) => (ys: List[File^{external}]) => ys }

def externalInBound() =
val external = File()
val f =
{ [C^, D^ <: {C, external}] => (xs: List[File^{D}]) => xs }

def nestedCapsetBinders() =
val f =
{ [C^] => (xs: List[File^{C}]) => [D^] => (ys: List[File^{C, D}]) => ys }

def literalNestedInFunction1() =
val external = File()
val f =
(i: Int) => { [C^] => (xs: List[File^{C, external}]) => xs }
24 changes: 24 additions & 0 deletions tests/pos-custom-args/captures/i25830-nicolas-lambda.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import language.experimental.captureChecking
import caps.*

// Lambda forms of the nicolas1 patterns: capset binder `B^` parameterising
// regular function types `Rand ->{B} A`. Verifies the explicify pipeline
// keeps `^{B}` flowing through to the result so the lambda can be applied
// with `B := {seed}` and used at the strict bound `Rand ->{seed} Int`.

trait Rand extends SharedCapability:
def range(min: Int, max: Int): Int

val pickFirst =
[A, B^] => (head: Rand ->{B} A, tail: Rand ->{B} A) => head

val oneOf =
[A, B^] => (head: Rand ->{B} A, tail: Seq[Rand ->{B} A]) =>
val all: Seq[Rand ->{B} A] = head +: tail
all.head

def use =
val seed: Rand = ???
val f: Rand ->{seed} Int = (r: Rand) => r.range(0, 10)
val r1: Rand ->{seed} Int = pickFirst[Int, {seed}](f, f)
val r3: Rand ->{seed} Int = oneOf[Int, {seed}](f, Seq(f, f))
63 changes: 63 additions & 0 deletions tests/pos-custom-args/captures/i25830.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import language.experimental.captureChecking
import caps.*

class File extends SharedCapability

@main def test =
val convert = { [C^] => (xs: List[File^{C}]) => xs.map(_ => ()) }
val x = File()
val files: List[File^{x}] = List(x)
val result = convert[{x}](files)

val convertCurried =
{ [C^] => (xs: List[File^{C}]) => (ys: List[File^{C}]) =>
xs.map(_ => ()) ++ ys.map(_ => ())
}
val resultCurried = convertCurried[{x}](files)(files)

def convertDef =
{ [C^] => (xs: List[File^{C}]) => xs.map(_ => ()) }
val resultDef = convertDef[{x}](files)

val resultInAnonymousFunction =
files.map: file =>
val localFiles: List[File^{file}] = List(file)
val localConvert =
{ [C^] => (xs: List[File^{C}]) => xs }
localConvert[{file}](localFiles)

// Poly-fn literal nested inside a Function1: fine as long as retains
// only mention the literal's own capset binders.
val nestedInFunction1 = (i: Int) => { [C^] => (xs: List[File^{C}]) => xs }
val resultNested = nestedInFunction1(0)[{x}](files)

// Capset binders interleaved with regular type binders.
val interleaved1 = { [C^, A] => (xs: List[A]) => (ys: List[File^{C}]) => ys }
val resultInterleaved1 = interleaved1[{x}, Int](List(1))(files)

val interleaved2 = { [A, C^] => (xs: List[A]) => (ys: List[File^{C}]) => ys }
val resultInterleaved2 = interleaved2[Int, {x}](List(1))(files)

val interleaved3 =
{ [A, C^, B, D^] => (xs: List[A], ys: List[B]) =>
(zs: List[File^{C}], ws: List[File^{D}]) => zs
}
val resultInterleaved3 = interleaved3[Int, {x}, String, {x}](List(1), List("a"))(files, files)

// Multiple capset binder blocks separated by term-parameter lists.
val multi1 =
{ [C^] => (xs: List[File^{C}]) => [D^] => (ys: List[File^{D}]) => (xs, ys) }
val resultMulti1 = multi1[{x}](files)[{x}](files)

val multi2 =
{ [C^] => (xs: List[File^{C}]) => [A] => (zs: List[A]) => [D^] => (ws: List[File^{D}]) => (xs, zs, ws) }
val resultMulti2 = multi2[{x}](files)[Int](List(1))[{x}](files)

// Non-capset block first, then capset block.
val multi3 = { [A] => (zs: List[A]) => [C^] => (xs: List[File^{C}]) => (zs, xs) }
val resultMulti3 = multi3[Int](List(1))[{x}](files)

// Inner block references both capset binders.
val multi4 =
{ [C^] => (xs: List[File^{C}]) => [D^] => (ys: List[File^{C, D}]) => ys }
val resultMulti4 = multi4[{x}](files)[{x}](files)
Loading