Skip to content

Commit

Permalink
Merge branch 'feature/tensor_schema_padding_values' into 'main'
Browse files Browse the repository at this point in the history
Set new parameter padding_value inside TensorSchemaInfo

See merge request ai-lab-pmo/mltools/recsys/RePlay!248
  • Loading branch information
OnlyDeniko committed Feb 14, 2025
2 parents e2dcc5e + 1dd0661 commit bbcc2ea
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 161 deletions.
10 changes: 10 additions & 0 deletions replay/data/nn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
feature_hint: Optional[FeatureHint] = None,
feature_sources: Optional[List[TensorFeatureSource]] = None,
cardinality: Optional[int] = None,
padding_value: int = 0,
embedding_dim: Optional[int] = None,
tensor_dim: Optional[int] = None,
) -> None:
Expand All @@ -96,6 +97,7 @@ def __init__(
:param cardinality: cardinality of categorical feature, required for ids columns,
optional for others,
default: ``None``.
:param padding_value: value to pad sequences to desired length
:param embedding_dim: embedding dimensions of categorical feature,
default: ``None``.
:param tensor_dim: tensor dimensions of numerical feature,
Expand All @@ -105,6 +107,7 @@ def __init__(
self._feature_hint = feature_hint
self._feature_sources = feature_sources
self._is_seq = is_seq
self._padding_value = padding_value

if not isinstance(feature_type, FeatureType):
msg = "Unknown feature type"
Expand Down Expand Up @@ -203,6 +206,13 @@ def is_list(self) -> bool:
"""
return self.feature_type in [FeatureType.CATEGORICAL_LIST, FeatureType.NUMERICAL_LIST]

@property
def padding_value(self) -> int:
"""
:returns: value to pad sequences to desired length.
"""
return self._padding_value

@property
def cardinality(self) -> Optional[int]:
"""
Expand Down
16 changes: 13 additions & 3 deletions replay/data/nn/torch_sequential_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch.utils.data import Dataset as TorchDataset

from replay.utils.model_handler import deprecation_warning

from .schema import TensorFeatureInfo, TensorMap, TensorSchema
from .sequential_dataset import SequentialDataset

Expand All @@ -25,6 +27,10 @@ class TorchSequentialDataset(TorchDataset):
Torch dataset for sequential recommender models
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -93,11 +99,11 @@ def _generate_tensor_feature(
tensor_dtype = self._get_tensor_dtype(feature)
tensor_sequence = torch.tensor(sequence, dtype=tensor_dtype)
if feature.is_seq:
tensor_sequence = self._pad_sequence(tensor_sequence)
tensor_sequence = self._pad_sequence(tensor_sequence, feature.padding_value)

return tensor_sequence

def _pad_sequence(self, sequence: torch.Tensor) -> torch.Tensor:
def _pad_sequence(self, sequence: torch.Tensor, padding_value: int) -> torch.Tensor:
assert len(sequence) <= self._max_sequence_length
if len(sequence) == self._max_sequence_length:
return sequence
Expand All @@ -114,7 +120,7 @@ def _pad_sequence(self, sequence: torch.Tensor) -> torch.Tensor:

padded_sequence = torch.full(
padded_sequence_shape,
self._padding_value,
padding_value,
dtype=sequence.dtype,
)
padded_sequence[-len(sequence) :].copy_(sequence)
Expand Down Expand Up @@ -169,6 +175,10 @@ class TorchSequentialValidationDataset(TorchDataset):
Torch dataset for sequential recommender models that additionally stores ground truth
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down
8 changes: 8 additions & 0 deletions replay/experimental/nn/data/schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def categorical(
feature_source: Optional[TensorFeatureSource] = None,
feature_hint: Optional[FeatureHint] = None,
embedding_dim: Optional[int] = None,
padding_value: int = 0,
) -> "TensorSchemaBuilder":
source = [feature_source] if feature_source else None
self._tensor_schema[name] = TensorFeatureInfo(
Expand All @@ -29,6 +30,7 @@ def categorical(
feature_sources=source,
feature_hint=feature_hint,
cardinality=cardinality,
padding_value=padding_value,
embedding_dim=embedding_dim,
)
return self
Expand All @@ -40,6 +42,7 @@ def numerical(
is_seq: bool = False,
feature_sources: Optional[List[TensorFeatureSource]] = None,
feature_hint: Optional[FeatureHint] = None,
padding_value: int = 0,
) -> "TensorSchemaBuilder":
self._tensor_schema[name] = TensorFeatureInfo(
name=name,
Expand All @@ -48,6 +51,7 @@ def numerical(
feature_sources=feature_sources,
feature_hint=feature_hint,
tensor_dim=tensor_dim,
padding_value=padding_value,
)
return self

Expand All @@ -59,6 +63,7 @@ def categorical_list(
feature_source: Optional[TensorFeatureSource] = None,
feature_hint: Optional[FeatureHint] = None,
embedding_dim: Optional[int] = None,
padding_value: int = 0,
) -> "TensorSchemaBuilder":
source = [feature_source] if feature_source else None
self._tensor_schema[name] = TensorFeatureInfo(
Expand All @@ -68,6 +73,7 @@ def categorical_list(
feature_sources=source,
feature_hint=feature_hint,
cardinality=cardinality,
padding_value=padding_value,
embedding_dim=embedding_dim,
)
return self
Expand All @@ -79,6 +85,7 @@ def numerical_list(
is_seq: bool = False,
feature_sources: Optional[List[TensorFeatureSource]] = None,
feature_hint: Optional[FeatureHint] = None,
padding_value: int = 0,
) -> "TensorSchemaBuilder":
self._tensor_schema[name] = TensorFeatureInfo(
name=name,
Expand All @@ -87,6 +94,7 @@ def numerical_list(
feature_sources=feature_sources,
feature_hint=feature_hint,
tensor_dim=tensor_dim,
padding_value=padding_value,
)
return self

Expand Down
21 changes: 17 additions & 4 deletions replay/models/nn/sequential/bert4rec/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TorchSequentialDataset,
TorchSequentialValidationDataset,
)
from replay.utils.model_handler import deprecation_warning


class Bert4RecTrainingBatch(NamedTuple):
Expand Down Expand Up @@ -88,6 +89,10 @@ class Bert4RecTrainingDataset(TorchDataset):
Dataset that generates samples to train BERT-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -176,6 +181,10 @@ class Bert4RecPredictionDataset(TorchDataset):
Dataset that generates samples to infer BERT-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -230,6 +239,10 @@ class Bert4RecValidationDataset(TorchDataset):
Dataset that generates samples to infer and validate BERT-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -286,12 +299,12 @@ def _shift_features(
shifted_features: MutableTensorMap = {}
for feature_name, feature in schema.items():
if feature.is_seq:
shifted_features[feature_name] = _shift_seq(features[feature_name])
shifted_features[feature_name] = _shift_seq(features[feature_name], feature.padding_value)
else:
shifted_features[feature_name] = features[feature_name]

# [0, 0, 1, 1, 1] -> [0, 1, 1, 1, 0]
tokens_mask = _shift_seq(padding_mask)
tokens_mask = _shift_seq(padding_mask, 0)

# [0, 1, 1, 1, 0] -> [0, 1, 1, 1, 1]
shifted_padding_mask = tokens_mask.clone()
Expand All @@ -304,7 +317,7 @@ def _shift_features(
)


def _shift_seq(seq: torch.Tensor) -> torch.Tensor:
def _shift_seq(seq: torch.Tensor, padding_value: int) -> torch.Tensor:
shifted_seq = seq.roll(-1, dims=0)
shifted_seq[-1, ...] = 0
shifted_seq[-1, ...] = padding_value
return shifted_seq
6 changes: 4 additions & 2 deletions replay/models/nn/sequential/bert4rec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ def _prepare_prediction_batch(self, batch: Bert4RecPredictionBatch) -> Bert4RecP
for feature_name, feature_tensor in features.items():
if self._schema[feature_name].is_cat:
features[feature_name] = torch.nn.functional.pad(
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
feature_tensor,
(self._model.max_len - sequence_item_count, 0),
value=self._schema[feature_name].padding_value,
)
else:
features[feature_name] = torch.nn.functional.pad(
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
(self._model.max_len - sequence_item_count, 0),
value=0,
value=self._schema[feature_name].padding_value,
).unsqueeze(-1)
padding_mask = torch.nn.functional.pad(
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
Expand Down
13 changes: 13 additions & 0 deletions replay/models/nn/sequential/sasrec/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TorchSequentialDataset,
TorchSequentialValidationDataset,
)
from replay.utils.model_handler import deprecation_warning


class SasRecTrainingBatch(NamedTuple):
Expand All @@ -30,6 +31,10 @@ class SasRecTrainingDataset(TorchDataset):
Dataset that generates samples to train SasRec-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -122,6 +127,10 @@ class SasRecPredictionDataset(TorchDataset):
Dataset that generates samples to infer SasRec-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down Expand Up @@ -170,6 +179,10 @@ class SasRecValidationDataset(TorchDataset):
Dataset that generates samples to infer and validate SasRec-like model
"""

@deprecation_warning(
"`padding_value` parameter will be removed in future versions. "
"Instead, you should specify `padding_value` for each column in TensorSchema"
)
def __init__(
self,
sequential: SequentialDataset,
Expand Down
6 changes: 4 additions & 2 deletions replay/models/nn/sequential/sasrec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,15 @@ def _prepare_prediction_batch(self, batch: SasRecPredictionBatch) -> SasRecPredi
for feature_name, feature_tensor in features.items():
if self._schema[feature_name].is_cat:
features[feature_name] = torch.nn.functional.pad(
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
feature_tensor,
(self._model.max_len - sequence_item_count, 0),
value=self._schema[feature_name].padding_value,
)
else:
features[feature_name] = torch.nn.functional.pad(
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
(self._model.max_len - sequence_item_count, 0),
value=0,
value=self._schema[feature_name].padding_value,
).unsqueeze(-1)
padding_mask = torch.nn.functional.pad(
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
Expand Down
4 changes: 4 additions & 0 deletions tests/data/nn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def item_id_and_item_features_schema():
is_seq=True,
feature_type=FeatureType.CATEGORICAL_LIST,
feature_sources=[TensorFeatureSource(FeatureSource.ITEM_FEATURES, "item_cat_list")],
padding_value=1,
),
TensorFeatureInfo(
"item_num",
Expand All @@ -375,6 +376,7 @@ def item_id_and_item_features_schema():
is_seq=True,
feature_type=FeatureType.NUMERICAL_LIST,
feature_sources=[TensorFeatureSource(FeatureSource.ITEM_FEATURES, "item_num_list")],
padding_value=1,
),
]
)
Expand Down Expand Up @@ -564,6 +566,7 @@ def sequential_dataset():
is_seq=True,
feature_type=FeatureType.CATEGORICAL,
feature_hint=FeatureHint.ITEM_ID,
padding_value=-1,
),
TensorFeatureInfo(
"some_user_feature",
Expand All @@ -576,6 +579,7 @@ def sequential_dataset():
cardinality=6,
is_seq=True,
feature_type=FeatureType.CATEGORICAL,
padding_value=-2,
),
]
)
Expand Down
Loading

0 comments on commit bbcc2ea

Please sign in to comment.