Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(exla): add LU custom_call #1549

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion exla/c_src/exla/custom_calls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
void lu_cpu_custom_call_f32(void *out[], const void *in[]);
void lu_cpu_custom_call_f64(void *out[], const void *in[]);
void lu_cpu_custom_call_f16(void *out[], const void *in[]);
void lu_cpu_custom_call_bf16(void *out[], const void *in[]);
void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);

Expand All @@ -12,4 +16,8 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_cu
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16);
95 changes: 95 additions & 0 deletions exla/c_src/exla/custom_calls/lu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#pragma once

#include "Eigen/LU";

template <typename DataType>
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;

Eigen::Map<RowMajorMatrix> input(in, n, n);
Eigen::PartialPivLU<RowMajorMatrix> lu = input.partialPivLu();

// Get the permutation matrix P and convert to indices
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P = lu.permutationP();
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {
p_out[i * n + j] = static_cast<uint8_t>(P.indices()[i] == j ? 1 : 0);
}
}

// Get L and U matrices
RowMajorMatrix L = lu.matrixLU().template triangularView<Eigen::UnitLower>();
RowMajorMatrix U = lu.matrixLU().template triangularView<Eigen::Upper>();

// Copy L matrix
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {

if (j < i) {
l_out[i * n + j] = static_cast<DataType>(L(i, j));
} else if (j == i) {
l_out[i * n + j] = static_cast<DataType>(1.0);
} else {
l_out[i * n + j] = static_cast<DataType>(0.0);
}
}
}

// Copy U matrix
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {
if (j >= i) {
u_out[i * n + j] = static_cast<DataType>(U(i, j));
} else {
u_out[i * n + j] = static_cast<DataType>(0.0);
}
}
}
}

template <typename DataType>
void lu_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

uint64_t *dim_sizes = (uint64_t *)in[1];
uint64_t num_operand_dims = dim_sizes[0];
uint64_t num_p_dims = dim_sizes[1];
uint64_t num_l_dims = dim_sizes[2];
uint64_t num_u_dims = dim_sizes[3];

uint64_t *operand_dims_ptr = (uint64_t *)in[2];
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

uint64_t *p_dims_ptr = (uint64_t *)in[3];
std::vector<uint64_t> p_dims(p_dims_ptr, p_dims_ptr + num_p_dims);

uint64_t *l_dims_ptr = (uint64_t *)in[4];
std::vector<uint64_t> l_dims(l_dims_ptr, l_dims_ptr + num_l_dims);

uint64_t *u_dims_ptr = (uint64_t *)in[5];
std::vector<uint64_t> u_dims(u_dims_ptr, u_dims_ptr + num_u_dims);

uint64_t n = l_dims[l_dims.size() - 1];

auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);

uint64_t batch_items = 1;
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
}

uint8_t *p = (uint8_t *)out[0];
DataType *l = (DataType *)out[1];
DataType *u = (DataType *)out[2];

uint64_t stride = n * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_lu_cpu_custom_call<DataType>(
p + i * stride,
l + i * stride,
u + i * stride,
operand + i * stride,
n);
}
}
6 changes: 6 additions & 0 deletions exla/c_src/exla/custom_calls/lu_bf16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include "lu.h"
#include "../exla_types.h"

void lu_cpu_custom_call_bf16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::bfloat16>(out, in);
}
6 changes: 6 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include "lu.h"
#include "../exla_types.h"

void lu_cpu_custom_call_f16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::float16>(out, in);
}
5 changes: 5 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f32.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "lu.h"

void lu_cpu_custom_call_f32(void *out[], const void *in[]) {
lu_cpu_custom_call<float>(out, in);
}
5 changes: 5 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "lu.h"

void lu_cpu_custom_call_f64(void *out[], const void *in[]) {
lu_cpu_custom_call<double>(out, in);
}
8 changes: 4 additions & 4 deletions exla/c_src/exla/custom_calls/qr.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ void qr_cpu_custom_call(void *out[], const void *in[]) {
DataType *q = (DataType *)out[0];
DataType *r = (DataType *)out[1];

uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
uint64_t inner_stride = m * n * sizeof(DataType);
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2];
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2];
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_qr_cpu_custom_call<DataType>(
(DataType *)out[0] + i * q_stride,
(DataType *)out[1] + i * r_stride,
operand + i * inner_stride * sizeof(DataType),
operand + i * inner_stride,
m, k, n, complete);
}
}
41 changes: 37 additions & 4 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,43 @@ defmodule EXLA.Defn do
end
end

