Skip to content

Commit 43700e2

Browse files
committed
Review comments from Guan
1 parent e97179c commit 43700e2

File tree

4 files changed

+44
-52
lines changed

4 files changed

+44
-52
lines changed

src/python/library/tritonclient/utils/shared_memory/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import os
3030
import struct
31+
import sys
3132
from ctypes import *
3233

3334
import numpy as np
@@ -45,6 +46,13 @@ def from_param(cls, value):
4546
return value.encode("utf8")
4647

4748

49+
class ShmFile(Structure):
50+
if sys.platform == "win32":
51+
_fields_ = [("shm_handle_", c_void_p)]
52+
else:
53+
_fields_ = [("shm_fd_", c_int)]
54+
55+
4856
_cshm_lib = "cshm" if os.name == "nt" else "libcshm.so"
4957
_cshm_path = pkg_resources.resource_filename(
5058
"tritonclient.utils.shared_memory", _cshm_lib
@@ -63,7 +71,7 @@ def from_param(cls, value):
6371
c_void_p,
6472
POINTER(c_char_p),
6573
POINTER(c_char_p),
66-
POINTER(c_void_p),
74+
POINTER(ShmFile),
6775
POINTER(c_uint64),
6876
POINTER(c_uint64),
6977
]
@@ -205,10 +213,7 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0):
205213
The numpy array generated using the contents of the specified shared
206214
memory region.
207215
"""
208-
# Safe initializer for Unix case where shm_file must be dereferenced to
209-
# base in order to store file descriptor.
210-
safe_initializer = c_int(-1)
211-
shm_file = cast(byref(safe_initializer), c_void_p)
216+
shm_file = ShmFile()
212217
region_offset = c_uint64()
213218
byte_size = c_uint64()
214219
shm_addr = c_char_p()
@@ -287,10 +292,7 @@ def destroy_shared_memory_region(shm_handle):
287292
SharedMemoryException
288293
If unable to unlink the shared memory region.
289294
"""
290-
# Safe initializer for Unix case where shm_file must be dereferenced to
291-
# base in order to store file descriptor.
292-
safe_initializer = c_int(-1)
293-
shm_file = cast(byref(safe_initializer), c_void_p)
295+
shm_file = ShmFile()
294296
offset = c_uint64()
295297
byte_size = c_uint64()
296298
shm_addr = c_char_p()

src/python/library/tritonclient/utils/shared_memory/shared_memory.cc

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,23 @@ namespace {
4646
void*
4747
SharedMemoryHandleCreate(
4848
std::string triton_shm_name, void* shm_addr, std::string shm_key,
49-
void* shm_file, size_t offset, size_t byte_size)
49+
ShmFile* shm_file, size_t offset, size_t byte_size)
5050
{
5151
SharedMemoryHandle* handle = new SharedMemoryHandle();
5252
handle->triton_shm_name_ = triton_shm_name;
5353
handle->base_addr_ = shm_addr;
5454
handle->shm_key_ = shm_key;
55-
handle->platform_handle_ = std::make_unique<ShmFile>(shm_file);
55+
handle->platform_handle_.reset(shm_file);
5656
handle->offset_ = offset;
5757
handle->byte_size_ = byte_size;
5858
return static_cast<void*>(handle);
5959
}
6060

