Skip to content

Commit ae92c63

Browse files
authored
Merge pull request #402 from dskkato/add_pluggable_device
add PluggableDeviceLibrary
2 parents 33977d3 + 5d7572d commit ae92c63

File tree

8 files changed

+88
-7
lines changed

8 files changed

+88
-7
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ doctest = false
1717

1818
# Prevent downloading or building TensorFlow when building docs on docs.rs.
1919
[package.metadata.docs.rs]
20-
features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager"]
20+
features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager", "experimental"]
2121

2222
[dependencies]
2323
libc = "0.2.132"
@@ -41,6 +41,7 @@ tempdir = "0.3"
4141

4242
[features]
4343
default = ["tensorflow-sys"]
44+
experimental = ["tensorflow-sys/experimental"]
4445
tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"]
4546
tensorflow_unstable = []
4647
tensorflow_runtime_linking = ["tensorflow-sys-runtime"]

src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ use tensorflow_sys as tf;
5454
#[cfg(feature = "tensorflow_runtime_linking")]
5555
use tensorflow_sys_runtime as tf;
5656

57+
#[cfg(feature = "experimental")]
58+
mod pluggable_device;
59+
#[cfg(feature = "experimental")]
60+
pub use pluggable_device::*;
61+
5762
////////////////////////
5863

5964
/// Will panic if `msg` contains an embedded 0 byte.

src/pluggable_device.rs

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use crate::{Result, Status};
2+
use std::ffi::CString;
3+
use tensorflow_sys as tf;
4+
5+
/// PluggableDeviceLibrary handler.
6+
#[derive(Debug)]
7+
pub struct PluggableDeviceLibrary {
8+
inner: *mut tf::TF_Library,
9+
}
10+
11+
impl PluggableDeviceLibrary {
12+
/// Load the library specified by library_filename and register the pluggable
13+
/// device and related kernels present in that library. This function is not
14+
/// supported on embedded on mobile and embedded platforms and will fail if
15+
/// called.
16+
///
17+
/// Pass "library_filename" to a platform-specific mechanism for dynamically
18+
/// loading a library. The rules for determining the exact location of the
19+
/// library are platform-specific and are not documented here.
20+
pub fn load(library_filename: &str) -> Result<PluggableDeviceLibrary> {
21+
let status = Status::new();
22+
let library_filename = CString::new(library_filename)?;
23+
let lib_handle =
24+
unsafe { tf::TF_LoadPluggableDeviceLibrary(library_filename.as_ptr(), status.inner) };
25+
status.into_result()?;
26+
27+
Ok(PluggableDeviceLibrary { inner: lib_handle })
28+
}
29+
}
30+
31+
impl Drop for PluggableDeviceLibrary {
32+
/// Frees the memory associated with the library handle.
33+
/// Does NOT unload the library.
34+
fn drop(&mut self) {
35+
unsafe {
36+
tf::TF_DeletePluggableDeviceLibraryHandle(self.inner);
37+
}
38+
}
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[ignore]
46+
#[test]
47+
fn load_pluggable_device_library() {
48+
let library_filename = "path-to-library";
49+
let pluggable_divice_library = PluggableDeviceLibrary::load(library_filename);
50+
dbg!(&pluggable_divice_library);
51+
assert!((pluggable_divice_library.is_ok()));
52+
}
53+
}

tensorflow-sys/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ zip = "0.6.4"
3636
[features]
3737
tensorflow_gpu = []
3838
eager = []
39+
experimental = []
3940
# This is for testing purposes; users should not use this.
4041
examples_system_alloc = []
4142
private-docs-rs = [] # DO NOT RELY ON THIS

tensorflow-sys/generate_bindgen_rs.sh

+12-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,22 @@ if ! which bindgen > /dev/null; then
77
exit 1
88
fi
99

10-
include_dir="$HOME/git/tensorflow"
10+
include_dir="../../tensorflow"
1111

12+
# Export C-API
1213
bindgen_options_c_api="--allowlist-function TF_.+ --allowlist-type TF_.+ --allowlist-var TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions"
13-
cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}"
14+
cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}"
1415
echo ${cmd}
1516
${cmd}
1617

17-
bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions"
18-
cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}"
18+
# Export PluggableDeviceLibrary from C-API experimental
19+
bindgen_options_c_api_experimental="--allowlist-function TF_.+PluggableDeviceLibrary.* --blocklist-type TF_.+ --size_t-is-usize"
20+
cmd="bindgen ${bindgen_options_c_api_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/c_api_experimental.rs -- -I ${include_dir}"
21+
echo ${cmd}
22+
${cmd}
23+
24+
# Export Eager C-API
25+
bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions --no-layout-tests"
26+
cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}"
1927
echo ${cmd}
2028
${cmd}
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
/* automatically generated by rust-bindgen 0.59.1 */
2+
3+
extern "C" {
4+
pub fn TF_LoadPluggableDeviceLibrary(
5+
library_filename: *const ::std::os::raw::c_char,
6+
status: *mut TF_Status,
7+
) -> *mut TF_Library;
8+
}
9+
extern "C" {
10+
pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library);
11+
}

tensorflow-sys/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ mod eager;
77
#[cfg(feature = "eager")]
88
pub use eager::*;
99
include!("c_api.rs");
10+
#[cfg(feature = "experimental")]
11+
include!("c_api_experimental.rs");
1012

1113
pub use crate::TF_AttrType::*;
1214
pub use crate::TF_Code::*;

test-all

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ run cargo run --example regression
5858
run cargo run --example xor
5959
run cargo run --features tensorflow_unstable --example expressions
6060
run cargo run --features eager --example mobilenetv3
61-
run cargo doc -vv --features tensorflow_unstable,ndarray,eager
62-
run cargo doc -vv --features tensorflow_unstable,ndarray,eager,private-docs-rs
61+
run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager
62+
run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager,private-docs-rs
6363
# TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1)
6464
(cd tensorflow-sys && run cargo run --example multiplication)
6565
(cd tensorflow-sys && run cargo run --example tf_version)

0 commit comments

Comments
 (0)