Skip to content

Commit b446644

Browse files
author
Mourad Abbay
committed
Update the modeling of switch label
Reviewed-by: psandoz
1 parent a368ff6 commit b446644

6 files changed

Lines changed: 615 additions & 105 deletions

File tree

src/jdk.incubator.code/share/classes/jdk/incubator/code/bytecode/impl/LoweringTransform.java

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
* Lowering transformer generates models supported by {@code BytecodeGenerator}.
5454
* Constant-labeled switch statements and switch expressions are lowered to
5555
* {@code ConstantLabelSwitchOp} with evaluated labels.
56+
* We expect label value to be the second operand to the operation that perform equality check.
5657
*/
5758
public final class LoweringTransform {
5859

@@ -175,22 +176,11 @@ private static List<Integer> isCaseConstantLabel(MethodHandles.Lookup l, Body la
175176
MethodRef objectsEquals = MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class);
176177
switch (r.op()) {
177178
case JavaOp.EqOp eqOp -> {
178-
Value labelValue = getLabelValue(eqOp);
179-
Optional<Object> v = JavaOp.JavaExpression.evaluate(l, labelValue);
179+
Optional<Object> v = JavaOp.JavaExpression.evaluate(l, eqOp.operands().getLast());
180180
v.ifPresent(o -> labels.add(toInteger(o)));
181181
}
182182
case JavaOp.InvokeOp ie when ie.invokeReference().equals(objectsEquals) -> {
183-
Value labelValue;
184-
if (ie.operands().getLast() instanceof Op.Result opr && opr.op() instanceof JavaOp.InvokeOp ib
185-
&& integralReferenceTypes.contains(ib.invokeReference().refType()) && ib.invokeReference().name().equals("valueOf")) {
186-
// workaround the modeling of switch that has a selector of type primitive wrapper and a case constant of type primitive
187-
// we skip the boxing operation that's contained in the model
188-
// invoking a boxing method is not a valid operation in a constant expr
189-
labelValue = ib.operands().getFirst();
190-
} else {
191-
labelValue = getLabelValue(ie);
192-
}
193-
Optional<Object> v = JavaOp.JavaExpression.evaluate(l, labelValue);
183+
Optional<Object> v = JavaOp.JavaExpression.evaluate(l, ie.operands().getLast());
194184
v.ifPresent(o -> labels.add(toInteger(o)));
195185
}
196186
case JavaOp.ConditionalOrOp cor -> {
@@ -211,21 +201,6 @@ private static List<Integer> isCaseConstantLabel(MethodHandles.Lookup l, Body la
211201
return labels;
212202
}
213203

214-
private static Value getLabelValue(Op equalityOp) {
215-
if (equalityOp.operands().size() != 2) {
216-
throw new IllegalArgumentException("Expect operation to have two operands");
217-
}
218-
Value labelValue;
219-
if (equalityOp.operands().getFirst() instanceof Block.Parameter && equalityOp.operands().getLast() instanceof Op.Result) {
220-
labelValue = equalityOp.operands().getLast();
221-
} else if (equalityOp.operands().getFirst() instanceof Op.Result && equalityOp.operands().getLast() instanceof Block.Parameter) {
222-
labelValue = equalityOp.operands().getFirst();
223-
} else {
224-
throw new IllegalArgumentException("Switch label model not valid");
225-
}
226-
return labelValue;
227-
}
228-
229204
private static Integer toInteger(Object o) {
230205
return switch (o) {
231206
case Byte b -> Integer.valueOf(b);

src/jdk.incubator.code/share/classes/jdk/incubator/code/internal/ReflectMethods.java

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
import java.lang.constant.ClassDesc;
102102
import java.util.*;
103103
import java.util.List;
104+
import java.util.function.BiFunction;
104105
import java.util.function.Function;
105106
import java.util.function.Supplier;
106107

@@ -1643,7 +1644,7 @@ private List<Body.Builder> visitSwitchStatAndExpr(JCTree tree, JCExpression sele
16431644
boolean hasDefaultCase = false;
16441645

16451646
for (JCTree.JCCase c : cases) {
1646-
Body.Builder caseLabel = visitCaseLabel(tree, selector, target, c);
1647+
Body.Builder caseLabel = visitCaseLabel(tree, target, c);
16471648
Body.Builder caseBody = visitCaseBody(tree, c, caseBodyType, cases.getLast() == c);
16481649
bodies.add(caseLabel);
16491650
bodies.add(caseBody);
@@ -1669,7 +1670,27 @@ private List<Body.Builder> visitSwitchStatAndExpr(JCTree tree, JCExpression sele
16691670
return bodies;
16701671
}
16711672

1672-
private Body.Builder visitCaseLabel(JCTree tree, JCExpression selector, Value target, JCTree.JCCase c) {
1673+
private Value processConstantLabel(Value target, JCTree.JCConstantCaseLabel label) {
1674+
if (target.type().equals(JavaType.J_L_STRING)) {
1675+
return append(JavaOp.invoke(
1676+
MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class),
1677+
target, toValue(label.expr)));
1678+
} else {
1679+
// target is primitive wrapper, primitive or enum
1680+
// if target of type Character, Byte, Short or Integer, unbox it
1681+
if (target.type().equals(JavaType.J_L_CHARACTER) || target.type().equals(JavaType.J_L_BYTE) ||
1682+
target.type().equals(JavaType.J_L_SHORT) || target.type().equals(JavaType.J_L_INTEGER)) {
1683+
PrimitiveType pt = ((ClassType) target.type()).unbox().get();
1684+
target = convert(target, typeElementToType(pt));
1685+
}
1686+
Value expr = toValue(label.expr);
1687+
// conversion may be needed for primitive, e.g. label (byte) 1 and selector of type int
1688+
expr = convert(expr, typeElementToType(target.type()));
1689+
return append(JavaOp.eq(target, expr));
1690+
}
1691+
}
1692+
1693+
private Body.Builder visitCaseLabel(JCTree tree, Value target, JCTree.JCCase c) {
16731694
Body.Builder body;
16741695
FunctionType caseLabelType = CoreType.functionType(JavaType.BOOLEAN, target.type());
16751696

@@ -1687,6 +1708,8 @@ private Body.Builder visitCaseLabel(JCTree tree, JCExpression selector, Value ta
16871708
List<Body.Builder> clBodies = new ArrayList<>();
16881709

16891710
pushBody(pcl.pat, CoreType.functionType(JavaType.BOOLEAN));
1711+
1712+
localTarget = boxIfNeeded(localTarget);
16901713
Value patVal = scanPattern(pcl.pat, localTarget);
16911714
append(CoreOp.core_yield(patVal));
16921715
clBodies.add(stack.body);
@@ -1699,6 +1722,7 @@ private Body.Builder visitCaseLabel(JCTree tree, JCExpression selector, Value ta
16991722

17001723
localResult = append(JavaOp.conditionalAnd(clBodies));
17011724
} else {
1725+
localTarget = boxIfNeeded(localTarget);
17021726
localResult = scanPattern(pcl.pat, localTarget);
17031727
}
17041728
// Yield the boolean result of the condition
@@ -1713,33 +1737,14 @@ private Body.Builder visitCaseLabel(JCTree tree, JCExpression selector, Value ta
17131737
Value localTarget = stack.block.parameters().get(0);
17141738
final Value localResult;
17151739
if (c.labels.size() == 1) {
1716-
Value expr = toValue(ccl.expr);
1717-
// per java spec, constant type is compatible with the type of the selector expression
1718-
// so, we convert constant to the type of the selector expression
1719-
expr = convert(expr, selector.type);
1720-
if (selector.type.isPrimitive()) {
1721-
localResult = append(JavaOp.eq(localTarget, expr));
1722-
} else {
1723-
localResult = append(JavaOp.invoke(
1724-
MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class),
1725-
localTarget, expr));
1726-
}
1740+
localResult = processConstantLabel(localTarget, ccl);
17271741
} else {
17281742
List<Body.Builder> clBodies = new ArrayList<>();
17291743
for (JCTree.JCCaseLabel cl : c.labels) {
17301744
ccl = (JCTree.JCConstantCaseLabel) cl;
17311745
pushBody(ccl, CoreType.functionType(JavaType.BOOLEAN));
17321746

1733-
Value expr = toValue(ccl.expr);
1734-
expr = convert(expr, selector.type);
1735-
final Value labelResult;
1736-
if (selector.type.isPrimitive()) {
1737-
labelResult = append(JavaOp.eq(localTarget, expr));
1738-
} else {
1739-
labelResult = append(JavaOp.invoke(
1740-
MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class),
1741-
localTarget, expr));
1742-
}
1747+
final Value labelResult = processConstantLabel(localTarget, ccl);
17431748

17441749
append(CoreOp.core_yield(labelResult));
17451750
clBodies.add(stack.body);

test/jdk/jdk/incubator/code/TestSwitchExpressionOp.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
/*
1818
* @test
1919
* @modules jdk.incubator.code
20+
* @enablePreview
2021
* @run junit TestSwitchExpressionOp
2122
* @run main Unreflect TestSwitchExpressionOp
2223
* @run junit TestSwitchExpressionOp
@@ -454,6 +455,78 @@ static String defaultCaseNotTheLast(String s) {
454455
};
455456
}
456457

458+
@Reflect
459+
static String caseConstantPrimitiveWrapperSelector(Integer i) {
460+
return switch (i) {
461+
case 1 -> "one";
462+
case 2, 3 -> "two or three";
463+
default -> "else";
464+
};
465+
}
466+
467+
@Test
468+
void testCaseConstantPrimitiveWrapperSelector() {
469+
CoreOp.FuncOp lf = lower("caseConstantPrimitiveWrapperSelector");
470+
Integer[] args = {1, 2, 3, 4};
471+
for (Integer a : args) {
472+
Assertions.assertEquals(caseConstantPrimitiveWrapperSelector(a),
473+
Interpreter.invoke(MethodHandles.lookup(), lf, a));
474+
}
475+
}
476+
477+
@Reflect
478+
static String constantLabelCasted(int i) {
479+
return switch (i) {
480+
case (byte) 1 -> "one";
481+
default -> "not one";
482+
};
483+
}
484+
485+
@Test
486+
void testConstantLabelCasted() {
487+
CoreOp.FuncOp lf = lower("constantLabelCasted");
488+
int[] args = {-1, 1};
489+
for (int a : args) {
490+
Assertions.assertEquals(constantLabelCasted(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
491+
}
492+
}
493+
494+
@Reflect
495+
static String caseConstantStringLiteral(String s) {
496+
return switch (s) {
497+
case "1" -> "one";
498+
case "2", "3" -> "two or three";
499+
default -> "else";
500+
};
501+
}
502+
503+
@Test
504+
void testCeaseConstantStringLiteral() {
505+
CoreOp.FuncOp lf = lower("caseConstantStringLiteral");
506+
String[] args = {"1", "2", "3", ""};
507+
for (String a : args) {
508+
Assertions.assertEquals(caseConstantStringLiteral(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
509+
}
510+
}
511+
512+
@Reflect
513+
static String casePatternWithCaseConstant2(int i) {
514+
return switch (i) {
515+
case 0 -> "zero";
516+
case Integer j when j > 0 -> "positive";
517+
case Integer _ -> "negative";
518+
};
519+
}
520+
521+
@Test
522+
void testCasePatterWithCaseConstant() {
523+
CoreOp.FuncOp lf = lower("casePatternWithCaseConstant2");
524+
Integer[] args = {2, 0, -1};
525+
for (Integer a : args) {
526+
Assertions.assertEquals(casePatternWithCaseConstant2(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
527+
}
528+
}
529+
457530
// we are not testing switch expr that has no default,
458531
// because to test for MatchException we need to set up separate compilation
459532
// in compiler tests we are checking that the code model contains a default case that throws MatchException

test/jdk/jdk/incubator/code/TestSwitchStatementOp.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,66 @@ static String defaultCaseNotTheLast(String s) {
510510
return r;
511511
}
512512

513+
@Reflect
514+
static String caseConstantPrimitiveWrapperSelector(Integer i) {
515+
String r = "";
516+
switch (i) {
517+
case 1 -> r += "one";
518+
case 2, 3 -> r += "two or three";
519+
default -> r += "else";
520+
};
521+
return r;
522+
}
523+
524+
@Test
525+
void testCaseConstantPrimitiveWrapperSelector() {
526+
CoreOp.FuncOp lf = lower("caseConstantPrimitiveWrapperSelector");
527+
Integer[] args = {1, 2, 3, 4};
528+
for (Integer a : args) {
529+
Assertions.assertEquals(caseConstantPrimitiveWrapperSelector(a),
530+
Interpreter.invoke(MethodHandles.lookup(), lf, a));
531+
}
532+
}
533+
534+
@Reflect
535+
static String constantLabelCasted(int i) {
536+
String r = "";
537+
switch (i) {
538+
case (byte) 1 -> r += "one";
539+
default -> r += "not one";
540+
};
541+
return r;
542+
}
543+
544+
@Test
545+
void testConstantLabelCasted() {
546+
CoreOp.FuncOp lf = lower("constantLabelCasted");
547+
int[] args = {-1, 1};
548+
for (int a : args) {
549+
Assertions.assertEquals(constantLabelCasted(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
550+
}
551+
}
552+
553+
@Reflect
554+
static String caseConstantStringLiteral(String s) {
555+
String r = "";
556+
switch (s) {
557+
case "1" -> r += "one";
558+
case "2", "3" -> r+= "two or three";
559+
default -> r += "else";
560+
};
561+
return r;
562+
}
563+
564+
@Test
565+
void testCeaseConstantStringLiteral() {
566+
CoreOp.FuncOp lf = lower("caseConstantStringLiteral");
567+
String[] args = {"1", "2", "3", ""};
568+
for (String a : args) {
569+
Assertions.assertEquals(caseConstantStringLiteral(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
570+
}
571+
}
572+
513573
@Test
514574
void testTryAndSwitch() {
515575
CoreOp.FuncOp lmodel = lower("tryAndSwitch");

0 commit comments

Comments
 (0)