From ee3a76bf34ef26e2730e16850d67c4e2753469c4 Mon Sep 17 00:00:00 2001 From: Nithin Vasisth Date: Sat, 11 Mar 2017 01:24:44 +0530 Subject: [PATCH 1/2] Upgraded files to tf version 1.0.0 --- infogan/algos/infogan_trainer.py | 14 +- infogan/misc/custom_ops.py | 4 +- infogan/misc/datasets.py | 2 +- infogan/misc/distributions.py | 26 +- infogan/misc/utils.py | 2 +- tf_upgrade.py | 681 +++++++++++++++++++++++++++++++ 6 files changed, 705 insertions(+), 24 deletions(-) create mode 100644 tf_upgrade.py diff --git a/infogan/algos/infogan_trainer.py b/infogan/algos/infogan_trainer.py index 946dbb9..94375e9 100644 --- a/infogan/algos/infogan_trainer.py +++ b/infogan/algos/infogan_trainer.py @@ -121,7 +121,7 @@ def init_opt(self): self.generator_trainer = pt.apply_optimizer(generator_optimizer, losses=[generator_loss], var_list=g_vars) for k, v in self.log_vars: - tf.scalar_summary(k, v) + tf.summary.scalar(k, v) with pt.defaults_scope(phase=pt.Phase.test): with tf.variable_scope("model", reuse=True) as scope: @@ -199,23 +199,23 @@ def visualize_all_factors(self): row_img = [] for col in xrange(rows): row_img.append(imgs[row, col, :, :, :]) - stacked_img.append(tf.concat(1, row_img)) - imgs = tf.concat(0, stacked_img) + stacked_img.append(tf.concat(axis=1, values=row_img)) + imgs = tf.concat(axis=0, values=stacked_img) imgs = tf.expand_dims(imgs, 0) - tf.image_summary("image_%d_%s" % (dist_idx, dist.__class__.__name__), imgs) + tf.summary.image("image_%d_%s" % (dist_idx, dist.__class__.__name__), imgs) def train(self): self.init_opt() - init = tf.initialize_all_variables() + init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) - summary_op = tf.merge_all_summaries() - summary_writer = tf.train.SummaryWriter(self.log_dir, sess.graph) + summary_op = tf.summary.merge_all() + summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph) saver = tf.train.Saver() diff --git a/infogan/misc/custom_ops.py b/infogan/misc/custom_ops.py index 9434a2c..14f309e 100644 --- a/infogan/misc/custom_ops.py +++ b/infogan/misc/custom_ops.py @@ -79,7 +79,7 @@ def __call__(self, input_layer, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="deconv2d"): output_shape[0] = input_layer.shape[0] - ts_output_shape = tf.pack(output_shape) + ts_output_shape = tf.stack(output_shape) with tf.variable_scope(name): # filter : [height, width, output_channels, in_channels] w = self.variable('w', [k_h, k_w, output_shape[-1], input_layer.shape[-1]], @@ -108,7 +108,7 @@ def __call__(self, input_layer, output_size, scope=None, in_dim=None, stddev=0.0 input_ = input_layer.tensor try: if len(shape) == 4: - input_ = tf.reshape(input_, tf.pack([tf.shape(input_)[0], np.prod(shape[1:])])) + input_ = tf.reshape(input_, tf.stack([tf.shape(input_)[0], np.prod(shape[1:])])) input_.set_shape([None, np.prod(shape[1:])]) shape = input_.get_shape().as_list() diff --git a/infogan/misc/datasets.py b/infogan/misc/datasets.py index 4602d76..27db406 100644 --- a/infogan/misc/datasets.py +++ b/infogan/misc/datasets.py @@ -84,4 +84,4 @@ def transform(self, data): return data def inverse_transform(self, data): - return data + return data \ No newline at end of file diff --git a/infogan/misc/distributions.py b/infogan/misc/distributions.py index 83bfc69..c91576c 100644 --- a/infogan/misc/distributions.py +++ b/infogan/misc/distributions.py @@ -122,7 +122,7 @@ def effective_dim(self): def logli(self, x_var, dist_info): prob = dist_info["prob"] - return tf.reduce_sum(tf.log(prob + TINY) * x_var, reduction_indices=1) + return tf.reduce_sum(tf.log(prob + TINY) * x_var, axis=1) def prior_dist_info(self, batch_size): prob = tf.ones([batch_size, self.dim]) * floatX(1.0 / self.dim) @@ -131,8 +131,8 @@ def prior_dist_info(self, batch_size): def marginal_logli(self, x_var, dist_info): prob = dist_info["prob"] avg_prob = tf.tile( - tf.reduce_mean(prob, reduction_indices=0, keep_dims=True), - tf.pack([tf.shape(prob)[0], 1]) + tf.reduce_mean(prob, axis=0, keep_dims=True), + tf.stack([tf.shape(prob)[0], 1]) ) return self.logli(x_var, dict(prob=avg_prob)) @@ -149,7 +149,7 @@ def kl(self, p, q): q_prob = q["prob"] return tf.reduce_sum( p_prob * (tf.log(p_prob + TINY) - tf.log(q_prob + TINY)), - reduction_indices=1 + axis=1 ) def sample(self, dist_info): @@ -163,13 +163,13 @@ def activate_dist(self, flat_dist): def entropy(self, dist_info): prob = dist_info["prob"] - return -tf.reduce_sum(prob * tf.log(prob + TINY), reduction_indices=1) + return -tf.reduce_sum(prob * tf.log(prob + TINY), axis=1) def marginal_entropy(self, dist_info): prob = dist_info["prob"] avg_prob = tf.tile( - tf.reduce_mean(prob, reduction_indices=0, keep_dims=True), - tf.pack([tf.shape(prob)[0], 1]) + tf.reduce_mean(prob, axis=0, keep_dims=True), + tf.stack([tf.shape(prob)[0], 1]) ) return self.entropy(dict(prob=avg_prob)) @@ -201,7 +201,7 @@ def logli(self, x_var, dist_info): epsilon = (x_var - mean) / (stddev + TINY) return tf.reduce_sum( - 0.5 * np.log(2 * np.pi) - tf.log(stddev + TINY) - 0.5 * tf.square(epsilon), - reduction_indices=1, + axis=1, ) def prior_dist_info(self, batch_size): @@ -225,7 +225,7 @@ def kl(self, p, q): denominator = 2. * tf.square(q_stddev) return tf.reduce_sum( numerator / (denominator + TINY) + tf.log(q_stddev + TINY) - tf.log(p_stddev + TINY), - reduction_indices=1 + axis=1 ) def sample(self, dist_info): @@ -291,7 +291,7 @@ def logli(self, x_var, dist_info): p = dist_info["p"] return tf.reduce_sum( x_var * tf.log(p + TINY) + (1.0 - x_var) * tf.log(1.0 - p + TINY), - reduction_indices=1 + axis=1 ) def nonreparam_logli(self, x_var, dist_info): @@ -397,7 +397,7 @@ def join_vars(self, xs): """ Join the per component tensor variables into a whole tensor """ - return tf.concat(1, xs) + return tf.concat(axis=1, values=xs) def split_dist_flat(self, dist_flat): """ @@ -434,13 +434,13 @@ def sample(self, dist_info): ret = [] for dist_info_i, dist_i in zip(self.split_dist_info(dist_info), self.dists): ret.append(tf.cast(dist_i.sample(dist_info_i), tf.float32)) - return tf.concat(1, ret) + return tf.concat(axis=1, values=ret) def sample_prior(self, batch_size): ret = [] for dist_i in self.dists: ret.append(tf.cast(dist_i.sample_prior(batch_size), tf.float32)) - return tf.concat(1, ret) + return tf.concat(axis=1, values=ret) def logli(self, x_var, dist_info): ret = tf.constant(0.) diff --git a/infogan/misc/utils.py b/infogan/misc/utils.py index 3385b5d..ead1ed3 100644 --- a/infogan/misc/utils.py +++ b/infogan/misc/utils.py @@ -11,4 +11,4 @@ def mkdir_p(path): if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: - raise + raise \ No newline at end of file diff --git a/tf_upgrade.py b/tf_upgrade.py new file mode 100644 index 0000000..3cd27a4 --- /dev/null +++ b/tf_upgrade.py @@ -0,0 +1,681 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import argparse +import ast +import collections +import os +import shutil +import sys +import tempfile +import traceback + + +class APIChangeSpec(object): + """List of maps that describe what changed in the API.""" + + def __init__(self): + # Maps from a function name to a dictionary that describes how to + # map from an old argument keyword to the new argument keyword. + self.function_keyword_renames = { + "tf.count_nonzero": { + "reduction_indices": "axis" + }, + "tf.reduce_all": { + "reduction_indices": "axis" + }, + "tf.reduce_any": { + "reduction_indices": "axis" + }, + "tf.reduce_max": { + "reduction_indices": "axis" + }, + "tf.reduce_mean": { + "reduction_indices": "axis" + }, + "tf.reduce_min": { + "reduction_indices": "axis" + }, + "tf.reduce_prod": { + "reduction_indices": "axis" + }, + "tf.reduce_sum": { + "reduction_indices": "axis" + }, + "tf.reduce_logsumexp": { + "reduction_indices": "axis" + }, + "tf.expand_dims": { + "dim": "axis" + }, + "tf.argmax": { + "dimension": "axis" + }, + "tf.argmin": { + "dimension": "axis" + }, + "tf.reduce_join": { + "reduction_indices": "axis" + }, + "tf.sparse_concat": { + "concat_dim": "axis" + }, + "tf.sparse_split": { + "split_dim": "axis" + }, + "tf.sparse_reduce_sum": { + "reduction_axes": "axis" + }, + "tf.reverse_sequence": { + "seq_dim": "seq_axis", + "batch_dim": "batch_axis" + }, + "tf.sparse_reduce_sum_sparse": { + "reduction_axes": "axis" + }, + "tf.squeeze": { + "squeeze_dims": "axis" + }, + "tf.split": { + "split_dim": "axis", + "num_split": "num_or_size_splits" + }, + "tf.concat": { + "concat_dim": "axis" + }, + } + + # Mapping from function to the new name of the function + self.function_renames = { + "tf.inv": "tf.reciprocal", + "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", + "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", + "tf.listdiff": "tf.setdiff1d", + "tf.list_diff": "tf.setdiff1d", + "tf.mul": "tf.multiply", + "tf.neg": "tf.negative", + "tf.sub": "tf.subtract", + "tf.train.SummaryWriter": "tf.summary.FileWriter", + "tf.scalar_summary": "tf.summary.scalar", + "tf.histogram_summary": "tf.summary.histogram", + "tf.audio_summary": "tf.summary.audio", + "tf.image_summary": "tf.summary.image", + "tf.merge_summary": "tf.summary.merge", + "tf.merge_all_summaries": "tf.summary.merge_all", + "tf.image.per_image_whitening": "tf.image.per_image_standardization", + "tf.all_variables": "tf.global_variables", + "tf.VARIABLES": "tf.GLOBAL_VARIABLES", + "tf.initialize_all_variables": "tf.global_variables_initializer", + "tf.initialize_variables": "tf.variables_initializer", + "tf.initialize_local_variables": "tf.local_variables_initializer", + "tf.batch_matrix_diag": "tf.matrix_diag", + "tf.batch_band_part": "tf.band_part", + "tf.batch_set_diag": "tf.set_diag", + "tf.batch_matrix_transpose": "tf.matrix_transpose", + "tf.batch_matrix_determinant": "tf.matrix_determinant", + "tf.batch_matrix_inverse": "tf.matrix_inverse", + "tf.batch_cholesky": "tf.cholesky", + "tf.batch_cholesky_solve": "tf.cholesky_solve", + "tf.batch_matrix_solve": "tf.matrix_solve", + "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", + "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", + "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", + "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", + "tf.batch_svd": "tf.svd", + "tf.batch_fft": "tf.fft", + "tf.batch_ifft": "tf.ifft", + "tf.batch_ifft2d": "tf.ifft2d", + "tf.batch_fft3d": "tf.fft3d", + "tf.batch_ifft3d": "tf.ifft3d", + "tf.select": "tf.where", + "tf.complex_abs": "tf.abs", + "tf.batch_matmul": "tf.matmul", + "tf.pack": "tf.stack", + "tf.unpack": "tf.unstack", + } + + self.change_to_function = { + "tf.ones_initializer", + "tf.zeros_initializer", + } + + # Functions that were reordered should be changed to the new keyword args + # for safety, if positional arguments are used. If you have reversed the + # positional arguments yourself, this could do the wrong thing. + self.function_reorders = { + "tf.split": ["axis", "num_or_size_splits", "value", "name"], + "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], + "tf.concat": ["concat_dim", "values", "name"], + "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], + "tf.nn.softmax_cross_entropy_with_logits": [ + "logits", "labels", "dim", "name"], + "tf.nn.sparse_softmax_cross_entropy_with_logits": [ + "logits", "labels", "name"], + "tf.nn.sigmoid_cross_entropy_with_logits": [ + "logits", "labels", "name"] + } + + # Specially handled functions. + self.function_handle = {"tf.reverse": self._reverse_handler} + + @staticmethod + def _reverse_handler(file_edit_recorder, node): + # TODO(aselle): Could check for a literal list of bools and try to convert + # them to indices. + comment = ("ERROR: tf.reverse has had its argument semantics changed\n" + "significantly the converter cannot detect this reliably, so you" + "need to inspect this usage manually.\n") + file_edit_recorder.add(comment, + node.lineno, + node.col_offset, + "tf.reverse", + "tf.reverse", + error="tf.reverse requires manual check.") + + +class FileEditTuple(collections.namedtuple( + "FileEditTuple", ["comment", "line", "start", "old", "new"])): + """Each edit that is recorded by a FileEditRecorder. + + Fields: + comment: A description of the edit and why it was made. + line: The line number in the file where the edit occurs (1-indexed). + start: The line number in the file where the edit occurs (0-indexed). + old: text string to remove (this must match what was in file). + new: text string to add in place of `old`. + """ + + __slots__ = () + + +class FileEditRecorder(object): + """Record changes that need to be done to the file.""" + + def __init__(self, filename): + # all edits are lists of chars + self._filename = filename + + self._line_to_edit = collections.defaultdict(list) + self._errors = [] + + def process(self, text): + """Process a list of strings, each corresponding to the recorded changes. + + Args: + text: A list of lines of text (assumed to contain newlines) + Returns: + A tuple of the modified text and a textual description of what is done. + Raises: + ValueError: if substitution source location does not have expected text. + """ + + change_report = "" + + # Iterate of each line + for line, edits in self._line_to_edit.items(): + offset = 0 + # sort by column so that edits are processed in order in order to make + # indexing adjustments cumulative for changes that change the string + # length + edits.sort(key=lambda x: x.start) + + # Extract each line to a list of characters, because mutable lists + # are editable, unlike immutable strings. + char_array = list(text[line - 1]) + + # Record a description of the change + change_report += "%r Line %d\n" % (self._filename, line) + change_report += "-" * 80 + "\n\n" + for e in edits: + change_report += "%s\n" % e.comment + change_report += "\n Old: %s" % (text[line - 1]) + + # Make underscore buffers for underlining where in the line the edit was + change_list = [" "] * len(text[line - 1]) + change_list_new = [" "] * len(text[line - 1]) + + # Iterate for each edit + for e in edits: + # Create effective start, end by accounting for change in length due + # to previous edits + start_eff = e.start + offset + end_eff = start_eff + len(e.old) + + # Make sure the edit is changing what it should be changing + old_actual = "".join(char_array[start_eff:end_eff]) + if old_actual != e.old: + raise ValueError("Expected text %r but got %r" % + ("".join(e.old), "".join(old_actual))) + # Make the edit + char_array[start_eff:end_eff] = list(e.new) + + # Create the underline highlighting of the before and after + change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) + change_list_new[start_eff:end_eff] = "~" * len(e.new) + + # Keep track of how to generate effective ranges + offset += len(e.new) - len(e.old) + + # Finish the report comment + change_report += " %s\n" % "".join(change_list) + text[line - 1] = "".join(char_array) + change_report += " New: %s" % (text[line - 1]) + change_report += " %s\n\n" % "".join(change_list_new) + return "".join(text), change_report, self._errors + + def add(self, comment, line, start, old, new, error=None): + """Add a new change that is needed. + + Args: + comment: A description of what was changed + line: Line number (1 indexed) + start: Column offset (0 indexed) + old: old text + new: new text + error: this "edit" is something that cannot be fixed automatically + Returns: + None + """ + + self._line_to_edit[line].append( + FileEditTuple(comment, line, start, old, new)) + if error: + self._errors.append("%s:%d: %s" % (self._filename, line, error)) + + +class TensorFlowCallVisitor(ast.NodeVisitor): + """AST Visitor that finds TensorFlow Function calls. + + Updates function calls from old API version to new API version. + """ + + def __init__(self, filename, lines): + self._filename = filename + self._file_edit = FileEditRecorder(filename) + self._lines = lines + self._api_change_spec = APIChangeSpec() + + def process(self, lines): + return self._file_edit.process(lines) + + def generic_visit(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def _rename_functions(self, node, full_name): + function_renames = self._api_change_spec.function_renames + try: + new_name = function_renames[full_name] + self._file_edit.add("Renamed function %r to %r" % (full_name, + new_name), + node.lineno, node.col_offset, full_name, new_name) + except KeyError: + pass + + def _get_attribute_full_path(self, node): + """Traverse an attribute to generate a full name e.g. tf.foo.bar. + + Args: + node: A Node of type Attribute. + + Returns: + a '.'-delimited full-name or None if the tree was not a simple form. + i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". + """ + curr = node + items = [] + while not isinstance(curr, ast.Name): + if not isinstance(curr, ast.Attribute): + return None + items.append(curr.attr) + curr = curr.value + items.append(curr.id) + return ".".join(reversed(items)) + + def _find_true_position(self, node): + """Return correct line number and column offset for a given node. + + This is necessary mainly because ListComp's location reporting reports + the next token after the list comprehension list opening. + + Args: + node: Node for which we wish to know the lineno and col_offset + """ + import re + find_open = re.compile("^\s*(\\[).*$") + find_string_chars = re.compile("['\"]") + + if isinstance(node, ast.ListComp): + # Strangely, ast.ListComp returns the col_offset of the first token + # after the '[' token which appears to be a bug. Workaround by + # explicitly finding the real start of the list comprehension. + line = node.lineno + col = node.col_offset + # loop over lines + while 1: + # Reverse the text to and regular expression search for whitespace + text = self._lines[line-1] + reversed_preceding_text = text[:col][::-1] + # First find if a [ can be found with only whitespace between it and + # col. + m = find_open.match(reversed_preceding_text) + if m: + new_col_offset = col - m.start(1) - 1 + return line, new_col_offset + else: + if (reversed_preceding_text=="" or + reversed_preceding_text.isspace()): + line = line - 1 + prev_line = self._lines[line - 1] + # TODO(aselle): + # this is poor comment detection, but it is good enough for + # cases where the comment does not contain string literal starting/ + # ending characters. If ast gave us start and end locations of the + # ast nodes rather than just start, we could use string literal + # node ranges to filter out spurious #'s that appear in string + # literals. + comment_start = prev_line.find("#") + if comment_start == -1: + col = len(prev_line) -1 + elif find_string_chars.search(prev_line[comment_start:]) is None: + col = comment_start + else: + return None, None + else: + return None, None + # Most other nodes return proper locations (with notably does not), but + # it is not possible to use that in an argument. + return node.lineno, node.col_offset + + + def visit_Call(self, node): # pylint: disable=invalid-name + """Handle visiting a call node in the AST. + + Args: + node: Current Node + """ + + + # Find a simple attribute name path e.g. "tf.foo.bar" + full_name = self._get_attribute_full_path(node.func) + + # Make sure the func is marked as being part of a call + node.func.is_function_for_call = True + + if full_name and full_name.startswith("tf."): + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + lineno, col_offset = self._find_true_position(arg) + if lineno is None or col_offset is None: + self._file_edit.add( + "Failed to add keyword %r to reordered function %r" + % (reordered[idx], full_name), arg.lineno, arg.col_offset, + "", "", + error="A necessary keyword argument failed to be inserted.") + else: + keyword_arg = reordered[idx] + if (full_name in function_keyword_renames and + keyword_arg in function_keyword_renames[full_name]): + keyword_arg = function_keyword_renames[full_name][keyword_arg] + self._file_edit.add("Added keyword %r to reordered function %r" + % (reordered[idx], full_name), lineno, + col_offset, "", keyword_arg + "=") + + # Examine each keyword argument and convert it to the final renamed form + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + + if argkey in renamed_keywords: + argval_lineno, argval_col_offset = self._find_true_position(argval) + if (argval_lineno is not None and argval_col_offset is not None): + # TODO(aselle): We should scan backward to find the start of the + # keyword key. Unfortunately ast does not give you the location of + # keyword keys, so we are forced to infer it from the keyword arg + # value. + key_start = argval_col_offset - len(argkey) - 1 + key_end = key_start + len(argkey) + 1 + if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=": + self._file_edit.add("Renamed keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval_lineno, + argval_col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + continue + self._file_edit.add( + "Failed to rename keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + "", "", + error="Failed to find keyword lexographically. Fix manually.") + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar]. + + Args: + node: Node that is of type ast.Attribute + """ + full_name = self._get_attribute_full_path(node) + if full_name and full_name.startswith("tf."): + self._rename_functions(node, full_name) + if full_name in self._api_change_spec.change_to_function: + if not hasattr(node, "is_function_for_call"): + new_text = full_name + "()" + self._file_edit.add("Changed %r to %r"%(full_name, new_text), + node.lineno, node.col_offset, full_name, new_text) + + ast.NodeVisitor.generic_visit(self, node) + + +class TensorFlowCodeUpgrader(object): + """Class that handles upgrading a set of Python files to TensorFlow 1.0.""" + + def __init__(self): + pass + + def process_file(self, in_filename, out_filename): + """Process the given python file for incompatible changes. + + Args: + in_filename: filename to parse + out_filename: output file to write to + Returns: + A tuple representing number of files processed, log of actions, errors + """ + + # Write to a temporary file, just in case we are doing an implace modify. + with open(in_filename, "r") as in_file, \ + tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + ret = self.process_opened_file( + in_filename, in_file, out_filename, temp_file) + + shutil.move(temp_file.name, out_filename) + return ret + + # Broad exceptions are required here because ast throws whatever it wants. + # pylint: disable=broad-except + def process_opened_file(self, in_filename, in_file, out_filename, out_file): + """Process the given python file for incompatible changes. + + This function is split out to facilitate StringIO testing from + tf_upgrade_test.py. + + Args: + in_filename: filename to parse + in_file: opened file (or StringIO) + out_filename: output file to write to + out_file: opened file (or StringIO) + Returns: + A tuple representing number of files processed, log of actions, errors + """ + process_errors = [] + text = "-" * 80 + "\n" + text += "Processing file %r\n outputting to %r\n" % (in_filename, + out_filename) + text += "-" * 80 + "\n\n" + + parsed_ast = None + lines = in_file.readlines() + try: + parsed_ast = ast.parse("".join(lines)) + except Exception: + text += "Failed to parse %r\n\n" % in_filename + text += traceback.format_exc() + if parsed_ast: + visitor = TensorFlowCallVisitor(in_filename, lines) + visitor.visit(parsed_ast) + out_text, new_text, process_errors = visitor.process(lines) + text += new_text + if out_file: + out_file.write(out_text) + text += "\n" + return 1, text, process_errors + # pylint: enable=broad-except + + def process_tree(self, root_directory, output_root_directory): + """Processes upgrades on an entire tree of python files in place. + + Note that only Python files. If you have custom code in other languages, + you will need to manually upgrade those. + + Args: + root_directory: Directory to walk and process. + output_root_directory: Directory to use as base + Returns: + A tuple of files processed, the report string ofr all files, and errors + """ + + # make sure output directory doesn't exist + if output_root_directory and os.path.exists(output_root_directory): + print("Output directory %r must not already exist." % ( + output_root_directory)) + sys.exit(1) + + # make sure output directory does not overlap with root_directory + norm_root = os.path.split(os.path.normpath(root_directory)) + norm_output = os.path.split(os.path.normpath(output_root_directory)) + if norm_root == norm_output: + print("Output directory %r same as input directory %r" % ( + root_directory, output_root_directory)) + sys.exit(1) + + # Collect list of files to process (we do this to correctly handle if the + # user puts the output directory in some sub directory of the input dir) + files_to_process = [] + for dir_name, _, file_list in os.walk(root_directory): + py_files = [f for f in file_list if f.endswith(".py")] + for filename in py_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join( + output_root_directory, os.path.relpath(fullpath, root_directory)) + files_to_process.append((fullpath, fullpath_output)) + + file_count = 0 + tree_errors = [] + report = "" + report += ("=" * 80) + "\n" + report += "Input tree: %r\n" % root_directory + report += ("=" * 80) + "\n" + + for input_path, output_path in files_to_process: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + file_count += 1 + _, l_report, l_errors = self.process_file(input_path, output_path) + tree_errors += l_errors + report += l_report + return file_count, report, tree_errors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Convert a TensorFlow Python file to 1.0 + +Simple usage: + tf_convert.py --infile foo.py --outfile bar.py + tf_convert.py --intree ~/code/old --outtree ~/code/new +""") + parser.add_argument( + "--infile", + dest="input_file", + help="If converting a single file, the name of the file " + "to convert") + parser.add_argument( + "--outfile", + dest="output_file", + help="If converting a single file, the output filename.") + parser.add_argument( + "--intree", + dest="input_tree", + help="If converting a whole tree of files, the directory " + "to read from (relative or absolute).") + parser.add_argument( + "--outtree", + dest="output_tree", + help="If converting a whole tree of files, the output " + "directory (relative or absolute).") + parser.add_argument( + "--reportfile", + dest="report_filename", + help=("The name of the file where the report log is " + "stored." + "(default: %(default)s)"), + default="report.txt") + args = parser.parse_args() + + upgrade = TensorFlowCodeUpgrader() + report_text = None + report_filename = args.report_filename + files_processed = 0 + if args.input_file: + files_processed, report_text, errors = upgrade.process_file( + args.input_file, args.output_file) + files_processed = 1 + elif args.input_tree: + files_processed, report_text, errors = upgrade.process_tree( + args.input_tree, args.output_tree) + else: + parser.print_help() + if report_text: + open(report_filename, "w").write(report_text) + print("TensorFlow 1.0 Upgrade Script") + print("-----------------------------") + print("Converted %d files\n" % files_processed) + print("Detected %d errors that require attention" % len(errors)) + print("-" * 80) + print("\n".join(errors)) + print("\nMake sure to read the detailed log %r\n" % report_filename) \ No newline at end of file From c8932c397dfacb56b09f3a2a9f1e6800171439bf Mon Sep 17 00:00:00 2001 From: Nithin Vasisth Date: Tue, 14 Mar 2017 23:23:52 +0530 Subject: [PATCH 2/2] Changes in batch_norm and reuse in VarScope --- infogan/misc/custom_ops.py | 34 ++++++++++--------------------- infogan/models/regularized_gan.py | 6 ++++-- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/infogan/misc/custom_ops.py b/infogan/misc/custom_ops.py index 14f309e..183215e 100644 --- a/infogan/misc/custom_ops.py +++ b/infogan/misc/custom_ops.py @@ -9,31 +9,19 @@ class conv_batch_norm(pt.VarStoreMethod): def __call__(self, input_layer, epsilon=1e-5, momentum=0.1, name="batch_norm", in_dim=None, phase=Phase.train): - self.ema = tf.train.ExponentialMovingAverage(decay=0.9) - shape = input_layer.shape shp = in_dim or shape[-1] - with tf.variable_scope(name) as scope: - self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02)) - self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.)) - - self.mean, self.variance = tf.nn.moments(input_layer.tensor, [0, 1, 2]) - # sigh...tf's shape system is so.. - self.mean.set_shape((shp,)) - self.variance.set_shape((shp,)) - self.ema_apply_op = self.ema.apply([self.mean, self.variance]) - - if phase == Phase.train: - with tf.control_dependencies([self.ema_apply_op]): - normalized_x = tf.nn.batch_norm_with_global_normalization( - input_layer.tensor, self.mean, self.variance, self.beta, self.gamma, epsilon, - scale_after_normalization=True) - else: - normalized_x = tf.nn.batch_norm_with_global_normalization( - x, self.ema.average(self.mean), self.ema.average(self.variance), self.beta, - self.gamma, epsilon, - scale_after_normalization=True) - return input_layer.with_tensor(normalized_x, parameters=self.vars) + self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02)) + self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.)) + + self.mean, self.variance = tf.nn.moments(input_layer, [0, 1, 2]) + # sigh...tf's shape system is so.. + self.mean.set_shape((shp,)) + self.variance.set_shape((shp,)) + + normalized_x = tf.nn.batch_normalization(input_layer, self.mean, + self.variance, None, None, epsilon) + return input_layer.with_tensor(normalized_x, parameters=self.vars) pt.Register(assign_defaults=('phase'))(conv_batch_norm) diff --git a/infogan/models/regularized_gan.py b/infogan/models/regularized_gan.py index 1d585c2..c881d85 100644 --- a/infogan/models/regularized_gan.py +++ b/infogan/models/regularized_gan.py @@ -13,6 +13,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty :type batch_size: int :type network_type: string """ + self.reuse = False self.output_dist = output_dist self.latent_spec = latent_spec self.latent_dist = Product([x for x, _ in latent_spec]) @@ -28,7 +29,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty image_size = image_shape[0] if network_type == "mnist": - with tf.variable_scope("d_net"): + with tf.variable_scope("d_net", reuse=self.reuse): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). @@ -48,7 +49,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) - with tf.variable_scope("g_net"): + with tf.variable_scope("g_net", reuse=self.reuse): self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). @@ -63,6 +64,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) + self.reuse = True else: raise NotImplementedError