From 0e7b885d0c81fa9c105ee20984589bd1931411a7 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Fri, 19 Sep 2025 11:36:55 -0400 Subject: [PATCH 1/2] [refactor] Updated type reflection attribute `class_` with Literal instead of string --- src/oqd_compiler_infrastructure/interface.py | 31 ++++---------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/oqd_compiler_infrastructure/interface.py b/src/oqd_compiler_infrastructure/interface.py index 21ae5e3..00cdab4 100644 --- a/src/oqd_compiler_infrastructure/interface.py +++ b/src/oqd_compiler_infrastructure/interface.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Literal -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict ######################################################################################## @@ -42,26 +42,7 @@ class TypeReflectBaseModel(VisitableBaseModel): Class representing a datastruct with type reflection """ - class_: Optional[str] - - @model_validator(mode="before") - @classmethod - def reflect(cls, data): - if isinstance(data, BaseModel): - return data - if "class_" in data.keys(): - if data["class_"] != cls.__name__: - raise ValueError('discrepency between "class_" field and model type') - - data["class_"] = cls.__name__ - - return data - - # @classmethod - # def get_subclasses(cls): - # # necessary for deserializing subclassed operators - # def all_subclasses(cls): - # return set(cls.__subclasses__()).union( - # [s for c in cls.__subclasses__() for s in all_subclasses(c)]) - # return tuple(all_subclasses(cls)) - # # return tuple(cls.__subclasses__()) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.__annotations__ = dict(class_=Literal[cls.__name__], **cls.__annotations__) + setattr(cls, "class_", cls.__name__) From 20224a5a8beae18f9e737029bf03f3d1f96cbca3 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Fri, 19 Sep 2025 11:45:04 -0400 Subject: [PATCH 2/2] [fix] modify model_fields accessed from class instead of instance due to deprecation of instance access in pydantic 2.11 --- src/oqd_compiler_infrastructure/walk.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/oqd_compiler_infrastructure/walk.py b/src/oqd_compiler_infrastructure/walk.py index 0250e8f..ab6562f 100644 --- a/src/oqd_compiler_infrastructure/walk.py +++ b/src/oqd_compiler_infrastructure/walk.py @@ -119,7 +119,9 @@ def walk_VisitableBaseModel(self, model): new_model = self.rule(model) new_fields = {} - for key in self.controlled_reverse(new_model.model_fields.keys(), self.reverse): + for key in self.controlled_reverse( + new_model.__class__.model_fields.keys(), self.reverse + ): if key == "class_": continue new_fields[key] = self(getattr(new_model, key)) @@ -181,7 +183,9 @@ def walk_tuple(self, model): def walk_VisitableBaseModel(self, model): new_fields = {} - for key in self.controlled_reverse(model.model_fields.keys(), self.reverse): + for key in self.controlled_reverse( + model.__class__.model_fields.keys(), self.reverse + ): if key == "class_": continue new_fields[key] = self(getattr(model, key)) @@ -260,7 +264,11 @@ def walk_VisitableBaseModel(self, model): self.stack.extend( self.controlled_reverse( - [getattr(model, k) for k in model.model_fields.keys() if k != "class_"], + [ + getattr(model, k) + for k in model.__class__.model_fields.keys() + if k != "class_" + ], self.reverse, ) ) @@ -308,7 +316,7 @@ def walk_dict(self, model): return model def walk_VisitableBaseModel(self, model): - keys = [k for k in model.model_fields.keys() if k != "class_"] + keys = [k for k in model.__class__.model_fields.keys() if k != "class_"] keys = self.controlled_reverse(keys, self.reverse, restore_type=True) for k in keys[:-1]: self(getattr(model, k))