Skip to content

Commit 1a51efd

Browse files
bdhirshpytorchmergebot
authored andcommitted
dispatch API for checking computed table, use it in prim decomps (pytorch#82358)
Fixes pytorch#82331 Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**. Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels. It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp. Pull Request resolved: pytorch#82358 Approved by: https://github.com/ezyang
1 parent 8a6b076 commit 1a51efd

File tree

8 files changed

+64
-17
lines changed

8 files changed

+64
-17
lines changed

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ class TORCH_API OperatorHandle {
333333
return operatorDef_->op.hasKernelForDispatchKey(k);
334334
}
335335

336+
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
337+
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
338+
}
336339

337340
std::string dumpComputedTable() const {
338341
return operatorDef_->op.dumpComputedTable();

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
211211
return jt->kernel;
212212
}
213213

214+
bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
215+
TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
216+
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
217+
TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
218+
return dispatchTable_[dispatch_ix].isValid();
219+
}
220+
214221
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
215222
auto kern_it = kernels_.find(dispatch_key);
216223
if (kern_it != kernels_.end()) {

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ class TORCH_API OperatorEntry final {
210210
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
211211
// getKernelForDispatchKey (private)
212212
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
213+
// Returns true if the "computed table" has an entry for a particular key.
214+
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
213215
// Returns all the operator tags added at the time of registration
214216
const std::vector<at::Tag>& getTags() const;
215217

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,24 +1108,26 @@ Tensor math_addr(const Tensor& self,
11081108
const Scalar& beta, const Scalar& alpha) {
11091109
// when beta==0, values in self should be ignored,
11101110
// nans and infs in self should not propagate.
1111+
Tensor out;
11111112
if (beta.toComplexDouble() == 0.0) {
11121113
if (alpha.toComplexDouble() == 1.0) {
1113-
return at::outer(vec1, vec2);
1114+
out = at::outer(vec1, vec2);
1115+
} else {
1116+
out = alpha * at::outer(vec1, vec2);
11141117
}
1115-
return alpha * at::outer(vec1, vec2);
1116-
}
1117-
1118-
if (beta.toComplexDouble() == 1.0) {
1118+
} else if (beta.toComplexDouble() == 1.0) {
11191119
if (alpha.toComplexDouble() == 1.0) {
1120-
return self + at::outer(vec1, vec2);
1120+
out = self + at::outer(vec1, vec2);
1121+
} else {
1122+
out = self + alpha * at::outer(vec1, vec2);
11211123
}
1122-
return self + alpha * at::outer(vec1, vec2);
1123-
}
1124-
1125-
if (alpha.toComplexDouble() == 1.0) {
1126-
return beta * self + at::outer(vec1, vec2);
1124+
} else if (alpha.toComplexDouble() == 1.0) {
1125+
out = beta * self + at::outer(vec1, vec2);
1126+
} else {
1127+
out = beta * self + alpha * at::outer(vec1, vec2);
11271128
}
1128-
return beta * self + alpha * at::outer(vec1, vec2);
1129+
auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
1130+
return out.to(c10::TensorOptions().dtype(result_type));
11291131
}
11301132

11311133
Tensor& math_addr_out(const Tensor& self,

aten/src/ATen/native/layer_norm.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
206206
const int normalized_ndim = normalized_shape.size();
207207
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
208208
const int axis = input_ndim - normalized_ndim;
209+
210+
// Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
211+
if (input.numel() == 0) {
212+
auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
213+
return std::make_tuple(
214+
at::empty_like(input),
215+
at::empty_like(input, c10::TensorOptions().dtype(result_type)),
216+
at::empty_like(input, c10::TensorOptions().dtype(result_type))
217+
);
218+
}
209219
at::Tensor input_reshaped = input.view({1, M, -1});
210220
// Unlike Batch Normalization, which applies scalar scale and bias for each
211221
// entire channel/plane with the affine option, Layer Normalization applies

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ class Generator(object):
965965
# Defined in torch/csrc/utils/python_dispatch.cpp
966966
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
967967
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
968+
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
968969
def _dispatch_has_kernel(name: str) -> _bool: ...
969970
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
970971
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...

torch/_decomp/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,15 @@ def add_op_to_table(aten_op):
110110
# which don't have corresponding dispatcher entries, we need
111111
# to filter those out
112112
and torch._C._dispatch_has_kernel(name)
113-
# Don't register a meta kernel to any operator that has
114-
# a CompositeImplicitAutograd kernel in core.
115-
# Otherwise we won't be able to run autograd for that operator with the meta backend.
116-
and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name)
117-
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta")
113+
# Don't register a python meta kernel to any operator that has
114+
# should already work with meta tensors today.
115+
# We can check that by seeing if the "computed table" for the operator
116+
# has a registration to Meta;
117+
# either through a direct registration, or an indirect one through
118+
# an alias dispatch key (e.g. CompositeImplicitAutograd)
119+
and not torch._C._dispatch_has_computed_kernel_for_dispatch_key(
120+
name, "Meta"
121+
)
118122
):
119123
if any(
120124
a.alias_info is not None and not a.alias_info.is_write

torch/csrc/utils/python_dispatch.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ void initDispatchBindings(PyObject* module) {
292292
});
293293

294294
m.def(
295+
// Returns whether or not a direct kernel registration exists
296+
// for this <op_name, dispatch_key> pair.
295297
"_dispatch_has_kernel_for_dispatch_key",
296298
[](const char* name, const char* dispatch) -> bool {
297299
auto op =
@@ -300,6 +302,22 @@ void initDispatchBindings(PyObject* module) {
300302
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
301303
});
302304

305+
m.def(
306+
// Returns whether or not there is an entry in the runtime computed
307+
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
308+
// "op" has a `CompositeImplicitAutograd` kernel, Then
309+
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
310+
// true for all backends that are part of the alias set for
311+
// CompositeImplicitAutograd.
312+
"_dispatch_has_computed_kernel_for_dispatch_key",
313+
[](const char* name, const char* dispatch) -> bool {
314+
auto op =
315+
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
316+
TORCH_CHECK(op, "operator ", name, " does not exist");
317+
return op->hasComputedKernelForDispatchKey(
318+
c10::parseDispatchKey(dispatch));
319+
});
320+
303321
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
304322
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
305323

0 commit comments

Comments
 (0)