@@ -38,18 +38,12 @@ class ConstantVariable(VariableTracker):
3838 @staticmethod
3939 def create (value , ** kwargs ) -> VariableTracker :
4040 source = kwargs .get ("source" , None )
41- is_literal = ConstantVariable .is_literal (value )
42- if not is_literal :
43- for disallowed_type , reason in _type_to_assert_reason .items ():
44- assert not isinstance (value , disallowed_type ), reason
4541
46- # Routing for list and tuple literals.
47- if is_literal and isinstance (value , (set , frozenset )):
48- items = []
49- for i , x in enumerate (value ):
50- items .append (ConstantVariable .create (x ))
42+ # Routing for supported collection literals.
43+ if isinstance (value , (set , frozenset )):
44+ items = [ConstantVariable .create (x ) for x in value ]
5145 return variables .SetVariable (items , ** kwargs )
52- elif is_literal and isinstance (value , (list , tuple )):
46+ elif isinstance (value , (list , tuple )):
5347 items = []
5448 for i , x in enumerate (value ):
5549 item_source = GetItemSource (source , i ) if source else None
@@ -67,13 +61,10 @@ def create(value, **kwargs) -> VariableTracker:
6761
6862 def __init__ (self , value , ** kwargs ) -> None :
6963 super ().__init__ (** kwargs )
70- if not ConstantVariable .is_literal (value ):
64+ if not ConstantVariable .is_base_literal (value ):
7165 for disallowed_type , reason in _type_to_assert_reason .items ():
7266 assert not isinstance (value , disallowed_type ), reason
7367
74- assert not isinstance (
75- value , (list , tuple )
76- ), "ConstantVariable(list) is banned - please create a ListVariable(items)"
7768 if np is not None and isinstance (value , np .number ):
7869 self .value = value .item ()
7970 else :
@@ -104,14 +95,15 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
10495 self .value [arg .as_python_constant ()],
10596 )
10697
98+ @staticmethod
99+ def is_base_literal (obj ):
100+ return type (obj ) in common_constant_types
101+
107102 @staticmethod
108103 def is_literal (obj ):
109- if type (obj ) in common_constant_types :
110- return True
111- # The structure within is_literal get routed to variables.BaseListVariable
112104 if type (obj ) in (list , tuple , set , frozenset , torch .Size ):
113105 return all (ConstantVariable .is_literal (x ) for x in obj )
114- return False
106+ return ConstantVariable . is_base_literal ( obj )
115107
116108 def unpack_var_sequence (self , tx ):
117109 try :
0 commit comments