Skip to content

sema: The builtin function @max/@min support incompatible arbitrary i… #23581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -24040,6 +24040,7 @@ fn checkSimdBinOp(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
comptime air_tag: Air.Inst.Tag,
uncasted_lhs: Air.Inst.Ref,
uncasted_rhs: Air.Inst.Ref,
lhs_src: LazySrcLoc,
Expand All @@ -24052,11 +24053,11 @@ fn checkSimdBinOp(

try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
const vec_len: ?usize = if (lhs_ty.zigTypeTag(zcu) == .vector) lhs_ty.vectorLen(zcu) else null;
const result_ty = try sema.resolvePeerTypes(block, src, &.{ uncasted_lhs, uncasted_rhs }, .{
const result_ty = try sema.resolvePeerTypesWithOp(block, src, air_tag, &.{ uncasted_lhs, uncasted_rhs }, .{
.override = &[_]?LazySrcLoc{ lhs_src, rhs_src },
});
const lhs = try sema.coerce(block, result_ty, uncasted_lhs, lhs_src);
const rhs = try sema.coerce(block, result_ty, uncasted_rhs, rhs_src);
const lhs = try sema.coerceWithOp(block, result_ty, uncasted_lhs, lhs_src, air_tag);
const rhs = try sema.coerceWithOp(block, result_ty, uncasted_rhs, rhs_src, air_tag);

return SimdBinOp{
.len = vec_len,
Expand Down Expand Up @@ -25407,7 +25408,7 @@ fn analyzeMinMax(
continue;
};

const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
const simd_op = try sema.checkSimdBinOp(block, src, air_tag, cur, operand, cur_minmax_src, operand_src);
const cur_val = try sema.resolveLazyValue(simd_op.lhs_val.?); // cur_minmax is comptime-known
const operand_val = try sema.resolveLazyValue(simd_op.rhs_val.?); // we checked the operand was resolvable above

Expand Down Expand Up @@ -25499,7 +25500,7 @@ fn analyzeMinMax(
const lhs_src = cur_minmax_src;
const rhs = operands[idx];
const rhs_src = operand_srcs[idx];
const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
const simd_op = try sema.checkSimdBinOp(block, src, air_tag, lhs, rhs, lhs_src, rhs_src);
if (known_undef) {
cur_minmax = try pt.undefRef(simd_op.result_ty);
} else {
Expand Down Expand Up @@ -28887,6 +28888,20 @@ pub fn coerce(
};
}

pub fn coerceWithOp(
sema: *Sema,
block: *Block,
dest_ty_unresolved: Type,
inst: Air.Inst.Ref,
inst_src: LazySrcLoc,
op_tag: Air.Inst.Tag,
) CompileError!Air.Inst.Ref {
return sema.coerceExtra(block, dest_ty_unresolved, inst, inst_src, .{ .opt_op_tag = op_tag }) catch |err| switch (err) {
error.NotCoercible => unreachable,
else => |e| return e,
};
}

const CoersionError = CompileError || error{
/// When coerce is called recursively, this error should be returned instead of using `fail`
/// to ensure correct types in compile errors.
Expand All @@ -28900,6 +28915,8 @@ const CoerceOpts = struct {
is_ret: bool = false,
/// Should coercion to comptime_int emit an error message.
no_cast_to_comptime_int: bool = false,
/// The tag of operator in which the coerce is called
opt_op_tag: ?Air.Inst.Tag = null,

param_src: struct {
func_inst: Air.Inst.Ref = .none,
Expand Down Expand Up @@ -29286,6 +29303,13 @@ fn coerceExtra(
if (maybe_inst_val) |val| {
// comptime-known integer to other number
if (!(try sema.intFitsInType(val, dest_ty, null))) {
if (opts.opt_op_tag) |op_tag| {
switch (op_tag) {
.min => return Air.internedToRef((try dest_ty.maxInt(pt, dest_ty)).toIntern()),
.max => return pt.intRef(dest_ty, 0),
else => {},
}
}
if (!opts.report_err) return error.NotCoercible;
return sema.fail(block, inst_src, "type '{}' cannot represent integer value '{}'", .{ dest_ty.fmt(pt), val.fmtValueSema(pt, sema) });
}
Expand Down Expand Up @@ -29313,6 +29337,30 @@ fn coerceExtra(
try sema.requireRuntimeBlock(block, inst_src, null);
return block.addTyOp(.intcast, dest_ty, inst);
}

if (opts.opt_op_tag) |op_tag| {
switch (op_tag) {
.min => {
if (src_info.signedness != dst_info.signedness and dst_info.signedness == .signed) {
std.debug.assert(dst_info.bits <= src_info.bits);
try sema.requireRuntimeBlock(block, inst_src, null);
const max_int_inst = Air.internedToRef((try dest_ty.maxInt(pt, inst_ty)).toIntern());
const min_inst = try block.addBinOp(.min, inst, max_int_inst);
return block.addTyOp(.intcast, dest_ty, min_inst);
}
},
.max => {
if (src_info.signedness != dst_info.signedness and dst_info.signedness == .unsigned) {
std.debug.assert(dst_info.bits >= src_info.bits);
try sema.requireRuntimeBlock(block, inst_src, null);
const zero_inst = try pt.intRef(inst_ty, 0);
const max_inst = try block.addBinOp(.max, inst, zero_inst);
return block.addTyOp(.intcast, dest_ty, max_inst);
}
},
else => {},
}
}
},
else => {},
},
Expand Down Expand Up @@ -29465,9 +29513,9 @@ fn coerceExtra(
}
}

return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src);
return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag);
},
.vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src),
.vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag),
.@"struct" => {
if (inst_ty.isTuple(zcu)) {
return sema.coerceTupleToArray(block, dest_ty, dest_ty_src, inst, inst_src);
Expand All @@ -29476,7 +29524,7 @@ fn coerceExtra(
else => {},
},
.vector => switch (inst_ty.zigTypeTag(zcu)) {
.array, .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src),
.array, .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag),
.@"struct" => {
if (inst_ty.isTuple(zcu)) {
return sema.coerceTupleToArray(block, dest_ty, dest_ty_src, inst, inst_src);
Expand Down Expand Up @@ -31353,6 +31401,7 @@ fn coerceArrayLike(
dest_ty_src: LazySrcLoc,
inst: Air.Inst.Ref,
inst_src: LazySrcLoc,
opt_op_tag: ?Air.Inst.Tag,
) !Air.Inst.Ref {
const pt = sema.pt;
const zcu = pt.zcu;
Expand Down Expand Up @@ -31401,6 +31450,31 @@ fn coerceArrayLike(
try sema.requireRuntimeBlock(block, inst_src, null);
return block.addTyOp(.intcast, dest_ty, inst);
}

if (opt_op_tag) |op_tag| {
switch (op_tag) {
.min => {
if (src_info.signedness != dst_info.signedness and dst_info.signedness == .signed) {
std.debug.assert(dst_info.bits <= src_info.bits);
try sema.requireRuntimeBlock(block, inst_src, null);
const max_int_inst = Air.internedToRef((try dest_ty.maxInt(pt, inst_ty)).toIntern());
const min_inst = try block.addBinOp(.min, inst, max_int_inst);
return block.addTyOp(.intcast, dest_ty, min_inst);
}
},
.max => {
if (src_info.signedness != dst_info.signedness and dst_info.signedness == .unsigned) {
std.debug.assert(dst_info.bits >= src_info.bits);
try sema.requireRuntimeBlock(block, inst_src, null);
const zeros = try sema.splat(inst_ty, try pt.intValue(inst_elem_ty, 0));
const zero_inst = Air.internedToRef(zeros.toIntern());
const max_inst = try block.addBinOp(.max, inst, zero_inst);
return block.addTyOp(.intcast, dest_ty, max_inst);
}
},
else => {},
}
}
},
.float => if (inst_elem_ty.isRuntimeFloat()) {
// float widening
Expand All @@ -31424,7 +31498,10 @@ fn coerceArrayLike(
const src = inst_src; // TODO better source location
const elem_src = inst_src; // TODO better source location
const elem_ref = try sema.elemValArray(block, src, inst_src, inst, elem_src, index_ref, true);
const coerced = try sema.coerce(block, dest_elem_ty, elem_ref, elem_src);
const coerced = if (opt_op_tag) |op_tag|
try sema.coerceWithOp(block, dest_elem_ty, elem_ref, elem_src, op_tag)
else
try sema.coerce(block, dest_elem_ty, elem_ref, elem_src);
ref.* = coerced;
if (runtime_src == null) {
if (try sema.resolveValue(coerced)) |elem_val| {
Expand Down Expand Up @@ -33499,6 +33576,17 @@ fn resolvePeerTypes(
src: LazySrcLoc,
instructions: []const Air.Inst.Ref,
candidate_srcs: PeerTypeCandidateSrc,
) !Type {
return resolvePeerTypesWithOp(sema, block, src, null, instructions, candidate_srcs);
}

fn resolvePeerTypesWithOp(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
comptime opt_op_tag: ?Air.Inst.Tag,
instructions: []const Air.Inst.Ref,
candidate_srcs: PeerTypeCandidateSrc,
) !Type {
switch (instructions.len) {
0 => return Type.noreturn,
Expand All @@ -33525,7 +33613,7 @@ fn resolvePeerTypes(
val.* = try sema.resolveValue(inst);
}

switch (try sema.resolvePeerTypesInner(block, src, peer_tys, peer_vals)) {
switch (try sema.resolvePeerTypesInner(block, src, opt_op_tag, peer_tys, peer_vals)) {
.success => |ty| return ty,
else => |result| {
const msg = try result.report(sema, block, src, instructions, candidate_srcs);
Expand All @@ -33538,6 +33626,7 @@ fn resolvePeerTypesInner(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
comptime opt_op_tag: ?Air.Inst.Tag,
peer_tys: []?Type,
peer_vals: []?Value,
) !PeerResolveResult {
Expand Down Expand Up @@ -33623,6 +33712,7 @@ fn resolvePeerTypesInner(
const final_payload = switch (try sema.resolvePeerTypesInner(
block,
src,
opt_op_tag,
peer_tys,
peer_vals,
)) {
Expand Down Expand Up @@ -33661,6 +33751,7 @@ fn resolvePeerTypesInner(
const child_ty = switch (try sema.resolvePeerTypesInner(
block,
src,
opt_op_tag,
peer_tys,
peer_vals,
)) {
Expand Down Expand Up @@ -33810,6 +33901,7 @@ fn resolvePeerTypesInner(
const child_ty = switch (try sema.resolvePeerTypesInner(
block,
src,
opt_op_tag,
peer_tys,
peer_vals,
)) {
Expand Down Expand Up @@ -34415,6 +34507,14 @@ fn resolvePeerTypesInner(
return .{ .success = peer_tys[idx_signed.?].? };
}

if (opt_op_tag) |op_tag| {
switch (op_tag) {
.min => return .{ .success = peer_tys[idx_signed.?].? },
.max => return .{ .success = peer_tys[idx_unsigned.?].? },
else => {},
}
}

// TODO: this is for compatibility with legacy behavior. Before this version of PTR was
// implemented, the algorithm very often returned false positives, with the expectation
// that you'd just hit a coercion error later. One of these was that for integers, the
Expand Down Expand Up @@ -34533,7 +34633,7 @@ fn resolvePeerTypesInner(
}

// Resolve field type recursively
field_ty.* = switch (try sema.resolvePeerTypesInner(block, src, sub_peer_tys, sub_peer_vals)) {
field_ty.* = switch (try sema.resolvePeerTypesInner(block, src, opt_op_tag, sub_peer_tys, sub_peer_vals)) {
.success => |ty| ty.toIntern(),
else => |result| {
const result_buf = try sema.arena.create(PeerResolveResult);
Expand Down
110 changes: 110 additions & 0 deletions test/behavior/maximum_minimum.zig
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,116 @@ test "@min/@max more than two vector arguments" {
try expectEqual(@Vector(2, u32){ 5, 2 }, @max(x, y, z));
}

test "@min/@max with incompatible arbitrary integer types" {
const x: i32 = -30;
const y: u32 = 0x8000_0010;
const z: u36 = 0x2_0000_0010;

const M = struct {
fn min(lhs: i32, rhs: u32) i32 {
return @min(lhs, rhs);
}

fn min3(fst: i32, snd: u32, thrd: u36) i32 {
return @min(fst, snd, thrd);
}

fn max(lhs: i32, rhs: u32) u32 {
return @max(lhs, rhs);
}

fn max3(fst: i32, snd: u32, thrd: u36) u36 {
return @max(fst, snd, thrd);
}
};

// test min for comptime value
const min = @min(x, y);
try expectEqual(i6, @TypeOf(min));
try expectEqual(-30, min);
const min3 = @min(x, y, z);
try expectEqual(i6, @TypeOf(min3));
try expectEqual(-30, min3);

// test min for runtime value
const m_min = M.min(x, y);
try expectEqual(-30, m_min);
const m_min3 = M.min3(x, y, z);
try expectEqual(-30, m_min3);

// test max for comptime value
const max = @max(x, y);
try expectEqual(u32, @TypeOf(max));
try expectEqual(0x8000_0010, max);
const max3 = @max(x, y, z);
try expectEqual(u34, @TypeOf(max3));
try expectEqual(0x2_0000_0010, max3);

// test max for runtime value
const m_max = M.max(x, y);
try expectEqual(0x8000_0010, m_max);
const m_max3 = M.max3(x, y, z);
try expectEqual(0x2_0000_0010, m_max3);
}

test "@min/@max vector with incompatible arbitrary integer types" {
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;

const x: @Vector(2, i16) = @splat(-30);
const y: @Vector(2, u16) = @splat(0x8010);
const z: @Vector(2, u32) = @splat(0x2_0010);

const M = struct {
fn min(lhs: @Vector(2, i16), rhs: @Vector(2, u16)) @Vector(2, i16) {
return @min(lhs, rhs);
}

fn min3(fst: @Vector(2, i16), snd: @Vector(2, u16), thrd: @Vector(2, u32)) @Vector(2, i16) {
return @min(fst, snd, thrd);
}

fn max(lhs: @Vector(2, i16), rhs: @Vector(2, u16)) @Vector(2, u16) {
return @max(lhs, rhs);
}

fn max3(fst: @Vector(2, i16), snd: @Vector(2, u16), thrd: @Vector(2, u32)) @Vector(2, u32) {
return @max(fst, snd, thrd);
}
};

// test min for comptime value
const min = @min(x, y);
try expectEqual(@Vector(2, i6), @TypeOf(min));
try expectEqual(@as(@Vector(2, i16), @splat(-30)), min);
const min3 = @min(x, y, z);
try expectEqual(@Vector(2, i6), @TypeOf(min3));
try expectEqual(@as(@Vector(2, i16), @splat(-30)), min3);

// test min for runtime value
const m_min = M.min(x, y);
try expectEqual(@as(@Vector(2, i16), @splat(-30)), m_min);
const m_min3 = M.min3(x, y, z);
try expectEqual(@as(@Vector(2, i16), @splat(-30)), m_min3);

// test max for comptime value
const max = @max(x, y);
try expectEqual(@Vector(2, u16), @TypeOf(max));
try expectEqual(@as(@Vector(2, u16), @splat(0x8010)), max);
const max3 = @max(x, y, z);
try expectEqual(@Vector(2, u18), @TypeOf(max3));
try expectEqual(@as(@Vector(2, u18), @splat(0x2_0010)), max3);

// test max for runtime value
const m_max = M.max(x, y);
try expectEqual(@as(@Vector(2, u16), @splat(0x8010)), m_max);
const m_max3 = M.max3(x, y, z);
try expectEqual(@as(@Vector(2, u32), @splat(0x2_0010)), m_max3);
}

test "@min/@max notices bounds" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand Down
Loading