Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion firrtl/src/main/scala/firrtl/annotations/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ trait HasSerializationHints {
def typeHints: Seq[Class[_]]
}

/** Similar to [[HasSerializationHints]] but for types whose serialization classes
* need to be overridden
*/
@deprecated("All APIs in package firrtl are deprecated.", "Chisel 7.0.0")
trait HasSerializationOverrides {
// For serialization of complicated constructor arguments, let the annotation
// writer specify additional type hints for relevant classes that might be
// contained within
def typeOverrides: Seq[(Class[_], String)]
}

/** Mix this in to override what class name is used for serialization
*
* Note that this breaks automatic deserialization.
*/
@deprecated("All APIs in package firrtl are deprecated.", "Chisel 7.0.0")
trait OverrideSerializationClass { self: Annotation =>
trait OverrideSerializationClass {
def serializationClassOverride: String
}

Expand Down Expand Up @@ -195,6 +206,18 @@ object JsonProtocol extends LazyLogging {
anno.typeHints.foreach(addTag(_))
case _ => ()
}
anno match {
case anno: HasSerializationOverrides =>
anno.typeOverrides.foreach { case (clazz, name) =>
val existing = tagOverride.put(clazz, name)
if (existing.isDefined && existing.get != name) {
throw new Exception(
s"Class $clazz has multiple serialization class overrides: ${existing.get}, $name"
)
}
}
case _ => ()
}
anno match {
case anno: OverrideSerializationClass =>
val existing = tagOverride.put(anno.getClass, anno.serializationClassOverride)
Expand Down
21 changes: 21 additions & 0 deletions firrtl/src/test/scala/firrtl/JsonProtocolSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ object JsonProtocolTestClasses {
case class AnnotationWithOverride(value: String) extends NoTargetAnnotation with OverrideSerializationClass {
def serializationClassOverride = value
}

// Test case for OverrideSerializationClass on nested types
case class NestedTypeWithOverride(name: String) extends OverrideSerializationClass {
def serializationClassOverride = "custom.nested.type"
}

case class AnnotationWithNestedOverride(nested: NestedTypeWithOverride)
extends NoTargetAnnotation
with HasSerializationOverrides {
def typeOverrides = Seq(nested.getClass -> nested.serializationClassOverride)
}
}

import JsonProtocolTestClasses._
Expand Down Expand Up @@ -139,4 +150,14 @@ class JsonProtocolSpec extends AnyFlatSpec with Matchers {
val e = the[Exception] thrownBy JsonProtocol.serialize(annos)
e.getMessage should include("multiple serialization class overrides: foo, bar")
}

it should "work on nested types inside annotations with HasSerializationOverrides" in {
val nested = NestedTypeWithOverride("test")
val anno = AnnotationWithNestedOverride(nested)
val res = JsonProtocol.serialize(Seq(anno))
// Verify that the nested type uses the overridden class name
res should include(""""class":"custom.nested.type"""")
// Also verify that the annotation itself uses its normal class name
res should include(""""class":"firrtlTests.JsonProtocolTestClasses$AnnotationWithNestedOverride"""")
}
}
Loading