forked from lucasjinreal/FruitsNutsSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
40 lines (34 loc) · 1.37 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import random
from detectron2.utils.visualizer import Visualizer
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
import fruitsnuts_data
import cv2
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode
fruits_nuts_metadata = MetadataCatalog.get("fruits_nuts")
if __name__ == "__main__":
cfg = get_cfg()
cfg.merge_from_file(
"../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
)
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
print('loading from: {}'.format(cfg.MODEL.WEIGHTS))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set the testing threshold for this model
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
cfg.DATASETS.TEST = ("fruits_nuts", )
predictor = DefaultPredictor(cfg)
data_f = './data/images/2.jpg'
im = cv2.imread(data_f)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=fruits_nuts_metadata,
scale=0.8,
instance_mode=ColorMode.IMAGE_BW # remove the colors of unsegmented pixels
)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img = v.get_image()[:, :, ::-1]
cv2.imshow('rr', img)
cv2.waitKey(0)