Skip to content

Commit

Permalink
🔀 [Merge] branch 'DEPLOY'
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Jan 3, 2025
2 parents f080104 + 5f0e785 commit 2b6f538
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions demo/hf_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from yolo import (
AugmentationComposer,
NMSConfig,
PostProccess,
PostProcess,
create_converter,
create_model,
draw_bboxes,
Expand All @@ -20,36 +20,35 @@
IMAGE_SIZE = (640, 640)


def load_model(model_name, device):
def load_model(model_name):
model_cfg = OmegaConf.load(f"yolo/config/model/{model_name}.yaml")
model_cfg.model.auxiliary = {}
model = create_model(model_cfg, True)
model.to(device).eval()
return model, model_cfg
converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
model = model.to(device).eval()
return model, converter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, model_cfg = load_model(DEFAULT_MODEL, device)
converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
model, converter = load_model(DEFAULT_MODEL)
class_list = OmegaConf.load("yolo/config/dataset/coco.yaml").class_list

transform = AugmentationComposer([])


def predict(model_name, image, nms_confidence, nms_iou):
def predict(model_name, image, nms_confidence, nms_iou, max_bbox):
global DEFAULT_MODEL, model, device, converter, class_list, post_proccess
if model_name != DEFAULT_MODEL:
model, model_cfg = load_model(model_name, device)
converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
model, converter = load_model(model_name)
DEFAULT_MODEL = model_name

image_tensor, _, rev_tensor = transform(image)

image_tensor = image_tensor.to(device)[None]
rev_tensor = rev_tensor.to(device)[None]

nms_config = NMSConfig(nms_confidence, nms_iou)
post_proccess = PostProccess(converter, nms_config)
nms_config = NMSConfig(nms_confidence, nms_iou, max_bbox)
post_proccess = PostProcess(converter, nms_config)

with torch.no_grad():
predict = model(image_tensor)
Expand All @@ -67,6 +66,7 @@ def predict(model_name, image, nms_confidence, nms_iou):
gradio.components.Image(type="pil", label="Input Image"),
gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS Confidence Threshold"),
gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS IoU Threshold"),
gradio.components.Slider(0, 1000, step=10, value=400, label="Max Bounding Box Number"),
],
outputs=gradio.components.Image(type="pil", label="Output Image"),
)
Expand Down

0 comments on commit 2b6f538

Please sign in to comment.