| 
13 | 13 | See the License for the specific language governing permissions and  | 
14 | 14 | limitations under the License.  | 
15 | 15 | """  | 
16 |  | - | 
17 |  | -from typing import Sequence  | 
 | 16 | +import os  | 
18 | 17 | import torch  | 
 | 18 | +import torch.nn.functional as F  | 
 | 19 | +import argparse  | 
19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM  | 
20 |  | -import os  | 
21 |  | -from absl import app  # Removed flags  | 
 | 21 | +from tabulate import tabulate  | 
22 | 22 | 
 
  | 
23 | 23 | from MaxText.utils.ckpt_conversion.utils.hf_utils import (  | 
24 |  | -    check_predicted_tokens_match,  | 
 | 24 | +    # check_predicted_tokens_match,  | 
25 | 25 |     check_arrays_match,  | 
26 | 26 | )  | 
 | 27 | +from MaxText import max_logging  | 
27 | 28 | # Read Hugging Face token from environment variable  | 
28 | 29 | hf_token = os.environ.get("HF_AUTH_TOKEN")  | 
29 | 30 | 
 
  | 
 | 
40 | 41 |     huggingface_hub  | 
41 | 42 |     transformers  | 
42 | 43 |     accelerate  | 
 | 44 | +    tabulate  | 
43 | 45 | """  | 
44 | 46 | 
 
  | 
45 | 47 | 
 
  | 
@@ -70,46 +72,180 @@ def get_logits(inputs, model, golden_model):  | 
70 | 72 |   return logits, golden_logits  | 
71 | 73 | 
 
  | 
72 | 74 | 
 
  | 
