Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,25 @@
def get_dataset(dataset: str, tokenizer: PreTrainedTokenizerBase) -> (pd.DataFrame, pd.DataFrame, List):
da = DATASET_NAMES2LOADERS[dataset]()
# Filter extremely long samples from both train and test samples:
_logger.info("filtering test set:")
_logger.info(f"Filtering test set for dataset: {dataset}")
test_df = filter_extremely_long_samples(da.test_df, tokenizer)
_logger.info("filtering train set:")
_logger.info(f"Filtering train set for dataset: {dataset}")
train_df = filter_extremely_long_samples(da.train_df, tokenizer)
return test_df, train_df, da.labels


def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_set: int, output_dir: str,
n_windows: List[int], n_shots_per_window: Optional[int], n_runs: int,
random_seed: int, right_indentation: bool) -> None:
random_seed: int, right_indentation: bool) -> None {
_logger.info(f'Starting experiment with dataset: {dataset} and model: {model}')
pcw_model = load_pcw_wrapper(model, cache_dir, right_indentation, max(n_windows))

test_df, train_df, labels = get_dataset(dataset, pcw_model.tokenizer)

if n_shots_per_window is None:
# default behaviour: we take the maximum number of samples per window
if n_shots_per_window is None:
# Default behaviour: we take the maximum number of samples per window
n_shots_per_window = get_max_n_shots(train_df, test_df, pcw_model.tokenizer, pcw_model.context_window_size)
_logger.info(f"Found max n shot per window = {n_shots_per_window}")
_logger.info(f"Found max n shot per window = {n_shots_per_window} for dataset: {dataset}")

n_shots = [i * n_shots_per_window for i in n_windows]

Expand All @@ -43,6 +44,8 @@ def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_

accuracies = em.run_experiment_across_shots(n_shots, n_runs)
save_results(dataset, n_shots, accuracies, output_dir, model)
_logger.info(f'Experiment completed for dataset: {dataset} and model: {model}')
}


if __name__ == '__main__':
Expand All @@ -69,3 +72,5 @@ def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_
action='store_true', default=False)
args = parser.parse_args()
run_pcw_experiment(**vars(args))
}