diff --git a/gematria/granite/python/gnn_model_base.py b/gematria/granite/python/gnn_model_base.py index a9426f63..70dadd3c 100644 --- a/gematria/granite/python/gnn_model_base.py +++ b/gematria/granite/python/gnn_model_base.py @@ -61,6 +61,8 @@ class GraphNetworkLayer(tf.Module): features computed by the layer. When both a residual connection and layer normalization are used, the layer normalization op is inserted after the residual connection. + extra_node_inputs: Names of extra feed_dict members that should be passed + to the node model. """ # NOTE(ondrasej): This should be one of the classes defined in @@ -74,6 +76,7 @@ class GraphNetworkLayer(tf.Module): edges_output_size: Sequence[int] | None = None nodes_output_size: Sequence[int] | None = None globals_output_size: Sequence[int] | None = None + extra_node_inputs: Sequence[str] | None = None class GnnModelBase(model_base.ModelBase): @@ -376,7 +379,15 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple: ) for iteration in range(num_iterations): residual_input = graphs_tuple - graphs_tuple = layer.module(graphs_tuple) + extra_node_args = {} + if layer.extra_node_inputs is not None: + for extra_node_arg_name in layer.extra_node_inputs: + extra_node_args[extra_node_arg_name] = feed_dict[ + extra_node_arg_name + ] + graphs_tuple = layer.module( + graphs_tuple, node_model_kwargs=extra_node_args + ) if use_residual_connections: residual_op_name_base = ( f'residual_connection_{layer_index}_{iteration}' diff --git a/gematria/granite/python/token_graph_builder_model.py b/gematria/granite/python/token_graph_builder_model.py index 848b8299..d7b226bc 100644 --- a/gematria/granite/python/token_graph_builder_model.py +++ b/gematria/granite/python/token_graph_builder_model.py @@ -345,7 +345,6 @@ def _create_graph_network_modules( vocab_size=len(self._token_list), common_embed_dim=self._common_node_embedding_size, num_annotations=self._num_annotations, - model_ref=self, initializer=embedding_initializer, ), global_model_fn=functools.partial( @@ -364,6 +363,10 @@ def _create_graph_network_modules( num_iterations=1, layer_normalization=options.EnableFeature.NEVER, residual_connection=options.EnableFeature.NEVER, + extra_node_inputs=( + 'instruction_node_mask', + 'instruction_annotations', + ), ), gnn_model_base.GraphNetworkLayer( module=graph_nets.modules.GraphNetwork( @@ -410,7 +413,6 @@ def __init__( self, common_embed_dim, num_annotations, - model_ref, **kwargs, ) -> None: """Initializes node embeddings. @@ -420,11 +422,8 @@ def __init__( embedding vectors. The remainder of the vector is filled with instruction annotation. num_annotations: The number of annotations per instruction. - model_ref: A reference to the model to get the instruction node mask and - instruction annotations. kwargs: Additional arguments to be passed to the internal `snt.Embed`s. """ - self._model_ref = model_ref # The first `embed_dim - num_annotations` embedding values for all nodes. self._common_embed = snt.Embed( @@ -446,6 +445,8 @@ def __init__( def __call__( self, inputs, + instruction_node_mask, + instruction_annotations, ): if not self._extra_embed: return self._common_embed(inputs) @@ -458,10 +459,8 @@ def __call__( common_embeddings, tf.tensor_scatter_nd_update( extra_embeddings, - indices=tf.where( - self._model_ref._instruction_node_mask, - ), - updates=self._model_ref._instruction_annotations, + indices=tf.where(instruction_node_mask), + updates=instruction_annotations, ), ], axis=1,