Skip to content

Commit 7b142b3

Browse files
jacobhinklenaoyam
andauthored
Clip slice range expressions (#460)
This PR normalizes the inputs to `slice` in order to mimic the semantics of numpy/PyTorch slicing. For an axis with extent `ext`, if we receive a slice of `(start, stop, step)` we normalize it to `(norm_start, norm_stop, step)` where ``` norm_start = max(0, start < 0 ? start + ext : start); norm_stop = max(norm_start, min(ext, stop < 0 ? stop + ext : stop)); ``` Specific changes in this PR: - Form the above expressions in the `slice` op. - Add shmoo tests that test various scenarios with constant and input size slices. The simple Fusion in the input range test prints like this: ``` Inputs: T0_g[ iS0{9} ], float i3, nvfuser_index_t i4, nvfuser_index_t Outputs: T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ], float %kernel_math { b7 = i3 < 0; i5 = i3 + 9; i9 = where(b7, i5, i3); i11 = fmax(0, i9); b15 = i4 < 0; i13 = i4 + 9; i17 = where(b15, i13, i4); i19 = fmin(9, i17); i21 = fmax(i11, i19); T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ] = slice( T0_g[ iS0{9} ], { {i11, i21, 1} } ) } T0_g[ iS0{9} ] root domain : (iS0{9}) contiguity: f leaf domain : (iS0{9}) T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ] root domain : (iS1{9}rf) Resize: iS1{9}rf by ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) and ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) -> ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf rfactor domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf) contiguity: t leaf domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf) ``` resulting in the following CUDA kernel: ```c++ __global__ void kernel1(Tensor<float, 1, 1> T0, nvfuser_index_t i0, nvfuser_index_t i1, Tensor<float, 1, 1> T1) { nvfuser_index_t i2; i2 = i0 + 9; bool b3; b3 = i0 < 0; nvfuser_index_t i4; i4 = b3 ? i2 : i0; nvfuser_index_t i5; i5 = max(0, i4); nvfuser_index_t i6; i6 = i1 + 9; bool b7; b7 = i1 < 0; nvfuser_index_t i8; i8 = b7 ? i6 : i1; nvfuser_index_t i9; i9 = min(9, i8); nvfuser_index_t i10; i10 = max(i5, i9); nvfuser_index_t i11; i11 = (-i5) + i10; nvfuser_index_t i12; i12 = i5 * T0.alloc_stride[0]; #pragma unroll 1 for(nvfuser_index_t i13 = 0; i13 < i11; ++i13) { T1[i13] = T0[(i12 + (T0.alloc_stride[0] * i13))]; } } ``` This PR does NOT simplify these expressions for non-constant inputs. This can be done at concretization, which will be left for a follow-up PR. Stacked on #892 and #895. Fixes #439. Fixes #52. --------- Co-authored-by: Naoya Maruyama <[email protected]>
1 parent 2dcfef6 commit 7b142b3

File tree

5 files changed

+174
-26
lines changed

5 files changed

+174
-26
lines changed

csrc/dynamic_transform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void DynamicTransformConcretizationInfo::analyzeResizes(
316316
out_id->toString());
317317
auto extent_int = extent_val.as<int64_t>();
318318
NVF_ERROR(
319-
extent_int > 0,
319+
extent_int >= 0,
320320
"Invalid resized domain extent ",
321321
extent_int,
322322
" for domain ",

csrc/kernel_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -897,8 +897,8 @@ FusionKernelRuntime::FusionKernelRuntime(
897897
fusion.get());
898898

899899
if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) {
900-
std::cout << "Fusion IR after pre-segmenter optimization passes:"
901-
<< std::endl;
900+
debug() << "Fusion IR after pre-segmenter optimization passes:"
901+
<< std::endl;
902902
fusion->printMath();
903903
}
904904

csrc/ops/alias.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,6 @@ TensorView* cat(
690690
return out;
691691
}
692692

693-
// Currently there's no error check about the actual values of the
694-
// Slice parameters. For example, the start parameter of a range of a
695-
// domain is assumed to be >= 0 and < the extent of the domain.
696693
TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
697694
const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
698695
const int ndims = static_cast<int>(inp_dom.size());
@@ -704,36 +701,58 @@ TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
704701
", Expected: ",
705702
ndims);
706703