73 |  | -def main(argv: Sequence[str]) -> None:  | 
74 |  | -  # Parse arguments from argv  | 
75 |  | -  # Default values  | 
76 |  | -  parsed_args = {"golden_model_id": "google/gemma-2-2b-it", "hf_checkpoint_path": os.path.expanduser("~/.hf_output/")}  | 
77 |  | -  for arg in argv[1:]:  | 
78 |  | -    if "=" in arg:  | 
79 |  | -      key, value = arg.split("=", 1)  | 
80 |  | -      if key in parsed_args:  | 
81 |  | -        parsed_args[key] = value  | 
82 |  | -      else:  | 
83 |  | -        print(f"Warning: Unknown argument '{key}' found in argv. Ignoring.")  | 
84 |  | - | 
85 |  | -  golden_model = AutoModelForCausalLM.from_pretrained(parsed_args["golden_model_id"], torch_dtype=torch.float32)  | 
86 |  | - | 
87 |  | -  tokenizer = AutoTokenizer.from_pretrained(parsed_args["hf_checkpoint_path"])  | 
88 |  | -  model = AutoModelForCausalLM.from_pretrained(parsed_args["hf_checkpoint_path"], torch_dtype=torch.float32)  | 
89 |  | - | 
90 |  | -  # TODO: (@yixuannwang) use 3 prompts to verify  | 
91 |  | -  input_text = "I love to"  | 
92 |  | -  inputs = tokenizer(input_text, return_tensors="pt")  | 
93 |  | -  # --- Generate Output ---  | 
94 |  | -  with torch.no_grad():  | 
95 |  | -    outputs = model.generate(**inputs, max_new_tokens=8)  | 
96 |  | -  # --- Decode and Print ---  | 
97 |  | -  print(tokenizer.decode(outputs[0], skip_special_tokens=True))  | 
98 |  | - | 
99 |  | -  # Check weights match  | 
100 |  | -  print("########### check weights match ############### ")  | 
101 |  | -  check_weights_match(model, golden_model)  | 
 | 75 | +def get_top_k_tokens_scores(logits_tensor, tokenizer_instance, k=10, description=""):  | 
 | 76 | +  """Get the top-k tokens and their scores from a given logits tensor."""  | 
 | 77 | +  max_logging.log(f"\n--- {description} top {k} tokens ---")  | 
 | 78 | +  collected_tokens = []  | 
 | 79 | +  tokens = []  | 
 | 80 | +  # Ensure logits_tensor is on CPU for operations like topk and item()  | 
 | 81 | +  logits_tensor = logits_tensor.cpu()  | 
 | 82 | +  topk_results = torch.topk(logits_tensor[0, -1], k=k)  | 
 | 83 | +  for i in range(k):  | 
 | 84 | +    tok_id = topk_results.indices[i].item()  | 
 | 85 | +    score = topk_results.values[i].item()  | 
 | 86 | +    tok = tokenizer_instance.decode(tok_id)  | 
 | 87 | +    collected_tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})  | 
 | 88 | +    tokens.append({"id": int(tok_id), "token": tok.strip(), "score": float(score)})  | 
 | 89 | + | 
 | 90 | +  # Prepare data for tabulate: a list of lists  | 
 | 91 | +  table_data = [[d["id"], d["token"], d["score"]] for d in collected_tokens]  | 
 | 92 | +  max_logging.log(tabulate(table_data, headers=["Token ID", "Token", "Score"], tablefmt="orgtbl"))  | 
 | 93 | +  return tokens  | 
 | 94 | + | 
 | 95 | + | 
 | 96 | +def compare_top_tokens(converted_tokens, golden_tokens):  | 
 | 97 | +  """  | 
 | 98 | +  Compares two lists of top tokens and calculates similarity metrics.  | 
 | 99 | +
  | 
 | 100 | +  Args:  | 
 | 101 | +      converted_tokens: top tokens from the converted model.  | 
 | 102 | +      golden_tokens:  top tokens from the golden model.  | 
 | 103 | +  """  | 
 | 104 | +  # Extract the sets of token IDs for comparison  | 
 | 105 | +  converted_ids = {token["id"] for token in converted_tokens}  | 
 | 106 | +  golden_ids = {token["id"] for token in golden_tokens}  | 
 | 107 | + | 
 | 108 | +  # --- Metric 1: Overlap Count & Jaccard Similarity ---  | 
 | 109 | +  intersection = converted_ids.intersection(golden_ids)  | 
 | 110 | +  union = converted_ids.union(golden_ids)  | 
 | 111 | + | 
 | 112 | +  overlap_count = len(intersection)  | 
 | 113 | +  jaccard_similarity = overlap_count / len(union) if union else 0.0  | 
 | 114 | + | 
 | 115 | +  # --- Metric 2: Rank Agreement ---  | 
 | 116 | +  rank_matches = 0  | 
 | 117 | +  min_len = min(len(converted_tokens), len(golden_tokens))  | 
 | 118 | +  for i in range(min_len):  | 
 | 119 | +    if converted_tokens[i]["id"] == golden_tokens[i]["id"]:  | 
 | 120 | +      rank_matches += 1  | 
 | 121 | + | 
 | 122 | +  rank_agreement = (rank_matches / min_len) * 100 if min_len > 0 else 0.0  | 
 | 123 | + | 
 | 124 | +  metrics = {  | 
 | 125 | +      "overlap_count": f"{overlap_count}/{min_len}",  | 
 | 126 | +      "jaccard_similarity": jaccard_similarity,  | 
 | 127 | +      "rank_agreement_percentage": rank_agreement,  | 
 | 128 | +  }  | 
 | 129 | + | 
 | 130 | +  max_logging.log("\n--- Similarity Metrics of Top Tokens ---")  | 
 | 131 | +  table = [[key, value] for key, value in metrics.items()]  | 
 | 132 | +  max_logging.log(tabulate(table, headers=["Metric", "Value"], tablefmt="orgtbl"))  | 
 | 133 | + | 
 | 134 | + | 
 | 135 | +def check_kl_divergence(model_logits, golden_logits, atol=0.02):  | 
 | 136 | +  """  | 
 | 137 | +  Calculates KL divergence D_KL(P_golden || Q_model) over a batch of sequences.  | 
 | 138 | +
  | 
 | 139 | +  Args:  | 
 | 140 | +      model_logits: Logits from the converted model (Batch, SeqLen, VocabSize).  | 
 | 141 | +      golden_logits: Logits from the golden model (Batch, SeqLen, VocabSize).  | 
 | 142 | +      token_size: The number of vocabulary entries to consider for the comparison.  | 
 | 143 | +                  (Effectively vocab_size_to_compare).  | 
 | 144 | +  """  | 
 | 145 | +  # 1. Select the relevant vocabulary slice from the logits.  | 
 | 146 | +  token_size = min(model_logits.shape[2], golden_logits.shape[2])  | 
 | 147 | +  model_logits_sliced = model_logits[..., :token_size]  | 
 | 148 | +  golden_logits_sliced = golden_logits[..., :token_size]  | 
 | 149 | + | 
 | 150 | +  # 2. Reshape  | 
 | 151 | +  b, s, v = model_logits_sliced.shape  | 
 | 152 | +  model_logits_reshaped = model_logits_sliced.view(b * s, v)  | 
 | 153 | +  golden_logits_reshaped = golden_logits_sliced.view(b * s, v)  | 
 | 154 | + | 
 | 155 | +  # 3. Get the probability distributions.  | 
 | 156 | +  golden_probabilities = F.softmax(golden_logits_reshaped, dim=-1)  | 
 | 157 | +  model_log_probabilities = F.log_softmax(model_logits_reshaped, dim=-1)  | 
 | 158 | + | 
 | 159 | +  # 4. Calculate avg KL divergence for all token distributions.  | 
 | 160 | +  # use 'batchmean'; the sum of the KL divergences for each token in the batch  | 
 | 161 | +  # and then divides by the number of tokens (b * s)  | 
 | 162 | +  kl_div_value = F.kl_div(  | 
 | 163 | +      input=model_log_probabilities,  | 
 | 164 | +      target=golden_probabilities,  | 
 | 165 | +      reduction="batchmean",  # Use 'batchmean' for the average KL per token.  | 
 | 166 | +      log_target=False,  | 
 | 167 | +  )  | 
 | 168 | + | 
 | 169 | +  max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.6f}")  | 
 | 170 | + | 
 | 171 | +  # To find the max KL divergence for any single token in the set  | 
 | 172 | +  # use reduction='none'.  | 
 | 173 | +  kl_divs_per_token = F.kl_div(  | 
 | 174 | +      input=model_log_probabilities, target=golden_probabilities, reduction="none", log_target=False  | 
 | 175 | +  ).sum(  | 
 | 176 | +      dim=-1  | 
 | 177 | +  )  # Sum over the vocab dim to get a single KL value per token  | 
 | 178 | + | 
 | 179 | +  max_kl_div = kl_divs_per_token.max()  | 
 | 180 | +  max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.6f}")  | 
 | 181 | + | 
 | 182 | +  assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}"  | 
 | 183 | + | 
 | 184 | + | 
 | 185 | +def run_prompts(args: argparse.Namespace) -> None:  | 
 | 186 | +  """  | 
 | 187 | +  Args:  | 
 | 188 | +      - golden_model_id (str): HF model ID for the golden model.  | 
 | 189 | +      - hf_checkpoint_path (str): Path to the converted HF checkpoint.  | 
 | 190 | +      - max_kl_div (float): Maximum allowed KL divergence.  | 
 | 191 | +  """  | 
 | 192 | +  golden_model = AutoModelForCausalLM.from_pretrained(args.golden_model_id, torch_dtype=torch.bfloat16)  | 
 | 193 | +  golden_tokenizer = AutoTokenizer.from_pretrained(args.golden_model_id)  | 
 | 194 | + | 
 | 195 | +  tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint_path)  | 
 | 196 | +  model, _ = AutoModelForCausalLM.from_pretrained(  | 
 | 197 | +      args.hf_checkpoint_path, trust_remote_code=True, torch_dtype=torch.bfloat16, output_loading_info=True  | 
 | 198 | +  )  | 
 | 199 | + | 
 | 200 | +  # max_logging.log(loading_info)  | 
 | 201 | + | 
 | 202 | +  prompts = ["I love to", "Today is a", "What is the"]  | 
 | 203 | +  for input_text in prompts:  | 
 | 204 | +    max_logging.log(f"\n--- Prompt: {input_text} ---")  | 
 | 205 | +    inputs = tokenizer(input_text, return_tensors="pt")  | 
 | 206 | +    # --- Generate Output ---  | 
 | 207 | +    with torch.no_grad():  | 
 | 208 | +      outputs = model.generate(**inputs, max_new_tokens=15, do_sample=False)  | 
 | 209 | +    # --- Decode and Print ---  | 
 | 210 | +    max_logging.log(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")  | 
 | 211 | + | 
 | 212 | +    # --- Compare tokens ---  | 
 | 213 | +    model_logits, golden_model_logits = get_logits(inputs, model, golden_model)  | 
 | 214 | +    tokens = get_top_k_tokens_scores(model_logits, tokenizer, k=10, description="converted model")  | 
 | 215 | +    golden_tokens = get_top_k_tokens_scores(golden_model_logits, golden_tokenizer, k=10, description="golden model")  | 
 | 216 | +    compare_top_tokens(converted_tokens=tokens, golden_tokens=golden_tokens)  | 
 | 217 | + | 
 | 218 | +    check_kl_divergence(model_logits, golden_model_logits, atol=args.max_kl_div)  | 
 | 219 | + | 
 | 220 | +  """  | 
 | 221 | +  if the model's structure is exactly the same as the golden model (layers, vocab_size, etc.),   | 
 | 222 | +  you can check more weights details using the following steps:  | 
102 | 223 | 
  | 
103 |  | -  # Run forward pass to get logits  | 
104 |  | -  logits, golden_logits = get_logits(inputs, model, golden_model)  | 
 | 224 | +  check_weights_match(model, golden_model)  | 
105 | 225 | 
  | 
106 | 226 |   # Check logits from the first 5 tokens match  | 
107 |  | -  print("########### check logits match ############### ")  | 
108 |  | -  check_arrays_match(logits[0, :5, :], golden_logits[0, :5, :], atol=0.2)  | 
 | 227 | +  check_arrays_match(model_logits[0, :5, :], golden_model_logits[0, :5, :], atol=0.2)  | 
109 | 228 | 
  | 
110 |  | -  print("########### check predicted token match ############### ")  | 
111 |  | -  check_predicted_tokens_match(logits, golden_logits)  | 
 | 229 | +  check_predicted_tokens_match(model_logits, golden_model_logits)  | 
 | 230 | +  """  | 
112 | 231 | 
 
  | 
113 | 232 | 
 
  | 
114 | 233 | if __name__ == "__main__":  | 
115 |  | -  app.run(main)  | 
 | 234 | +  parser = argparse.ArgumentParser(description="Verify HuggingFace checkpoints converted from MaxText.")  | 
 | 235 | +  parser.add_argument(  | 
 | 236 | +      "--golden_model_id",  | 
 | 237 | +      type=str,  | 
 | 238 | +      default="google/gemma-2-2b-it",  | 
 | 239 | +      help="The HuggingFace model ID for the golden/reference model.",  | 
 | 240 | +  )  | 
 | 241 | +  parser.add_argument(  | 
 | 242 | +      "--hf_checkpoint_path",  | 
 | 243 | +      type=str,  | 
 | 244 | +      default=os.path.expanduser("~/.hf_output/"),  | 
 | 245 | +      help="Path to the converted HuggingFace checkpoint directory.",  | 
 | 246 | +  )  | 
 | 247 | +  parser.add_argument("--max_kl_div", type=float, default=0.02, help="Maximum allowed KL divergence between model logits.")  | 
 | 248 | + | 
 | 249 | +  parsed_args = parser.parse_args()  | 
 | 250 | + | 
 | 251 | +  run_prompts(parsed_args)  | 
0 commit comments