Skip to content

Commit 0abf166

Browse files
committed
Add Tensors class to adapt Tensor and Tensor[].
1 parent 69967b4 commit 0abf166

36 files changed

+282
-176
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist
337337
# training model resources
338338
.resources
339339
/redist
340+
*.xml
341+
*.xsd

src/TensorFlowNET.Core/APIs/tf.reshape.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Tensorflow
1818
{
1919
public partial class tensorflow
2020
{
21-
public Tensor reshape<T>(T tensor,
21+
public Tensor reshape(Tensor tensor,
2222
TensorShape shape,
2323
string name = null) => gen_array_ops.reshape(tensor, shape, name);
2424

src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Linq;
34
using static Tensorflow.Binding;
45

@@ -42,8 +43,8 @@ public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
4243
var values_shape = values.TensorShape.with_rank(1);
4344
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);
4445

45-
indices_shape[0].merge_with(values_shape.dims[0]);
46-
indices_shape[1].merge_with(dense_shape_shape.dims[0]);
46+
indices_shape["0"].merge_with(values_shape[0]);
47+
indices_shape["1"].merge_with(dense_shape_shape[0]);
4748

4849
_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
4950
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
66
{
77
public class ModelArgs : LayerArgs
88
{
9+
public Tensor[] Inputs { get; set; }
10+
public Tensor[] Outputs { get; set; }
911
}
1012
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ public class NodeArgs
1212
public int[] NodeIndices { get; set; }
1313
public int[] TensorIndices { get; set; }
1414
public Tensor InputTensors { get; set; }
15-
public Tensor Outputs { get; set; }
15+
public Tensors Outputs { get; set; }
1616
}
1717
}

src/TensorFlowNET.Core/Keras/Engine/Flatten.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public Flatten(FlattenArgs args)
2121
_channels_first = args.DataFormat == "channels_first";
2222
}
2323

24-
protected override Tensor call(Tensor inputs, bool is_training = false)
24+
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
2525
{
2626
if (_channels_first)
2727
{
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Engine
6+
{
7+
/// <summary>
8+
/// Tracks the Layer call that created a Tensor, for Keras Graph Networks.
9+
/// </summary>
10+
public class KerasHistory
11+
{
12+
Layer layer;
13+
int node_index;
14+
int tensor_index;
15+
16+
public KerasHistory(Layer layer, int node_index, int tensor_index)
17+
{
18+
this.layer = layer;
19+
this.node_index = node_index;
20+
this.tensor_index = tensor_index;
21+
}
22+
23+
public static implicit operator Layer(KerasHistory history)
24+
=> history.layer;
25+
26+
public static implicit operator (Layer, int, int)(KerasHistory history)
27+
=> (history.layer, history.node_index, history.tensor_index);
28+
}
29+
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,12 @@ public Layer(LayerArgs args)
119119
/// Wraps `call`, applying pre- and post-processing steps.
120120
/// </summary>
121121
/// <param name="input"></param>
122+
/// <param name="state"></param>
122123
/// <param name="is_training"></param>
123124
/// <returns></returns>
124-
public Tensor Apply(Tensor inputs, bool is_training = false)
125+
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
125126
{
126-
Tensor outputs = null;
127+
Tensors outputs = null;
127128

128129
callContext = callContext ?? new ThreadLocal<CallContext>()
129130
{
@@ -148,7 +149,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
148149
if (!built)
149150
MaybeBuild(inputs);
150151

151-
outputs = call(inputs, is_training: is_training);
152+
outputs = call(inputs, state: state, is_training: is_training);
152153

153154
outputs = _set_connectivity_metadata_(inputs, outputs);
154155
_handle_activity_regularization(inputs, outputs);
@@ -161,36 +162,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
161162
return outputs;
162163
}
163164

164-
public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false)
165-
{
166-
Tensor[] outputs = null;
167-
168-
callContext = callContext ?? new ThreadLocal<CallContext>()
169-
{
170-
Value = new CallContext()
171-
};
172-
173-
var eager = tf.executing_eagerly();
174-
using var ctxManager = CallContext.enter();
175-
176-
string nameScope = "";
177-
if (eager)
178-
nameScope = name;
179-
else
180-
nameScope = _name_scope();
181-
182-
tf_with(ops.name_scope(nameScope), scope =>
183-
{
184-
if (!built)
185-
MaybeBuild(inputs[0]);
186-
187-
outputs = call(inputs, is_training: is_training, state: state);
188-
});
189-
190-
return outputs;
191-
}
192-
193-
private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
165+
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
194166
{
195167
/*var returnOutputs = new List<Tensor>();
196168
foreach(var x in outputs)
@@ -211,15 +183,15 @@ private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
211183
return outputs;
212184
}
213185

214-
private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
186+
private void _handle_activity_regularization(Tensors inputs, Tensors outputs)
215187
{
216188
//if(_activity_regularizer != null)
217189
{
218190

219191
}
220192
}
221193

222-
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
194+
private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask)
223195
{
224196

225197
}
@@ -229,12 +201,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
229201
return null;
230202
}
231203

232-
protected virtual Tensor call(Tensor inputs, bool is_training = false)
233-
{
234-
throw new NotImplementedException("");
235-
}
236-
237-
protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
204+
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
238205
{
239206
throw new NotImplementedException("");
240207
}
@@ -244,15 +211,15 @@ protected virtual string _name_scope()
244211
return Name;
245212
}
246213

247-
protected void MaybeBuild(Tensor inputs)
214+
protected void MaybeBuild(Tensors inputs)
248215
{
249216
// Check input assumptions set before layer building, e.g. input rank.
250217
if (built)
251218
return;
252219
if (DType == TF_DataType.DtInvalid)
253220
args.DType = inputs.dtype;
254221

255-
var input_shapes = inputs.TensorShape;
222+
var input_shapes = inputs.shape;
256223
build(input_shapes);
257224
built = true;
258225
}

src/TensorFlowNET.Core/Keras/Engine/Model.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ public class Model : Layer
2727
public Model(ModelArgs args)
2828
: base(args)
2929
{
30-
30+
// Build _output_layers
31+
/*foreach(var x in args.Outputs)
32+
{
33+
var layer = x.KerasHistory;
34+
}*/
3135
}
3236

3337
public void compile(string optimizerName, string lossName)

src/TensorFlowNET.Core/Keras/Engine/Node.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ public class Node
3535

3636
public int[] node_indices;
3737
public int[] tensor_indices;
38-
public Tensor input_tensors;
39-
public Tensor Outputs => args.Outputs;
38+
public Tensors input_tensors;
39+
public Tensors Outputs => args.Outputs;
4040
public TensorShape[] input_shapes;
4141
public TensorShape[] output_shapes;
4242
List<Layer> kerasInputs;
@@ -57,7 +57,8 @@ public Node(Layer layer, NodeArgs args)
5757

5858
// Set metadata on outputs.
5959
var node_index = layer.InboundNodes.Count - 1;
60-
args.Outputs.KerasHistory.Add(layer);
60+
foreach (var (i, tensor) in enumerate(Outputs))
61+
tensor.KerasHistory = new KerasHistory(layer, node_index, i);
6162
}
6263
}
6364
}

0 commit comments

Comments
 (0)