707-
auto normalize_slice_range = [](Slice range, Val* extent) -> Slice {
704+
const auto normalize_slice_range = [](Slice range, Val* extent) -> Slice {
705+
auto cast_extent =
706+
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent);
707+
708+
auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index);
709+
710+
// norm_start = max(0, start < 0 ? start + extent : start)
708711
if (range.start == nullptr) {
709-
range.start = FusionGuard::getCurFusion()->zeroVal();
710-
}
711-
if (range.stop == nullptr) {
712-
range.stop = extent;
713-
}
714-
if (range.step == nullptr) {
715-
range.step = FusionGuard::getCurFusion()->oneVal();
716-
}
717-
if (range.start->dtype() != DataType::Index) {
712+
range.start = zero;
713+
} else if (!range.start->isZeroInt()) {
718714
range.start =
719715
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start);
716+
range.start = SimplifyingIrBuilder::maxExpr(
717+
zero,
718+
SimplifyingIrBuilder::whereExpr(
719+
SimplifyingIrBuilder::ltExpr(range.start, zero),
720+
SimplifyingIrBuilder::addExpr(range.start, cast_extent),
721+
range.start));
720722
}
721-
if (range.stop->dtype() != DataType::Index) {
723+
724+
// norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop)
725+
if (range.stop == nullptr) {
726+
range.stop = cast_extent;
727+
} else if (!range.stop->sameAs(extent)) {
722728
range.stop =
723729
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop);
730+
range.stop = SimplifyingIrBuilder::maxExpr(
731+
range.start,
732+
SimplifyingIrBuilder::minExpr(
733+
cast_extent,
734+
SimplifyingIrBuilder::whereExpr(
735+
SimplifyingIrBuilder::ltExpr(range.stop, zero),
736+
SimplifyingIrBuilder::addExpr(range.stop, cast_extent),
737+
range.stop)));
724738
}
725-
if (range.step->dtype() != DataType::Index) {
739+
740+
// Ensure step is of type Index
741+
if (range.step == nullptr) {
742+
range.step = FusionGuard::getCurFusion()->oneVal(DataType::Index);
743+
} else {
726744
range.step =
727745
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.step);
728746
}
747+
729748
return range;
730749
};
731750

732751
for (auto& range : ranges) {
733752
// Step not supported yet
734753
NVF_CHECK(
735754
range.step == nullptr || range.step->isOneInt(),
736-
"Unsupported step: ",
755+
"Unsupported step (must be 1 or null): ",
737756
range.step->toString());
738757
}
739758

@@ -754,12 +773,13 @@ TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges) {
754773
out_root_id = inp_root_id->cloneWithoutRFactor();
755774
out_rf_id = out_root_id;
756775
} else {
776+
// Clip the start and stop values to the extent of the input
757777
out_root_id =
758778
IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build();
759779
out_rf_id = IterDomain::resize(
760780
out_root_id,
761781
SimplifyingIrBuilder::negExpr(range.start),
762-
sub(range.stop, inp_root_id->extent()),
782+
SimplifyingIrBuilder::subExpr(range.stop, inp_root_id->extent()),
763783
true);
764784
needs_real_slicing = true;
765785
}

csrc/ops/alias.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ TensorView* cat(
9191
std::optional<IterType> iter_type_opt = std::nullopt);
9292

9393
//! Return a tensor where each dimension is sliced as specified by the
94-
//! ranges parameter. Stepping must be one at this moment.
94+
//! ranges parameter. Stepping must be one at this moment. The semantics of
95+
//! slicing with negative values and values >= extent follow those of numpy and
96+
//! PyTorch.
9597
TensorView* slice(TensorView* inp, const std::vector<Slice>& ranges);
9698