6161
int
6262
SharedMemoryRegionMap(
63-
void* shm_file, size_t offset, size_t byte_size, void** shm_addr)
63+
ShmFile* shm_file, size_t offset, size_t byte_size, void** shm_addr)
6464
{
6565
#ifdef _WIN32
66-
HANDLE file_handle = static_cast<HANDLE>(shm_file);
6766
// The MapViewOfFile function takes a high-order and low-order DWORD (4 bytes
6867
// each) for offset. 'size_t' can either be 4 or 8 bytes depending on the
6968
// operating system. To handle both cases agnostically, we cast 'offset' to
@@ -74,14 +73,14 @@ SharedMemoryRegionMap(
7473
DWORD low_order_offset = upperbound_offset & 0xFFFFFFFF;
7574
// map shared memory to process address space
7675
*shm_addr = MapViewOfFile(
77-
file_handle, // handle to map object
78-
FILE_MAP_ALL_ACCESS, // read/write permission
79-
high_order_offset, // offset (high-order DWORD)
80-
low_order_offset, // offset (low-order DWORD)
76+
shm_file->shm_handle_, // handle to map object
77+
FILE_MAP_ALL_ACCESS, // read/write permission
78+
high_order_offset, // offset (high-order DWORD)
79+
low_order_offset, // offset (low-order DWORD)
8180
byte_size);
8281

8382
if (*shm_addr == NULL) {
84-
CloseHandle(file_handle);
83+
CloseHandle(shm_file->shm_handle_);
8584
return -1;
8685
}
8786
// For Windows, we cannot close the shared memory handle here. When all
@@ -90,9 +89,9 @@ SharedMemoryRegionMap(
9089
// we are destroying the shared memory object.
9190
return 0;
9291
#else
93-
int fd = *static_cast<int*>(shm_file);
9492
// map shared memory to process address space
95-
*shm_addr = mmap(NULL, byte_size, PROT_WRITE, MAP_SHARED, fd, offset);
93+
*shm_addr =
94+
mmap(NULL, byte_size, PROT_WRITE, MAP_SHARED, shm_file->shm_fd_, offset);
9695
if (*shm_addr == MAP_FAILED) {
9796
return -1;
9897
}
@@ -118,29 +117,25 @@ SharedMemoryRegionCreate(
118117
DWORD high_order_size = (upperbound_size >> 32) & 0xFFFFFFFF;
119118
DWORD low_order_size = upperbound_size & 0xFFFFFFFF;
120119

121-
HANDLE shm_file = CreateFileMapping(
120+
HANDLE win_handle = CreateFileMapping(
122121
INVALID_HANDLE_VALUE, // use paging file
123122
NULL, // default security
124123
PAGE_READWRITE, // read/write access
125124
high_order_size, // maximum object size (high-order DWORD)
126125
low_order_size, // maximum object size (low-order DWORD)
127126
shm_key); // name of mapping object
128127

129-
if (shm_file == NULL) {
128+
if (win_handle == NULL) {
130129
return -7;
131130
}
132131

132+
ShmFile* shm_file = new ShmFile(win_handle);
133133
// get base address of shared memory region
134134
void* shm_addr = nullptr;
135-
int err = SharedMemoryRegionMap((void*)shm_file, 0, byte_size, &shm_addr);
135+
int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr);
136136
if (err == -1) {
137137
return -4;
138138
}
139-
140-
// create a handle for the shared memory region
141-
*shm_handle = SharedMemoryHandleCreate(
142-
std::string(triton_shm_name), shm_addr, std::string(shm_key),
143-
(void*)shm_file, 0, byte_size);
144139
#else
145140
// get shared memory region descriptor
146141
int shm_fd = shm_open(shm_key, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
@@ -154,18 +149,18 @@ SharedMemoryRegionCreate(
154149
return -3;
155150
}
156151

152+
ShmFile* shm_file = new ShmFile(shm_fd);
157153
// get base address of shared memory region
158154
void* shm_addr = nullptr;
159-
int err = SharedMemoryRegionMap((void*)&shm_fd, 0, byte_size, &shm_addr);
155+
int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr);
160156
if (err == -1) {
161157
return -4;
162158
}
163-
159+
#endif
164160
// create a handle for the shared memory region
165161
*shm_handle = SharedMemoryHandleCreate(
166-
std::string(triton_shm_name), shm_addr, std::string(shm_key),
167-
(void*)&shm_fd, 0, byte_size);
168-
#endif
162+
std::string(triton_shm_name), shm_addr, std::string(shm_key), shm_file, 0,
163+
byte_size);
169164
return 0;
170165
}
171166

@@ -181,20 +176,20 @@ SharedMemoryRegionSet(
181176

182177
TRITONCLIENT_DECLSPEC int
183178
GetSharedMemoryHandleInfo(
184-
void* shm_handle, char** shm_addr, const char** shm_key, void** shm_file,
179+
void* shm_handle, char** shm_addr, const char** shm_key, void* shm_file,
185180
size_t* offset, size_t* byte_size)
186181
{
187-
#ifdef _WIN32
188-
HANDLE* file = static_cast<HANDLE*>(shm_file);
189-
#else
190-
int* file = *reinterpret_cast<int**>(shm_file);
191-
#endif // _WIN32
192182
SharedMemoryHandle* handle = static_cast<SharedMemoryHandle*>(shm_handle);
183+
ShmFile* file = static_cast<ShmFile*>(shm_file);
193184
*shm_addr = static_cast<char*>(handle->base_addr_);
194185
*shm_key = handle->shm_key_.c_str();
195-
*file = *(handle->platform_handle_->GetShmFile());
196186
*offset = handle->offset_;
197187
*byte_size = handle->byte_size_;
188+
#ifdef _WIN32
189+
file->shm_handle_ = handle->platform_handle_->shm_handle_;
190+
#else
191+
file->shm_fd_ = handle->platform_handle_->shm_fd_;
192+
#endif
198193
return 0;
199194
}
200195

@@ -212,7 +207,7 @@ SharedMemoryRegionDestroy(void* shm_handle)
212207
// We keep Windows shared memory handles open until we are done
213208
// using them. When all handles are closed, the system will free
214209
// the section of the paging file that the object uses.
215-
CloseHandle(*(handle->platform_handle_->GetShmFile()));
210+
CloseHandle(handle->platform_handle_->shm_handle_);
216211
#else
217212
int status = munmap(shm_addr, handle->byte_size_);
218213
if (status == -1) {
@@ -223,7 +218,7 @@ SharedMemoryRegionDestroy(void* shm_handle)
223218
if (shm_fd == -1) {
224219
return -5;
225220
}
226-
close(*(handle->platform_handle_->GetShmFile()));
221+
close(handle->platform_handle_->shm_fd_);
227222
#endif // _WIN32
228223

229224
// FIXME: Investigate use of smart pointers for this

src/python/library/tritonclient/utils/shared_memory/shared_memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ TRITONCLIENT_DECLSPEC int SharedMemoryRegionCreate(
4343
TRITONCLIENT_DECLSPEC int SharedMemoryRegionSet(
4444
void* shm_handle, size_t offset, size_t byte_size, const void* data);
4545
TRITONCLIENT_DECLSPEC int GetSharedMemoryHandleInfo(
46-
void* shm_handle, char** shm_addr, const char** shm_key, void** shm_file,
46+
void* shm_handle, char** shm_addr, const char** shm_key, void* shm_file,
4747
size_t* offset, size_t* byte_size);
4848
TRITONCLIENT_DECLSPEC int SharedMemoryRegionDestroy(void* shm_handle);
4949

src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,11 @@
3737

3838
struct ShmFile {
3939
#ifdef _WIN32
40-
HANDLE shm_file_;
41-
ShmFile(void* shm_file) { shm_file_ = static_cast<HANDLE>(shm_file); };
42-
HANDLE* GetShmFile() { return &shm_file_; };
40+
HANDLE shm_handle_;
41+
ShmFile(HANDLE shm_handle) : shm_handle_(shm_handle){};
4342
#else
44-
std::unique_ptr<int> shm_file_;
45-
ShmFile(void* shm_file)
46-
{
47-
shm_file_ = std::make_unique<int>(*static_cast<int*>(shm_file));
48-
};
49-
int* GetShmFile() { return shm_file_.get(); }
43+
int shm_fd_;
44+
ShmFile(int shm_fd) : shm_fd_(shm_fd){};
5045
#endif // _WIN32
5146
};
5247

0 commit comments

Comments
 (0)