import os
print(f"CPU count: {os.cpu_count()}")
!nvidia-smi
import jax
print(jax.devices())
devices = jax.devices()
print(f"Detected {len(devices)} TPU device(s)")
print(f"Device type: {devices[0].platform}")
print(f"TPU Kind: {devices[0].device_kind}")
import psutil
print(f"Memory: {psutil.virtual_memory().percent}%")