Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions tests/test_tpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,38 @@ def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):


# Test get_num_chips
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
@patch("tpu_inference.tpu_info.glob.glob",
return_value=["/dev/accel0", "/dev/accel1"])
def test_get_num_chips_from_accel(mock_glob):
def test_get_num_chips_from_accel(mock_glob, mock_get_tpu_type):
"""Test get_num_chips when /dev/accel* files exist."""
assert get_num_chips() == 2


@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
@patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
def test_get_num_chips_from_vfio(mock_listdir, mock_glob, mock_get_tpu_type):
"""Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
assert get_num_chips() == 3


@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
@patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
def test_get_num_chips_not_found(mock_listdir, mock_glob, mock_get_tpu_type,
caplog):
"""Test get_num_chips when neither files nor directory are found."""
assert get_num_chips() == 0


# Test get_num_chips with tpu7x
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="tpu7x-8")
@patch("tpu_inference.tpu_info.glob.glob",
return_value=[
"/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3",
"/dev/accel4", "/dev/accel5", "/dev/accel6", "/dev/accel7"
])
def test_get_num_chips_tpu7x(mock_glob, mock_get_tpu_type):
"""Test get_num_chips for tpu7x divides by 2 (8 devices = 4 chips)."""
assert get_num_chips() == 4
27 changes: 17 additions & 10 deletions tpu_inference/tpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,20 @@ def get_num_cores_per_chip() -> int:
def get_num_chips() -> int:
accel_files = glob.glob("/dev/accel*")
if accel_files:
return len(accel_files)
try:
vfio_entries = os.listdir("/dev/vfio")
numeric_entries = [
int(entry) for entry in vfio_entries if entry.isdigit()
]
return len(numeric_entries)
except FileNotFoundError as e:
logger.error("Failed to detect number of TPUs: %s", e)
return 0
num_devices = len(accel_files)
else:
try:
vfio_entries = os.listdir("/dev/vfio")
numeric_entries = [
int(entry) for entry in vfio_entries if entry.isdigit()
]
num_devices = len(numeric_entries)
except FileNotFoundError as e:
logger.error("Failed to detect number of TPUs: %s", e)
return 0

# For tpu7x, each chip has 2 chiplets exposed as separate devices
tpu_type = get_tpu_type()
if tpu_type and "tpu7x" in tpu_type.lower():
return num_devices // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really a fan of hard coding the value "2" or having "tpu7x" specific logic.

Is there any logic within tpu_info that programatically fetches num_devices?

Or wouldn't a simple len(jax.devices()) return correct number of cores - regardless of having 1 chiplet or 1 chiplets?

Copy link
Author

@burbajr burbajr Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, and thank you for taking a closer look at this.

The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there.

def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2

The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips.

With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug.

I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count.

return num_devices