From 5bd59e4d4a578b56e694d3d8ec13664de880fc7a Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Sun, 19 Feb 2023 13:01:07 +0000 Subject: [PATCH] Revert "Add Python 3.11 to CI tests (#91)" This reverts commit a3fadd2da04d374d2c9225c38eb8beff8c2d6d61. --- .github/workflows/cpu_ci.yml | 22 ---- elk/__main__.py | 64 +++++++---- elk/extraction/extraction.py | 170 ++++++++---------------------- elk/extraction/extraction_main.py | 78 ++++++++++---- elk/extraction/prompt_collator.py | 41 +++---- 5 files changed, 164 insertions(+), 211 deletions(-) diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 4f51c990..bae0477c 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -46,25 +46,3 @@ jobs: - name: Run CPU Tests run: pytest -m cpu - - run-tests-python3_11: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - - name: Install Python - uses: actions/setup-python@v4 - with: - python-version: "3.11" - - - name: Upgrade Pip - run: python -m pip install --upgrade pip - - - name: Install Dependencies - run: pip install -e .[dev] - - - name: Type Checking - uses: jakebailey/pyright-action@v1 - - - name: Run CPU Tests - run: pytest -m cpu diff --git a/elk/__main__.py b/elk/__main__.py index 1d592980..0bdddedd 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -4,6 +4,8 @@ from .files import args_to_uuid from .list import list_runs from argparse import ArgumentParser +from contextlib import nullcontext, redirect_stdout +import logging import warnings @@ -38,7 +40,6 @@ def run(): list_runs(args) return - # Import here and not at the top to speed up `elk list` from transformers import AutoConfig, PretrainedConfig config = AutoConfig.from_pretrained(args.model) @@ -66,30 +67,47 @@ def run(): # Import here and not at the top to speed up `elk list` from .extraction.extraction_main import run as run_extraction from .training.train import train + import os + import torch.distributed as dist - # Print CLI arguments to stdout - for key, value in vars(args).items(): - print(f"{key}: {value}") - - if args.command == "extract": - run_extraction(args) - elif args.command == "elicit": - # The user can specify a name for the run, but by default we use the - # MD5 hash of the arguments to ensure the name is unique - if not args.name: - args.name = args_to_uuid(args) - - try: - train(args) - except (EOFError, FileNotFoundError): - run_extraction(args) - train(args) + # Check if we were called with torchrun or not + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + dist.init_process_group("nccl") + local_rank = int(local_rank) + + with redirect_stdout(None) if local_rank else nullcontext(): + # Print CLI arguments to stdout + for key, value in vars(args).items(): + print(f"{key}: {value}") - elif args.command == "eval": - # TODO: Implement evaluation script - raise NotImplementedError - else: - raise ValueError(f"Unknown command {args.command}") + if local_rank: + logging.getLogger("transformers").setLevel(logging.CRITICAL) + + if args.command == "extract": + run_extraction(args) + elif args.command == "elicit": + # The user can specify a name for the run, but by default we use the + # MD5 hash of the arguments to ensure the name is unique + if not args.name: + args.name = args_to_uuid(args) + + try: + train(args) + except (EOFError, FileNotFoundError): + run_extraction(args) + + # Ensure the extraction is finished before starting training + if dist.is_initialized(): + dist.barrier() + + train(args) + + elif args.command == "eval": + # TODO: Implement evaluation script + raise NotImplementedError + else: + raise ValueError(f"Unknown command {args.command}") if __name__ == "__main__": diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 9638371f..e09a37b6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -2,36 +2,19 @@ from ..utils import pytree_map from .prompt_collator import Prompt, PromptCollator -from dataclasses import dataclass from einops import rearrange from torch.utils.data import DataLoader -from transformers import ( - BatchEncoding, - PreTrainedModel, - PreTrainedTokenizerBase, - AutoModel, -) -from typing import cast, Literal, Sequence -import logging -import numpy as np +from tqdm.auto import tqdm +from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase +from typing import cast, Iterable, Literal, Sequence import torch -import torch.multiprocessing as mp - - -@dataclass -class ExtractionParameters: - model_str: str - tokenizer: PreTrainedTokenizerBase - collator: PromptCollator - batch_size: int = 1 - layers: Sequence[int] = () - prompt_suffix: str = "" - token_loc: Literal["first", "last", "mean"] = "last" - use_encoder_states: bool = False +import torch.distributed as dist +@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore +@torch.no_grad() def extract_hiddens( - model_str: str, + model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, collator: PromptCollator, *, @@ -41,85 +24,40 @@ def extract_hiddens( prompt_suffix: str = "", token_loc: Literal["first", "last", "mean"] = "last", use_encoder_states: bool = False, -): - """Run inference on a model with a set of prompts, yielding the hidden states.""" - - ctx = mp.get_context("spawn") - queue = ctx.Queue() - - num_gpus = torch.cuda.device_count() - params = ExtractionParameters( - model_str=model_str, - tokenizer=tokenizer, - collator=collator, - batch_size=batch_size, - layers=layers, - prompt_suffix=prompt_suffix, - token_loc=token_loc, - use_encoder_states=use_encoder_states, - ) - - # Spawn a process for each GPU - ctx = torch.multiprocessing.spawn( - _extract_hiddens_process, - args=(num_gpus, queue, params), - nprocs=num_gpus, - join=False, - ) - assert ctx is not None - - # Yield results from the queue - for _ in range(len(collator)): - yield queue.get() - - # Clean up - ctx.join() - - -@torch.no_grad() -def _extract_hiddens_process( - rank: int, - world_size: int, - queue: mp.Queue, - params: ExtractionParameters, -): +) -> Iterable[tuple[torch.Tensor, list[int]]]: + """Run inference on a model with a set of prompts, yielding the hidden states. + + Args: + model: The model to run inference on. + tokenizer: The tokenizer to use for tokenization. + collator: The PromptCollator to use for generating prompts. + batch_size: The batch size to use for inference. + layers (Sequence[int]): The layers to extract hidden states from. + prompt_suffix (str): A string to append to the end of each prompt. + token_loc: The location of the token to extract hidden states from. + can be either "first", "last", or "mean". Defaults to "last". + use_encoder_states: Whether to use the encoder states instead of the + decoder states. This allows simplification from an encoder-decoder + model to an encoder-only model. Defaults to False. """ - Do inference on a model with a set of prompts on a single process. - To be passed to Dataset.from_generator. - """ - print(f"Process with rank={rank}") - if rank != 0: - logging.getLogger("transformers").setLevel(logging.CRITICAL) - - num_choices = params.collator.num_classes - shards = np.array_split(np.arange(len(params.collator)), world_size) - params.collator.select_(shards[rank]) - - # AutoModel should do the right thing here in nearly all cases. We don't actually - # care what head the model has, since we are just extracting hidden states. - model = AutoModel.from_pretrained(params.model_str, torch_dtype="auto").to( - f"cuda:{rank}" - ) - if params.use_encoder_states and not model.config.is_encoder_decoder: - raise ValueError( - "use_encoder_states is only compatible with encoder-decoder models." - ) + device = model.device + num_choices = collator.num_classes # TODO: Make this configurable or something # Token used to separate the question from the answer - sep_token = params.tokenizer.sep_token or "\n" + sep_token = tokenizer.sep_token or "\n" # TODO: Maybe also make this configurable? # We want to make sure the answer is never truncated - params.tokenizer.truncation_side = "left" - if not params.tokenizer.pad_token: - params.tokenizer.pad_token = params.tokenizer.eos_token + tokenizer.truncation_side = "left" + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token def tokenize(strings: list[str]): return pytree_map( - lambda x: x.to(f"cuda:{rank}"), - params.tokenizer( + lambda x: x.to(device), + tokenizer( strings, padding=True, return_tensors="pt", @@ -131,7 +69,7 @@ def tokenize(strings: list[str]): # each question-answer pair. After inference we need to reshape the results. def collate(prompts: list[Prompt]) -> tuple[BatchEncoding, list[int]]: choices = [ - prompt.to_string(i, sep=sep_token) + params.prompt_suffix + prompt.to_string(i, sep=sep_token) + prompt_suffix for prompt in prompts for i in range(num_choices) ] @@ -145,7 +83,7 @@ def collate_enc_dec( ) tokenized_answers = tokenize( [ - prompt.answers[i] + params.prompt_suffix + prompt.answers[i] + prompt_suffix for prompt in prompts for i in range(num_choices) ] @@ -161,22 +99,22 @@ def reduce_seqs( # Unflatten the hiddens hiddens = [rearrange(h, "(b c) l d -> b c l d", c=num_choices) for h in hiddens] - if params.token_loc == "first": + if token_loc == "first": hiddens = [h[..., 0, :] for h in hiddens] - elif params.token_loc == "last": + elif token_loc == "last": # Because of padding, the last token is going to be at a different index # for each example, so we use gather. B, C, _, D = hiddens[0].shape lengths = attention_mask.sum(dim=-1).view(B, C, 1, 1) indices = lengths.sub(1).expand(B, C, 1, D) hiddens = [h.gather(index=indices, dim=-2).squeeze(-2) for h in hiddens] - elif params.token_loc == "mean": + elif token_loc == "mean": hiddens = [h.mean(dim=-2) for h in hiddens] else: - raise ValueError(f"Invalid token_loc: {params.token_loc}") + raise ValueError(f"Invalid token_loc: {token_loc}") - if params.layers: - hiddens = [hiddens[i] for i in params.layers] + if layers: + hiddens = [hiddens[i] for i in layers] # [batch size, layers, num choices, hidden size] return torch.stack(hiddens, dim=1) @@ -185,7 +123,7 @@ def reduce_seqs( # we don't need to run the decoder at all. Just strip it off, making the problem # equivalent to a regular encoder-only model. is_enc_dec = model.config.is_encoder_decoder - if is_enc_dec and params.use_encoder_states: + if is_enc_dec and use_encoder_states: # This isn't actually *guaranteed* by HF, but it's true for all existing models if not hasattr(model, "get_encoder") or not callable(model.get_encoder): raise ValueError( @@ -196,16 +134,17 @@ def reduce_seqs( # Whether to concatenate the question and answer before passing to the model. # If False pass them to the encoder and decoder separately. - should_concat = not is_enc_dec or params.use_encoder_states + should_concat = not is_enc_dec or use_encoder_states dl = DataLoader( - params.collator, - batch_size=params.batch_size, + collator, + batch_size=batch_size, collate_fn=collate if should_concat else collate_enc_dec, ) # Iterating over questions - for batch in dl: + rank = dist.get_rank() if dist.is_initialized() else 0 + for batch in tqdm(dl, position=rank): # Condition 1: Encoder-decoder transformer, with answer in the decoder if not should_concat: questions, answers, labels = batch @@ -215,16 +154,7 @@ def reduce_seqs( output_hidden_states=True, ) # [batch_size, num_layers, num_choices, hidden_size] - # need to convert hidden states to numpy array first or - # you get a ConnectionResetErrror - queue.put( - { - "hiddens": torch.stack(outputs.decoder_hidden_states, dim=2) - .cpu() - .numpy(), - "labels": labels, - } - ) + yield torch.stack(outputs.decoder_hidden_states, dim=2), labels # Condition 2: Either a decoder-only transformer or a transformer encoder else: @@ -232,12 +162,4 @@ def reduce_seqs( # Skip the input embeddings which are unlikely to be interesting h = model(**choices, output_hidden_states=True).hidden_states[1:] - - # need to convert hidden states to numpy array first or - # you get a ConnectionResetErrror - queue.put( - { - "hiddens": reduce_seqs(h, choices["attention_mask"]).cpu().numpy(), - "labels": labels, - } - ) + yield reduce_seqs(h, choices["attention_mask"]), labels diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 64111b87..390c9fc3 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -3,9 +3,11 @@ from .extraction import extract_hiddens, PromptCollator from ..files import args_to_uuid, elk_cache_dir from ..training.preprocessing import silence_datasets_messages -from transformers import AutoConfig, AutoTokenizer +from ..utils import maybe_all_gather, maybe_barrier, select_usable_gpus +from transformers import AutoModel, AutoTokenizer import json -from datasets import Dataset +import torch +import torch.distributed as dist def run(args): @@ -43,17 +45,49 @@ def extract(args, split: str): else: raise ValueError(f"Unknown prompt strategy: {args.prompts}") - return Dataset.from_generator( - extract_hiddens, - gen_kwargs={ - "model_str": args.model, - "tokenizer": tokenizer, - "collator": collator, - "layers": args.layers, - "prompt_suffix": args.prompt_suffix, - "token_loc": args.token_loc, - "use_encoder_states": args.use_encoder_states, - }, + items = [ + (features, labels) + for features, labels in extract_hiddens( + model, + tokenizer, + collator, + layers=args.layers, + prompt_suffix=args.prompt_suffix, + token_loc=args.token_loc, + use_encoder_states=args.use_encoder_states, + ) + ] + save_dir.mkdir(parents=True, exist_ok=True) + + with open(save_dir / f"{split}_hiddens.pt", "wb") as f: + hidden_batches, label_batches = zip(*items) + hiddens = maybe_all_gather(torch.cat(hidden_batches)) # type: ignore + + # Moving labels to GPU just to be able to use maybe_all_gather + labels = torch.tensor(sum(label_batches, []), device=hiddens.device) + labels = maybe_all_gather(labels) # type: ignore + + if not dist.is_initialized() or dist.get_rank() == 0: + torch.save((hiddens.cpu(), labels.cpu()), f) + + # AutoModel should do the right thing here in nearly all cases. We don't actually + # care what head the model has, since we are just extracting hidden states. + print(f"Loading model '{args.model}'...") + model = AutoModel.from_pretrained(args.model, torch_dtype="auto") + print(f"Done. Model class: '{model.__class__.__name__}'") + + # Intelligently select a GPU with enough memory + if dist.is_initialized(): + model.to(f"cuda:{dist.get_rank()}") + elif torch.cuda.is_available(): + # We at least need enough VRAM to hold the model parameters + min_memory = sum(p.element_size() * p.numel() for p in model.parameters()) + (device_idx,) = select_usable_gpus(max_gpus=1, min_memory=min_memory) + model.to(f"cuda:{device_idx}") + + if args.use_encoder_states and not model.config.is_encoder_decoder: + raise ValueError( + "--use_encoder_states is only compatible with encoder-decoder models." ) print("Loading tokenizer...") @@ -69,14 +103,14 @@ def extract(args, split: str): print("Loading datasets") silence_datasets_messages() - train_dset = extract(args, "train") - valid_dset = extract(args, "validation") + maybe_barrier() # Not strictly necessary but makes the output cleaner + extract(args, "train") + maybe_barrier() + extract(args, "validation") - with open(save_dir / "args.json", "w") as f: - json.dump(vars(args), f) + if not dist.is_initialized() or dist.get_rank() == 0: + with open(save_dir / "args.json", "w") as f: + json.dump(vars(args), f) - with open(save_dir / "model_config.json", "w") as f: - config = AutoConfig.from_pretrained(args.model) - json.dump(config.to_dict(), f) - - return train_dset, valid_dset + with open(save_dir / "model_config.json", "w") as f: + json.dump(model.config.to_dict(), f) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 912ba0cb..c50cc9c5 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -13,8 +13,9 @@ from promptsource.templates import DatasetTemplates from random import Random from torch.utils.data import Dataset as TorchDataset -from typing import Literal, Optional +from typing import Literal, Optional, cast import numpy as np +import torch.distributed as dist @dataclass @@ -101,7 +102,6 @@ def __init__( ds_dict = ds_dict[split_name].train_test_split( seed=seed, shuffle=False, stratify_by_column=label_column ) - assert isinstance(ds_dict, DatasetDict) # Lots of datasets have a validation split or a test split, but not both. If # the requested split doesn't exist, we try to use the other one instead. @@ -166,33 +166,37 @@ def __init__( print(f"Undersampling classes to {smallest_size} examples each") # First group the active split by class - strata = ( + strata = [ self.active_split.filter(lambda ex: ex[label_column] == i) for i in range(self.num_classes) - ) + ] # Then randomly sample `smallest_size` examples from each class and merge - undersampled = concatenate_datasets( - [ - stratum.select( - self.rng.sample(range(len(stratum)), k=smallest_size) - ) - for stratum in strata - ] + self.active_split = cast( + Dataset, + concatenate_datasets( + [ + stratum.select( + self.rng.sample(range(len(stratum)), k=smallest_size) + ) + for stratum in strata + ] + ), ) - assert isinstance(undersampled, Dataset) - self.active_split = undersampled # Sanity check that we successfully balanced the classes class_sizes = np.bincount( - list(self.active_split[label_column]), minlength=self.num_classes + self.active_split[label_column], minlength=self.num_classes ) assert np.all(class_sizes == smallest_size) # Store the (possibly post-undersampling) empirical class balance for later self.class_fracs: NDArray[np.floating] = class_sizes / class_sizes.sum() - if self.num_classes < 2: - raise ValueError(f"Dataset {path}/{name} has only one label") + # Shard across ranks iff we're in a distributed setting + if dist.is_initialized(): + self.active_split = self.active_split.shard( + dist.get_world_size(), dist.get_rank() + ) # We use stratified sampling to create few-shot prompts that are as balanced as # possible. If needed, create the strata now so that we can use them later. @@ -216,7 +220,7 @@ def __init__( for i in range(self.num_classes) ] else: - self.fewshot_strata = [] + self.fewshot_strata: list[Dataset] = [] # Now shuffle the active split and truncate it if needed self.active_split = self.active_split.shuffle(seed=seed) @@ -289,6 +293,3 @@ def num_classes(self) -> int: # We piggyback on the ClassLabel feature type to get the number of classes return self.active_split.features[self.label_column].num_classes - - def select_(self, indices): - self.dataset = self.dataset.select(indices)