Skip to content

Commit e61ac28

Browse files
committed
fixed a typo
2 parents 3cde737 + 5c915f7 commit e61ac28

13 files changed

+70
-196
lines changed

scripts/gen_lazy_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
BaseCType,
99
OptionalCType,
1010
VectorCType,
11+
boolT,
1112
kernel_signature,
1213
)
1314
import pathlib
@@ -22,6 +23,10 @@
2223
source_yaml = str(torch_xla_root / "xla_native_functions.yaml")
2324

2425

26+
def is_boolean_dtype(lazy_type):
27+
return lazy_type == BaseCType(boolT)
28+
29+
2530
@dataclass(frozen=True)
2631
class GenXlaLazyIR(GenLazyIR):
2732

@@ -47,7 +52,7 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
4752
shape_fn_inputs_list = [
4853
f"{a.name}" for a in schema.positional_args
4954
if (a.is_lazy_value or isinstance(a.lazy_type, VectorCType) or
50-
a.name == 'reduction')
55+
is_boolean_dtype(a.lazy_type) or a.name == 'reduction')
5156
]
5257
shape_fn_inputs = ", ".join(shape_fn_inputs_list)
5358

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -676,22 +676,6 @@ at::Tensor XLANativeFunctions::all(const at::Tensor& self, int64_t dim,
676676
XLATensor::all_dim(bridge::GetXlaTensor(self), {dim}, keepdim));
677677
}
678678

