diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index db3157f6130d..eca224941770 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -215,6 +215,24 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { ref(field).becomes(nullLiteral) } + /** + * Creates a ValDef used as the underlying volatile var, + * and also returns the term name as it's later needed as a TermName (and not just any symbol name) + */ + private def mkContainerTree(x: ValOrDefDef, isVolatile: Boolean)(using Context): (Names.TermName, ValDef) = + val claz = x.symbol.owner.asClass + val tpe = if isVolatile then defn.ObjectType else x.tpe.widen.resultType.widen + val containerName = LazyLocalName.fresh(x.name.asTermName) + val containerSymbol = newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags | Private, tpe, coord = x.symbol.coord).enteredAfter(this) + if isVolatile then containerSymbol.addAnnotation(Annotation(defn.VolatileAnnot, containerSymbol.span)) + // Keep field annotations like @transient, see scala/scala3#23487 + for a <- x.symbol.annotations do + if a.hasOneOfMetaAnnotation(Set(defn.FieldMetaAnnot)) then + containerSymbol.addAnnotation(a) + // for the thread-safe implementation, the generated symbol must not be static or the CAS operations won't work, see scala/scala3#16800 + if isVolatile then containerSymbol.removeAnnotation(defn.ScalaStaticAnnot) + (containerName, ValDef(containerSymbol, defaultValue(tpe))) + /** Create thread-unsafe lazy accessor equivalent to such code * ``` * def methodSymbol() = { @@ -262,24 +280,15 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { } def transformMemberDefThreadUnsafe(x: ValOrDefDef)(using Context): Thicket = { - val claz = x.symbol.owner.asClass - val tpe = x.tpe.widen.resultType.widen - assert(!x.symbol.isMutableVarOrAccessor) - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = newSymbol(claz, containerName, - x.symbol.flags &~ containerFlagsMask | containerFlags | Private, - tpe, coord = x.symbol.coord - ).enteredAfter(this) - - val containerTree = ValDef(containerSymbol, defaultValue(tpe)) - if (x.tpe.isNotNull && tpe <:< defn.ObjectType) + val (_, containerTree) = mkContainerTree(x, isVolatile = false) + if (x.tpe.isNotNull && containerTree.rhs.tpe <:< defn.ObjectType) // can use 'null' value instead of flag - Thicket(containerTree, mkDefThreadUnsafeNonNullable(x.symbol, containerSymbol, x.rhs)) + Thicket(containerTree, mkDefThreadUnsafeNonNullable(x.symbol, containerTree.symbol, x.rhs)) else { val flagName = LazyBitMapName.fresh(x.name.asTermName) val flagSymbol = newSymbol(x.symbol.owner, flagName, containerFlags | Private, defn.BooleanType).enteredAfter(this) val flag = ValDef(flagSymbol, Literal(Constant(false))) - Thicket(containerTree, flag, mkThreadUnsafeDef(x.symbol, flagSymbol, containerSymbol, x.rhs)) + Thicket(containerTree, flag, mkThreadUnsafeDef(x.symbol, flagSymbol, containerTree.symbol, x.rhs)) } } @@ -474,14 +483,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val claz = x.symbol.owner.asClass val thizClass = Literal(Constant(claz.info)) - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags | Private, defn.ObjectType, coord = x.symbol.coord).enteredAfter(this) - containerSymbol.addAnnotation(Annotation(defn.VolatileAnnot, containerSymbol.span)) // private @volatile var _x: AnyRef - containerSymbol.addAnnotations(x.symbol.annotations) // pass annotations from original definition - containerSymbol.removeAnnotation(defn.ScalaStaticAnnot) - val getOffset = - Select(ref(defn.LazyValsModule), lazyNme.RLazyVals.getOffsetStatic) - val containerTree = ValDef(containerSymbol, nullLiteral) + val (containerName, containerTree) = mkContainerTree(x, isVolatile = true) // create a VarHandle for this lazy val val varHandleSymbol: TermSymbol = newSymbol(claz, LazyVarHandleName(containerName), Private | Synthetic, defn.VarHandleClass.typeRef).enteredAfter(this) @@ -499,7 +501,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val swapOver = This(claz) - val (accessorDef, initMethodDef) = mkThreadSafeDef(x, claz, containerSymbol, varHandle, swapOver) + val (accessorDef, initMethodDef) = mkThreadSafeDef(x, claz, containerTree.symbol, varHandle, swapOver) Thicket(containerTree, accessorDef, initMethodDef) } diff --git a/tests/printing/transformed/lazy-vals-new.check b/tests/printing/transformed/lazy-vals-new.check index 26fc3ec21ddb..51a2905782c8 100644 --- a/tests/printing/transformed/lazy-vals-new.check +++ b/tests/printing/transformed/lazy-vals-new.check @@ -19,7 +19,7 @@ package { classOf[Object {...}], "x$lzy1", classOf[Object]) private def writeReplace(): Object = new scala.runtime.ModuleSerializationProxy(classOf[A]) - @volatile private lazy var x$lzy1: Object = null + @volatile private lazy var x$lzy1: Object = null.asInstanceOf[Object] lazy def x(): Int = { val result: Object = A.x$lzy1 diff --git a/tests/run/lazy-threadUnsafe-transient.scala b/tests/run/lazy-threadUnsafe-transient.scala new file mode 100644 index 000000000000..a08982dae33f --- /dev/null +++ b/tests/run/lazy-threadUnsafe-transient.scala @@ -0,0 +1,51 @@ +// scalajs: --skip +// https://github.com/scala/scala3/issues/23487 +import java.io.* +import scala.annotation.threadUnsafe + +def serialize[T <: Serializable](obj: T): Array[Byte] = { + val byteArrayOutputStream = new ByteArrayOutputStream() + val objectOutputStream = new ObjectOutputStream(byteArrayOutputStream) + + try { + objectOutputStream.writeObject(obj) + byteArrayOutputStream.toByteArray + } finally { + objectOutputStream.close() + byteArrayOutputStream.close() + } +} + +def deserialize[T](bytes: Array[Byte]): T = { + val byteArrayInputStream = new ByteArrayInputStream(bytes) + val objectInputStream = new ObjectInputStream(byteArrayInputStream) + + try { + objectInputStream.readObject().asInstanceOf[T] + } finally { + objectInputStream.close() + byteArrayInputStream.close() + } +} + +case class Foo() { + @transient + lazy val value: Long = System.nanoTime() +} + +case class Bar() { + @transient @threadUnsafe + lazy val value: Long = System.nanoTime() +} + +@main def Test() = { + val foo1 = Foo() + foo1.value // init lazy val + val foo2 = deserialize[Foo](serialize(foo1)) + assert(foo1.value != foo2.value, "Foo#value is not transient") + + val bar1 = Bar() + bar1.value // init lazy val + val bar2 = deserialize[Bar](serialize(bar1)) + assert(bar1.value != bar2.value, "Bar#value is not transient") +}