Skip to content

[POC] Dynamic Shape size API #3360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions test/dynamic_lazy_tensor5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch_xla
import torch_xla.debug.metrics as metrics
import torch_xla.core.xla_model as xm
from collections.abc import Iterable
from enum import Enum

torch_xla._XLAC._xla_init_ts_backend()
torch_xla._XLAC._xla_set_dynamic_shapes_mode()


class LazyDynamicSize:
def __init__(self, n):
super().__init__()
self._n = n

@staticmethod
def fromTensor(t):
n = torch_xla._XLAC._dynamic_size2(t)
return LazyDynamicSize(n)

class AutogradExpand(torch.autograd.Function):
@staticmethod
def forward(ctx, self, size):
ctx.save_for_backward(self)
print("running Expand forward")
t = torch_xla._XLAC._dynamic_expand2(self, size._n)
return t

@staticmethod
def backward(ctx, grad_output):
# TODO: we need sum_to for expand
print("running Expand backward")
(input,) = ctx.saved_tensors
return torch_xla._XLAC._sum_to_size(grad_output, input.sizes()._n), None

class DynamicLazyTensor4(torch.Tensor):
"""
Non-compound ops
"""

# Category 1: non-compound ops that don't take size or dim and don't use sizes() in backward
# e.g. add, relu, abs, mul, div, etc
# No work necessary as long we can rely on `torch.autograd.register_py_tensor_class_for_device("xla:0", DynamicLazyTensor4)`

# Category 2: non-compound ops that return sizes() or dim can be implemented with thin wrappers around IR
# We would need to handwrite these ops
def sizes(self):
return LazyDynamicSize.fromTensor(self)

def size(index):
raise RuntimeError("NYI")

def expand(self, sizes):
return AutogradExpand.apply(self, sizes)



# register the python class
torch.autograd._register_py_tensor_class_for_device("xla:0", DynamicLazyTensor4)
a = torch.ones(2, 2, requires_grad=True).to(device="xla:0")
a2 = torch.ones(2, 2).to(device="xla:0")
# print(f"a2={a2[0, 0]}")
# print(f"a={a[0, 0]}")

print(type(a).__name__)
sz = a.sizes()
c = a.expand(sz)
b = torch.nn.functional.gelu(c)





grad_outputs = [torch.ones(2, 2).to(device="xla:0")]
grads = torch.autograd.grad((b, ), (a,), grad_outputs)

print(torch_xla._XLAC._get_xla_tensors_text([grads[0]]))
print(torch_xla._XLAC._get_xla_tensors_backend([grads[0]]))

print(grads[0][0, 0].item())
#xm.mark_step()
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,13 @@ at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
bridge::GetXlaTensor(self), xla::util::ToVector<int64_t>(size)));
}

at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
at::IntArrayRef size, bool implicit) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::expand(
bridge::GetXlaTensor(self), xla::util::ToVector<int64_t>(size)));
}

at::Tensor XLANativeFunctions::expm1(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ xla::XlaOp BuildExpand(xla::XlaOp input,
xla::util::Iota<int64_t>(output_sizes.size()));
}

xla::XlaOp BuildDynamicExpand(xla::XlaOp input,
xla::XlaOp output_size,
xla::Shape output_shape) {
xla::XlaOp output = BuildExpand(input, output_shape.dimensions());
for (int i = 0; i < output_dims.dimensions(); ++i) {
if (output_shape.is_dynamic_dimension(i)) {
output = xla::SetDimensionSize(output, MaybeConvertTo(output_size,
xla::PrimitiveType::S32) , i);
}
}
return output;
}

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, int64_t squeeze_dim) {
std::vector<int64_t> output_dimensions;
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input);
xla::XlaOp BuildExpand(xla::XlaOp input,
absl::Span<const int64_t> output_sizes);

// Creates a new tensor with the dynamic dimensions expanded to the specified
// output sizes.
xla::XlaOp BuildDynamicExpand(xla::XlaOp input,
xla::XlaOp output_size,
xla::Shape output_shape);

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, int64_t squeeze_dim);

Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
#include "torch_xla/csrc/version.h"
#include "torch_xla/csrc/xla_op_builder.h"

//#include "torch_xla/csrc/ops/sum_to_size.h"
#include "torch_xla/csrc/ops/dynamic_size.h"
#include "torch_xla/csrc/ops/dynamic_expand.h"

