Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor providers into separate libraries #1190

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

RyanUnderhill
Copy link
Member

This removes most of the #if USE_CUDA and #if USE_DML blocks for the model handling code. Device memory management is also handled through the DeviceSpan structure and now all data copying is done in a device independent manner.

It's a huge change, and there will be some rough edges when submitted. Goal is to unblock other people needing the changes and then to make larger improvements in future prs.

src/cuda/interface.cpp Show resolved Hide resolved
src/cuda/interface.cpp Show resolved Hide resolved
src/cuda/interface.cpp Show resolved Hide resolved
src/cuda/interface.cpp Show resolved Hide resolved
std::string CurrentModulePath();

namespace Generators {
namespace Dml { // If this was in a shared library it wouldn't need to be in its own namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't the CUDA one in it's own namespace? If we build with both DML and CUDA, wouldn't GpuMemory overlap (if it weren't for the namespace)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cuda is built as a separate shared library so it shouldn't overlap at all. Though I'm not 100% sure I'm using the right dlopen options as symbols on non windows can behave differently.

src/dml/interface.cpp Show resolved Hide resolved
#endif
}

value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did static buffer go?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It went away, the existing logic was overly complicated and I couldn't figure out why we still needed it. Do you know?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely related to graph capture stuff. But I don't know details.

}
}
}
// Update input_ids with next tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the comment made me think WrapTensor was doing the update when it's the following code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I can see that. I'm not sure what would make it clearer. Maybe an extra newline?

@natke natke self-requested a review January 27, 2025 18:31
Copy link
Contributor

@natke natke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you enumerate the places where any #ifdefs remain and why they need to be there please

And what impact will the rough edges have and can they be smoothed before you merge this PR?

Copy link
Collaborator

@baijumeswani baijumeswani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass through the code

@@ -31,7 +31,6 @@ TEST(CAPITests, Config) {
config->AppendProvider("cuda");
#endif
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental

@@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf
}

CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that DML will no longer use graph capture?

#endif
}

value_ = OrtValue::CreateTensor<int32_t>(*model_.allocator_device_, shape_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely related to graph capture stuff. But I don't know details.

@@ -160,6 +167,8 @@ std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); // For benchmarking purposes only
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);

void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment explaining what CopyThroughCpu means here would be helpful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the comment in the cpp. :)
Could we move it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I put it in both places as it was small.

@@ -96,77 +92,73 @@ OrtEnv& GetOrtEnv() {
return *GetOrtGlobals()->env_;
}

// Fallback to copy between two separate device buffers by going through CPU memory (slow unless we're the CPU device)
void CopyThroughCpu(DeviceBuffer& dest, size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this go inside device interface file as a free function as opposed to generators.cpp?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used by the cuda provider directly, so having it inside the CPU interface made that difficult. I might change it going forward once we have more shared providers and base class implementations (inheritance doesn't work across shared libraries, so passing in the CPU provider as the base interface might be the solution).

Comment on lines +122 to +123
void DumpSpan(std::ostream& stream, std::span<const float> values) override { return Generators::DumpSpan(stream, values); }
void DumpSpan(std::ostream& stream, std::span<const int> values) override { return Generators::DumpSpan(stream, values); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this scope needed?

}
throw std::runtime_error("Unknown device type");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the static analysis tools we have smart enough to detect that the control will never reach end of this function. Do we need a dummy return std::string()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually a build warning on android or iOS that spurred me to fix that. I didn't see any other issues after doing it.

case DeviceType::CUDA:
return GetCudaInterface();
#if USE_DML
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this #if and not one for CUDA?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not built as a shared library yet, only cuda is. So we can't just try loading it and failing if it doesn't exist. It's either built with it or not. Once it's a shared library there will be no #ifdef.

I could have the function exist but the definition of this function will be inside another #ifdef !USE_DML so it's just moving the problem around.

@@ -1,3 +1,4 @@
#include "models/onnxruntime_api.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

license at the top of the file.

@@ -42,7 +38,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, DeviceSpan<int32_t> seq
}

if (inputs.alignment_heads != nullptr) {
#if USE_CUDA
#if 0 // USE_CUDA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these USE_CUDA's be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that will happen when Kunal merges his whisper branch

@RyanUnderhill
Copy link
Member Author

Can you enumerate the places where any #ifdefs remain and why they need to be there please

And what impact will the rough edges have and can they be smoothed before you merge this PR?

There are some #if USE_CUDA in our tests, this shouldn't be a problem
There are two #if USE_DML, one in generators.cpp due to it not being a shared library and a second in model.cpp for a similar reason. Making it into a shared library should factor those out and remove the #ifs (the shared library's existence takes the place of the #if, since when it's static you will fail to build without the #if)

The rough edges are just expected simple bugs we'll find and easily fix that I can't find in advance.


auto& device = GetOrtGlobals()->allocator_device_[static_cast<int>(type)];
if (!device) {
static const char* device_type_names[static_cast<int>(DeviceType::MAX)] = {"CPU - SEE ABOVE", "Cuda", "DML", "WebGPU Buffer"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on WebGPU_Buffer.

Doesn't QNN use CPU memory? It doesn't have a device allocator "QnnWithSharedMemory".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you gave me an idea, this should fail to compile if new providers are added. I fixed it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants