Skip to content

Commit

Permalink
Sort x for plot; Add print info to save
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Jul 12, 2019
1 parent 0a26932 commit 825ca96
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions deepxde/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def plot_loss_history(losshistory):


def save_loss_history(losshistory, fname):
print("Saving loss history to {} ...".format(fname))
loss = np.hstack(
(
np.array(losshistory.steps)[:, None],
Expand All @@ -57,15 +58,15 @@ def plot_best_state(train_state):

# Regression plot
plt.figure()
idx = np.argsort(X_test[:, 0])
X = X_test[idx, 0]
for i in range(y_dim):
plt.plot(X_train[:, 0], y_train[:, i], "ok", label="Train")
plt.plot(X_test[:, 0], y_test[:, i], "-k", label="True")
plt.plot(X_test[:, 0], best_y[:, i], "--r", label="Prediction")
plt.plot(X, y_test[idx, i], "-k", label="True")
plt.plot(X, best_y[idx, i], "--r", label="Prediction")
if best_ystd is not None:
plt.plot(
X_test[:, 0], best_y[:, i] + 2 * best_ystd[:, i], "-b", label="95% CI"
)
plt.plot(X_test[:, 0], best_y[:, i] - 2 * best_ystd[:, i], "-b")
plt.plot(X, best_y[idx, i] + 2 * best_ystd[idx, i], "-b", label="95% CI")
plt.plot(X, best_y[idx, i] - 2 * best_ystd[idx, i], "-b")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
Expand Down Expand Up @@ -93,10 +94,12 @@ def plot_best_state(train_state):


def save_best_state(train_state, fname_train, fname_test):
print("Saving training data to {} ...".format(fname_train))
X_train, y_train, X_test, y_test, best_y, best_ystd = train_state.packed_data()
train = np.hstack((X_train, y_train))
np.savetxt(fname_train, train, header="x, y")

print("Saving test data to {} ...".format(fname_test))
test = np.hstack((X_test, y_test, best_y))
if best_ystd is not None:
test = np.hstack((test, best_ystd))
Expand Down

0 comments on commit 825ca96

Please sign in to comment.