diff --git a/experimental/xla_tensor.py b/experimental/xla_tensor.py new file mode 100644 index 000000000000..0febbac3216e --- /dev/null +++ b/experimental/xla_tensor.py @@ -0,0 +1,103 @@ +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import logging +import functools +import time +from collections import namedtuple + +debug = False +SIMPLE_TEST = True +time_torch_function = 0 + +class XLATensor(torch.Tensor): + def __new__(cls, data, **kwargs): + return torch.Tensor._make_subclass(cls, torch.as_tensor(data, dtype=torch.float32, **kwargs)) + + def __init__(self, data, **kwargs): + super().__init__() + # if debug: print('[__init__]') + self.t_ = torch.as_tensor(data, dtype=torch.float32, **kwargs) + self.xla_tensor_sizes_ = namedtuple("XLATensorSizes", "static_size dynamic_size") + + def size(self): + # if debug: print('[size] ') + static_size = self.t_.size() + dynamic_size = [] + for i,_ in enumerate(self.t_.shape): + dynamic_size.append(torch_xla._XLAC._get_xla_tensor_dimension_size(self.t_, i)) + return self.xla_tensor_sizes_(static_size, dynamic_size[0]) + + def expand(self, size): + # if debug: print('[expand]') + return torch_xla._XLAC._xla_dynamic_expand(self.t_, size.dynamic_size, size.dynamic_size.int(), size.static_size) + + def nonzero(self): + return torch.nonzero(self.t_) + + def __repr__(self): + # if debug: print ('[__repr__]') + return "XLATensor:\n{}\n".format(self.t_) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + args = [a.t_ if hasattr(a, 't_') else a for a in args] + result = func(*args, **kwargs) + return result + +if SIMPLE_TEST == True: + a = XLATensor([[0, 0, 0], [1, 1, 0]], device=xm.xla_device()) + b = XLATensor([[1], [2]], device=xm.xla_device()) + + idxs = a.nonzero() + idxs_t = XLATensor([[1]], device=xm.xla_device()) + idxs_t.t_ = idxs + print('a.nonzero(): ', idxs) + + y = b.expand(idxs_t.size()) + print (f"b.expand(): {y}") + y_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(y, 0) + print("[RUNNING] _get_xla_tensor_dimension_size(y, 0)\n", y_dim0_shape) + print("[RUNNING] _get_xla_tensors_hlo([y])\n", torch_xla._XLAC._get_xla_tensors_hlo([y])) + print("[RUNNING] _get_xla_tensors_text([y])\n", torch_xla._XLAC._get_xla_tensors_text([y])) +else: + NUM = 500 + + time_size = 0 + time_expand = 0 + for i in range(NUM): + t = torch.tensor([[1], [2], [3]], device=xm.xla_device()) + o = torch.tensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device()) + + tic = time.perf_counter() + b = o.size() + toc = time.perf_counter() + time_size += toc-tic + + tic = time.perf_counter() + tt = t.expand(b) + toc = time.perf_counter() + time_expand += toc-tic + print(f"size() time {time_size:0.4f} seconds") + print(f"expand() time {time_expand:0.4f} seconds") + + time_size = 0 + time_expand = 0 + for i in range(NUM): + t = XLATensor([[1], [2], [3]], device=xm.xla_device()) + o = XLATensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device()) + + tic = time.perf_counter() + b = o.size() + toc = time.perf_counter() + time_size += toc-tic + + tic = time.perf_counter() + tt = t.expand(b) + toc = time.perf_counter() + time_expand += toc-tic + print(f"size() time {time_size:0.4f} seconds") + print(f"expand() time {time_expand:0.4f} seconds") + print(f"torch_function() time {time_torch_function:0.4f} seconds") \ No newline at end of file diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py new file mode 100644 index 000000000000..b0344034553a --- /dev/null +++ b/test/test_dynamic_shapes.py @@ -0,0 +1,50 @@ +# Run: +# XLA_EXPERIMENTAL="nonzero:masked_select" python3 +# + +import torch +import torch_xla + + +class TestDynamicShapes(object): + + def __init__(self): + self.device = 'xla:0' + + def runTest(self): + t1 = torch.tensor([1, 0, 2, 0, 0, 1, 3], device=self.device) + + t2 = torch.nonzero(t1) + t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0) + print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape) + print("[RUNNING] _get_xla_tensors_text([t2])\n", + torch_xla._XLAC._get_xla_tensors_text([t2])) + print("[RUNNING] _get_xla_tensors_hlo([t2])\n", + torch_xla._XLAC._get_xla_tensors_hlo([t2])) + assert t2_dim0_shape.item() == 4 + + t3 = torch.fill_(t2, 10) + print(f"size of {t3.size()}, size of t2 {t2.size()}") + t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0) + t3_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t3, 0) + print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape) + print("[RUNNING] _get_xla_tensor_dimension_size(t3, 0)\n", t3_dim0_shape) + print("[RUNNING] _get_xla_tensors_text([t2])\n", + torch_xla._XLAC._get_xla_tensors_text([t2])) + print("[RUNNING] _get_xla_tensors_text([t3])\n", + torch_xla._XLAC._get_xla_tensors_text([t3])) + print("[RUNNING] _get_xla_tensors_hlo([t2])\n", + torch_xla._XLAC._get_xla_tensors_hlo([t2])) + print("[RUNNING] _get_xla_tensors_hlo([t3])\n", + torch_xla._XLAC._get_xla_tensors_hlo([t3])) + assert t2_dim0_shape.item() == 4 + assert t3_dim0_shape.item() == 4 + + print('t1: ', t1) + print('t2: ', t2) + print('t3: ', t3) + + +if __name__ == "__main__": + test = TestDynamicShapes() + test.runTest() diff --git a/third_party/tensorflow b/third_party/tensorflow index dafc412b0a95..b8551d62cf75 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit dafc412b0a95dbf8ea2d0487dc6518fdd39f8dec +Subproject commit b8551d62cf750ac6791dff92bc04c1e7d75e7ec4 diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 402c2e8eb40e..cc1615d8b025 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -12,11 +12,16 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/util.h" +#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "tensorflow/compiler/xla/client/value_inference.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/reduction.h" #include "torch_xla/csrc/tensor_util.h" +#include +using namespace std::chrono; + namespace torch_xla { namespace { @@ -112,6 +117,27 @@ xla::XlaOp BuildExpand(xla::XlaOp input, xla::util::Iota(output_sizes.size())); } +xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, + xla::XlaOp dynamic_target, + xla::Shape dynamic_shapes) { + // auto start = high_resolution_clock::now(); + xla::XlaOp output = BuildExpand(static_input, dynamic_shapes.dimensions()); + bool seen_dynamic = false; // Limit support to one dynamic dimension + for (int i = 0; i < dynamic_shapes.rank(); ++i) { + if (dynamic_shapes.is_dynamic_dimension(i)) { + XLA_CHECK(seen_dynamic == false); + seen_dynamic = true; + output = xla::SetDimensionSize(output, MaybeConvertTo(dynamic_target, + xla::PrimitiveType::S32) , i); + } + } + // auto stop = high_resolution_clock::now(); + // auto duration = duration_cast(stop - start); + // std::cout << "Time taken dynamic shape (BuildDynamicShape): " + // << duration.count() << " microseconds" << std::endl; + return output; +} + std::vector BuildSqueezedDimensions( absl::Span dimensions, xla::int64 squeeze_dim) { std::vector output_dimensions; diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index c957d4fd4263..305dc426c764 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -37,6 +37,11 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); xla::XlaOp BuildExpand(xla::XlaOp input, absl::Span output_sizes); +// Dynamic Shape version of BuildExpand() +xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, + xla::XlaOp dynamic_target, + xla::Shape dynamic_shapes); + std::vector BuildSqueezedDimensions( absl::Span dimensions, xla::int64 squeeze_dim); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c96ccc026d34..f1842533a2f9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -648,6 +648,17 @@ absl::flat_hash_map> ConvertDictToMap( return map; } +at::Tensor GetXlaDynamicExpand(const at::Tensor& tensor, + const at::Tensor& dynamic_size_tensors, + xla::int64 dynamic_size_values, + std::vector static_size) { + XLATensor xtensor = bridge::GetXlaTensor(tensor); + XLATensor xdynamicsizetensors = bridge::GetXlaTensor(dynamic_size_tensors); + at::Tensor dynamic_expand_tensor = bridge::AtenFromXlaTensor( + XLATensor::dynamic_expand(xtensor, xdynamicsizetensors, dynamic_size_values, static_size)); + return dynamic_expand_tensor; +} + void BuildProfilerSubmodule(py::module* m) { py::module profiler = m->def_submodule("profiler", "Profiler integration"); py::class_ static_size) { + return GetXlaDynamicExpand(tensor, dynamic_size_tensors, dynamic_size_values, static_size); + }); BuildProfilerSubmodule(&m); } diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 4098803292f2..2cbaf2ae4fcf 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -5,6 +5,7 @@ #include "tensorflow/compiler/xla/client/lib/logdet.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/value_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/util.h" @@ -882,6 +883,50 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, input.shape(), std::move(lower_fn)); } +NodePtr DynamicExpand(const Value& input, const std::vector static_size, const xla::int64 dynamic_size, const Value& dynamic_target) { + xla::Shape input_shape = input.shape(); + auto lower_fn = [input_shape, static_size, dynamic_size](const Node& node, LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp static_input = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp dynamic_target = loctx->GetOutputOp(node.operand(1)); + std::vector dynamic_dims; + std::vector dims; + for (int i = 0; i < static_size.size(); i++) { + if (i == 0) { + dims.push_back(dynamic_size); + dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false); + } else { + dims.push_back(static_size[i]); + dynamic_dims.push_back(false); + } + } + xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims); + xla::XlaOp dynamic_output = BuildDynamicExpand(static_input, dynamic_target, target_shape); + return node.ReturnOp(dynamic_output, loctx); + }; + auto shape_fn = [input_shape, static_size, dynamic_size](absl::Span operands) -> xla::XlaOp { + std::vector dynamic_dims; + std::vector dims; + for (int i = 0; i < static_size.size(); i++) { + if (i == 0) { + dims.push_back(dynamic_size); + dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false); + } else { + dims.push_back(static_size[i]); + dynamic_dims.push_back(false); + } + } + xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims); + return BuildDynamicExpand(operands[0], operands[1], target_shape); + }; + return GenericOp(OpKind(at::aten::expand), {input, dynamic_target}, + [&]() { + return InferOutputShape( + {input.shape(), dynamic_target.shape()}, + shape_fn); + }, + std::move(lower_fn)); +} + } // namespace ops } // namespace ir -} // namespace torch_xla +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index fa5e428ef248..816b968b12a3 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -226,6 +226,10 @@ NodePtr LogicalOr(const Value& input, const Value& other); NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, const Value& neginf); +NodePtr DynamicExpand(const Value& static_input, + const std::vector static_size_values, + const xla::int64 dynamic_size_values, + const Value& dynamic_target); } // namespace ops } // namespace ir diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 11bf844e259f..9a92e1cdfbda 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -513,6 +513,8 @@ class XLATensor { static XLATensor expand(const XLATensor& input, std::vector size); + static XLATensor dynamic_expand(const XLATensor& input, const XLATensor& dynamic_size_tensors, const xla::int64 dynamic_size, std::vector static_size); + static XLATensor expm1(const XLATensor& input); static void exponential_(XLATensor& input, double lambd); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 8831a31cc9db..418f68bb0497 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1183,6 +1183,17 @@ XLATensor XLATensor::expand(const XLATensor& input, GetExpandDimensions(input_shape.get(), std::move(size)))); } +XLATensor XLATensor::dynamic_expand(const XLATensor& input, + const XLATensor& dynamic_size_tensors, + const xla::int64 dynamic_size, + std::vector static_size) { + return input.CreateFrom(ir::ops::DynamicExpand( + input.GetIrValue(), + std::move(static_size), + std::move(dynamic_size), + dynamic_size_tensors.GetIrValue())); +} + XLATensor XLATensor::expm1(const XLATensor& input) { return input.CreateFrom(ir::ops::Expm1(input.GetIrValue())); } @@ -1211,8 +1222,12 @@ void XLATensor::eye_out(XLATensor& out, xla::int64 lines, xla::int64 cols) { } void XLATensor::fill_(XLATensor& input, const at::Scalar& value) { - ir::Value constant = - GetIrValueForScalar(value, input.shape(), input.GetDevice()); + ir::Value constant = GetIrValueForScalar(value, input.GetDevice()); + constant = ir::ops::DynamicExpand( + constant, + {-1}, //TODO: fix this line + std::move(-1), //TODO: fix this line + input.GetIrValue()); input.SetInPlaceIrValue(std::move(constant)); }