Skip to content

Commit 2fd3f5b

Browse files
subhamsoni-googlecopybara-github
authored andcommitted
Cache available tools for each run.
PiperOrigin-RevId: 834160987
1 parent fc53a4f commit 2fd3f5b

File tree

2 files changed

+397
-17
lines changed

2 files changed

+397
-17
lines changed

plugin/xprof/profile_plugin.py

Lines changed: 250 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
import logging
2525
import os
2626
import re
27+
import sys
2728
import threading
28-
from typing import Any, List, Optional, TypedDict
29+
from typing import Any, Dict, List, Optional, Sequence, TypedDict
2930

3031
from etils import epath
3132
import etils.epath.backend
33+
from fsspec import core
3234
import six
3335
from werkzeug import wrappers
3436

@@ -38,11 +40,19 @@
3840
from xprof.standalone.tensorboard_shim import plugin_asset_util
3941
from xprof.convert import _pywrap_profiler_plugin
4042

41-
42-
logger = logging.getLogger('tensorboard')
43+
logger = logging.getLogger('tensorboard.plugins.profile')
44+
logger.setLevel(logging.INFO)
45+
if not logger.handlers:
46+
handler = logging.StreamHandler(sys.stderr)
47+
formatter = logging.Formatter(
48+
'%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s'
49+
)
50+
handler.setFormatter(formatter)
51+
logger.addHandler(handler)
52+
logger.propagate = False
4353

4454
try:
45-
import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top
55+
import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
4656

