Skip to content

Commit 64c0227

Browse files
committed
feat: fix bugs in normalize function (thanks to 'Ezio1018') and improve default REPS
1 parent 8923b2b commit 64c0227

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_rank_matrix
1717

1818

19-
REPS = 50
19+
REPS = 1000
2020
CONFIDENCE = 0.68
2121

2222

rlplot/plot_helpers.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,20 @@ def read_and_norm_algo_scores(
288288
algo: read_milestone_from_yaml(dir, algo, milestone)
289289
for algo in algos
290290
}
291+
292+
task_scores = defaultdict(list)
293+
for algo in algos:
294+
for task, scores in algo_scores[algo].items():
295+
task_scores[task] += scores
296+
291297
normalized_algo_scores = deepcopy(algo_scores)
292-
for algo in normalized_algo_scores:
293-
normalized_algo_scores[algo] = \
294-
{task: norm_func(task, scores)
295-
for task, scores in normalized_algo_scores[algo].items()}
296-
for algo, task_scores in algo_scores.items():
297-
for task, scores in task_scores.items():
298-
assert np.argmax(algo_scores[algo][task]) \
299-
== np.argmax(normalized_algo_scores[algo][task])
298+
for task, scores in task_scores.items():
299+
normalized_scores = norm_func(task, scores)
300+
num_runs = normalized_scores.shape[0] // len(algos)
301+
normalized_scores = \
302+
normalized_scores.reshape(len(algos), num_runs, -1).squeeze()
303+
for idx, algo in enumerate(normalized_algo_scores):
304+
normalized_algo_scores[algo][task] = normalized_scores[idx].tolist()
300305

301306
# num_runs * num_tasks
302307
algo_scores = \

0 commit comments

Comments
 (0)