Skip to content

Commit

Permalink
Merge pull request #140 from AIRI-Institute/dataloader-fix
Browse files Browse the repository at this point in the history
Dataloader fix
  • Loading branch information
Vitaly-Protasov authored Mar 24, 2024
2 parents aa7af38 + 2b45a68 commit 8eb7adb
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 180 deletions.
22 changes: 13 additions & 9 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
- Basic example:
```python
from probing.pipeline import ProbingPipeline
from pathlib import Path

experiment = ProbingPipeline(
hf_model_name="bert-base-uncased",
device="cuda:0",
metric_names=["f1", "accuracy"],
encoding_batch_size=32,
classifier_batch_size=32)
hf_model_name="DeepPavlov/rubert-base-cased",
device="cuda:0",
metric_names=["f1", "accuracy"],
encoding_batch_size=32,
classifier_batch_size=32
)

experiment.run(probe_task=Path("example.csv").stem,
path_to_task_file="example.csv",
verbose=True,
train_epochs=20,)
experiment.run(
probe_task="gapping",
#path_to_task_file="example.csv",
verbose=True,
train_epochs=20
)
```
- For other examples, see [Probing Pipeline documentation](https://github.com/AIRI-Institute/Probing_framework/tree/main/scripts).
17 changes: 11 additions & 6 deletions probing/data_former.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import os
import typing

try:
from typing import Literal # type: ignore
except:
from typing_extensions import Literal # type: ignore

from collections import Counter, defaultdict
from typing import DefaultDict, Dict, Optional, Set, Tuple, Union

Expand All @@ -22,9 +28,10 @@ def __init__(
):
self.probe_task = probe_task
self.shuffle = shuffle
self.sep = sep
self.data_path = get_probe_task_path(probe_task, data_path)

self.samples, self.unique_labels = self.form_data(sep=sep)
self.samples, self.unique_labels = self.form_data()

def __len__(self):
return len(self.samples)
Expand All @@ -33,7 +40,7 @@ def __getitem__(self, idx):
return self.samples[idx]

@property
def ratio_by_classes(self) -> Dict[str, Dict[str, int]]:
def ratio_by_classes(self) -> Dict[Literal["tr", "va", "te"], Dict[str, int]]:
ratio_by_classes = {}
for class_name in self.samples:
class_labels_all = [i[1] for i in self.samples[class_name]]
Expand All @@ -42,12 +49,10 @@ def ratio_by_classes(self) -> Dict[str, Dict[str, int]]:
return ratio_by_classes

@typing.no_type_check
def form_data(
self, sep: str = "\t"
) -> Tuple[DefaultDict[str, np.ndarray], Set[str]]:
def form_data(self) -> Tuple[DefaultDict[str, np.ndarray], Set[str]]:
samples_dict = defaultdict(list)
unique_labels = set()
dataset = pd.read_csv(self.data_path, sep=sep, header=None, dtype=str)
dataset = pd.read_csv(self.data_path, sep=self.sep, header=None, dtype=str)
for _, (stage, label, text) in dataset.iterrows():
samples_dict[stage].append((text, label))
unique_labels.add(label)
Expand Down
27 changes: 15 additions & 12 deletions probing/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def _get_embeddings_by_layers(
aggregation_embeddings: AggregationType,
) -> List[torch.Tensor]:
layers_outputs = []
for output in model_outputs[1:]: # type: ignore
if len(model_outputs) == 1:
process_outputs = model_outputs
else:
process_outputs = model_outputs[1:]
for output in process_outputs: # type: ignore
if aggregation_embeddings == AggregationType("first"):
sent_vector = output[:, 0, :] # type: ignore
elif aggregation_embeddings == AggregationType("last"):
Expand Down Expand Up @@ -255,11 +259,13 @@ def model_layers_forward(
return_dict=self.return_dict,
)

model_outputs = (
model_outputs["hidden_states"]
if "hidden_states" in model_outputs
else model_outputs["encoder_hidden_states"]
)
if "hidden_states" in model_outputs:
model_outputs = model_outputs["hidden_states"]
elif "last_hidden_state" in model_outputs:
model_outputs = model_outputs["last_hidden_state"]
else:
model_outputs = model_outputs["encoder_hidden_states"]

layers_outputs = self._get_embeddings_by_layers(
model_outputs, aggregation_embeddings=aggregation_embeddings
)
Expand All @@ -285,7 +291,8 @@ def encode_data(
else data
)

for batch_input_ids, batch_attention_mask, batch_labels in iter_data:
for batch in iter_data:
batch_input_ids, batch_attention_mask, batch_labels = batch
in_cache_ids, out_cache_ids = self.Caching.check_cache_ids(
batch_input_ids
)
Expand Down Expand Up @@ -357,18 +364,14 @@ def get_encoded_dataloaders(
verbose: bool = True,
do_control_task: bool = False,
) -> Tuple[Dict[Literal["tr", "va", "te"], DataLoader], Dict[str, int]]:
# if self.tokenizer.model_max_length > self.model_max_length:
# logger.warning(
# f"In tokenizer model_max_length = {self.tokenizer.model_max_length}. Changed to {self.model_max_length} for preventing Out-Of-Memory."
# )
if self.Caching is None:
if self.tokenizer is None:
raise RuntimeError("Tokenizer is None")
self.Caching = Cacher(tokenizer=self.tokenizer, cache={})

tokenized_datasets = self.get_tokenized_datasets(task_dataset)
encoded_dataloaders = {}
for stage, _ in tokenized_datasets.items():
for stage in tokenized_datasets:
stage_dataloader_tokenized = DataLoader(
tokenized_datasets[stage], batch_size=encoding_batch_size
)
Expand Down
Loading

0 comments on commit 8eb7adb

Please sign in to comment.