Skip to content

Commit e92f0a2

Browse files
committed
Fix TPU7x chip counting to account for chiplet architecture
For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. The previous implementation counted these devices directly, resulting in 2x the actual chip count being reported. Example of the issue: - tpu7x-8 has 4 physical chips with 8 chiplets total - Previous behavior: reported num_chips=8 (incorrect) - Fixed behavior: reports num_chips=4 (correct) Changes: - Add get_num_chiplets_per_chip() helper function that returns 2 for tpu7x devices and 1 for all other TPU types, following the same pattern as get_num_cores_per_chip() - Modify get_num_chips() to divide device count by chiplets_per_chip using integer division - Update logging to conditionally show chiplets_per_chip for tpu7x devices only (non-tpu7x devices don't have chiplets, so we avoid showing confusing information) - Add comprehensive tests for tpu7x chip counting scenarios - Update existing get_num_chips() tests to mock get_tpu_type() since the function now calls it Test coverage: - test_get_num_chiplets_per_chip: Tests tpu7x variants, other TPU types, and edge cases (None, empty string) - test_get_num_chips_tpu7x_from_accel: Tests tpu7x-8 with /dev/accel* - test_get_num_chips_tpu7x_4_from_accel: Tests tpu7x-4 with /dev/accel* - test_get_num_chips_tpu7x_from_vfio: Tests tpu7x-8 with /dev/vfio - test_get_num_chips_non_tpu7x_unchanged: Verifies backward compatibility Backward compatibility: All non-tpu7x devices return chiplets_per_chip=1, making the division a no-op. Behavior is identical to before the fix. Signed-off-by: burbajr <[email protected]>
1 parent 0d4a4d1 commit e92f0a2

File tree

3 files changed

+122
-9
lines changed

3 files changed

+122
-9
lines changed

tests/test_tpu_info.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import requests
66

77
from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
8-
get_num_chips, get_num_cores_per_chip,
9-
get_tpu_metadata, get_tpu_type)
8+
get_num_chiplets_per_chip, get_num_chips,
9+
get_num_cores_per_chip, get_tpu_metadata,
10+
get_tpu_type)
1011

1112

1213
# Mock requests.get for get_tpu_metadata tests
@@ -98,23 +99,105 @@ def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
9899
assert get_num_cores_per_chip() == expected
99100

100101

102+
# Test get_num_chiplets_per_chip
103+
@pytest.mark.parametrize(
104+
"tpu_type, expected",
105+
[
106+
("tpu7x-8", 2),
107+
("tpu7x-4", 2),
108+
("TPU7x", 2), # Case insensitive
109+
("v5litepod-4", 1),
110+
("v6e-8", 1),
111+
("v4-8", 1),
112+
(None, 1), # Default when tpu_type is None
113+
("", 1), # Default when tpu_type is empty
114+
])
115+
@patch("tpu_inference.tpu_info.get_tpu_type")
116+
def test_get_num_chiplets_per_chip(mock_get_tpu_type, tpu_type, expected):
117+
"""Test get_num_chiplets_per_chip with different TPU types."""
118+
mock_get_tpu_type.return_value = tpu_type
119+
assert get_num_chiplets_per_chip() == expected
120+
121+
101122
# Test get_num_chips
123+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
102124
@patch("tpu_inference.tpu_info.glob.glob",
103125
return_value=["/dev/accel0", "/dev/accel1"])
104-
def test_get_num_chips_from_accel(mock_glob):
126+
def test_get_num_chips_from_accel(mock_glob, mock_get_tpu_type):
105127
"""Test get_num_chips when /dev/accel* files exist."""
106128
assert get_num_chips() == 2
107129

108130

