Skip to content

Conversation

@burbajr
Copy link

@burbajr burbajr commented Dec 8, 2025

Description

For TPU7x devices, each physical chip contains 2 chiplets that are exposed to the host as separate devices. The previous implementation counted these devices directly, resulting in 2x the actual chip count being reported in logs.

Problem Being Solved
When running on tpu7x-8 hardware (which has 4 physical chips with 2 chiplets each), the system would log:
TPU info: ... | num_chips=8 | num_cores_per_chip=2

This was misleading because it reported 8 chips when there are only 4 physical chips. Each chip has 2 chiplets exposed as separate devices to the host.

Solution
This PR adds chiplet-aware logic to the chip counting mechanism:

  1. New helper function get_num_chiplets_per_chip() that returns 2 for tpu7x devices and 1 for all other TPU types
  2. Modified get_num_chips() to divide the device count by chiplets per chip
  3. Enhanced logging to show num_chiplets_per_chip for tpu7x devices only (to avoid confusing non-tpu7x users)

Why This is a Good Solution

  • Follows existing patterns: Uses the same pattern as get_num_cores_per_chip() for consistency
  • Backward compatible: All non-tpu7x devices get chiplets_per_chip=1, making the division a no-op
  • Minimal code changes: Only touches the necessary functions
  • Well tested: Comprehensive test coverage for all scenarios

Implementation Details
The fix uses integer division (//) which is safe because:

  • tpu7x-8: 8 devices // 2 chiplets = 4 chips ✓
  • tpu7x-4: 4 devices // 2 chiplets = 2 chips ✓
  • v5e/v6e: N devices // 1 chiplet = N chips ✓ (unchanged)

Logging is conditional - chiplet info only appears for tpu7x:
tpu7x-8 (after fix):
TPU info: ... | num_chips=4 | num_chiplets_per_chip=2 | num_cores_per_chip=2
v6e-8 (unchanged, no chiplet info to avoid confusion):
TPU info: ... | num_chips=8 | num_cores_per_chip=1

Test Coverage Added

  • test_get_num_chiplets_per_chip: Tests all TPU types including tpu7x-8, tpu7x-4, v6e-8, v5litepod, and edge cases (None, empty string)
  • test_get_num_chips_tpu7x_from_accel: Verifies tpu7x-8 with 8 /dev/accel* devices returns 4 chips
  • test_get_num_chips_tpu7x_4_from_accel: Verifies tpu7x-4 with 4 /dev/accel* devices returns 2 chips
  • test_get_num_chips_tpu7x_from_vfio: Verifies tpu7x-8 with /dev/vfio path returns 4 chips
  • test_get_num_chips_non_tpu7x_unchanged: Verifies v6e-8 still returns 8 chips (backward compatibility)

How to Test
pytest tests/test_tpu_info.py -v

Run specific new tests
pytest tests/test_tpu_info.py::test_get_num_chiplets_per_chip -v
pytest tests/test_tpu_info.py::test_get_num_chips_tpu7x_from_accel -v
pytest tests/test_tpu_info.py::test_get_num_chips_non_tpu7x_unchanged -v
All tests pass with the changes.

@burbajr burbajr requested a review from vipannalla as a code owner December 8, 2025 23:17
@burbajr burbajr force-pushed the fix/tpu7x-chip-counting branch from 68301ae to e92f0a2 Compare December 8, 2025 23:21
For TPU7x devices, each physical chip contains 2 chiplets that are
exposed to the host as separate devices. This was causing
get_num_chips() to report double the actual chip count.

Example: tpu7x-8 has 4 physical chips (8 chiplets total)
- Before: reported num_chips=8 (incorrect)
- After: reports num_chips=4 (correct)

Implementation:
- Modified get_num_chips() to detect tpu7x devices and divide the
  device count by 2 using integer division
- Added test_get_num_chips_tpu7x to verify correct chip counting

All non-tpu7x devices are unaffected by this change.

Signed-off-by: burbajr <[email protected]>
@burbajr burbajr force-pushed the fix/tpu7x-chip-counting branch from f70ba2e to a9d4a49 Compare December 9, 2025 22:40
# For tpu7x, each chip has 2 chiplets exposed as separate devices
tpu_type = get_tpu_type()
if tpu_type and "tpu7x" in tpu_type.lower():
return num_devices // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really a fan of hard coding the value "2" or having "tpu7x" specific logic.

Is there any logic within tpu_info that programatically fetches num_devices?

Or wouldn't a simple len(jax.devices()) return correct number of cores - regardless of having 1 chiplet or 1 chiplets?

Copy link
Author

@burbajr burbajr Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, and thank you for taking a closer look at this.

The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there.

def get_num_cores_per_chip() -> int:
tpu_type = get_tpu_type()
if tpu_type.startswith(("v5litepod", "v6e")):
return 1
return 2

The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips.

With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug.

I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count.

@burbajr
Copy link
Author

burbajr commented Dec 9, 2025

Updated implementation based on feedback to use a simpler approach:

  • Removed get_num_chiplets_per_chip() helper function
  • Simplified get_num_chips() with inline tpu7x check
  • Reverted logging changes to original simple format
  • Kept single focused test for tpu7x chip counting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants