Skip to content

Commit 7af065e

Browse files
feat(exla): add LU custom_call (#1549)
Co-authored-by: José Valim <[email protected]>
1 parent 9d73de2 commit 7af065e

File tree

11 files changed

+258
-17
lines changed

11 files changed

+258
-17
lines changed

exla/c_src/exla/custom_calls.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]);
44
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
55
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
66
void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
7+
void lu_cpu_custom_call_f32(void *out[], const void *in[]);
8+
void lu_cpu_custom_call_f64(void *out[], const void *in[]);
9+
void lu_cpu_custom_call_f16(void *out[], const void *in[]);
10+
void lu_cpu_custom_call_bf16(void *out[], const void *in[]);
711
void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
812
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);
913

@@ -12,4 +16,8 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_cu
1216
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
1317
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
1418
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
15-
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
19+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
20+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64);
21+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32);
22+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16);
23+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16);

exla/c_src/exla/custom_calls/lu.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#pragma once
2+
3+
#include "Eigen/LU";
4+
5+
template <typename DataType>
6+
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) {
7+
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
8+
9+
Eigen::Map<RowMajorMatrix> input(in, n, n);
10+
Eigen::PartialPivLU<RowMajorMatrix> lu = input.partialPivLu();
11+
12+
// Get the permutation matrix P and convert to indices
13+
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P = lu.permutationP();
14+
for (uint64_t i = 0; i < n; i++) {
15+
for (uint64_t j = 0; j < n; j++) {
16+
p_out[i * n + j] = static_cast<uint8_t>(P.indices()[i] == j ? 1 : 0);
17+
}
18+
}
19+
20+
// Get L and U matrices
21+
RowMajorMatrix L = lu.matrixLU().template triangularView<Eigen::UnitLower>();
22+
RowMajorMatrix U = lu.matrixLU().template triangularView<Eigen::Upper>();
23+
24+
// Copy L matrix
25+
for (uint64_t i = 0; i < n; i++) {
26+
for (uint64_t j = 0; j < n; j++) {
27+
28+
if (j < i) {
29+
l_out[i * n + j] = static_cast<DataType>(L(i, j));
30+
} else if (j == i) {
31+
l_out[i * n + j] = static_cast<DataType>(1.0);
32+
} else {
33+
l_out[i * n + j] = static_cast<DataType>(0.0);
34+
}
35+
}
36+
}
37+
38+
// Copy U matrix
39+
for (uint64_t i = 0; i < n; i++) {
40+
for (uint64_t j = 0; j < n; j++) {
41+
if (j >= i) {
42+
u_out[i * n + j] = static_cast<DataType>(U(i, j));
43+
} else {
44+
u_out[i * n + j] = static_cast<DataType>(0.0);
45+
}
46+
}
47+
}
48+
}
49+
50+
template <typename DataType>
51+
void lu_cpu_custom_call(void *out[], const void *in[]) {
52+
DataType *operand = (DataType *)in[0];
53+
54+
uint64_t *dim_sizes = (uint64_t *)in[1];
55+
uint64_t num_operand_dims = dim_sizes[0];
56+
uint64_t num_p_dims = dim_sizes[1];
57+
uint64_t num_l_dims = dim_sizes[2];
58+
uint64_t num_u_dims = dim_sizes[3];
59+
60+
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
61+
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
62+
63+
uint64_t *p_dims_ptr = (uint64_t *)in[3];
64+
std::vector<uint64_t> p_dims(p_dims_ptr, p_dims_ptr + num_p_dims);
65+
66+
uint64_t *l_dims_ptr = (uint64_t *)in[4];
67+
std::vector<uint64_t> l_dims(l_dims_ptr, l_dims_ptr + num_l_dims);
68+
69+
uint64_t *u_dims_ptr = (uint64_t *)in[5];
70+
std::vector<uint64_t> u_dims(u_dims_ptr, u_dims_ptr + num_u_dims);
71+
72+
uint64_t n = l_dims[l_dims.size() - 1];
73+
74+
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
75+
76+
uint64_t batch_items = 1;
77+
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
78+
batch_items *= leading_dimensions[i];
79+
}
80+
81+
uint8_t *p = (uint8_t *)out[0];
82+
DataType *l = (DataType *)out[1];
83+
DataType *u = (DataType *)out[2];
84+
85+
uint64_t stride = n * n;
86+
87+
for (uint64_t i = 0; i < batch_items; i++) {
88+
single_matrix_lu_cpu_custom_call<DataType>(
89+
p + i * stride,
90+
l + i * stride,
91+
u + i * stride,
92+
operand + i * stride,
93+
n);
94+
}
95+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#include "lu.h"
2+
#include "../exla_types.h"
3+
4+
void lu_cpu_custom_call_bf16(void *out[], const void *in[]) {
5+
lu_cpu_custom_call<exla::bfloat16>(out, in);
6+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#include "lu.h"
2+
#include "../exla_types.h"
3+
4+
void lu_cpu_custom_call_f16(void *out[], const void *in[]) {
5+
lu_cpu_custom_call<exla::float16>(out, in);
6+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "lu.h"
2+
3+
void lu_cpu_custom_call_f32(void *out[], const void *in[]) {
4+
lu_cpu_custom_call<float>(out, in);
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "lu.h"
2+
3+
void lu_cpu_custom_call_f64(void *out[], const void *in[]) {
4+
lu_cpu_custom_call<double>(out, in);
5+
}

exla/c_src/exla/custom_calls/qr.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ void qr_cpu_custom_call(void *out[], const void *in[]) {
7373
DataType *q = (DataType *)out[0];
7474
DataType *r = (DataType *)out[1];
7575

76-
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
77-
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
78-
uint64_t inner_stride = m * n * sizeof(DataType);
76+
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2];
77+
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2];
78+
uint64_t inner_stride = m * n;
7979

8080
for (uint64_t i = 0; i < batch_items; i++) {
8181
single_matrix_qr_cpu_custom_call<DataType>(
8282
(DataType *)out[0] + i * q_stride,
8383
(DataType *)out[1] + i * r_stride,
84-
operand + i * inner_stride * sizeof(DataType),
84+
operand + i * inner_stride,
8585
m, k, n, complete);
8686
}
8787
}

exla/lib/exla/defn.ex

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,43 @@ defmodule EXLA.Defn do
544544
end
545545
end
546546

547+
defp cached_recur_operator(
548+
:lu,
549+
%T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}},
550+
state,
551+
cache
552+
) do
553+
%{type: {p_type_kind, _}} = p_expr
554+
%{type: {out_type_kind, _}} = l_expr
555+
556+
if state.client.platform != :host do
557+
raise ArgumentError, "XLA does not currently support the LU operation on non-host devices"
558+
end
559+
560+
if p_type_kind == :c or out_type_kind == :c do
561+
raise ArgumentError, "XLA does not currently support the LU operation for complex inputs"
562+
end
563+
564+
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
565+
566+
tensor =
567+
if op_type(tensor) != u_expr.type do
568+
to_type(tensor, u_expr.type)
569+
else
570+
tensor
571+
end
572+
573+
{p, l, u} =
574+
Value.lu(
575+
tensor,
576+
expr_to_typespec(p_expr),
577+
expr_to_typespec(l_expr),
578+
expr_to_typespec(u_expr)
579+
)
580+
581+
{[p, l, u], cache}
582+
end
583+
547584
defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do
548585
{op, cache} = recur_operator(expr, state, cache)
549586
{_, cache} = recur_operator(token, state, cache)
@@ -772,10 +809,6 @@ defmodule EXLA.Defn do
772809
end
773810
end
774811

