Skip to content

Commit

Permalink
Add Tensors class to adapt Tensor and Tensor[].
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Oct 5, 2020
1 parent 69967b4 commit 0abf166
Show file tree
Hide file tree
Showing 36 changed files with 282 additions and 176 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist
# training model resources
.resources
/redist
*.xml
*.xsd
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/APIs/tf.reshape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor reshape<T>(T tensor,
public Tensor reshape(Tensor tensor,
TensorShape shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);

Expand Down
7 changes: 4 additions & 3 deletions src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Linq;
using static Tensorflow.Binding;

Expand Down Expand Up @@ -42,8 +43,8 @@ public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape[0].merge_with(values_shape.dims[0]);
indices_shape[1].merge_with(dense_shape_shape.dims[0]);
indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);

_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
}
Expand Down
2 changes: 2 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ModelArgs : LayerArgs
{
public Tensor[] Inputs { get; set; }
public Tensor[] Outputs { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ public class NodeArgs
public int[] NodeIndices { get; set; }
public int[] TensorIndices { get; set; }
public Tensor InputTensors { get; set; }
public Tensor Outputs { get; set; }
public Tensors Outputs { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public Flatten(FlattenArgs args)
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (_channels_first)
{
Expand Down
29 changes: 29 additions & 0 deletions src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// Tracks the Layer call that created a Tensor, for Keras Graph Networks.
/// </summary>
public class KerasHistory
{
Layer layer;
int node_index;
int tensor_index;

public KerasHistory(Layer layer, int node_index, int tensor_index)
{
this.layer = layer;
this.node_index = node_index;
this.tensor_index = tensor_index;
}

public static implicit operator Layer(KerasHistory history)
=> history.layer;

public static implicit operator (Layer, int, int)(KerasHistory history)
=> (history.layer, history.node_index, history.tensor_index);
}
}
53 changes: 10 additions & 43 deletions src/TensorFlowNET.Core/Keras/Engine/Layer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ public Layer(LayerArgs args)
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensor Apply(Tensor inputs, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;
Tensors outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Expand All @@ -148,7 +149,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, is_training: is_training);
outputs = call(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
Expand All @@ -161,36 +162,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
return outputs;
}

public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false)
{
Tensor[] outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = name;
else
nameScope = _name_scope();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs[0]);

outputs = call(inputs, is_training: is_training, state: state);
});

return outputs;
}

private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
{
/*var returnOutputs = new List<Tensor>();
foreach(var x in outputs)
Expand All @@ -211,15 +183,15 @@ private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
return outputs;
}

private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
private void _handle_activity_regularization(Tensors inputs, Tensors outputs)
{
//if(_activity_regularizer != null)
{

}
}

private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask)
{

}
Expand All @@ -229,12 +201,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
return null;
}

protected virtual Tensor call(Tensor inputs, bool is_training = false)
{
throw new NotImplementedException("");
}

protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}
Expand All @@ -244,15 +211,15 @@ protected virtual string _name_scope()
return Name;
}

protected void MaybeBuild(Tensor inputs)
protected void MaybeBuild(Tensors inputs)
{
// Check input assumptions set before layer building, e.g. input rank.
if (built)
return;
if (DType == TF_DataType.DtInvalid)
args.DType = inputs.dtype;

var input_shapes = inputs.TensorShape;
var input_shapes = inputs.shape;
build(input_shapes);
built = true;
}
Expand Down
6 changes: 5 additions & 1 deletion src/TensorFlowNET.Core/Keras/Engine/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ public class Model : Layer
public Model(ModelArgs args)
: base(args)
{

// Build _output_layers
/*foreach(var x in args.Outputs)
{
var layer = x.KerasHistory;
}*/
}

public void compile(string optimizerName, string lossName)
Expand Down
7 changes: 4 additions & 3 deletions src/TensorFlowNET.Core/Keras/Engine/Node.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ public class Node

public int[] node_indices;
public int[] tensor_indices;
public Tensor input_tensors;
public Tensor Outputs => args.Outputs;
public Tensors input_tensors;
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;
List<Layer> kerasInputs;
Expand All @@ -57,7 +57,8 @@ public Node(Layer layer, NodeArgs args)

// Set metadata on outputs.
var node_index = layer.InboundNodes.Count - 1;
args.Outputs.KerasHistory.Add(layer);
foreach (var (i, tensor) in enumerate(Outputs))
tensor.KerasHistory = new KerasHistory(layer, node_index, i);
}
}
}
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Sequential(SequentialArgs args)

public void add(Tensor tensor)
{
var layer = tensor.KerasHistory[0];
Layer layer = tensor.KerasHistory;
add(layer);
}

Expand Down Expand Up @@ -129,7 +129,7 @@ void _init_graph_network(Tensor inputs, Tensor outputs)

void _map_graph_network(Tensor inputs, Tensor outputs)
{
layers.add(outputs.KerasHistory[0]);
layers.add(outputs.KerasHistory);
}
}
}
13 changes: 13 additions & 0 deletions src/TensorFlowNET.Core/Keras/KerasApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ public Sequential Sequential(List<Layer> layers = null,
Name = name
});

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Model Model(Tensor input, Tensor output)
=> new Model(new ModelArgs
{
Inputs = new[] { input },
Outputs = new[] { output }
});

/// <summary>
/// Instantiate a Keras tensor.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected override void build(TensorShape input_shape)
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;

Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Conv.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ protected override void build(TensorShape input_shape)
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Dense.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ protected override void build(TensorShape input_shape)
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
Tensor outputs = null;
var rank = inputs.rank;
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public Dropout(DropoutArgs args)
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs,
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ protected override void build(TensorShape input_shape)
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ public LSTM(LSTMArgs args) :
.ToArray();
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, is_training);
return base.call(inputs, state: state, is_training: is_training);
}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Pooling2D(Pooling2DArgs args)
input_spec = new InputSpec(ndim: 4);
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
int[] pool_shape;
int[] strides;
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public Rescaling(RescalingArgs args) : base(args)
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);
Expand Down
Loading

0 comments on commit 0abf166

Please sign in to comment.