Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions examples/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
load_model,
print_colored,
string_to_float,
DatasetInfo,
load_json_file_dataset,
)

# Local
Expand Down Expand Up @@ -49,7 +51,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--dataset",
help="Dataset to use to train prompt vectors. Options: {}".format(
list(SUPPORTED_DATASETS.keys())
list(SUPPORTED_DATASETS.keys()) + ["json_file"]
),
default="twitter_complaints",
)
Expand All @@ -64,6 +66,46 @@ def parse_args() -> argparse.Namespace:
help="JSON file to dump raw source / target texts to.",
default="model_preds.json",
)

parser.add_argument(
"--json_file_path",
help="File path of the local JSON file to be loaded as dataset. It should be a JSON array or JSONL format.",
default=None,
type=pathlib.Path,
)
parser.add_argument(
"--json_file_input_field",
help="The input field to be used at the JSON file dataset",
default="input",
)
parser.add_argument(
"--json_file_output_field",
help="The output field to be used at the JSON file dataset",
default="output",
)
parser.add_argument(
"--json_file_init_text",
help="The init text to be used by the JSON file dataset",
default="",
)
parser.add_argument(
"--json_file_verbalizer",
help="The custom verbalizer to be used by the JSON file dataset",
default="{{input}}",
)
parser.add_argument(
"--json_file_test_size",
help="The percentage of the dataset that will be extracted as test set",
default=0.1,
type=float
)
parser.add_argument(
"--json_file_validation_size",
help="The percentage of the dataset that will be extracted as validation set",
default=0.1,
type=float
)

args = parser.parse_args()
return args

Expand All @@ -89,10 +131,11 @@ def get_model_preds_and_references(model, validation_stream):
for datum in tqdm(validation_stream):
# Local .run() currently prepends the input text to the generated string;
# Ensure that we're just splitting the first predicted token & beyond.
raw_model_text = model.run(datum.input).text
parse_pred_text = raw_model_text.split(datum.input)[-1].strip()
model_preds.append(parse_pred_text)
targets.append(datum.output)
if len(datum.input) > 0:
raw_model_text = model.run(datum.input).text
parse_pred_text = raw_model_text.split(datum.input)[-1].strip()
model_preds.append(parse_pred_text)
targets.append(datum.output)
return (
model_preds,
targets,
Expand Down Expand Up @@ -148,8 +191,23 @@ def export_model_preds(preds_file, predictions, validation_stream, verbalizer):
model = load_model(args.tgis, str(args.model_path))
# Load the validation stream with marked target sequences
print_colored("Grabbing validation data...")
dataset_info = SUPPORTED_DATASETS[args.dataset]
validation_stream = dataset_info.dataset_loader()[1]
if args.dataset != "json_file":
dataset_info = SUPPORTED_DATASETS[args.dataset]
validation_stream = dataset_info.dataset_loader()[1]
else:
dataset_info = DatasetInfo(
verbalizer=args.json_file_verbalizer,
dataset_loader=load_json_file_dataset(
str(args.json_file_path),
str(args.json_file_input_field),
str(args.json_file_output_field),
test_size=args.json_file_test_size,
validation_size=args.json_file_validation_size
),
init_text=args.json_file_init_text,
)
validation_stream = dataset_info.dataset_loader[1]

if validation_stream is None:
raise ValueError(
"Selected dataset does not have a validation dataset available!"
Expand Down
75 changes: 67 additions & 8 deletions examples/run_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SUPPORTED_DATASETS,
DatasetInfo,
configure_random_seed_and_logging,
load_json_file_dataset,
print_colored,
)
import datasets
Expand Down Expand Up @@ -166,7 +167,7 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non
subparser.add_argument(
"--dataset",
help="Dataset to use to train prompt vectors. Options: {}".format(
list(SUPPORTED_DATASETS.keys())
list(SUPPORTED_DATASETS.keys()) + ["json_file"]
),
default="twitter_complaints",
)
Expand Down Expand Up @@ -235,7 +236,44 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non
default=1,
type=int,
)

