diff --git a/graph/InceptionV3.meta b/graph/InceptionV3.meta index 2a11b0829..fe220cce1 100644 Binary files a/graph/InceptionV3.meta and b/graph/InceptionV3.meta differ diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index feb34663c..3902d5d06 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -74,131 +74,23 @@ public class RetrainImageClassifier : IExample Tensor bottleneck_input; Tensor cross_entropy; Tensor ground_truth_input; + Tensor bottleneck_tensor; + bool wants_quantization; + float test_accuracy; + NDArray predictions; public bool Run() { PrepareData(); - // Set up the pre-trained graph. - var (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) = - create_module_graph(); + var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - // Add the new layer that we'll be training. - with(graph.as_default(), delegate + with(tf.Session(graph), sess => { - (train_step, cross_entropy, bottleneck_input, - ground_truth_input, final_tensor) = add_final_retrain_ops( - class_count, final_tensor_name, bottleneck_tensor, - wants_quantization, is_training: true); + Train(sess); }); - var sw = new Stopwatch(); - - return with(tf.Session(graph), sess => - { - // Initialize all weights: for the module to their pretrained values, - // and for the newly added retraining layer to random initial values. - var init = tf.global_variables_initializer(); - sess.run(init); - - var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); - - // We'll make sure we've calculated the 'bottleneck' image summaries and - // cached them on disk. - cache_bottlenecks(sess, image_lists, image_dir, - bottleneck_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, - bottleneck_tensor, tfhub_module); - - // Create the operations we need to evaluate the accuracy of our new layer. - var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input); - - // Merge all the summaries and write them out to the summaries_dir - var merged = tf.summary.merge_all(); - var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph); - var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph); - - // Create a train saver that is used to restore values into an eval graph - // when exporting models. - var train_saver = tf.train.Saver(); - train_saver.save(sess, CHECKPOINT_NAME); - - sw.Restart(); - - for (int i = 0; i < how_many_training_steps; i++) - { - var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks( - sess, image_lists, train_batch_size, "training", - bottleneck_dir, image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - tfhub_module); - - // Feed the bottlenecks and ground truth into the graph, and run a training - // step. Capture training summaries for TensorBoard with the `merged` op. - var results = sess.run( - new ITensorOrOperation[] { merged, train_step }, - new FeedItem(bottleneck_input, train_bottlenecks), - new FeedItem(ground_truth_input, train_ground_truth)); - var train_summary = results[0]; - - // TODO - train_writer.add_summary(train_summary, i); - - // Every so often, print out how well the graph is training. - bool is_last_step = (i + 1 == how_many_training_steps); - if ((i % eval_step_interval) == 0 || is_last_step) - { - results = sess.run( - new Tensor[] { evaluation_step, cross_entropy }, - new FeedItem(bottleneck_input, train_bottlenecks), - new FeedItem(ground_truth_input, train_ground_truth)); - (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); - print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); - - var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( - sess, image_lists, validation_batch_size, "validation", - bottleneck_dir, image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - tfhub_module); - - // Run a validation step and capture training summaries for TensorBoard - // with the `merged` op. - results = sess.run(new Tensor[] { merged, evaluation_step }, - new FeedItem(bottleneck_input, validation_bottlenecks), - new FeedItem(ground_truth_input, validation_ground_truth)); - - (string validation_summary, float validation_accuracy) = (results[0], results[1]); - - validation_writer.add_summary(validation_summary, i); - print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); - sw.Restart(); - } - - // Store intermediate results - int intermediate_frequency = intermediate_store_frequency; - if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0) - { - - } - } - - // After training is complete, force one last save of the train checkpoint. - train_saver.save(sess, CHECKPOINT_NAME); - - // We've completed all our training, so run a final test evaluation on - // some new images we haven't used before. - var (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists, - jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, - bottleneck_tensor); - - // Write out the trained graph and labels with the weights stored as - // constants. - print($"Save final result to : {output_graph}"); - save_graph_to_file(output_graph, class_count); - File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); - - return test_accuracy > 0.75f; - }); + return test_accuracy > 0.75f; } /// @@ -212,7 +104,7 @@ public bool Run() /// /// /// - private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count, + private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count, Dictionary> image_lists, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_image_tensor, Tensor bottleneck_tensor) @@ -233,7 +125,7 @@ public bool Run() return (results[0], results[1]); } - private (Session, Tensor, Tensor, Tensor, Tensor, Tensor) + private (Session, Tensor, Tensor, Tensor, Tensor, Tensor) build_eval_session(int class_count) { // If quantized, we need to create the correct eval graph for exporting. @@ -245,7 +137,7 @@ public bool Run() with(eval_graph.as_default(), graph => { // Add the new layer for exporting. - var (_, _, bottleneck_input, ground_truth_input, final_tensor) = + var (_, _, bottleneck_input, ground_truth_input, final_tensor) = add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, wants_quantization, is_training: false); @@ -277,7 +169,7 @@ public bool Run() /// /// /// - private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, + private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, Tensor bottleneck_tensor, bool quantize_layer, bool is_training) { var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]); @@ -365,7 +257,8 @@ private void variable_summaries(RefVariable var) var mean = tf.reduce_mean(var); tf.summary.scalar("mean", mean); Tensor stddev = null; - with(tf.name_scope("stddev"), delegate { + with(tf.name_scope("stddev"), delegate + { stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); }); tf.summary.scalar("stddev", stddev); @@ -378,7 +271,7 @@ private void variable_summaries(RefVariable var) private (Graph, Tensor, Tensor, bool) create_module_graph() { var (height, width) = (299, 299); - + return with(tf.Graph().as_default(), graph => { tf.train.import_meta_graph("graph/InceptionV3.meta"); @@ -390,8 +283,8 @@ private void variable_summaries(RefVariable var) }); } - private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists, - int how_many, string category, string bottleneck_dir, string image_dir, + private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists, + int how_many, string category, string bottleneck_dir, string image_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { @@ -423,7 +316,7 @@ private void variable_summaries(RefVariable var) // Retrieve all bottlenecks. foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray())) { - foreach(var (image_index, image_name) in enumerate(image_lists[label_name][category])) + foreach (var (image_index, image_name) in enumerate(image_lists[label_name][category])) { var bottleneck = get_or_create_bottleneck( sess, image_lists, label_name, image_index, image_dir, category, @@ -480,17 +373,17 @@ private void variable_summaries(RefVariable var) /// /// /// - private void cache_bottlenecks(Session sess, Dictionary> image_lists, + private void cache_bottlenecks(Session sess, Dictionary> image_lists, string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { int how_many_bottlenecks = 0; - foreach(var (label_name, label_lists) in image_lists) + foreach (var (label_name, label_lists) in image_lists) { - foreach(var category in new string[] { "training", "testing", "validation" }) + foreach (var category in new string[] { "training", "testing", "validation" }) { var category_list = label_lists[category]; - foreach(var (index, unused_base_name) in enumerate(category_list)) + foreach (var (index, unused_base_name) in enumerate(category_list)) { get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, @@ -503,8 +396,8 @@ private void cache_bottlenecks(Session sess, Dictionary> image_lists, - string label_name, int index, string image_dir, string category, string bottleneck_dir, + private float[] get_or_create_bottleneck(Session sess, Dictionary> image_lists, + string label_name, int index, string image_dir, string category, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) { @@ -524,8 +417,8 @@ private float[] get_or_create_bottleneck(Session sess, Dictionary> image_lists, - string label_name, int index, string image_dir, string category, Session sess, + private void create_bottleneck_file(string bottleneck_path, Dictionary> image_lists, + string label_name, int index, string image_dir, string category, Session sess, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) { // Create a single bottleneck file. @@ -557,14 +450,14 @@ private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) { // First decode the JPEG image, resize it, and rescale the pixel values. - var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor( image_data, TF_DataType.TF_STRING))); + var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor(image_data, TF_DataType.TF_STRING))); // Then run it through the recognition network. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values)); bottleneck_values = np.squeeze(bottleneck_values); return bottleneck_values; } - private string get_bottleneck_path(Dictionary> image_lists, string label_name, int index, + private string get_bottleneck_path(Dictionary> image_lists, string label_name, int index, string bottleneck_dir, string category, string module_name) { module_name = (module_name.Replace("://", "~") // URL scheme. @@ -667,7 +560,7 @@ private Dictionary> create_image_lists() var result = new Dictionary>(); - foreach(var sub_dir in sub_dirs) + foreach (var sub_dir in sub_dirs) { var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last(); print($"Looking for images in '{dir_name}'"); @@ -689,7 +582,22 @@ private Dictionary> create_image_lists() public Graph ImportGraph() { - throw new NotImplementedException(); + Graph graph; + + // Set up the pre-trained graph. + (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) = + create_module_graph(); + + // Add the new layer that we'll be training. + with(graph.as_default(), delegate + { + (train_step, cross_entropy, bottleneck_input, + ground_truth_input, final_tensor) = add_final_retrain_ops( + class_count, final_tensor_name, bottleneck_tensor, + wants_quantization, is_training: true); + }); + + return graph; } public Graph BuildGraph() @@ -699,7 +607,108 @@ public Graph BuildGraph() public void Train(Session sess) { - throw new NotImplementedException(); + var sw = new Stopwatch(); + + // Initialize all weights: for the module to their pretrained values, + // and for the newly added retraining layer to random initial values. + var init = tf.global_variables_initializer(); + sess.run(init); + + var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); + + // We'll make sure we've calculated the 'bottleneck' image summaries and + // cached them on disk. + cache_bottlenecks(sess, image_lists, image_dir, + bottleneck_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor, tfhub_module); + + // Create the operations we need to evaluate the accuracy of our new layer. + var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input); + + // Merge all the summaries and write them out to the summaries_dir + var merged = tf.summary.merge_all(); + var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph); + var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph); + + // Create a train saver that is used to restore values into an eval graph + // when exporting models. + var train_saver = tf.train.Saver(); + train_saver.save(sess, CHECKPOINT_NAME); + + sw.Restart(); + + for (int i = 0; i < how_many_training_steps; i++) + { + var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks( + sess, image_lists, train_batch_size, "training", + bottleneck_dir, image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, + tfhub_module); + + // Feed the bottlenecks and ground truth into the graph, and run a training + // step. Capture training summaries for TensorBoard with the `merged` op. + var results = sess.run( + new ITensorOrOperation[] { merged, train_step }, + new FeedItem(bottleneck_input, train_bottlenecks), + new FeedItem(ground_truth_input, train_ground_truth)); + var train_summary = results[0]; + + // TODO + train_writer.add_summary(train_summary, i); + + // Every so often, print out how well the graph is training. + bool is_last_step = (i + 1 == how_many_training_steps); + if ((i % eval_step_interval) == 0 || is_last_step) + { + results = sess.run( + new Tensor[] { evaluation_step, cross_entropy }, + new FeedItem(bottleneck_input, train_bottlenecks), + new FeedItem(ground_truth_input, train_ground_truth)); + (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); + print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); + + var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( + sess, image_lists, validation_batch_size, "validation", + bottleneck_dir, image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, + tfhub_module); + + // Run a validation step and capture training summaries for TensorBoard + // with the `merged` op. + results = sess.run(new Tensor[] { merged, evaluation_step }, + new FeedItem(bottleneck_input, validation_bottlenecks), + new FeedItem(ground_truth_input, validation_ground_truth)); + + (string validation_summary, float validation_accuracy) = (results[0], results[1]); + + validation_writer.add_summary(validation_summary, i); + print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); + sw.Restart(); + } + + // Store intermediate results + int intermediate_frequency = intermediate_store_frequency; + if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0) + { + + } + } + + // After training is complete, force one last save of the train checkpoint. + train_saver.save(sess, CHECKPOINT_NAME); + + // We've completed all our training, so run a final test evaluation on + // some new images we haven't used before. + (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists, + jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, + bottleneck_tensor); + + // Write out the trained graph and labels with the weights stored as + // constants. + print($"Save final result to : {output_graph}"); + save_graph_to_file(output_graph, class_count); + File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); } public void Predict(Session sess) diff --git a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs b/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs index 4884a76a5..a30912bee 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs @@ -51,7 +51,7 @@ public bool Run() var graph = tf.Graph().as_default(); - tf.train.import_meta_graph("graph/word2vec.meta"); + tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta"); // Input data Tensor X = graph.OperationByName("Placeholder"); @@ -169,10 +169,10 @@ public void PrepareData() url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; Web.Download(url, "word2vec", "text8.zip"); // Unzip the dataset file. Text has already been processed - Compress.UnZip(@"word2vec\text8.zip", "word2vec"); + Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec"); int wordId = 0; - text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split(); + text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split(); // Build the dictionary and replace rare words with UNK token word2id = text_words.GroupBy(x => x) .Select(x => new WordId