-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgit.py
119 lines (105 loc) · 4.36 KB
/
git.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from transformers import pipeline
# Global variable for storing the loaded image-to-text pipeline
generator = None
def load_model(model_choice: str) -> str:
"""
Load the selected GIT model.
Args:
model_choice (str): Model identifier ("git-large-r" or "git-base").
Returns:
str: Status message indicating the loaded model.
"""
model_dict = {
"git-large-r": "microsoft/git-large-r",
"git-base": "microsoft/git-base"
}
model_name = model_dict.get(model_choice, "microsoft/git-base")
global generator
generator = pipeline("image-to-text", model=model_name)
return f"Model '{model_choice}' loaded!"
def generate_caption(image, max_length, num_beams, temperature, top_k, top_p, repetition_penalty) -> str:
"""
Generate an image caption using the loaded model with specified generation parameters.
Args:
image (PIL.Image): Input image.
max_length (int): Maximum length of the generated caption.
num_beams (int): Beam width for beam search.
temperature (float): Sampling temperature.
top_k (int): Top-k sampling parameter.
top_p (float): Top-p (nucleus) sampling parameter.
repetition_penalty (float): Repetition penalty factor.
Returns:
str: Generated caption text.
"""
if generator is None:
return "Please load the model first!"
# Update generation configuration if available
if hasattr(generator.model, "generation_config"):
gen_config = generator.model.generation_config
gen_config.max_length = int(max_length)
gen_config.num_beams = int(num_beams)
gen_config.temperature = temperature
gen_config.top_k = int(top_k)
gen_config.top_p = top_p
gen_config.repetition_penalty = repetition_penalty
else:
config = generator.model.config
config.max_length = int(max_length)
config.num_beams = int(num_beams)
config.temperature = temperature
config.top_k = int(top_k)
config.top_p = top_p
config.repetition_penalty = repetition_penalty
results = generator(image)
return results[0]['generated_text']
def main():
"""
Main function to launch the Gradio demo.
"""
with gr.Blocks() as demo:
gr.Markdown("# GIT Image-to-Text Demo")
gr.Markdown("Select a model and click 'Load Model' before uploading an image to generate a caption.")
# Model selection and loading
with gr.Row():
model_dropdown = gr.Dropdown(
choices=["git-large-r", "git-base"],
value="git-base",
label="Select Model"
)
load_button = gr.Button("Load Model")
load_message = gr.Textbox(label="Status", interactive=False)
load_button.click(fn=load_model, inputs=model_dropdown, outputs=load_message)
gr.Markdown("---")
gr.Markdown("## Generation Parameters")
with gr.Row():
max_length_slider = gr.Slider(minimum=16, maximum=256, step=1, value=64, label="Max Length")
num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Num Beams")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature")
top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k")
with gr.Row():
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Top-p")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
gr.Markdown("---")
# Image upload and caption generation
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
caption_output = gr.Textbox(label="Generated Caption", interactive=False)
caption_button = gr.Button("Generate Caption")
caption_button.click(
fn=generate_caption,
inputs=[
image_input,
max_length_slider,
num_beams_slider,
temperature_slider,
top_k_slider,
top_p_slider,
repetition_penalty_slider
],
outputs=caption_output
)
demo.launch()
if __name__ == "__main__":
main()