4757
tf.enable_v2_behavior()
4858
except ImportError:
@@ -399,6 +409,216 @@ def _get_bool_arg(
399409
return arg_str.lower() == 'true'
400410

401411

412+
class ToolsCache:
413+
"""Caches the list of tools for a profile run based on file content hashes or mtimes.
414+
415+
Attributes:
416+
CACHE_FILE_NAME: The name of the cache file.
417+
CACHE_VERSION: The version of the cache format.
418+
"""
419+
420+
CACHE_FILE_NAME = '.cached_tools.json'
421+
CACHE_VERSION = 1
422+
423+
def __init__(self, profile_run_dir: epath.Path):
424+
"""Initializes the ToolsCache.
425+
426+
Args:
427+
profile_run_dir: The directory containing the profile run data.
428+
"""
429+
self._profile_run_dir = profile_run_dir
430+
self._cache_file = self._profile_run_dir / self.CACHE_FILE_NAME
431+
logger.info('ToolsCache initialized for %s', self._cache_file)
432+
433+
def _get_local_file_identifier(self, file_path_str: str) -> Optional[str]:
434+
"""Gets a string identifier for a local file.
435+
436+
The identifier is a combination of the file's last modification time (mtime)
437+
and size, in the format "{mtime}-{size}".
438+
439+
Args:
440+
file_path_str: The absolute path to the local file.
441+
442+
Returns:
443+
A string identifier, or None if the file is not found or an error occurs.
444+
"""
445+
try:
446+
stat_result = os.stat(file_path_str)
447+
return f'{int(stat_result.st_mtime)}-{stat_result.st_size}'
448+
except FileNotFoundError:
449+
logger.warning('Local file not found: %s', file_path_str)
450+
return None
451+
except OSError as e:
452+
logger.error(
453+
'OSError getting stat for local file %s: %s', file_path_str, e
454+
)
455+
return None
456+
457+
def _get_gcs_file_hash(self, file_path_str: str) -> Optional[str]:
458+
"""Gets the MD5 hash for a GCS file.
459+
460+
Args:
461+
file_path_str: The GCS path (e.g., "gs://bucket/object").
462+
463+
Returns:
464+
The MD5 hash string, or None if the file is not found or an error occurs.
465+
"""
466+
try:
467+
fs = core.get_fs_token_paths(file_path_str)[0]
468+
info = fs.info(file_path_str)
469+
md5_hash = info.get('md5Hash')
470+
471+
if not isinstance(md5_hash, str):
472+
logger.warning(
473+
'Could not find a valid md5Hash string in info for %s: %s',
474+
file_path_str,
475+
info,
476+
)
477+
return None
478+
479+
return md5_hash
480+
481+
except FileNotFoundError:
482+
logger.warning('GCS path not found: %s', file_path_str)
483+
return None
484+
except IndexError:
485+
logger.error('Could not get filesystem for GCS path: %s', file_path_str)
486+
return None
487+
except Exception as e: # pylint: disable=broad-exception-caught
488+
logger.exception(
489+
'Unexpected error getting hash for GCS path %s: %s', file_path_str, e
490+
)
491+
return None
492+
493+
def get_file_identifier(self, file_path_str: str) -> Optional[str]:
494+
"""Gets a string identifier for a file.
495+
496+
For GCS files, this is the MD5 hash.
497+
For local files, this is a string combining mtime and size.
498+
499+
Args:
500+
file_path_str: The full path to the file (local or GCS).
501+
502+
Returns:
503+
A string identifier, or None if an error occurs.
504+
"""
505+
if file_path_str.startswith('gs://'):
506+
return self._get_gcs_file_hash(file_path_str)
507+
else:
508+
return self._get_local_file_identifier(file_path_str)
509+
510+
def _get_current_xplane_file_states(self) -> Optional[Dict[str, str]]:
511+
"""Gets the current state of XPlane files in the profile run directory.
512+
513+
Returns:
514+
A dictionary mapping filename to a string identifier (hash or mtime-size),
515+
or None if any file state cannot be determined.
516+
"""
517+
try:
518+
file_identifiers = {}
519+
for xplane_file in self._profile_run_dir.glob(f"*.{TOOLS['xplane']}"):
520+
file_id = self.get_file_identifier(str(xplane_file))
521+
if file_id is None:
522+
logger.warning(
523+
'Could not get identifier for %s, cache will be invalidated.',
524+
xplane_file,
525+
)
526+
return None
527+
file_identifiers[xplane_file.name] = file_id
528+
return file_identifiers
529+
except OSError as e:
530+
logger.warning('Could not glob files in %s: %s', self._profile_run_dir, e)
531+
return None
532+
533+
def load(self) -> Optional[List[str]]:
534+
"""Loads the cached list of tools if the cache is valid.
535+
536+
The cache is valid if the cache file exists, the version matches, and
537+
the file states (hashes/mtimes) of the XPlane files have not changed.
538+
539+
Returns:
540+
A list of tool names if the cache is valid, otherwise None.
541+
"""
542+
try:
543+
with self._cache_file.open('r') as f:
544+
cached_data = json.load(f)
545+
except (OSError, json.JSONDecodeError) as e:
546+
logger.warning(
547+
'Error reading or decoding cache file %s: %s, invalidating.',
548+
self._cache_file,
549+
e,
550+
)
551+
self.invalidate()
552+
return None
553+
554+
if cached_data.get('version') != self.CACHE_VERSION:
555+
logger.info(
556+
'ToolsCache invalid: version mismatch, expected %s, got %s.'
557+
' Invalidating %s',
558+
self.CACHE_VERSION,
559+
cached_data.get('version'),
560+
self._cache_file,
561+
)
562+
self.invalidate()
563+
return None
564+
565+
current_files = self._get_current_xplane_file_states()
566+
if current_files is None:
567+
logger.info(
568+
'ToolsCache invalid: could not determine current file states.'
569+
' Invalidating %s',
570+
self._cache_file,
571+
)
572+
self.invalidate()
573+
return None
574+
575+
if cached_data.get('files') != current_files:
576+
logger.info(
577+
'ToolsCache invalid: file states differ. Invalidating %s',
578+
self._cache_file,
579+
)
580+
self.invalidate()
581+
return None
582+
583+
logger.info('ToolsCache hit: %s', self._cache_file)
584+
return cached_data.get('tools')
585+
586+
def save(self, tools: Sequence[str]) -> None:
587+
"""Saves the list of tools and the current file states to the cache file.
588+
589+
Args:
590+
tools: The list of tool names to cache.
591+
"""
592+
current_files_for_cache = self._get_current_xplane_file_states()
593+
if current_files_for_cache is None:
594+
logger.warning(
595+
'ToolsCache not saved: could not get file states %s', self._cache_file
596+
)
597+
return
598+
599+
new_cache_data = {
600+
'version': self.CACHE_VERSION,
601+
'files': current_files_for_cache,
602+
'tools': tools,
603+
}
604+
try:
605+
with self._cache_file.open('w') as f:
606+
json.dump(new_cache_data, f, sort_keys=True, indent=2)
607+
logger.info('ToolsCache saved: %s', self._cache_file)
608+
except (OSError, TypeError) as e:
609+
logger.error('Error writing cache file %s: %s', self._cache_file, e)
610+
611+
def invalidate(self) -> None:
612+
"""Deletes the cache file, forcing regeneration on the next load."""
613+
try:
614+
self._cache_file.unlink()
615+
logger.info('ToolsCache invalidated: %s', self._cache_file)
616+
except FileNotFoundError:
617+
pass
618+
except OSError as e:
619+
logger.error('Error removing cache file %s: %s', self._cache_file, e)
620+
621+
402622
class _TfProfiler:
403623
"""A helper class to encapsulate all TensorFlow-dependent profiler logic."""
404624

@@ -1256,19 +1476,32 @@ def generate_runs(self) -> Iterator[str]:
12561476

12571477
def generate_tools_of_run(self, run: str) -> Iterator[str]:
12581478
"""Generate a list of tools given a certain run."""
1259-
profile_run_dir = self._run_to_profile_run_dir[run]
1260-
if epath.Path(profile_run_dir).is_dir():
1261-
try:
1262-
filenames = epath.Path(profile_run_dir).iterdir()
1263-
except OSError as e:
1264-
logger.warning('Cannot read asset directory: %s, NotFoundError %s',
1265-
profile_run_dir, e)
1266-
filenames = []
1267-
if filenames:
1268-
for tool in self._get_active_tools(
1269-
[name.name for name in filenames], profile_run_dir
1270-
):
1271-
yield tool
1479+
profile_run_dir = epath.Path(self._run_to_profile_run_dir[run])
1480+
cache = ToolsCache(profile_run_dir)
1481+
1482+
cached_tools = cache.load()
1483+
1484+
if cached_tools is not None:
1485+
for tool in cached_tools:
1486+
yield tool
1487+
return
1488+
1489+
# Cache is invalid or doesn't exist, regenerate
1490+
tools = []
1491+
try:
1492+
all_filenames = [f.name for f in profile_run_dir.iterdir()]
1493+
except OSError as e:
1494+
logger.warning(
1495+
'Cannot read asset directory: %s, Error %s', profile_run_dir, e
1496+
)
1497+
return tools
1498+
1499+
if all_filenames:
1500+
tools = self._get_active_tools(all_filenames, str(profile_run_dir))
1501+
cache.save(tools)
1502+
1503+
for tool in tools:
1504+
yield tool
12721505

12731506
def _get_active_tools(self, filenames, profile_run_dir=''):
12741507
"""Get a list of tools available given the filenames created by profiler.

0 commit comments

Comments
 (0)