From 01404f3b006fa91669f169691b1a336ccc6199f2 Mon Sep 17 00:00:00 2001 From: Ruby Chen Date: Wed, 25 Mar 2026 10:56:35 -0700 Subject: [PATCH 1/4] Initial buffer tagger refactoring --- hat/core/src/main/java/hat/BufferTagger.java | 86 ++++++++++--------- .../src/main/java/hat/buffer/S32Array.java | 2 +- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/hat/core/src/main/java/hat/BufferTagger.java b/hat/core/src/main/java/hat/BufferTagger.java index 4216d6aa335..4a1f16eadf4 100644 --- a/hat/core/src/main/java/hat/BufferTagger.java +++ b/hat/core/src/main/java/hat/BufferTagger.java @@ -26,12 +26,12 @@ package hat; import hat.phases.HATPhaseUtils; +import jdk.incubator.code.dialect.java.ArrayType; import jdk.incubator.code.dialect.java.JavaOp; +import jdk.incubator.code.dialect.java.PrimitiveType; import optkl.IfaceValue; import optkl.OpHelper; import optkl.ifacemapper.AccessType; -import optkl.ifacemapper.Buffer; -import optkl.ifacemapper.MappableIface; import jdk.incubator.code.Op; import jdk.incubator.code.Value; import jdk.incubator.code.Block; @@ -45,8 +45,9 @@ import static optkl.OpHelper.Invoke.invoke; public class BufferTagger { - static HashMap accessMap = new HashMap<>(); - static HashMap remappedVals = new HashMap<>(); // maps values to their "root" parameter/value + static HashMap accessMap = new HashMap<>(); // mapping of parameters/buffers to access type + // TODO: fix how we use rootValues + static HashMap rootValues = new HashMap<>(); // maps values to their "root" parameter/value static HashMap> blockParams = new HashMap<>(); // holds block parameters for easy lookup // generates a list of AccessTypes matching the given FuncOp's parameter order @@ -56,7 +57,7 @@ public static ArrayList getAccessList(MethodHandles.Lookup lookup, C for (Block.Parameter p : inlinedEntryPoint.body().entryBlock().parameters()) { if (accessMap.containsKey(p)) { accessList.add(accessMap.get(p)); // is an accessed buffer - } else if (OpHelper.isAssignable(lookup, p.type(), MappableIface.class)) { + } else if (OpHelper.isAssignable(lookup, p.type(), IfaceValue.class)) { accessList.add(AccessType.NA); // is a buffer but not accessed } else { accessList.add(AccessType.NOT_BUFFER); // is not a buffer @@ -64,16 +65,17 @@ public static ArrayList getAccessList(MethodHandles.Lookup lookup, C } return accessList; } - private static boolean isReference(Invoke ioh) { - return ioh.returns(IfaceValue.class) + + private static boolean isAccessed(Invoke ioh) { + return !(ioh.returns(IfaceValue.class) && ioh.opFromOnlyUseOrNull() instanceof JavaOp.InvokeOp nextInvoke && invoke(ioh.lookup(), nextInvoke) instanceof Invoke nextIoh && nextIoh.refIs(IfaceValue.class) - && nextIoh.returnsVoid(); + && nextIoh.returnsVoid()); } // creates the access map - private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { + private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { // build blockParams so that we can map params to "root" params later funcOp.elements() .filter(elem -> elem instanceof Block) @@ -88,25 +90,24 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f mapBranch(lookup, cb.falseBranch()); // handle false branch } case JavaOp.InvokeOp invokeOp -> { - var ioh = invoke(lookup,invokeOp); - // we have to deal with array views too - // should .arrayview() calls be marked as reads? - if ( ioh.refIs(IfaceValue.class)) { - // updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access - // if the invokeOp retrieves an element that is only written to, don't update the access type - // (i.e. the only use is an invoke, the invoke is of MappableIface/HAType class, and is a write) - if (!isReference(ioh)) { // value retrieved and not just referenced? - updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access - } - if (ioh.refIs(IfaceValue.class) && (ioh.returns(IfaceValue.class) || ioh.returnsArray())) { - // if we access a struct/union from a buffer, we map the struct/union to the buffer root - remappedVals.put(invokeOp.result(), getRootValue(invokeOp)); + var ioh = invoke(lookup,invokeOp); + if (ioh.refIs(KernelContext.class)) break; // if this is not referencing a buffer, we continue + if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value + for (Value operand : ioh.op().operands()) { + if (!(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) { + // TODO: FIX + if (operand instanceof Block.Parameter) updateAccessType(operand, AccessType.RO); + else updateAccessType(getRootValue(operand.result().op()), AccessType.RO); + } } + rootValues.put(invokeOp.result(), getRootValue(invokeOp)); + } else if (isAccessed(ioh)) { // if we actually operate on a buffer instead of storing an element in a variable + updateAccessType(getRootValue(invokeOp), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access } } case CoreOp.VarOp vop -> { // map the new VarOp to the "root" param - if (OpHelper.isAssignable(lookup, vop.resultType().valueType(), Buffer.class)) { - remappedVals.put(vop.initOperand(), getRootValue(vop)); + if (OpHelper.isAssignable(lookup, vop.resultType().valueType(), IfaceValue.class)) { + rootValues.put(vop.initOperand(), getRootValue(vop)); }else{ // or else maybe CoreOp.VarOp vop when ??? -> } @@ -118,7 +119,10 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp f // or else } } - case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> updateAccessType(getRootValue(alop), AccessType.RO); + case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> { + if (alop.resultType() instanceof ArrayType) break; + updateAccessType(getRootValue(alop), AccessType.RO); + } case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> updateAccessType(getRootValue(asop), AccessType.WO); default -> {} } @@ -131,30 +135,30 @@ private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference block for (int i = 0; i < args.size(); i++) { Value key = blockParams.get(blockReference.targetBlock()).get(i); Value value = args.get(i); - if (value instanceof Op.Result result) { - // either find root param or it doesn't exist (is a constant for example) - if (OpHelper.isAssignable(lookup, value.type(), MappableIface.class)) { - value = getRootValue(result.op()); - if (value instanceof Block.Parameter) { - value = remappedVals.getOrDefault(value, value); - } - }else{ - // or else + if (value instanceof Op.Result result && OpHelper.isAssignable(lookup, value.type(), IfaceValue.class)) { + value = getRootValue(result.op()); + if (value instanceof Block.Parameter) { + value = rootValues.getOrDefault(value, value); } }else{ - // or else? + // or else? } - remappedVals.put(key, value); + rootValues.put(key, value); } } // retrieves "root" value of an op, which is how we track accesses // we will map the return value of this method to the accessType - private static Value getRootValue(Op op) { + private static Value getRootValue(Op op) { // the op is a field load, an invoke, or something that reduces to one or the other // first, check if we can retrieve a fieldloadop from the given op Op fieldOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class); - if (fieldOp != null) return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses + if (fieldOp != null) { + if (fieldOp.operands().isEmpty()) { + return null; + } + return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses + } // we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created) // or if there's an invokeop with a parameter as its first operation (this is a global buffer) @@ -166,9 +170,9 @@ private static Value getRootValue(Op op) { return (invokeOp == null) ? null : invokeOp.result(); // return the shared/private buffer invokeop that creates the buffer } - // updates accessMap - private static void updateAccessType(Value value, AccessType currentAccess) { - Value remappedValue = remappedVals.getOrDefault(value, value); + // updates the access map + private static void updateAccessType(Value value, AccessType currentAccess) { + Value remappedValue = rootValues.getOrDefault(value, value); AccessType storedAccess = accessMap.get(remappedValue); if (storedAccess == null) { accessMap.put(remappedValue, currentAccess); diff --git a/hat/core/src/main/java/hat/buffer/S32Array.java b/hat/core/src/main/java/hat/buffer/S32Array.java index 7c3bee2e179..1743e55a701 100644 --- a/hat/core/src/main/java/hat/buffer/S32Array.java +++ b/hat/core/src/main/java/hat/buffer/S32Array.java @@ -69,7 +69,7 @@ static S32Array createFrom(ArenaAndLookupCarrier cc, int[] arr){ return ints; } - @Reflect default int[] arrayView() { + default int[] arrayView() { return this.copyTo(new int[this.length()]); } } From 0f0c339243d11a3b4db4ec6b116db8a03038f99f Mon Sep 17 00:00:00 2001 From: Ruby Chen Date: Thu, 26 Mar 2026 16:02:44 -0700 Subject: [PATCH 2/4] More cleanup --- hat/core/src/main/java/hat/BufferTagger.java | 112 +++++++----------- .../src/main/java/hat/buffer/S32Array.java | 2 +- 2 files changed, 45 insertions(+), 69 deletions(-) diff --git a/hat/core/src/main/java/hat/BufferTagger.java b/hat/core/src/main/java/hat/BufferTagger.java index 4a1f16eadf4..c9a2a944069 100644 --- a/hat/core/src/main/java/hat/BufferTagger.java +++ b/hat/core/src/main/java/hat/BufferTagger.java @@ -40,21 +40,19 @@ import java.lang.invoke.MethodHandles; import java.util.ArrayList; import java.util.HashMap; +import java.util.Map; import java.util.List; -import static optkl.OpHelper.Invoke; import static optkl.OpHelper.Invoke.invoke; public class BufferTagger { - static HashMap accessMap = new HashMap<>(); // mapping of parameters/buffers to access type - // TODO: fix how we use rootValues - static HashMap rootValues = new HashMap<>(); // maps values to their "root" parameter/value - static HashMap> blockParams = new HashMap<>(); // holds block parameters for easy lookup + static Map accessMap = new HashMap<>(); // mapping of parameters/buffers to access type + static Map rootValues = new HashMap<>(); // maps values to their "root" parameter/value // generates a list of AccessTypes matching the given FuncOp's parameter order - public static ArrayList getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp inlinedEntryPoint) { - buildAccessMap(lookup, inlinedEntryPoint); - ArrayList accessList = new ArrayList<>(); - for (Block.Parameter p : inlinedEntryPoint.body().entryBlock().parameters()) { + public static List getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { + buildAccessMap(lookup, funcOp); + List accessList = new ArrayList<>(); + for (Block.Parameter p : funcOp.body().entryBlock().parameters()) { if (accessMap.containsKey(p)) { accessList.add(accessMap.get(p)); // is an accessed buffer } else if (OpHelper.isAssignable(lookup, p.type(), IfaceValue.class)) { @@ -66,22 +64,8 @@ public static ArrayList getAccessList(MethodHandles.Lookup lookup, C return accessList; } - private static boolean isAccessed(Invoke ioh) { - return !(ioh.returns(IfaceValue.class) - && ioh.opFromOnlyUseOrNull() instanceof JavaOp.InvokeOp nextInvoke - && invoke(ioh.lookup(), nextInvoke) instanceof Invoke nextIoh - && nextIoh.refIs(IfaceValue.class) - && nextIoh.returnsVoid()); - } - // creates the access map private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - // build blockParams so that we can map params to "root" params later - funcOp.elements() - .filter(elem -> elem instanceof Block) - .map(elem->(Block)elem) - .forEach(block -> blockParams.put(block, block.parameters())); - funcOp.elements().forEach(op -> { switch (op) { case CoreOp.BranchOp b -> mapBranch(lookup, b.branch()); @@ -91,33 +75,26 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp fu } case JavaOp.InvokeOp invokeOp -> { var ioh = invoke(lookup,invokeOp); - if (ioh.refIs(KernelContext.class)) break; // if this is not referencing a buffer, we continue + if (ioh.refIs(KernelContext.class)) break; // if this is not referencing a buffer, we break if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value for (Value operand : ioh.op().operands()) { if (!(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) { - // TODO: FIX if (operand instanceof Block.Parameter) updateAccessType(operand, AccessType.RO); else updateAccessType(getRootValue(operand.result().op()), AccessType.RO); } } rootValues.put(invokeOp.result(), getRootValue(invokeOp)); - } else if (isAccessed(ioh)) { // if we actually operate on a buffer instead of storing an element in a variable - updateAccessType(getRootValue(invokeOp), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access + } else { // if we actually operate on a buffer instead of storing an element in a variable + updateAccessType(rootValues.getOrDefault(invokeOp.result(), getRootValue(invokeOp)), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access } } case CoreOp.VarOp vop -> { // map the new VarOp to the "root" param - if (OpHelper.isAssignable(lookup, vop.resultType().valueType(), IfaceValue.class)) { - rootValues.put(vop.initOperand(), getRootValue(vop)); - }else{ - // or else maybe CoreOp.VarOp vop when ??? -> - } + if (!OpHelper.isAssignable(lookup, vop.resultType().valueType(), IfaceValue.class)) break; + rootValues.put(vop.initOperand(), getRootValue(vop)); } case JavaOp.FieldAccessOp.FieldLoadOp flop -> { - if (OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) { - updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access - }else{ - // or else - } + if (!OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) break; + updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access } case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> { if (alop.resultType() instanceof ArrayType) break; @@ -131,43 +108,36 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp fu // maps the parameters of a block to the values passed to a branch private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference blockReference) { - List args = blockReference.arguments(); - for (int i = 0; i < args.size(); i++) { - Value key = blockParams.get(blockReference.targetBlock()).get(i); - Value value = args.get(i); - if (value instanceof Op.Result result && OpHelper.isAssignable(lookup, value.type(), IfaceValue.class)) { - value = getRootValue(result.op()); - if (value instanceof Block.Parameter) { - value = rootValues.getOrDefault(value, value); - } - }else{ - // or else? - } - rootValues.put(key, value); + List inputArgs = blockReference.arguments(); + List targetArgs = blockReference.targetBlock().parameters(); + for (int i = 0; i < inputArgs.size(); i++) { + Value target = targetArgs.get(i); + Value input = inputArgs.get(i); + if (!(input instanceof Op.Result result && OpHelper.isAssignable(lookup, input.type(), IfaceValue.class))) break; + input = getRootValue(result.op()); + rootValues.put(target, rootValues.getOrDefault(input, input)); } } // retrieves "root" value of an op, which is how we track accesses - // we will map the return value of this method to the accessType private static Value getRootValue(Op op) { // the op is a field load, an invoke, or something that reduces to one or the other - // first, check if we can retrieve a fieldloadop from the given op - Op fieldOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class); - if (fieldOp != null) { - if (fieldOp.operands().isEmpty()) { - return null; + Op rootOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class, JavaOp.InvokeOp.class); + switch (rootOp) { + case JavaOp.FieldAccessOp.FieldLoadOp fieldOp -> { + if (fieldOp.operands().isEmpty()) break; // e.g. handling kc.warpSize + return fieldOp.operands().getFirst(); } - return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses - } - - // we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created) - // or if there's an invokeop with a parameter as its first operation (this is a global buffer) - Op invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class); - while (invokeOp != null && !invokeOp.operands().isEmpty()) { - if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer - invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class); + case JavaOp.InvokeOp invokeOp -> { + while (invokeOp != null && !invokeOp.operands().isEmpty()) { // we look for either the parameter or initialization for the buffer + if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer + invokeOp = (JavaOp.InvokeOp) HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class); + } + if (invokeOp != null) return invokeOp.result(); + } + case null, default -> {} } - return (invokeOp == null) ? null : invokeOp.result(); // return the shared/private buffer invokeop that creates the buffer + return null; } // updates the access map @@ -178,8 +148,14 @@ private static void updateAccessType(Value value, AccessType currentAccess) { accessMap.put(remappedValue, currentAccess); } else if (currentAccess != storedAccess && storedAccess != AccessType.RW) { accessMap.put(remappedValue, AccessType.RW); - } else { - // this is the same access type as what's already stored + } // otherwise this is the same access type as what's already stored + } + + public static void printAccessList(CoreOp.FuncOp inlinedEntryPoint, List accessList) { + System.out.print("func " + inlinedEntryPoint.funcName() + " has parameters"); + for (AccessType at : accessList) { + System.out.print(" " + at); } + System.out.println(); } } \ No newline at end of file diff --git a/hat/core/src/main/java/hat/buffer/S32Array.java b/hat/core/src/main/java/hat/buffer/S32Array.java index 1743e55a701..9fdbc989c21 100644 --- a/hat/core/src/main/java/hat/buffer/S32Array.java +++ b/hat/core/src/main/java/hat/buffer/S32Array.java @@ -69,7 +69,7 @@ static S32Array createFrom(ArenaAndLookupCarrier cc, int[] arr){ return ints; } - default int[] arrayView() { + default int[] arrayView() { return this.copyTo(new int[this.length()]); } } From 8af6b80be7f205883792174916a783bfb56f0568 Mon Sep 17 00:00:00 2001 From: Ruby Chen Date: Sat, 28 Mar 2026 16:57:39 -0700 Subject: [PATCH 3/4] Fixes --- hat/core/src/main/java/hat/BufferTagger.java | 70 ++++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/hat/core/src/main/java/hat/BufferTagger.java b/hat/core/src/main/java/hat/BufferTagger.java index c9a2a944069..4a84d88d627 100644 --- a/hat/core/src/main/java/hat/BufferTagger.java +++ b/hat/core/src/main/java/hat/BufferTagger.java @@ -42,6 +42,7 @@ import java.util.HashMap; import java.util.Map; import java.util.List; +import java.util.stream.IntStream; import static optkl.OpHelper.Invoke.invoke; public class BufferTagger { @@ -73,34 +74,27 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp fu mapBranch(lookup, cb.trueBranch()); // handle true branch mapBranch(lookup, cb.falseBranch()); // handle false branch } - case JavaOp.InvokeOp invokeOp -> { - var ioh = invoke(lookup,invokeOp); - if (ioh.refIs(KernelContext.class)) break; // if this is not referencing a buffer, we break + case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof OpHelper.Invoke ioh && !ioh.refIs(KernelContext.class) -> { if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value for (Value operand : ioh.op().operands()) { if (!(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) { if (operand instanceof Block.Parameter) updateAccessType(operand, AccessType.RO); - else updateAccessType(getRootValue(operand.result().op()), AccessType.RO); + else updateAccessType(operand.result().op(), AccessType.RO); } } - rootValues.put(invokeOp.result(), getRootValue(invokeOp)); + rootValues.put(ioh.returnResult(), getRootValue(ioh.op())); } else { // if we actually operate on a buffer instead of storing an element in a variable - updateAccessType(rootValues.getOrDefault(invokeOp.result(), getRootValue(invokeOp)), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access + updateAccessType(ioh.op(), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access } } - case CoreOp.VarOp vop -> { // map the new VarOp to the "root" param - if (!OpHelper.isAssignable(lookup, vop.resultType().valueType(), IfaceValue.class)) break; - rootValues.put(vop.initOperand(), getRootValue(vop)); - } - case JavaOp.FieldAccessOp.FieldLoadOp flop -> { - if (!OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) break; - updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access - } - case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> { - if (alop.resultType() instanceof ArrayType) break; - updateAccessType(getRootValue(alop), AccessType.RO); - } - case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> updateAccessType(getRootValue(asop), AccessType.WO); + case CoreOp.VarOp vop when OpHelper.isAssignable(lookup, vop.varValueType(), IfaceValue.class) -> + rootValues.put(vop.initOperand(), getRootValue(vop)); // map the new VarOp to the "root" param + case JavaOp.FieldAccessOp.FieldLoadOp flop when OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class) -> + updateAccessType(flop, AccessType.RO); // handle kc access + case JavaOp.ArrayAccessOp.ArrayLoadOp alop when !(alop.resultType() instanceof ArrayType) -> + updateAccessType(alop, AccessType.RO); + case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> + updateAccessType(asop, AccessType.WO); default -> {} } }); @@ -110,13 +104,13 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp fu private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference blockReference) { List inputArgs = blockReference.arguments(); List targetArgs = blockReference.targetBlock().parameters(); - for (int i = 0; i < inputArgs.size(); i++) { - Value target = targetArgs.get(i); - Value input = inputArgs.get(i); - if (!(input instanceof Op.Result result && OpHelper.isAssignable(lookup, input.type(), IfaceValue.class))) break; - input = getRootValue(result.op()); - rootValues.put(target, rootValues.getOrDefault(input, input)); - } + IntStream.range(0, inputArgs.size()).filter(i -> + inputArgs.get(i) instanceof Op.Result && OpHelper.isAssignable(lookup, inputArgs.get(i).type(), IfaceValue.class)) + .forEach(i -> { + Value input = inputArgs.get(i); + input = getRootValue(input.result().op()); + rootValues.put(targetArgs.get(i), rootValues.getOrDefault(input, input)); + }); } // retrieves "root" value of an op, which is how we track accesses @@ -124,8 +118,7 @@ private static Value getRootValue(Op op) { // the op is a field load, an invoke, or something that reduces to one or the other Op rootOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class, JavaOp.InvokeOp.class); switch (rootOp) { - case JavaOp.FieldAccessOp.FieldLoadOp fieldOp -> { - if (fieldOp.operands().isEmpty()) break; // e.g. handling kc.warpSize + case JavaOp.FieldAccessOp.FieldLoadOp fieldOp when !fieldOp.operands().isEmpty() -> { return fieldOp.operands().getFirst(); } case JavaOp.InvokeOp invokeOp -> { @@ -140,22 +133,27 @@ private static Value getRootValue(Op op) { return null; } + // retrieves root value of op before updating the access map + private static void updateAccessType(Op op, AccessType currentAccess) { + updateAccessType(getRootValue(op), currentAccess); + } + // updates the access map private static void updateAccessType(Value value, AccessType currentAccess) { - Value remappedValue = rootValues.getOrDefault(value, value); - AccessType storedAccess = accessMap.get(remappedValue); + AccessType storedAccess = accessMap.get(value); if (storedAccess == null) { - accessMap.put(remappedValue, currentAccess); + accessMap.put(value, currentAccess); } else if (currentAccess != storedAccess && storedAccess != AccessType.RW) { - accessMap.put(remappedValue, AccessType.RW); + accessMap.put(value, AccessType.RW); } // otherwise this is the same access type as what's already stored } - public static void printAccessList(CoreOp.FuncOp inlinedEntryPoint, List accessList) { - System.out.print("func " + inlinedEntryPoint.funcName() + " has parameters"); + public static void printAccessList(CoreOp.FuncOp funcOp, List accessList) { + StringBuilder output = new StringBuilder(); + output.append("func ").append(funcOp.funcName()).append(" has parameters"); for (AccessType at : accessList) { - System.out.print(" " + at); + output.append(" ").append(at); } - System.out.println(); + System.out.println(output); } } \ No newline at end of file From 37c837f56e98167c15dd0be0b7b4f81942cce400 Mon Sep 17 00:00:00 2001 From: Ruby Chen Date: Mon, 30 Mar 2026 11:06:27 -0700 Subject: [PATCH 4/4] Add blocks for if statements and use streams --- hat/core/src/main/java/hat/BufferTagger.java | 24 +++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/hat/core/src/main/java/hat/BufferTagger.java b/hat/core/src/main/java/hat/BufferTagger.java index 4a84d88d627..745e53680c1 100644 --- a/hat/core/src/main/java/hat/BufferTagger.java +++ b/hat/core/src/main/java/hat/BufferTagger.java @@ -26,6 +26,7 @@ package hat; import hat.phases.HATPhaseUtils; +import jdk.incubator.code.CodeItem; import jdk.incubator.code.dialect.java.ArrayType; import jdk.incubator.code.dialect.java.JavaOp; import jdk.incubator.code.dialect.java.PrimitiveType; @@ -76,12 +77,15 @@ private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp fu } case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof OpHelper.Invoke ioh && !ioh.refIs(KernelContext.class) -> { if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value - for (Value operand : ioh.op().operands()) { - if (!(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) { - if (operand instanceof Block.Parameter) updateAccessType(operand, AccessType.RO); - else updateAccessType(operand.result().op(), AccessType.RO); - } - } + ioh.op().operands().stream() + .filter(operand -> !(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand)) + .forEach(operand -> { + if (operand instanceof Block.Parameter) { + updateAccessType(operand, AccessType.RO); + } else { + updateAccessType(operand.result().op(), AccessType.RO); + } + }); rootValues.put(ioh.returnResult(), getRootValue(ioh.op())); } else { // if we actually operate on a buffer instead of storing an element in a variable updateAccessType(ioh.op(), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access @@ -123,10 +127,14 @@ private static Value getRootValue(Op op) { } case JavaOp.InvokeOp invokeOp -> { while (invokeOp != null && !invokeOp.operands().isEmpty()) { // we look for either the parameter or initialization for the buffer - if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer + if (invokeOp.operands().getFirst() instanceof Block.Parameter p) { + return p; // return the parameter that is the global buffer + } invokeOp = (JavaOp.InvokeOp) HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class); } - if (invokeOp != null) return invokeOp.result(); + if (invokeOp != null) { + return invokeOp.result(); + } } case null, default -> {} }