Skip to content

Commit 479dccc

Browse files
committed
Fix ctypes, don't close fd, use smart pointers
1 parent 0a124ef commit 479dccc

File tree

6 files changed

+58
-29
lines changed

6 files changed

+58
-29
lines changed

src/python/library/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ else()
128128

129129
if (${TRITON_ENABLE_PERF_ANALYZER})
130130
set(perf_analyzer_arg --perf-analyzer ${CMAKE_INSTALL_PREFIX}/bin/perf_analyzer)
131-
endif()
131+
endif() # TRITON_ENABLE_PERF_ANALYZER
132+
if (${TRITON_ENABLE_GPU})
133+
set(gpu_arg --include-gpu-libs)
134+
endif() # TRITON_ENABLE_GPU
132135
set(linux_wheel_stamp_file "linux_stamp.whl")
133136
add_custom_command(
134137
OUTPUT "${linux_wheel_stamp_file}"
@@ -138,6 +141,7 @@ else()
138141
--dest-dir "${CMAKE_CURRENT_BINARY_DIR}/linux"
139142
--linux
140143
${perf_analyzer_arg}
144+
${gpu_arg}
141145
DEPENDS ${LINUX_WHEEL_DEPENDS}
142146
)
143147

@@ -178,7 +182,14 @@ if(${TRITON_ENABLE_PYTHON_GRPC})
178182
)
179183
endif() # TRITON_ENABLE_PYTHON_GRPC
180184

185+
# Generic Wheel
186+
set(WHEEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/generic")
187+
install(
188+
CODE "file(GLOB _Wheel \"${WHEEL_DIR}/triton*.whl\")"
189+
CODE "file(INSTALL \${_Wheel} DESTINATION \"${CMAKE_INSTALL_PREFIX}/python\")"
190+
)
181191

192+
# Platform-specific wheels
182193
if(WIN32)
183194
set(WHEEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/windows")
184195
else()

src/python/library/build_wheel.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def sed(pattern, replace, source, dest=None):
9191
required=False,
9292
help="Include windows specific artifacts.",
9393
)
94+
parser.add_argument(
95+
"--include-gpu-libs",
96+
action="store_true",
97+
required=False,
98+
help="Include gpu specific libraries",
99+
)
94100
parser.add_argument(
95101
"--perf-analyzer",
96102
type=str,
@@ -186,10 +192,11 @@ def sed(pattern, replace, source, dest=None):
186192
"tritonclient/utils/libcshm.so",
187193
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/libcshm.so"),
188194
)
189-
cpdir(
190-
"tritonclient/utils/cuda_shared_memory",
191-
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
192-
)
195+
if FLAGS.include_gpu_libs:
196+
cpdir(
197+
"tritonclient/utils/cuda_shared_memory",
198+
os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
199+
)
193200

194201
# Copy the pre-compiled perf_analyzer binary
195202
if FLAGS.perf_analyzer is not None:
@@ -212,10 +219,11 @@ def sed(pattern, replace, source, dest=None):
212219
os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/cshm.dll"),
213220
)
214221
# FIXME: Enable when Windows supports GPU tensors DLIS-4169
215-
# cpdir(
216-
# "tritonclient/utils/cuda_shared_memory",
217-
# os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
218-
# )
222+
# if FLAGS.include_gpu_libs:
223+
# cpdir(
224+
# "tritonclient/utils/cuda_shared_memory",
225+
# os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"),
226+
# )
219227

220228
shutil.copyfile("LICENSE.txt", os.path.join(FLAGS.whl_dir, "LICENSE.txt"))
221229
shutil.copyfile("setup.py", os.path.join(FLAGS.whl_dir, "setup.py"))

src/python/library/tritonclient/utils/shared_memory/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
205205
The numpy array generated using the contents of the specified shared
206206
memory region.
207207
"""
208-
shm_file = c_void_p()
208+
# Safe initializer for Unix case where shm_file must be dereferenced to
209+
# base in order to store file descriptor.
210+
safe_initializer = c_int(-1)
211+
shm_file = cast(byref(safe_initializer), c_void_p)
209212
region_offset = c_uint64()
210213
byte_size = c_uint64()
211214
shm_addr = c_char_p()
@@ -284,8 +287,10 @@ def destroy_shared_memory_region(shm_handle):
284287
SharedMemoryException
285288
If unable to unlink the shared memory region.
286289
"""
287-
288-
shm_file = c_void_p()
290+
# Safe initializer for Unix case where shm_file must be dereferenced to
291+
# base in order to store file descriptor.
292+
safe_initializer = c_int(-1)
293+
shm_file = cast(byref(safe_initializer), c_void_p)
289294
offset = c_uint64()
290295
byte_size = c_uint64()
291296
shm_addr = c_char_p()

src/python/library/tritonclient/utils/shared_memory/shared_memory.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ SharedMemoryHandleCreate(
5252
handle->triton_shm_name_ = triton_shm_name;
5353
handle->base_addr_ = shm_addr;
5454
handle->shm_key_ = shm_key;
55-
handle->platform_handle_ = new ShmFile(shm_file);
55+
handle->platform_handle_ = std::make_unique<ShmFile>(shm_file);
5656
handle->offset_ = offset;
5757
handle->byte_size_ = byte_size;
5858
return static_cast<void*>(handle);
@@ -97,8 +97,7 @@ SharedMemoryRegionMap(
9797
return -1;
9898
}
9999

100-
// close shared memory descriptor, return 0 if success else return -1
101-
return close(fd);
100+
return 0;
102101
#endif
103102
}
104103

