diff --git a/compiler/src/dotty/tools/dotc/cc/Capability.scala b/compiler/src/dotty/tools/dotc/cc/Capability.scala index 95f8f180b339..a25a3eb461a7 100644 --- a/compiler/src/dotty/tools/dotc/cc/Capability.scala +++ b/compiler/src/dotty/tools/dotc/cc/Capability.scala @@ -473,27 +473,28 @@ object Capabilities: case info: OrType => viaInfo(info.tp1)(test) && viaInfo(info.tp2)(test) case _ => false + def trySubpath(y: TermRef): Boolean = + y.prefix.match + case ypre: Capability => + this.subsumes(ypre) + || this.match + case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol => + // To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y` + // are equvalent, which means `x =:= y` in terms of subtyping, + // not just `{x} =:= {y}` in terms of subcapturing. + // It is possible to construct two singleton types `x` and `y`, + // which subsume each other, but are not equal references. + // See `tests/neg-custom-args/captures/path-prefix.scala` for example. + withMode(Mode.IgnoreCaptures): + TypeComparer.isSameRef(xpre, ypre) + case _ => + false + case _ => false + try (this eq y) || maxSubsumes(y, canAddHidden = !vs.isOpen) || y.match - case y: TermRef => - y.prefix.match - case ypre: Capability => - this.subsumes(ypre) - || this.match - case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol => - // To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y` - // are equvalent, which means `x =:= y` in terms of subtyping, - // not just `{x} =:= {y}` in terms of subcapturing. - // It is possible to construct two singleton types `x` and `y`, - // which subsume each other, but are not equal references. - // See `tests/neg-custom-args/captures/path-prefix.scala` for example. - withMode(Mode.IgnoreCaptures): - TypeComparer.isSameRef(xpre, ypre) - case _ => - false - case _ => false - || viaInfo(y.info)(subsumingRefs(this, _)) + case y: TermRef => trySubpath(y) || viaInfo(y.info)(subsumingRefs(this, _)) case Maybe(y1) => this.stripMaybe.subsumes(y1) case ReadOnly(y1) => this.stripReadOnly.subsumes(y1) case y: TypeRef if y.derivesFrom(defn.Caps_CapSet) => @@ -507,6 +508,15 @@ object Capabilities: this.subsumes(hi) case _ => y.captureSetOfInfo.elems.forall(this.subsumes) + case Reach(y1: TermRef) => + val sym = y1.symbol + def isUseClassParam: Boolean = + sym.owner match + case classSym: ClassSymbol => + val paramSym = classSym.primaryConstructor.paramNamed(sym.name) + paramSym.isUseParam + case _ => false + isUseClassParam && trySubpath(y1) case _ => false || this.match case Reach(x1) => x1.subsumes(y.stripReach) @@ -858,4 +868,4 @@ object Capabilities: case tp1 => tp1 end toResultInResults -end Capabilities \ No newline at end of file +end Capabilities diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 3dd847f19b56..d8d2a5a039c8 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -369,6 +369,7 @@ extension (tp: Type) val tp1 = narrowCaps(tp) if narrowCaps.change then capt.println(i"narrow $tp of $ref to $tp1") + //println(i"reach refinement $tp at $ref to $tp1 (${ctx.compilationUnit})") tp1 else tp @@ -395,6 +396,9 @@ extension (tp: Type) RefinedType(tp, name, AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan))) + def dropUseAndConsumeAnnots(using Context): Type = + tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot) + extension (tp: MethodType) /** A method marks an existential scope unless it is the prefix of a curried method */ def marksExistentialScope(using Context): Boolean = @@ -490,11 +494,12 @@ extension (sym: Symbol) def hasTrackedParts(using Context): Boolean = !CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty - /** `sym` is annotated @use or it is a type parameter with a matching + /** `sym` itself or its info is annotated @use or it is a type parameter with a matching * @use-annotated term parameter that contains `sym` in its deep capture set. */ def isUseParam(using Context): Boolean = sym.hasAnnotation(defn.UseAnnot) + || sym.info.hasAnnotation(defn.UseAnnot) || sym.is(TypeParam) && sym.owner.rawParamss.nestedExists: param => param.is(TermParam) && param.hasAnnotation(defn.UseAnnot) @@ -502,6 +507,11 @@ extension (sym: Symbol) case c: TypeRef => c.symbol == sym case _ => false + /** `sym` or its info is annotated with `@consume`. */ + def isConsumeParam(using Context): Boolean = + sym.hasAnnotation(defn.ConsumeAnnot) + || sym.info.hasAnnotation(defn.ConsumeAnnot) + def isUpdateMethod(using Context): Boolean = sym.isAllOf(Mutable | Method, butNot = Accessor) diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 1bdd7ce92129..879cd9a512c0 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -716,7 +716,7 @@ class CheckCaptures extends Recheck, SymTransformer: funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) => val param = meth.paramNamed(pname) def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match - case Some(ann) => AnnotatedType(tp, ann) + case Some(ann) if !tp.hasAnnotation(cls) => AnnotatedType(tp, ann) case _ => tp copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot) funtpe.derivedLambdaType(paramInfos = paramInfosWithUses) @@ -789,6 +789,7 @@ class CheckCaptures extends Recheck, SymTransformer: case appType @ CapturingType(appType1, refs) if qualType.exists && !tree.fun.symbol.isConstructor + && funType.paramInfos.isEmpty && qualCaptures.mightSubcapture(refs) && argCaptures.forall(_.mightSubcapture(refs)) => val callCaptures = argCaptures.foldLeft(qualCaptures)(_ ++ _) @@ -845,10 +846,14 @@ class CheckCaptures extends Recheck, SymTransformer: initCs ++ FreshCap(Origin.NewCapability(core)).readOnly.singletonCaptureSet else initCs for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do + val paramSym = cls.primaryConstructor.paramNamed(getterName) val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol if !getter.is(Private) && getter.hasTrackedParts then refined = refined.refinedOverride(getterName, argType.unboxed) // Yichen you might want to check this - allCaptures ++= argType.captureSet + if paramSym.isUseParam then + allCaptures ++= argType.deepCaptureSet + else + allCaptures ++= argType.captureSet (refined, allCaptures) /** Augment result type of constructor with refinements and captures. @@ -1616,7 +1621,10 @@ class CheckCaptures extends Recheck, SymTransformer: if noWiden(actual, expected) then actual else - val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual) + // Compute the widened type. Drop `@use` and `@consume` annotations from the type, + // since they obscures the capturing type. + val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots + val improvedVAR = improveCaptures(widened, actual) val improved = improveReadOnly(improvedVAR, expected) val adapted = adaptBoxed( improved.withReachCaptures(actual), expected, tree, diff --git a/compiler/src/dotty/tools/dotc/cc/SepCheck.scala b/compiler/src/dotty/tools/dotc/cc/SepCheck.scala index 6dad0e9a2ff7..a402e58624f2 100644 --- a/compiler/src/dotty/tools/dotc/cc/SepCheck.scala +++ b/compiler/src/dotty/tools/dotc/cc/SepCheck.scala @@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then report.error(em"""Separation failure: $descr non-local $refSym""", pos) else if refSym.is(TermParam) - && !refSym.hasAnnotation(defn.ConsumeAnnot) + && !refSym.isConsumeParam && currentOwner.isContainedIn(refSym.owner) then badParams += refSym @@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"): checkUse(tree) tree match - case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) => + case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam => traverseChildren(tree) checkConsumedRefs( captures(qual).footprint(), qual.nuType, @@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: consumeInLoopError(ref, pos) case _ => traverseChildren(tree) -end SepCheck \ No newline at end of file +end SepCheck diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b06bd5c00a28..e10a5221e8e7 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4234,6 +4234,11 @@ object Types extends TypeUtils { paramType = addAnnotation(paramType, defn.InlineParamAnnot, param) if param.is(Erased) then paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param) + // Copy `@use` and `@consume` annotations from parameter symbols to the type. + if param.hasAnnotation(defn.UseAnnot) then + paramType = addAnnotation(paramType, defn.UseAnnot, param) + if param.hasAnnotation(defn.ConsumeAnnot) then + paramType = addAnnotation(paramType, defn.ConsumeAnnot, param) paramType def adaptParamInfo(param: Symbol)(using Context): Type = diff --git a/scala2-library-cc/src/scala/collection/Iterable.scala b/scala2-library-cc/src/scala/collection/Iterable.scala index 6556f31d378d..1fc40a019c4d 100644 --- a/scala2-library-cc/src/scala/collection/Iterable.scala +++ b/scala2-library-cc/src/scala/collection/Iterable.scala @@ -682,9 +682,9 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable def map[B](f: A => B): CC[B]^{this, f} = iterableFactory.from(new View.Map(this, f)) - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = iterableFactory.from(new View.FlatMap(this, f)) + def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} = iterableFactory.from(new View.FlatMap(this, f)) - def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this} = flatMap(asIterable) + def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this, asIterable*} = flatMap(asIterable) def collect[B](pf: PartialFunction[A, B]^): CC[B]^{this, pf} = iterableFactory.from(new View.Collect(this, pf)) @@ -902,7 +902,7 @@ object IterableOps { def map[B](f: A => B): CC[B]^{this, f} = self.iterableFactory.from(new View.Map(filtered, f)) - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = + def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} = self.iterableFactory.from(new View.FlatMap(filtered, f)) def foreach[U](f: A => U): Unit = filtered.foreach(f) diff --git a/scala2-library-cc/src/scala/collection/IterableOnce.scala b/scala2-library-cc/src/scala/collection/IterableOnce.scala index 7ea62a9e1a65..803d97e963c7 100644 --- a/scala2-library-cc/src/scala/collection/IterableOnce.scala +++ b/scala2-library-cc/src/scala/collection/IterableOnce.scala @@ -246,10 +246,9 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext } @deprecated("Use .iterator.flatMap instead or consider requiring an Iterable", "2.13.0") - def flatMap[B](f: A => IterableOnce[B]^): IterableOnce[B]^{f} = it match { + def flatMap[B](@caps.use f: A => IterableOnce[B]^): IterableOnce[B]^{f*} = it match case it: Iterable[A] => it.flatMap(f) case _ => it.iterator.flatMap(f) - } @deprecated("Use .iterator.sameElements instead", "2.13.0") def sameElements[B >: A](that: IterableOnce[B]): Boolean = it.iterator.sameElements(that) @@ -439,7 +438,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ => * @return a new $coll resulting from applying the given collection-valued function * `f` to each element of this $coll and concatenating the results. */ - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} + def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} /** Converts this $coll of iterable collections into * a $coll formed by the elements of these iterable diff --git a/scala2-library-cc/src/scala/collection/Iterator.scala b/scala2-library-cc/src/scala/collection/Iterator.scala index 91a22caa288c..275e0651da3d 100644 --- a/scala2-library-cc/src/scala/collection/Iterator.scala +++ b/scala2-library-cc/src/scala/collection/Iterator.scala @@ -588,8 +588,8 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite def next() = f(self.next()) } - def flatMap[B](f: A => IterableOnce[B]^): Iterator[B]^{this, f} = new AbstractIterator[B] { - private[this] var cur: Iterator[B]^{f} = Iterator.empty + def flatMap[B](@caps.use f: A => IterableOnce[B]^): Iterator[B]^{this, f*} = new AbstractIterator[B] { + private[this] var cur: Iterator[B]^{f*} = Iterator.empty /** Trillium logic boolean: -1 = unknown, 0 = false, 1 = true */ private[this] var _hasNext: Int = -1 @@ -623,7 +623,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite } } - def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this} = + def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this, ev*} = flatMap[B](ev) def concat[B >: A](xs: => IterableOnce[B]^): Iterator[B]^{this, xs} = new Iterator.ConcatIterator[B](self).concat(xs) @@ -982,7 +982,7 @@ object Iterator extends IterableFactory[Iterator] { /** Creates a target $coll from an existing source collection * * @param source Source collection - * @tparam A the type of the collection’s elements + * @tparam A the type of the collection's elements * @return a new $coll with the elements of `source` */ override def from[A](source: IterableOnce[A]^): Iterator[A]^{source} = source.iterator @@ -1003,7 +1003,7 @@ object Iterator extends IterableFactory[Iterator] { /** * @return A builder for $Coll objects. - * @tparam A the type of the ${coll}’s elements + * @tparam A the type of the ${coll}'s elements */ def newBuilder[A]: Builder[A, Iterator[A]] = new ImmutableBuilder[A, Iterator[A]](empty[A]) { diff --git a/scala2-library-cc/src/scala/collection/Map.scala b/scala2-library-cc/src/scala/collection/Map.scala index 7ba393ecd242..3734ecf5bc83 100644 --- a/scala2-library-cc/src/scala/collection/Map.scala +++ b/scala2-library-cc/src/scala/collection/Map.scala @@ -321,7 +321,7 @@ trait MapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C] * @return a new $coll resulting from applying the given collection-valued function * `f` to each element of this $coll and concatenating the results. */ - def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f)) + def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f)) /** Returns a new $coll containing the elements from the left hand operand followed by the elements from the * right hand operand. The element type of the $coll is the most specific superclass encompassing @@ -383,7 +383,7 @@ object MapOps { def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2]^{this, f} = self.mapFactory.from(new View.Map(filtered, f)) - def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f} = + def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f*} = self.mapFactory.from(new View.FlatMap(filtered, f)) override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, CC]^{this, q} = diff --git a/scala2-library-cc/src/scala/collection/SortedMap.scala b/scala2-library-cc/src/scala/collection/SortedMap.scala index 876a83b2709c..546f10b452a6 100644 --- a/scala2-library-cc/src/scala/collection/SortedMap.scala +++ b/scala2-library-cc/src/scala/collection/SortedMap.scala @@ -208,7 +208,7 @@ object SortedMapOps { def map[K2 : Ordering, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] = self.sortedMapFactory.from(new View.Map(filtered, f)) - def flatMap[K2 : Ordering, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = + def flatMap[K2 : Ordering, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = self.sortedMapFactory.from(new View.FlatMap(filtered, f)) override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, MapCC, CC]^{this, q} = diff --git a/scala2-library-cc/src/scala/collection/StrictOptimizedIterableOps.scala b/scala2-library-cc/src/scala/collection/StrictOptimizedIterableOps.scala index 5b504a2469b5..4dae50afea6e 100644 --- a/scala2-library-cc/src/scala/collection/StrictOptimizedIterableOps.scala +++ b/scala2-library-cc/src/scala/collection/StrictOptimizedIterableOps.scala @@ -104,7 +104,7 @@ trait StrictOptimizedIterableOps[+A, +CC[_], +C] b.result() } - override def flatMap[B](f: A => IterableOnce[B]^): CC[B] = + override def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B] = strictOptimizedFlatMap(iterableFactory.newBuilder, f) /** diff --git a/scala2-library-cc/src/scala/collection/StrictOptimizedMapOps.scala b/scala2-library-cc/src/scala/collection/StrictOptimizedMapOps.scala index a9c5e0af43b3..a26ab590e91f 100644 --- a/scala2-library-cc/src/scala/collection/StrictOptimizedMapOps.scala +++ b/scala2-library-cc/src/scala/collection/StrictOptimizedMapOps.scala @@ -29,7 +29,7 @@ trait StrictOptimizedMapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C override def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] = strictOptimizedMap(mapFactory.newBuilder, f) - override def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = + override def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = strictOptimizedFlatMap(mapFactory.newBuilder, f) override def concat[V2 >: V](suffix: IterableOnce[(K, V2)]^): CC[K, V2] = diff --git a/scala2-library-cc/src/scala/collection/View.scala b/scala2-library-cc/src/scala/collection/View.scala index 72a073836e77..a376895b06a7 100644 --- a/scala2-library-cc/src/scala/collection/View.scala +++ b/scala2-library-cc/src/scala/collection/View.scala @@ -309,8 +309,8 @@ object View extends IterableFactory[View] { /** A view that flatmaps elements of the underlying collection. */ @SerialVersionUID(3L) - class FlatMap[A, B](underlying: SomeIterableOps[A]^, f: A => IterableOnce[B]^) extends AbstractView[B] { - def iterator: Iterator[B]^{underlying, f} = underlying.iterator.flatMap(f) + class FlatMap[A, B](underlying: SomeIterableOps[A]^, @caps.use f: A => IterableOnce[B]^) extends AbstractView[B] { + def iterator: Iterator[B]^{underlying, f*} = underlying.iterator.flatMap(f) override def knownSize: Int = if (underlying.knownSize == 0) 0 else super.knownSize override def isEmpty: Boolean = iterator.isEmpty } diff --git a/scala2-library-cc/src/scala/collection/WithFilter.scala b/scala2-library-cc/src/scala/collection/WithFilter.scala index a2255a8cc0c5..3dfe920461c7 100644 --- a/scala2-library-cc/src/scala/collection/WithFilter.scala +++ b/scala2-library-cc/src/scala/collection/WithFilter.scala @@ -45,7 +45,7 @@ abstract class WithFilter[+A, +CC[_]] extends Serializable { * of the filtered outer $coll and * concatenating the results. */ - def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} + def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} /** Applies a function `f` to all elements of the `filtered` outer $coll. * diff --git a/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala b/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala index 726b011c6929..5ab128cb01d7 100644 --- a/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala +++ b/scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala @@ -592,7 +592,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz */ // optimisations are not for speed, but for functionality // see tickets #153, #498, #2147, and corresponding tests in run/ (as well as run/stream_flatmap_odds.scala) - override def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = + override def flatMap[B](@caps.use f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} = if (knownIsEmpty) LazyListIterable.empty else LazyListIterable.flatMapImpl(this, f) @@ -600,7 +600,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz * * $preservesLaziness */ - override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this} = flatMap(asIterable) + override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this, asIterable*} = flatMap(asIterable) /** @inheritdoc * @@ -1061,11 +1061,11 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { } } - private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f} = { + private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f*} = { // DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD var restRef: LazyListIterable[A]^{ll} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric newLL { - var it: Iterator[B]^{ll, f} = null + var it: Iterator[B]^{ll, f*} = null var itHasNext = false var rest = restRef // var rest = restRef.elem while (!itHasNext && !rest.isEmpty) { @@ -1307,7 +1307,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] { extends collection.WithFilter[A, LazyListIterable] { private[this] val filtered = lazyList.filter(p) def map[B](f: A => B): LazyListIterable[B]^{this, f} = filtered.map(f) - def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = filtered.flatMap(f) + def flatMap[B](@caps.use f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} = filtered.flatMap(f) def foreach[U](f: A => U): Unit = filtered.foreach(f) def withFilter(q: A => Boolean): collection.WithFilter[A, LazyListIterable]^{this, q} = new WithFilter(filtered, q) } diff --git a/scala2-library-cc/src/scala/collection/immutable/List.scala b/scala2-library-cc/src/scala/collection/immutable/List.scala index 913de8b0be08..ad4edf1b3989 100644 --- a/scala2-library-cc/src/scala/collection/immutable/List.scala +++ b/scala2-library-cc/src/scala/collection/immutable/List.scala @@ -286,7 +286,7 @@ sealed abstract class List[+A] } } - final override def flatMap[B](f: A => IterableOnce[B]^): List[B] = { + final override def flatMap[B](@caps.use f: A => IterableOnce[B]^): List[B] = { var rest = this var h: ::[B] = null var t: ::[B] = null diff --git a/scala2-library-cc/src/scala/collection/immutable/TreeSeqMap.scala b/scala2-library-cc/src/scala/collection/immutable/TreeSeqMap.scala index dc59d21b8b19..6fbeea560e07 100644 --- a/scala2-library-cc/src/scala/collection/immutable/TreeSeqMap.scala +++ b/scala2-library-cc/src/scala/collection/immutable/TreeSeqMap.scala @@ -234,7 +234,7 @@ final class TreeSeqMap[K, +V] private ( bdr.result() } - override def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): TreeSeqMap[K2, V2] = { + override def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): TreeSeqMap[K2, V2] = { val bdr = newBuilder[K2, V2](orderedBy) val iter = ordering.iterator while (iter.hasNext) { diff --git a/tests/neg-custom-args/captures/cc-annot-value-classes.scala b/tests/neg-custom-args/captures/cc-annot-value-classes.scala new file mode 100644 index 000000000000..745b1c85b8b1 --- /dev/null +++ b/tests/neg-custom-args/captures/cc-annot-value-classes.scala @@ -0,0 +1,18 @@ +import language.experimental.captureChecking +import caps.* + +class Runner(val x: Int) extends AnyVal: + def runOps(@use ops: List[() => Unit]): Unit = + ops.foreach(_()) // ok + +class RunnerAlt(val x: Int): + def runOps(@use ops: List[() => Unit]): Unit = + ops.foreach(_()) // ok, of course + +class RunnerAltAlt(val x: Int) extends AnyVal: + def runOps(ops: List[() => Unit]): Unit = + ops.foreach(_()) // error, as expected + +class RunnerAltAltAlt(val x: Int): + def runOps(ops: List[() => Unit]): Unit = + ops.foreach(_()) // error, as expected diff --git a/tests/neg-custom-args/captures/cc-annot-value-classes2.scala b/tests/neg-custom-args/captures/cc-annot-value-classes2.scala new file mode 100644 index 000000000000..5821f9664f6b --- /dev/null +++ b/tests/neg-custom-args/captures/cc-annot-value-classes2.scala @@ -0,0 +1,16 @@ +import language.experimental.captureChecking +import caps.* +trait Ref extends Mutable +def kill(@consume x: Ref^): Unit = () + +class C1: + def myKill(@consume x: Ref^): Unit = kill(x) // ok + +class C2(val dummy: Int) extends AnyVal: + def myKill(@consume x: Ref^): Unit = kill(x) // ok, too + +class C3: + def myKill(x: Ref^): Unit = kill(x) // error + +class C4(val dummy: Int) extends AnyVal: + def myKill(x: Ref^): Unit = kill(x) // error, too diff --git a/tests/neg-custom-args/captures/cc-class-reach.scala b/tests/neg-custom-args/captures/cc-class-reach.scala new file mode 100644 index 000000000000..860eb782b427 --- /dev/null +++ b/tests/neg-custom-args/captures/cc-class-reach.scala @@ -0,0 +1,7 @@ +import language.experimental.captureChecking +import caps.* +class Runner(@use xs: List[() => Unit]): + def execute: Unit = xs.foreach(op => op()) +def test1(@use ops: List[() => Unit]): Unit = + val runner: Runner^{} = Runner(ops) // error + diff --git a/tests/neg-custom-args/captures/cc-class-this-reach1.scala b/tests/neg-custom-args/captures/cc-class-this-reach1.scala new file mode 100644 index 000000000000..1eff58a6202b --- /dev/null +++ b/tests/neg-custom-args/captures/cc-class-this-reach1.scala @@ -0,0 +1,8 @@ +import language.experimental.captureChecking +import caps.* +trait Runner: + def run: () ->{this} Unit +class Runner1(f: List[() => Unit]) extends Runner: + def run: () ->{f*} Unit = f.head // error +class Runner2(@use f: List[() => Unit]) extends Runner: + def run: () ->{f*} Unit = f.head // ok diff --git a/tests/pos-custom-args/captures/cc-use-iterable.scala b/tests/pos-custom-args/captures/cc-use-iterable.scala new file mode 100644 index 000000000000..84c497c0f6ce --- /dev/null +++ b/tests/pos-custom-args/captures/cc-use-iterable.scala @@ -0,0 +1,10 @@ +import language.experimental.captureChecking +trait IterableOnce[+T] +trait Iterable[+T] extends IterableOnce[T]: + def flatMap[U](@caps.use f: T => IterableOnce[U]^): Iterable[U]^{this, f*} + + +class IterableOnceExtensionMethods[T](val it: IterableOnce[T]) extends AnyVal: + def flatMap[U](@caps.use f: T => IterableOnce[U]^): IterableOnce[U]^{f*} = it match + case it: Iterable[T] => it.flatMap(f) +