679-
at::Tensor XLANativeFunctions::amax(const at::Tensor& self, at::IntArrayRef dim,
680-
bool keepdim) {
681-
XLA_FN_COUNTER("xla::");
682-
auto xdim = XlaHelpers::I64List(dim);
683-
return bridge::AtenFromXlaTensor(
684-
XLATensor::amax(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
685-
}
686-
687-
at::Tensor XLANativeFunctions::amin(const at::Tensor& self, at::IntArrayRef dim,
688-
bool keepdim) {
689-
XLA_FN_COUNTER("xla::");
690-
auto xdim = XlaHelpers::I64List(dim);
691-
return bridge::AtenFromXlaTensor(
692-
XLATensor::amin(bridge::GetXlaTensor(self), std::move(xdim), keepdim));
693-
}
694-
695679
at::Tensor XLANativeFunctions::any(const at::Tensor& self) {
696680
XLA_FN_COUNTER("xla::");
697681
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);

torch_xla/csrc/ops/amax.cpp

Lines changed: 0 additions & 46 deletions
This file was deleted.

torch_xla/csrc/ops/amax.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

torch_xla/csrc/ops/amin.cpp

Lines changed: 0 additions & 46 deletions
This file was deleted.

torch_xla/csrc/ops/amin.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "torch_xla/csrc/reduction.h"
1010

1111
namespace torch_xla {
12-
1312
torch_xla::XlaOpVector Abs::Lower(LoweringContext* loctx) const {
1413
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
1514
return ReturnOp(BuildAbs(xla_input), loctx);
@@ -60,6 +59,16 @@ torch_xla::XlaOpVector All::Lower(LoweringContext* loctx) const {
6059
return ReturnOp(BuildAll(input, dimensions, false), loctx);
6160
}
6261

62+
torch_xla::XlaOpVector Amax::Lower(LoweringContext* loctx) const {
63+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
64+
return ReturnOp(BuildMaxInDims(input, dim, keepdim), loctx);
65+
}
66+
67+
torch_xla::XlaOpVector Amin::Lower(LoweringContext* loctx) const {
68+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
69+
return ReturnOp(BuildMinInDims(input, dim, keepdim), loctx);
70+
}
71+
6372
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
6473
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
6574
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
6868
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
6969
}
7070

71+
xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
72+
absl::Span<const int64_t> dim, bool keepdim) {
73+
auto lower_for_shape_fn =
74+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
75+
return BuildMaxInDims(operands[0], dim, keepdim);
76+
};
77+
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
78+
}
79+
80+
xla::Shape AminOutputShape(const torch::lazy::Value& input,
81+
absl::Span<const int64_t> dim, bool keepdim) {
82+
auto lower_for_shape_fn =
83+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
84+
return BuildMinInDims(operands[0], dim, keepdim);
85+
};
86+
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
87+
}
88+
7189
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
7290
const torch::lazy::Value& grad_output, const torch::lazy::Value& input) {
7391
auto lower_for_shape_fn =

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ xla::Shape AdaptiveAvgPool3dOutputShape(const torch::lazy::Value& input,
2121
xla::Shape AdaptiveAvgPool3dBackwardOutputShape(
2222
const torch::lazy::Value& grad_output, const torch::lazy::Value& input);
2323

24+
xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
25+
absl::Span<const int64_t> dim, bool keepdim);
26+
27+
xla::Shape AminOutputShape(const torch::lazy::Value& input,
28+
absl::Span<const int64_t> dim, bool keepdim);
29+
2430
xla::Shape AllOutputShape(const torch::lazy::Value& input);
2531

2632
xla::Shape AsinOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/reduction.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "tensorflow/compiler/xla/client/lib/constants.h"
1010
#include "tensorflow/compiler/xla/literal_util.h"
1111
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
12+
#include "torch/csrc/lazy/core/helpers.h"
1213
#include "torch/csrc/lazy/core/util.h"
1314
#include "torch_xla/csrc/convert_ops.h"
1415
#include "torch_xla/csrc/helpers.h"
@@ -318,17 +319,20 @@ xla::XlaOp BuildMaxInDims(xla::XlaOp input,
318319
bool keep_reduced_dimensions) {
319320
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
320321
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type());
322+
std::vector<int64_t> canonical_dimensions =
323+
torch::lazy::GetCanonicalDimensionIndices(
324+
xla::util::ToVector<int64_t>(dimensions), shape.rank());
321325
xla::XlaOp init_value = XlaHelpers::ScalarValue(
322326
min_max.min, shape.element_type(), input.builder());
323-
ReductionInfo rinfo =
324-
GetReductionInfo(input, shape, dimensions, keep_reduced_dimensions);
327+
ReductionInfo rinfo = GetReductionInfo(input, shape, canonical_dimensions,
328+
keep_reduced_dimensions);
325329
if (rinfo.element_count.scalar_size) {
326330
// When can only assert this if dimensions are not dynamic.
327331
XLA_CHECK_GT(*rinfo.element_count.scalar_size, 0);
328332
}
329333
xla::XlaOp result = xla::Reduce(
330334
input, init_value, XlaHelpers::CreateMaxComputation(shape.element_type()),
331-
dimensions);
335+
canonical_dimensions);
332336
if (keep_reduced_dimensions) {
333337
result = XlaHelpers::DynamicReshape(result, rinfo.new_dimensions);
334338
}
@@ -345,17 +349,22 @@ xla::XlaOp BuildMinInDims(xla::XlaOp input,
345349
bool keep_reduced_dimensions) {
346350
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
347351
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type());
352+
353+
std::vector<int64_t> canonical_dimensions =
354+
torch::lazy::GetCanonicalDimensionIndices(
355+
xla::util::ToVector<int64_t>(dimensions), shape.rank());
356+
348357
xla::XlaOp init_value = XlaHelpers::ScalarValue(
349358
min_max.max, shape.element_type(), input.builder());
350-
ReductionInfo rinfo =
351-
GetReductionInfo(input, shape, dimensions, keep_reduced_dimensions);
359+
ReductionInfo rinfo = GetReductionInfo(input, shape, canonical_dimensions,
360+
keep_reduced_dimensions);
352361
if (rinfo.element_count.scalar_size) {
353362
// When can only assert this if dimensions are not dynamic.
354363
XLA_CHECK_GT(*rinfo.element_count.scalar_size, 0);
355364
}
356365
xla::XlaOp result = xla::Reduce(
357366
input, init_value, XlaHelpers::CreateMinComputation(shape.element_type()),
358-
dimensions);
367+
canonical_dimensions);
359368
if (keep_reduced_dimensions) {
360369
result = XlaHelpers::DynamicReshape(result, rinfo.new_dimensions);
361370
}

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
#include "torch_xla/csrc/ops/all_gather.h"
2727
#include "torch_xla/csrc/ops/all_reduce.h"
2828
#include "torch_xla/csrc/ops/all_to_all.h"
29-
#include "torch_xla/csrc/ops/amax.h"
30-
#include "torch_xla/csrc/ops/amin.h"
3129
#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h"
3230
#include "torch_xla/csrc/ops/amp_update_scale.h"
3331
#include "torch_xla/csrc/ops/any.h"
@@ -713,28 +711,6 @@ XLATensorPtr XLATensor::all_dim(const XLATensorPtr& input,
713711
result_type);
714712
}
715713

