Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 65 additions & 85 deletions hat/core/src/main/java/hat/BufferTagger.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
package hat;

import hat.phases.HATPhaseUtils;
import jdk.incubator.code.dialect.java.ArrayType;
import jdk.incubator.code.dialect.java.JavaOp;
import jdk.incubator.code.dialect.java.PrimitiveType;
import optkl.IfaceValue;
import optkl.OpHelper;
import optkl.ifacemapper.AccessType;
import optkl.ifacemapper.Buffer;
import optkl.ifacemapper.MappableIface;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;
import jdk.incubator.code.Block;
Expand All @@ -40,46 +40,32 @@
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import static optkl.OpHelper.Invoke;
import static optkl.OpHelper.Invoke.invoke;

public class BufferTagger {
static HashMap<Value, AccessType> accessMap = new HashMap<>();
static HashMap<Value, Value> remappedVals = new HashMap<>(); // maps values to their "root" parameter/value
static HashMap<Block, List<Block.Parameter>> blockParams = new HashMap<>(); // holds block parameters for easy lookup
static Map<Value, AccessType> accessMap = new HashMap<>(); // mapping of parameters/buffers to access type
static Map<Value, Value> rootValues = new HashMap<>(); // maps values to their "root" parameter/value

// generates a list of AccessTypes matching the given FuncOp's parameter order
public static ArrayList<AccessType> getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp inlinedEntryPoint) {
buildAccessMap(lookup, inlinedEntryPoint);
ArrayList<AccessType> accessList = new ArrayList<>();
for (Block.Parameter p : inlinedEntryPoint.body().entryBlock().parameters()) {
public static List<AccessType> getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
buildAccessMap(lookup, funcOp);
List<AccessType> accessList = new ArrayList<>();
for (Block.Parameter p : funcOp.body().entryBlock().parameters()) {
if (accessMap.containsKey(p)) {
accessList.add(accessMap.get(p)); // is an accessed buffer
} else if (OpHelper.isAssignable(lookup, p.type(), MappableIface.class)) {
} else if (OpHelper.isAssignable(lookup, p.type(), IfaceValue.class)) {
accessList.add(AccessType.NA); // is a buffer but not accessed
} else {
accessList.add(AccessType.NOT_BUFFER); // is not a buffer
}
}
return accessList;
}
private static boolean isReference(Invoke ioh) {
return ioh.returns(IfaceValue.class)
&& ioh.opFromOnlyUseOrNull() instanceof JavaOp.InvokeOp nextInvoke
&& invoke(ioh.lookup(), nextInvoke) instanceof Invoke nextIoh
&& nextIoh.refIs(IfaceValue.class)
&& nextIoh.returnsVoid();
}

// creates the access map
private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
// build blockParams so that we can map params to "root" params later
funcOp.elements()
.filter(elem -> elem instanceof Block)
.map(elem->(Block)elem)
.forEach(block -> blockParams.put(block, block.parameters()));

private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
funcOp.elements().forEach(op -> {
switch (op) {
case CoreOp.BranchOp b -> mapBranch(lookup, b.branch());
Expand All @@ -88,37 +74,32 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f
mapBranch(lookup, cb.falseBranch()); // handle false branch
}
case JavaOp.InvokeOp invokeOp -> {
var ioh = invoke(lookup,invokeOp);
// we have to deal with array views too
// should .arrayview() calls be marked as reads?
if ( ioh.refIs(IfaceValue.class)) {
// updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access
// if the invokeOp retrieves an element that is only written to, don't update the access type
// (i.e. the only use is an invoke, the invoke is of MappableIface/HAType class, and is a write)
if (!isReference(ioh)) { // value retrieved and not just referenced?
updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access
}
if (ioh.refIs(IfaceValue.class) && (ioh.returns(IfaceValue.class) || ioh.returnsArray())) {
// if we access a struct/union from a buffer, we map the struct/union to the buffer root
remappedVals.put(invokeOp.result(), getRootValue(invokeOp));
var ioh = invoke(lookup,invokeOp);
if (ioh.refIs(KernelContext.class)) break; // if this is not referencing a buffer, we break
if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value
for (Value operand : ioh.op().operands()) {
Comment thread
rbrchen marked this conversation as resolved.
Outdated
if (!(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) {
if (operand instanceof Block.Parameter) updateAccessType(operand, AccessType.RO);
Comment thread
rbrchen marked this conversation as resolved.
Outdated
else updateAccessType(getRootValue(operand.result().op()), AccessType.RO);
}
}
rootValues.put(invokeOp.result(), getRootValue(invokeOp));
} else { // if we actually operate on a buffer instead of storing an element in a variable
updateAccessType(rootValues.getOrDefault(invokeOp.result(), getRootValue(invokeOp)), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access
}
}
case CoreOp.VarOp vop -> { // map the new VarOp to the "root" param
if (OpHelper.isAssignable(lookup, vop.resultType().valueType(), Buffer.class)) {
remappedVals.put(vop.initOperand(), getRootValue(vop));
}else{
// or else maybe CoreOp.VarOp vop when ??? ->
}
if (!OpHelper.isAssignable(lookup, vop.resultType().valueType(), IfaceValue.class)) break;
rootValues.put(vop.initOperand(), getRootValue(vop));
}
case JavaOp.FieldAccessOp.FieldLoadOp flop -> {
if (OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) {
updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access
}else{
// or else
}
if (!OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) break;
updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access
}
case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> {
if (alop.resultType() instanceof ArrayType) break;
updateAccessType(getRootValue(alop), AccessType.RO);
}
case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> updateAccessType(getRootValue(alop), AccessType.RO);
case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> updateAccessType(getRootValue(asop), AccessType.WO);
default -> {}
}
Expand All @@ -127,55 +108,54 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f

// maps the parameters of a block to the values passed to a branch
private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference blockReference) {
List<Value> args = blockReference.arguments();
for (int i = 0; i < args.size(); i++) {
Value key = blockParams.get(blockReference.targetBlock()).get(i);
Value value = args.get(i);
if (value instanceof Op.Result result) {
// either find root param or it doesn't exist (is a constant for example)
if (OpHelper.isAssignable(lookup, value.type(), MappableIface.class)) {
value = getRootValue(result.op());
if (value instanceof Block.Parameter) {
value = remappedVals.getOrDefault(value, value);
}
}else{
// or else
}
}else{
// or else?
}
remappedVals.put(key, value);
List<Value> inputArgs = blockReference.arguments();
List<Block.Parameter> targetArgs = blockReference.targetBlock().parameters();
for (int i = 0; i < inputArgs.size(); i++) {
Comment thread
rbrchen marked this conversation as resolved.
Outdated
Value target = targetArgs.get(i);
Value input = inputArgs.get(i);
if (!(input instanceof Op.Result result && OpHelper.isAssignable(lookup, input.type(), IfaceValue.class))) break;
input = getRootValue(result.op());
rootValues.put(target, rootValues.getOrDefault(input, input));
}
}

// retrieves "root" value of an op, which is how we track accesses
// we will map the return value of this method to the accessType
private static Value getRootValue(Op op) {
private static Value getRootValue(Op op) {
// the op is a field load, an invoke, or something that reduces to one or the other
// first, check if we can retrieve a fieldloadop from the given op
Op fieldOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class);
if (fieldOp != null) return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses

// we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created)
// or if there's an invokeop with a parameter as its first operation (this is a global buffer)
Op invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
while (invokeOp != null && !invokeOp.operands().isEmpty()) {
if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer
invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class);
Op rootOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class, JavaOp.InvokeOp.class);
switch (rootOp) {
case JavaOp.FieldAccessOp.FieldLoadOp fieldOp -> {
if (fieldOp.operands().isEmpty()) break; // e.g. handling kc.warpSize
Comment thread
rbrchen marked this conversation as resolved.
Outdated
return fieldOp.operands().getFirst();
}
case JavaOp.InvokeOp invokeOp -> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this would be easier I think if you used the Invoke wrapper.

while (invokeOp != null && !invokeOp.operands().isEmpty()) { // we look for either the parameter or initialization for the buffer
if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer
invokeOp = (JavaOp.InvokeOp) HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class);
}
if (invokeOp != null) return invokeOp.result();
}
case null, default -> {}
}
return (invokeOp == null) ? null : invokeOp.result(); // return the shared/private buffer invokeop that creates the buffer
return null;
}

// updates accessMap
private static void updateAccessType(Value value, AccessType currentAccess) {
Value remappedValue = remappedVals.getOrDefault(value, value);
// updates the access map
private static void updateAccessType(Value value, AccessType currentAccess) {
Value remappedValue = rootValues.getOrDefault(value, value);
AccessType storedAccess = accessMap.get(remappedValue);
Comment thread
rbrchen marked this conversation as resolved.
Outdated
if (storedAccess == null) {
accessMap.put(remappedValue, currentAccess);
} else if (currentAccess != storedAccess && storedAccess != AccessType.RW) {
accessMap.put(remappedValue, AccessType.RW);
} else {
// this is the same access type as what's already stored
} // otherwise this is the same access type as what's already stored
}

public static void printAccessList(CoreOp.FuncOp inlinedEntryPoint, List<AccessType> accessList) {
System.out.print("func " + inlinedEntryPoint.funcName() + " has parameters");
for (AccessType at : accessList) {
Comment thread
rbrchen marked this conversation as resolved.
System.out.print(" " + at);
}
System.out.println();
}
}
2 changes: 1 addition & 1 deletion hat/core/src/main/java/hat/buffer/S32Array.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ static S32Array createFrom(ArenaAndLookupCarrier cc, int[] arr){
return ints;
}

@Reflect default int[] arrayView() {
default int[] arrayView() {
return this.copyTo(new int[this.length()]);
}
}