775-
defp to_operator(:lu, [{_, _, _}, _tensor, _opts], _ans, _state) do
776-
raise ArgumentError, "XLA does not currently support the LU operation"
777-
end
778-
779812
## to_operator element-wise
780813

781814
defp to_operator(:negate, [%Value{} = op], ans, _state),

exla/lib/exla/mlir/value.ex

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,81 @@ defmodule EXLA.MLIR.Value do
815815
{q, r}
816816
end
817817

818+
def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do
819+
%{type: op_type, shape: op_shape} = get_typespec(value)
820+
%{type: _p_type, shape: p_shape} = p_typespec
821+
%{type: l_type, shape: l_shape} = l_typespec
822+
%{type: u_type, shape: u_shape} = u_typespec
823+
824+
dim_sizes = [
825+
tuple_size(op_shape),
826+
tuple_size(p_shape),
827+
tuple_size(l_shape),
828+
tuple_size(u_shape)
829+
]
830+
831+
operand_dims = Tuple.to_list(op_shape)
832+
p_dims = Tuple.to_list(p_shape)
833+
l_dims = Tuple.to_list(l_shape)
834+
u_dims = Tuple.to_list(u_shape)
835+
836+
dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
837+
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
838+
p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)}))
839+
l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)}))
840+
u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)}))
841+
operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims]
842+
843+
# Force P to always b u8 to avoid requiring too many template instances during custom_call registration
844+
p_result_type = type_tensor({:u, 8}, p_shape)
845+
l_result_type = type_tensor(l_type, l_shape)
846+
u_result_type = type_tensor(u_type, u_shape)
847+
result_types = [type_tuple([p_result_type, l_result_type, u_result_type])]
848+
849+
call_target_name =
850+
case op_type do
851+
{:f, 32} ->
852+
"lu_cpu_custom_call_f32"
853+
854+
{:f, 64} ->
855+
"lu_cpu_custom_call_f64"
856+
857+
{:f, 16} ->
858+
"lu_cpu_custom_call_f16"
859+
860+
{:bf, 16} ->
861+
"lu_cpu_custom_call_bf16"
862+
863+
type ->
864+
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
865+
raise "LU decomposition not supported on :host device for type #{inspect(type)}"
866+
end
867+
868+
attributes = [
869+
call_target_name: attr_string(call_target_name),
870+
backend_config: attr_string("Host")
871+
]
872+
873+
result =
874+
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()
875+
876+
# This is not the best approach, but the alternative would require many more template instances
877+
u8_typespec = Typespec.to_type(p_typespec, {:u, 8})
878+
p = get_tuple_element(result, 0, u8_typespec)
879+
880+
p =
881+
if u8_typespec != p_typespec do
882+
convert(p, p_typespec)
883+
else
884+
p
885+
end
886+
887+
l = get_tuple_element(result, 1, l_typespec)
888+
u = get_tuple_element(result, 2, u_typespec)
889+
890+
{p, l, u}
891+
end
892+
818893
def get_tuple_element(%Value{function: func} = operand, index, typespec) do
819894
result_types = typespecs_to_mlir_types([typespec])
820895
attributes = [index: attr_i32(index)]

exla/mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ defmodule EXLA.MixProject do
187187
File.rm_rf!("cache/#{@version}/libexla.so")
188188

189189
Mix.shell().info("Removing libexla.so cache at #{cached_so}")
190-
File.rm!(cached_so)
190+
File.rm_rf!(cached_so)
191191
end
192192

193193
if cached? do

0 commit comments

Comments
 (0)