Skip to content

💾 Store bounding boxes in text file on inference #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
18 changes: 13 additions & 5 deletions yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
if "model_state_dict" in weights:
weights = weights["model_state_dict"]
if "state_dict" in weights:
weights = weights["state_dict"]

model_state_dict = self.model.state_dict()
model_state_dict = self.state_dict()

# TODO1: autoload old version weight
# TODO2: weight transform if num_class difference

error_dict = {"Mismatch": set(), "Not Found": set()}
for model_key, model_weight in model_state_dict.items():
if model_key not in weights:

weights_key = model_key
if weights_key not in weights: # .ckpt
weights_key = "model." + model_key
if weights_key not in weights: # .pt old
weights_key = model_key[6:]
if weights_key not in weights:
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
continue
if model_weight.shape != weights[model_key].shape:
if model_weight.shape != weights[weights_key].shape:
error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
continue
model_state_dict[model_key] = weights[model_key]
model_state_dict[model_key] = weights[weights_key]

for error_name, error_set in error_dict.items():
for weight_name in error_set:
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")

self.model.load_state_dict(model_state_dict)
self.load_state_dict(model_state_dict)


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
Expand Down
24 changes: 15 additions & 9 deletions yolo/tools/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st


class StreamDataLoader:
def __init__(self, data_cfg: DataConfig):
def __init__(self, data_cfg: DataConfig, asynchronous: bool = True):
self.source = data_cfg.source
self.running = True
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
Expand All @@ -249,8 +249,12 @@ def __init__(self, data_cfg: DataConfig):
else:
self.source = Path(self.source)
self.queue = Queue()
self.thread = Thread(target=self.load_source)
self.thread.start()

if asynchronous:
self.thread = Thread(target=self.load_source)
self.thread.start()
else:
self.load_source()

def load_source(self):
if self.source.is_dir(): # image folder
Expand All @@ -272,20 +276,22 @@ def process_image(self, image_path):
image = Image.open(image_path).convert("RGB")
if image is None:
raise ValueError(f"Error loading image: {image_path}")
self.process_frame(image)
self.process_frame(image, image_path)

def load_video_file(self, video_path):
import cv2

cap = cv2.VideoCapture(str(video_path))
frame_idx = 0
while self.running:
ret, frame = cap.read()
if not ret:
break
self.process_frame(frame)
self.process_frame(frame, f"{video_path.stem}_frame{frame_idx:04d}.png")
frame_idx += 1
cap.release()

def process_frame(self, frame):
def process_frame(self, frame, image_path):
if isinstance(frame, np.ndarray):
# TODO: we don't need cv2
import cv2
Expand All @@ -297,9 +303,9 @@ def process_frame(self, frame):
frame = frame[None]
rev_tensor = rev_tensor[None]
if not self.is_stream:
self.queue.put((frame, rev_tensor, origin_frame))
self.queue.put((frame, rev_tensor, origin_frame, image_path))
else:
self.current_frame = (frame, rev_tensor, origin_frame)
self.current_frame = (frame, rev_tensor, origin_frame, image_path)

def __iter__(self) -> Generator[Tensor, None, None]:
return self
Expand All @@ -310,7 +316,7 @@ def __next__(self) -> Tensor:
if not ret:
self.stop()
raise StopIteration
self.process_frame(frame)
self.process_frame(frame, "stream_frame.png")
return self.current_frame
else:
try:
Expand Down
22 changes: 19 additions & 3 deletions yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
Expand Down Expand Up @@ -112,7 +112,9 @@ def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
# StreamDataLoader has to be synchronous, otherwise not all images are loaded
# TODO: Make this load in parallel
self.predict_loader = StreamDataLoader(cfg.task.data, asynchronous=False)

def setup(self, stage):
self.vec2box = create_converter(
Expand All @@ -124,15 +126,29 @@ def predict_dataloader(self):
return self.predict_loader

def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
images, rev_tensor, origin_frame, image_path = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)

if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
fps = None

if getattr(self.cfg.task, "save_predict", None):
self._save_image(img, batch_idx)

output_txt_file = Path(getattr(self.cfg, "out_path")) / f"results.txt"

# save predics to file img.name .txt, space separated
with open(output_txt_file, "a") as f:
for bboxes in predicts:
for bbox in bboxes:
class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
f.write(f"{image_path.name} {int(class_id)} {x_min} {y_min} {x_max} {y_max} {conf[0]}\n")

print(f"💾 Saved predictions at {output_txt_file}")

return img, fps

def _save_image(self, img, batch_idx):
Expand Down
3 changes: 2 additions & 1 deletion yolo/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx:
epoch_descript = "[cyan]Train [white]|"
batch_descript = "[green]Train [white]|"
metrics = self.get_metrics(trainer, pl_module)
metrics.pop("v_num")
if "v_num" in metrics:
metrics.pop("v_num")
for metrics_name, metrics_val in metrics.items():
if "Loss_step" in metrics_name:
epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
Expand Down
Loading