Skip to content

Conversation

@adgilbert
Copy link

Hi,

Thanks for putting together this repository. I'm using the loss functions only as a part of another project and using CPU/GPU at different times. I also use half precision training sometimes. This PR makes the loss functions use the device/dtypes of the passed in tensors rather than always using GPU/torch.float32. Since the training code still uses get_torch_device() and float32 tensors this should change the operation only when someone is using the loss functions separately (my use case)

Creating this PR in case it's useful to others. Obviously feel free to reject if you want to keep as is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant