Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor providers into separate libraries #1190

Merged
merged 39 commits into from
Feb 13, 2025
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7e4668b
Use DeviceInterface for debugging
RyanUnderhill Nov 23, 2024
34381af
Merge remote-tracking branch 'origin/main' into ryanunderhill/providers
RyanUnderhill Nov 23, 2024
35e79ce
Merge remote-tracking branch 'origin/main' into ryanunderhill/providers
RyanUnderhill Nov 26, 2024
3823664
Summary: Remove #ifdefs for providers and go through device interface.
RyanUnderhill Dec 16, 2024
41b462a
Finish refactoring model processing
RyanUnderhill Jan 15, 2025
237fb1e
Merge with main
RyanUnderhill Jan 15, 2025
bdbb09c
Fix merge build issues
RyanUnderhill Jan 16, 2025
66321dd
Formatting
RyanUnderhill Jan 16, 2025
0bc39a5
Build fixes
RyanUnderhill Jan 17, 2025
0f2ea36
Merge with main
RyanUnderhill Jan 17, 2025
5244049
Build fix
RyanUnderhill Jan 17, 2025
49b51ef
Build fix
RyanUnderhill Jan 17, 2025
d3db2f6
Fix input_ids issue from merge
RyanUnderhill Jan 21, 2025
133d5a0
Fix C# unit tests
RyanUnderhill Jan 22, 2025
b079b74
Try again to fix C# test
RyanUnderhill Jan 22, 2025
0e7064c
Merge with main
RyanUnderhill Jan 23, 2025
afecf1d
Test theory
RyanUnderhill Jan 23, 2025
1734f5c
Test instrumenting
RyanUnderhill Jan 24, 2025
2bc83eb
Crash investigation
RyanUnderhill Jan 24, 2025
fd788d7
Extra debug logging
RyanUnderhill Jan 24, 2025
67d914c
Merge with main
RyanUnderhill Jan 24, 2025
0303592
Undefined behavior fix in startup
RyanUnderhill Jan 25, 2025
d87807c
Don't load cuda library outside of linux & windows
RyanUnderhill Jan 25, 2025
2df5fe1
Fix iOS break
RyanUnderhill Jan 25, 2025
6736517
Android tweak
RyanUnderhill Jan 25, 2025
a011fe0
Leftover #ifdef fix
RyanUnderhill Jan 27, 2025
c11704f
Type tweak
RyanUnderhill Jan 27, 2025
45dad2b
Review feedback
RyanUnderhill Jan 28, 2025
53c666c
Edward gave me ideas.
RyanUnderhill Jan 29, 2025
e804697
Clean up allocators, now everything is through p_device_* interfaces.
RyanUnderhill Jan 30, 2025
f8ed9ce
Previous change also added device interfaces for webgpu & qnn
RyanUnderhill Jan 30, 2025
198e8f8
Remove accidental change
RyanUnderhill Jan 30, 2025
e6b77f2
Device check simplifications
RyanUnderhill Jan 30, 2025
4f2f084
Refactor device_type
RyanUnderhill Jan 31, 2025
0765339
Merge remote-tracking branch 'origin/main' into ryanunderhill/providers
RyanUnderhill Feb 1, 2025
acba52c
Update src/models/model.h
RyanUnderhill Feb 3, 2025
4bcfa33
Merge with main
RyanUnderhill Feb 13, 2025
68a6ea7
Fix merge conflicts
RyanUnderhill Feb 13, 2025
12e2f76
Formatting
RyanUnderhill Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Edward gave me ideas.
RyanUnderhill committed Jan 29, 2025
commit 53c666c1013223b731c18d32e7bc0f7562ca1ab7
8 changes: 6 additions & 2 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
@@ -228,10 +228,14 @@ Ort::Allocator* GetDeviceAllocator(OrtSession& session, DeviceType type) {

auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
if (!device) {
static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU - SEE ABOVE", "Cuda", "DML", "WebGPU Buffer"};
static const char* device_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QNN (Not used, uses CPU memory)"};
static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));

auto memory_info = OrtMemoryInfo::Create(device_type_names[static_cast<int>(type)], OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
auto name = device_type_names[static_cast<int>(type)];
auto memory_info = OrtMemoryInfo::Create(name, OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
device = Ort::Allocator::Create(session, *memory_info);
if (!device)
throw std::runtime_error("Unexpected failure to create device memory allocator for " + std::string(name));
GetDeviceInterface(type)->InitOrt(*Ort::api, *device); // Necessary for any shared library providers so they can access Ort::api
}
return device.get();