-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor providers into separate libraries #1190
base: main
Are you sure you want to change the base?
Conversation
Details: Add a DML DeviceInterface and DML DeviceBuffer handler. Remove #if blocks that are doing memory copies between device/cpu memory and use the DeviceSpan interface.
Remove as many #if USE_CUDA/USE_DML as possible
std::string CurrentModulePath(); | ||
|
||
namespace Generators { | ||
namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't the CUDA one in it's own namespace? If we build with both DML and CUDA, wouldn't GpuMemory overlap (if it weren't for the namespace)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cuda is built as a separate shared library so it shouldn't overlap at all. Though I'm not 100% sure I'm using the right dlopen options as symbols on non windows can behave differently.
src/models/input_ids.cpp
Outdated
#endif | ||
} | ||
|
||
value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where did static buffer go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It went away, the existing logic was overly complicated and I couldn't figure out why we still needed it. Do you know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely related to graph capture stuff. But I don't know details.
} | ||
} | ||
} | ||
// Update input_ids with next tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the comment made me think WrapTensor was doing the update when it's the following code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I can see that. I'm not sure what would make it clearer. Maybe an extra newline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you enumerate the places where any #ifdefs remain and why they need to be there please
And what impact will the rough edges have and can they be smoothed before you merge this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass through the code
test/c_api_tests.cpp
Outdated
@@ -31,7 +31,6 @@ TEST(CAPITests, Config) { | |||
config->AppendProvider("cuda"); | |||
#endif | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accidental
@@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf | |||
} | |||
|
|||
CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const { | |||
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) { | |||
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that DML will no longer use graph capture?
src/models/input_ids.cpp
Outdated
#endif | ||
} | ||
|
||
value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely related to graph capture stuff. But I don't know details.
@@ -160,6 +167,8 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model); | |||
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); // For benchmarking purposes only | |||
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params); | |||
|
|||
void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment explaining what CopyThroughCpu means here would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the comment in the cpp. :)
Could we move it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I put it in both places as it was small.
@@ -96,77 +92,73 @@ OrtEnv& GetOrtEnv() { | |||
return *GetOrtGlobals()->env_; | |||
} | |||
|
|||
// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device) | |||
void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this go inside device interface file as a free function as opposed to generators.cpp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used by the cuda provider directly, so having it inside the CPU interface made that difficult. I might change it going forward once we have more shared providers and base class implementations (inheritance doesn't work across shared libraries, so passing in the CPU provider as the base interface might be the solution).
void DumpSpan(std::ostream& stream, std::span<const float> values) override { return Generators::DumpSpan(stream, values); } | ||
void DumpSpan(std::ostream& stream, std::span<const int> values) override { return Generators::DumpSpan(stream, values); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this scope needed?
} | ||
throw std::runtime_error("Unknown device type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the static analysis tools we have smart enough to detect that the control will never reach end of this function. Do we need a dummy return std::string()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was actually a build warning on android or iOS that spurred me to fix that. I didn't see any other issues after doing it.
case DeviceType::CUDA: | ||
return GetCudaInterface(); | ||
#if USE_DML |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this #if and not one for CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not built as a shared library yet, only cuda is. So we can't just try loading it and failing if it doesn't exist. It's either built with it or not. Once it's a shared library there will be no #ifdef.
I could have the function exist but the definition of this function will be inside another #ifdef !USE_DML so it's just moving the problem around.
@@ -1,3 +1,4 @@ | |||
#include "models/onnxruntime_api.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
license at the top of the file.
@@ -42,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq | |||
} | |||
|
|||
if (inputs.alignment_heads != nullptr) { | |||
#if USE_CUDA | |||
#if 0 // USE_CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will these USE_CUDA's be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that will happen when Kunal merges his whisper branch
There are some #if USE_CUDA in our tests, this shouldn't be a problem The rough edges are just expected simple bugs we'll find and easily fix that I can't find in advance. |
src/models/model.cpp
Outdated
|
||
auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)]; | ||
if (!device) { | ||
static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU - SEE ABOVE", "Cuda", "DML", "WebGPU Buffer"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about DeviceType::QNN
?
I think "WebGPU Buffer" should be "WebGPU_Buffer".
https://github.com/microsoft/onnxruntime/blob/e3e41739a7ca0ce0806805aa7e2814c72748d0e5/include/onnxruntime/core/framework/allocator.h#L56
https://github.com/microsoft/onnxruntime/blob/e3e41739a7ca0ce0806805aa7e2814c72748d0e5/onnxruntime/core/framework/allocator.cc#L143
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on WebGPU_Buffer.
Doesn't QNN use CPU memory? It doesn't have a device allocator "QnnWithSharedMemory".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you gave me an idea, this should fail to compile if new providers are added. I fixed it.
This removes most of the #if USE_CUDA and #if USE_DML blocks for the model handling code. Device memory management is also handled through the DeviceSpan structure and now all data copying is done in a device independent manner.
It's a huge change, and there will be some rough edges when submitted. Goal is to unblock other people needing the changes and then to make larger improvements in future prs.