Skip to content

Commit 3e4ec2f

Browse files
authored
Generate where with C++ ? ... : ... rather than using device functions (#2472)
1 parent 7b37a83 commit 3e4ec2f

File tree

2 files changed

+13
-54
lines changed

2 files changed

+13
-54
lines changed

third_party/nvfuser/csrc/codegen.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -994,17 +994,22 @@ class CudaKernelGenerator : private OptOutConstDispatch {
994994
code_ << " = ";
995995
}
996996

997-
code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", ";
998-
999-
// Make sure the two operands of where has the same
1000-
// type. Note that compiling "where(0.0f, 0.0)" fails because of
1001-
// the overloading ambiguity.
997+
// Don't use a runtime device function for where as the second and
998+
// third aguments should not be evaluated unless picked by the
999+
// condition. If a device function is implemnted as pass-by-value,
1000+
// both arguments would be evaluated. Could be worked around by
1001+
// pass-by-reference, but it's just simpler to use the C++ ? operator.
10021002
if (top->getTernaryOpType() == TernaryOpType::Where) {
1003+
code_ << gen(top->in1()) << " ? ";
1004+
// Make sure the two operands of where has the same
1005+
// type. Note that compiling "where(0.0f, 0.0)" fails because of
1006+
// the overloading ambiguity.
10031007
auto cast = scalarCast(top->in2(), top->in3());
1004-
code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", "
1005-
<< (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")";
1008+
code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << " : "
1009+
<< (top->in3()->isScalar() ? cast : "") << gen(top->in3());
10061010
} else {
1007-
code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")";
1011+
code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", "
1012+
<< gen(top->in2()) << ", " << gen(top->in3()) << ")";
10081013
}
10091014

10101015
if (!print_inline_) {

third_party/nvfuser/runtime/helpers.cu

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,6 @@ __device__ float threshold(float x, double t, double v) {
275275
return x <= t ? v : x;
276276
}
277277

278-
__device__ std::complex<double> where(
279-
bool c,
280-
std::complex<double> a,
281-
std::complex<double> b) {
282-
return c ? a : b;
283-
}
284-
285-
__device__ std::complex<float> where(
286-
bool c,
287-
std::complex<float> a,
288-
std::complex<float> b) {
289-
return c ? a : b;
290-
}
291-
292278
__device__ int threshold(int x, int64_t t, int64_t v) {
293279
return x <= t ? v : x;
294280
}
@@ -297,38 +283,6 @@ __device__ int64_t threshold(int64_t x, int64_t t, int64_t v) {
297283
return x <= t ? v : x;
298284
}
299285

300-
__device__ double where(bool c, double a, double b) {
301-
return c ? a : b;
302-
}
303-
304-
__device__ float where(bool c, float a, float b) {
305-
return c ? a : b;
306-
}
307-
308-
__device__ __half where(bool c, __half a, __half b) {
309-
return c ? a : b;
310-
}
311-
312-
__device__ __bfloat where(bool c, __bfloat a, __bfloat b) {
313-
return c ? a : b;
314-
}
315-
316-
__device__ int64_t where(bool c, int64_t a, int64_t b) {
317-
return c ? a : b;
318-
}
319-
320-
__device__ int where(bool c, int a, int b) {
321-
return c ? a : b;
322-
}
323-
324-
__device__ int64_t where(bool c, int64_t a, int b) {
325-
return c ? a : b;
326-
}
327-
328-
__device__ int64_t where(bool c, int a, int64_t b) {
329-
return c ? a : b;
330-
}
331-
332286
__device__ constexpr int64_t remainder(int64_t a, int64_t b) {
333287
auto mod = a % b;
334288
if ((mod != 0) && ((b < 0) != (mod < 0)))

0 commit comments

Comments
 (0)