namespace torch_xla {
namespace {

Expand Down Expand Up @@ -851,6 +855,7 @@ void InitXlaModuleBindings(py::module m) {
});

py::class_<ir::Value, std::shared_ptr<ir::Value>>(m, "IrValue");
py::class_<ir::Node, std::shared_ptr<ir::Node>>(m, "IrNode");
m.def("_xla_create_token",
[](const std::string& device) { return CreateToken(device); });
m.def("_xla_all_reduce_inplace", [](const std::string& reduce_type,
Expand All @@ -867,6 +872,23 @@ void InitXlaModuleBindings(py::module m) {
}
return new_token;
});
m.def("_dynamic_expand2",
[](at::Tensor& self, std::shared_ptr<ir::Node> val) {
XLATensor self_xla_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
self_xla_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicExpand2>(
self_xla_tensor.GetIrValue(),val)));
});
m.def("_dynamic_size2",
[](at::Tensor& self) {
XLATensor self_xla_tensor = bridge::GetXlaTensor(self);
return ir::MakeNode<ir::ops::DynamicSize2>(self_xla_tensor.GetIrValue());
});
m.def("_sum_to_size",
[](at::Tensor& self, std::shared_ptr<ir::Node> val) {
XLATensor self_xla_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(self_xla_tensor.CreateFrom(ir::MakeNode<ir::ops::SumToOrThrow>(self_xla_tensor.GetIrValue(), val)));
});
m.def("_xla_all_reduce",
[](const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<ir::Value>& token, double scale,
Expand Down Expand Up @@ -1196,6 +1218,8 @@ void InitXlaModuleBindings(py::module m) {
});

BuildProfilerSubmodule(&m);
m.def("_ltc_set_dynamic_shapes_mode",
[]() { lazy_tensors::Shape::SetDynamicMode(); });
}

} // namespace
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const {

torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape,
torch::lazy::hash_t hash_seed) {
if (xla::Shape::IsDynamicMode()) { //TODO: Milad to implement this
torch::lazy::hash_t h = torch_xla::HashCombine(
op.hash(), torch::lazy::Hash(shape.dim()));
return torch::lazy::HashCombine(h, hash_seed); //TODO: is this how DS wants to be designed?
}

torch::lazy::hash_t h =
torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString()));
return torch::lazy::HashCombine(h, hash_seed);
Expand Down
37 changes: 37 additions & 0 deletions torch_xla/csrc/ops/dynamic_expand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "torch_xla/csrc/ops/dynamic_expand.h"


namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input,
const Value& size) {
xla::Shape shape = size.shape();
auto lower_for_shape_fn =
[shape](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildDynamicExpand(operands[0], operands[1], shape);
};
return InferOutputShape({input.shape(), shape}, lower_for_shape_fn);
}

} // namespace

DynamicExpand2::DynamicExpand2(Value& lhs, Value& sz)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_expand2")), {lhs, sz},
[&]() { return NodeOutputShape(input, sz); },
/*num_outputs=*/1, torch::lazy::MHash(sz.shape)), /*TODO: cast lazy shape to xla shape */
) {}

XlaOpVector Lower(LoweringContext* loctx) const {
XLA_CHECK(operands().size() == 2);
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp size_ = loctx->GetOutputOp(operand(1)); // TODO: confirm with Nick if .input is needed
xla::Shape shape_ = operand(1).shape(); //TODO: cast lazy::shape to xla::Shape (xla::ShapeUtil::CastShape ?) - confirm with Nick
return ReturnOp(BuildDynamicExpand(input, size_, shape_), loctx);
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/dynamic_expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "torch_xla/csrc/ir.h"
#include "lazy_tensor_core/csrc/ts_backend/ts_shape_inference.h"
#include "lazy_tensor_core/csrc/ts_backend/ts_node_lowering.h"

namespace torch_lazy_tensors {
namespace ir {
namespace ops {

class DynamicExpand2 : public Node {
public:
DynamicExpand2(Value& lhs, Value& sz);

XlaOpVector Lower(LoweringContext* loctx) const override;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
24 changes: 24 additions & 0 deletions torch_xla/csrc/ops/dynamic_size.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "torch_xla/csrc/ops/dynamic_size.h"


namespace torch_xla {
namespace ir {
namespace ops {

DynamicSize2::DynamicSize2(Value lhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_size2")), lhs,
{ir::GetShapeFromTsValue(lhs)}) {}

XlaOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function, //TODO: milad fix this
LoweringContext* loctx) const override {

CHECK(operands().size() == 1);
auto graph = function->graph();

auto size_val = graph->insert(at::aten::size, {loctx->GetOutputOp(operands().at(0))});
return {size_val};
}

} // namespace ops
} // namespace ir
} // namespace torch_lazy_tensors
21 changes: 21 additions & 0 deletions torch_xla/csrc/ops/dynamic_size.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include "torch_xla/csrc/ir.h"
#include "lazy_tensor_core/csrc/ts_backend/ts_shape_inference.h"
#include "lazy_tensor_core/csrc/ts_backend/ts_node_lowering.h"

namespace torch_xla {
namespace ir {
namespace ops {

class DynamicSize2 : public Node {
public:
DynamicSize2(Value lhs);

XlaOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function, //TODO: milad fix this
LoweringContext* loctx) const override;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
20 changes: 20 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
#include <stdexcept>
#include <unordered_set>

#include "torch/csrc/autograd/function.h"
#include "torch_xla/csrc/ops/sum_to_size.h"
#include "torch_xla/csrc/ops/dynamic_size.h"
#include "torch_xla/csrc/ops/dynamic_expand.h"
#include "torch_xla/csrc/aten_ltc_bridge.h"

#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/literal_util.h"
Expand Down Expand Up @@ -46,6 +52,20 @@
namespace torch_xla {
namespace {

static at::Tensor generate_size_check_for(at::Tensor& input, at::Tensor& grad) {
auto grad_lt = bridge::GetOrCreateLtcTensor(grad, GetCurrentDevice());
auto input_lt = bridge::GetLtcTensor(input);
auto sz = ir::MakeNode<ir::ops::DynamicSize2>(input_lt.GetIrValue());
return bridge::AtenFromLtcTensor(
grad_lt.CreateFrom(ir::MakeNode<ir::ops::SumToOrThrow>(grad_lt.GetIrValue(), sz)));
}

static struct InitializeHandlers {
InitializeHandlers() {
torch::autograd::setSumToOrThrowHandler(generate_size_check_for);
}
} init_handlers;

struct TlsData {
void Reset() { trim_counter = 0; }

Expand Down