Skip to content

Commit

Permalink
Add posterior mean/stddev into agent_data
Browse files Browse the repository at this point in the history
  • Loading branch information
XPD Operator committed Jun 11, 2024
1 parent 25ecc3a commit b9bb121
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 14 deletions.
15 changes: 12 additions & 3 deletions scripts/kafka_consumer_iterate_1LL09_RM.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
'ZnI':'infusion_rate_I2',
'ZnCl':'infusion_rate_Cl'}

new_points_label = ['infusion_rate_CsPb', 'infusion_rate_Br', 'infusion_rate_I2', 'infusion_rate_Cl']

use_good_bad = True
post_dilute = True
write_agent_data = True
Expand Down Expand Up @@ -417,7 +419,7 @@ def print_message(consumer, doctype, doc,
if USE_AGENT_iterate:

# print(f"\ntelling agent {agent_data}")
agent = build_agen2(peak_target=peak_target)
agent = build_agen2(peak_target=peak_target, agent_data_path=agent_data_path)

if len(agent.table) < 2:
acq_func = "qr"
Expand Down Expand Up @@ -530,8 +532,15 @@ def print_message(consumer, doctype, doc,
if post_dilute:
set_target_list = [0 for i in range(len(pump_list))]
# rate_list = new_points['points'].tolist()[0][:-1] + [new_points['points'].sum()]
rate_list = [rr for rr in new_points['points'].tolist()[0] if rr!=0] + [new_points['points'].sum()]
rate_list = np.asarray(rate_list)
# rate_list = [rr for rr in new_points['points'].tolist()[0] if rr!=0] + [new_points['points'].sum()]
# rate_list = np.asarray(rate_list)
rate_list = []
for i in new_points_label:
for key in new_points['points']:
if i == key:
rate_list.append(new_points['points'][key][0])
rate_list.insert(2, sum(rate_list)/10)
rate_list.append(sum(rate_list)*5)

else:
# set_target_list = [0 for i in range(new_points['points'].shape[1])]
Expand Down
29 changes: 28 additions & 1 deletion scripts/kafka_consumer_iterate_XPD_RM.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
USE_AGENT_iterate = False
peak_target = 515
if USE_AGENT_iterate:
import torch
from prepare_agent_pdf import build_agen
agent = build_agen(peak_target=peak_target, agent_data_path=agent_data_path)

Expand Down Expand Up @@ -549,7 +550,7 @@ def print_message(consumer, doctype, doc,
if USE_AGENT_iterate:

# print(f"\ntelling agent {agent_data}")
agent = build_agen(peak_target=peak_target)
agent = build_agen(peak_target=peak_target, agent_data_path=agent_data_path)

if len(agent.table) < 2:
acq_func = "qr"
Expand All @@ -558,6 +559,32 @@ def print_message(consumer, doctype, doc,

new_points = agent.ask(acq_func, n=1)

## Get target of agent.ask()
agent_target = agent.objectives.summary['target'].tolist()

## Get mean and standard deviation of agent.ask()
res_values = []
for i in new_points_label:
if i in new_points['points'].keys():
res_values.append(new_points['points'][i])
x_tensor = torch.tensor(res_value)
post = agent.posterior(x_tensor)
post_mean = post.mean.tolist()[0]
post_stddev = post.stddev.tolist()[0]

## apply np.exp for log-transform objectives
if_log = agent.objectives.summary['transform']
for j in range(if_log.shape[0]):
if if_log[j] == 'log':
post_mean[j] = np.exp(post_mean[j])
post_stddev[j] = np.exp(post_stddev[j])

## Update target, mean, and standard deviation in agent_data
agent_data.update({'agent_target': agent_target})
agent_data.update({'posterior_mean': post_mean})
agent_data.update({'posterior_stddev': post_stddev})


# peak_diff = peak_emission - peak_target
peak_diff = False

Expand Down
51 changes: 41 additions & 10 deletions scripts/prepare_agent_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,45 @@ def build_agen(peak_target=660, peak_tolerance=5, size_target=6, agent_data_path


'''
agent.posterior
import torch
res = agent.ask('qem', n=1)
agent.posterior(torch.tensor(res['points'])).mean
x = torch.tensor([[ 24.0426776 , 159.30614932, 101.20516362]])
agent.posterior(x)
agent.posterior(x).mean
agent.plot_acquisition(); plt.show()
agent.posterior
import torch
res = agent.ask('qem', n=1)
agent.posterior(torch.tensor(res['points'])).mean
x = torch.tensor([[ 24.0426776 , 159.30614932, 101.20516362]])
agent.posterior(x)
agent.posterior(x).mean
agent.plot_acquisition(); plt.show()
18/2: agent
18/3: agent.table
18/4: agent.table.Peak
18/5: plt.rcParams['font.size'] = 4
18/6: import matplotlib.pyplot as plt
18/7: plt.rcParams['font.size'] = 4
18/8: agent.plot_objectives(); plt.show()
18/9: agent.objectives
18/10: agent.ask("qem", n=1)
18/11: agent.ask("qei", n=1)
18/12: agent.ask("qei", n=1)
18/13: import torch
18/14: x = torch.tensor(res[0])
18/15: res = agent.ask("qem", n=4)
18/16: x = torch.tensor(res[0])
18/17: agent.posterior(x).mean
18/18: agent.best
18/19: agent.objectives
18/20: post = agent.posterior(x)
18/21: post.mean
18/22: post.sigma
18/23: post.stddev
18/24: agent.objectives
18/25: agent.plot_acquisition(); plt.show()
18/26: agent.plot_constraint(); plt.show()
18/27: agent.dofs
18/28: agent.objectives
18/29: agent.best
'''

0 comments on commit b9bb121

Please sign in to comment.