Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Multi-GPU Evaluation #3611

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions flair/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import os
import random
from multiprocessing.connection import Connection
from typing import Callable
from typing import Callable, Collection, Iterable, TypeVar

import numpy as np
import torch
import torch.multiprocessing as mp
from torch.distributed import destroy_process_group, init_process_group
Expand All @@ -15,8 +14,10 @@

log = logging.getLogger("flair")

T = TypeVar("T")

def launch_distributed(fn, *args, **kwargs):

def launch_distributed(fn: Callable, *args, **kwargs):
"""Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU).

If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune.
Expand Down Expand Up @@ -61,16 +62,6 @@ def is_main_process() -> bool:
return True


def aggregate(value, aggregation_fn=np.mean):
"""Gather `value` from all processes and send to `aggregation_fn` to get a single return value."""
if torch.distributed.is_initialized():
gathered_values = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_values, value)
else:
gathered_values = [value]
return aggregation_fn(gathered_values)


def validate_corpus_same_each_process(corpus: Corpus) -> None:
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable
Expand All @@ -81,9 +72,76 @@ def validate_corpus_same_each_process(corpus: Corpus) -> None:


def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None:
""":raises: ValueError if the dataset is not the same on each process."""
random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset)))
for i in random_indices:
example = str(dataset[i])
examples = aggregate(example, list)
if not all(example == examples[0] for example in examples):
raise ValueError("Dataset must be the same on each process")


def gather(value: T) -> list[T]:
"""Gather `value` from all processes and return a list of values."""
if torch.distributed.is_initialized():
gathered_values = [value for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_values, value)
else:
gathered_values = [value]
return gathered_values


def aggregate(value: T, aggregation_fn: Callable):
"""Gather `value` from all processes and send to `aggregation_fn` to get a single return value."""
gathered_values = gather(value)
return aggregation_fn(gathered_values)


def broadcast_value(value: T, src: int = 0) -> T:
"""
Broadcasts a Python object from the source process (src) to all other processes.
Every process returns the same object.
"""
obj_list = [value]
torch.distributed.broadcast_object_list(obj_list, src=src)
return obj_list[0]


# aggregation functions
def flatten(l: Iterable[Iterable[T]]) -> list[T]:
"""Flattens all elements in an iterable, such as a list, of iterables into a single list."""
return [x for s in l for x in s]


def flatten_set(list_of_sets: Iterable[Iterable[T]]) -> set[T]:
"""Flattens all elements in an iterable, such as a list, of iterables into a single set."""
return {x for subset in list_of_sets for x in subset}


def merge_sets(list_of_sets: Collection[set[T]]) -> set[T]:
"""Merges a collection of sets into a single set."""
merged_set = set()
for s in list_of_sets:
merged_set.update(s)
return merged_set


def flatten_dicts(list_of_dicts: list[dict[str, list[T]]]) -> dict[str, list[T]]:
"""This function merges a list of dictionaries with list values into a single dictionary with merged list values."""
merged_dict: dict[str, list[T]] = {}
for d in list_of_dicts:
for k, v in d.items():
if k not in merged_dict:
merged_dict[k] = []
merged_dict[k].extend(v)
return merged_dict


def aggregate_tensor_sum(list_of_tensors: list[torch.Tensor]) -> torch.Tensor:
"""
Custom aggregation function to sum loss values from all processes.
Moves all tensors to CPU and converts them to Python scalars before summing.
Returns a single tensor containing the summed loss.
"""
total = sum(t.cpu().item() for t in list_of_tensors)
return torch.tensor(total)
76 changes: 51 additions & 25 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import torch
from torch import nn
from torch.utils.data import DistributedSampler
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

import flair.embeddings
import flair.nn
from flair.data import Corpus, Dictionary, Sentence, TextPair, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.distributed_utils import aggregate, aggregate_tensor_sum, broadcast_value, flatten, is_main_process
from flair.nn.model import ReduceTransformerVocabMixin
from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings

Expand Down Expand Up @@ -288,13 +290,21 @@ def evaluate(
exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
multi_gpu: bool = False,
**kwargs,
) -> Result:
exclude_labels = exclude_labels if exclude_labels is not None else []

# read Dataset into data loader, if list of sentences passed, make Dataset first
if not isinstance(data_points, Dataset):
data_points = FlairDatapointDataset(data_points)
data_loader = DataLoader(data_points, batch_size=mini_batch_size)

data_loader = DataLoader(
data_points,
batch_size=mini_batch_size,
shuffle=False,
sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None,
)

with torch.no_grad():
eval_loss = torch.zeros(1, device=flair.device)
Expand All @@ -311,15 +321,15 @@ def evaluate(
if isinstance(batch, Sentence):
batch = [batch]

loss, num, scores = self._forward_loss_and_scores(batch, return_scores=True)
loss, num, scores_forward = self._forward_loss_and_scores(batch, return_scores=True)

true_values = []
for sentence in batch:
total_count += 1
for label in sentence.get_labels(gold_label_type):
true_values.append(float(label.value))

results = scores.cpu().tolist()
results = scores_forward.cpu().tolist()

eval_loss += loss

Expand All @@ -336,30 +346,46 @@ def evaluate(
if out_path is not None:
out_file.close()

if multi_gpu:
metric.true = aggregate(metric.true, flatten)
metric.pred = aggregate(metric.pred, flatten)
eval_loss = aggregate(eval_loss, aggregate_tensor_sum)
total_count = aggregate(total_count, sum)

eval_loss /= total_count

detailed_result = (
f"AVG: mse: {metric.mean_squared_error():.4f} - "
f"mae: {metric.mean_absolute_error():.4f} - "
f"pearson: {metric.pearsonr():.4f} - "
f"spearman: {metric.spearmanr():.4f}"
)
if is_main_process(): # only calculate metrics in main process

eval_metrics = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}
detailed_result = (
f"AVG: mse: {metric.mean_squared_error():.4f} - "
f"mae: {metric.mean_absolute_error():.4f} - "
f"pearson: {metric.pearsonr():.4f} - "
f"spearman: {metric.spearmanr():.4f}"
)

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = eval_metrics[main_evaluation_metric[1]]
else:
main_score = eval_metrics["spearman"]
scores = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}

return Result(
main_score=main_score,
detailed_results=detailed_result,
scores=eval_metrics,
)
if main_evaluation_metric[0] in ("correlation", "other"):
main_score = scores[main_evaluation_metric[1]]
else:
main_score = scores["spearman"]

result = Result(
main_score=main_score,
detailed_results=detailed_result,
scores=scores,
)

else: # if it's not the main process, just set a dummy Result
result = Result(0.0, "", {}, {"loss": 0.0})

if multi_gpu:
result = broadcast_value(result, src=0)

return result
82 changes: 53 additions & 29 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import torch
from torch import nn
from torch.utils.data import DistributedSampler
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

import flair
import flair.embeddings
from flair.data import Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.distributed_utils import aggregate, aggregate_tensor_sum, broadcast_value, flatten, is_main_process
from flair.embeddings.base import load_embeddings
from flair.nn.model import ReduceTransformerVocabMixin
from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings
Expand Down Expand Up @@ -141,13 +143,21 @@ def evaluate(
exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
multi_gpu: bool = False,
**kwargs,
) -> Result:
exclude_labels = exclude_labels if exclude_labels is not None else []

# read Dataset into data loader, if list of sentences passed, make Dataset first
if not isinstance(data_points, Dataset):
data_points = FlairDatapointDataset(data_points)
data_loader = DataLoader(data_points, batch_size=mini_batch_size)

data_loader = DataLoader(
data_points,
batch_size=mini_batch_size,
shuffle=False,
sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None,
)

with torch.no_grad():
eval_loss = torch.zeros(1, device=flair.device)
Expand All @@ -156,19 +166,19 @@ def evaluate(

lines: list[str] = []
total_count = 0
for batch in data_loader:
for batch in tqdm(data_loader):
if isinstance(batch, Sentence):
batch = [batch]

scores, loss = self.forward_labels_and_loss(batch)
scores_forward, loss = self.forward_labels_and_loss(batch)

true_values = []
for sentence in batch:
total_count += 1
for label in sentence.get_labels(gold_label_type):
true_values.append(float(label.value))

results = scores[:, 0].cpu().tolist()
results = scores_forward[:, 0].cpu().tolist()

eval_loss += loss

Expand All @@ -181,38 +191,52 @@ def evaluate(

store_embeddings(batch, embedding_storage_mode)

if multi_gpu:
metric.true = aggregate(metric.true, flatten)
metric.pred = aggregate(metric.pred, flatten)
eval_loss = aggregate(eval_loss, aggregate_tensor_sum)
total_count = aggregate(total_count, sum)

eval_loss /= total_count

# TODO: not saving lines yet
if out_path is not None:
with open(out_path, "w", encoding="utf-8") as outfile:
outfile.write("".join(lines))

detailed_result = (
f"AVG: mse: {metric.mean_squared_error():.4f} - "
f"mae: {metric.mean_absolute_error():.4f} - "
f"pearson: {metric.pearsonr():.4f} - "
f"spearman: {metric.spearmanr():.4f}"
)

eval_metrics = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = eval_metrics[main_evaluation_metric[1]]
else:
main_score = eval_metrics["spearman"]

result = Result(
main_score=main_score,
detailed_results=detailed_result,
scores=eval_metrics,
)
if is_main_process(): # only calculate metrics in main process

detailed_result = (
f"AVG: mse: {metric.mean_squared_error():.4f} - "
f"mae: {metric.mean_absolute_error():.4f} - "
f"pearson: {metric.pearsonr():.4f} - "
f"spearman: {metric.spearmanr():.4f}"
)

scores = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = scores[main_evaluation_metric[1]]
else:
main_score = scores["spearman"]

result = Result(
main_score=main_score,
detailed_results=detailed_result,
scores=scores,
)

else: # if it's not the main process, just set a dummy Result
result = Result(0.0, "", {}, {"loss": 0.0})

if multi_gpu:
result = broadcast_value(result, src=0)

return result

Expand Down
Loading