23
23
# See the License for the specific language governing permissions and
24
24
# limitations under the License.
25
25
26
- from typing import Callable , Dict , Type
26
+ from typing import Callable , Dict , List , Type
27
27
28
28
import numpy as np
29
29
import onnx_graphsurgeon as gs
@@ -82,6 +82,8 @@ def __init__(self,
82
82
RemoveGlobalOutputReshapePass (),
83
83
]
84
84
85
+ self .extNameCount = 0
86
+
85
87
def bind (self ):
86
88
# SCHEREMO: THIS IS A STOP GAP SOLUTION. DONT REUSE. I MEAN IT. I WILL FIND YOU.
87
89
# SCHEREMO: The BindingOptimizationPass system is fairly fragile;
@@ -96,28 +98,33 @@ def bind(self):
96
98
self .ctxt .hoistGlobalDefinition ("cluster_dev" , "extern struct pi_device cluster_dev;" )
97
99
return ret
98
100
99
- def generateBufferAllocationCode (self ) -> str :
100
- retStr = super ().generateBufferAllocationCode ()
101
-
102
- L3FileStr = ""
103
- globalConstBuffers = [
104
- buf for key , buf in self .ctxt .globalObjects .items () if isinstance (buf , VariableBuffer ) and buf ._deploy
101
+ def _l3ConstBuffer (self ) -> List [VariableBuffer ]:
102
+ return [
103
+ buf for buf in self .ctxt .globalObjects .values () if all ([
104
+ isinstance (buf , VariableBuffer ) and buf ._deploy ,
105
+ hasattr (buf , "_users" ) and len (buf ._users ) > 0 ,
106
+ hasattr (buf , "_memoryLevel" ) and buf ._memoryLevel == "L3" ,
107
+ ])
105
108
]
106
- nonArenaBuffers = [buf for buf in globalConstBuffers if buf ._users != []]
107
- l3ConstBuffer = [buf for buf in nonArenaBuffers if hasattr (buf , "_memoryLevel" ) and buf ._memoryLevel == "L3" ]
108
-
109
- for idx , buf in enumerate (l3ConstBuffer ):
110
-
111
- locPtr = str (buf ._instance )
112
- extName = str (idx )
113
- buf .extName = extName
114
- size = np .prod (buf .shape ) * (buf ._type .referencedType .typeWidth // 8 )
115
109
116
- if isinstance (buf , ConstantBuffer ):
117
- L3FileStr += _L3AllocTemplate .generate ({"locPtr" : locPtr , "extName" : extName , "size" : size })
118
-
119
- L3FileStr += _L3InitTemplate .generate ({"locPtr" : locPtr , "extName" : extName , "size" : size })
120
-
121
- retStr = retStr + L3FileStr
110
+ def _newExtName (self ) -> str :
111
+ name = str (self .extNameCount )
112
+ self .extNameCount += 1
113
+ return name
114
+
115
+ def _generateL3BufferAllocationCode (self , buf : VariableBuffer ) -> str :
116
+ retStr = ""
117
+ locPtr = str (buf ._instance )
118
+ extName = self ._newExtName ()
119
+ buf .extName = extName
120
+ size = np .prod (buf .shape ) * (buf ._type .referencedType .typeWidth // 8 )
121
+ if isinstance (buf , ConstantBuffer ):
122
+ retStr += _L3AllocTemplate .generate ({"locPtr" : locPtr , "extName" : extName , "size" : size })
123
+ retStr += _L3InitTemplate .generate ({"locPtr" : locPtr , "extName" : extName , "size" : size })
124
+ return retStr
122
125
126
+ def generateBufferAllocationCode (self ) -> str :
127
+ retStr = super ().generateBufferAllocationCode ()
128
+ for buf in self ._l3ConstBuffer ():
129
+ retStr += self ._generateL3BufferAllocationCode (buf )
123
130
return retStr
0 commit comments