From 0abf1664370592dbcfc19606098a1e1244fdfc45 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 4 Oct 2020 20:06:57 -0500 Subject: [PATCH] Add Tensors class to adapt Tensor and Tensor[]. --- .gitignore | 2 + src/TensorFlowNET.Core/APIs/tf.reshape.cs | 2 +- .../Framework/sparse_tensor.py.cs | 7 +- .../Keras/ArgsDefinition/ModelArgs.cs | 2 + .../Keras/ArgsDefinition/NodeArgs.cs | 2 +- .../Keras/Engine/Flatten.cs | 2 +- .../Keras/Engine/KerasHistory.cs | 29 ++++++++ src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 53 +++----------- src/TensorFlowNET.Core/Keras/Engine/Model.cs | 6 +- src/TensorFlowNET.Core/Keras/Engine/Node.cs | 7 +- .../Keras/Engine/Sequential.cs | 4 +- src/TensorFlowNET.Core/Keras/KerasApi.cs | 13 ++++ .../Keras/Layers/BatchNormalization.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 2 +- .../Keras/Layers/Dropout.cs | 2 +- .../Keras/Layers/Embedding.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/LSTM.cs | 4 +- .../Keras/Layers/Pooling2D.cs | 2 +- .../Keras/Layers/Rescaling.cs | 2 +- src/TensorFlowNET.Core/Layers/Layer.cs | 43 +----------- .../Operations/NnOps/BasicLSTMCell.cs | 4 +- .../Operations/NnOps/BasicRNNCell.cs | 6 +- .../Operations/NnOps/rnn.cs | 4 +- .../Operations/array_ops.cs | 4 +- .../Operations/gen_array_ops.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 2 + src/TensorFlowNET.Core/Tensors/Tensors.cs | 70 +++++++++++++++++++ src/TensorFlowNET.Core/Tensors/constant_op.cs | 2 + .../Variables/variable_scope.py.cs | 3 + src/TensorFlowNET.Core/ops.cs | 4 +- .../Keras/ModelSaveTest.cs | 37 ++++++++++ .../ManagedAPI/FunctionApiTest.cs | 34 +++++++++ .../NativeAPI/GraphBuildTest.cs | 35 ++++++++++ .../layers_test/flatten.cs | 59 ---------------- 36 files changed, 282 insertions(+), 176 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs create mode 100644 src/TensorFlowNET.Core/Tensors/Tensors.cs create mode 100644 test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs create mode 100644 test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/layers_test/flatten.cs diff --git a/.gitignore b/.gitignore index 261c681a3..231d8379a 100644 --- a/.gitignore +++ b/.gitignore @@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist # training model resources .resources /redist +*.xml +*.xsd diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index ab7d62a04..334889bb5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -18,7 +18,7 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor reshape(T tensor, + public Tensor reshape(Tensor tensor, TensorShape shape, string name = null) => gen_array_ops.reshape(tensor, shape, name); diff --git a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs index b03ce2de3..13b75ee96 100644 --- a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs +++ b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp; +using System; using System.Linq; using static Tensorflow.Binding; @@ -42,8 +43,8 @@ public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_) var values_shape = values.TensorShape.with_rank(1); var dense_shape_shape = dense_shape.TensorShape.with_rank(1); - indices_shape[0].merge_with(values_shape.dims[0]); - indices_shape[1].merge_with(dense_shape_shape.dims[0]); + indices_shape["0"].merge_with(values_shape[0]); + indices_shape["1"].merge_with(dense_shape_shape[0]); _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs index f8e13bbe7..70238405d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs @@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition { public class ModelArgs : LayerArgs { + public Tensor[] Inputs { get; set; } + public Tensor[] Outputs { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 0dd4355fd..303e832e4 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -12,6 +12,6 @@ public class NodeArgs public int[] NodeIndices { get; set; } public int[] TensorIndices { get; set; } public Tensor InputTensors { get; set; } - public Tensor Outputs { get; set; } + public Tensors Outputs { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs index 6bd101514..e6c2d9b03 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs @@ -21,7 +21,7 @@ public Flatten(FlattenArgs args) _channels_first = args.DataFormat == "channels_first"; } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { if (_channels_first) { diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs new file mode 100644 index 000000000..832124e42 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + /// + /// Tracks the Layer call that created a Tensor, for Keras Graph Networks. + /// + public class KerasHistory + { + Layer layer; + int node_index; + int tensor_index; + + public KerasHistory(Layer layer, int node_index, int tensor_index) + { + this.layer = layer; + this.node_index = node_index; + this.tensor_index = tensor_index; + } + + public static implicit operator Layer(KerasHistory history) + => history.layer; + + public static implicit operator (Layer, int, int)(KerasHistory history) + => (history.layer, history.node_index, history.tensor_index); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 2887a97be..8c9432354 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -119,11 +119,12 @@ public Layer(LayerArgs args) /// Wraps `call`, applying pre- and post-processing steps. /// /// + /// /// /// - public Tensor Apply(Tensor inputs, bool is_training = false) + public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) { - Tensor outputs = null; + Tensors outputs = null; callContext = callContext ?? new ThreadLocal() { @@ -148,7 +149,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false) if (!built) MaybeBuild(inputs); - outputs = call(inputs, is_training: is_training); + outputs = call(inputs, state: state, is_training: is_training); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); @@ -161,36 +162,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false) return outputs; } - public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) - { - Tensor[] outputs = null; - - callContext = callContext ?? new ThreadLocal() - { - Value = new CallContext() - }; - - var eager = tf.executing_eagerly(); - using var ctxManager = CallContext.enter(); - - string nameScope = ""; - if (eager) - nameScope = name; - else - nameScope = _name_scope(); - - tf_with(ops.name_scope(nameScope), scope => - { - if (!built) - MaybeBuild(inputs[0]); - - outputs = call(inputs, is_training: is_training, state: state); - }); - - return outputs; - } - - private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) + private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) { /*var returnOutputs = new List(); foreach(var x in outputs) @@ -211,7 +183,7 @@ private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) return outputs; } - private void _handle_activity_regularization(Tensor inputs, Tensor outputs) + private void _handle_activity_regularization(Tensors inputs, Tensors outputs) { //if(_activity_regularizer != null) { @@ -219,7 +191,7 @@ private void _handle_activity_regularization(Tensor inputs, Tensor outputs) } } - private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) + private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask) { } @@ -229,12 +201,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null) return null; } - protected virtual Tensor call(Tensor inputs, bool is_training = false) - { - throw new NotImplementedException(""); - } - - protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) + protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { throw new NotImplementedException(""); } @@ -244,7 +211,7 @@ protected virtual string _name_scope() return Name; } - protected void MaybeBuild(Tensor inputs) + protected void MaybeBuild(Tensors inputs) { // Check input assumptions set before layer building, e.g. input rank. if (built) @@ -252,7 +219,7 @@ protected void MaybeBuild(Tensor inputs) if (DType == TF_DataType.DtInvalid) args.DType = inputs.dtype; - var input_shapes = inputs.TensorShape; + var input_shapes = inputs.shape; build(input_shapes); built = true; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index e4af7021d..b5e2b0c86 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -27,7 +27,11 @@ public class Model : Layer public Model(ModelArgs args) : base(args) { - + // Build _output_layers + /*foreach(var x in args.Outputs) + { + var layer = x.KerasHistory; + }*/ } public void compile(string optimizerName, string lossName) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index bb70d7796..5eef1195b 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -35,8 +35,8 @@ public class Node public int[] node_indices; public int[] tensor_indices; - public Tensor input_tensors; - public Tensor Outputs => args.Outputs; + public Tensors input_tensors; + public Tensors Outputs => args.Outputs; public TensorShape[] input_shapes; public TensorShape[] output_shapes; List kerasInputs; @@ -57,7 +57,8 @@ public Node(Layer layer, NodeArgs args) // Set metadata on outputs. var node_index = layer.InboundNodes.Count - 1; - args.Outputs.KerasHistory.Add(layer); + foreach (var (i, tensor) in enumerate(Outputs)) + tensor.KerasHistory = new KerasHistory(layer, node_index, i); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index 16ce0e0a6..a70a5799f 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -60,7 +60,7 @@ public Sequential(SequentialArgs args) public void add(Tensor tensor) { - var layer = tensor.KerasHistory[0]; + Layer layer = tensor.KerasHistory; add(layer); } @@ -129,7 +129,7 @@ void _init_graph_network(Tensor inputs, Tensor outputs) void _map_graph_network(Tensor inputs, Tensor outputs) { - layers.add(outputs.KerasHistory[0]); + layers.add(outputs.KerasHistory); } } } diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 39e9eedc0..603dd2cfb 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -30,6 +30,19 @@ public Sequential Sequential(List layers = null, Name = name }); + /// + /// `Model` groups layers into an object with training and inference features. + /// + /// + /// + /// + public Model Model(Tensor input, Tensor output) + => new Model(new ModelArgs + { + Inputs = new[] { input }, + Outputs = new[] { output } + }); + /// /// Instantiate a Keras tensor. /// diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index a1c0ab7b0..d7664493d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -143,7 +143,7 @@ protected override void build(TensorShape input_shape) built = true; } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 282cef9d2..b26f5465b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -95,7 +95,7 @@ protected override void build(TensorShape input_shape) built = true; } - protected override Tensor call(Tensor inputs, bool training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 9c117fd49..845cca2f1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -65,7 +65,7 @@ protected override void build(TensorShape input_shape) built = true; } - protected override Tensor call(Tensor inputs, bool training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs index b581ac624..057dc2c75 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs @@ -18,7 +18,7 @@ public Dropout(DropoutArgs args) this.args = args; } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { var output = tf_utils.smart_cond(is_training, () => tf.nn.dropout(inputs, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index f07c9c73b..47752081d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -62,7 +62,7 @@ protected override void build(TensorShape input_shape) built = true; } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs index e5ddb1ec9..d61f3faa1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs @@ -29,9 +29,9 @@ public LSTM(LSTMArgs args) : .ToArray(); } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { - return base.call(inputs, is_training); + return base.call(inputs, state: state, is_training: is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 83bfdaab9..559fc982b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -36,7 +36,7 @@ public Pooling2D(Pooling2DArgs args) input_spec = new InputSpec(ndim: 4); } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs index 99d3a9f51..983522753 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs @@ -20,7 +20,7 @@ public Rescaling(RescalingArgs args) : base(args) this.args = args; } - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { scale = math_ops.cast(args.Scale, args.DType); offset = math_ops.cast(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 4aaae7d03..e07677e5b 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -61,44 +61,7 @@ public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) return (results[0], results[1]); } - public Tensor __call__(Tensor inputs, - Tensor training = null, - VariableScope scope = null) - { - _set_scope(scope); - _graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph); - - variable_scope scope_context_manager = null; - if (built) - { - scope_context_manager = tf.variable_scope(_scope, - reuse: true, - auxiliary_name_scope: false); - } - else - { - scope_context_manager = tf.variable_scope(_scope, - reuse: _reuse, - auxiliary_name_scope: false); - } - - Tensor outputs = null; - tf_with(scope_context_manager, scope2 => - { - _current_scope = scope2; - // Actually call layer - outputs = base.Apply(inputs[0], - is_training: training == null ? false : false); - }); - - - // Update global default collections. - _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); - - return outputs; - } - - public Tensor[] __call__(Tensor[] inputs, + public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null, VariableScope scope = null) @@ -120,13 +83,13 @@ public Tensor[] __call__(Tensor[] inputs, auxiliary_name_scope: false); } - Tensor[] outputs = null; + Tensors outputs = null; tf_with(scope_context_manager, scope2 => { _current_scope = scope2; // Actually call layer outputs = base.Apply(inputs, - state, + state: state, is_training: training == null ? false : false); }); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index bb53a4681..9e3e35912 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -74,7 +74,7 @@ public Tensor __call__(Tensor inputs, LSTMStateTuple state) /// /// /// - protected override Tensor call(Tensor inputs, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. @@ -87,7 +87,7 @@ protected override Tensor call(Tensor inputs, bool is_training = false) // array_ops.split(value: state, num_or_size_splits: 2, axis: one); throw new NotImplementedException("BasicLstmCell call"); } - var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel.AsTensor()); + var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor()); gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); // i = input_gate, j = new_input, f = forget_gate, o = output_gate diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index 3754072d5..7bf27e8e4 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,14 +67,14 @@ protected override void build(TensorShape inputs_shape) built = true; } - protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { // Most basic RNN: output = new_state = act(W * input + U * state + B). - var concat = array_ops.concat(new[] { inputs[0], state }, 1); + var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); var output = _activation(gate_inputs, null); - return new[] { output, output }; + return new Tensors(output, output); } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 66327cb57..842cc33e5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -127,7 +127,7 @@ public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell, { input_shape = flat_input.TensorShape.with_rank_at_least(2); batch_size = tensor_shape.dimension_at_index(input_shape, 0); - var input_size = input_shape[1]; + var input_size = input_shape[new Slice(1)]; fixed_batch_size.merge_with(batch_size); foreach (var (i, size) in enumerate(input_size.dims)) { @@ -364,7 +364,7 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RnnCell cell, Tensor inputs, T if (sequence_length != null) throw new NotImplementedException("sequence_length != null"); else - outputs = cell.__call__(new[] { input_t_t }, state: state1); + outputs = cell.__call__(input_t_t, state: state1); var (output, new_state) = (outputs[0], outputs[1]); // Keras cells always wrap state as list, even if it's a single tensor. diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 7b0d6a949..dc9bc5ce8 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -157,7 +157,7 @@ public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boo leading_size, shape(tensor_tensor)[$"{axis + ndims_mask}:"] }, 0); - tensor_tensor = reshape(tensor, shape1); + tensor_tensor = reshape(tensor_tensor, shape1); var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); @@ -353,7 +353,7 @@ public static Tensor rank_internal(Tensor input, string name = null, bool optimi public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => ones_like_impl(tensor, dtype, name, optimize); - public static Tensor reshape(T1 tensor, T2 shape, string name = null) + public static Tensor reshape(Tensor tensor, T2 shape, string name = null) => gen_array_ops.reshape(tensor, shape, null); private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 85fcdae80..6bce44c58 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -292,7 +292,7 @@ public static Tensor reverse(Tensor tensor, T axis, string name = null) return _op.output; } - public static Tensor reshape(T1 tensor, T2 shape, string name = null) + public static Tensor reshape(Tensor tensor, T shape, string name = null) { if (tf.Context.executing_eagerly()) { diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 926524e91..0a3ea47a2 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -144,7 +144,7 @@ public int[] _shape_tuple() /// /// Keras History: (Layer, (node_index, tensor_index)) /// - public List KerasHistory = new List(); + public KerasHistory KerasHistory { get; set; } /// /// Updates the shape of this tensor. diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 0bc783be6..2f1300026 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -132,6 +132,8 @@ public TensorShape this[Slice slice] } } + public int this[int index] => dims[index]; + /// /// Returns True iff `self` is fully defined in every dimension. /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs new file mode 100644 index 000000000..af8796bde --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -0,0 +1,70 @@ +using NumSharp; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Gradients; + +namespace Tensorflow +{ + /// + /// Tensors is used to represent a Tensor or a array of Tensor. + /// It will simplify the API interface, it converts Tensor + /// and Tensor[] to Tensors implicitily. And parse back to Tensor + /// and Tensor[] from Tensors implicitily. + /// It works for tuple and scalar as well. + /// + public class Tensors : IEnumerable + { + Tensor[] items; + + public TF_DataType dtype => items.First().dtype; + public TensorShape shape => items.First().TensorShape; + public int rank => items.First().rank; + public bool IsEagerTensor => items.First().IsEagerTensor; + + public Tensor this[int index] => items[index]; + + public Tensors(params Tensor[] tensors) + { + items = tensors; + } + + public Tensors(NDArray nd) + { + items = new[] { ops.convert_to_tensor(nd) }; + } + + public IEnumerator GetEnumerator() + { + foreach (var tensor in items) + yield return tensor; + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + + public static implicit operator Tensors(Tensor tensor) + => new Tensors(tensor); + + public static implicit operator Tensors(NDArray nd) + => new Tensors(nd); + + public static implicit operator Tensors(Tensor[] tensors) + => new Tensors(tensors); + + public static implicit operator Tensor(Tensors tensors) + => tensors.FirstOrDefault(); + + public static implicit operator Tensor[](Tensors tensors) + => tensors.items; + + public override string ToString() + => items.Length == 1 + ? items.First().ToString() + : items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); + } +} diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 083119f4b..632d8370c 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -155,6 +155,8 @@ private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF return val; case NDArray val: return new EagerTensor(val, ctx.DeviceName); + //case TensorShape val: + //return new EagerTensor(val.dims, ctx.DeviceName); case string val: return new EagerTensor(val, ctx.DeviceName); case string[] val: diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 41f1132df..f21f195bc 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -16,6 +16,7 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; namespace Tensorflow @@ -280,6 +281,7 @@ public static implicit operator VariableScope(variable_scope scope) return scope._scope; } + [DebuggerHidden] public void __exit__() { _cached_pure_variable_scope.__exit__(); @@ -287,6 +289,7 @@ public void __exit__() _current_name_scope.__exit__(); } + [DebuggerHidden] public void Dispose() { if (_current_name_scope != null) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 935414ea2..cf935ab30 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -76,10 +76,10 @@ public static List get_collection_ref(string key) return get_default_graph().get_collection_ref(key); } - public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) + public static Graph _get_graph_from_inputs(Tensors op_input_list) => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); - public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null) + public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null) { foreach(var op_input in op_input_list) { diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs new file mode 100644 index 000000000..050151aff --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs @@ -0,0 +1,37 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using NumSharp; +using Tensorflow.UnitTest; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Keras +{ + /// + /// https://www.tensorflow.org/guide/keras/save_and_serialize + /// + [TestClass] + public class ModelSaveTest : EagerModeTestBase + { + [TestMethod] + public void SaveAndLoadTest() + { + var model = GetModel(); + } + + Model GetModel() + { + var keras = tf.keras; + + // Create a simple model. + var inputs = keras.Input(shape: 32); + var outputs = keras.layers.Dense(1).Apply(inputs); + var model = keras.Model(inputs, outputs); + model.compile("adam", "mean_squared_error"); + return model; + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs index cfc234e72..109bafb26 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs @@ -12,6 +12,40 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestClass] public class FunctionApiTest : TFNetApiTest { + Tensor Min(Tensor a, Tensor b) + { + return tf.cond(a < b, () => a, () => b); + } + + [TestMethod] + public void MulInAutoGraph() + { + var a = tf.constant(1); + var b = tf.constant(2); + // For first time running, tf.net will record the operations in graph mode. + // And register to tensorflow op library. + var output = Mul(a, b); + Assert.AreEqual(2, (int)output); + + var c = tf.constant(3); + // for the following invoke, Mul will be intercepted and run it in eager mode. + output = Mul(b, c); + Assert.AreEqual(6, (int)output); + } + + /// + /// Method with AutoGraph attribute will be converted to FuncGraph + /// when it's invoked for the first time. + /// + /// + /// + /// + [AutoGraph] + Tensor Mul(Tensor a, Tensor b) + { + return a * b; + } + [TestMethod] public void TwoInputs_OneOutput() { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs new file mode 100644 index 000000000..5352976d5 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs @@ -0,0 +1,35 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.NativeAPI +{ + [TestClass] + public class GraphBuildTest : CApiTest + { + [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] + public void UpdateEdge() + { + using var graph = new Graph().as_default(); + + var one = tf.constant(1, name: "one"); + var two = tf.constant(2, name: "two"); + var add = tf.add(one, two, name: "add"); + var neg = tf.negative(add, name: "neg"); + + Assert.AreEqual(1, one.consumers().Length); + Assert.AreEqual("add", neg.op.node_def.Input[0]); + + // update edge + neg.op._update_input(0, one); + // c_api.TF_UpdateEdge(graph, new TF_Output(c1.op, 0), new TF_Input(neg.op, 0), tf.Status.Handle); + + Assert.AreEqual(2, one.consumers().Length); + Assert.AreEqual("one:0", neg.op.node_def.Input[0]); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs deleted file mode 100644 index eb3fef938..000000000 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using FluentAssertions; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using NumSharp; -using Tensorflow; -using Tensorflow.UnitTest; -using static Tensorflow.Binding; - -namespace TensorFlowNET.UnitTest.layers_test -{ - [TestClass] - public class flatten : GraphModeTestBase - { - [TestMethod] - public void Case1() - { - var sess = tf.Session().as_default(); - - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2)); - sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); - } - - [TestMethod] - public void Case2() - { - var sess = tf.Session().as_default(); - - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); - sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1); - } - - [TestMethod] - public void Case3() - { - var sess = tf.Session().as_default(); - - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); - new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); - } - - [TestMethod] - public void Case4() - { - var sess = tf.Session().as_default(); - - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); - sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); - } - - [TestMethod] - public void Case5() - { - var sess = tf.Session().as_default(); - - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); - sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); - } - } -} \ No newline at end of file