diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py index 1888368f..5c5bf7c6 100644 --- a/tunix/rl/utils.py +++ b/tunix/rl/utils.py @@ -31,12 +31,15 @@ Mesh = jax.sharding.Mesh NamedSharding = jax.sharding.NamedSharding - -def is_positive_integer(value: int | None, name: str): - """Checks if the value is positive.""" - if value is not None and (not value.is_integer() or value <= 0): +def is_positive_integer(value: int | float | None, name: str) -> None: + """Checks if the value is a positive integer.""" + if value is not None and (not is_integer_value(value) or value <= 0): raise ValueError(f"{name} must be a positive integer. Got: {value}") +def is_integer_value(value) -> bool: + """Checks if the value is an integer""" + return (isinstance(value, int) and not isinstance(value, bool)) or (isinstance(value, float) and value.is_integer()) + def check_divisibility( small_size,