9799
} // namespace nvfuser

test/test_resize.cpp

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,132 @@ TEST_F(ResizeTest, FusionResizeSlice5) {
11241124
testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__);
11251125
}
11261126

1127+
std::vector<std::pair<int64_t, int64_t>> slice_cases(
1128+
{{0, 5},
1129+
{3, 9},
1130+
{3, 4},
1131+
{7, 5},
1132+
{0, 11},
1133+
{11, 13},
1134+
{-3, 8},
1135+
{-3, -1},
1136+
{-3, -5},
1137+
{13, -1},
1138+
{-11, 9},
1139+
{-11, 0},
1140+
{-13, -11}});
1141+
1142+
// Test slice with a variety of constant ranges
1143+
TEST_F(NVFuserTest, FusionResizeSliceConstantShmoo_CUDA) {
1144+
for (auto [start, stop] : slice_cases) {
1145+
Fusion fusion;
1146+
FusionGuard fg(&fusion);
1147+
1148+
std::vector<int64_t> shape({9});
1149+
1150+
// concrete shapes to avoid dynamic Fusion
1151+
auto tv0 = makeConcreteTensor(shape);
1152+
fusion.addInput(tv0);
1153+
1154+
auto tv1 = slice(
1155+
tv0, {{IrBuilder::create<Val>(start), IrBuilder::create<Val>(stop)}});
1156+
fusion.addOutput(tv1);
1157+
1158+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1159+
1160+
auto t0 = at::randn(shape, options);
1161+
std::vector<c10::IValue> aten_inputs({t0});
1162+
1163+
FusionExecutor fe;
1164+
fe.compileFusion(&fusion, aten_inputs);
1165+
auto cg_outputs = fe.runFusion(aten_inputs);
1166+
1167+
testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
1168+
}
1169+
}
1170+
1171+
// Test slice with a variety of non-constant input ranges
1172+
TEST_F(NVFuserTest, FusionResizeSliceInputShmoo_CUDA) {
1173+
Fusion fusion;
1174+
FusionGuard fg(&fusion);
1175+
1176+
std::vector<int64_t> shape({9});
1177+
1178+
// concrete shapes to avoid dynamic Fusion
1179+
auto tv0 = makeConcreteTensor(shape);
1180+
auto s0 = IrBuilder::create<Val>(DataType::Index);
1181+
auto s1 = IrBuilder::create<Val>(DataType::Index);
1182+
fusion.addInput(tv0);
1183+
fusion.addInput(s0);
1184+
fusion.addInput(s1);
1185+
1186+
auto tv1 = slice(tv0, {{s0, s1}});
1187+
fusion.addOutput(tv1);
1188+
1189+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1190+
1191+
{
1192+
// Concretize so that we set output IterType as Iteration. We should now
1193+
// have expressions that work with any input range.
1194+
ExpressionEvaluator expr_eval;
1195+
1196+
expr_eval.bind(tv0->axis(0)->extent(), 9);
1197+
expr_eval.bind(s0, 0);
1198+
expr_eval.bind(s1, 9);
1199+
1200+
auto initial_info = DynamicTransform::getInitialInfo(&fusion);
1201+
auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval);
1202+
1203+
DynamicTransform::concretizeFusion(&fusion, &info);
1204+
NVF_CHECK(
1205+
!fusion.hasDynamicTransform(), "Expected to have no dynamic transform");
1206+
}
1207+
1208+
FusionExecutor fe;
1209+
fe.compileFusion(&fusion);
1210+
1211+
auto t0 = at::randn(shape, options);
1212+
for (auto [start, stop] : slice_cases) {
1213+
std::vector<c10::IValue> aten_inputs({t0, start, stop});
1214+
auto cg_outputs = fe.runFusion(aten_inputs);
1215+
1216+
testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
1217+
}
1218+
}
1219+
1220+
// Same as FusionResizeSliceInputShmoo_CUDA but use FusionExecutorCache, which
1221+
// might re-concretize when output sizes change
1222+
TEST_F(NVFuserTest, FusionResizeSliceInputShmooFusionExecutorCache_CUDA) {
1223+
auto fusion_ptr = std::make_unique<Fusion>();
1224+
auto fusion = fusion_ptr.get();
1225+
FusionGuard fg(fusion);
1226+
1227+
std::vector<int64_t> shape({9});
1228+
1229+
// concrete shapes to avoid dynamic Fusion
1230+
auto tv0 = makeConcreteTensor(shape);
1231+
auto s0 = IrBuilder::create<Val>(DataType::Index);
1232+
auto s1 = IrBuilder::create<Val>(DataType::Index);
1233+
fusion->addInput(tv0);
1234+
fusion->addInput(s0);
1235+
fusion->addInput(s1);
1236+
1237+
auto tv1 = slice(tv0, {{s0, s1}});
1238+
fusion->addOutput(tv1);
1239+
1240+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1241+
1242+
FusionExecutorCache fec(std::move(fusion_ptr));
1243+
1244+
auto t0 = at::randn(shape, options);
1245+
for (auto [start, stop] : slice_cases) {
1246+
std::vector<c10::IValue> aten_inputs({t0, start, stop});
1247+
auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
1248+
1249+
testValidate(fec.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__);
1250+
}
1251+
}
1252+
11271253
// Auto scheduled version of Slice1
11281254
TEST_F(ResizeTest, FusionResizeSliceScheduler1) {
11291255
auto fusion_ptr = std::make_unique<Fusion>();
@@ -2319,7 +2445,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual1) {
23192445
FusionGuard fg(fusion_ptr.get());
23202446

23212447
const int64_t slice_offset = 4;
2322-
const std::vector<int64_t> shape({1024 * 1024});
2448+
const std::vector<int64_t> shape({1024L * 1024L});
23232449

23242450
// Using a concrete tensor to avoid dynamic reshape
23252451
auto tv0 = makeContigConcreteTensor(shape);
@@ -2358,7 +2484,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual2) {
23582484
FusionGuard fg(fusion_ptr.get());
23592485

23602486
const int64_t slice_offset = 4;
2361-
const std::vector<int64_t> shape({1024 * 1024});
2487+
const std::vector<int64_t> shape({1024L * 1024L});
23622488

23632489
auto tv0 = makeContigConcreteTensor(shape);
23642490
fusion.addInput(tv0);
@@ -2414,7 +2540,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual3) {
24142540
FusionGuard fg(fusion_ptr.get());
24152541

24162542
const int64_t slice_offset = 4;
2417-
const std::vector<int64_t> shape({1024 * 1024});
2543+
const std::vector<int64_t> shape({1024L * 1024L});
24182544

24192545
auto tv0 = makeContigConcreteTensor(shape);
24202546
fusion.addInput(tv0);
@@ -2463,7 +2589,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual4) {
24632589
auto& fusion = *fusion_ptr;
24642590
FusionGuard fg(fusion_ptr.get());
24652591

2466-
const std::vector<int64_t> shape({1024 * 1024});
2592+
const std::vector<int64_t> shape({1024L * 1024L});
24672593

24682594
auto tv0 = makeContigConcreteTensor({shape[0] - 4});
24692595
fusion.addInput(tv0);
@@ -2505,7 +2631,7 @@ TEST_F(ResizeTest, Slice2DVectorizeManual1) {
25052631
// The extent of the innermost domain is just 2, and the outer
25062632
// domain is sliced. This slicing should be vectorizable by a
25072633
// factor of 4 as the two domains can be merged and vectorized.
2508-
const std::vector<int64_t> shape({1024 * 1024, 2});
2634+
const std::vector<int64_t> shape({1024L * 1024L, 2});
25092635

25102636
auto tv0 = makeContigConcreteTensor(shape);
25112637
fusion.addInput(tv0);

0 commit comments

Comments
 (0)