-
Notifications
You must be signed in to change notification settings - Fork 59
Fix TPU7x chip counting to account for chiplet architecture #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
68301ae to
e92f0a2
Compare
For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. This was causing get_num_chips() to report double the actual chip count. Example: tpu7x-8 has 4 physical chips (8 chiplets total) - Before: reported num_chips=8 (incorrect) - After: reports num_chips=4 (correct) Implementation: - Modified get_num_chips() to detect tpu7x devices and divide the device count by 2 using integer division - Added test_get_num_chips_tpu7x to verify correct chip counting All non-tpu7x devices are unaffected by this change. Signed-off-by: burbajr <[email protected]>
f70ba2e to
a9d4a49
Compare
| # For tpu7x, each chip has 2 chiplets exposed as separate devices | ||
| tpu_type = get_tpu_type() | ||
| if tpu_type and "tpu7x" in tpu_type.lower(): | ||
| return num_devices // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really a fan of hard coding the value "2" or having "tpu7x" specific logic.
Is there any logic within tpu_info that programatically fetches num_devices?
Or wouldn't a simple len(jax.devices()) return correct number of cores - regardless of having 1 chiplet or 1 chiplets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, and thank you for taking a closer look at this.
The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there.
def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2
The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips.
With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug.
I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count.
|
Updated implementation based on feedback to use a simpler approach:
|
Description
For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. The previous implementation counted these devices directly, resulting in 2x the actual chip count being reported in logs.
Problem Being Solved
When running on tpu7x-8 hardware (which has 4 physical chips with 2 chiplets each), the system would log:
TPU info: ... | num_chips=8 | num_cores_per_chip=2
This was misleading because it reported 8 chips when there are only 4 physical chips. Each chip has 2 chiplets exposed as separate devices to the host.
Solution
This PR adds chiplet-aware logic to the chip counting mechanism:
get_num_chiplets_per_chip()that returns 2 for tpu7x devices and 1 for all other TPU typesget_num_chips()to divide the device count by chiplets per chipnum_chiplets_per_chipfor tpu7x devices only (to avoid confusing non-tpu7x users)Why This is a Good Solution
get_num_cores_per_chip()for consistencychiplets_per_chip=1, making the division a no-opImplementation Details
The fix uses integer division (
//) which is safe because:Logging is conditional - chiplet info only appears for tpu7x:
tpu7x-8 (after fix):
TPU info: ... | num_chips=4 | num_chiplets_per_chip=2 | num_cores_per_chip=2
v6e-8 (unchanged, no chiplet info to avoid confusion):
TPU info: ... | num_chips=8 | num_cores_per_chip=1
Test Coverage Added
test_get_num_chiplets_per_chip: Tests all TPU types including tpu7x-8, tpu7x-4, v6e-8, v5litepod, and edge cases (None, empty string)test_get_num_chips_tpu7x_from_accel: Verifies tpu7x-8 with 8/dev/accel*devices returns 4 chipstest_get_num_chips_tpu7x_4_from_accel: Verifies tpu7x-4 with 4/dev/accel*devices returns 2 chipstest_get_num_chips_tpu7x_from_vfio: Verifies tpu7x-8 with/dev/vfiopath returns 4 chipstest_get_num_chips_non_tpu7x_unchanged: Verifies v6e-8 still returns 8 chips (backward compatibility)How to Test
pytest tests/test_tpu_info.py -v
Run specific new tests
pytest tests/test_tpu_info.py::test_get_num_chiplets_per_chip -v
pytest tests/test_tpu_info.py::test_get_num_chips_tpu7x_from_accel -v
pytest tests/test_tpu_info.py::test_get_num_chips_non_tpu7x_unchanged -v
All tests pass with the changes.