2525 Computed ,
2626 Constraint ,
2727 DefaultClause ,
28+ Dialect ,
2829 Enum ,
2930 Float ,
3031 ForeignKey ,
3940 UniqueConstraint ,
4041)
4142from sqlalchemy .dialects .postgresql import JSONB
42- from sqlalchemy .engine import Connection , Engine
4343from sqlalchemy .exc import CompileError
4444from sqlalchemy .sql .elements import TextClause
4545
@@ -94,11 +94,9 @@ class Base:
9494class CodeGenerator (metaclass = ABCMeta ):
9595 valid_options : ClassVar [set [str ]] = set ()
9696
97- def __init__ (
98- self , metadata : MetaData , bind : Connection | Engine , options : Sequence [str ]
99- ):
97+ def __init__ (self , metadata : MetaData , dialect : Dialect , options : Sequence [str ]):
10098 self .metadata : MetaData = metadata
101- self .bind : Connection | Engine = bind
99+ self .dialect : Dialect = dialect
102100 self .options : set [str ] = set (options )
103101
104102 # Validate options
@@ -124,12 +122,12 @@ class TablesGenerator(CodeGenerator):
124122 def __init__ (
125123 self ,
126124 metadata : MetaData ,
127- bind : Connection | Engine ,
125+ dialect : Dialect ,
128126 options : Sequence [str ],
129127 * ,
130128 indentation : str = " " ,
131129 ):
132- super ().__init__ (metadata , bind , options )
130+ super ().__init__ (metadata , dialect , options )
133131 self .indentation : str = indentation
134132 self .imports : dict [str , set [str ]] = defaultdict (set )
135133 self .module_imports : set [str ] = set ()
@@ -562,7 +560,7 @@ def add_fk_options(*opts: Any) -> None:
562560 ]
563561 add_fk_options (local_columns , remote_columns )
564562 elif isinstance (constraint , CheckConstraint ):
565- args .append (repr (get_compiled_expression (constraint .sqltext , self .bind )))
563+ args .append (repr (get_compiled_expression (constraint .sqltext , self .dialect )))
566564 elif isinstance (constraint , (UniqueConstraint , PrimaryKeyConstraint )):
567565 args .extend (repr (col .name ) for col in constraint .columns )
568566 else :
@@ -608,7 +606,7 @@ def fix_column_types(self, table: Table) -> None:
608606 # Detect check constraints for boolean and enum columns
609607 for constraint in table .constraints .copy ():
610608 if isinstance (constraint , CheckConstraint ):
611- sqltext = get_compiled_expression (constraint .sqltext , self .bind )
609+ sqltext = get_compiled_expression (constraint .sqltext , self .dialect )
612610
613611 # Turn any integer-like column with a CheckConstraint like
614612 # "column IN (0, 1)" into a Boolean
@@ -646,7 +644,7 @@ def fix_column_types(self, table: Table) -> None:
646644 pass
647645
648646 # PostgreSQL specific fix: detect sequences from server_default
649- if column .server_default and self .bind . dialect .name == "postgresql" :
647+ if column .server_default and self .dialect .name == "postgresql" :
650648 if isinstance (column .server_default , DefaultClause ) and isinstance (
651649 column .server_default .arg , TextClause
652650 ):
@@ -661,7 +659,7 @@ def fix_column_types(self, table: Table) -> None:
661659 column .server_default = None
662660
663661 def get_adapted_type (self , coltype : Any ) -> Any :
664- compiled_type = coltype .compile (self .bind . engine . dialect )
662+ compiled_type = coltype .compile (self .dialect )
665663 for supercls in coltype .__class__ .__mro__ :
666664 if not supercls .__name__ .startswith ("_" ) and hasattr (
667665 supercls , "__visit_name__"
@@ -687,7 +685,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
687685 try :
688686 # If the adapted column type does not render the same as the
689687 # original, don't substitute it
690- if new_coltype .compile (self .bind . engine . dialect ) != compiled_type :
688+ if new_coltype .compile (self .dialect ) != compiled_type :
691689 # Make an exception to the rule for Float and arrays of Float,
692690 # since at least on PostgreSQL, Float can accurately represent
693691 # both REAL and DOUBLE_PRECISION
@@ -718,13 +716,13 @@ class DeclarativeGenerator(TablesGenerator):
718716 def __init__ (
719717 self ,
720718 metadata : MetaData ,
721- bind : Connection | Engine ,
719+ dialect : Dialect ,
722720 options : Sequence [str ],
723721 * ,
724722 indentation : str = " " ,
725723 base_class_name : str = "Base" ,
726724 ):
727- super ().__init__ (metadata , bind , options , indentation = indentation )
725+ super ().__init__ (metadata , dialect , options , indentation = indentation )
728726 self .base_class_name : str = base_class_name
729727 self .inflect_engine = inflect .engine ()
730728
@@ -1305,7 +1303,7 @@ class DataclassGenerator(DeclarativeGenerator):
13051303 def __init__ (
13061304 self ,
13071305 metadata : MetaData ,
1308- bind : Connection | Engine ,
1306+ dialect : Dialect ,
13091307 options : Sequence [str ],
13101308 * ,
13111309 indentation : str = " " ,
@@ -1315,7 +1313,7 @@ def __init__(
13151313 ):
13161314 super ().__init__ (
13171315 metadata ,
1318- bind ,
1316+ dialect ,
13191317 options ,
13201318 indentation = indentation ,
13211319 base_class_name = base_class_name ,
@@ -1344,15 +1342,15 @@ class SQLModelGenerator(DeclarativeGenerator):
13441342 def __init__ (
13451343 self ,
13461344 metadata : MetaData ,
1347- bind : Connection | Engine ,
1345+ dialect : Dialect ,
13481346 options : Sequence [str ],
13491347 * ,
13501348 indentation : str = " " ,
13511349 base_class_name : str = "SQLModel" ,
13521350 ):
13531351 super ().__init__ (
13541352 metadata ,
1355- bind ,
1353+ dialect ,
13561354 options ,
13571355 indentation = indentation ,
13581356 base_class_name = base_class_name ,
0 commit comments