Skip to content

Commit

Permalink
Tensor: Rewrite Creation to fix heap corruption.
Browse files Browse the repository at this point in the history
TensorTest: Added Unit tests for more coverage, reformatted file.
c_api.tensor: Added overloads for TF_NewTensor that does not have deallocator parameters.
BaseSession: Removed disposal immediately after TF_SessionRun call.
c_api.DeallocatorArgs: Added DeallocatorArgs.Empty
  • Loading branch information
Nucs committed Aug 31, 2019
1 parent 1d6f3de commit e4e62dd
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 329 deletions.
2 changes: 2 additions & 0 deletions TensorFlow.NET.sln.DotSettings
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/UserDictionary/Words/=Tensorflow/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
9 changes: 9 additions & 0 deletions src/TensorFlowNET.Core/APIs/c_api.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ public static string StringPiece(IntPtr handle)

public struct DeallocatorArgs
{
internal static unsafe c_api.DeallocatorArgs* EmptyPtr;
internal static unsafe IntPtr Empty;

static unsafe DeallocatorArgs()
{
Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*) Marshal.AllocHGlobal(Marshal.SizeOf<DeallocatorArgs>()));
*EmptyPtr = new DeallocatorArgs() {gc_handle = IntPtr.Zero, deallocator_called = false};
}

public bool deallocator_called;
public IntPtr gc_handle;
}
Expand Down
18 changes: 3 additions & 15 deletions src/TensorFlowNET.Core/Sessions/BaseSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,16 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
{

var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
var ignoreDispose = new bool[feed_dict.Count];
int i = 0;
foreach (var x in feed_dict)
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break;
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
#if _REGEN
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
Expand Down Expand Up @@ -194,7 +194,6 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
#endif
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break;
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
}
Expand All @@ -203,18 +202,7 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,

var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
//var targets = target_list;
try
{
return _call_tf_sessionrun(feeds, fetches, target_list);
} finally
{
for (var idx = 0; idx < feeds.Length; idx++)
{
if (ignoreDispose[idx])
continue;
feeds[idx].Value.Dispose();
}
}
return _call_tf_sessionrun(feeds, fetches, target_list);
}

private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
Expand Down
27 changes: 27 additions & 0 deletions src/TensorFlowNET.Core/Tensors/AllocationType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace Tensorflow
{
/// <summary>
/// Used internally to
/// </summary>
public enum AllocationType
{
None = 0,
/// <summary>
/// Allocation was done by passing in a pointer, might be also holding reference to a C# object.
/// </summary>
FromPointer = 1,
/// <summary>
/// Allocation was done by calling c_api.TF_AllocateTensor or TF decided it has to copy data during c_api.TF_NewTensor. <br></br>
/// Deallocation is handled solely by Tensorflow.
/// </summary>
Tensorflow = 2,
/// <summary>
/// Allocation was done by Marshal.AllocateHGlobal
/// </summary>
Marshal = 3,
/// <summary>
/// Allocation was done by GCHandle.Alloc
/// </summary>
GCHandle = 4,
}
}
Loading

0 comments on commit e4e62dd

Please sign in to comment.