-
Notifications
You must be signed in to change notification settings - Fork 531
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
76a964f
commit cbf2d81
Showing
32 changed files
with
269 additions
and
216 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
namespace Tensorflow.Keras.Engine; | ||
|
||
public interface ICallback | ||
{ | ||
Dictionary<string, List<float>> history { get; set; } | ||
void on_train_begin(); | ||
void on_epoch_begin(int epoch); | ||
void on_train_batch_begin(long step); | ||
void on_train_batch_end(long end_step, Dictionary<string, float> logs); | ||
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs); | ||
void on_predict_begin(); | ||
void on_predict_batch_begin(long step); | ||
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | ||
void on_predict_end(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,65 @@ | ||
namespace Tensorflow.Keras.Engine | ||
using Tensorflow.Functions; | ||
using Tensorflow.Keras.Losses; | ||
using Tensorflow.Keras.Saving; | ||
using Tensorflow.NumPy; | ||
|
||
namespace Tensorflow.Keras.Engine; | ||
|
||
public interface IModel : ILayer | ||
{ | ||
public interface IModel | ||
{ | ||
} | ||
void compile(IOptimizer optimizer = null, | ||
ILossFunc loss = null, | ||
string[] metrics = null); | ||
|
||
void compile(string optimizer, string loss, string[] metrics); | ||
|
||
ICallback fit(NDArray x, NDArray y, | ||
int batch_size = -1, | ||
int epochs = 1, | ||
int verbose = 1, | ||
float validation_split = 0f, | ||
bool shuffle = true, | ||
int initial_epoch = 0, | ||
int max_queue_size = 10, | ||
int workers = 1, | ||
bool use_multiprocessing = false); | ||
|
||
void save(string filepath, | ||
bool overwrite = true, | ||
bool include_optimizer = true, | ||
string save_format = "tf", | ||
SaveOptions? options = null, | ||
ConcreteFunction? signatures = null, | ||
bool save_traces = true); | ||
|
||
void save_weights(string filepath, | ||
bool overwrite = true, | ||
string save_format = null, | ||
object options = null); | ||
|
||
void load_weights(string filepath, | ||
bool by_name = false, | ||
bool skip_mismatch = false, | ||
object options = null); | ||
|
||
void evaluate(NDArray x, NDArray y, | ||
int batch_size = -1, | ||
int verbose = 1, | ||
int steps = -1, | ||
int max_queue_size = 10, | ||
int workers = 1, | ||
bool use_multiprocessing = false, | ||
bool return_dict = false); | ||
|
||
Tensors predict(Tensor x, | ||
int batch_size = -1, | ||
int verbose = 0, | ||
int steps = -1, | ||
int max_queue_size = 10, | ||
int workers = 1, | ||
bool use_multiprocessing = false); | ||
|
||
void summary(int line_length = -1, float[] positions = null); | ||
|
||
IKerasConfig get_config(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
namespace Tensorflow.Keras.Engine; | ||
|
||
public interface IOptimizer | ||
{ | ||
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | ||
Tensor[] clip_gradients(Tensor[] grads); | ||
void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | ||
string name = null, | ||
bool experimental_aggregate_gradients = true); | ||
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | ||
string name = null, | ||
bool experimental_aggregate_gradients = true); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,63 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
using Tensorflow.Keras.Engine; | ||
|
||
namespace Tensorflow.Keras.Callbacks | ||
namespace Tensorflow.Keras.Callbacks; | ||
|
||
public class CallbackList | ||
{ | ||
public class CallbackList | ||
{ | ||
List<ICallback> callbacks = new List<ICallback>(); | ||
public History History => callbacks[0] as History; | ||
|
||
public CallbackList(CallbackParams parameters) | ||
{ | ||
callbacks.Add(new History(parameters)); | ||
callbacks.Add(new ProgbarLogger(parameters)); | ||
} | ||
|
||
public void on_train_begin() | ||
{ | ||
callbacks.ForEach(x => x.on_train_begin()); | ||
} | ||
|
||
public void on_epoch_begin(int epoch) | ||
{ | ||
callbacks.ForEach(x => x.on_epoch_begin(epoch)); | ||
} | ||
|
||
public void on_train_batch_begin(long step) | ||
{ | ||
callbacks.ForEach(x => x.on_train_batch_begin(step)); | ||
} | ||
|
||
public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | ||
{ | ||
callbacks.ForEach(x => x.on_train_batch_end(end_step, logs)); | ||
} | ||
|
||
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) | ||
{ | ||
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | ||
} | ||
|
||
public void on_predict_begin() | ||
{ | ||
callbacks.ForEach(x => x.on_predict_begin()); | ||
} | ||
|
||
public void on_predict_batch_begin(long step) | ||
{ | ||
callbacks.ForEach(x => x.on_predict_batch_begin(step)); | ||
} | ||
|
||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||
{ | ||
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); | ||
} | ||
|
||
public void on_predict_end() | ||
{ | ||
callbacks.ForEach(x => x.on_predict_end()); | ||
} | ||
List<ICallback> callbacks = new List<ICallback>(); | ||
public History History => callbacks[0] as History; | ||
|
||
public CallbackList(CallbackParams parameters) | ||
{ | ||
callbacks.Add(new History(parameters)); | ||
callbacks.Add(new ProgbarLogger(parameters)); | ||
} | ||
|
||
public void on_train_begin() | ||
{ | ||
callbacks.ForEach(x => x.on_train_begin()); | ||
} | ||
|
||
public void on_epoch_begin(int epoch) | ||
{ | ||
callbacks.ForEach(x => x.on_epoch_begin(epoch)); | ||
} | ||
|
||
public void on_train_batch_begin(long step) | ||
{ | ||
callbacks.ForEach(x => x.on_train_batch_begin(step)); | ||
} | ||
|
||
public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | ||
{ | ||
callbacks.ForEach(x => x.on_train_batch_end(end_step, logs)); | ||
} | ||
|
||
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) | ||
{ | ||
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | ||
} | ||
|
||
public void on_predict_begin() | ||
{ | ||
callbacks.ForEach(x => x.on_predict_begin()); | ||
} | ||
|
||
public void on_predict_batch_begin(long step) | ||
{ | ||
callbacks.ForEach(x => x.on_predict_batch_begin(step)); | ||
} | ||
|
||
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||
{ | ||
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); | ||
} | ||
|
||
public void on_predict_end() | ||
{ | ||
callbacks.ForEach(x => x.on_predict_end()); | ||
} | ||
} |
Oops, something went wrong.