Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tunix/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
Mesh = jax.sharding.Mesh
NamedSharding = jax.sharding.NamedSharding


def is_positive_integer(value: int | None, name: str):
"""Checks if the value is positive."""
if value is not None and (not value.is_integer() or value <= 0):
def is_positive_integer(value: int | float | None, name: str) -> None:
"""Checks if the value is a positive integer."""
if value is not None and (not is_integer_value(value) or value <= 0):
raise ValueError(f"{name} must be a positive integer. Got: {value}")

def is_integer_value(value) -> bool:
"""Checks if the value is an integer"""
return (isinstance(value, int) and not isinstance(value, bool)) or (isinstance(value, float) and value.is_integer())


def check_divisibility(
small_size,
Expand Down