131+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
109132
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
110133
@patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
111-
def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
134+
def test_get_num_chips_from_vfio(mock_listdir, mock_glob, mock_get_tpu_type):
112135
"""Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
113136
assert get_num_chips() == 3
114137

115138

139+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
116140
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
117141
@patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
118-
def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
142+
def test_get_num_chips_not_found(mock_listdir, mock_glob, mock_get_tpu_type,
143+
caplog):
119144
"""Test get_num_chips when neither files nor directory are found."""
120145
assert get_num_chips() == 0
146+
147+
148+
# Test get_num_chips with tpu7x
149+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="tpu7x-8")
150+
@patch("tpu_inference.tpu_info.glob.glob",
151+
return_value=[
152+
"/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3",
153+
"/dev/accel4", "/dev/accel5", "/dev/accel6", "/dev/accel7"
154+
])
155+
def test_get_num_chips_tpu7x_from_accel(mock_glob, mock_get_tpu_type):
156+
"""Test get_num_chips for tpu7x-8 when /dev/accel* files exist.
157+
158+
tpu7x-8 has 4 physical chips with 2 chiplets each (8 total devices).
159+
Should return 4 chips, not 8.
160+
"""
161+
assert get_num_chips() == 4
162+
163+
164+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="tpu7x-4")
165+
@patch(
166+
"tpu_inference.tpu_info.glob.glob",
167+
return_value=["/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3"])
168+
def test_get_num_chips_tpu7x_4_from_accel(mock_glob, mock_get_tpu_type):
169+
"""Test get_num_chips for tpu7x-4 when /dev/accel* files exist.
170+
171+
tpu7x-4 has 2 physical chips with 2 chiplets each (4 total devices).
172+
Should return 2 chips, not 4.
173+
"""
174+
assert get_num_chips() == 2
175+
176+
177+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="tpu7x-8")
178+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
179+
@patch("tpu_inference.tpu_info.os.listdir",
180+
return_value=["0", "1", "2", "3", "4", "5", "6", "7"])
181+
def test_get_num_chips_tpu7x_from_vfio(mock_listdir, mock_glob,
182+
mock_get_tpu_type):
183+
"""Test get_num_chips for tpu7x when using /dev/vfio entries.
184+
185+
tpu7x-8 has 4 physical chips with 2 chiplets each (8 total devices).
186+
Should return 4 chips when counting vfio entries.
187+
"""
188+
assert get_num_chips() == 4
189+
190+
191+
@patch("tpu_inference.tpu_info.get_tpu_type", return_value="v6e-8")
192+
@patch("tpu_inference.tpu_info.glob.glob",
193+
return_value=[
194+
"/dev/accel0", "/dev/accel1", "/dev/accel2", "/dev/accel3",
195+
"/dev/accel4", "/dev/accel5", "/dev/accel6", "/dev/accel7"
196+
])
197+
def test_get_num_chips_non_tpu7x_unchanged(mock_glob, mock_get_tpu_type):
198+
"""Test get_num_chips for non-tpu7x devices remains unchanged.
199+
200+
v6e-8 has 8 physical chips (1 chiplet each). Should return 8.
201+
This verifies backward compatibility.
202+
"""
203+
assert get_num_chips() == 8

tpu_inference/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,15 @@
4242
else:
4343
# Either running on TPU or CPU
4444
try:
45+
tpu_type = ti.get_tpu_type()
46+
# For tpu7x, show chiplets_per_chip since each chip has 2 chiplets.
47+
# Other TPU types don't have chiplets, so we don't show this info.
48+
chiplets = f" | num_chiplets_per_chip={ti.get_num_chiplets_per_chip()}" if (
49+
tpu_type and "tpu7x" in tpu_type.lower()) else ""
4550
logger.info(f"TPU info: node_name={ti.get_node_name()} | "
46-
f"tpu_type={ti.get_tpu_type()} | "
51+
f"tpu_type={tpu_type} | "
4752
f"worker_id={ti.get_node_worker_id()} | "
48-
f"num_chips={ti.get_num_chips()} | "
53+
f"num_chips={ti.get_num_chips()}{chiplets} | "
4954
f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
5055
except Exception as e:
5156
logger.error(

tpu_inference/tpu_info.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,41 @@ def get_num_cores_per_chip() -> int:
6363
return 2
6464

6565

66+
def get_num_chiplets_per_chip() -> int:
67+
"""Returns the number of chiplets per physical chip for the current TPU type.
68+
69+
For tpu7x devices, each physical chip contains 2 chiplets that are exposed
70+
to the host as separate devices. For all other TPU types, each chip contains
71+
1 chiplet.
72+
73+
Returns:
74+
int: Number of chiplets per physical chip (1 or 2)
75+
"""
76+
tpu_type = get_tpu_type()
77+
if tpu_type and "tpu7x" in tpu_type.lower():
78+
return 2
79+
return 1
80+
81+
6682
def get_num_chips() -> int:
83+
"""Returns the number of physical TPU chips available on the current host.
84+
85+
For tpu7x devices, this correctly accounts for the fact that each physical
86+
chip contains 2 chiplets that are exposed as separate /dev/accel* devices.
87+
88+
Returns:
89+
int: Number of physical TPU chips
90+
"""
91+
chiplets_per_chip = get_num_chiplets_per_chip()
6792
accel_files = glob.glob("/dev/accel*")
6893
if accel_files:
69-
return len(accel_files)
94+
return len(accel_files) // chiplets_per_chip
7095
try:
7196
vfio_entries = os.listdir("/dev/vfio")
7297
numeric_entries = [
7398
int(entry) for entry in vfio_entries if entry.isdigit()
7499
]
75-
return len(numeric_entries)
100+
return len(numeric_entries) // chiplets_per_chip
76101
except FileNotFoundError as e:
77102
logger.error("Failed to detect number of TPUs: %s", e)
78103
return 0

0 commit comments

Comments
 (0)