-
Notifications
You must be signed in to change notification settings - Fork 31
Refactor get valid prompts - for memory optimization #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c156101
ba345b1
22c13ea
0ce59bf
952d432
0aa793f
42e51cf
e308e67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -518,112 +518,111 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: | |
| for v in program_map.values(): | ||
| random.Random(42).shuffle(v) | ||
|
|
||
|
|
||
| # select prompts that fit the batch size criteria | ||
| valid_prompts = [] | ||
| if custom_shape: | ||
| for program_criteria_seq, valid_prompt_shapes in program_map.items(): | ||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| if valid_prompt_shape == custom_shape: | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| ) | ||
| valid_prompts = [ | ||
| ( | ||
| def get_program_prompt_list(): | ||
| if custom_shape: | ||
| prompt_found = 0 | ||
| for program_criteria_seq, valid_prompt_shapes in program_map.items(): | ||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| if valid_prompt_shape == custom_shape: | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| ) | ||
| prompt_found = 1 | ||
| yield ( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change2: yield instead of list, flag set before yield |
||
| program_criteria_seq[0].program_id, | ||
| custom_shape, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) | ||
| ] | ||
| break | ||
| if prompt_found: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change3: see flag instead of length of list |
||
| break | ||
| if len(valid_prompts) > 0: | ||
| break | ||
| else: | ||
| for program_info in programs: | ||
| program_id = program_info.program_id | ||
| batch_size_limit = program_info.batch_size_limit | ||
| batch_size_limit_type = program_info.batch_size_limit_type | ||
| prompt_length_limit = program_info.prompt_length_limit | ||
| prompt_length_limit_type = program_info.prompt_length_limit_type | ||
|
|
||
| filtered_program_map = program_map | ||
| if program_id.isnumeric(): | ||
| filtered_program_map = { | ||
| k: v | ||
| for k, v in program_map.items() | ||
| if k[0] == program_criteria_list[int(program_id)] | ||
| } | ||
| used_keys = set() | ||
| # for each program, we need to check if we have a shape that satisfies the --programs request | ||
| for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): | ||
| # if ? or numeric => we need to check if we have found at least one valid key to stop | ||
| if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: | ||
| break | ||
| # if * => we need to see if we have found the first key to see if we should skip | ||
| elif program_id == "*" and program_seq_key[0] in used_keys: | ||
| continue | ||
|
|
||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| # make sure the criteria for batch limit and prompt limit is satisfied | ||
| # eval is safe here because we have limited what type and limit can be before | ||
|
|
||
| batch_check = eval( | ||
| f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" | ||
| ) | ||
| prompt_check = eval( | ||
| f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" | ||
| ) | ||
| if batch_check and prompt_check: | ||
| # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length | ||
| # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning | ||
| # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| if args.enforce_homogeneous_prompt_programs: | ||
| # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length | ||
| tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) | ||
| possible_seq_lengths = [ | ||
| _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) | ||
| ] | ||
| # favor sequences that are close to the valid prompt length | ||
| possible_seq_lengths.reverse() | ||
| enforce_sizes = enforce_sizes + list( | ||
| itertools.islice( | ||
| itertools.cycle(possible_seq_lengths), | ||
| valid_prompt_shape[0] - 1, | ||
| else: | ||
| for program_info in programs: | ||
| program_id = program_info.program_id | ||
| batch_size_limit = program_info.batch_size_limit | ||
| batch_size_limit_type = program_info.batch_size_limit_type | ||
| prompt_length_limit = program_info.prompt_length_limit | ||
| prompt_length_limit_type = program_info.prompt_length_limit_type | ||
|
|
||
| filtered_program_map = program_map | ||
| if program_id.isnumeric(): | ||
| filtered_program_map = { | ||
| k: v | ||
| for k, v in program_map.items() | ||
| if k[0] == program_criteria_list[int(program_id)] | ||
| } | ||
| used_keys = set() | ||
| # for each program, we need to check if we have a shape that satisfies the --programs request | ||
| for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): | ||
| # if ? or numeric => we need to check if we have found at least one valid key to stop | ||
| if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: | ||
| break | ||
| # if * => we need to see if we have found the first key to see if we should skip | ||
| elif program_id == "*" and program_seq_key[0] in used_keys: | ||
| continue | ||
|
|
||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| # make sure the criteria for batch limit and prompt limit is satisfied | ||
| # eval is safe here because we have limited what type and limit can be before | ||
|
|
||
| batch_check = eval( | ||
| f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" | ||
| ) | ||
| prompt_check = eval( | ||
| f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" | ||
| ) | ||
| if batch_check and prompt_check: | ||
| # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length | ||
| # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning | ||
| # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| if args.enforce_homogeneous_prompt_programs: | ||
| # this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length | ||
| tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1) | ||
| possible_seq_lengths = [ | ||
| _ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64) | ||
| ] | ||
| # favor sequences that are close to the valid prompt length | ||
| possible_seq_lengths.reverse() | ||
| enforce_sizes = enforce_sizes + list( | ||
| itertools.islice( | ||
| itertools.cycle(possible_seq_lengths), | ||
| valid_prompt_shape[0] - 1, | ||
| ) | ||
| ) | ||
| try: | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| ) | ||
| ) | ||
| try: | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| ) | ||
| valid_prompts.append( | ||
| ( | ||
| used_keys.add(program_seq_key[0]) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change 4: used_keys.add(program_seq_key[0]) before yield and then yield |
||
| yield ( | ||
| program_seq_key[0], | ||
| valid_prompt_shape, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) | ||
| ) | ||
| used_keys.add(program_seq_key[0]) | ||
| break | ||
| except ValueError: | ||
| dprint( | ||
| f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" | ||
| ) | ||
|
|
||
| if len(used_keys) == 0 and local_rank == 0: | ||
| dprint( | ||
| f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" | ||
| ) | ||
| break | ||
| except ValueError: | ||
| dprint( | ||
| f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" | ||
| ) | ||
|
|
||
| if len(used_keys) == 0 and local_rank == 0: | ||
| dprint( | ||
| f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" | ||
| ) | ||
|
|
||
|
|
||
| # metric calculator based on the cross-entropy and mean diff for each decode step | ||
|
|
@@ -642,7 +641,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): | |
|
|
||
| failed_cases = [] | ||
| # for each program and valid prompt (batch size, sequence length) | ||
| for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: | ||
| for ( | ||
| program_id, | ||
| valid_prompt, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) in get_program_prompt_list(): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change 5: call function to yield instead of list |
||
| extra_kwargs["attn_name"] = ATTN_NAME | ||
| if ( | ||
| "granite-3.3-8b-instruct" in model_variant | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since its hard to see git diff changes made in this PR:,
change 1 - use prompt_found flag as we are yielding instead of storing in list