Skip to content

Commit e8ad72b

Browse files
committed
transfer the norm_func to main.py and easier for custom design
1 parent 2b20a72 commit e8ad72b

File tree

5 files changed

+79
-51
lines changed

5 files changed

+79
-51
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ tests/
55
exps_cached/
66
diagnosis/
77
figs/
8+
exps/
89
__pycache__/
910

1011
# C extensions

Makefile

+9
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,12 @@ imporve:
2222

2323
efficiency:
2424
python main.py type=sample_efficiency_curve
25+
26+
## Delete all compiled Python files
27+
clean:
28+
find . -type f -name "*.py[co]" -delete
29+
find . -type d -name "__pycache__" -delete
30+
31+
## Lint using flake8
32+
lint:
33+
flake8 src

main.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@
2020
CONFIDENCE = 0.68
2121

2222

23+
def random_score_norm_func(task: str, scores: List):
24+
random_score = {
25+
'HalfCheetah-v4': -290.0479832104089,
26+
'Ant-v4': -55.14243068976598,
27+
'Walker2d-v4': 2.5912887180069686,
28+
'Humanoid-v4': 120.45141735893694
29+
}
30+
scores = np.array(scores)
31+
nume = scores - random_score[task]
32+
deno = np.max(scores) - random_score[task]
33+
return nume / deno
34+
35+
2336
def create_diagnosis(
2437
n_epoch: int = 200,
2538
epoch_len: int = 5000,
@@ -103,7 +116,11 @@ def OG(metric_val): return \
103116
save_fig(fig, f'metric_curve_{task.lower()}', fig_dir)
104117

105118
algo_scores, normalized_algo_scores = \
106-
read_and_norm_algo_scores(diagnosis_dir, algos, 'all')
119+
read_and_norm_algo_scores(
120+
diagnosis_dir, algos,
121+
milestone='all',
122+
norm_func=random_score_norm_func
123+
)
107124

108125
scores, cis = \
109126
rly.get_interval_estimates(
@@ -131,7 +148,10 @@ def metric_value(
131148
):
132149

133150
algo_scores, normalized_algo_scores = \
134-
read_and_norm_algo_scores(diagnosis_dir, algos, milestone)
151+
read_and_norm_algo_scores(
152+
diagnosis_dir, algos, milestone,
153+
norm_func=random_score_norm_func,
154+
)
135155

136156
aggregate_func_mapper = {
137157
'Mean': metrics.aggregate_mean,
@@ -180,7 +200,11 @@ def performance_profiles(
180200
for i, milestone in enumerate(milestones):
181201

182202
algo_scores, normalized_algo_scores = \
183-
read_and_norm_algo_scores(diagnosis_dir, algos, milestone)
203+
read_and_norm_algo_scores(
204+
diagnosis_dir, algos, milestone,
205+
norm_func=random_score_norm_func
206+
)
207+
184208
perf_prof, perf_prof_cis = \
185209
rly.create_performance_profile(
186210
normalized_algo_scores, tau,
@@ -220,7 +244,10 @@ def probability_of_improvement(
220244
**kwargs,
221245
):
222246
algo_scores, normalized_algo_scores = \
223-
read_and_norm_algo_scores(diagnosis_dir, algos, milestone)
247+
read_and_norm_algo_scores(
248+
diagnosis_dir, algos, milestone,
249+
norm_func=random_score_norm_func
250+
)
224251

225252
pairs = generate_pairs(algos)
226253

@@ -250,7 +277,11 @@ def sample_efficiency_curve(
250277
**kwargs,
251278
):
252279
algo_scores, normalized_algo_scores = \
253-
read_and_norm_algo_scores(diagnosis_dir, algos, 'all')
280+
read_and_norm_algo_scores(
281+
diagnosis_dir, algos, 'all',
282+
norm_func=random_score_norm_func
283+
)
284+
254285
steps = np.array(steps) - 1
255286
normalized_algo_steps_scores_dict = {algo: scores[:, :, steps] for algo, scores
256287
in normalized_algo_scores.items()}
@@ -294,7 +325,10 @@ def overall_ranks(
294325

295326
for i, milestone in enumerate(milestones):
296327
algo_scores, normalized_algo_scores = \
297-
read_and_norm_algo_scores(diagnosis_dir, algos, milestone)
328+
read_and_norm_algo_scores(
329+
diagnosis_dir, algos, milestone,
330+
norm_func=random_score_norm_func
331+
)
298332

299333
# num_task * (num_algo * num_algo)
300334
rank_matrix = \

rlplot/plot_helpers.py

+2-45
Original file line numberDiff line numberDiff line change
@@ -262,49 +262,6 @@ def generate_pairs(elements):
262262
return pairs[::-1]
263263

264264

265-
random_score = {
266-
267-
'HalfCheetah-v4': -290.0479832104089,
268-
'Ant-v4': -55.14243068976598,
269-
'Walker2d-v4': 2.5912887180069686,
270-
'Humanoid-v4': 120.45141735893694,
271-
272-
}
273-
274-
275-
def random_score_norm_func(task: str, scores: List):
276-
scores = np.array(scores)
277-
nume = scores - random_score[task]
278-
deno = np.max(scores) - random_score[task]
279-
return nume / deno
280-
281-
282-
def normalized_scores(
283-
task: str,
284-
scores: Union[np.ndarray, List],
285-
norm_func: Callable
286-
):
287-
algos = list(scores.keys())
288-
envs = list(scores[algos[0]].keys())
289-
num_runs = scores[algos[0]][envs[0]].shape[0]
290-
env_scores = {env: [] for env in envs}
291-
for algo in algos:
292-
for env in envs:
293-
env_scores[env] += scores[algo][env].tolist()
294-
normalized_env_scores = {}
295-
for env in envs:
296-
normalized_env_scores[env] = norm_func(env, env_scores[env])
297-
normalized_scores = {}
298-
start, end = 0, num_runs
299-
for algo in algos:
300-
normalized_scores[algo] = {}
301-
for env in envs:
302-
normalized_scores[algo][env] = normalized_env_scores[env][start:end]
303-
start += num_runs
304-
end += num_runs
305-
return normalized_scores
306-
307-
308265
def convert_to_matrix(score_dict, sort=False):
309266
if sort:
310267
keys = sorted(list(score_dict.keys()))
@@ -324,8 +281,8 @@ def read_milestone_from_yaml(
324281

325282

326283
def read_and_norm_algo_scores(
327-
dir, algos, milestone='1m',
328-
norm_func=random_score_norm_func,
284+
dir, algos, milestone: str,
285+
norm_func: Callable,
329286
):
330287
algo_scores = {
331288
algo: read_milestone_from_yaml(dir, algo, milestone)

rlplot/plot_utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# yanked and modified from https://github.com/google-research/rliable/blob/master/rliable/plot_utils.py
1010

11+
1112
def _non_linear_scaling(performance_profiles,
1213
tau_list,
1314
xticklabels=None,
@@ -714,3 +715,29 @@ def plot_overall_ranks(
714715
# # fig.subplots_adjust(hspace=0.25)
715716
#
716717
# save_fig(fig, save_name, save_dir)
718+
719+
720+
# def normalized_scores(
721+
# task: str,
722+
# scores: Union[np.ndarray, List],
723+
# norm_func: Callable
724+
# ):
725+
# algos = list(scores.keys())
726+
# envs = list(scores[algos[0]].keys())
727+
# num_runs = scores[algos[0]][envs[0]].shape[0]
728+
# env_scores = {env: [] for env in envs}
729+
# for algo in algos:
730+
# for env in envs:
731+
# env_scores[env] += scores[algo][env].tolist()
732+
# normalized_env_scores = {}
733+
# for env in envs:
734+
# normalized_env_scores[env] = norm_func(env, env_scores[env])
735+
# normalized_scores = {}
736+
# start, end = 0, num_runs
737+
# for algo in algos:
738+
# normalized_scores[algo] = {}
739+
# for env in envs:
740+
# normalized_scores[algo][env] = normalized_env_scores[env][start:end]
741+
# start += num_runs
742+
# end += num_runs
743+
# return normalized_scores

0 commit comments

Comments
 (0)