2626package hat ;
2727
2828import hat .phases .HATPhaseUtils ;
29+ import jdk .incubator .code .dialect .java .ArrayType ;
2930import jdk .incubator .code .dialect .java .JavaOp ;
31+ import jdk .incubator .code .dialect .java .PrimitiveType ;
3032import optkl .IfaceValue ;
3133import optkl .OpHelper ;
3234import optkl .ifacemapper .AccessType ;
33- import optkl .ifacemapper .Buffer ;
34- import optkl .ifacemapper .MappableIface ;
3535import jdk .incubator .code .Op ;
3636import jdk .incubator .code .Value ;
3737import jdk .incubator .code .Block ;
4545import static optkl .OpHelper .Invoke .invoke ;
4646
4747public class BufferTagger {
48- static HashMap <Value , AccessType > accessMap = new HashMap <>();
49- static HashMap <Value , Value > remappedVals = new HashMap <>(); // maps values to their "root" parameter/value
48+ static HashMap <Value , AccessType > accessMap = new HashMap <>(); // mapping of parameters/buffers to access type
49+ // TODO: fix how we use rootValues
50+ static HashMap <Value , Value > rootValues = new HashMap <>(); // maps values to their "root" parameter/value
5051 static HashMap <Block , List <Block .Parameter >> blockParams = new HashMap <>(); // holds block parameters for easy lookup
5152
5253 // generates a list of AccessTypes matching the given FuncOp's parameter order
@@ -56,24 +57,25 @@ public static ArrayList<AccessType> getAccessList(MethodHandles.Lookup lookup, C
5657 for (Block .Parameter p : inlinedEntryPoint .body ().entryBlock ().parameters ()) {
5758 if (accessMap .containsKey (p )) {
5859 accessList .add (accessMap .get (p )); // is an accessed buffer
59- } else if (OpHelper .isAssignable (lookup , p .type (), MappableIface .class )) {
60+ } else if (OpHelper .isAssignable (lookup , p .type (), IfaceValue .class )) {
6061 accessList .add (AccessType .NA ); // is a buffer but not accessed
6162 } else {
6263 accessList .add (AccessType .NOT_BUFFER ); // is not a buffer
6364 }
6465 }
6566 return accessList ;
6667 }
67- private static boolean isReference (Invoke ioh ) {
68- return ioh .returns (IfaceValue .class )
68+
69+ private static boolean isAccessed (Invoke ioh ) {
70+ return !(ioh .returns (IfaceValue .class )
6971 && ioh .opFromOnlyUseOrNull () instanceof JavaOp .InvokeOp nextInvoke
7072 && invoke (ioh .lookup (), nextInvoke ) instanceof Invoke nextIoh
7173 && nextIoh .refIs (IfaceValue .class )
72- && nextIoh .returnsVoid ();
74+ && nextIoh .returnsVoid ()) ;
7375 }
7476
7577 // creates the access map
76- private static void buildAccessMap (MethodHandles .Lookup lookup , CoreOp .FuncOp funcOp ) {
78+ private static void buildAccessMap (MethodHandles .Lookup lookup , CoreOp .FuncOp funcOp ) {
7779 // build blockParams so that we can map params to "root" params later
7880 funcOp .elements ()
7981 .filter (elem -> elem instanceof Block )
@@ -88,25 +90,24 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f
8890 mapBranch (lookup , cb .falseBranch ()); // handle false branch
8991 }
9092 case JavaOp .InvokeOp invokeOp -> {
91- var ioh = invoke (lookup ,invokeOp );
92- // we have to deal with array views too
93- // should .arrayview() calls be marked as reads?
94- if ( ioh .refIs (IfaceValue .class )) {
95- // updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access
96- // if the invokeOp retrieves an element that is only written to, don't update the access type
97- // (i.e. the only use is an invoke, the invoke is of MappableIface/HAType class, and is a write)
98- if (!isReference (ioh )) { // value retrieved and not just referenced?
99- updateAccessType (getRootValue (invokeOp ), ioh .returnsVoid ()? AccessType .WO : AccessType .RO ); // update buffer access
100- }
101- if (ioh .refIs (IfaceValue .class ) && (ioh .returns (IfaceValue .class ) || ioh .returnsArray ())) {
102- // if we access a struct/union from a buffer, we map the struct/union to the buffer root
103- remappedVals .put (invokeOp .result (), getRootValue (invokeOp ));
93+ var ioh = invoke (lookup ,invokeOp );
94+ if (ioh .refIs (KernelContext .class )) break ; // if this is not referencing a buffer, we continue
95+ if (ioh .returns (IfaceValue .class ) || ioh .returnsArray ()) { // if we receive a buffer from this invoke, we save its root value
96+ for (Value operand : ioh .op ().operands ()) {
97+ if (!(operand .type () instanceof PrimitiveType ) && rootValues .containsKey (operand )) {
98+ // TODO: FIX
99+ if (operand instanceof Block .Parameter ) updateAccessType (operand , AccessType .RO );
100+ else updateAccessType (getRootValue (operand .result ().op ()), AccessType .RO );
101+ }
104102 }
103+ rootValues .put (invokeOp .result (), getRootValue (invokeOp ));
104+ } else if (isAccessed (ioh )) { // if we actually operate on a buffer instead of storing an element in a variable
105+ updateAccessType (getRootValue (invokeOp ), ioh .returnsVoid () ? AccessType .WO : AccessType .RO ); // update buffer access
105106 }
106107 }
107108 case CoreOp .VarOp vop -> { // map the new VarOp to the "root" param
108- if (OpHelper .isAssignable (lookup , vop .resultType ().valueType (), Buffer .class )) {
109- remappedVals .put (vop .initOperand (), getRootValue (vop ));
109+ if (OpHelper .isAssignable (lookup , vop .resultType ().valueType (), IfaceValue .class )) {
110+ rootValues .put (vop .initOperand (), getRootValue (vop ));
110111 }else {
111112 // or else maybe CoreOp.VarOp vop when ??? ->
112113 }
@@ -118,7 +119,10 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f
118119 // or else
119120 }
120121 }
121- case JavaOp .ArrayAccessOp .ArrayLoadOp alop -> updateAccessType (getRootValue (alop ), AccessType .RO );
122+ case JavaOp .ArrayAccessOp .ArrayLoadOp alop -> {
123+ if (alop .resultType () instanceof ArrayType ) break ;
124+ updateAccessType (getRootValue (alop ), AccessType .RO );
125+ }
122126 case JavaOp .ArrayAccessOp .ArrayStoreOp asop -> updateAccessType (getRootValue (asop ), AccessType .WO );
123127 default -> {}
124128 }
@@ -131,30 +135,30 @@ private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference block
131135 for (int i = 0 ; i < args .size (); i ++) {
132136 Value key = blockParams .get (blockReference .targetBlock ()).get (i );
133137 Value value = args .get (i );
134- if (value instanceof Op .Result result ) {
135- // either find root param or it doesn't exist (is a constant for example)
136- if (OpHelper .isAssignable (lookup , value .type (), MappableIface .class )) {
137- value = getRootValue (result .op ());
138- if (value instanceof Block .Parameter ) {
139- value = remappedVals .getOrDefault (value , value );
140- }
141- }else {
142- // or else
138+ if (value instanceof Op .Result result && OpHelper .isAssignable (lookup , value .type (), IfaceValue .class )) {
139+ value = getRootValue (result .op ());
140+ if (value instanceof Block .Parameter ) {
141+ value = rootValues .getOrDefault (value , value );
143142 }
144143 }else {
145- // or else?
144+ // or else?
146145 }
147- remappedVals .put (key , value );
146+ rootValues .put (key , value );
148147 }
149148 }
150149
151150 // retrieves "root" value of an op, which is how we track accesses
152151 // we will map the return value of this method to the accessType
153- private static Value getRootValue (Op op ) {
152+ private static Value getRootValue (Op op ) {
154153 // the op is a field load, an invoke, or something that reduces to one or the other
155154 // first, check if we can retrieve a fieldloadop from the given op
156155 Op fieldOp = HATPhaseUtils .findOpInResultFromFirstOperandsOrNull (op , JavaOp .FieldAccessOp .FieldLoadOp .class );
157- if (fieldOp != null ) return fieldOp .operands ().getFirst (); // if so, we use its first operand to map to accesses
156+ if (fieldOp != null ) {
157+ if (fieldOp .operands ().isEmpty ()) {
158+ return null ;
159+ }
160+ return fieldOp .operands ().getFirst (); // if so, we use its first operand to map to accesses
161+ }
158162
159163 // we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created)
160164 // or if there's an invokeop with a parameter as its first operation (this is a global buffer)
@@ -166,9 +170,9 @@ private static Value getRootValue(Op op) {
166170 return (invokeOp == null ) ? null : invokeOp .result (); // return the shared/private buffer invokeop that creates the buffer
167171 }
168172
169- // updates accessMap
170- private static void updateAccessType (Value value , AccessType currentAccess ) {
171- Value remappedValue = remappedVals .getOrDefault (value , value );
173+ // updates the access map
174+ private static void updateAccessType (Value value , AccessType currentAccess ) {
175+ Value remappedValue = rootValues .getOrDefault (value , value );
172176 AccessType storedAccess = accessMap .get (remappedValue );
173177 if (storedAccess == null ) {
174178 accessMap .put (remappedValue , currentAccess );
0 commit comments