Skip to content

Commit 476f8cf

Browse files
committed
add predict for logistic regression.
1 parent 5a2433c commit 476f8cf

File tree

12 files changed

+196
-18
lines changed

12 files changed

+196
-18
lines changed

docs/source/Graph.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,61 @@ A typical graph is looks like below:
2121

2222
![image](../assets/graph_vis_animation.gif)
2323

24+
25+
26+
### Save Model
27+
28+
Saving the model means saving all the values of the parameters and the graph.
29+
30+
```python
31+
saver = tf.train.Saver()
32+
saver.save(sess,'./tensorflowModel.ckpt')
33+
```
34+
35+
After saving the model there will be four files:
36+
37+
* tensorflowModel.ckpt.meta:
38+
* tensorflowModel.ckpt.data-00000-of-00001:
39+
* tensorflowModel.ckpt.index
40+
* checkpoint
41+
42+
We also created a protocol buffer file .pbtxt. It is human readable if you want to convert it to binary: `as_text: false`.
43+
44+
* tensorflowModel.pbtxt:
45+
46+
This holds a network of nodes, each representing one operation, connected to each other as inputs and outputs.
47+
48+
49+
50+
### Freezing the Graph
51+
52+
##### *Why we need it?*
53+
54+
When we need to keep all the values of the variables and the Graph structure in a single file we have to freeze the graph.
55+
56+
```csharp
57+
from tensorflow.python.tools import freeze_graph
58+
59+
freeze_graph.freeze_graph(input_graph = 'logistic_regression/tensorflowModel.pbtxt',
60+
input_saver = "",
61+
input_binary = False,
62+
input_checkpoint = 'logistic_regression/tensorflowModel.ckpt',
63+
output_node_names = "Softmax",
64+
restore_op_name = "save/restore_all",
65+
filename_tensor_name = "save/Const:0",
66+
output_graph = 'frozentensorflowModel.pb',
67+
clear_devices = True,
68+
initializer_nodes = "")
69+
70+
```
71+
72+
### Optimizing for Inference
73+
74+
To Reduce the amount of computation needed when the network is used only for inferences we can remove some parts of a graph that are only needed for training.
75+
76+
77+
78+
### Restoring the Model
79+
80+
81+

docs/source/LogisticRegression.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Chapter. Logistic Regression
2+
3+
### What is logistic regression?
4+
5+
6+
7+
The full example is [here](https://github.com/SciSharp/TensorFlow.NET/blob/master/test/TensorFlowNET.Examples/LogisticRegression.cs).

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ Welcome to TensorFlow.NET's documentation!
2626
Train
2727
EagerMode
2828
LinearRegression
29+
LogisticRegression
2930
ImageRecognition

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = nu
184184

185185
// Adds graph_def or the default.
186186
if (graph_def == null)
187-
meta_graph_def.GraphDef = graph._as_graph_def(add_shapes: true);
187+
meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true);
188188
else
189189
meta_graph_def.GraphDef = graph_def;
190190

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class FreezeGraph
8+
{
9+
public static void freeze_graph(string input_graph,
10+
string input_saver,
11+
bool input_binary,
12+
string input_checkpoint,
13+
string output_node_names,
14+
string restore_op_name,
15+
string filename_tensor_name,
16+
string output_graph,
17+
bool clear_devices,
18+
string initializer_nodes)
19+
{
20+
21+
}
22+
}
23+
}

src/TensorFlowNET.Core/Graphs/Graph.Export.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public Buffer ToGraphDef(Status s)
1818
return buffer;
1919
}
2020

21-
public GraphDef _as_graph_def(bool add_shapes = false)
21+
private GraphDef _as_graph_def(bool add_shapes = false)
2222
{
2323
var buffer = ToGraphDef(Status);
2424
Status.Check();
@@ -30,5 +30,8 @@ public GraphDef _as_graph_def(bool add_shapes = false)
3030

3131
return def;
3232
}
33+
34+
public GraphDef as_graph_def(bool add_shapes = false)
35+
=> _as_graph_def(add_shapes);
3336
}
3437
}

src/TensorFlowNET.Core/Graphs/graph_io.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class graph_io
1010
{
1111
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
1212
{
13-
var graph_def = graph._as_graph_def();
13+
var graph_def = graph.as_graph_def();
1414
string path = Path.Combine(logdir, name);
1515
if (as_text)
1616
File.WriteAllText(path, graph_def.ToString());

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,24 @@ public NDArray build_results(BaseSession session, NDArray[] tensor_values)
5858
{
5959
var value = tensor_values[j];
6060
j += 1;
61-
switch (value.dtype.Name)
61+
if (value.ndim == 2)
6262
{
63-
case "Int32":
64-
full_values.Add(value.Data<int>(0));
65-
break;
66-
case "Single":
67-
full_values.Add(value.Data<float>(0));
68-
break;
69-
case "Double":
70-
full_values.Add(value.Data<double>(0));
71-
break;
63+
full_values.Add(value[0]);
64+
}
65+
else
66+
{
67+
switch (value.dtype.Name)
68+
{
69+
case "Int32":
70+
full_values.Add(value.Data<int>(0));
71+
break;
72+
case "Single":
73+
full_values.Add(value.Data<float>(0));
74+
break;
75+
case "Double":
76+
full_values.Add(value.Data<double>(0));
77+
break;
78+
}
7279
}
7380
}
7481
i += 1;

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public MetaGraphDef export_meta_graph(string filename= "",
251251
{
252252
return export_meta_graph(
253253
filename: filename,
254-
graph_def: ops.get_default_graph()._as_graph_def(add_shapes: true),
254+
graph_def: ops.get_default_graph().as_graph_def(add_shapes: true),
255255
saver_def: _saver_def,
256256
collection_list: collection_list,
257257
as_text: as_text,

src/TensorFlowNET.Core/Train/tf.optimizers.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ public static class train
1616

1717
public static Saver Saver() => new Saver();
1818

19-
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
19+
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
20+
=> graph_io.write_graph(graph, logdir, name, as_text);
2021

2122
public static Saver import_meta_graph(string meta_graph_or_file,
2223
bool clear_devices = false,

0 commit comments

Comments
 (0)