Skip to content

Fix torch.jit.ScriptModule.zero_grad. #1478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/Native/LibTorchSharp/THSJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ int THSJIT_Module_is_training(JITModule module)
return (*module)->is_training();
}

void THSJIT_Module_zero_grad(const JITModule module, bool set_to_none)
{
// According to https://github.com/pytorch/pytorch/issues/27144,
// torch::jit::Module has no zero_grad().
// As a workaround, manually loop over the parameters and zero them out like optimizer does;
// https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/api/src/optim/optimizer.cpp#L123
for (const auto& p : (*module)->parameters()) {
if (p.mutable_grad().defined()) {
p.mutable_grad().detach_();
if (set_to_none)
p.mutable_grad().reset();
else
p.mutable_grad().zero_();
}
}
}

void THSJIT_Module_train(JITModule module, bool on)
{
(*module)->train(on);
Expand Down
1 change: 1 addition & 0 deletions src/Native/LibTorchSharp/THSJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name,
EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t idx);

EXPORT_API(int) THSJIT_Module_is_training(JITModule module);
EXPORT_API(void) THSJIT_Module_zero_grad(const JITModule module, bool set_to_none);
EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on);
EXPORT_API(void) THSJIT_Module_eval(JITModule module);

Expand Down
17 changes: 17 additions & 0 deletions src/TorchSharp/JIT/ScriptModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ public override bool training {
}
}

public override void zero_grad(bool set_to_none = true)
{
THSJIT_Module_zero_grad(handle, set_to_none);
CheckForErrors();

foreach (var (_, p) in named_parameters()) {
using var grad = p.grad;
if (grad is not null) {
if (set_to_none) {
p.grad = null;
} else {
grad.zero_();
}
}
}
}

protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking)
{
if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };
Expand Down
3 changes: 3 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ internal static partial class NativeMethods
[return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module);

[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_zero_grad(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool set_to_none);

[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_to_device(torch.nn.Module.HType module, long deviceType, long deviceIndex);

Expand Down