|
24 | 24 | import logging |
25 | 25 | import os |
26 | 26 | import re |
| 27 | +import sys |
27 | 28 | import threading |
28 | | -from typing import Any, List, Optional, TypedDict |
| 29 | +from typing import Any, Dict, List, Optional, Sequence, TypedDict |
29 | 30 |
|
30 | 31 | from etils import epath |
31 | 32 | import etils.epath.backend |
| 33 | +from fsspec import core |
32 | 34 | import six |
33 | 35 | from werkzeug import wrappers |
34 | 36 |
|
|
38 | 40 | from xprof.standalone.tensorboard_shim import plugin_asset_util |
39 | 41 | from xprof.convert import _pywrap_profiler_plugin |
40 | 42 |
|
41 | | - |
42 | | -logger = logging.getLogger('tensorboard') |
| 43 | +logger = logging.getLogger('tensorboard.plugins.profile') |
| 44 | +logger.setLevel(logging.INFO) |
| 45 | +if not logger.handlers: |
| 46 | + handler = logging.StreamHandler(sys.stderr) |
| 47 | + formatter = logging.Formatter( |
| 48 | + '%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s' |
| 49 | + ) |
| 50 | + handler.setFormatter(formatter) |
| 51 | + logger.addHandler(handler) |
| 52 | + logger.propagate = False |
43 | 53 |
|
44 | 54 | try: |
45 | | - import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top |
| 55 | + import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
46 | 56 |
|
47 | 57 | tf.enable_v2_behavior() |
48 | 58 | except ImportError: |
@@ -399,6 +409,216 @@ def _get_bool_arg( |
399 | 409 | return arg_str.lower() == 'true' |
400 | 410 |
|
401 | 411 |
|
| 412 | +class ToolsCache: |
| 413 | + """Caches the list of tools for a profile run based on file content hashes or mtimes. |
| 414 | +
|
| 415 | + Attributes: |
| 416 | + CACHE_FILE_NAME: The name of the cache file. |
| 417 | + CACHE_VERSION: The version of the cache format. |
| 418 | + """ |
| 419 | + |
| 420 | + CACHE_FILE_NAME = '.cached_tools.json' |
| 421 | + CACHE_VERSION = 1 |
| 422 | + |
| 423 | + def __init__(self, profile_run_dir: epath.Path): |
| 424 | + """Initializes the ToolsCache. |
| 425 | +
|
| 426 | + Args: |
| 427 | + profile_run_dir: The directory containing the profile run data. |
| 428 | + """ |
| 429 | + self._profile_run_dir = profile_run_dir |
| 430 | + self._cache_file = self._profile_run_dir / self.CACHE_FILE_NAME |
| 431 | + logger.info('ToolsCache initialized for %s', self._cache_file) |
| 432 | + |
| 433 | + def _get_local_file_identifier(self, file_path_str: str) -> Optional[str]: |
| 434 | + """Gets a string identifier for a local file. |
| 435 | +
|
| 436 | + The identifier is a combination of the file's last modification time (mtime) |
| 437 | + and size, in the format "{mtime}-{size}". |
| 438 | +
|
| 439 | + Args: |
| 440 | + file_path_str: The absolute path to the local file. |
| 441 | +
|
| 442 | + Returns: |
| 443 | + A string identifier, or None if the file is not found or an error occurs. |
| 444 | + """ |
| 445 | + try: |
| 446 | + stat_result = os.stat(file_path_str) |
| 447 | + return f'{int(stat_result.st_mtime)}-{stat_result.st_size}' |
| 448 | + except FileNotFoundError: |
| 449 | + logger.warning('Local file not found: %s', file_path_str) |
| 450 | + return None |
| 451 | + except OSError as e: |
| 452 | + logger.error( |
| 453 | + 'OSError getting stat for local file %s: %s', file_path_str, e |
| 454 | + ) |
| 455 | + return None |
| 456 | + |
| 457 | + def _get_gcs_file_hash(self, file_path_str: str) -> Optional[str]: |
| 458 | + """Gets the MD5 hash for a GCS file. |
| 459 | +
|
| 460 | + Args: |
| 461 | + file_path_str: The GCS path (e.g., "gs://bucket/object"). |
| 462 | +
|
| 463 | + Returns: |
| 464 | + The MD5 hash string, or None if the file is not found or an error occurs. |
| 465 | + """ |
| 466 | + try: |
| 467 | + fs = core.get_fs_token_paths(file_path_str)[0] |
| 468 | + info = fs.info(file_path_str) |
| 469 | + md5_hash = info.get('md5Hash') |
| 470 | + |
| 471 | + if not isinstance(md5_hash, str): |
| 472 | + logger.warning( |
| 473 | + 'Could not find a valid md5Hash string in info for %s: %s', |
| 474 | + file_path_str, |
| 475 | + info, |
| 476 | + ) |
| 477 | + return None |
| 478 | + |
| 479 | + return md5_hash |
| 480 | + |
| 481 | + except FileNotFoundError: |
| 482 | + logger.warning('GCS path not found: %s', file_path_str) |
| 483 | + return None |
| 484 | + except IndexError: |
| 485 | + logger.error('Could not get filesystem for GCS path: %s', file_path_str) |
| 486 | + return None |
| 487 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 488 | + logger.exception( |
| 489 | + 'Unexpected error getting hash for GCS path %s: %s', file_path_str, e |
| 490 | + ) |
| 491 | + return None |
| 492 | + |
| 493 | + def get_file_identifier(self, file_path_str: str) -> Optional[str]: |
| 494 | + """Gets a string identifier for a file. |
| 495 | +
|
| 496 | + For GCS files, this is the MD5 hash. |
| 497 | + For local files, this is a string combining mtime and size. |
| 498 | +
|
| 499 | + Args: |
| 500 | + file_path_str: The full path to the file (local or GCS). |
| 501 | +
|
| 502 | + Returns: |
| 503 | + A string identifier, or None if an error occurs. |
| 504 | + """ |
| 505 | + if file_path_str.startswith('gs://'): |
| 506 | + return self._get_gcs_file_hash(file_path_str) |
| 507 | + else: |
| 508 | + return self._get_local_file_identifier(file_path_str) |
| 509 | + |
| 510 | + def _get_current_xplane_file_states(self) -> Optional[Dict[str, str]]: |
| 511 | + """Gets the current state of XPlane files in the profile run directory. |
| 512 | +
|
| 513 | + Returns: |
| 514 | + A dictionary mapping filename to a string identifier (hash or mtime-size), |
| 515 | + or None if any file state cannot be determined. |
| 516 | + """ |
| 517 | + try: |
| 518 | + file_identifiers = {} |
| 519 | + for xplane_file in self._profile_run_dir.glob(f"*.{TOOLS['xplane']}"): |
| 520 | + file_id = self.get_file_identifier(str(xplane_file)) |
| 521 | + if file_id is None: |
| 522 | + logger.warning( |
| 523 | + 'Could not get identifier for %s, cache will be invalidated.', |
| 524 | + xplane_file, |
| 525 | + ) |
| 526 | + return None |
| 527 | + file_identifiers[xplane_file.name] = file_id |
| 528 | + return file_identifiers |
| 529 | + except OSError as e: |
| 530 | + logger.warning('Could not glob files in %s: %s', self._profile_run_dir, e) |
| 531 | + return None |
| 532 | + |
| 533 | + def load(self) -> Optional[List[str]]: |
| 534 | + """Loads the cached list of tools if the cache is valid. |
| 535 | +
|
| 536 | + The cache is valid if the cache file exists, the version matches, and |
| 537 | + the file states (hashes/mtimes) of the XPlane files have not changed. |
| 538 | +
|
| 539 | + Returns: |
| 540 | + A list of tool names if the cache is valid, otherwise None. |
| 541 | + """ |
| 542 | + try: |
| 543 | + with self._cache_file.open('r') as f: |
| 544 | + cached_data = json.load(f) |
| 545 | + except (OSError, json.JSONDecodeError) as e: |
| 546 | + logger.warning( |
| 547 | + 'Error reading or decoding cache file %s: %s, invalidating.', |
| 548 | + self._cache_file, |
| 549 | + e, |
| 550 | + ) |
| 551 | + self.invalidate() |
| 552 | + return None |
| 553 | + |
| 554 | + if cached_data.get('version') != self.CACHE_VERSION: |
| 555 | + logger.info( |
| 556 | + 'ToolsCache invalid: version mismatch, expected %s, got %s.' |
| 557 | + ' Invalidating %s', |
| 558 | + self.CACHE_VERSION, |
| 559 | + cached_data.get('version'), |
| 560 | + self._cache_file, |
| 561 | + ) |
| 562 | + self.invalidate() |
| 563 | + return None |
| 564 | + |
| 565 | + current_files = self._get_current_xplane_file_states() |
| 566 | + if current_files is None: |
| 567 | + logger.info( |
| 568 | + 'ToolsCache invalid: could not determine current file states.' |
| 569 | + ' Invalidating %s', |
| 570 | + self._cache_file, |
| 571 | + ) |
| 572 | + self.invalidate() |
| 573 | + return None |
| 574 | + |
| 575 | + if cached_data.get('files') != current_files: |
| 576 | + logger.info( |
| 577 | + 'ToolsCache invalid: file states differ. Invalidating %s', |
| 578 | + self._cache_file, |
| 579 | + ) |
| 580 | + self.invalidate() |
| 581 | + return None |
| 582 | + |
| 583 | + logger.info('ToolsCache hit: %s', self._cache_file) |
| 584 | + return cached_data.get('tools') |
| 585 | + |
| 586 | + def save(self, tools: Sequence[str]) -> None: |
| 587 | + """Saves the list of tools and the current file states to the cache file. |
| 588 | +
|
| 589 | + Args: |
| 590 | + tools: The list of tool names to cache. |
| 591 | + """ |
| 592 | + current_files_for_cache = self._get_current_xplane_file_states() |
| 593 | + if current_files_for_cache is None: |
| 594 | + logger.warning( |
| 595 | + 'ToolsCache not saved: could not get file states %s', self._cache_file |
| 596 | + ) |
| 597 | + return |
| 598 | + |
| 599 | + new_cache_data = { |
| 600 | + 'version': self.CACHE_VERSION, |
| 601 | + 'files': current_files_for_cache, |
| 602 | + 'tools': tools, |
| 603 | + } |
| 604 | + try: |
| 605 | + with self._cache_file.open('w') as f: |
| 606 | + json.dump(new_cache_data, f, sort_keys=True, indent=2) |
| 607 | + logger.info('ToolsCache saved: %s', self._cache_file) |
| 608 | + except (OSError, TypeError) as e: |
| 609 | + logger.error('Error writing cache file %s: %s', self._cache_file, e) |
| 610 | + |
| 611 | + def invalidate(self) -> None: |
| 612 | + """Deletes the cache file, forcing regeneration on the next load.""" |
| 613 | + try: |
| 614 | + self._cache_file.unlink() |
| 615 | + logger.info('ToolsCache invalidated: %s', self._cache_file) |
| 616 | + except FileNotFoundError: |
| 617 | + pass |
| 618 | + except OSError as e: |
| 619 | + logger.error('Error removing cache file %s: %s', self._cache_file, e) |
| 620 | + |
| 621 | + |
402 | 622 | class _TfProfiler: |
403 | 623 | """A helper class to encapsulate all TensorFlow-dependent profiler logic.""" |
404 | 624 |
|
@@ -1256,19 +1476,32 @@ def generate_runs(self) -> Iterator[str]: |
1256 | 1476 |
|
1257 | 1477 | def generate_tools_of_run(self, run: str) -> Iterator[str]: |
1258 | 1478 | """Generate a list of tools given a certain run.""" |
1259 | | - profile_run_dir = self._run_to_profile_run_dir[run] |
1260 | | - if epath.Path(profile_run_dir).is_dir(): |
1261 | | - try: |
1262 | | - filenames = epath.Path(profile_run_dir).iterdir() |
1263 | | - except OSError as e: |
1264 | | - logger.warning('Cannot read asset directory: %s, NotFoundError %s', |
1265 | | - profile_run_dir, e) |
1266 | | - filenames = [] |
1267 | | - if filenames: |
1268 | | - for tool in self._get_active_tools( |
1269 | | - [name.name for name in filenames], profile_run_dir |
1270 | | - ): |
1271 | | - yield tool |
| 1479 | + profile_run_dir = epath.Path(self._run_to_profile_run_dir[run]) |
| 1480 | + cache = ToolsCache(profile_run_dir) |
| 1481 | + |
| 1482 | + cached_tools = cache.load() |
| 1483 | + |
| 1484 | + if cached_tools is not None: |
| 1485 | + for tool in cached_tools: |
| 1486 | + yield tool |
| 1487 | + return |
| 1488 | + |
| 1489 | + # Cache is invalid or doesn't exist, regenerate |
| 1490 | + tools = [] |
| 1491 | + try: |
| 1492 | + all_filenames = [f.name for f in profile_run_dir.iterdir()] |
| 1493 | + except OSError as e: |
| 1494 | + logger.warning( |
| 1495 | + 'Cannot read asset directory: %s, Error %s', profile_run_dir, e |
| 1496 | + ) |
| 1497 | + return tools |
| 1498 | + |
| 1499 | + if all_filenames: |
| 1500 | + tools = self._get_active_tools(all_filenames, str(profile_run_dir)) |
| 1501 | + cache.save(tools) |
| 1502 | + |
| 1503 | + for tool in tools: |
| 1504 | + yield tool |
1272 | 1505 |
|
1273 | 1506 | def _get_active_tools(self, filenames, profile_run_dir=''): |
1274 | 1507 | """Get a list of tools available given the filenames created by profiler. |
|
0 commit comments