Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: property interface refactor #16

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ __pycache__
.env
*.jpg
testood/
finetune/
finetune*/
99 changes: 53 additions & 46 deletions lamstare/experiments/plt_lcurve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from lamstare.experiments.plt_test import COLOR
from lamstare.utils.plot import fetch_lcurve, sendimg
from lamstare.utils.dptest import get_head_weights
import matplotlib.pyplot as plt
Expand All @@ -14,69 +15,75 @@



def main(exp_path:str, roll:int=50):
run_id=exp_path.split("/")[-1] # Get basename as id
def main(exp_paths:list, roll:int=10):
try:
weights = get_head_weights(exp_path)
heads = list(get_head_weights(exp_path).keys())
weights = get_head_weights(exp_paths[0])
heads = list(get_head_weights(exp_paths[0]).keys())
except KeyError:
heads = [""] # single task

n_heads = len(heads)
mult_hist = fetch_lcurve(exp_path)
fig, ax = plt.subplots(n_heads, 3, figsize=(12,2*n_heads+1),sharex=True)

if n_heads == 1:
ax[0].loglog(mult_hist["step"], mult_hist[f"rmse_e_trn"].rolling(roll).mean(), linestyle='-',color="blue")
ax[0].loglog(mult_hist["step"], mult_hist[f"rmse_e_val"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[0].set_ylabel(f"rmse_e")
for i, exp_path in enumerate(exp_paths):
color = COLOR[i]
mult_hist = fetch_lcurve(exp_path)
run_id=exp_path.split("/")[-1] # Get basename as id
if n_heads == 1:
ax[0].loglog(mult_hist["step"], mult_hist[f"rmse_e_trn"].rolling(roll).mean(), linestyle='-',color=color)
ax[0].loglog(mult_hist["step"], mult_hist[f"rmse_e_val"].rolling(roll).mean(), linestyle='-.',color=color)
ax[0].set_ylabel(f"rmse_e")

ax[1].loglog(mult_hist["step"], mult_hist[f"rmse_f_trn"].rolling(roll).mean(), linestyle='-',color="blue")
ax[1].loglog(mult_hist["step"], mult_hist[f"rmse_f_val"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[1].set_ylabel(f"rmse_f")
ax[1].loglog(mult_hist["step"], mult_hist[f"rmse_f_trn"].rolling(roll).mean(), linestyle='-',color=color)
ax[1].loglog(mult_hist["step"], mult_hist[f"rmse_f_val"].rolling(roll).mean(), linestyle='-.',color=color)
ax[1].set_ylabel(f"rmse_f")

ax[2].loglog(mult_hist["step"], mult_hist[f"rmse_v_trn"].rolling(roll).mean(), linestyle='-',color="blue")
ax[2].loglog(mult_hist["step"], mult_hist[f"rmse_v_val"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[2].set_ylabel(f"rmse_v")
fig.suptitle(run_id)
else:
for i, head in enumerate(heads):
ax[i][0].loglog(mult_hist["step"], mult_hist[f"rmse_e_trn_{head}"].rolling(roll).mean(), linestyle='-',color="blue")
ax[i][0].loglog(mult_hist["step"], mult_hist[f"rmse_e_val_{head}"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[i][0].set_ylabel(f"rmse_e_{head}")
ax[2].loglog(mult_hist["step"], mult_hist[f"rmse_v_trn"].rolling(roll).mean(), linestyle='-',color=color)
ax[2].loglog(mult_hist["step"], mult_hist[f"rmse_v_val"].rolling(roll).mean(), linestyle='-.',color=color)
ax[2].set_ylabel(f"rmse_v")
fig.suptitle(run_id)
else:
for i, head in enumerate(heads):
ax[i][0].loglog(mult_hist["step"], mult_hist[f"rmse_e_trn_{head}"].rolling(roll).mean(), linestyle='-',color=color)
ax[i][0].loglog(mult_hist["step"], mult_hist[f"rmse_e_val_{head}"].rolling(roll).mean(), linestyle='-.',color=color)
ax[i][0].set_ylabel(f"rmse_e_{head}")

ax[i][1].loglog(mult_hist["step"], mult_hist[f"rmse_f_trn_{head}"].rolling(roll).mean(), linestyle='-',color="blue")
ax[i][1].loglog(mult_hist["step"], mult_hist[f"rmse_f_val_{head}"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[i][1].set_ylabel(f"rmse_f_{head}")
ax[i][1].loglog(mult_hist["step"], mult_hist[f"rmse_f_trn_{head}"].rolling(roll).mean(), linestyle='-',color=color)
ax[i][1].loglog(mult_hist["step"], mult_hist[f"rmse_f_val_{head}"].rolling(roll).mean(), linestyle='-.',color=color)
ax[i][1].set_ylabel(f"rmse_f_{head}")

ax[i][2].loglog(mult_hist["step"], mult_hist[f"rmse_v_trn_{head}"].rolling(roll).mean(), linestyle='-',color="blue")
ax[i][2].loglog(mult_hist["step"], mult_hist[f"rmse_v_val_{head}"].rolling(roll).mean(), linestyle='-.',color="blue")
ax[i][2].set_ylabel(f"rmse_v_{head}")
ax[i][2].loglog(mult_hist["step"], mult_hist[f"rmse_v_trn_{head}"].rolling(roll).mean(), linestyle='-',color=color)
ax[i][2].loglog(mult_hist["step"], mult_hist[f"rmse_v_val_{head}"].rolling(roll).mean(), linestyle='-.',color=color)
ax[i][2].set_ylabel(f"rmse_v_{head}")

if head in BASELINE_MAP:
baseline_hist = fetch_lcurve(BASELINE_MAP[head])
STEP_NORMAL_PREF = sum(weights.values())/weights[head]*128/120 # need to adjust this value
if head in BASELINE_MAP:
baseline_hist = fetch_lcurve(BASELINE_MAP[head])
STEP_NORMAL_PREF = sum(weights.values())/weights[head]*128/120 # need to adjust this value

ax[i][0].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_e_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][0].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_e_val"].rolling(1000).mean(), linestyle='-.',color="red")
ax[i][1].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_f_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][1].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_f_val"].rolling(1000).mean(), linestyle='-.',color="red")
if "rmse_v_val" in baseline_hist:
ax[i][2].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_v_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][2].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_v_val"].rolling(1000).mean(), linestyle='-.',color="red")
if head in PREVIOUS_BASELINE:
ax[i][0].axhline(PREVIOUS_BASELINE[head]["rmse_e"],color="green", linestyle="-.")
ax[i][0].axhline(PREVIOUS_BASELINE[head]["e_std"],color="purple", linestyle="-.")
ax[i][1].axhline(PREVIOUS_BASELINE[head]["rmse_f"],color="green", linestyle="-.")
ax[i][1].axhline(PREVIOUS_BASELINE[head]["f_std"],color="purple", linestyle="-.")
ax[i][2].axhline(PREVIOUS_BASELINE[head]["rmse_v"],color="green", linestyle="-.")
ax[i][2].axhline(PREVIOUS_BASELINE[head]["v_std"],color="purple", linestyle="-.")
ax[i][0].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_e_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][0].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_e_val"].rolling(1000).mean(), linestyle='-.',color="red")
ax[i][1].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_f_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][1].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_f_val"].rolling(1000).mean(), linestyle='-.',color="red")
if "rmse_v_val" in baseline_hist:
ax[i][2].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_v_trn"].rolling(1000).mean(), linestyle='-',color="red")
ax[i][2].loglog([s * STEP_NORMAL_PREF for s in baseline_hist["step"]], baseline_hist[f"rmse_v_val"].rolling(1000).mean(), linestyle='-.',color="red")
if head in PREVIOUS_BASELINE:
ax[i][0].axhline(PREVIOUS_BASELINE[head]["rmse_e"],color="green", linestyle="-.")
ax[i][0].axhline(PREVIOUS_BASELINE[head]["e_std"],color="purple", linestyle="-.")
ax[i][1].axhline(PREVIOUS_BASELINE[head]["rmse_f"],color="green", linestyle="-.")
ax[i][1].axhline(PREVIOUS_BASELINE[head]["f_std"],color="purple", linestyle="-.")
ax[i][2].axhline(PREVIOUS_BASELINE[head]["rmse_v"],color="green", linestyle="-.")
ax[i][2].axhline(PREVIOUS_BASELINE[head]["v_std"],color="purple", linestyle="-.")

plt.tight_layout()
fig.savefig("lcurve.jpg")
sendimg(["lcurve.jpg"], run_id)

if __name__ == "__main__":
for exp_path in ["/mnt/data_nas/public/multitask/training_exps/1126_prod_shareft_120GUP_240by3_single_384_96_24"]:
main(exp_path)
exp_path = [
# "/mnt/data_nas/public/multitask/training_exps/1126_prod_shareft_120GUP_240by3_single_384_96_24",
# "/mnt/data_nas/public/multitask/training_exps/1122_shareft_lr1e-3_1e-5_pref0021_1000100_24GUP_240by3_single_384_96_24",
"/mnt/data_nas/public/multitask/training_exps/1220_dpa3a_shareft_rc6_120_arc_4_30_l6_120GPU_240by3_384_96_64"
]
main(exp_path)
# main("/mnt/data_nas/public/multitask/training_exps/1018_b4_medium_l6_atton_37head_linear_fitting_tanh")
46 changes: 33 additions & 13 deletions lamstare/experiments/plt_ood.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
import os

from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas
Expand All @@ -14,7 +14,7 @@
from lamstare.utils.plot import sendimg


with open(os.path.dirname(__file__) + "/../release/OOD_DATASET.yml", "r") as f:
with open(os.path.dirname(__file__) + "/../release/ood_test/OOD_DATASET_v2.yml", "r") as f:
OOD_DATASET = yaml.load(f, Loader=yaml.FullLoader)
OOD_DATASET = (
DataFrame(OOD_DATASET["OOD_TO_HEAD_MAP"]).T.rename_axis("Dataset").infer_objects()
Expand All @@ -27,7 +27,7 @@
print(OOD_DATASET)

OOD_DATASET_STD = pandas.read_csv(
"/mnt/workspace/cc/LAMstare_new/lamstare/release/ood_data_std.csv"
"/mnt/workspace/cc/LAMstare_new/lamstare/release/ood_test/ood_data_std.csv"
).infer_objects()
OOD_DATASET_STD.set_index("Dataset", inplace=True)
print(OOD_DATASET_STD)
Expand All @@ -45,21 +45,23 @@ def get_weighted_result(exp_path: str) -> DataFrame:

weighted_avg = all_records_df.groupby(
"Training Steps"
).mean() # provide a baseline with same shape
).mean().map(lambda x: np.nan) # provide a df with same shape

# mask.inplace and update() won't work; need to assign to a new variable
for efv in ["energy", "force", "virial"]:
data = all_records_df.loc[
:, [key for key in all_records_df.keys() if efv in key]
]
weights = OOD_DATASET[efv + "_weight"]
# data.mask(weights == 0, inplace=True)
weighted_avg_efv = (
data.apply(np.log)
.mul(weights, axis="index")
.groupby("Training Steps")
.mean()
.apply(np.exp)
)
# mask out the results where NAN exists in the original data
weighted_avg_efv.mask(all_records_df.isna().any(axis=1).groupby("Training Steps").any(), inplace=True)
weighted_avg.update(weighted_avg_efv)

weighted_avg["Dataset"] = "Weighted"
Expand All @@ -77,6 +79,7 @@ def plotting(
all_records_df: DataFrame,
color: str,
legend_handles: list[Line2D],
metric_key: str="rmse"
):
for dataset, records in all_records_df.groupby("Dataset"):
assert dataset in dataset_to_subplot.keys(), f"Dataset {dataset} not presented"
Expand All @@ -95,7 +98,7 @@ def plotting(
subsubplot.axhline(std, color="purple", linestyle="-.")
# note: this will draw duplicated lines

metric_name = efv + "_rmse" + suffix
metric_name = efv + f"_{metric_key}" + suffix
line = subsubplot.loglog(
records.index, # step
records[metric_name],
Expand All @@ -107,7 +110,7 @@ def plotting(
legend_handles.extend(line) # type: ignore


def main(exps: list[str]):
def main(exps: list[str], metric_key: str="rmse"):
# Get dataset list from yaml file to preserve the order
datasets: list[str] = OOD_DATASET.index.tolist()
datasets.append("Weighted")
Expand All @@ -126,11 +129,19 @@ def main(exps: list[str]):

for exp_path, color in zip(exps, COLOR):
all_records_df = get_weighted_result(exp_path)
plotting(dataset_to_subplot, all_records_df, color, legend_handles)

plotting(dataset_to_subplot, all_records_df, color, legend_handles, metric_key)

## to set finer tick
from matplotlib.ticker import FixedLocator
ax[-1][0].yaxis.set_major_locator(FixedLocator(np.arange(0.02, 0.04, 0.002)))
ax[-1][1].yaxis.set_major_locator(FixedLocator(np.arange(0.2, 0.5, 0.04)))
## to handle hpt explosion
# for ax in dataset_to_subplot["HPt_NC_2022"]:
# ax.set_ylim(0.05,0.2)

fig.tight_layout()
fig.subplots_adjust(top=0.975)
title = "Compare OOD"
title = f"Compare OOD-{metric_key}"
# fig.suptitle(title) # Poor placement
fig.legend(
handles=legend_handles,
Expand All @@ -146,8 +157,17 @@ def main(exps: list[str]):

if __name__ == "__main__":
exps = [

"/mnt/data_nas/public/multitask/training_exps/1122_shareft_lr1e-3_1e-5_pref0021_1000100_24GUP_240by3_single_384_96_24",
"/mnt/data_nas/public/multitask/training_exps/1126_prod_shareft_120GUP_240by3_single_384_96_24"
# "/mnt/data_nas/public/multitask/training_exps/1122_shareft_lr1e-3_1e-5_pref0021_1000100_24GUP_240by3_single_384_96_24",
"/mnt/data_nas/public/multitask/training_exps/1126_prod_shareft_120GUP_240by3_single_384_96_24",
# "/mnt/data_nas/public/multitask/training_exps/1223_prod_shareft_40GPU_finetune_pref0210_10010",
# "/mnt/data_nas/public/multitask/training_exps/1226_prod_shareft_40GPU_finetune_pref0210_10010_lr1e-5",
"/mnt/data_nas/public/multitask/training_exps/1225_dpa3a_shareft_rc6_120_arc_4_30_l6_120GPU_240by3_384_96_32_comp1",
"/mnt/data_nas/public/multitask/training_exps/0105_dpa3a_shareft_384_96_32_scp1_e1a_tanh_rc6_120_arc_4_30_l6_120GPU_240by3",
# "/mnt/workspace/public/multitask/training_exps/N0130_dpa3a_shareft_128_64_32_scp1_e1a_cdsilu10_rc6_120_arc_4_30_l6_64GPU_240by3_float32",
"/mnt/workspace/public/multitask/training_exps/0202_dpa3a_shareft_256_128_32_scp1_e1a_csilu10_rc6_120_arc_4_30_l9_104GPU_240by3"
# "/mnt/data_nas/public/multitask/training_exps/0115_dpa3a_shareft_128_64_32_scp1_e1a_tanh_rc6_120_arc_4_30_l6_64GPU_240by3_float32"


]
main(exps)
main(exps, "mae")
Loading