Skip to content

Commit

Permalink
Simplify casting base layer. Inputs are ignored for ragged inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 5, 2025
1 parent 123194a commit ab2a914
Showing 1 changed file with 19 additions and 33 deletions.
52 changes: 19 additions & 33 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ def _cat_one(t):

class _CastBatchedDisjointBase(Layer):

def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None,
padded_disjoint: bool = False, uses_mask: bool = False,
def __init__(self,
reverse_indices: bool = False,
dtype_batch: str = "int64",
dtype_index=None,
padded_disjoint: bool = False,
uses_mask: bool = False,
static_batched_node_output_shape: tuple = None,
static_batched_edge_output_shape: tuple = None,
remove_padded_disjoint_from_batched_output: bool = True,
Expand All @@ -29,20 +33,26 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
dtype_index (str): Dtype for index tensor. Default is None.
padded_disjoint (bool): Whether to keep padding in disjoint representation. Default is False.
Not used for ragged arguments.
uses_mask (bool): Whether the padding is marked by a boolean mask or by a length tensor, counting the
non-padded nodes from index 0. Default is False.
Not used for ragged arguments.
static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None.
Not used for ragged arguments.
static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None.
Not used for ragged arguments.
remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output
in case of padding.
Not used for ragged arguments.
"""
super(_CastBatchedDisjointBase, self).__init__(**kwargs)
self.reverse_indices = reverse_indices
self.dtype_index = dtype_index
self.dtype_batch = dtype_batch
self.uses_mask = uses_mask
self.padded_disjoint = padded_disjoint
self.supports_jit = padded_disjoint
if padded_disjoint:
self.supports_jit = True
self.static_batched_node_output_shape = static_batched_node_output_shape
self.static_batched_edge_output_shape = static_batched_edge_output_shape
self.remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output
Expand Down Expand Up @@ -536,31 +546,7 @@ def call(self, inputs: list, **kwargs):
CastBatchedGraphStateToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__


class _CastRaggedToDisjointBase(Layer):

def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None, **kwargs):
r"""Initialize layer.
Args:
reverse_indices (bool): Whether to reverse index order. Default is False.
dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
dtype_index (str): Dtype for index tensor. Default is None.
"""
super(_CastRaggedToDisjointBase, self).__init__(**kwargs)
self.reverse_indices = reverse_indices
self.dtype_index = dtype_index
self.dtype_batch = dtype_batch
# self.supports_jit = False

def get_config(self):
"""Get config dictionary for this layer."""
config = super(_CastRaggedToDisjointBase, self).get_config()
config.update({"reverse_indices": self.reverse_indices, "dtype_batch": self.dtype_batch,
"dtype_index": self.dtype_index})
return config


class CastRaggedAttributesToDisjoint(_CastRaggedToDisjointBase):
class CastRaggedAttributesToDisjoint(_CastBatchedDisjointBase):

def __init__(self, **kwargs):
super(CastRaggedAttributesToDisjoint, self).__init__(**kwargs)
Expand Down Expand Up @@ -598,10 +584,10 @@ def call(self, inputs, **kwargs):
return decompose_ragged_tensor(inputs, batch_dtype=self.dtype_batch)


CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__
CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__


class CastRaggedIndicesToDisjoint(_CastRaggedToDisjointBase):
class CastRaggedIndicesToDisjoint(_CastBatchedDisjointBase):

def __init__(self, **kwargs):
super(CastRaggedIndicesToDisjoint, self).__init__(**kwargs)
Expand Down Expand Up @@ -685,10 +671,10 @@ def call(self, inputs, **kwargs):
return [nodes_flatten, disjoint_indices, graph_id_node, graph_id_edge, node_id, edge_id, node_len, edge_len]


CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__
CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__


class CastDisjointToRaggedAttributes(_CastRaggedToDisjointBase):
class CastDisjointToRaggedAttributes(_CastBatchedDisjointBase):

def __init__(self, **kwargs):
super(CastDisjointToRaggedAttributes, self).__init__(**kwargs)
Expand All @@ -713,4 +699,4 @@ def call(self, inputs, **kwargs):
raise NotImplementedError()


CastDisjointToRaggedAttributes.__init__.__doc__ = CastDisjointToRaggedAttributes.__init__.__doc__
CastDisjointToRaggedAttributes.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__

0 comments on commit ab2a914

Please sign in to comment.