Skip to content

Commit

Permalink
finish chunk writing codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
chaokunyang committed Jan 27, 2025
1 parent 3dc7956 commit 3e633ad
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
import static org.apache.fury.codegen.ExpressionOptimizer.invokeGenerated;
import static org.apache.fury.codegen.ExpressionUtils.add;
import static org.apache.fury.codegen.ExpressionUtils.and;
import static org.apache.fury.codegen.ExpressionUtils.cast;
import static org.apache.fury.codegen.ExpressionUtils.eq;
import static org.apache.fury.codegen.ExpressionUtils.eqNull;
import static org.apache.fury.codegen.ExpressionUtils.gt;
import static org.apache.fury.codegen.ExpressionUtils.inline;
import static org.apache.fury.codegen.ExpressionUtils.isNull;
import static org.apache.fury.codegen.ExpressionUtils.list;
import static org.apache.fury.codegen.ExpressionUtils.neq;
import static org.apache.fury.codegen.ExpressionUtils.neqNull;
import static org.apache.fury.codegen.ExpressionUtils.not;
import static org.apache.fury.codegen.ExpressionUtils.notNull;
import static org.apache.fury.codegen.ExpressionUtils.nullValue;
Expand Down Expand Up @@ -349,7 +352,7 @@ protected Expression serializeFor(
buffer, "writeByte", new Literal(Fury.REF_VALUE_FLAG, PRIMITIVE_BYTE_TYPE)),
serializeForNotNull(inputObject, buffer, typeRef, serializer, generateNewMethod));
return new If(
ExpressionUtils.eqNull(inputObject),
eqNull(inputObject),
new Invoke(buffer, "writeByte", new Literal(Fury.NULL_FLAG, PRIMITIVE_BYTE_TYPE)),
action);
}
Expand Down Expand Up @@ -1050,9 +1053,16 @@ private Expression jitWriteMap(
map = new Invoke(serializer, "onMapWrite", TypeUtils.mapOf(keyType, valueType), buffer, map);
Expression iterator =
new Invoke(inlineInvoke(map, "entrySet", SET_TYPE), "iterator", ITERATOR_TYPE);
Expression entry = new Invoke(iterator, "next", MAP_ENTRY_TYPE);
Expression entry =
new Cast(inlineInvoke(iterator, "next", OBJECT_TYPE), MAP_ENTRY_TYPE, "entry");
boolean keyMonomorphic = isMonomorphic(keyType);
boolean valueMonomorphic = isMonomorphic(valueType);
Class<?> keyTypeRawType = keyType.getRawType();
Class<?> valueTypeRawType = valueType.getRawType();
boolean trackingKeyRef =
visitFury(fury -> fury.getClassResolver().needToWriteRef(keyTypeRawType));
boolean trackingValueRef =
visitFury(fury -> fury.getClassResolver().needToWriteRef(valueTypeRawType));
Expression keySerializer, valueSerializer;
if (keyMonomorphic && valueMonomorphic) {
keySerializer = getOrCreateSerializer(keyType.getRawType());
Expand All @@ -1069,22 +1079,28 @@ private Expression jitWriteMap(
}
Expression.While whileAction =
new Expression.While(
notNull(entry),
() ->
new ListExpression(
new Assign(
inlineInvoke(
serializer,
"writeJavaNullChunk",
MAP_ENTRY_TYPE,
buffer,
entry,
iterator,
keySerializer,
valueSerializer),
entry),
new If(
notNull(entry), writeChunk(buffer, entry, iterator, keyType, valueType))));
neqNull(entry),
() -> {
String method = "writeJavaNullChunk";
if (keyMonomorphic && valueMonomorphic) {
if (!trackingKeyRef && !trackingValueRef) {
method = "writeNullChunkKVNoRef";
}
}
return new ListExpression(
new Assign(
entry,
inlineInvoke(
serializer,
method,
MAP_ENTRY_TYPE,
buffer,
entry,
iterator,
keySerializer,
valueSerializer)),
new If(neqNull(entry), writeChunk(buffer, entry, iterator, keyType, valueType)));
});

return new If(not(inlineInvoke(map, "isEmpty", PRIMITIVE_BOOLEAN_TYPE)), whileAction);
}
Expand All @@ -1096,27 +1112,38 @@ protected Expression writeChunk(
TypeRef<?> keyType,
TypeRef<?> valueType) {
ListExpression expressions = new ListExpression();
Expression key =
tryCastIfPublic(new Invoke(entry, "getKey", "keyObj", OBJECT_TYPE), keyType, "key");
Expression key = tryCastIfPublic(inlineInvoke(entry, "getKey", OBJECT_TYPE), keyType, "key");
Expression value =
tryCastIfPublic(new Invoke(entry, "getValue", "valueObj", OBJECT_TYPE), valueType, "value");
Expression keyTypeExpr = new Invoke(key, "getClass", "keyType", CLASS_TYPE);
Expression valueTypeExpr = new Invoke(value, "getClass", "valueType", CLASS_TYPE);
Expression writePlaceHolder = new Invoke(buffer, "writeInt16", Literal.ofShort((short) -1));
Expression chunkSizeOffset =
subtract(inlineInvoke(buffer, "writerIndex", PRIMITIVE_INT_TYPE), Literal.ofInt(-1));
tryCastIfPublic(inlineInvoke(entry, "getValue", OBJECT_TYPE), valueType, "value");
boolean keyMonomorphic = isMonomorphic(keyType);
boolean valueMonomorphic = isMonomorphic(valueType);
Class<?> keyTypeRawType = keyType.getRawType();
Class<?> valueTypeRawType = valueType.getRawType();
Expression keyTypeExpr =
keyMonomorphic
? getClassExpr(keyTypeRawType)
: new Invoke(key, "getClass", "keyType", CLASS_TYPE);
Expression valueTypeExpr =
valueMonomorphic
? getClassExpr(valueTypeRawType)
: new Invoke(value, "getClass", "valueType", CLASS_TYPE);
Expression writePlaceHolder = new Invoke(buffer, "writeInt16", Literal.ofShort((short) -1));
Expression chunkSizeOffset =
subtract(
inlineInvoke(buffer, "writerIndex", PRIMITIVE_INT_TYPE),
Literal.ofInt(1),
"chunkSizeOffset");

Expression chunkHeader;
Expression keySerializer, valueSerializer;
boolean keyMonomorphic = isMonomorphic(keyType);
boolean valueMonomorphic = isMonomorphic(valueType);

boolean trackingKeyRef =
visitFury(fury -> fury.getClassResolver().needToWriteRef(keyTypeRawType));
boolean trackingValueRef =
visitFury(fury -> fury.getClassResolver().needToWriteRef(valueTypeRawType));
Expression keyWriteRef = Literal.ofBoolean(trackingKeyRef);
Expression valueWriteRef = Literal.ofBoolean(trackingValueRef);
boolean inline = keyMonomorphic && valueMonomorphic;
if (keyMonomorphic && valueMonomorphic) {
keySerializer = getOrCreateSerializer(keyTypeRawType);
valueSerializer = getOrCreateSerializer(valueTypeRawType);
Expand Down Expand Up @@ -1174,7 +1201,6 @@ protected Expression writeChunk(
}
}
Expression chunkSize = ofInt("chunkSize", 0);

expressions.add(
key,
value,
Expand All @@ -1189,21 +1215,38 @@ protected Expression writeChunk(
valueWriteRef,
new Invoke(buffer, "putByte", subtract(chunkSizeOffset, Literal.ofInt(1)), chunkHeader),
chunkSize);

Expression keyWriteRefExpr = keyWriteRef;
Expression valueWriteRefExpr = valueWriteRef;
Expression.While writeLoop =
new Expression.While(
Literal.ofBoolean(true),
() ->
new ListExpression(
new If(
or(
isNull(key),
isNull(value),
neq(inlineInvoke(key, "getClass", CLASS_TYPE), keyTypeExpr),
neq(inlineInvoke(value, "getClass", CLASS_TYPE), valueTypeExpr)),
new Break()),
() -> {
Expression breakCondition;
if (keyMonomorphic && valueMonomorphic) {
breakCondition = or(eqNull(key), eqNull(value));
} else if (keyMonomorphic) {
breakCondition =
or(
eqNull(key),
eqNull(value),
neq(inlineInvoke(value, "getClass", CLASS_TYPE), valueTypeExpr));
} else if (valueMonomorphic) {
breakCondition =
or(
eqNull(key),
eqNull(value),
neq(inlineInvoke(key, "getClass", CLASS_TYPE), keyTypeExpr));
} else {
breakCondition =
or(
eqNull(key),
eqNull(value),
neq(inlineInvoke(key, "getClass", CLASS_TYPE), keyTypeExpr),
neq(inlineInvoke(value, "getClass", CLASS_TYPE), valueTypeExpr));
}
Expression writeKey = new Invoke(keySerializer, "write", buffer, key);
if (trackingKeyRef) {
writeKey =
new If(
or(
not(keyWriteRefExpr),
Expand All @@ -1214,7 +1257,11 @@ protected Expression writeChunk(
PRIMITIVE_BOOLEAN_TYPE,
buffer,
key))),
new Invoke(keySerializer, "write", buffer, key)),
writeKey);
}
Expression writeValue = new Invoke(valueSerializer, "write", buffer, value);
if (trackingValueRef) {
writeValue =
new If(
or(
not(valueWriteRefExpr),
Expand All @@ -1225,24 +1272,43 @@ protected Expression writeChunk(
PRIMITIVE_BOOLEAN_TYPE,
buffer,
value))),
new Invoke(valueSerializer, "write", buffer, value)),
new Assign(add(chunkSize, Literal.ofInt(1)), chunkSize),
new If(eq(chunkSize, Literal.ofInt(MAX_CHUNK_SIZE)), new Break()),
new If(
inlineInvoke(iterator, "hasNext", PRIMITIVE_BOOLEAN_TYPE),
new ListExpression(
new Assign(inlineInvoke(iterator, "next", MAP_ENTRY_TYPE), entry),
new Assign(
tryInlineCast(inlineInvoke(entry, "getKey", OBJECT_TYPE), keyType),
key),
new Assign(
tryInlineCast(
inlineInvoke(entry, "getValue", OBJECT_TYPE), valueType),
value)),
new ListExpression(
new Assign(new Literal(null, MAP_ENTRY_TYPE), entry), new Break()))));
expressions.add(
writeLoop, new Invoke(buffer, "putByte", chunkSizeOffset, chunkSize), new Return(entry));
writeValue);
}
return new ListExpression(
new If(breakCondition, new Break()),
writeKey,
writeValue,
new Assign(chunkSize, add(chunkSize, Literal.ofInt(1))),
new If(eq(chunkSize, Literal.ofInt(MAX_CHUNK_SIZE)), new Break()),
new If(
inlineInvoke(iterator, "hasNext", PRIMITIVE_BOOLEAN_TYPE),
new ListExpression(
new Assign(
entry,
cast(inlineInvoke(iterator, "next", OBJECT_TYPE), MAP_ENTRY_TYPE)),
new Assign(
key,
tryInlineCast(inlineInvoke(entry, "getKey", OBJECT_TYPE), keyType)),
new Assign(
value,
tryInlineCast(
inlineInvoke(entry, "getValue", OBJECT_TYPE), valueType))),
list(new Assign(entry, new Literal(null, MAP_ENTRY_TYPE)), new Break())));
});
expressions.add(writeLoop, new Invoke(buffer, "putByte", chunkSizeOffset, chunkSize));
if (!inline) {
expressions.add(new Return(entry));
// method too big, spilt it into a new method.
// Generate similar signature as `AbstractMapSerializer.writeJavaChunk`(
// MemoryBuffer buffer,
// Entry<Object, Object> entry,
// Iterator<Entry<Object, Object>> iterator,
// Serializer keySerializer,
// Serializer valueSerializer
// )
Set<Expression> params = ofHashSet(buffer, entry, iterator);
return invokeGenerated(ctx, params, expressions, "writeChunk", false);
}
return expressions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,26 +266,29 @@ public ExprCode doGenCode(CodegenContext ctx) {
if (StringUtils.isNotBlank(targetExprCode.code())) {
codeBuilder.append(targetExprCode.code()).append('\n');
}
String name = ctx.newName(namePrefix);
String decl =
StringUtils.format(
"${type} ${name} = ${from};",
"type",
ctx.type(type()),
"name",
ctx.newName(namePrefix),
"from",
targetExprCode.value());
StringUtils.format(
"${type} ${name} = ${from};",
"type",
ctx.type(type()),
"name",
name,
"from",
targetExprCode.value());
codeBuilder.append(decl);
return new ExprCode(codeBuilder.toString(), null, null);
return new ExprCode(
codeBuilder.toString(),
targetExprCode.isNull(),
Code.variable(type().getRawType(), name));
}

