Skip to content

Commit ecf51e5

Browse files
committed
Fix type annotations and hidden imports
1 parent f4989fd commit ecf51e5

File tree

5 files changed

+184
-26
lines changed

5 files changed

+184
-26
lines changed

compiler/src/main/java/dev/ultreon/pythonc/JClass.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ final class JClass implements JvmClass {
1313

1414
JClass(String className) {
1515
this.className = className;
16-
this.asmType = Type.getType("L" + className + ";");
16+
this.asmType = Type.getObjectType(className.replace(".", "/"));
1717
try {
1818
this.type = Class.forName(className.replace("/", "."), false, getClass().getClassLoader());
1919
} catch (ClassNotFoundException e) {

compiler/src/main/java/dev/ultreon/pythonc/JvmWriter.java

+11
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,13 @@ public void cast(Type type) {
947947
default -> throw new RuntimeException("Unsupported cast from " + pop + " to " + type);
948948
}
949949
}
950+
case Type.ARRAY -> {
951+
if (!pop.equals(type)) {
952+
if (pop.getSort() != Type.ARRAY && pop.getSort() != Type.OBJECT) {
953+
throw new RuntimeException("Unsupported cast from " + pop + " to " + type);
954+
}
955+
}
956+
}
950957
default -> throw new RuntimeException("Unsupported cast to " + type);
951958
}
952959
context.push(type);
@@ -1147,6 +1154,10 @@ public void smartCast(Type from, Type to) {
11471154
} else if (!from.equals(to)) {
11481155
throw new RuntimeException("Cannot smart cast " + from + " to " + to);
11491156
}
1157+
} else if (to.getSort() == Type.ARRAY) {
1158+
if (!from.equals(to)) {
1159+
throw new RuntimeException("Cannot smart cast " + from + " to " + to);
1160+
}
11501161
} else {
11511162
throw new RuntimeException("Cannot smart cast " + from + " to " + to);
11521163
}

compiler/src/main/java/dev/ultreon/pythonc/PyVariable.java

