Skip to content

Commit

Permalink
fix getting field names from of pydantic models
Browse files Browse the repository at this point in the history
  • Loading branch information
sillitoe committed Nov 15, 2023
1 parent 56ced3c commit f01233e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions cath_alphaflow/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Type
import dataclasses

import pydantic
from Bio.PDB import PDBParser
from Bio.PDB import Structure

Expand Down Expand Up @@ -62,23 +63,23 @@ def __init__(
fieldnames = kwargs["fieldnames"]
else:
fieldnames = [
f.name
for f in self.get_result_object_fields()
if not f.name.startswith("_")
f
for f in self.get_result_object_fieldnames()
if not f.startswith("_")
]
self.fieldnames = fieldnames

def get_default_delimiter(self):
return self.DEFAULT_DELIMITER

def get_result_object_fields(self):
def get_result_object_fieldnames(self):
# pydantic model
if hasattr(self.object_class, "__fields__"):
fields = list(self.object_class.__fields__.values())
if issubclass(self.object_class, pydantic.BaseModel):
fieldnames = [f for f in list(self.object_class.model_fields.keys())]
# dataclasses
else:
fields = dataclasses.fields(self.object_class)
return fields
fieldnames = [f.name for f in dataclasses.fields(self.object_class)]
return fieldnames

def __next__(self):
"""
Expand Down

0 comments on commit f01233e

Please sign in to comment.