diff --git a/firrtl/src/main/scala/firrtl/annotations/JsonProtocol.scala b/firrtl/src/main/scala/firrtl/annotations/JsonProtocol.scala index 925f3215ac5..11a57a9d243 100644 --- a/firrtl/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/firrtl/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -25,12 +25,23 @@ trait HasSerializationHints { def typeHints: Seq[Class[_]] } +/** Similar to [[HasSerializationHints]] but for types whose serialization classes + * need to be overridden + */ +@deprecated("All APIs in package firrtl are deprecated.", "Chisel 7.0.0") +trait HasSerializationOverrides { + // For serialization of complicated constructor arguments, let the annotation + // writer specify additional type hints for relevant classes that might be + // contained within + def typeOverrides: Seq[(Class[_], String)] +} + /** Mix this in to override what class name is used for serialization * * Note that this breaks automatic deserialization. */ @deprecated("All APIs in package firrtl are deprecated.", "Chisel 7.0.0") -trait OverrideSerializationClass { self: Annotation => +trait OverrideSerializationClass { def serializationClassOverride: String } @@ -195,6 +206,18 @@ object JsonProtocol extends LazyLogging { anno.typeHints.foreach(addTag(_)) case _ => () } + anno match { + case anno: HasSerializationOverrides => + anno.typeOverrides.foreach { case (clazz, name) => + val existing = tagOverride.put(clazz, name) + if (existing.isDefined && existing.get != name) { + throw new Exception( + s"Class $clazz has multiple serialization class overrides: ${existing.get}, $name" + ) + } + } + case _ => () + } anno match { case anno: OverrideSerializationClass => val existing = tagOverride.put(anno.getClass, anno.serializationClassOverride) diff --git a/firrtl/src/test/scala/firrtl/JsonProtocolSpec.scala b/firrtl/src/test/scala/firrtl/JsonProtocolSpec.scala index 07ae1b96b15..44ab64e1430 100644 --- a/firrtl/src/test/scala/firrtl/JsonProtocolSpec.scala +++ b/firrtl/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -36,6 +36,17 @@ object JsonProtocolTestClasses { case class AnnotationWithOverride(value: String) extends NoTargetAnnotation with OverrideSerializationClass { def serializationClassOverride = value } + + // Test case for OverrideSerializationClass on nested types + case class NestedTypeWithOverride(name: String) extends OverrideSerializationClass { + def serializationClassOverride = "custom.nested.type" + } + + case class AnnotationWithNestedOverride(nested: NestedTypeWithOverride) + extends NoTargetAnnotation + with HasSerializationOverrides { + def typeOverrides = Seq(nested.getClass -> nested.serializationClassOverride) + } } import JsonProtocolTestClasses._ @@ -139,4 +150,14 @@ class JsonProtocolSpec extends AnyFlatSpec with Matchers { val e = the[Exception] thrownBy JsonProtocol.serialize(annos) e.getMessage should include("multiple serialization class overrides: foo, bar") } + + it should "work on nested types inside annotations with HasSerializationOverrides" in { + val nested = NestedTypeWithOverride("test") + val anno = AnnotationWithNestedOverride(nested) + val res = JsonProtocol.serialize(Seq(anno)) + // Verify that the nested type uses the overridden class name + res should include(""""class":"custom.nested.type"""") + // Also verify that the annotation itself uses its normal class name + res should include(""""class":"firrtlTests.JsonProtocolTestClasses$AnnotationWithNestedOverride"""") + } }