diff --git a/examples/evaluate_model.py b/examples/evaluate_model.py index 7e7f9410..d2bff821 100644 --- a/examples/evaluate_model.py +++ b/examples/evaluate_model.py @@ -18,6 +18,8 @@ load_model, print_colored, string_to_float, + DatasetInfo, + load_json_file_dataset, ) # Local @@ -49,7 +51,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--dataset", help="Dataset to use to train prompt vectors. Options: {}".format( - list(SUPPORTED_DATASETS.keys()) + list(SUPPORTED_DATASETS.keys()) + ["json_file"] ), default="twitter_complaints", ) @@ -64,6 +66,46 @@ def parse_args() -> argparse.Namespace: help="JSON file to dump raw source / target texts to.", default="model_preds.json", ) + + parser.add_argument( + "--json_file_path", + help="File path of the local JSON file to be loaded as dataset. It should be a JSON array or JSONL format.", + default=None, + type=pathlib.Path, + ) + parser.add_argument( + "--json_file_input_field", + help="The input field to be used at the JSON file dataset", + default="input", + ) + parser.add_argument( + "--json_file_output_field", + help="The output field to be used at the JSON file dataset", + default="output", + ) + parser.add_argument( + "--json_file_init_text", + help="The init text to be used by the JSON file dataset", + default="", + ) + parser.add_argument( + "--json_file_verbalizer", + help="The custom verbalizer to be used by the JSON file dataset", + default="{{input}}", + ) + parser.add_argument( + "--json_file_test_size", + help="The percentage of the dataset that will be extracted as test set", + default=0.1, + type=float + ) + parser.add_argument( + "--json_file_validation_size", + help="The percentage of the dataset that will be extracted as validation set", + default=0.1, + type=float + ) + args = parser.parse_args() return args @@ -89,10 +131,11 @@ def get_model_preds_and_references(model, validation_stream): for datum in tqdm(validation_stream): # Local .run() currently prepends the input text to the generated string; # Ensure that we're just splitting the first predicted token & beyond. - raw_model_text = model.run(datum.input).text - parse_pred_text = raw_model_text.split(datum.input)[-1].strip() - model_preds.append(parse_pred_text) - targets.append(datum.output) + if len(datum.input) > 0: + raw_model_text = model.run(datum.input).text + parse_pred_text = raw_model_text.split(datum.input)[-1].strip() + model_preds.append(parse_pred_text) + targets.append(datum.output) return ( model_preds, targets, @@ -148,8 +191,23 @@ def export_model_preds(preds_file, predictions, validation_stream, verbalizer): model = load_model(args.tgis, str(args.model_path)) # Load the validation stream with marked target sequences print_colored("Grabbing validation data...") - dataset_info = SUPPORTED_DATASETS[args.dataset] - validation_stream = dataset_info.dataset_loader()[1] + if args.dataset != "json_file": + dataset_info = SUPPORTED_DATASETS[args.dataset] + validation_stream = dataset_info.dataset_loader()[1] + else: + dataset_info = DatasetInfo( + verbalizer=args.json_file_verbalizer, + dataset_loader=load_json_file_dataset( + str(args.json_file_path), + str(args.json_file_input_field), + str(args.json_file_output_field), + test_size=args.json_file_test_size, + validation_size=args.json_file_validation_size + ), + init_text=args.json_file_init_text, + ) + validation_stream = dataset_info.dataset_loader[1] + if validation_stream is None: raise ValueError( "Selected dataset does not have a validation dataset available!" diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index 44d7d28d..53221516 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -26,6 +26,7 @@ SUPPORTED_DATASETS, DatasetInfo, configure_random_seed_and_logging, + load_json_file_dataset, print_colored, ) import datasets @@ -166,7 +167,7 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non subparser.add_argument( "--dataset", help="Dataset to use to train prompt vectors. Options: {}".format( - list(SUPPORTED_DATASETS.keys()) + list(SUPPORTED_DATASETS.keys()) + ["json_file"] ), default="twitter_complaints", ) @@ -235,7 +236,44 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non default=1, type=int, ) - + subparser.add_argument( + "--json_file_path", + help="File path of the local JSON file to be loaded as dataset. It should be a JSON array or JSONL format.", + default=None, + type=pathlib.Path, + ) + subparser.add_argument( + "--json_file_input_field", + help="The input field to be used at the JSON file dataset", + default="input", + ) + subparser.add_argument( + "--json_file_output_field", + help="The output field to be used at the JSON file dataset", + default="output", + ) + subparser.add_argument( + "--json_file_init_text", + help="The init text to be used by the JSON file dataset", + default="", + ) + subparser.add_argument( + "--json_file_verbalizer", + help="The custom verbalizer to be used by the JSON file dataset", + default="{{input}}", + ) + subparser.add_argument( + "--json_file_test_size", + help="The percentage of the dataset that will be extracted as test set", + default=0.1, + type=float + ) + subparser.add_argument( + "--json_file_validation_size", + help="The percentage of the dataset that will be extracted as validation set", + default=0.1, + type=float + ) def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser): """Register additional configuration options for MP(rompt)T subtask. @@ -281,12 +319,16 @@ def validate_common_args(args: argparse.Namespace): Parsed args corresponding to one tuning task. """ # Validate that the dataset is one of our allowed values - if args.dataset not in SUPPORTED_DATASETS: + if args.dataset != "json_file" and args.dataset not in SUPPORTED_DATASETS: raise KeyError( "[{}] is not a supported dataset; see --help for options.".format( args.dataset ) ) + if args.dataset == "json_file" and args.json_file_path is None: + raise argparse.ArgumentError( + None, "--json_file_path is required when dataset value is json_file." + ) # Purge our output directory if one already exists. if os.path.isdir(args.output_dir): print("Existing model directory found; purging it now.") @@ -378,13 +420,30 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: configure_random_seed_and_logging() args = parse_args() model_type = get_resource_type(args.model_name) - # Unpack the dataset dictionary into a loaded dataset & verbalizer - dataset_info = SUPPORTED_DATASETS[args.dataset] - show_experiment_configuration(args, dataset_info, model_type) # Convert the loaded dataset to a stream print_colored("[Loading the dataset...]") - # TODO - conditionally enable validation stream - train_stream = dataset_info.dataset_loader()[0] + # Unpack the dataset dictionary into a loaded dataset & verbalizer + if args.dataset != "json_file": + dataset_info = SUPPORTED_DATASETS[args.dataset] + # TODO - conditionally enable validation stream + train_stream = dataset_info.dataset_loader()[0] + else: + print( + args.json_file_path, args.json_file_input_field, args.json_file_output_field + ) + dataset_info = DatasetInfo( + verbalizer=args.json_file_verbalizer, + dataset_loader=load_json_file_dataset( + str(args.json_file_path), + str(args.json_file_input_field), + str(args.json_file_output_field), + test_size=args.json_file_test_size, + validation_size=args.json_file_validation_size + ), + init_text=args.json_file_init_text, + ) + train_stream = dataset_info.dataset_loader[0] + show_experiment_configuration(args, dataset_info, model_type) if args.num_shots is not None: train_stream = subsample_stream(train_stream, args.num_shots) # Init the resource & Build the tuning config from our dataset/arg info diff --git a/examples/utils.py b/examples/utils.py index cfb944fd..fc1506a2 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -280,6 +280,37 @@ def to_generation_fmt(x): return (train_stream, validation_stream, test_stream) +def load_json_file_dataset( + file_path, input_field, output_field, test_size=0.1, validation_size=0.1 +) -> Tuple[caikit.core.data_model.DataStream]: + """Load the dataset from local JSON file.""" + + def to_generation_fmt(x): + return GenerationTrainRecord(input=x[input_field], output=str(x[output_field])) + + dataset = datasets.load_dataset("json", data_files=file_path) + if test_size > 0: + train_test_dataset = dataset["train"].train_test_split(test_size=test_size, shuffle=False) + else: + train_test_dataset = dataset + train_test_dataset["test"] = [] + # # # Split the 10% test + valid into half test, half valid + if validation_size > 0: + test_valid = train_test_dataset["train"].train_test_split(test_size=validation_size, shuffle=False) + train_test_dataset["train"] = test_valid["train"] + train_test_dataset["validation"] = test_valid["test"] + else: + train_test_dataset["validation"] = [] + + build_stream = lambda split: caikit.core.data_model.DataStream.from_iterable( + [to_generation_fmt(datum) for datum in train_test_dataset[split]] + ) + train_stream = build_stream("train") + validation_stream = build_stream("validation") + test_stream = build_stream("test") + return (train_stream, validation_stream, test_stream) + + def get_wrapped_evaluate_metric(metric_name: str, convert_to_numeric: bool) -> Callable: """Wrapper for running metrics out of evaluate which operate on numeric arrays named predictions & references, respectively.