Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable uploading multiple images in demo.py #232

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,32 @@ def setup_seeds(config):
# Gradio Setting
# ========================================

def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list

def upload_img(gr_img, text_input, chat_state):
def gradio_reset():
# reset chatbot, image, text_input, upload_button, chat_state, img_list, img_emb_list, gallery
return None, \
gr.update(value=None, interactive=True), \
gr.update(placeholder='Please upload your image first', interactive=False), \
gr.update(value="Upload & Start Chat", interactive=True), \
CONV_VISION.copy(), \
[], \
[], \
[]


def upload_img(gr_img, chat_state, img_list, img_emb_list):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
return None, None, gr.update(interactive=True), chat_state, img_list, img_emb_list
img_list.append(gr_img)
# upload an image to the chat
chat.upload_img(gr_img, chat_state, img_emb_list)
# update image, text_input, upload_button, chat_state, gallery, img_emb_list
return gr.update(value=None, interactive=False), \
gr.update(interactive=True, placeholder='Type and press Enter'), \
gr.update(value="Send more images after sending a message", interactive=False), \
chat_state, \
img_list, \
img_emb_list


def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
Expand All @@ -99,7 +111,12 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
max_new_tokens=300,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
# update chatbot, chat_state, image, upload_button
return chatbot, \
chat_state, \
gr.update(interactive=True), \
gr.update(value="Send more image", interactive=True)


title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
Expand Down Expand Up @@ -138,16 +155,26 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
)

with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chat_state = gr.State(CONV_VISION.copy())
img_list = gr.State([])
img_emb_list = gr.State([])
gallery = gr.Gallery(label="Uploaded Images", show_label=True) \
.style(rows=[1], object_fit="scale-down", height="500px", preview=True)
chatbot = gr.Chatbot(label='MiniGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)

upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])

text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)

upload_button.click(upload_img, [image, chat_state, img_list, img_emb_list],
[image, text_input, upload_button, chat_state, gallery, img_emb_list])

text_input \
.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]) \
.then(gradio_answer,
[chatbot, chat_state, img_emb_list, num_beams, temperature],
[chatbot, chat_state, image, upload_button])

clear.click(gradio_reset,
None,
[chatbot, image, text_input, upload_button, chat_state, img_list, img_emb_list, gallery],
queue=False)

demo.launch(share=True, enable_queue=True)