From 5b9dde51a891d518c8d47877820f4d81b40c1f7d Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Sun, 18 May 2025 04:14:12 +0000 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- gematria/granite/python/gnn_model_base.py | 21 +- .../python/token_graph_builder_model.py | 17 +- gematria/model/python/loss_utils.py | 296 ++++++++++-------- gematria/model/python/loss_utils_test.py | 90 +++--- gematria/model/python/model_base.py | 20 +- 5 files changed, 246 insertions(+), 198 deletions(-) diff --git a/gematria/granite/python/gnn_model_base.py b/gematria/granite/python/gnn_model_base.py index c2b33e0e..fa3dcf73 100644 --- a/gematria/granite/python/gnn_model_base.py +++ b/gematria/granite/python/gnn_model_base.py @@ -61,6 +61,8 @@ class GraphNetworkLayer(tf.Module): features computed by the layer. When both a residual connection and layer normalization are used, the layer normalization op is inserted after the residual connection. + extra_node_inputs: Names of extra feed_dict members that should be passed + to the node model. """ # NOTE(ondrasej): This should be one of the classes defined in @@ -74,6 +76,7 @@ class GraphNetworkLayer(tf.Module): edges_output_size: Sequence[int] | None = None nodes_output_size: Sequence[int] | None = None globals_output_size: Sequence[int] | None = None + extra_node_inputs: Sequence[str] | None = None class GnnModelBase(model_base.ModelBase): @@ -210,6 +213,14 @@ def __init__( self._graph_module_residual_connections = graph_module_residual_connections self._graph_module_layer_normalization = graph_module_layer_normalization + @property + def trainable_variables(self): + trainable_vars = set([var.ref() for var in super().trainable_variables]) + for layer in self._graph_network: + layer_vars = [var.ref() for var in layer.module.trainable_variables] + trainable_vars.update(layer_vars) + return tuple([var.deref() for var in trainable_vars]) + def initialize(self): super().initialize() self._graph_network = self._create_graph_network_modules() @@ -368,7 +379,15 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple: ) for iteration in range(num_iterations): residual_input = graphs_tuple - graphs_tuple = layer.module(graphs_tuple) + extra_node_args = {} + if layer.extra_node_inputs is not None: + for extra_node_arg_name in layer.extra_node_inputs: + extra_node_args[extra_node_arg_name] = feed_dict[ + extra_node_arg_name + ] + graphs_tuple = layer.module( + graphs_tuple, node_model_kwargs=extra_node_args + ) if use_residual_connections: residual_op_name_base = ( f'residual_connection_{layer_index}_{iteration}' diff --git a/gematria/granite/python/token_graph_builder_model.py b/gematria/granite/python/token_graph_builder_model.py index 848b8299..f63da407 100644 --- a/gematria/granite/python/token_graph_builder_model.py +++ b/gematria/granite/python/token_graph_builder_model.py @@ -345,7 +345,6 @@ def _create_graph_network_modules( vocab_size=len(self._token_list), common_embed_dim=self._common_node_embedding_size, num_annotations=self._num_annotations, - model_ref=self, initializer=embedding_initializer, ), global_model_fn=functools.partial( @@ -364,6 +363,10 @@ def _create_graph_network_modules( num_iterations=1, layer_normalization=options.EnableFeature.NEVER, residual_connection=options.EnableFeature.NEVER, + extra_node_inputs=[ + 'instruction_node_mask', + 'instruction_annotations', + ], ), gnn_model_base.GraphNetworkLayer( module=graph_nets.modules.GraphNetwork( @@ -410,7 +413,6 @@ def __init__( self, common_embed_dim, num_annotations, - model_ref, **kwargs, ) -> None: """Initializes node embeddings. @@ -420,11 +422,8 @@ def __init__( embedding vectors. The remainder of the vector is filled with instruction annotation. num_annotations: The number of annotations per instruction. - model_ref: A reference to the model to get the instruction node mask and - instruction annotations. kwargs: Additional arguments to be passed to the internal `snt.Embed`s. """ - self._model_ref = model_ref # The first `embed_dim - num_annotations` embedding values for all nodes. self._common_embed = snt.Embed( @@ -446,6 +445,8 @@ def __init__( def __call__( self, inputs, + instruction_node_mask, + instruction_annotations, ): if not self._extra_embed: return self._common_embed(inputs) @@ -458,10 +459,8 @@ def __call__( common_embeddings, tf.tensor_scatter_nd_update( extra_embeddings, - indices=tf.where( - self._model_ref._instruction_node_mask, - ), - updates=self._model_ref._instruction_annotations, + indices=tf.where(instruction_node_mask), + updates=instruction_annotations, ), ], axis=1, diff --git a/gematria/model/python/loss_utils.py b/gematria/model/python/loss_utils.py index 06d834fd..a5a75a20 100644 --- a/gematria/model/python/loss_utils.py +++ b/gematria/model/python/loss_utils.py @@ -20,14 +20,17 @@ import tensorflow_probability as tfp -# Type of keys used when caching the loss tensors generated by a LossComputation -# object. -_LossTensorType = tuple[options.LossType, options.ErrorNormalization] - - -class LossComputation: +class LossComputation(tf.experimental.ExtensionType): """Maintains TF ops for computing loss from actual and expected outputs.""" + loss_tensor: tf.Tensor + mean_absolute_error: tf.Tensor + mean_squared_error: tf.Tensor + mean_absolute_percentage_error: tf.Tensor + mean_squared_percentage_error: tf.Tensor + absolute_error_percentiles: tf.Tensor + absolute_percentage_error_percentiles: tf.Tensor + def __init__( self, output_values: tf.Tensor, @@ -35,6 +38,8 @@ def __init__( mask: tf.Tensor, dtype: tf.dtypes.DType, percentile_ranks: Sequence[int] = (), + normalization: options.ErrorNormalization = options.ErrorNormalization.NONE, + loss_type: options.LossType = options.LossType.MEAN_SQUARED_ERROR, ): """Initializes the loss computation. @@ -61,7 +66,7 @@ def __init__( f'{output_values.shape}. Found {expected_outputs.shape}' ) - self._num_tasks = output_values.shape[1] or expected_outputs.shape[1] + num_tasks = output_values.shape[1] or expected_outputs.shape[1] if not mask.shape.is_compatible_with(output_values.shape): raise ValueError( 'Expected mask.shape to be compatible with' @@ -72,120 +77,109 @@ def __init__( f'Expected mask.dtype to be tf.dtypes.bool. Found {mask.dtype}.' ) - self._percentile_ranks = percentile_ranks - self._dtype = dtype - self._loss_tensors: dict[_LossTensorType, tf.Tensor] = {} - # tf.ragged.boolean_mask() does not have an `axis` argument to control which # dimension is ragged and in case of 2D tensors it is always the second one. # We transpose the data so that the first (non-ragged) dimension goes along # tasks, and the second (ragged) dimension goes along the values. # All the tensors below have the shape - self._mask = tf.transpose(mask) - self._output_values = tf.ragged.boolean_mask( - tf.transpose(output_values), self._mask - ) - assert self._output_values.shape.is_compatible_with((self._num_tasks, None)) - self._expected_outputs = tf.ragged.boolean_mask( - tf.transpose(expected_outputs), self._mask - ) - assert self._expected_outputs.shape.is_compatible_with( - (self._num_tasks, None) + mask = tf.transpose(mask) + output_values = tf.ragged.boolean_mask(tf.transpose(output_values), mask) + assert output_values.shape.is_compatible_with((num_tasks, None)) + expected_outputs = tf.ragged.boolean_mask( + tf.transpose(expected_outputs), mask ) + assert expected_outputs.shape.is_compatible_with((num_tasks, None)) - self._delta = self._output_values - self._expected_outputs - assert self._delta.shape.is_compatible_with((self._num_tasks, None)) + delta = output_values - expected_outputs + assert delta.shape.is_compatible_with((num_tasks, None)) - self._squared_errors = tf.square(self._delta) - assert self._squared_errors.shape.is_compatible_with( - (self._num_tasks, None) - ) - self._absolute_errors = tf.abs(self._delta) - assert self._absolute_errors.shape.is_compatible_with( - (self._num_tasks, None) - ) - self._absolute_percentage_errors = ( - self._absolute_errors / self._expected_outputs + squared_errors = tf.square(delta) + assert squared_errors.shape.is_compatible_with((num_tasks, None)) + absolute_errors = tf.abs(delta) + assert absolute_errors.shape.is_compatible_with((num_tasks, None)) + absolute_percentage_errors = absolute_errors / expected_outputs + assert absolute_percentage_errors.shape.is_compatible_with( + (num_tasks, None) ) - assert self._absolute_percentage_errors.shape.is_compatible_with( - (self._num_tasks, None) - ) - self._squared_percentage_error = tf.square(self._absolute_percentage_errors) - assert self._squared_percentage_error.shape.is_compatible_with( - (self._num_tasks, None) - ) - self._absolute_error_percentiles = self._make_percentile_tensor( - self._absolute_errors + squared_percentage_error = tf.square(absolute_percentage_errors) + assert squared_percentage_error.shape.is_compatible_with((num_tasks, None)) + self.absolute_error_percentiles = self._make_percentile_tensor( + absolute_errors, num_tasks, percentile_ranks, dtype ) assert ( - not self._percentile_ranks - or self._absolute_error_percentiles.shape.is_compatible_with( - (len(self._percentile_ranks), self._num_tasks) + not percentile_ranks + or self.absolute_error_percentiles.shape.is_compatible_with( + (len(percentile_ranks), num_tasks) ) ) - self._absolute_percentage_error_percentiles = self._make_percentile_tensor( - self._absolute_percentage_errors + self.absolute_percentage_error_percentiles = self._make_percentile_tensor( + absolute_percentage_errors, num_tasks, percentile_ranks, dtype ) assert ( - not self._percentile_ranks - or self._absolute_percentage_error_percentiles.shape.is_compatible_with( - (len(self._percentile_ranks), self._num_tasks) + not percentile_ranks + or self.absolute_percentage_error_percentiles.shape.is_compatible_with( + (len(percentile_ranks), num_tasks) ) ) # The absolute value of expected_outputs. Contains 1.0 in place of values # that are smaller than one. - self._absolute_expected_outputs_or_one = tf.math.maximum( - self._expected_outputs, - tf.ones_like(self._expected_outputs, dtype=dtype), + absolute_expected_outputs_or_one = tf.math.maximum( + expected_outputs, + tf.ones_like(expected_outputs, dtype=dtype), ) - assert self._absolute_expected_outputs_or_one.shape.is_compatible_with( - (self._num_tasks, None) + assert absolute_expected_outputs_or_one.shape.is_compatible_with( + (num_tasks, None) ) - @property - def mean_absolute_error(self) -> tf.Tensor: - """Returns the mean absolute error.""" - return self.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.MEAN_ABSOLUTE_ERROR - ) + tensor_args = { + 'num_tasks': num_tasks, + 'dtype': dtype, + 'delta': delta, + 'squared_errors': squared_errors, + 'absolute_errors': absolute_errors, + 'absolute_percentage_errors': absolute_percentage_errors, + 'squared_percentage_error': squared_percentage_error, + 'absolute_expected_outputs_or_one': absolute_expected_outputs_or_one, + } - @property - def mean_squared_error(self) -> tf.Tensor: - """Returns the mean squared error.""" - return self.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.MEAN_SQUARED_ERROR + self.mean_absolute_error = self._loss_tensor( + options.ErrorNormalization.NONE, + options.LossType.MEAN_ABSOLUTE_ERROR, + **tensor_args, ) - - @property - def mean_absolute_percentage_error(self) -> tf.Tensor: - """Returns the mean absolute percentager error.""" - return self.loss_tensor( + self.mean_squared_error = self._loss_tensor( + options.ErrorNormalization.NONE, + options.LossType.MEAN_SQUARED_ERROR, + **tensor_args, + ) + self.mean_absolute_percentage_error = self._loss_tensor( options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.MEAN_ABSOLUTE_ERROR, + **tensor_args, ) - - @property - def mean_squared_percentage_error(self) -> tf.Tensor: - """Returns the mean squared percentage error.""" - return self.loss_tensor( + self.mean_squared_percentage_error = self._loss_tensor( options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.MEAN_SQUARED_ERROR, + **tensor_args, + ) + self.loss_tensor = self._loss_tensor( + normalization, + loss_type, + **tensor_args, ) - @property - def absolute_error_percentiles(self) -> tf.Tensor: - """Returns the percentiles of the absolute error.""" - return self._absolute_error_percentiles - - @property - def absolute_percentage_error_percentiles(self) -> tf.Tensor: - """Returns the percentiles of the absolute percentage error.""" - return self._absolute_percentage_error_percentiles - - def loss_tensor( + def _loss_tensor( self, normalization: options.ErrorNormalization, loss_type: options.LossType, + num_tasks: int, + dtype: tf.dtypes.DType, + delta: tf.Tensor, + squared_errors: tf.Tensor, + absolute_errors: tf.Tensor, + absolute_percentage_errors: tf.Tensor, + squared_percentage_error: tf.Tensor, + absolute_expected_outputs_or_one: tf.Tensor, ) -> tf.Tensor: """Returns a loss tensor of the given type. @@ -193,104 +187,140 @@ def loss_tensor( normalization: Determines whether and how the errors in the loss tensor are normalized. loss_type: The type of loss. + num_tasks: The number of tasks for the current model. Used for validation + purposes. + dtype: The Tensorflow DType used by the model. Returns: A tensor that contains the requested loss. When called multiple times with the same arguments, this method will always return the same tensor object. The returned tensor is of shape (N, T), where T is the number of tasks. """ - tensor = self._loss_tensors.get((loss_type, normalization)) - if tensor is None: - match loss_type: - case options.LossType.MEAN_SQUARED_ERROR: - tensor = tf.reduce_mean( - self._squared_errors_witn_normalization(normalization), axis=1 - ) - case options.LossType.MEAN_ABSOLUTE_ERROR: - tensor = tf.reduce_mean( - self._absolute_errors_with_normalization(normalization), axis=1 - ) - case options.LossType.HUBER: - absolute_errors = self._absolute_errors_with_normalization( - normalization - ) - # The delta parameter from the Huber loss definition. - huber_delta = tf.constant(1.0, dtype=self._dtype) - # The expression in the quadratic part of the Huber loss expression. - # It is squared in the return statement below. - quadratic = tf.minimum(absolute_errors, huber_delta) - # The linear part of the Huber loss expression. This is zero when - # absolute_error <= huber_delta. - linear = absolute_errors - quadratic - tensor = tf.reduce_mean( - 0.5 * tf.square(quadratic) + huber_delta * linear, axis=1 - ) - case _: - raise ValueError(f'Unexpected loss type: {loss_type}') - assert tensor.shape.is_compatible_with(( - self._num_tasks, - )), f'The actual shape is {tensor.shape}' - self._loss_tensors[loss_type, normalization] = tensor + match loss_type: + case options.LossType.MEAN_SQUARED_ERROR: + tensor = tf.reduce_mean( + self._squared_errors_witn_normalization( + normalization, + num_tasks, + delta, + squared_errors, + squared_percentage_error, + absolute_expected_outputs_or_one, + ), + axis=1, + ) + case options.LossType.MEAN_ABSOLUTE_ERROR: + tensor = tf.reduce_mean( + self._absolute_errors_with_normalization( + normalization, + absolute_errors, + absolute_percentage_errors, + absolute_expected_outputs_or_one, + ), + axis=1, + ) + case options.LossType.HUBER: + absolute_errors = self._absolute_errors_with_normalization( + normalization, + absolute_errors, + absolute_percentage_errors, + absolute_expected_outputs_or_one, + ) + # The delta parameter from the Huber loss definition. + huber_delta = tf.constant(1.0, dtype=dtype) + # The expression in the quadratic part of the Huber loss expression. + # It is squared in the return statement below. + quadratic = tf.minimum(absolute_errors, huber_delta) + # The linear part of the Huber loss expression. This is zero when + # absolute_error <= huber_delta. + linear = absolute_errors - quadratic + tensor = tf.reduce_mean( + 0.5 * tf.square(quadratic) + huber_delta * linear, axis=1 + ) + case _: + raise ValueError(f'Unexpected loss type: {loss_type}') + assert tensor.shape.is_compatible_with(( + num_tasks, + )), f'The actual shape is {tensor.shape}' return tensor - def _make_percentile_tensor(self, values: tf.RaggedTensor) -> tf.Tensor: + def _make_percentile_tensor( + self, + values: tf.RaggedTensor, + num_tasks: int, + percentile_ranks: Sequence[int], + dtype: tf.dtypes.DType, + ) -> tf.Tensor: """Creates a percentile tensor from 'values' using self.percentile_ranks. Args: values: A 2D ragged tensor from which the percentiles are collected. The percentiles are collected along the axis 0 of `values`. + num_tasks: The number of tasks for the current model. Used for validation + purposes. + percentile_ranks: The percentile ranks to use for calculating the tensor. Returns: Percentiles based on self_percentile_ranks and the values. The returned tensor is of shape (N_PERCENTILE_RANKS, T), where T is the number of tasks. """ - if not self._percentile_ranks: - return tf.constant([], dtype=self._dtype) + if not percentile_ranks: + return tf.constant([], dtype=dtype) percentile_tensors = [] # NOTE(ondrasej): As of Nov 2022, tfp.stats.percentile() is not compatible # with ragged tensors, so we need to split the ragged tensor into rows and # then stack the individual percentile tensors to the desired output shape. - for task in range(self._num_tasks): + for task in range(num_tasks): task_values = values[task] percentile_tensors.append( - tfp.stats.percentile(task_values, self._percentile_ranks) + tfp.stats.percentile(task_values, percentile_ranks) ) assert percentile_tensors[-1].shape.is_compatible_with((None,)) return tf.stack(percentile_tensors, axis=1) def _squared_errors_witn_normalization( - self, normalization: options.ErrorNormalization + self, + normalization: options.ErrorNormalization, + num_tasks: int, + delta: tf.Tensor, + squared_errors: tf.Tensor, + squared_percentage_error: tf.Tensor, + absolute_expected_outputs_or_one: tf.Tensor, ) -> tf.Tensor: """Returns the tensor of squared errors.""" match normalization: case options.ErrorNormalization.NONE: - result = self._squared_errors + result = squared_errors case options.ErrorNormalization.PERCENTAGE_ERROR: - result = self._squared_percentage_error + result = squared_percentage_error case options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE: - result = tf.square(self._delta / self._absolute_expected_outputs_or_one) + result = tf.square(delta / absolute_expected_outputs_or_one) case _: raise NotImplementedError( f'Squared errors not implemented yet: {normalization}' ) assert result.shape.is_compatible_with( - (self._num_tasks, None) + (num_tasks, None) ), f'Actual shape of the squared errors tensor is {result.shape}' return result def _absolute_errors_with_normalization( - self, normalization: options.ErrorNormalization + self, + normalization: options.ErrorNormalization, + absolute_errors: tf.Tensor, + absolute_percentage_errors: tf.Tensor, + absolute_expected_outputs_or_one: tf.Tensor, ) -> tf.Tensor: """Returns the tensor of absolute errors.""" match normalization: case options.ErrorNormalization.NONE: - return self._absolute_errors + return absolute_errors case options.ErrorNormalization.PERCENTAGE_ERROR: - return self._absolute_percentage_errors + return absolute_percentage_errors case options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE: - return self._absolute_errors / self._absolute_expected_outputs_or_one + return absolute_errors / absolute_expected_outputs_or_one case _: raise NotImplementedError( 'Absolute errors not implemented yet: {normalization}' diff --git a/gematria/model/python/loss_utils_test.py b/gematria/model/python/loss_utils_test.py index 38e3bd8a..20b7f440 100644 --- a/gematria/model/python/loss_utils_test.py +++ b/gematria/model/python/loss_utils_test.py @@ -60,18 +60,16 @@ def test_unscaled_loss(self): self.full_mask, percentile_ranks=(0, 10, 50, 75, 100), dtype=self.dtype, + loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) percentiles = loss.absolute_error_percentiles self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0,), (0.5,), (3,), (4,))) @@ -101,23 +99,33 @@ def test_normalized_loss_when_expected_value_greater_than_one(self): ((1,), (4,), (3,), (0.5,), (0,)), dtype=self.dtype ) mask = tf.ones_like(actual_outputs, dtype=tf.dtypes.bool) - loss = loss_utils.LossComputation( - actual_outputs, expected_outputs, mask, dtype=self.dtype - ) + loss_params = { + 'output_values': tf.constant( + ((1.3,), (-2,), (3,), (1,), (2,)), dtype=self.dtype + ), + 'expected_outputs': tf.constant( + ((1,), (4,), (3,), (0.5,), (0,)), dtype=self.dtype + ), + 'mask': tf.ones_like(actual_outputs, dtype=tf.dtypes.bool), + 'dtype': self.dtype, + } - mean_absolute_error = loss.loss_tensor( - options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, - options.LossType.MEAN_ABSOLUTE_ERROR, + mean_absolute_error = loss_utils.LossComputation( + **loss_params, + normalization=options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, + loss_type=options.LossType.MEAN_ABSOLUTE_ERROR ) - mean_squared_error = loss.loss_tensor( - options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, - options.LossType.MEAN_SQUARED_ERROR, + mean_squared_error = loss_utils.LossComputation( + **loss_params, + normalization=options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, + loss_type=options.LossType.MEAN_SQUARED_ERROR ) self.assertAlmostEqual( - float(mean_absolute_error), (0.3 + 1.5 + 0.0 + 0.5 + 2.0) / 5 + float(mean_absolute_error.loss_tensor), + (0.3 + 1.5 + 0.0 + 0.5 + 2.0) / 5, ) self.assertAlmostEqual( - float(mean_squared_error), + float(mean_squared_error.loss_tensor), (0.3**2 + 1.5**2 + 0.0 + 0.5**2 + 2.0**2) / 5, delta=1e-6, ) @@ -144,13 +152,12 @@ def test_multi_task(self): self.multitask_full_mask, percentile_ranks=(0, 10, 50, 75, 100), dtype=self.dtype, + loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) + huber = loss.loss_tensor percentiles = loss.absolute_error_percentiles self.assertAllClose( mse, @@ -176,35 +183,36 @@ def test_multi_task(self): ) def test_multi_task_with_mask(self): - loss = loss_utils.LossComputation( - output_values=tf.constant( + loss_params = { + 'output_values': tf.constant( ((1, 20), (2, 12.1), (3, 100), (50, 150), (2.5, 12.5)), dtype=self.dtype, ), - expected_outputs=tf.constant( + 'expected_outputs': tf.constant( ((4, 14), (500, 12), (30, 13), (1, 11), (2, 12)), dtype=self.dtype ), - mask=tf.constant(( + 'mask': tf.constant(( (True, True), (False, True), (True, False), (False, False), (True, False), )), - percentile_ranks=(0, 10, 50, 75, 100), - dtype=self.dtype, + 'percentile_ranks': (0, 10, 50, 75, 100), + 'dtype': self.dtype, + 'loss_type': options.LossType.HUBER, + } + loss = loss_utils.LossComputation(**loss_params) + loss_percentage_normalized = loss_utils.LossComputation( + **loss_params, normalization=options.ErrorNormalization.PERCENTAGE_ERROR ) mse = loss.mean_squared_error mspe = loss.mean_squared_percentage_error mae = loss.mean_absolute_error mape = loss.mean_absolute_percentage_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) - huber_percentage = loss.loss_tensor( - options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.HUBER - ) + huber = loss.loss_tensor + huber_percentage = loss_percentage_normalized.loss_tensor percentiles = loss.absolute_error_percentiles self.assertAllClose( mse, @@ -252,24 +260,22 @@ def test_unknown_shape(self): tf.ones_like(self.actual_outputs, dtype=bool), percentile_ranks=percentile_ranks, dtype=self.dtype, + loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (1,)) self.assertEqual(mae.shape, (1,)) - self.assertEqual(huber.shape, (1,)) + self.assertEqual(loss.loss_tensor.shape, (1,)) self.assertEqual(percentiles.shape, (len(percentile_ranks), 1)) self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,))) @@ -283,13 +289,12 @@ def test_multi_task_unknown_shape(self): tf.ones_like(self.multitask_actual_outputs_array, dtype=bool), percentile_ranks=percentile_ranks, dtype=self.dtype, + loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) + huber = loss.loss_tensor percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (num_tasks,)) @@ -331,24 +336,23 @@ def test_single_task_unknown_shape(self): mask, percentile_ranks=percentile_ranks, dtype=self.dtype, + normalization=options.ErrorNormalization.NONE, + loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor( - options.ErrorNormalization.NONE, options.LossType.HUBER - ) percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (num_tasks,)) self.assertEqual(mae.shape, (num_tasks,)) - self.assertEqual(huber.shape, (num_tasks,)) + self.assertEqual(loss.loss_tensor.shape, (num_tasks,)) self.assertEqual(percentiles.shape, (len(percentile_ranks), num_tasks)) self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,))) diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index 6bae4319..d499fd4f 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -335,7 +335,6 @@ def __init__( self._decayed_learning_rate = None self._loss: Optional[loss_utils.LossComputation] = None - self._delta_loss: Optional[loss_utils.LossComputation] = None self._train_step: Optional[tf.Operation] = None self._optimizer: Union[ tf.train.Optimizer, tf.train.SyncReplicasOptimizer @@ -1308,28 +1307,28 @@ def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation: tf.constant(schedule['output_mask']), percentile_ranks=self._collected_percentile_ranks, dtype=self.dtype, + normalization=self._loss_normalization, + loss_type=self._loss_type, ) if self._use_deltas: - self._delta_loss = loss_utils.LossComputation( + delta_loss = loss_utils.LossComputation( output['output_deltas'], tf.constant(schedule['expected_outputs_deltas']), output['output_mask_deltas'], percentile_ranks=self._collected_percentile_ranks, dtype=self.dtype, + normalization=self._loss_normalization, + loss_type=self._loss_type, ) if self._use_delta_loss: - loss = self._delta_loss + loss = delta_loss return loss def compute_loss_tensor(self, schedule: FeedDict): - return tf.reduce_mean( - self._compute_loss(schedule).loss_tensor( - self._loss_normalization, self._loss_type - ) - ) + return tf.reduce_mean(self._compute_loss(schedule).loss_tensor) def train_batch( self, @@ -1350,10 +1349,7 @@ def train_batch( with tf.GradientTape() as tape: stats = {} loss = self._compute_loss(schedule) - loss_tensor_per_task = loss.loss_tensor( - self._loss_normalization, self._loss_type - ) - loss_tensor = tf.reduce_mean(loss_tensor_per_task) + loss_tensor = tf.reduce_mean(loss.loss_tensor) # The list of variables to optimize. By default, the list is empty which # means optimize all trainable variables. From f1b902f784bc56f4b878979213f421a1516e396b Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Mon, 19 May 2025 16:18:07 +0000 Subject: [PATCH 2/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20introduced=20through=20rebase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- gematria/model/python/loss_utils.py | 296 ++++++++++------------- gematria/model/python/loss_utils_test.py | 90 ++++--- gematria/model/python/model_base.py | 20 +- 3 files changed, 188 insertions(+), 218 deletions(-) diff --git a/gematria/model/python/loss_utils.py b/gematria/model/python/loss_utils.py index a5a75a20..06d834fd 100644 --- a/gematria/model/python/loss_utils.py +++ b/gematria/model/python/loss_utils.py @@ -20,16 +20,13 @@ import tensorflow_probability as tfp -class LossComputation(tf.experimental.ExtensionType): - """Maintains TF ops for computing loss from actual and expected outputs.""" +# Type of keys used when caching the loss tensors generated by a LossComputation +# object. +_LossTensorType = tuple[options.LossType, options.ErrorNormalization] + - loss_tensor: tf.Tensor - mean_absolute_error: tf.Tensor - mean_squared_error: tf.Tensor - mean_absolute_percentage_error: tf.Tensor - mean_squared_percentage_error: tf.Tensor - absolute_error_percentiles: tf.Tensor - absolute_percentage_error_percentiles: tf.Tensor +class LossComputation: + """Maintains TF ops for computing loss from actual and expected outputs.""" def __init__( self, @@ -38,8 +35,6 @@ def __init__( mask: tf.Tensor, dtype: tf.dtypes.DType, percentile_ranks: Sequence[int] = (), - normalization: options.ErrorNormalization = options.ErrorNormalization.NONE, - loss_type: options.LossType = options.LossType.MEAN_SQUARED_ERROR, ): """Initializes the loss computation. @@ -66,7 +61,7 @@ def __init__( f'{output_values.shape}. Found {expected_outputs.shape}' ) - num_tasks = output_values.shape[1] or expected_outputs.shape[1] + self._num_tasks = output_values.shape[1] or expected_outputs.shape[1] if not mask.shape.is_compatible_with(output_values.shape): raise ValueError( 'Expected mask.shape to be compatible with' @@ -77,109 +72,120 @@ def __init__( f'Expected mask.dtype to be tf.dtypes.bool. Found {mask.dtype}.' ) + self._percentile_ranks = percentile_ranks + self._dtype = dtype + self._loss_tensors: dict[_LossTensorType, tf.Tensor] = {} + # tf.ragged.boolean_mask() does not have an `axis` argument to control which # dimension is ragged and in case of 2D tensors it is always the second one. # We transpose the data so that the first (non-ragged) dimension goes along # tasks, and the second (ragged) dimension goes along the values. # All the tensors below have the shape - mask = tf.transpose(mask) - output_values = tf.ragged.boolean_mask(tf.transpose(output_values), mask) - assert output_values.shape.is_compatible_with((num_tasks, None)) - expected_outputs = tf.ragged.boolean_mask( - tf.transpose(expected_outputs), mask + self._mask = tf.transpose(mask) + self._output_values = tf.ragged.boolean_mask( + tf.transpose(output_values), self._mask + ) + assert self._output_values.shape.is_compatible_with((self._num_tasks, None)) + self._expected_outputs = tf.ragged.boolean_mask( + tf.transpose(expected_outputs), self._mask + ) + assert self._expected_outputs.shape.is_compatible_with( + (self._num_tasks, None) ) - assert expected_outputs.shape.is_compatible_with((num_tasks, None)) - delta = output_values - expected_outputs - assert delta.shape.is_compatible_with((num_tasks, None)) + self._delta = self._output_values - self._expected_outputs + assert self._delta.shape.is_compatible_with((self._num_tasks, None)) - squared_errors = tf.square(delta) - assert squared_errors.shape.is_compatible_with((num_tasks, None)) - absolute_errors = tf.abs(delta) - assert absolute_errors.shape.is_compatible_with((num_tasks, None)) - absolute_percentage_errors = absolute_errors / expected_outputs - assert absolute_percentage_errors.shape.is_compatible_with( - (num_tasks, None) + self._squared_errors = tf.square(self._delta) + assert self._squared_errors.shape.is_compatible_with( + (self._num_tasks, None) + ) + self._absolute_errors = tf.abs(self._delta) + assert self._absolute_errors.shape.is_compatible_with( + (self._num_tasks, None) + ) + self._absolute_percentage_errors = ( + self._absolute_errors / self._expected_outputs ) - squared_percentage_error = tf.square(absolute_percentage_errors) - assert squared_percentage_error.shape.is_compatible_with((num_tasks, None)) - self.absolute_error_percentiles = self._make_percentile_tensor( - absolute_errors, num_tasks, percentile_ranks, dtype + assert self._absolute_percentage_errors.shape.is_compatible_with( + (self._num_tasks, None) + ) + self._squared_percentage_error = tf.square(self._absolute_percentage_errors) + assert self._squared_percentage_error.shape.is_compatible_with( + (self._num_tasks, None) + ) + self._absolute_error_percentiles = self._make_percentile_tensor( + self._absolute_errors ) assert ( - not percentile_ranks - or self.absolute_error_percentiles.shape.is_compatible_with( - (len(percentile_ranks), num_tasks) + not self._percentile_ranks + or self._absolute_error_percentiles.shape.is_compatible_with( + (len(self._percentile_ranks), self._num_tasks) ) ) - self.absolute_percentage_error_percentiles = self._make_percentile_tensor( - absolute_percentage_errors, num_tasks, percentile_ranks, dtype + self._absolute_percentage_error_percentiles = self._make_percentile_tensor( + self._absolute_percentage_errors ) assert ( - not percentile_ranks - or self.absolute_percentage_error_percentiles.shape.is_compatible_with( - (len(percentile_ranks), num_tasks) + not self._percentile_ranks + or self._absolute_percentage_error_percentiles.shape.is_compatible_with( + (len(self._percentile_ranks), self._num_tasks) ) ) # The absolute value of expected_outputs. Contains 1.0 in place of values # that are smaller than one. - absolute_expected_outputs_or_one = tf.math.maximum( - expected_outputs, - tf.ones_like(expected_outputs, dtype=dtype), + self._absolute_expected_outputs_or_one = tf.math.maximum( + self._expected_outputs, + tf.ones_like(self._expected_outputs, dtype=dtype), ) - assert absolute_expected_outputs_or_one.shape.is_compatible_with( - (num_tasks, None) + assert self._absolute_expected_outputs_or_one.shape.is_compatible_with( + (self._num_tasks, None) ) - tensor_args = { - 'num_tasks': num_tasks, - 'dtype': dtype, - 'delta': delta, - 'squared_errors': squared_errors, - 'absolute_errors': absolute_errors, - 'absolute_percentage_errors': absolute_percentage_errors, - 'squared_percentage_error': squared_percentage_error, - 'absolute_expected_outputs_or_one': absolute_expected_outputs_or_one, - } - - self.mean_absolute_error = self._loss_tensor( - options.ErrorNormalization.NONE, - options.LossType.MEAN_ABSOLUTE_ERROR, - **tensor_args, + @property + def mean_absolute_error(self) -> tf.Tensor: + """Returns the mean absolute error.""" + return self.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.MEAN_ABSOLUTE_ERROR ) - self.mean_squared_error = self._loss_tensor( - options.ErrorNormalization.NONE, - options.LossType.MEAN_SQUARED_ERROR, - **tensor_args, + + @property + def mean_squared_error(self) -> tf.Tensor: + """Returns the mean squared error.""" + return self.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.MEAN_SQUARED_ERROR ) - self.mean_absolute_percentage_error = self._loss_tensor( + + @property + def mean_absolute_percentage_error(self) -> tf.Tensor: + """Returns the mean absolute percentager error.""" + return self.loss_tensor( options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.MEAN_ABSOLUTE_ERROR, - **tensor_args, ) - self.mean_squared_percentage_error = self._loss_tensor( + + @property + def mean_squared_percentage_error(self) -> tf.Tensor: + """Returns the mean squared percentage error.""" + return self.loss_tensor( options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.MEAN_SQUARED_ERROR, - **tensor_args, - ) - self.loss_tensor = self._loss_tensor( - normalization, - loss_type, - **tensor_args, ) - def _loss_tensor( + @property + def absolute_error_percentiles(self) -> tf.Tensor: + """Returns the percentiles of the absolute error.""" + return self._absolute_error_percentiles + + @property + def absolute_percentage_error_percentiles(self) -> tf.Tensor: + """Returns the percentiles of the absolute percentage error.""" + return self._absolute_percentage_error_percentiles + + def loss_tensor( self, normalization: options.ErrorNormalization, loss_type: options.LossType, - num_tasks: int, - dtype: tf.dtypes.DType, - delta: tf.Tensor, - squared_errors: tf.Tensor, - absolute_errors: tf.Tensor, - absolute_percentage_errors: tf.Tensor, - squared_percentage_error: tf.Tensor, - absolute_expected_outputs_or_one: tf.Tensor, ) -> tf.Tensor: """Returns a loss tensor of the given type. @@ -187,140 +193,104 @@ def _loss_tensor( normalization: Determines whether and how the errors in the loss tensor are normalized. loss_type: The type of loss. - num_tasks: The number of tasks for the current model. Used for validation - purposes. - dtype: The Tensorflow DType used by the model. Returns: A tensor that contains the requested loss. When called multiple times with the same arguments, this method will always return the same tensor object. The returned tensor is of shape (N, T), where T is the number of tasks. """ - match loss_type: - case options.LossType.MEAN_SQUARED_ERROR: - tensor = tf.reduce_mean( - self._squared_errors_witn_normalization( - normalization, - num_tasks, - delta, - squared_errors, - squared_percentage_error, - absolute_expected_outputs_or_one, - ), - axis=1, - ) - case options.LossType.MEAN_ABSOLUTE_ERROR: - tensor = tf.reduce_mean( - self._absolute_errors_with_normalization( - normalization, - absolute_errors, - absolute_percentage_errors, - absolute_expected_outputs_or_one, - ), - axis=1, - ) - case options.LossType.HUBER: - absolute_errors = self._absolute_errors_with_normalization( - normalization, - absolute_errors, - absolute_percentage_errors, - absolute_expected_outputs_or_one, - ) - # The delta parameter from the Huber loss definition. - huber_delta = tf.constant(1.0, dtype=dtype) - # The expression in the quadratic part of the Huber loss expression. - # It is squared in the return statement below. - quadratic = tf.minimum(absolute_errors, huber_delta) - # The linear part of the Huber loss expression. This is zero when - # absolute_error <= huber_delta. - linear = absolute_errors - quadratic - tensor = tf.reduce_mean( - 0.5 * tf.square(quadratic) + huber_delta * linear, axis=1 - ) - case _: - raise ValueError(f'Unexpected loss type: {loss_type}') - assert tensor.shape.is_compatible_with(( - num_tasks, - )), f'The actual shape is {tensor.shape}' + tensor = self._loss_tensors.get((loss_type, normalization)) + if tensor is None: + match loss_type: + case options.LossType.MEAN_SQUARED_ERROR: + tensor = tf.reduce_mean( + self._squared_errors_witn_normalization(normalization), axis=1 + ) + case options.LossType.MEAN_ABSOLUTE_ERROR: + tensor = tf.reduce_mean( + self._absolute_errors_with_normalization(normalization), axis=1 + ) + case options.LossType.HUBER: + absolute_errors = self._absolute_errors_with_normalization( + normalization + ) + # The delta parameter from the Huber loss definition. + huber_delta = tf.constant(1.0, dtype=self._dtype) + # The expression in the quadratic part of the Huber loss expression. + # It is squared in the return statement below. + quadratic = tf.minimum(absolute_errors, huber_delta) + # The linear part of the Huber loss expression. This is zero when + # absolute_error <= huber_delta. + linear = absolute_errors - quadratic + tensor = tf.reduce_mean( + 0.5 * tf.square(quadratic) + huber_delta * linear, axis=1 + ) + case _: + raise ValueError(f'Unexpected loss type: {loss_type}') + assert tensor.shape.is_compatible_with(( + self._num_tasks, + )), f'The actual shape is {tensor.shape}' + self._loss_tensors[loss_type, normalization] = tensor return tensor - def _make_percentile_tensor( - self, - values: tf.RaggedTensor, - num_tasks: int, - percentile_ranks: Sequence[int], - dtype: tf.dtypes.DType, - ) -> tf.Tensor: + def _make_percentile_tensor(self, values: tf.RaggedTensor) -> tf.Tensor: """Creates a percentile tensor from 'values' using self.percentile_ranks. Args: values: A 2D ragged tensor from which the percentiles are collected. The percentiles are collected along the axis 0 of `values`. - num_tasks: The number of tasks for the current model. Used for validation - purposes. - percentile_ranks: The percentile ranks to use for calculating the tensor. Returns: Percentiles based on self_percentile_ranks and the values. The returned tensor is of shape (N_PERCENTILE_RANKS, T), where T is the number of tasks. """ - if not percentile_ranks: - return tf.constant([], dtype=dtype) + if not self._percentile_ranks: + return tf.constant([], dtype=self._dtype) percentile_tensors = [] # NOTE(ondrasej): As of Nov 2022, tfp.stats.percentile() is not compatible # with ragged tensors, so we need to split the ragged tensor into rows and # then stack the individual percentile tensors to the desired output shape. - for task in range(num_tasks): + for task in range(self._num_tasks): task_values = values[task] percentile_tensors.append( - tfp.stats.percentile(task_values, percentile_ranks) + tfp.stats.percentile(task_values, self._percentile_ranks) ) assert percentile_tensors[-1].shape.is_compatible_with((None,)) return tf.stack(percentile_tensors, axis=1) def _squared_errors_witn_normalization( - self, - normalization: options.ErrorNormalization, - num_tasks: int, - delta: tf.Tensor, - squared_errors: tf.Tensor, - squared_percentage_error: tf.Tensor, - absolute_expected_outputs_or_one: tf.Tensor, + self, normalization: options.ErrorNormalization ) -> tf.Tensor: """Returns the tensor of squared errors.""" match normalization: case options.ErrorNormalization.NONE: - result = squared_errors + result = self._squared_errors case options.ErrorNormalization.PERCENTAGE_ERROR: - result = squared_percentage_error + result = self._squared_percentage_error case options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE: - result = tf.square(delta / absolute_expected_outputs_or_one) + result = tf.square(self._delta / self._absolute_expected_outputs_or_one) case _: raise NotImplementedError( f'Squared errors not implemented yet: {normalization}' ) assert result.shape.is_compatible_with( - (num_tasks, None) + (self._num_tasks, None) ), f'Actual shape of the squared errors tensor is {result.shape}' return result def _absolute_errors_with_normalization( - self, - normalization: options.ErrorNormalization, - absolute_errors: tf.Tensor, - absolute_percentage_errors: tf.Tensor, - absolute_expected_outputs_or_one: tf.Tensor, + self, normalization: options.ErrorNormalization ) -> tf.Tensor: """Returns the tensor of absolute errors.""" match normalization: case options.ErrorNormalization.NONE: - return absolute_errors + return self._absolute_errors case options.ErrorNormalization.PERCENTAGE_ERROR: - return absolute_percentage_errors + return self._absolute_percentage_errors case options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE: - return absolute_errors / absolute_expected_outputs_or_one + return self._absolute_errors / self._absolute_expected_outputs_or_one case _: raise NotImplementedError( 'Absolute errors not implemented yet: {normalization}' diff --git a/gematria/model/python/loss_utils_test.py b/gematria/model/python/loss_utils_test.py index 20b7f440..38e3bd8a 100644 --- a/gematria/model/python/loss_utils_test.py +++ b/gematria/model/python/loss_utils_test.py @@ -60,16 +60,18 @@ def test_unscaled_loss(self): self.full_mask, percentile_ranks=(0, 10, 50, 75, 100), dtype=self.dtype, - loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0,), (0.5,), (3,), (4,))) @@ -99,33 +101,23 @@ def test_normalized_loss_when_expected_value_greater_than_one(self): ((1,), (4,), (3,), (0.5,), (0,)), dtype=self.dtype ) mask = tf.ones_like(actual_outputs, dtype=tf.dtypes.bool) - loss_params = { - 'output_values': tf.constant( - ((1.3,), (-2,), (3,), (1,), (2,)), dtype=self.dtype - ), - 'expected_outputs': tf.constant( - ((1,), (4,), (3,), (0.5,), (0,)), dtype=self.dtype - ), - 'mask': tf.ones_like(actual_outputs, dtype=tf.dtypes.bool), - 'dtype': self.dtype, - } + loss = loss_utils.LossComputation( + actual_outputs, expected_outputs, mask, dtype=self.dtype + ) - mean_absolute_error = loss_utils.LossComputation( - **loss_params, - normalization=options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, - loss_type=options.LossType.MEAN_ABSOLUTE_ERROR + mean_absolute_error = loss.loss_tensor( + options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, + options.LossType.MEAN_ABSOLUTE_ERROR, ) - mean_squared_error = loss_utils.LossComputation( - **loss_params, - normalization=options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, - loss_type=options.LossType.MEAN_SQUARED_ERROR + mean_squared_error = loss.loss_tensor( + options.ErrorNormalization.EXPECTED_VALUE_GREATER_THAN_ONE, + options.LossType.MEAN_SQUARED_ERROR, ) self.assertAlmostEqual( - float(mean_absolute_error.loss_tensor), - (0.3 + 1.5 + 0.0 + 0.5 + 2.0) / 5, + float(mean_absolute_error), (0.3 + 1.5 + 0.0 + 0.5 + 2.0) / 5 ) self.assertAlmostEqual( - float(mean_squared_error.loss_tensor), + float(mean_squared_error), (0.3**2 + 1.5**2 + 0.0 + 0.5**2 + 2.0**2) / 5, delta=1e-6, ) @@ -152,12 +144,13 @@ def test_multi_task(self): self.multitask_full_mask, percentile_ranks=(0, 10, 50, 75, 100), dtype=self.dtype, - loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertAllClose( mse, @@ -183,36 +176,35 @@ def test_multi_task(self): ) def test_multi_task_with_mask(self): - loss_params = { - 'output_values': tf.constant( + loss = loss_utils.LossComputation( + output_values=tf.constant( ((1, 20), (2, 12.1), (3, 100), (50, 150), (2.5, 12.5)), dtype=self.dtype, ), - 'expected_outputs': tf.constant( + expected_outputs=tf.constant( ((4, 14), (500, 12), (30, 13), (1, 11), (2, 12)), dtype=self.dtype ), - 'mask': tf.constant(( + mask=tf.constant(( (True, True), (False, True), (True, False), (False, False), (True, False), )), - 'percentile_ranks': (0, 10, 50, 75, 100), - 'dtype': self.dtype, - 'loss_type': options.LossType.HUBER, - } - loss = loss_utils.LossComputation(**loss_params) - loss_percentage_normalized = loss_utils.LossComputation( - **loss_params, normalization=options.ErrorNormalization.PERCENTAGE_ERROR + percentile_ranks=(0, 10, 50, 75, 100), + dtype=self.dtype, ) mse = loss.mean_squared_error mspe = loss.mean_squared_percentage_error mae = loss.mean_absolute_error mape = loss.mean_absolute_percentage_error - huber = loss.loss_tensor - huber_percentage = loss_percentage_normalized.loss_tensor + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) + huber_percentage = loss.loss_tensor( + options.ErrorNormalization.PERCENTAGE_ERROR, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertAllClose( mse, @@ -260,22 +252,24 @@ def test_unknown_shape(self): tf.ones_like(self.actual_outputs, dtype=bool), percentile_ranks=percentile_ranks, dtype=self.dtype, - loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (1,)) self.assertEqual(mae.shape, (1,)) - self.assertEqual(loss.loss_tensor.shape, (1,)) + self.assertEqual(huber.shape, (1,)) self.assertEqual(percentiles.shape, (len(percentile_ranks), 1)) self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,))) @@ -289,12 +283,13 @@ def test_multi_task_unknown_shape(self): tf.ones_like(self.multitask_actual_outputs_array, dtype=bool), percentile_ranks=percentile_ranks, dtype=self.dtype, - loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error - huber = loss.loss_tensor + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (num_tasks,)) @@ -336,23 +331,24 @@ def test_single_task_unknown_shape(self): mask, percentile_ranks=percentile_ranks, dtype=self.dtype, - normalization=options.ErrorNormalization.NONE, - loss_type=options.LossType.HUBER, ) mse = loss.mean_squared_error mae = loss.mean_absolute_error + huber = loss.loss_tensor( + options.ErrorNormalization.NONE, options.LossType.HUBER + ) percentiles = loss.absolute_error_percentiles self.assertEqual(mse.shape, (num_tasks,)) self.assertEqual(mae.shape, (num_tasks,)) - self.assertEqual(loss.loss_tensor.shape, (num_tasks,)) + self.assertEqual(huber.shape, (num_tasks,)) self.assertEqual(percentiles.shape, (len(percentile_ranks), num_tasks)) self.assertNear(float(mse), (3**2 + 0 + 0 + 4**2 + 0.5**2) / 5, 1e-6) self.assertNear(float(mae), (3 + 0 + 0 + 4 + 0.5) / 5, 1e-6) self.assertNear( - float(loss.loss_tensor), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 + float(huber), (2.5 + 0 + 0 + 3.5 + (0.5**2) / 2) / 5, 1e-6 ) self.assertAllEqual(percentiles, ((0,), (0.5,), (3,), (4,))) diff --git a/gematria/model/python/model_base.py b/gematria/model/python/model_base.py index d499fd4f..6bae4319 100644 --- a/gematria/model/python/model_base.py +++ b/gematria/model/python/model_base.py @@ -335,6 +335,7 @@ def __init__( self._decayed_learning_rate = None self._loss: Optional[loss_utils.LossComputation] = None + self._delta_loss: Optional[loss_utils.LossComputation] = None self._train_step: Optional[tf.Operation] = None self._optimizer: Union[ tf.train.Optimizer, tf.train.SyncReplicasOptimizer @@ -1307,28 +1308,28 @@ def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation: tf.constant(schedule['output_mask']), percentile_ranks=self._collected_percentile_ranks, dtype=self.dtype, - normalization=self._loss_normalization, - loss_type=self._loss_type, ) if self._use_deltas: - delta_loss = loss_utils.LossComputation( + self._delta_loss = loss_utils.LossComputation( output['output_deltas'], tf.constant(schedule['expected_outputs_deltas']), output['output_mask_deltas'], percentile_ranks=self._collected_percentile_ranks, dtype=self.dtype, - normalization=self._loss_normalization, - loss_type=self._loss_type, ) if self._use_delta_loss: - loss = delta_loss + loss = self._delta_loss return loss def compute_loss_tensor(self, schedule: FeedDict): - return tf.reduce_mean(self._compute_loss(schedule).loss_tensor) + return tf.reduce_mean( + self._compute_loss(schedule).loss_tensor( + self._loss_normalization, self._loss_type + ) + ) def train_batch( self, @@ -1349,7 +1350,10 @@ def train_batch( with tf.GradientTape() as tape: stats = {} loss = self._compute_loss(schedule) - loss_tensor = tf.reduce_mean(loss.loss_tensor) + loss_tensor_per_task = loss.loss_tensor( + self._loss_normalization, self._loss_type + ) + loss_tensor = tf.reduce_mean(loss_tensor_per_task) # The list of variables to optimize. By default, the list is empty which # means optimize all trainable variables.