-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add string normalize method #1461
base: main
Are you sure you want to change the base?
Changes from all commits
587fa88
4fa0e34
681d2e3
1ec4b10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,26 @@ public static partial class torch | |
static bool nativeBackendCudaLoaded = false; | ||
|
||
public static string __version__ => libtorchPackageVersion; | ||
public static string NormalizeNuGetVersion(string versionString) | ||
{ | ||
if (string.IsNullOrWhiteSpace(versionString)) | ||
throw new ArgumentException($"Invalid NuGet version: {versionString}. Version string is null, empty or only contains whitespaces"); | ||
|
||
string[] parts = versionString.Split('-', '+'); | ||
string[] versionParts = parts[0].Split('.'); | ||
|
||
if (versionParts.Length < 2 || versionParts.Length > 4 || !versionParts.All(v => int.TryParse(v, out _))) | ||
throw new ArgumentException($"Invalid NuGet version: {versionString}. Please check: https://learn.microsoft.com/en-us/nuget/concepts/package-versioning"); | ||
|
||
string normalizedVersion = versionParts[0] + "." + versionParts[1]; | ||
if (versionParts.Length > 2) normalizedVersion += "." + versionParts[2]; | ||
if (versionParts.Length > 3 && int.Parse(versionParts[3]) != 0) normalizedVersion += "." + versionParts[3]; | ||
|
||
if (parts.Length > 1) | ||
normalizedVersion += "-" + parts[1]; | ||
|
||
return normalizedVersion; | ||
} | ||
|
||
internal static bool TryLoadNativeLibraryFromFile(string path, StringBuilder trace) { | ||
bool ok; | ||
|
@@ -168,16 +188,17 @@ private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder? tr | |
|
||
if (torchsharpLoc!.Contains("torchsharp") && torchsharpLoc.Contains("lib") && Directory.Exists(packagesDir) && Directory.Exists(torchsharpHome)) { | ||
|
||
var torchSharpVersion = Path.GetFileName(torchsharpHome); // really GetDirectoryName | ||
|
||
var assembly = typeof(torch).Assembly; | ||
var version = assembly.GetName().Version; | ||
var torchSharpVersion = (version != null) ? version.ToString() : Path.GetFileName(torchsharpHome); | ||
if (useCudaBackend) { | ||
var consolidatedDir = Path.Combine(torchsharpLoc, $"cuda-{cudaVersion}"); | ||
|
||
trace.AppendLine($" Trying dynamic load for .NET/F# Interactive by consolidating native {cudaRootPackage}-* binaries to {consolidatedDir}..."); | ||
|
||
var cudaOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, $"{cudaRootPackage}-*", libtorchPackageVersion, consolidatedDir, trace); | ||
var cudaOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, $"{cudaRootPackage}-*", NormalizeNuGetVersion(libtorchPackageVersion), consolidatedDir, trace); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of calling With doing it:
|
||
if (cudaOk) { | ||
cudaOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, "torchsharp", torchSharpVersion, consolidatedDir, trace); | ||
cudaOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, "torchsharp", NormalizeNuGetVersion(torchSharpVersion), consolidatedDir, trace); | ||
if (cudaOk) { | ||
var consolidated = Path.Combine(consolidatedDir, target); | ||
ok = TryLoadNativeLibraryFromFile(consolidated, trace); | ||
|
@@ -193,9 +214,9 @@ private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder? tr | |
|
||
trace.AppendLine($" Trying dynamic load for .NET/F# Interactive by consolidating native {cpuRootPackage}-* binaries to {consolidatedDir}..."); | ||
|
||
var cpuOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, cpuRootPackage, libtorchPackageVersion, consolidatedDir, trace); | ||
var cpuOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, cpuRootPackage, NormalizeNuGetVersion(libtorchPackageVersion), consolidatedDir, trace); | ||
if (cpuOk) { | ||
cpuOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, "torchsharp", torchSharpVersion, consolidatedDir, trace); | ||
cpuOk = CopyNativeComponentsIntoSingleDirectory(packagesDir, "torchsharp", NormalizeNuGetVersion(torchSharpVersion), consolidatedDir, trace); | ||
if (cpuOk) { | ||
var consolidated = Path.Combine(consolidatedDir, target); | ||
ok = TryLoadNativeLibraryFromFile(consolidated, trace); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When taking the version string directly from the library it results in a 4 number version string: eg: 0.105.0.0, which needs to be normalized to remove last 0. I think it's better to have it taken straight from the dll that was loaded to the project than the nugget install folder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, I would recommend to consider behavior change (reading via assembly) in a separated PR - issue.
For this PR, we can focus on the issue we have.
Later, we can do the other change (reading via assembly) with only that scope.
So, in case of any revert, we can keep the fix safe.