diff --git a/R/RcppExports.R b/R/RcppExports.R index fedebf5537..cfd249f25f 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -11205,6 +11205,14 @@ cpp_jit_compile_get_function <- function(cu, name) { .Call('_torch_cpp_jit_compile_get_function', PACKAGE = 'torchpkg', cu, name) } +cpp_jit_get_all_operators_names <- function() { + .Call('_torch_cpp_jit_get_all_operators_names', PACKAGE = 'torchpkg') +} + +cpp_jit_get_operator_from_name <- function(x) { + .Call('_torch_cpp_jit_get_operator_from_name', PACKAGE = 'torchpkg', x) +} + cpp_lantern_configure <- function(log) { invisible(.Call('_torch_cpp_lantern_configure', PACKAGE = 'torchpkg', log)) } diff --git a/inst/include/lantern/lantern.h b/inst/include/lantern/lantern.h index be49c0e226..c9047b139b 100644 --- a/inst/include/lantern/lantern.h +++ b/inst/include/lantern/lantern.h @@ -2288,6 +2288,33 @@ HOST_API int lantern_string_size (void* self) return ret; } +LANTERN_API void* (LANTERN_PTR _lantern_jit_get_all_operators_names) (); +HOST_API void* lantern_jit_get_all_operators_names () +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_get_all_operators_names(); + LANTERN_HOST_HANDLER; + return ret; +} + +LANTERN_API void* (LANTERN_PTR _lantern_jit_get_operation_schema) (void* name); +HOST_API void* lantern_jit_get_operation_schema (void* name) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_get_operation_schema(name); + LANTERN_HOST_HANDLER; + return ret; +} + +LANTERN_API void* (LANTERN_PTR _lantern_jit_FunctionSchema_name) (void* schema); +HOST_API void* lantern_jit_FunctionSchema_name (void* schema) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_FunctionSchema_name(schema); + LANTERN_HOST_HANDLER; + return ret; +} + /* Autogen Headers -- Start */ LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking); HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; } @@ -8321,6 +8348,9 @@ LOAD_SYMBOL(_lantern_cuda_device_stats); LOAD_SYMBOL(_lantern_cuda_get_runtime_version); LOAD_SYMBOL(_set_delete_lambda_fun); LOAD_SYMBOL(_lantern_string_size); +LOAD_SYMBOL(_lantern_jit_get_all_operators_names); +LOAD_SYMBOL(_lantern_jit_get_operation_schema); +LOAD_SYMBOL(_lantern_jit_FunctionSchema_name); /* Autogen Symbols -- Start */ LOAD_SYMBOL(_lantern__cast_byte_tensor_bool) LOAD_SYMBOL(_lantern__cast_char_tensor_bool) diff --git a/inst/include/lantern/types.h b/inst/include/lantern/types.h index f54e09bee5..b67d96ea65 100644 --- a/inst/include/lantern/types.h +++ b/inst/include/lantern/types.h @@ -74,6 +74,7 @@ void* bool_t(const bool& x); void* double_t(const double& x); void* Stream(const at::Stream& x); void* IValue(const torch::IValue& x); +void* FunctionSchema (const c10::FunctionSchema& x); namespace vector { void* string(const std::vector& x); @@ -147,6 +148,7 @@ LANTERN_FROM_RAW_DECL(bool_t, bool) LANTERN_FROM_RAW_DECL(double_t, double) LANTERN_FROM_RAW_DECL(Stream, at::Stream) LANTERN_FROM_RAW_DECL(IValue, torch::IValue) +LANTERN_FROM_RAW_DECL(FunctionSchema, c10::FunctionSchema) namespace optional { LANTERN_FROM_RAW_DECL(DimnameList, c10::optional) @@ -398,6 +400,8 @@ void* double_t(const double& x) { return make_ptr(x); } void* bool_t(const bool& x) { return make_ptr(x); } void* Stream(const at::Stream& x) { return make_ptr(x); } void* IValue(const at::IValue& x) { return make_ptr(x); } +void* FunctionSchema (const c10::FunctionSchema& x) { return make_ptr(x); } + namespace vector { @@ -527,6 +531,7 @@ LANTERN_FROM_RAW(bool_t, bool) LANTERN_FROM_RAW(double_t, double) LANTERN_FROM_RAW(Stream, at::Stream) LANTERN_FROM_RAW(IValue, torch::IValue) +LANTERN_FROM_RAW(FunctionSchema, c10::FunctionSchema) namespace optional { LANTERN_FROM_RAW_WRAPPED(DimnameList, self_contained::optional::DimnameList, diff --git a/lantern/include/lantern/lantern.h b/lantern/include/lantern/lantern.h index be49c0e226..c9047b139b 100644 --- a/lantern/include/lantern/lantern.h +++ b/lantern/include/lantern/lantern.h @@ -2288,6 +2288,33 @@ HOST_API int lantern_string_size (void* self) return ret; } +LANTERN_API void* (LANTERN_PTR _lantern_jit_get_all_operators_names) (); +HOST_API void* lantern_jit_get_all_operators_names () +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_get_all_operators_names(); + LANTERN_HOST_HANDLER; + return ret; +} + +LANTERN_API void* (LANTERN_PTR _lantern_jit_get_operation_schema) (void* name); +HOST_API void* lantern_jit_get_operation_schema (void* name) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_get_operation_schema(name); + LANTERN_HOST_HANDLER; + return ret; +} + +LANTERN_API void* (LANTERN_PTR _lantern_jit_FunctionSchema_name) (void* schema); +HOST_API void* lantern_jit_FunctionSchema_name (void* schema) +{ + LANTERN_CHECK_LOADED + void* ret = _lantern_jit_FunctionSchema_name(schema); + LANTERN_HOST_HANDLER; + return ret; +} + /* Autogen Headers -- Start */ LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking); HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; } @@ -8321,6 +8348,9 @@ LOAD_SYMBOL(_lantern_cuda_device_stats); LOAD_SYMBOL(_lantern_cuda_get_runtime_version); LOAD_SYMBOL(_set_delete_lambda_fun); LOAD_SYMBOL(_lantern_string_size); +LOAD_SYMBOL(_lantern_jit_get_all_operators_names); +LOAD_SYMBOL(_lantern_jit_get_operation_schema); +LOAD_SYMBOL(_lantern_jit_FunctionSchema_name); /* Autogen Symbols -- Start */ LOAD_SYMBOL(_lantern__cast_byte_tensor_bool) LOAD_SYMBOL(_lantern__cast_char_tensor_bool) diff --git a/lantern/include/lantern/types.h b/lantern/include/lantern/types.h index f54e09bee5..b67d96ea65 100644 --- a/lantern/include/lantern/types.h +++ b/lantern/include/lantern/types.h @@ -74,6 +74,7 @@ void* bool_t(const bool& x); void* double_t(const double& x); void* Stream(const at::Stream& x); void* IValue(const torch::IValue& x); +void* FunctionSchema (const c10::FunctionSchema& x); namespace vector { void* string(const std::vector& x); @@ -147,6 +148,7 @@ LANTERN_FROM_RAW_DECL(bool_t, bool) LANTERN_FROM_RAW_DECL(double_t, double) LANTERN_FROM_RAW_DECL(Stream, at::Stream) LANTERN_FROM_RAW_DECL(IValue, torch::IValue) +LANTERN_FROM_RAW_DECL(FunctionSchema, c10::FunctionSchema) namespace optional { LANTERN_FROM_RAW_DECL(DimnameList, c10::optional) @@ -398,6 +400,8 @@ void* double_t(const double& x) { return make_ptr(x); } void* bool_t(const bool& x) { return make_ptr(x); } void* Stream(const at::Stream& x) { return make_ptr(x); } void* IValue(const at::IValue& x) { return make_ptr(x); } +void* FunctionSchema (const c10::FunctionSchema& x) { return make_ptr(x); } + namespace vector { @@ -527,6 +531,7 @@ LANTERN_FROM_RAW(bool_t, bool) LANTERN_FROM_RAW(double_t, double) LANTERN_FROM_RAW(Stream, at::Stream) LANTERN_FROM_RAW(IValue, torch::IValue) +LANTERN_FROM_RAW(FunctionSchema, c10::FunctionSchema) namespace optional { LANTERN_FROM_RAW_WRAPPED(DimnameList, self_contained::optional::DimnameList, diff --git a/lantern/src/Compile.cpp b/lantern/src/Compile.cpp index 6b01bca232..680315a031 100644 --- a/lantern/src/Compile.cpp +++ b/lantern/src/Compile.cpp @@ -30,4 +30,32 @@ void* _lantern_jit_compile_get_method(void* cu, void* name) { auto name_ = from_raw::string(name); return (void*)from_raw::CompilationUnit(cu).find_function(name_); LANTERN_FUNCTION_END -} \ No newline at end of file +} + +void * _lantern_jit_get_all_operators_names () { + LANTERN_FUNCTION_START + auto ops = torch::jit::getAllOperators(); + std::vector names; + for (const auto& op : ops) { + names.push_back(op->schema().name()); + } + return make_raw::vector::string(names); + LANTERN_FUNCTION_END +} + +void* _lantern_jit_get_operation_schema (void* name) { + LANTERN_FUNCTION_START + auto name_ = from_raw::string(name); + auto op_name = c10::Symbol::fromQualString(name_); + auto op = torch::jit::getAllOperatorsFor(op_name); + return make_raw::FunctionSchema(op[0]->schema()); + LANTERN_FUNCTION_END +} + +void* _lantern_jit_FunctionSchema_name (void* schema) { + auto schema_ = from_raw::FunctionSchema(schema); + return make_raw::string(schema_.name()); +} + +// https://cs.github.com/pytorch/pytorch/blob/47834679ba2f869e66450a74e2add4c04f0006e9/torch/csrc/jit/python/pybind_utils.h#L874 +// https://cs.github.com/pytorch/pytorch/blob/47834679ba2f869e66450a74e2add4c04f0006e9/torch/csrc/jit/python/pybind_utils.h#L1137 \ No newline at end of file diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index d329c81726..091a2ba6a2 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -36539,6 +36539,27 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// cpp_jit_get_all_operators_names +torch::vector::string cpp_jit_get_all_operators_names(); +RcppExport SEXP _torch_cpp_jit_get_all_operators_names() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(cpp_jit_get_all_operators_names()); + return rcpp_result_gen; +END_RCPP +} +// cpp_jit_get_operator_from_name +torch::string cpp_jit_get_operator_from_name(torch::string x); +RcppExport SEXP _torch_cpp_jit_get_operator_from_name(SEXP xSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< torch::string >::type x(xSEXP); + rcpp_result_gen = Rcpp::wrap(cpp_jit_get_operator_from_name(x)); + return rcpp_result_gen; +END_RCPP +} // cpp_lantern_configure void cpp_lantern_configure(int log); RcppExport SEXP _torch_cpp_lantern_configure(SEXP logSEXP) { @@ -40597,6 +40618,8 @@ static const R_CallMethodDef CallEntries[] = { {"_torch_cpp_jit_compile", (DL_FUNC) &_torch_cpp_jit_compile, 1}, {"_torch_cpp_jit_compile_list_methods", (DL_FUNC) &_torch_cpp_jit_compile_list_methods, 1}, {"_torch_cpp_jit_compile_get_function", (DL_FUNC) &_torch_cpp_jit_compile_get_function, 2}, + {"_torch_cpp_jit_get_all_operators_names", (DL_FUNC) &_torch_cpp_jit_get_all_operators_names, 0}, + {"_torch_cpp_jit_get_operator_from_name", (DL_FUNC) &_torch_cpp_jit_get_operator_from_name, 1}, {"_torch_cpp_lantern_configure", (DL_FUNC) &_torch_cpp_lantern_configure, 1}, {"_torch_cpp_lantern_version", (DL_FUNC) &_torch_cpp_lantern_version, 0}, {"_torch_cpp_lantern_init", (DL_FUNC) &_torch_cpp_lantern_init, 1}, diff --git a/src/jit-compile.cpp b/src/jit-compile.cpp index 7adb6ec740..307475c9e3 100644 --- a/src/jit-compile.cpp +++ b/src/jit-compile.cpp @@ -29,3 +29,13 @@ SEXP cpp_jit_compile_get_function(SEXP cu, XPtrTorchstring name) { return R_NilValue; } } + +// [[Rcpp::export]] +torch::vector::string cpp_jit_get_all_operators_names () { + return lantern_jit_get_all_operators_names(); +} + +// [[Rcpp::export]] +torch::string cpp_jit_get_operator_from_name (torch::string x) { + return lantern_jit_FunctionSchema_name(lantern_jit_get_operation_schema(x.get())); +} diff --git a/src/tensor.cpp b/src/tensor.cpp index b2a38ece2b..bc3cad7595 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -128,6 +128,12 @@ torch::Tensor torch_tensor_cpp(SEXP x, Rcpp::Nullable dtype, break; } } + case NILSXP: { + cdtype = lantern_Dtype_bool(); + final_type = dtype.isNull() ? torch::Dtype(lantern_Dtype_bool()) + : Rcpp::as(dtype); + break; + } default: { Rcpp::stop("R type not handled"); } diff --git a/tests/testthat/test-indexing.R b/tests/testthat/test-indexing.R index efd8f932dc..c3161a1979 100644 --- a/tests/testthat/test-indexing.R +++ b/tests/testthat/test-indexing.R @@ -250,3 +250,15 @@ test_that("regression test for #695", { as.array(a)[c(1, 3), , c(1, 3)] ) }) + +test_that("NULL tensor", { + + x <- torch_tensor(NULL) + expect_true(x$dtype == torch_bool()) + expect_equal(x$shape, 0) + + # subsetting shouldn't crash + expect_error(x[1], regexp = "out of bounds") + expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds") + +}) diff --git a/tools/create-decls.R b/tools/create-decls.R index 97fb8c9a18..dff470dd6d 100644 --- a/tools/create-decls.R +++ b/tools/create-decls.R @@ -30,8 +30,8 @@ make_load_symbols <- function(decls) { decls <- readr::read_lines( " -void _lantern_autograd_edge_list_delete (void* x) -void _lantern_autograd_edge_delete (void* x) +void* _lantern_jit_get_operation_schema (void* name) +void* _lantern_jit_FunctionSchema_name (void* schema) " )