Skip to content

Commit 363a965

Browse files
committed
Fix image processing in cli.
1 parent df3a78d commit 363a965

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

llava/eval/run_llava.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from llava.model.builder import load_pretrained_model
1313
from llava.utils import disable_torch_init
1414
from llava.mm_utils import (
15+
process_images,
1516
tokenizer_image_token,
1617
get_model_name_from_path,
1718
KeywordsStoppingCriteria,
@@ -94,11 +95,11 @@ def eval_model(args):
9495

9596
image_files = image_parser(args)
9697
images = load_images(image_files)
97-
images_tensor = (
98-
image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
99-
.half()
100-
.cuda()
101-
)
98+
images_tensor = process_images(
99+
images,
100+
image_processor,
101+
model.config
102+
).to(model.device, dtype=torch.float16)
102103

103104
input_ids = (
104105
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")

llava/serve/cli.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def main(args):
5353

5454
image = load_image(args.image_file)
5555
# Similar operation in model_worker.py
56-
image_tensor = process_images([image], image_processor, args)
56+
image_tensor = process_images([image], image_processor, model.config)
5757
if type(image_tensor) is list:
5858
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
5959
else:
@@ -94,7 +94,7 @@ def main(args):
9494
output_ids = model.generate(
9595
input_ids,
9696
images=image_tensor,
97-
do_sample=True,
97+
do_sample=True if args.temperature > 0 else False,
9898
temperature=args.temperature,
9999
max_new_tokens=args.max_new_tokens,
100100
streamer=streamer,
@@ -120,6 +120,5 @@ def main(args):
120120
parser.add_argument("--load-8bit", action="store_true")
121121
parser.add_argument("--load-4bit", action="store_true")
122122
parser.add_argument("--debug", action="store_true")
123-
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
124123
args = parser.parse_args()
125124
main(args)

0 commit comments

Comments
 (0)