defp cached_recur_operator(
:lu,
%T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}},
state,
cache
) do
%{type: {p_type_kind, _}} = p_expr
%{type: {out_type_kind, _}} = l_expr

if state.client.platform != :host do
raise ArgumentError, "XLA does not currently support the LU operation on non-host devices"
end

if p_type_kind == :c or out_type_kind == :c do
raise ArgumentError, "XLA does not currently support the LU operation for complex inputs"
end

{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

tensor =
if op_type(tensor) != u_expr.type do
to_type(tensor, u_expr.type)
else
tensor
end

{p, l, u} =
Value.lu(
tensor,
expr_to_typespec(p_expr),
expr_to_typespec(l_expr),
expr_to_typespec(u_expr)
)

{[p, l, u], cache}
end

defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do
{op, cache} = recur_operator(expr, state, cache)
{_, cache} = recur_operator(token, state, cache)
Expand Down Expand Up @@ -772,10 +809,6 @@ defmodule EXLA.Defn do
end
end

defp to_operator(:lu, [{_, _, _}, _tensor, _opts], _ans, _state) do
raise ArgumentError, "XLA does not currently support the LU operation"
end

## to_operator element-wise

defp to_operator(:negate, [%Value{} = op], ans, _state),
Expand Down
75 changes: 75 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,81 @@ defmodule EXLA.MLIR.Value do
{q, r}
end

def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: _p_type, shape: p_shape} = p_typespec
%{type: l_type, shape: l_shape} = l_typespec
%{type: u_type, shape: u_shape} = u_typespec

dim_sizes = [
tuple_size(op_shape),
tuple_size(p_shape),
tuple_size(l_shape),
tuple_size(u_shape)
]

operand_dims = Tuple.to_list(op_shape)
p_dims = Tuple.to_list(p_shape)
l_dims = Tuple.to_list(l_shape)
u_dims = Tuple.to_list(u_shape)

dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)}))
l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)}))
u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)}))
operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims]

# Force P to always b u8 to avoid requiring too many template instances during custom_call registration
p_result_type = type_tensor({:u, 8}, p_shape)
l_result_type = type_tensor(l_type, l_shape)
u_result_type = type_tensor(u_type, u_shape)
result_types = [type_tuple([p_result_type, l_result_type, u_result_type])]

call_target_name =
case op_type do
{:f, 32} ->
"lu_cpu_custom_call_f32"

{:f, 64} ->
"lu_cpu_custom_call_f64"

{:f, 16} ->
"lu_cpu_custom_call_f16"

{:bf, 16} ->
"lu_cpu_custom_call_bf16"

type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "LU decomposition not supported on :host device for type #{inspect(type)}"
end

attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]

result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()

# This is not the best approach, but the alternative would require many more template instances
u8_typespec = Typespec.to_type(p_typespec, {:u, 8})
p = get_tuple_element(result, 0, u8_typespec)

p =
if u8_typespec != p_typespec do
convert(p, p_typespec)
else
p
end

l = get_tuple_element(result, 1, l_typespec)
u = get_tuple_element(result, 2, u_typespec)

{p, l, u}
end

def get_tuple_element(%Value{function: func} = operand, index, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [index: attr_i32(index)]
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ defmodule EXLA.MixProject do
File.rm_rf!("cache/#{@version}/libexla.so")

Mix.shell().info("Removing libexla.so cache at #{cached_so}")
File.rm!(cached_so)
File.rm_rf!(cached_so)
end

if cached? do
Expand Down
22 changes: 15 additions & 7 deletions exla/test/exla/nx_linalg_doctest_test.exs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
defmodule EXLA.MLIR.NxLinAlgDoctestTest do
use EXLA.Case, async: true

@invalid_type_error_doctests [svd: 2, pinv: 2, matrix_rank: 2]
@invalid_type_error_doctests [
svd: 2,
pinv: 2
]

@function_clause_error_doctests [
norm: 2,
lu: 2,
solve: 2,
solve: 2
]

@rounding_error_doctests [
triangular_solve: 3,
eigh: 2,
cholesky: 1,
least_squares: 3,
determinant: 1,
invert: 1,
matrix_power: 2
matrix_power: 2,
lu: 2
]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3]

@excluded_doctests @function_clause_error_doctests ++
@rounding_error_doctests ++
Expand Down
Loading