forked from MikeBirdTech/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaicommit.py
224 lines (194 loc) · 7.37 KB
/
aicommit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#!/usr/bin/env python
import sys
import os
import ollama
import subprocess
import argparse
import time
from typing import List, Optional
from groq import Groq
# Use the environment variable for the Ollama model
OLLAMA_MODEL = os.getenv("FAST_OLLAMA_MODEL", "llama3.1")
GROQ_MODEL = "llama-3.1-8b-instant"
def get_git_diff() -> str:
"""Get the git diff of staged changes, or unstaged if no staged changes."""
try:
diff = subprocess.check_output(["git", "diff", "--cached"], text=True)
if not diff:
diff = subprocess.check_output(["git", "diff"], text=True)
return diff[:5000] # Limit to 5000 characters
except subprocess.CalledProcessError:
print("Error: Not a git repository or git is not installed.")
sys.exit(1)
def query_ollama(prompt: str) -> str:
"""Query Ollama with the given prompt."""
try:
print("Generating commit messages...", end="", flush=True)
response = ollama.generate(
model=OLLAMA_MODEL,
prompt=prompt,
system="You are an expert programmer that values clear, unambiguous communication and are specialized in generating concise and informative git commit messages.",
options={
"num_predict": 128,
},
keep_alive="2m",
)
print("Done!")
return response["response"]
except Exception as e:
print(f"\nError querying Ollama: {e}")
sys.exit(1)
def query_groq(prompt: str) -> str:
"""Query Groq with the given prompt."""
try:
print("Generating commit messages...", end="", flush=True)
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are an expert programmer that values clear, unambiguous communication and are specialized in generating concise and informative git commit messages.",
},
{
"role": "user",
"content": prompt,
},
],
model=GROQ_MODEL,
)
print("Done!")
return chat_completion.choices[0].message.content
except Exception as e:
print(f"\nError querying Groq: {e}")
sys.exit(1)
def parse_commit_messages(response: str) -> List[str]:
"""Parse the LLM response into a list of commit messages."""
messages = []
for line in response.split("\n"):
if line.strip().startswith(("1.", "2.", "3.")):
messages.append(line.split(".", 1)[1].strip())
return messages
def select_message_with_fzf(
messages: List[str], use_vim: bool = False, use_num: bool = False
) -> Optional[str]:
"""Use fzf to select a commit message, with option to regenerate."""
try:
messages.append("Regenerate messages")
fzf_args = [
"fzf",
"--height=10",
"--layout=reverse",
"--prompt=Select a commit message (ESC to cancel): ",
"--no-info",
"--margin=1,2",
"--border",
"--color=prompt:#D73BC9,pointer:#D73BC9",
]
if use_vim:
fzf_args.extend(["--bind", "j:down,k:up"])
if use_num:
for i, msg in enumerate(messages):
messages[i] = f"{i+1}. {msg}"
fzf_args.extend(
[
"--bind",
"1:accept-non-empty,2:accept-non-empty,3:accept-non-empty,4:accept-non-empty",
]
)
result = subprocess.run(
fzf_args,
input="\n".join(messages),
capture_output=True,
text=True,
)
if result.returncode == 130: # User pressed ESC
return None
selected = result.stdout.strip()
if selected == "Regenerate messages" or selected == "4. Regenerate messages":
return "regenerate"
return selected.split(". ", 1)[1] if use_num and selected else selected
except subprocess.CalledProcessError:
print("Error: fzf selection failed.")
return None
def create_commit(message: str):
"""Create a git commit with the selected message."""
try:
subprocess.run(["git", "commit", "-m", message], check=True)
print(f"Committed with message: {message}")
except subprocess.CalledProcessError:
print("Error: Failed to create commit.")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Generate git commit messages using LLMs."
)
parser.add_argument(
"--groq", action="store_true", help="Use Groq API instead of Ollama"
)
parser.add_argument(
"--analytics", action="store_true", help="Display performance analytics"
)
parser.add_argument(
"--vim", action="store_true", help="Use vim-style navigation in fzf"
)
parser.add_argument(
"--num", action="store_true", help="Use number selection for commit messages"
)
args = parser.parse_args()
start_time = time.time()
diff = get_git_diff()
if not diff:
print("No changes to commit.")
sys.exit(0)
prompt = f"""
Your task is to generate three concise, informative git commit messages based on the following git diff.
Be sure that each commit message reflects the entire diff.
It is very important that the entire commit is clear and understandable with each of the three options.
Each message should be on a new line, starting with a number and a period (e.g., '1.', '2.', '3.').
Here's the diff:\n\n{diff}"""
if args.groq:
response = query_groq(prompt)
else:
response = query_ollama(prompt)
end_time = time.time()
if args.analytics:
print(f"\nAnalytics:")
print(
f"Time taken to generate commit messages: {end_time - start_time:.2f} seconds"
)
print(f"Inference used: {'Groq' if args.groq else 'Ollama'}")
print(f"Model name: {GROQ_MODEL if args.groq else OLLAMA_MODEL}")
print("") # Add a blank line for better readability
commit_messages = parse_commit_messages(response)
if not commit_messages:
print("Error: Could not generate commit messages.")
sys.exit(1)
while True:
selected_message = select_message_with_fzf(
commit_messages, use_vim=args.vim, use_num=args.num
)
if selected_message == "regenerate":
start_time = time.time()
if args.groq:
response = query_groq(prompt)
else:
response = query_ollama(prompt)
end_time = time.time()
if args.analytics:
print(f"\nRegeneration Analytics:")
print(
f"Time taken to regenerate commit messages: {end_time - start_time:.2f} seconds"
)
print("") # Add a blank line for better readability
commit_messages = parse_commit_messages(response)
if not commit_messages:
print("Error: Could not generate commit messages.")
sys.exit(1)
elif selected_message:
create_commit(selected_message)
break
else:
print("Commit messages rejected. Please create commit message manually.")
break
if __name__ == "__main__":
main()