Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions hta/common/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def parse_trace_file(
meta, df, local_symbol_table = parse_trace_dataframe(trace_file_path, cfg)

# add fwd bwd links between CPU ops
add_fwd_bwd_links(df)
add_fwd_bwd_links(df, local_symbol_table)

df = transform_correlation_to_index(df, local_symbol_table)

Expand All @@ -274,9 +274,10 @@ def __call__(
return parse_trace_file(trace_file, self.cfg)


def add_fwd_bwd_links(df: pd.DataFrame) -> None:
def add_fwd_bwd_links(df: pd.DataFrame, symbol_table: TraceSymbolTable) -> None:
t0 = time.perf_counter()
if df.cat.eq("fwdbwd").sum() == 0:
fwdbwd_sym_id = symbol_table.get_sym_id_map().get("fwdbwd", None)
if df.cat.eq(fwdbwd_sym_id).sum() == 0:
return

# Initialize the fwdbwd columns to -1
Expand All @@ -285,14 +286,14 @@ def add_fwd_bwd_links(df: pd.DataFrame) -> None:
df["key"] = list(zip(df["ts"], df["tid"], df["pid"]))

# Get the fwdbwd events. Only the "id" and "key" columns are needed for merging.
df_fwdbwd = df.loc[df.cat.eq("fwdbwd")]
df_fwdbwd = df.loc[df.cat.eq(fwdbwd_sym_id)]
df_fwdbwd_start = df_fwdbwd.query("ph == 's'")[["id", "key"]]
df_fwdbwd_end = df_fwdbwd.query("ph == 'f' and bp == 'e'")[["id", "key"]]

# The "index" column for the cpu event will be used when merging with the fwdbwd events.
# The "key" column will be used for the merge.
df_cpu = df.loc[df.cat.eq("cpu_op")][["index", "key"]]

cpu_op_sym_id = symbol_table.get_sym_id_map().get("cpu_op", None)
df_cpu = df.loc[df.cat.eq(cpu_op_sym_id)][["index", "key"]]
# Merge the fwdbwd events with the cpu events.
# We will be using the index of last cpu event when multiple cpu events start from the same ts.
df_fwdbwd_start_events = (
Expand Down Expand Up @@ -322,6 +323,10 @@ def add_fwd_bwd_links(df: pd.DataFrame) -> None:
df.loc[start_indices, "fwdbwd"] = 0
df.loc[end_indices, "fwdbwd"] = 1
df.drop(columns=["key"], inplace=True)

df.dropna(axis=0, subset=["dur"], inplace=True)
columns_to_drop = {"ph", "id", "bp", "s"}.intersection(set(df.columns))
df.drop(list(columns_to_drop), axis=1, inplace=True)
t1 = time.perf_counter()
logger.debug(f"Time taken to add fwd_bwd links: {t1 - t0 :.2f} seconds")

Expand Down
5 changes: 1 addition & 4 deletions hta/common/trace_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,9 @@ def _compress_df(
cfg = cfg or ParserConfig.get_default_cfg()

# drop rows with null values
df.dropna(axis=0, subset=["dur", "cat"], inplace=True)
df.dropna(axis=0, subset=["cat"], inplace=True)
df.drop(df[df["cat"] == "Trace"].index, inplace=True)

# drop columns
columns_to_drop = {"ph", "id", "bp", "s"}.intersection(set(df.columns))
df.drop(list(columns_to_drop), axis=1, inplace=True)
columns = set(df.columns)

# performance counters appear as args
Expand Down
55 changes: 55 additions & 0 deletions tests/test_trace_fwd_bwd_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
from collections import namedtuple
from pathlib import Path

import pandas as pd
from hta.common.trace import Trace


class TraceFWDBWDLinkTestCase(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
test_data_path = Path(__file__).parent.parent.joinpath(
"tests/data/h100/h100_trace.json"
)
self.trace = Trace(
trace_files={0: str(test_data_path)},
trace_dir="",
)
self.trace.parse_traces()
self.trace.decode_symbol_ids(use_shorten_name=False)

def test_fwdbwd_index_column(self):
self.assertIn(
"fwdbwd_index",
self.trace.get_trace(0).columns,
"fwdbwd_index column not found in trace DataFrame",
)

def test_fwdbwd_symbol_and_id(self):
self.assertIn(
"fwdbwd",
self.trace.symbol_table.sym_index,
"fwdbwd symbol not found in trace symbol table",
)

def test_fwdbwd_correlation(self):
df = self.trace.traces[0]
fwd_func_name = "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function"
expect_bwd_func_name = "torch::autograd::CppNode<SplitLookupFunction_rowwise_adagrad_Op>"
for index, row in df[df['s_name'].str.match(pat=r"^"+fwd_func_name+r"$")].iterrows():
fwdbwd_type, bwd_id = row['fwdbwd'], row['fwdbwd_index']
self.assertEqual(fwdbwd_type, 0, "fwdbwd type should be 0")
bwd_func_name = df.loc[bwd_id, 's_name'] if bwd_id in df.index else None
self.assertEqual(
bwd_func_name,
expect_bwd_func_name,
f"Expected bwd function name to be {bwd_func_name}, but got {expect_bwd_func_name}",
)
bwd_type = df.loc[bwd_id, 'fwdbwd'] if bwd_id in df.index else None
self.assertEqual(
bwd_type, 1, "bwd type should be 1 for backward function"
)

if __name__ == "__main__": # pragma: no cover
unittest.main()