+1-15
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,6 @@ public Object preload(MethodVisitor mv, PythonCompiler compiler, boolean boxed)
3535
@Override
3636
public void load(MethodVisitor mv, PythonCompiler compiler, Object preloaded, boolean boxed) {
3737
int opcode;
38-
if (type.getSort() == Type.OBJECT) {
39-
if (type.equals(Type.getType(String.class)) || type.equals(Type.BYTE_TYPE) || type.equals(Type.CHAR_TYPE) || type.equals(Type.SHORT_TYPE) || type.equals(Type.INT_TYPE) || type.equals(Type.LONG_TYPE) || type.equals(Type.FLOAT_TYPE) || type.equals(Type.DOUBLE_TYPE) || type.equals(Type.BOOLEAN_TYPE)
40-
|| type.equals(Type.getType(byte[].class)) || type.equals(Type.getType(Object[].class)) || type.equals(Type.getType(Object.class)) || type.equals(Type.getType(Class.class))
41-
|| type.equals(Type.getType(Byte.class)) || type.equals(Type.getType(Character.class)) || type.equals(Type.getType(Short.class)) || type.equals(Type.getType(Integer.class)) || type.equals(Type.getType(Long.class)) || type.equals(Type.getType(Float.class)) || type.equals(Type.getType(Double.class)) || type.equals(Type.getType(Boolean.class))) {
42-
43-
} else if (compiler.imports.get(compiler.writer.boxType(type).getClassName().substring(compiler.writer.boxType(type).getClassName().lastIndexOf('.') + 1)) == null) {
44-
throw compiler.typeNotFound(compiler.writer.boxType(type).getClassName().substring(compiler.writer.boxType(type).getClassName().lastIndexOf('.') + 1), this);
45-
}
46-
}
47-
4838
compiler.writer.loadObject(index, compiler.writer.boxType(type));
4939

5040
if (!boxed) {
@@ -87,11 +77,7 @@ public Type type(PythonCompiler compiler) {
8777
} else if (type.equals(Type.getType(Class.class))) {
8878
return Type.getType(Class.class);
8979
}
90-
if (compiler.symbols.get(type.getClassName().substring(type.getClassName().lastIndexOf('.') + 1)) == null) {
91-
throw compiler.typeNotFound(type.getClassName().substring(type.getClassName().lastIndexOf('.') + 1), this);
92-
}
93-
94-
return compiler.symbols.get(type.getClassName().substring(type.getClassName().lastIndexOf('.') + 1)).type(compiler);
80+
return compiler.typeCheck(type, this);
9581
}
9682

9783
@Override

compiler/src/main/java/dev/ultreon/pythonc/PythonCompiler.java

+167-9
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,9 @@ else if (name.getText().startsWith("_"))
678678

679679
writer.returnVoid();
680680
writer.end();
681-
681+
} catch (Exception e) {
682+
e.printStackTrace();
683+
throw new RuntimeException(e);
682684
} finally {
683685
flags.clear(F_CPL_STATIC_FUNC);
684686
flags.clear(F_CPL_CLASS_FUNC);
@@ -1193,7 +1195,7 @@ public Object visitAssignment(PythonParser.AssignmentContext ctx) {
11931195
pyVariable.set(mv, this, expr);
11941196
}
11951197
} else {
1196-
createVariable(name, switch (visit1) {
1198+
Type type = switch (visit1) {
11971199
case FuncCall funcCall -> funcCall.type(this);
11981200
case Symbol symbol -> symbol.type(this);
11991201
case Boolean booleanValue -> Type.BOOLEAN_TYPE;
@@ -1208,7 +1210,9 @@ public Object visitAssignment(PythonParser.AssignmentContext ctx) {
12081210
case PyConstant pyConstant -> pyConstant.type(this);
12091211
case PyExpr pyExpr -> pyExpr.type(this);
12101212
default -> throw new RuntimeException("Expression for variable assignment wasn't found.");
1211-
}, switch (visit1) {
1213+
};
1214+
String s = importType(type);
1215+
createVariable(name, type, switch (visit1) {
12121216
case PyConstant pyConstant -> pyConstant;
12131217
case Symbol symbol -> symbol;
12141218
case String string -> new PyConstant(string, ctx.start.getLine());
@@ -1223,6 +1227,8 @@ public Object visitAssignment(PythonParser.AssignmentContext ctx) {
12231227
case PyExpr pyExpr -> pyExpr;
12241228
default -> throw new RuntimeException("Expression for variable assignment wasn't found.");
12251229
}, false);
1230+
1231+
imports.remove(s);
12261232
}
12271233

12281234
return Unit.Instance;
@@ -1295,6 +1301,35 @@ public Object visitAssignment(PythonParser.AssignmentContext ctx) {
12951301
}
12961302
}
12971303

1304+
private String importType(Type type) {
1305+
if (type == null) throw new RuntimeException("Can't import null type");
1306+
if (type.getSort() == Type.ARRAY) {
1307+
importType(type.getElementType());
1308+
return null;
1309+
}
1310+
if (type.getSort() != Type.OBJECT) return null;
1311+
Class clazz = null;
1312+
try {
1313+
clazz = Class.forName(type.getClassName(), false, getClass().getClassLoader());
1314+
JClass value = new JClass(type.getClassName());
1315+
String internalName = type.getInternalName();
1316+
String[] split = internalName.split("/");
1317+
String importedName = split[split.length - 1];
1318+
imports.put(importedName, value);
1319+
symbols.put(importedName, value);
1320+
return importedName;
1321+
} catch (ClassNotFoundException e) {
1322+
PyClass value = classes.get(type.getClassName());
1323+
if (value == null) throw new CompilerException("JVM Class not found: " + type.getClassName());
1324+
String internalName = type.getInternalName();
1325+
String[] split = internalName.split("/");
1326+
String importedName = split[split.length - 1];
1327+
imports.put(importedName, value);
1328+
symbols.put(importedName, value);
1329+
return importedName;
1330+
}
1331+
}
1332+
12981333
public void loadConstant(ParserRuleContext ctx, Object visit1, MethodVisitor mv) {
12991334
var a = constant(ctx, visit1);
13001335

@@ -1359,6 +1394,15 @@ public Object visitBitwise_or(PythonParser.Bitwise_orContext ctx) {
13591394
}
13601395
Object finalValue = value;
13611396
Object finalAddition = addition;
1397+
if (flags.get(F_CPL_TYPE_ANNO)) {
1398+
if (addition != null) {
1399+
throw new RuntimeException("Binary operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1400+
}
1401+
1402+
if (value != null) {
1403+
return value;
1404+
}
1405+
}
13621406
PyEval.Operator operator = null;
13631407
if (ctx.VBAR() != null) {
13641408
operator = PyEval.Operator.OR;
@@ -1381,6 +1425,15 @@ public Object visitBitwise_xor(PythonParser.Bitwise_xorContext ctx) {
13811425
Object finalValue = value;
13821426
Object finalAddition = addition;
13831427
PyEval.Operator operator = null;
1428+
if (flags.get(F_CPL_TYPE_ANNO)) {
1429+
if (addition != null) {
1430+
throw new RuntimeException("Binary operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1431+
}
1432+
1433+
if (value != null) {
1434+
return value;
1435+
}
1436+
}
13841437
if (ctx.CIRCUMFLEX() != null) {
13851438
operator = PyEval.Operator.XOR;
13861439
}
@@ -1440,6 +1493,15 @@ public Object visitComparison(PythonParser.ComparisonContext ctx) {
14401493
List<PythonParser.Compare_op_bitwise_or_pairContext> compareOpBitwiseOrPairContexts = ctx.compare_op_bitwise_or_pair();
14411494
if (bitwiseOrContext != null) {
14421495
Object visit = visit(bitwiseOrContext);
1496+
if (flags.get(F_CPL_TYPE_ANNO)) {
1497+
if (!compareOpBitwiseOrPairContexts.isEmpty()) {
1498+
throw new RuntimeException("Comparison is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1499+
}
1500+
1501+
if (visit != null) {
1502+
return visit;
1503+
}
1504+
}
14431505
if (!compareOpBitwiseOrPairContexts.isEmpty()) {
14441506
for (PythonParser.Compare_op_bitwise_or_pairContext compareOpBitwiseOrPairContext : compareOpBitwiseOrPairContexts) {
14451507
Object visit1 = visit(compareOpBitwiseOrPairContext);
@@ -1480,6 +1542,28 @@ public Type type(PythonCompiler compiler) {
14801542
@Override
14811543
public Object visitCompare_op_bitwise_or_pair(PythonParser.Compare_op_bitwise_or_pairContext ctx) {
14821544
PythonParser.Eq_bitwise_orContext eqBitwiseOrContext = ctx.eq_bitwise_or();
1545+
if (flags.get(F_CPL_TYPE_ANNO)) {
1546+
if (ctx.eq_bitwise_or() != null) {
1547+
throw new RuntimeException("Equality is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1548+
}
1549+
if (ctx.noteq_bitwise_or() != null) {
1550+
throw new RuntimeException("Inequality is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1551+
}
1552+
if (ctx.gt_bitwise_or() != null) {
1553+
throw new RuntimeException("Relational operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1554+
}
1555+
if (ctx.gte_bitwise_or() != null) {
1556+
throw new RuntimeException("Relational operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1557+
}
1558+
if (ctx.lt_bitwise_or() != null) {
1559+
throw new RuntimeException("Relational operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1560+
}
1561+
if (ctx.lte_bitwise_or() != null) {
1562+
throw new RuntimeException("Relational operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1563+
}
1564+
1565+
throw new RuntimeException("No supported matching compare_op_bitwise_or_pair found for:\n" + ctx.getText());
1566+
}
14831567
if (eqBitwiseOrContext != null) {
14841568
return new PyComparison(eqBitwiseOrContext, PyComparison.Comparison.EQ, ctx);
14851569
}
@@ -1529,6 +1613,15 @@ public Object visitBitwise_and(PythonParser.Bitwise_andContext ctx) {
15291613
Object finalValue = value;
15301614
Object finalAddition = addition;
15311615
PyEval.Operator operator = null;
1616+
if (flags.get(F_CPL_TYPE_ANNO)) {
1617+
if (addition != null) {
1618+
throw new RuntimeException("Type annotation is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1619+
}
1620+
1621+
if (value != null) {
1622+
return value;
1623+
}
1624+
}
15321625
if (ctx.AMPER() != null) {
15331626
operator = PyEval.Operator.AND;
15341627
}
@@ -1550,6 +1643,15 @@ public Object visitShift_expr(PythonParser.Shift_exprContext ctx) {
15501643
Object finalValue = value;
15511644
Object finalAddition = addition;
15521645
PyEval.Operator operator = null;
1646+
if (flags.get(F_CPL_TYPE_ANNO)) {
1647+
if (addition != null) {
1648+
throw new RuntimeException("Type annotation is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1649+
}
1650+
1651+
if (value != null) {
1652+
return value;
1653+
}
1654+
}
15531655
if (ctx.LEFTSHIFT() != null) {
15541656
operator = PyEval.Operator.LSHIFT;
15551657
} else if (ctx.RIGHTSHIFT() != null) {
@@ -1573,6 +1675,15 @@ public Object visitSum(PythonParser.SumContext ctx) {
15731675
Object finalValue = value;
15741676
Object finalAddition = addition;
15751677
PyEval.Operator operator = null;
1678+
if (flags.get(F_CPL_TYPE_ANNO)) {
1679+
if (addition != null) {
1680+
throw new RuntimeException("Type annotation is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1681+
}
1682+
1683+
if (value != null) {
1684+
return value;
1685+
}
1686+
}
15761687
if (ctx.PLUS() != null) {
15771688
operator = PyEval.Operator.ADD;
15781689
} else if (ctx.MINUS() != null) {
@@ -1615,6 +1726,15 @@ public Object visitTerm(PythonParser.TermContext ctx) {
16151726
Object finalValue = value;
16161727
Object finalAddition = addition;
16171728
PyEval.Operator operator = null;
1729+
if (flags.get(F_CPL_TYPE_ANNO)) {
1730+
if (addition != null) {
1731+
throw new RuntimeException("Type annotation is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1732+
}
1733+
1734+
if (value != null) {
1735+
return value;
1736+
}
1737+
}
16181738
if (ctx.STAR() != null) {
16191739
operator = PyEval.Operator.MUL;
16201740
} else if (ctx.SLASH() != null) {
@@ -1641,6 +1761,9 @@ public Object visitFactor(PythonParser.FactorContext ctx) {
16411761
}
16421762
Object finalValue = value;
16431763
PyEval.Operator operator = PyEval.Operator.UNARY_MINUS;
1764+
if (flags.get(F_CPL_TYPE_ANNO)) {
1765+
throw new RuntimeException("Unary operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1766+
}
16441767
return new PyEval(this, ctx, operator, finalValue, null);
16451768
}
16461769

@@ -1655,6 +1778,9 @@ public Object visitFactor(PythonParser.FactorContext ctx) {
16551778
}
16561779
Object finalValue = value;
16571780
PyEval.Operator operator = PyEval.Operator.UNARY_PLUS;
1781+
if (flags.get(F_CPL_TYPE_ANNO)) {
1782+
throw new RuntimeException("Unary operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1783+
}
16581784
return new PyEval(this, ctx, operator, finalValue, null);
16591785
}
16601786

@@ -1669,6 +1795,9 @@ public Object visitFactor(PythonParser.FactorContext ctx) {
16691795
}
16701796
Object finalValue = value;
16711797
PyEval.Operator operator = PyEval.Operator.UNARY_NOT;
1798+
if (flags.get(F_CPL_TYPE_ANNO)) {
1799+
throw new RuntimeException("Unary operator is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1800+
}
16721801
return new PyEval(this, ctx, operator, finalValue, null);
16731802
}
16741803
PythonParser.PowerContext powerContext = ctx.power();
@@ -1694,6 +1823,15 @@ public Object visitPower(PythonParser.PowerContext ctx) {
16941823
Object finalValue = value;
16951824
Object finalAddition = addition;
16961825
PyEval.Operator operator = PyEval.Operator.POW;
1826+
if (flags.get(F_CPL_TYPE_ANNO)) {
1827+
if (addition != null) {
1828+
throw new RuntimeException("Type annotation is not allowed in type annotations (at " + fileName + ":" + ctx.getStart().getLine() + ":" + ctx.getStart().getCharPositionInLine() + ")");
1829+
}
1830+
1831+
if (value != null) {
1832+
return value;
1833+
}
1834+
}
16971835
return new PyEval(this, ctx, operator, finalValue, finalAddition);
16981836

16991837
}
@@ -2126,12 +2264,7 @@ int createVariable(String name, Type type, PyExpr expr, boolean boxed) {
21262264
writer.localVariable(name, Type.getType(Object.class).getDescriptor(), null, endLabel, endLabel, currentVariableIndex);
21272265
mv.visitLineNumber(expr.lineNo(), label);
21282266
int opcode;
2129-
if (!type.equals(Type.getType(String.class)) && !type.equals(Type.LONG_TYPE) && !type.equals(Type.DOUBLE_TYPE)
2130-
&& !type.equals(Type.FLOAT_TYPE) && !type.equals(Type.INT_TYPE) && !type.equals(Type.BOOLEAN_TYPE)
2131-
&& !type.equals(Type.BYTE_TYPE) && !type.equals(Type.SHORT_TYPE)
2132-
&& symbols.get(type.getClassName().substring(type.getClassName().lastIndexOf('.') + 1)) == null) {
2133-
throw typeNotFound(type.getClassName(), expr);
2134-
}
2267+
typeCheck(type, expr);
21352268
Context context = writer.getContext();
21362269
context.pop();
21372270

@@ -2142,6 +2275,31 @@ int createVariable(String name, Type type, PyExpr expr, boolean boxed) {
21422275
return currentVariableIndex++;
21432276
}
21442277

2278+
public Type typeCheck(Type type, PyExpr expr) {
2279+
if (!type.equals(Type.getType(String.class)) && !type.equals(Type.LONG_TYPE) && !type.equals(Type.DOUBLE_TYPE)
2280+
&& !type.equals(Type.FLOAT_TYPE) && !type.equals(Type.INT_TYPE) && !type.equals(Type.BOOLEAN_TYPE)
2281+
&& !type.equals(Type.BYTE_TYPE) && !type.equals(Type.SHORT_TYPE)
2282+
&& symbols.get(type.getClassName().substring(type.getClassName().lastIndexOf('.') + 1)) == null) {
2283+
if (type.getSort() == Type.ARRAY) {
2284+
Type actualType = type.getElementType();
2285+
while (actualType.getSort() == Type.ARRAY) {
2286+
actualType = actualType.getElementType();
2287+
}
2288+
if (!actualType.equals(Type.getType(String.class)) && !actualType.equals(Type.LONG_TYPE) && !actualType.equals(Type.DOUBLE_TYPE)
2289+
&& !actualType.equals(Type.FLOAT_TYPE) && !actualType.equals(Type.INT_TYPE) && !actualType.equals(Type.BOOLEAN_TYPE)
2290+
&& !actualType.equals(Type.BYTE_TYPE) && !actualType.equals(Type.SHORT_TYPE)
2291+
&& symbols.get(actualType.getClassName().substring(actualType.getClassName().lastIndexOf('.') + 1)) == null) {
2292+
throw typeNotFound(actualType.getClassName(), expr);
2293+
}
2294+
2295+
return actualType;
2296+
} else {
2297+
throw typeNotFound(type.getClassName(), expr);
2298+
}
2299+
}
2300+
return type;
2301+
}
2302+
21452303
CompilerException typeNotFound(String type, PyExpr expr) {
21462304
return new CompilerException("Type '" + type + "' not found " + getLocation(expr));
21472305
}

0 commit comments

Comments
 (0)