diff --git a/run_evaluation.py b/run_evaluation.py index d0cec2f..5c6b3db 100644 --- a/run_evaluation.py +++ b/run_evaluation.py @@ -17,24 +17,25 @@ def get_dataset(dataset: str, tokenizer: PreTrainedTokenizerBase) -> (pd.DataFrame, pd.DataFrame, List): da = DATASET_NAMES2LOADERS[dataset]() # Filter extremely long samples from both train and test samples: - _logger.info("filtering test set:") + _logger.info(f"Filtering test set for dataset: {dataset}") test_df = filter_extremely_long_samples(da.test_df, tokenizer) - _logger.info("filtering train set:") + _logger.info(f"Filtering train set for dataset: {dataset}") train_df = filter_extremely_long_samples(da.train_df, tokenizer) return test_df, train_df, da.labels def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_set: int, output_dir: str, n_windows: List[int], n_shots_per_window: Optional[int], n_runs: int, - random_seed: int, right_indentation: bool) -> None: + random_seed: int, right_indentation: bool) -> None { + _logger.info(f'Starting experiment with dataset: {dataset} and model: {model}') pcw_model = load_pcw_wrapper(model, cache_dir, right_indentation, max(n_windows)) test_df, train_df, labels = get_dataset(dataset, pcw_model.tokenizer) - if n_shots_per_window is None: - # default behaviour: we take the maximum number of samples per window + if n_shots_per_window is None: + # Default behaviour: we take the maximum number of samples per window n_shots_per_window = get_max_n_shots(train_df, test_df, pcw_model.tokenizer, pcw_model.context_window_size) - _logger.info(f"Found max n shot per window = {n_shots_per_window}") + _logger.info(f"Found max n shot per window = {n_shots_per_window} for dataset: {dataset}") n_shots = [i * n_shots_per_window for i in n_windows] @@ -43,6 +44,8 @@ def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_ accuracies = em.run_experiment_across_shots(n_shots, n_runs) save_results(dataset, n_shots, accuracies, output_dir, model) + _logger.info(f'Experiment completed for dataset: {dataset} and model: {model}') +} if __name__ == '__main__': @@ -69,3 +72,5 @@ def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_ action='store_true', default=False) args = parser.parse_args() run_pcw_experiment(**vars(args)) +} +