@@ -119,29 +118,29 @@ SharedMemoryRegionCreate(
119118
DWORD high_order_size = (upperbound_size >> 32) & 0xFFFFFFFF;
120119
DWORD low_order_size = upperbound_size & 0xFFFFFFFF;
121120

122-
HANDLE local_handle = CreateFileMapping(
121+
HANDLE shm_file = CreateFileMapping(
123122
INVALID_HANDLE_VALUE, // use paging file
124123
NULL, // default security
125124
PAGE_READWRITE, // read/write access
126125
high_order_size, // maximum object size (high-order DWORD)
127126
low_order_size, // maximum object size (low-order DWORD)
128127
shm_key); // name of mapping object
129128

130-
if (local_handle == NULL) {
129+
if (shm_file == NULL) {
131130
return -7;
132131
}
133132

134133
// get base address of shared memory region
135134
void* shm_addr = nullptr;
136-
int err = SharedMemoryRegionMap((void*)local_handle, 0, byte_size, &shm_addr);
135+
int err = SharedMemoryRegionMap((void*)shm_file, 0, byte_size, &shm_addr);
137136
if (err == -1) {
138137
return -4;
139138
}
140139

141140
// create a handle for the shared memory region
142141
*shm_handle = SharedMemoryHandleCreate(
143142
std::string(triton_shm_name), shm_addr, std::string(shm_key),
144-
(void*)local_handle, 0, byte_size);
143+
(void*)shm_file, 0, byte_size);
145144
#else
146145
// get shared memory region descriptor
147146
int shm_fd = shm_open(shm_key, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
@@ -188,12 +187,12 @@ GetSharedMemoryHandleInfo(
188187
#ifdef _WIN32
189188
HANDLE* file = static_cast<HANDLE*>(shm_file);
190189
#else
191-
int* file = *static_cast<int**>(shm_file);
190+
int* file = *reinterpret_cast<int**>(shm_file);
192191
#endif // _WIN32
193192
SharedMemoryHandle* handle = static_cast<SharedMemoryHandle*>(shm_handle);
194193
*shm_addr = static_cast<char*>(handle->base_addr_);
195194
*shm_key = handle->shm_key_.c_str();
196-
*file = handle->platform_handle_->shm_file_;
195+
*file = *(handle->platform_handle_->GetShmFile());
197196
*offset = handle->offset_;
198197
*byte_size = handle->byte_size_;
199198
return 0;
@@ -213,7 +212,7 @@ SharedMemoryRegionDestroy(void* shm_handle)
213212
// We keep Windows shared memory handles open until we are done
214213
// using them. When all handles are closed, the system will free
215214
// the section of the paging file that the object uses.
216-
CloseHandle(handle->platform_handle_->shm_file_);
215+
CloseHandle(*(handle->platform_handle_->GetShmFile()));
217216
#else
218217
int status = munmap(shm_addr, handle->byte_size_);
219218
if (status == -1) {
@@ -224,11 +223,11 @@ SharedMemoryRegionDestroy(void* shm_handle)
224223
if (shm_fd == -1) {
225224
return -5;
226225
}
226+
close(*(handle->platform_handle_->GetShmFile()));
227227
#endif // _WIN32
228228

229-
// FIXME: Investigate use of smart pointers for these
230-
// allocations instead
231-
delete handle->platform_handle_;
229+
// FIXME: Investigate use of smart pointers for this
230+
// allocation instead
232231
delete handle;
233232

234233
return 0;

src/python/library/tritonclient/utils/shared_memory/shared_memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ extern "C" {
3232
#ifdef _WIN32
3333
#define TRITONCLIENT_DECLSPEC __declspec(dllexport)
3434
#else
35-
define TRITONCLIENT_DECLSPEC
35+
#define TRITONCLIENT_DECLSPEC
3636
#endif
3737

3838
//==============================================================================

src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,28 @@
3333
#ifdef _WIN32
3434
#include <windows.h>
3535
#endif // _WIN32
36+
#include <memory>
3637

3738
struct ShmFile {
3839
#ifdef _WIN32
3940
HANDLE shm_file_;
4041
ShmFile(void* shm_file) { shm_file_ = static_cast<HANDLE>(shm_file); };
42+
HANDLE* GetShmFile() { return &shm_file_; };
4143
#else
42-
int shm_file_;
43-
ShmFile(int shm_file) { shm_file_ = *static_cast<int*>(shm_file); };
44+
std::unique_ptr<int> shm_file_;
45+
ShmFile(void* shm_file)
46+
{
47+
shm_file_ = std::make_unique<int>(*static_cast<int*>(shm_file));
48+
};
49+
int* GetShmFile() { return shm_file_.get(); }
4450
#endif // _WIN32
4551
};
4652

4753
struct SharedMemoryHandle {
4854
std::string triton_shm_name_;
4955
std::string shm_key_;
5056
void* base_addr_;
51-
ShmFile* platform_handle_;
57+
std::unique_ptr<ShmFile> platform_handle_;
5258
size_t offset_;
5359
size_t byte_size_;
5460
};

0 commit comments

Comments
 (0)