Skip to content

Commit 9f28f1a

Browse files
feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling (#10699)
* feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling * chore: update type hint * refactor: use union for type hint --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 5d2d239 commit 9f28f1a

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/diffusers/training_utils.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder(
248248

249249

250250
def compute_density_for_timestep_sampling(
251-
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
251+
weighting_scheme: str,
252+
batch_size: int,
253+
logit_mean: float = None,
254+
logit_std: float = None,
255+
mode_scale: float = None,
256+
device: Union[torch.device, str] = "cpu",
257+
generator: Optional[torch.Generator] = None,
252258
):
253259
"""
254260
Compute the density for sampling the timesteps when doing SD3 training.
@@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling(
258264
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
259265
"""
260266
if weighting_scheme == "logit_normal":
261-
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
262-
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
267+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
263268
u = torch.nn.functional.sigmoid(u)
264269
elif weighting_scheme == "mode":
265-
u = torch.rand(size=(batch_size,), device="cpu")
270+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
266271
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
267272
else:
268-
u = torch.rand(size=(batch_size,), device="cpu")
273+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
269274
return u
270275

271276

0 commit comments

Comments
 (0)