-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Very nice work. But I am a little confused about some details.
In mnist_runner.py, the train_fn() function is like this:
def train_fn(state, params, timesteps):
net = Net()
copy_params(base_net, net)
train_net(params, train_loader, net, timesteps, meta_train=True)
avg_loss = test_net(test_loader, net, timesteps)
compute = timesteps
return avg_loss, compute
I think the objective of this function is to do a forward pass and calculate the loss on the training dataset. However, the parameter of test_net() is test_loader. Why not directly using the train_loader?
Thanks!
jordanrule
Metadata
Metadata
Assignees
Labels
No labels