subparser.add_argument(
"--json_file_path",
help="File path of the local JSON file to be loaded as dataset. It should be a JSON array or JSONL format.",
default=None,
type=pathlib.Path,
)
subparser.add_argument(
"--json_file_input_field",
help="The input field to be used at the JSON file dataset",
default="input",
)
subparser.add_argument(
"--json_file_output_field",
help="The output field to be used at the JSON file dataset",
default="output",
)
subparser.add_argument(
"--json_file_init_text",
help="The init text to be used by the JSON file dataset",
default="",
)
subparser.add_argument(
"--json_file_verbalizer",
help="The custom verbalizer to be used by the JSON file dataset",
default="{{input}}",
)
subparser.add_argument(
"--json_file_test_size",
help="The percentage of the dataset that will be extracted as test set",
default=0.1,
type=float
)
subparser.add_argument(
"--json_file_validation_size",
help="The percentage of the dataset that will be extracted as validation set",
default=0.1,
type=float
)

def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser):
"""Register additional configuration options for MP(rompt)T subtask.
Expand Down Expand Up @@ -281,12 +319,16 @@ def validate_common_args(args: argparse.Namespace):
Parsed args corresponding to one tuning task.
"""
# Validate that the dataset is one of our allowed values
if args.dataset not in SUPPORTED_DATASETS:
if args.dataset != "json_file" and args.dataset not in SUPPORTED_DATASETS:
raise KeyError(
"[{}] is not a supported dataset; see --help for options.".format(
args.dataset
)
)
if args.dataset == "json_file" and args.json_file_path is None:
raise argparse.ArgumentError(
None, "--json_file_path is required when dataset value is json_file."
)
# Purge our output directory if one already exists.
if os.path.isdir(args.output_dir):
print("Existing model directory found; purging it now.")
Expand Down Expand Up @@ -378,13 +420,30 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
configure_random_seed_and_logging()
args = parse_args()
model_type = get_resource_type(args.model_name)
# Unpack the dataset dictionary into a loaded dataset & verbalizer
dataset_info = SUPPORTED_DATASETS[args.dataset]
show_experiment_configuration(args, dataset_info, model_type)
# Convert the loaded dataset to a stream
print_colored("[Loading the dataset...]")
# TODO - conditionally enable validation stream
train_stream = dataset_info.dataset_loader()[0]
# Unpack the dataset dictionary into a loaded dataset & verbalizer
if args.dataset != "json_file":
dataset_info = SUPPORTED_DATASETS[args.dataset]
# TODO - conditionally enable validation stream
train_stream = dataset_info.dataset_loader()[0]
else:
print(
args.json_file_path, args.json_file_input_field, args.json_file_output_field
)
dataset_info = DatasetInfo(
verbalizer=args.json_file_verbalizer,
dataset_loader=load_json_file_dataset(
str(args.json_file_path),
str(args.json_file_input_field),
str(args.json_file_output_field),
test_size=args.json_file_test_size,
validation_size=args.json_file_validation_size
),
init_text=args.json_file_init_text,
)
train_stream = dataset_info.dataset_loader[0]
show_experiment_configuration(args, dataset_info, model_type)
if args.num_shots is not None:
train_stream = subsample_stream(train_stream, args.num_shots)
# Init the resource & Build the tuning config from our dataset/arg info
Expand Down
31 changes: 31 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,37 @@ def to_generation_fmt(x):
return (train_stream, validation_stream, test_stream)


def load_json_file_dataset(
file_path, input_field, output_field, test_size=0.1, validation_size=0.1
) -> Tuple[caikit.core.data_model.DataStream]:
"""Load the dataset from local JSON file."""

def to_generation_fmt(x):
return GenerationTrainRecord(input=x[input_field], output=str(x[output_field]))

dataset = datasets.load_dataset("json", data_files=file_path)
if test_size > 0:
train_test_dataset = dataset["train"].train_test_split(test_size=test_size, shuffle=False)
else:
train_test_dataset = dataset
train_test_dataset["test"] = []
# # # Split the 10% test + valid into half test, half valid
if validation_size > 0:
test_valid = train_test_dataset["train"].train_test_split(test_size=validation_size, shuffle=False)
train_test_dataset["train"] = test_valid["train"]
train_test_dataset["validation"] = test_valid["test"]
else:
train_test_dataset["validation"] = []

build_stream = lambda split: caikit.core.data_model.DataStream.from_iterable(
[to_generation_fmt(datum) for datum in train_test_dataset[split]]
)
train_stream = build_stream("train")
validation_stream = build_stream("validation")
test_stream = build_stream("test")
return (train_stream, validation_stream, test_stream)


def get_wrapped_evaluate_metric(metric_name: str, convert_to_numeric: bool) -> Callable:
"""Wrapper for running metrics out of evaluate which operate on numeric arrays
named predictions & references, respectively.
Expand Down