|
1 | | -import ctypes |
2 | | -import getpass |
3 | 1 | import logging |
4 | 2 | import os |
5 | | -import platform |
6 | | -import tempfile |
7 | | -import urllib.request |
8 | | -from pathlib import Path |
9 | | -from typing import Optional |
10 | 3 |
|
11 | 4 | import torch |
12 | 5 | from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh |
13 | | -from torch_tensorrt._version import __tensorrt_llm_version__ |
14 | | - |
15 | | -_WHL_CPYTHON_VERSION = "cp310" |
16 | 6 |
|
17 | 7 | logger = logging.getLogger(__name__) |
18 | 8 |
|
@@ -42,268 +32,10 @@ def get_tensor_parallel_device_mesh( |
42 | 32 | return device_mesh, world_size, rank |
43 | 33 |
|
44 | 34 |
|
45 | | -def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: |
| 35 | +def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: |
46 | 36 | logger = logging.getLogger() |
47 | 37 | logger.setLevel(logging.INFO) |
48 | 38 | fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") |
49 | 39 | fh.setLevel(logging.INFO) |
50 | 40 | logger.addHandler(fh) |
51 | 41 | return logger |
52 | | - |
53 | | - |
54 | | -def is_platform_supported_for_trtllm() -> bool: |
55 | | - """ |
56 | | - Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. |
57 | | -
|
58 | | - Returns: |
59 | | - bool: True if supported, False otherwise. |
60 | | -
|
61 | | - Unsupported: |
62 | | - - Windows platforms |
63 | | - - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) |
64 | | - - CUDA 13 not supported |
65 | | - """ |
66 | | - system = platform.system().lower() |
67 | | - machine = platform.machine().lower() |
68 | | - release = platform.release().lower() |
69 | | - |
70 | | - if "windows" in system: |
71 | | - logger.info( |
72 | | - "TensorRT-LLM plugins for NCCL backend are not supported on Windows." |
73 | | - ) |
74 | | - return False |
75 | | - |
76 | | - if machine == "aarch64" and "tegra" in release: |
77 | | - logger.info( |
78 | | - "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices." |
79 | | - ) |
80 | | - return False |
81 | | - |
82 | | - try: |
83 | | - cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" |
84 | | - if cuda_version is None: |
85 | | - logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.") |
86 | | - return False |
87 | | - |
88 | | - major, minor = map(int, cuda_version.split(".")) |
89 | | - if major != 12: |
90 | | - logger.warning("CUDA 13 is not supported for TRT-LLM plugins.") |
91 | | - return False |
92 | | - |
93 | | - return True |
94 | | - |
95 | | - except Exception as e: |
96 | | - logger.warning(f"Failed to detect CUDA version: {e}") |
97 | | - return False |
98 | | - |
99 | | - return True |
100 | | - |
101 | | - |
102 | | -def _cache_root() -> Path: |
103 | | - username = getpass.getuser() |
104 | | - return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" |
105 | | - |
106 | | - |
107 | | -def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: |
108 | | - return ( |
109 | | - _cache_root() |
110 | | - / "trtllm" |
111 | | - / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" |
112 | | - ) |
113 | | - |
114 | | - |
115 | | -def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: |
116 | | - from torch.distributed import barrier, get_rank, is_initialized |
117 | | - |
118 | | - if not is_initialized(): |
119 | | - # Single process case, just unzip |
120 | | - is_master = True |
121 | | - else: |
122 | | - is_master = get_rank() == 0 # only rank 0 does the unzip |
123 | | - |
124 | | - if is_master: |
125 | | - try: |
126 | | - import zipfile |
127 | | - except ImportError as e: |
128 | | - raise ImportError( |
129 | | - "zipfile module is required but not found. Please install zipfile" |
130 | | - ) |
131 | | - try: |
132 | | - with zipfile.ZipFile(wheel_path) as zip_ref: |
133 | | - zip_ref.extractall(extract_dir) |
134 | | - logger.debug(f"Extracted wheel to {extract_dir}") |
135 | | - |
136 | | - except FileNotFoundError as e: |
137 | | - # This should capture the errors in the download failure above |
138 | | - logger.error(f"Wheel file not found at {wheel_path}: {e}") |
139 | | - raise RuntimeError( |
140 | | - f"Failed to find downloaded wheel file at {wheel_path}" |
141 | | - ) from e |
142 | | - except zipfile.BadZipFile as e: |
143 | | - logger.error(f"Invalid or corrupted wheel file: {e}") |
144 | | - raise RuntimeError( |
145 | | - "Downloaded wheel file is corrupted or not a valid zip archive" |
146 | | - ) from e |
147 | | - except Exception as e: |
148 | | - logger.error(f"Unexpected error while extracting wheel: {e}") |
149 | | - raise RuntimeError( |
150 | | - "Unexpected error during extraction of TensorRT-LLM wheel" |
151 | | - ) from e |
152 | | - |
153 | | - # Make sure others wait until unzip is done |
154 | | - if is_initialized(): |
155 | | - barrier() |
156 | | - |
157 | | - |
158 | | -def download_and_get_plugin_lib_path() -> Optional[str]: |
159 | | - """ |
160 | | - Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. |
161 | | -
|
162 | | - Args: |
163 | | - platform (str): Platform identifier (e.g., 'linux_x86_64') |
164 | | -
|
165 | | - Returns: |
166 | | - Optional[str]: Path to shared library or None if operation fails. |
167 | | - """ |
168 | | - platform_system = platform.system().lower() |
169 | | - platform_machine = platform.machine().lower() |
170 | | - wheel_filename = ( |
171 | | - f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" |
172 | | - f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" |
173 | | - ) |
174 | | - wheel_path = _cache_root() / wheel_filename |
175 | | - extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) |
176 | | - # else will never be met though |
177 | | - lib_filename = ( |
178 | | - "libnvinfer_plugin_tensorrt_llm.so" |
179 | | - if "linux" in platform_system |
180 | | - else "libnvinfer_plugin_tensorrt_llm.dll" |
181 | | - ) |
182 | | - # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so |
183 | | - plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename |
184 | | - |
185 | | - if plugin_lib_path.exists(): |
186 | | - return str(plugin_lib_path) |
187 | | - |
188 | | - wheel_path.parent.mkdir(parents=True, exist_ok=True) |
189 | | - extract_dir.mkdir(parents=True, exist_ok=True) |
190 | | - |
191 | | - if not wheel_path.exists(): |
192 | | - base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
193 | | - download_url = base_url + wheel_filename |
194 | | - try: |
195 | | - logger.debug(f"Downloading {download_url} ...") |
196 | | - urllib.request.urlretrieve(download_url, wheel_path) |
197 | | - logger.debug("Download succeeded and TRT-LLM wheel is now present") |
198 | | - except urllib.error.HTTPError as e: |
199 | | - logger.error( |
200 | | - f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" |
201 | | - ) |
202 | | - except urllib.error.URLError as e: |
203 | | - logger.error( |
204 | | - f"URL error when trying to download {download_url}: {e.reason}" |
205 | | - ) |
206 | | - except OSError as e: |
207 | | - logger.error(f"Local file write error: {e}") |
208 | | - |
209 | | - extract_wheel_file(wheel_path, extract_dir) |
210 | | - |
211 | | - try: |
212 | | - wheel_path.unlink(missing_ok=True) |
213 | | - logger.debug(f"Deleted wheel file: {wheel_path}") |
214 | | - except Exception as e: |
215 | | - logger.warning(f"Could not delete wheel file {wheel_path}: {e}") |
216 | | - if not plugin_lib_path.exists(): |
217 | | - logger.error( |
218 | | - f"Plugin library not found at expected location: {plugin_lib_path}" |
219 | | - ) |
220 | | - return None |
221 | | - |
222 | | - return str(plugin_lib_path) |
223 | | - |
224 | | - |
225 | | -def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: |
226 | | - """ |
227 | | - Loads and initializes the TensorRT-LLM plugin from the given shared library path. |
228 | | -
|
229 | | - Args: |
230 | | - plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. |
231 | | -
|
232 | | - Returns: |
233 | | - bool: True if successful, False otherwise. |
234 | | - """ |
235 | | - try: |
236 | | - handle = ctypes.CDLL(plugin_lib_path) |
237 | | - logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
238 | | - except OSError as e_os_error: |
239 | | - if "libmpi" in str(e_os_error): |
240 | | - logger.warning( |
241 | | - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", |
242 | | - exc_info=e_os_error, |
243 | | - ) |
244 | | - else: |
245 | | - logger.warning( |
246 | | - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " |
247 | | - f"Ensure the path is correct and the library is compatible.", |
248 | | - exc_info=e_os_error, |
249 | | - ) |
250 | | - return False |
251 | | - |
252 | | - try: |
253 | | - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
254 | | - handle.initTrtLlmPlugins.restype = ctypes.c_bool |
255 | | - except AttributeError as e_plugin_unavailable: |
256 | | - logger.warning( |
257 | | - "Unable to initialize the TensorRT-LLM plugin library", |
258 | | - exc_info=e_plugin_unavailable, |
259 | | - ) |
260 | | - return False |
261 | | - |
262 | | - try: |
263 | | - if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): |
264 | | - logger.info("TensorRT-LLM plugin successfully initialized") |
265 | | - return True |
266 | | - else: |
267 | | - logger.warning("TensorRT-LLM plugin library failed in initialization") |
268 | | - return False |
269 | | - except Exception as e_initialization_error: |
270 | | - logger.warning( |
271 | | - "Exception occurred during TensorRT-LLM plugin library initialization", |
272 | | - exc_info=e_initialization_error, |
273 | | - ) |
274 | | - return False |
275 | | - return False |
276 | | - |
277 | | - |
278 | | -def load_tensorrt_llm_for_nccl() -> bool: |
279 | | - """ |
280 | | - Attempts to load the TensorRT-LLM plugin and initialize it. |
281 | | - Either the env variable TRTLLM_PLUGINS_PATH can specify the path |
282 | | - Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it |
283 | | -
|
284 | | - Returns: |
285 | | - bool: True if the plugin was successfully loaded and initialized, False otherwise. |
286 | | - """ |
287 | | - if not is_platform_supported_for_trtllm(): |
288 | | - return False |
289 | | - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
290 | | - |
291 | | - if plugin_lib_path: |
292 | | - return load_and_initialize_trtllm_plugin(plugin_lib_path) |
293 | | - else: |
294 | | - # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user |
295 | | - use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( |
296 | | - "1", |
297 | | - "true", |
298 | | - "yes", |
299 | | - "on", |
300 | | - ) |
301 | | - if not use_trtllm_plugin: |
302 | | - logger.warning( |
303 | | - "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" |
304 | | - ) |
305 | | - return False |
306 | | - |
307 | | - plugin_lib_path = download_and_get_plugin_lib_path() |
308 | | - return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] |
309 | | - return False |
0 commit comments