Explicitly cast input_ids to jnp.int32
#838
+3
−1
Open