Commit 7b142b3
Clip slice range expressions (#460)
This PR normalizes the inputs to `slice` in order to mimic the semantics
of numpy/PyTorch slicing. For an axis with extent `ext`, if we receive a
slice of `(start, stop, step)` we normalize it to `(norm_start,
norm_stop, step)` where
```
norm_start = max(0, start < 0 ? start + ext : start);
norm_stop = max(norm_start, min(ext, stop < 0 ? stop + ext : stop));
```
Specific changes in this PR:
- Form the above expressions in the `slice` op.
- Add shmoo tests that test various scenarios with constant and input
size slices.
The simple Fusion in the input range test prints like this:
```
Inputs:
T0_g[ iS0{9} ], float
i3, nvfuser_index_t
i4, nvfuser_index_t
Outputs:
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ], float
%kernel_math {
b7 = i3 < 0;
i5 = i3 + 9;
i9 = where(b7, i5, i3);
i11 = fmax(0, i9);
b15 = i4 < 0;
i13 = i4 + 9;
i17 = where(b15, i13, i4);
i19 = fmin(9, i17);
i21 = fmax(i11, i19);
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
= slice( T0_g[ iS0{9} ], { {i11, i21, 1} } )
}
T0_g[ iS0{9} ]
root domain : (iS0{9})
contiguity: f
leaf domain : (iS0{9})
T1_g[ ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf ]
root domain : (iS1{9}rf)
Resize: iS1{9}rf by ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) and ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) -> ?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf
rfactor domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
contiguity: t
leaf domain : (?S2{( ( ( -( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ) ) + 9 ) + ( ( fmax(( fmax(0, ( where(( i3 < 0 ), ( i3 + 9 ), i3) )) ), ( fmin(9, ( where(( i4 < 0 ), ( i4 + 9 ), i4) )) )) ) - 9 ) )}rf)
```
resulting in the following CUDA kernel:
```c++
__global__ void kernel1(Tensor<float, 1, 1> T0, nvfuser_index_t i0, nvfuser_index_t i1, Tensor<float, 1, 1> T1) {
nvfuser_index_t i2;
i2 = i0 + 9;
bool b3;
b3 = i0 < 0;
nvfuser_index_t i4;
i4 = b3 ? i2 : i0;
nvfuser_index_t i5;
i5 = max(0, i4);
nvfuser_index_t i6;
i6 = i1 + 9;
bool b7;
b7 = i1 < 0;
nvfuser_index_t i8;
i8 = b7 ? i6 : i1;
nvfuser_index_t i9;
i9 = min(9, i8);
nvfuser_index_t i10;
i10 = max(i5, i9);
nvfuser_index_t i11;
i11 = (-i5) + i10;
nvfuser_index_t i12;
i12 = i5 * T0.alloc_stride[0];
#pragma unroll 1
for(nvfuser_index_t i13 = 0; i13 < i11; ++i13) {
T1[i13]
= T0[(i12 + (T0.alloc_stride[0] * i13))];
}
}
```
This PR does NOT simplify these expressions for non-constant inputs.
This can be done at concretization, which will be left for a follow-up
PR.
Stacked on #892 and #895.
Fixes #439. Fixes #52.
---------
Co-authored-by: Naoya Maruyama <[email protected]>1 parent 2dcfef6 commit 7b142b3
File tree
5 files changed
+174
-26
lines changed- csrc
- ops
- test
5 files changed
+174
-26
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
316 | 316 | | |
317 | 317 | | |
318 | 318 | | |
319 | | - | |
| 319 | + | |
320 | 320 | | |
321 | 321 | | |
322 | 322 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
897 | 897 | | |
898 | 898 | | |
899 | 899 | | |
900 | | - | |
901 | | - | |
| 900 | + | |
| 901 | + | |
902 | 902 | | |
903 | 903 | | |
904 | 904 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
690 | 690 | | |
691 | 691 | | |
692 | 692 | | |
693 | | - | |
694 | | - | |
695 | | - | |
696 | 693 | | |
697 | 694 | | |
698 | 695 | | |
| |||
704 | 701 | | |
705 | 702 | | |
706 | 703 | | |
707 | | - | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
708 | 711 | | |
709 | | - | |
710 | | - | |
711 | | - | |
712 | | - | |
713 | | - | |
714 | | - | |
715 | | - | |
716 | | - | |
717 | | - | |
| 712 | + | |
| 713 | + | |
718 | 714 | | |
719 | 715 | | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
720 | 722 | | |
721 | | - | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
722 | 728 | | |
723 | 729 | | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
724 | 738 | | |
725 | | - | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
726 | 744 | | |
727 | 745 | | |
728 | 746 | | |
| 747 | + | |
729 | 748 | | |
730 | 749 | | |
731 | 750 | | |
732 | 751 | | |
733 | 752 | | |
734 | 753 | | |
735 | 754 | | |
736 | | - | |
| 755 | + | |
737 | 756 | | |
738 | 757 | | |
739 | 758 | | |
| |||
754 | 773 | | |
755 | 774 | | |
756 | 775 | | |
| 776 | + | |
757 | 777 | | |
758 | 778 | | |
759 | 779 | | |
760 | 780 | | |
761 | 781 | | |
762 | | - | |
| 782 | + | |
763 | 783 | | |
764 | 784 | | |
765 | 785 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
95 | 97 | | |
96 | 98 | | |
97 | 99 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1124 | 1124 | | |
1125 | 1125 | | |
1126 | 1126 | | |
| 1127 | + | |
| 1128 | + | |
| 1129 | + | |
| 1130 | + | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
| 1136 | + | |
| 1137 | + | |
| 1138 | + | |
| 1139 | + | |
| 1140 | + | |
| 1141 | + | |
| 1142 | + | |
| 1143 | + | |
| 1144 | + | |
| 1145 | + | |
| 1146 | + | |
| 1147 | + | |
| 1148 | + | |
| 1149 | + | |
| 1150 | + | |
| 1151 | + | |
| 1152 | + | |
| 1153 | + | |
| 1154 | + | |
| 1155 | + | |
| 1156 | + | |
| 1157 | + | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
| 1162 | + | |
| 1163 | + | |
| 1164 | + | |
| 1165 | + | |
| 1166 | + | |
| 1167 | + | |
| 1168 | + | |
| 1169 | + | |
| 1170 | + | |
| 1171 | + | |
| 1172 | + | |
| 1173 | + | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
| 1177 | + | |
| 1178 | + | |
| 1179 | + | |
| 1180 | + | |
| 1181 | + | |
| 1182 | + | |
| 1183 | + | |
| 1184 | + | |
| 1185 | + | |
| 1186 | + | |
| 1187 | + | |
| 1188 | + | |
| 1189 | + | |
| 1190 | + | |
| 1191 | + | |
| 1192 | + | |
| 1193 | + | |
| 1194 | + | |
| 1195 | + | |
| 1196 | + | |
| 1197 | + | |
| 1198 | + | |
| 1199 | + | |
| 1200 | + | |
| 1201 | + | |
| 1202 | + | |
| 1203 | + | |
| 1204 | + | |
| 1205 | + | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
| 1217 | + | |
| 1218 | + | |
| 1219 | + | |
| 1220 | + | |
| 1221 | + | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
| 1225 | + | |
| 1226 | + | |
| 1227 | + | |
| 1228 | + | |
| 1229 | + | |
| 1230 | + | |
| 1231 | + | |
| 1232 | + | |
| 1233 | + | |
| 1234 | + | |
| 1235 | + | |
| 1236 | + | |
| 1237 | + | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
| 1246 | + | |
| 1247 | + | |
| 1248 | + | |
| 1249 | + | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
1127 | 1253 | | |
1128 | 1254 | | |
1129 | 1255 | | |
| |||
2319 | 2445 | | |
2320 | 2446 | | |
2321 | 2447 | | |
2322 | | - | |
| 2448 | + | |
2323 | 2449 | | |
2324 | 2450 | | |
2325 | 2451 | | |
| |||
2358 | 2484 | | |
2359 | 2485 | | |
2360 | 2486 | | |
2361 | | - | |
| 2487 | + | |
2362 | 2488 | | |
2363 | 2489 | | |
2364 | 2490 | | |
| |||
2414 | 2540 | | |
2415 | 2541 | | |
2416 | 2542 | | |
2417 | | - | |
| 2543 | + | |
2418 | 2544 | | |
2419 | 2545 | | |
2420 | 2546 | | |
| |||
2463 | 2589 | | |
2464 | 2590 | | |
2465 | 2591 | | |
2466 | | - | |
| 2592 | + | |
2467 | 2593 | | |
2468 | 2594 | | |
2469 | 2595 | | |
| |||
2505 | 2631 | | |
2506 | 2632 | | |
2507 | 2633 | | |
2508 | | - | |
| 2634 | + | |
2509 | 2635 | | |
2510 | 2636 | | |
2511 | 2637 | | |
| |||
0 commit comments