Skip to content

Commit 8e134b8

Browse files
add class label option to write metric report to improve readability … (#7249)
add class label option to write metric report to improve readability, without that option in case of many classes the resulting report is very hard to interpret. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: elitap <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c300b36 commit 8e134b8

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

monai/handlers/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def write_metrics_reports(
6161
summary_ops: str | Sequence[str] | None,
6262
deli: str = ",",
6363
output_type: str = "csv",
64+
class_labels: list[str] | None = None,
6465
) -> None:
6566
"""
6667
Utility function to write the metrics into files, contains 3 parts:
@@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans
9495
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
9596
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
9697
output_type: expected output file type, supported types: ["csv"], default to "csv".
98+
class_labels: list of class names used to name the classes in the output report, if None,
99+
"class0", ..., "classn" are used, default to None.
97100
98101
"""
99102
if output_type.lower() != "csv":
@@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans
118121
v = v.reshape((-1, 1))
119122

120123
# add the average value of all classes to v
121-
class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
124+
if class_labels is None:
125+
class_labels = ["class" + str(i) for i in range(v.shape[1])]
126+
else:
127+
class_labels = [str(i) for i in class_labels] # ensure to have a list of str
128+
129+
class_labels += ["mean"]
122130
v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)
123131

124132
with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:

0 commit comments

Comments
 (0)