@@ -36,6 +36,7 @@ limitations under the License.
36
36
#include " tensorflow/core/profiler/protobuf/steps_db.pb.h"
37
37
#include " tensorflow/core/profiler/protobuf/tf_function.pb.h"
38
38
#include " tensorflow/core/profiler/protobuf/xplane.pb.h"
39
+ #include " tensorflow/core/profiler/utils/device_caps_utils.h"
39
40
#include " tensorflow/core/profiler/utils/event_span.h"
40
41
#include " tensorflow/core/profiler/utils/hardware_type_utils.h"
41
42
#include " tensorflow/core/profiler/utils/kernel_stats_utils.h"
@@ -49,38 +50,6 @@ limitations under the License.
49
50
namespace tensorflow {
50
51
namespace profiler {
51
52
52
- DeviceCapabilities GetDeviceCapFromXPlane (const XPlane& device_plane) {
53
- DeviceCapabilities cap;
54
- XPlaneVisitor plane = CreateTfXPlaneVisitor (&device_plane);
55
- plane.ForEachStat ([&cap](const XStatVisitor& stat) {
56
- if (!stat.Type ().has_value ()) return ;
57
- switch (stat.Type ().value ()) {
58
- case kDevCapClockRateKHz :
59
- cap.set_clock_rate_in_ghz (stat.IntValue () / 1000000.0 );
60
- break ;
61
- case kDevCapCoreCount :
62
- cap.set_num_cores (stat.IntValue ());
63
- break ;
64
- case kDevCapMemoryBandwidth :
65
- cap.set_memory_bandwidth (stat.UintValue ()); // bytes/s
66
- break ;
67
- case kDevCapMemorySize :
68
- cap.set_memory_size_in_bytes (stat.UintValue ());
69
- break ;
70
- case kDevCapComputeCapMajor :
71
- cap.mutable_compute_capability ()->set_major (stat.IntValue ());
72
- break ;
73
- case kDevCapComputeCapMinor :
74
- cap.mutable_compute_capability ()->set_minor (stat.IntValue ());
75
- break ;
76
- case kDevVendor :
77
- cap.set_device_vendor (std::string (stat.StrOrRefValue ()));
78
- break ;
79
- }
80
- });
81
- return cap;
82
- }
83
-
84
53
PerfEnv MakePerfEnv (double peak_tera_flops_per_second,
85
54
double peak_hbm_bw_giga_bytes_per_second) {
86
55
PerfEnv result;
@@ -93,7 +62,7 @@ PerfEnv MakePerfEnv(double peak_tera_flops_per_second,
93
62
}
94
63
95
64
PerfEnv GetPerfEnvFromXPlane (const XPlane& device_plane) {
96
- DeviceCapabilities cap = GetDeviceCapFromXPlane (device_plane);
65
+ DeviceCapabilities cap = GetDeviceCaps (device_plane);
97
66
return MakePerfEnv (GetFlopMaxThroughputPerSM (cap) / 1000 * cap.num_cores (),
98
67
cap.memory_bandwidth () / 1e9 );
99
68
}
@@ -159,7 +128,7 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
159
128
op_metrics_db_combiner.Combine (device_op_metrics_db);
160
129
}
161
130
if (gpu_model.empty ()) {
162
- gpu_model = GpuModelName (GetDeviceCapFromXPlane (*device_trace));
131
+ gpu_model = GpuModelName (GetDeviceCaps (*device_trace));
163
132
}
164
133
if (options.generate_step_db ) {
165
134
StepEvents device_step_events =
0 commit comments