@Override
public String toString() {
return String.format("%s %s = %s", type(), namePrefix, from);
return String.format("%s %s = %s;", type(), namePrefix, from);
}
}


class Literal implements Expression {
public static final Literal True = new Literal(true, PRIMITIVE_BOOLEAN_TYPE);
public static final Literal False = new Literal(false, PRIMITIVE_BOOLEAN_TYPE);
Expand Down Expand Up @@ -2595,12 +2598,12 @@ public String toString() {
}

class Assign implements Expression {
private Expression from;
private Expression to;
private Expression from;

public Assign(Expression from, Expression to) {
this.from = from;
public Assign(Expression to, Expression from) {
this.to = to;
this.from = from;
}

@Override
Expand All @@ -2622,7 +2625,7 @@ public ExprCode doGenCode(CodegenContext ctx) {
});
String assign =
StringUtils.format(
"${from} = ${to};", "from", fromExprCode.value(), "to", toExprCode.value());
"${to} = ${from};", "from", fromExprCode.value(), "to", toExprCode.value());
codeBuilder.append(assign);
return new ExprCode(codeBuilder.toString(), null, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Collections;
import java.util.List;
import org.apache.fury.codegen.Expression.Cast;
import org.apache.fury.codegen.Expression.ListExpression;
import org.apache.fury.codegen.Expression.LogicalAnd;
import org.apache.fury.codegen.Expression.LogicalOr;
import org.apache.fury.codegen.Expression.Null;
Expand All @@ -45,6 +46,9 @@

/** Expression utils to create expression and code in a more convenient way. */
public class ExpressionUtils {
public static ListExpression list(Expression... expressions) {
return new ListExpression(expressions);
}

public static Expression newObjectArray(Expression... expressions) {
return new NewArray(TypeRef.of(Object[].class), expressions);
Expand All @@ -71,6 +75,11 @@ public static Expression eqNull(Expression target) {
return eq(target, new Null(target.type()));
}

public static Expression neqNull(Expression target) {
Preconditions.checkArgument(!target.type().isPrimitive());
return neq(target, new Null(target.type()));
}

public static LogicalAnd and(Expression left, Expression right, String name) {
return new LogicalAnd(false, left, right);
}
Expand Down Expand Up @@ -153,7 +162,7 @@ public static Arithmetic subtract(Expression left, Expression right) {
}

public static Arithmetic subtract(Expression left, Expression right, String valuePrefix) {
Arithmetic arithmetic = new Arithmetic(true, "-", left, right);
Arithmetic arithmetic = new Arithmetic(false, "-", left, right);
arithmetic.valuePrefix = valuePrefix;
return arithmetic;
}
Expand Down
Loading

0 comments on commit 3e633ad

Please sign in to comment.