716-
XLATensorPtr XLATensor::amax(const XLATensorPtr& input,
717-
std::vector<int64_t> dimensions,
718-
bool keep_reduced_dimensions) {
719-
return input->CreateFrom(
720-
torch::lazy::MakeNode<Amax>(input->GetIrValue(),
721-
torch::lazy::GetCanonicalDimensionIndices(
722-
xla::util::ToVector<int64_t>(dimensions),
723-
input->shape().get().rank()),
724-
keep_reduced_dimensions));
725-
}
726-
727-
XLATensorPtr XLATensor::amin(const XLATensorPtr& input,
728-
std::vector<int64_t> dimensions,
729-
bool keep_reduced_dimensions) {
730-
return input->CreateFrom(
731-
torch::lazy::MakeNode<Amin>(input->GetIrValue(),
732-
torch::lazy::GetCanonicalDimensionIndices(
733-
xla::util::ToVector<int64_t>(dimensions),
734-
input->shape().get().rank()),
735-
keep_reduced_dimensions));
736-
}
737-
738714
XLATensorPtr XLATensor::any(const XLATensorPtr& input,
739715
std::vector<int64_t> dimensions,
740716
bool keep_reduced_dimensions) {

torch_xla/csrc/token_handler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "tensorflow/compiler/xla/client/lib/constants.h"
44
#include "tensorflow/compiler/xla/shape_util.h"
5+
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
56
#include "torch_xla/csrc/convert_ops.h"
67
#include "torch_xla/csrc/helpers.h"
78

@@ -32,6 +33,12 @@ xla::XlaOp SliceOneToken(xla::XlaOp input) {
3233

3334
xla::XlaOp TokenHandler::GetInput(xla::XlaOp input,
3435
const xla::Shape* input_shape) {
36+
static bool disable_numeric_token =
37+
xla::sys_util::GetEnvBool("DISABLE_NUMERIC_CC_TOKEN", false);
38+
if (disable_numeric_token) {
39+
return input;
40+
}
41+
3542
if (input_shape == nullptr) {
3643
input_shape = &XlaHelpers::ShapeOfXlaOp(input);
3744
}
@@ -40,6 +47,12 @@ xla::XlaOp TokenHandler::GetInput(xla::XlaOp input,
4047
}
4148

4249
xla::XlaOp TokenHandler::GetNewToken(xla::XlaOp result) {
50+
static bool disable_numeric_token =
51+
xla::sys_util::GetEnvBool("DISABLE_NUMERIC_CC_TOKEN", false);
52+
if (disable_numeric_token) {
53+
return token_;
54+
}
55+
4356
xla::XlaOp slice = SliceOneToken(result);
4457
// Token is always a numeric zero, and multiplying it for one element of the
4558
// result will still leave it as zero.

0 commit comments

Comments
 (0)