-
Notifications
You must be signed in to change notification settings - Fork 60
Fix TPU7x chip counting to account for chiplet architecture #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,13 +66,20 @@ def get_num_cores_per_chip() -> int: | |
| def get_num_chips() -> int: | ||
| accel_files = glob.glob("/dev/accel*") | ||
| if accel_files: | ||
| return len(accel_files) | ||
| try: | ||
| vfio_entries = os.listdir("/dev/vfio") | ||
| numeric_entries = [ | ||
| int(entry) for entry in vfio_entries if entry.isdigit() | ||
| ] | ||
| return len(numeric_entries) | ||
| except FileNotFoundError as e: | ||
| logger.error("Failed to detect number of TPUs: %s", e) | ||
| return 0 | ||
| num_devices = len(accel_files) | ||
| else: | ||
| try: | ||
| vfio_entries = os.listdir("/dev/vfio") | ||
| numeric_entries = [ | ||
| int(entry) for entry in vfio_entries if entry.isdigit() | ||
| ] | ||
| num_devices = len(numeric_entries) | ||
| except FileNotFoundError as e: | ||
| logger.error("Failed to detect number of TPUs: %s", e) | ||
| return 0 | ||
|
|
||
| # For tpu7x, each chip has 2 chiplets exposed as separate devices | ||
| tpu_type = get_tpu_type() | ||
| if tpu_type and "tpu7x" in tpu_type.lower(): | ||
| return num_devices // 2 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really a fan of hard coding the value "2" or having "tpu7x" specific logic. Is there any logic within tpu_info that programatically fetches num_devices? Or wouldn't a simple
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, and thank you for taking a closer look at this. The original logic (below) for determining the number of cores per TPU family remains correct for v7, which has 2 cores per chip, so no changes are needed there. def get_num_cores_per_chip() -> int: The issue is only with chip count inference. Prior to v7, each physical chip (regardless of core counts) exposed a single logical device, so using jax.devices() or the OS device list was a reliable way to infer the number of chips, since there was a 1:1 mapping between physical chips and logical devices. The original logic in tpu_info simply counts the number of host devices to report the number of chips. With v7, each physical chip now exposes two logical devices, which means len(jax.devices()) and the OS device list report twice as many devices as there are actual chips, leading to the current reporting bug. I was not able to identify a reliable programmatic way to retrieve the physical TPU topology on TPU/GKE nodes, so dividing the device count by two is the most accurate approach I have found for correcting the reported chip count. |
||
| return num_devices | ||
Uh oh!
There was an error while loading. Please reload this page.