Skip to content

Commit e4f9134

Browse files
authored
Merge pull request #331 from dskkato/eager_api_wrappers
Implement Eager api wrappers for Context and TensorHandle
2 parents de27f4e + 145a11d commit e4f9134

16 files changed

+16065
-0
lines changed

src/eager.rs

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//! C API extensions to experiment with eager execution of kernels.
2+
//!
3+
//! WARNING: The underlying C-API for the eager execution is not guaranteed to be
4+
//! stable and can be changed without notice, which could result in breaking.
5+
6+
mod context;
7+
pub use context::*;
8+
9+
mod tensor_handle;
10+
pub use tensor_handle::*;

src/eager/context.rs

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
use std::ffi::CStr;
2+
3+
use tensorflow_sys as tf;
4+
5+
use crate::{Device, Result, Status};
6+
7+
/// Options that can be passed during context creation.
8+
#[derive(Debug)]
9+
pub struct ContextOptions {
10+
inner: *mut tf::TFE_ContextOptions,
11+
}
12+
impl_new!(
13+
ContextOptions,
14+
TFE_NewContextOptions,
15+
"Creates a blank set of context options."
16+
);
17+
impl_drop!(ContextOptions, TFE_DeleteContextOptions);
18+
19+
impl ContextOptions {
20+
/// Set the config.
21+
///
22+
/// `config` should be a serialized [`ConfigProto` proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto).
23+
/// Returns an error if config was not parsed successfully as a `ConfigProto`.
24+
pub fn set_config(&mut self, config: &[u8]) -> Result<()> {
25+
let mut status = Status::new();
26+
unsafe {
27+
tf::TFE_ContextOptionsSetConfig(
28+
self.inner,
29+
config.as_ptr() as *const _,
30+
config.len(),
31+
status.inner(),
32+
);
33+
}
34+
status.into_result()
35+
}
36+
37+
/// Sets the default execution mode (sync/async).
38+
pub fn set_async(&mut self, enable: bool) {
39+
unsafe {
40+
tf::TFE_ContextOptionsSetAsync(self.inner, enable as u8);
41+
}
42+
}
43+
}
44+
45+
/// Context under which operations/functions are executed.
46+
#[derive(Debug)]
47+
pub struct Context {
48+
pub(crate) inner: *mut tf::TFE_Context,
49+
}
50+
impl_drop!(Context, TFE_DeleteContext);
51+
52+
impl Context {
53+
/// Create a Context
54+
pub fn new(opts: ContextOptions) -> Result<Self> {
55+
let status = Status::new();
56+
57+
let inner = unsafe { tf::TFE_NewContext(opts.inner, status.inner) };
58+
if inner.is_null() {
59+
Err(status)
60+
} else {
61+
Ok(Context { inner })
62+
}
63+
}
64+
65+
/// Lists all devices in a context.
66+
pub fn device_list(&self) -> Result<Vec<Device>> {
67+
let status = Status::new();
68+
unsafe {
69+
let list = tf::TFE_ContextListDevices(self.inner, status.inner);
70+
if !status.is_ok() {
71+
return Err(status);
72+
}
73+
let result = (|| {
74+
let n = tf::TF_DeviceListCount(list);
75+
let mut devices = Vec::with_capacity(n as usize);
76+
for i in 0..n {
77+
let c_name = tf::TF_DeviceListName(list, i, status.inner);
78+
if !status.is_ok() {
79+
return Err(status);
80+
}
81+
let c_type = tf::TF_DeviceListType(list, i, status.inner);
82+
if !status.is_ok() {
83+
return Err(status);
84+
}
85+
let bytes = tf::TF_DeviceListMemoryBytes(list, i, status.inner);
86+
if !status.is_ok() {
87+
return Err(status);
88+
}
89+
let incarnation = tf::TF_DeviceListIncarnation(list, i, status.inner);
90+
if !status.is_ok() {
91+
return Err(status);
92+
}
93+
devices.push(Device {
94+
name: CStr::from_ptr(c_name).to_str()?.to_string(),
95+
device_type: CStr::from_ptr(c_type).to_str()?.to_string(),
96+
memory_bytes: bytes,
97+
incarnation,
98+
});
99+
}
100+
Ok(devices)
101+
})();
102+
tf::TF_DeleteDeviceList(list);
103+
result
104+
}
105+
}
106+
107+
/// Clears the internal caches in the context.
108+
pub fn clear_caches(&mut self) {
109+
unsafe {
110+
tf::TFE_ContextClearCaches(self.inner);
111+
}
112+
}
113+
}
114+
115+
unsafe impl std::marker::Send for Context {}
116+
unsafe impl std::marker::Sync for Context {}
117+
118+
#[cfg(test)]
119+
mod test {
120+
use super::*;
121+
122+
#[test]
123+
fn test_create_context() {
124+
let opts = ContextOptions::new();
125+
Context::new(opts).unwrap();
126+
}
127+
128+
#[test]
129+
fn test_create_async_context() {
130+
let mut opts = ContextOptions::new();
131+
opts.set_async(true);
132+
Context::new(opts).unwrap();
133+
}
134+
135+
#[test]
136+
fn test_context_set_config() {
137+
use crate::protos::config::{ConfigProto, GPUOptions};
138+
use protobuf::Message;
139+
140+
let gpu_options = GPUOptions {
141+
per_process_gpu_memory_fraction: 0.5,
142+
allow_growth: true,
143+
..Default::default()
144+
};
145+
let mut config = ConfigProto::new();
146+
config.set_gpu_options(gpu_options);
147+
148+
let mut buf = vec![];
149+
config.write_to_writer(&mut buf).unwrap();
150+
151+
let mut opts = ContextOptions::new();
152+
opts.set_config(&buf).unwrap();
153+
Context::new(opts).unwrap();
154+
}
155+
156+
#[test]
157+
fn test_device_list() {
158+
let opts = ContextOptions::new();
159+
let ctx = Context::new(opts).unwrap();
160+
161+
let devices = ctx.device_list().unwrap();
162+
for d in &devices {
163+
assert_ne!(String::from(""), d.name);
164+
}
165+
}
166+
}

0 commit comments

Comments
 (0)