forked from rasbt/LLMs-from-scratch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add standalone finetuning and evaluation scripts for chapter 7 (rasbt…
…#234) * add finetuning and eval scripts * update link * update links * fix link
- Loading branch information
Showing
10 changed files
with
153 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | ||
# Source for "Build a Large Language Model From Scratch" | ||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch | ||
# Code: https://github.com/rasbt/LLMs-from-scratch | ||
# | ||
# A minimal instruction finetuning file based on the code in chapter 7 | ||
|
||
import json | ||
import psutil | ||
from tqdm import tqdm | ||
import urllib.request | ||
|
||
|
||
def query_model(prompt, model="llama3", url="http://localhost:11434/api/chat"): | ||
# Create the data payload as a dictionary | ||
data = { | ||
"model": model, | ||
"seed": 123, # for deterministic responses | ||
"temperature": 0, # for deterministic responses | ||
"messages": [ | ||
{"role": "user", "content": prompt} | ||
] | ||
} | ||
|
||
# Convert the dictionary to a JSON formatted string and encode it to bytes | ||
payload = json.dumps(data).encode("utf-8") | ||
|
||
# Create a request object, setting the method to POST and adding necessary headers | ||
request = urllib.request.Request(url, data=payload, method="POST") | ||
request.add_header("Content-Type", "application/json") | ||
|
||
# Send the request and capture the response | ||
response_data = "" | ||
with urllib.request.urlopen(request) as response: | ||
# Read and decode the response | ||
while True: | ||
line = response.readline().decode("utf-8") | ||
if not line: | ||
break | ||
response_json = json.loads(line) | ||
response_data += response_json["message"]["content"] | ||
|
||
return response_data | ||
|
||
|
||
def check_if_running(process_name): | ||
running = False | ||
for proc in psutil.process_iter(["name"]): | ||
if process_name in proc.info["name"]: | ||
running = True | ||
break | ||
return running | ||
|
||
|
||
def format_input(entry): | ||
instruction_text = ( | ||
f"Below is an instruction that describes a task. " | ||
f"Write a response that appropriately completes the request." | ||
f"\n\n### Instruction:\n{entry['instruction']}" | ||
) | ||
|
||
input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" | ||
|
||
return instruction_text + input_text | ||
|
||
|
||
def main(file_path): | ||
ollama_running = check_if_running("ollama") | ||
|
||
if not ollama_running: | ||
raise RuntimeError("Ollama not running. Launch ollama before proceeding.") | ||
print("Ollama running:", check_if_running("ollama")) | ||
|
||
with open(file_path, "r") as file: | ||
test_data = json.load(file) | ||
|
||
model = "llama3" | ||
scores = generate_model_scores(test_data, "model_response", model) | ||
print(f"Number of scores: {len(scores)} of {len(test_data)}") | ||
print(f"Average score: {sum(scores)/len(scores):.2f}\n") | ||
|
||
|
||
def generate_model_scores(json_data, json_key, model="llama3"): | ||
scores = [] | ||
for entry in tqdm(json_data, desc="Scoring entries"): | ||
prompt = ( | ||
f"Given the input `{format_input(entry)}` " | ||
f"and correct output `{entry['output']}`, " | ||
f"score the model response `{entry[json_key]}`" | ||
f" on a scale from 0 to 100, where 100 is the best score. " | ||
f"Respond with the integer number only." | ||
) | ||
score = query_model(prompt, model) | ||
try: | ||
scores.append(int(score)) | ||
except ValueError: | ||
print(f"Could not convert score: {score}") | ||
continue | ||
|
||
return scores | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Instruction finetune a GPT model" | ||
) | ||
parser.add_argument( | ||
"--file_path", | ||
required=True, | ||
help=( | ||
"The path to the test dataset `.json` file with the" | ||
" `'output'` and `'model_response'` keys" | ||
) | ||
) | ||
args = parser.parse_args() | ||
|
||
main(file_path=args.file_path) |