From fd516c2d6fc1d5189a58d72309f2bb74bc12570e Mon Sep 17 00:00:00 2001 From: chenhaijin Date: Fri, 12 Sep 2025 11:28:07 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dsam2=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E6=9D=83=E9=99=90=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo/backend/server/app_conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/backend/server/app_conf.py b/demo/backend/server/app_conf.py index eea777289..1b38bcc3c 100644 --- a/demo/backend/server/app_conf.py +++ b/demo/backend/server/app_conf.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -APP_ROOT = os.getenv("APP_ROOT", "/opt/sam2") +APP_ROOT = os.getenv("APP_ROOT", "./sam2") API_URL = os.getenv("API_URL", "http://localhost:7263") @@ -20,7 +20,7 @@ FFMPEG_NUM_THREADS = int(os.getenv("FFMPEG_NUM_THREADS", "1")) # Path for all data used in API -DATA_PATH = Path(os.getenv("DATA_PATH", "/data")) +DATA_PATH = Path(os.getenv("DATA_PATH", "./data")) # Max duration an uploaded video can have in seconds. The default is 10 # seconds. From cc5abed8510a3b5062985898ea360d01c55c8dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E6=AD=A3=20=E7=8E=8B?= Date: Wed, 3 Dec 2025 08:51:19 +0800 Subject: [PATCH 2/9] add masks predictor for single img --- demo/backend/server/app.py | 79 +++++++++++++- demo/backend/server/app_conf.py | 5 + demo/backend/server/inference/predictor.py | 119 ++++++++++++++++++++- 3 files changed, 200 insertions(+), 3 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index 424e85bb5..483a82035 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import time from typing import Any, Generator from app_conf import ( @@ -11,18 +12,29 @@ GALLERY_PREFIX, POSTERS_PATH, POSTERS_PREFIX, + TEMP_PATH, + TEMP_PREFIX, UPLOADS_PATH, UPLOADS_PREFIX, ) from data.loader import preload_data from data.schema import schema from data.store import set_videos -from flask import Flask, make_response, Request, request, Response, send_from_directory +from flask import ( + Flask, + make_response, + Request, + jsonify, + request, + Response, + send_from_directory, +) from flask_cors import CORS from inference.data_types import PropagateDataResponse, PropagateInVideoRequest from inference.multipart import MultipartResponseBuilder from inference.predictor import InferenceAPI from strawberry.flask.views import GraphQLView +import json logger = logging.getLogger(__name__) @@ -71,6 +83,71 @@ def send_uploaded_video(path: str): ) except: raise ValueError("resource not found") + +@app.route(f"/mask", methods=["POST"]) +def predict_image() -> Response: + data = request.json + start_time = time.time() + res = inference_api.predict_image(data["url"],data["points"],data["labels"],None,True) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"mask生成时间: {elapsed_time:.6f} 秒") + return Response( + res, + mimetype="image/jpeg", + headers={ + "Content-Disposition": "attachment; filename=mask.jpg" + } + ) + +@app.route(f"/masks", methods=["POST"]) +def generate_masks() -> Response: + data = request.json + start_time = time.time() + res = inference_api.generate_masks(data["url"]) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"masks生成时间: {elapsed_time:.6f} 秒") + return Response( + res, + mimetype="image/png", + headers={ + "Content-Disposition": "attachment; filename=mask.png" + } + ) + + +@app.route("/image_masks_save", methods=["POST"]) +def image_masks_save() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + filename_prefix = data.get("name") or "mask" + + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + timestamp = str(int(time.time() * 1000)) + output_dir = TEMP_PATH / timestamp + + try: + saved_files = inference_api.save_masks_to_dir( + image_input=image_input, + output_dir=output_dir, + filename_prefix=filename_prefix, + ) + except Exception as exc: + logger.exception("failed to save image masks") + return jsonify({"error": f"failed to save masks: {exc}"}), 500 + + rel_dir = f"{TEMP_PREFIX}/{timestamp}" + masks_payload = [{"path": f"{rel_dir}/{name}"} for name in saved_files] + return jsonify( + { + "count": len(saved_files), + "saved_dir": rel_dir, + "masks": masks_payload, + } + ) # TOOD: Protect route with ToS permission check diff --git a/demo/backend/server/app_conf.py b/demo/backend/server/app_conf.py index 1b38bcc3c..417bcdca8 100644 --- a/demo/backend/server/app_conf.py +++ b/demo/backend/server/app_conf.py @@ -48,8 +48,13 @@ # Path where all posters are stored POSTERS_PATH = DATA_PATH / POSTERS_PREFIX +# Prefix and path for temporary assets (e.g. generated masks) +TEMP_PREFIX = "temp" +TEMP_PATH = DATA_PATH / TEMP_PREFIX + # Make sure any of those paths exist os.makedirs(DATA_PATH, exist_ok=True) os.makedirs(GALLERY_PATH, exist_ok=True) os.makedirs(UPLOADS_PATH, exist_ok=True) os.makedirs(POSTERS_PATH, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index ff0dab233..efc6c3635 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -7,6 +7,7 @@ import logging import os import uuid +import base64 from pathlib import Path from threading import Lock from typing import Any, Dict, Generator, List @@ -34,8 +35,13 @@ StartSessionResponse, ) from pycocotools.mask import decode as decode_masks, encode as encode_masks -from sam2.build_sam import build_sam2_video_predictor - +from sam2.build_sam import build_sam2_video_predictor, build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from PIL import Image +import requests +from io import BytesIO +import cv2 logger = logging.getLogger(__name__) @@ -89,8 +95,117 @@ def __init__(self) -> None: self.predictor = build_sam2_video_predictor( model_cfg, checkpoint, device=device ) + sam2_model = build_sam2(model_cfg, checkpoint, device=device) + self.img_predictor = SAM2ImagePredictor(sam2_model) + self.mask_generator = SAM2AutomaticMaskGenerator( + sam2_model, + points_per_side=32, + points_per_batch=64, + pred_iou_thresh=0.7, + stability_score_thresh=0.92, + stability_score_offset=0.7, + crop_n_layers=1, + box_nms_thresh=0.7, + crop_n_points_downscale_factor=2, + min_mask_region_area=100.0, + use_m2m=True + ) + self.current_img = None self.inference_lock = Lock() + def _load_image(self, image_input: str) -> Image.Image: + """Load an image from a data URI, local path, or URL.""" + if image_input.startswith("data:"): + header, encoded = image_input.split(",", 1) + return Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB") + + local_path = Path(image_input) + if local_path.exists(): + return Image.open(local_path).convert("RGB") + + response = requests.get(image_input) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + + def predict_image(self,url:str,input_points:List[List[int]],input_labels:List[int],input_box:List[List[int]],multimask_output:bool): + print(url[:50],input_points,input_labels,input_box,multimask_output) + with self.inference_lock: + if self.current_img != url: + img = self._load_image(url) + self.img_predictor.set_image(img) + self.current_img = url + masks, scores, logits = self.img_predictor.predict( + point_coords=np.array(input_points), + point_labels=np.array(input_labels), + box = np.array(input_box) if isinstance(input_box,List) else None, + multimask_output=multimask_output) + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + scores = scores[sorted_ind] + logits = logits[sorted_ind] + mask = Image.fromarray(masks[0].astype(np.uint8) * 255) + byte_io = BytesIO() + mask.save(byte_io, format="JPEG") + byte_io.seek(0) + return byte_io.getvalue() + + def generate_masks(self, url:str): + print(url[:50]) + with self.inference_lock: + img = self._load_image(url) + masks = self.mask_generator.generate(np.array(img)) + print(len(masks)) + + maskImgArr = self.combine_masks(masks) + mask = Image.fromarray((maskImgArr * 255).astype(np.uint8), 'RGBA') + byte_io = BytesIO() + mask.save(byte_io, format="PNG") + byte_io.seek(0) + return byte_io.getvalue() + + def save_masks_to_dir( + self, image_input: str, output_dir: Path, filename_prefix: str = "mask" + ) -> List[str]: + with self.autocast_context(), self.inference_lock: + img = self._load_image(image_input) + masks = self.mask_generator.generate(np.array(img)) + + output_dir.mkdir(parents=True, exist_ok=True) + saved_files: List[str] = [] + for idx, ann in enumerate(masks): + segmentation = ann.get("segmentation") + if segmentation is None: + continue + mask_arr = np.array(segmentation).astype(np.uint8) + # squeeze potential extra dimensions (e.g., 1xHxW) + mask_arr = np.squeeze(mask_arr) + mask_img = Image.fromarray(mask_arr * 255) + file_name = f"{filename_prefix}_{idx:04d}.png" + mask_img.save(output_dir / file_name) + saved_files.append(file_name) + + return saved_files + + def combine_masks(self, masks, borders=True): + if len(masks) == 0: + return + np.random.seed(3) + sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True) + + img = np.ones((sorted_masks[0]['segmentation'].shape[0], sorted_masks[0]['segmentation'].shape[1], 4)) + img[:, :, 3] = 0 + for ann in sorted_masks: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3), [0.5]]) + img[m] = color_mask + if borders: + contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # Try to smooth contours + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) + print(img) + return img + def autocast_context(self): if self.device.type == "cuda": return torch.autocast("cuda", dtype=torch.bfloat16) From 7bcc5b60ccad44f38ff029f64de83c64ed44beff Mon Sep 17 00:00:00 2001 From: weimingz996 Date: Mon, 8 Dec 2025 08:54:29 +0800 Subject: [PATCH 3/9] change sam2 setting for speed --- demo/backend/server/app.py | 2 +- demo/backend/server/inference/predictor.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index 483a82035..5037ff362 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -214,4 +214,4 @@ def get_context(self, request: Request, response: Response) -> Any: if __name__ == "__main__": - app.run(host="0.0.0.0", port=5000) + app.run(host="0.0.0.0", port=7263) diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index efc6c3635..38b514c5b 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -99,16 +99,16 @@ def __init__(self) -> None: self.img_predictor = SAM2ImagePredictor(sam2_model) self.mask_generator = SAM2AutomaticMaskGenerator( sam2_model, - points_per_side=32, + points_per_side=16, points_per_batch=64, - pred_iou_thresh=0.7, - stability_score_thresh=0.92, + pred_iou_thresh=0.8, + stability_score_thresh=0.95, stability_score_offset=0.7, crop_n_layers=1, - box_nms_thresh=0.7, + box_nms_thresh=0.4, crop_n_points_downscale_factor=2, min_mask_region_area=100.0, - use_m2m=True + use_m2m=False ) self.current_img = None self.inference_lock = Lock() From 797225b9972af3f998e9d565072ff0e5b4517fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E6=AD=A3=20=E7=8E=8B?= Date: Tue, 9 Dec 2025 16:08:21 +0800 Subject: [PATCH 4/9] add sam api for ref model --- demo/backend/server/app.py | 21 +++++++++++++++++ demo/backend/server/inference/predictor.py | 27 ++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index 5037ff362..b020bca6c 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -149,6 +149,27 @@ def image_masks_save() -> Response: } ) +@app.route("/image_masks", methods=["POST"]) +def image_masks() -> Response: + """ + 输入图片(url/path/data uri),直接返回全局 mask(base64 PNG 列表),不落盘。 + """ + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + try: + masks_payload = inference_api.generate_masks_base64(image_input=image_input) + except Exception as exc: + logger.exception("failed to generate masks in memory") + return jsonify({"error": f"failed to generate masks: {exc}"}), 500 + + return jsonify({ + "count": len(masks_payload), + "masks": masks_payload, + }) + # TOOD: Protect route with ToS permission check @app.route("/propagate_in_video", methods=["POST"]) diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index 38b514c5b..37927a748 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -163,6 +163,33 @@ def generate_masks(self, url:str): byte_io.seek(0) return byte_io.getvalue() + def generate_masks_base64(self, image_input: str): + """生成全局 mask,直接以 base64 PNG 列表返回,不落盘。""" + print(image_input[:50]) + with self.inference_lock: + img = self._load_image(image_input) + masks = self.mask_generator.generate(np.array(img)) + + payload: List[Dict[str, Any]] = [] + for ann in masks: + seg = ann.get("segmentation") + if seg is None: + continue + mask_arr = np.array(seg).astype(np.uint8) + mask_arr = np.squeeze(mask_arr) + if mask_arr.ndim != 2: + continue + h, w = mask_arr.shape + mask_img = Image.fromarray(mask_arr * 255) + buf = BytesIO() + mask_img.save(buf, format="PNG") + buf.seek(0) + payload.append({ + "size": [int(h), int(w)], + "png_base64": base64.b64encode(buf.getvalue()).decode("ascii"), + }) + return payload + def save_masks_to_dir( self, image_input: str, output_dir: Path, filename_prefix: str = "mask" ) -> List[str]: From 8c5d14d6171a2bc7439670b05f6ff27efaf5eb7b Mon Sep 17 00:00:00 2001 From: qzwang Date: Thu, 18 Dec 2025 09:01:14 +0800 Subject: [PATCH 5/9] add cache for sam2 --- demo/backend/server/app.py | 26 ++++- demo/backend/server/inference/predictor.py | 119 ++++++++++++++++++++- 2 files changed, 139 insertions(+), 6 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index b020bca6c..ee95a9ee0 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -83,12 +83,34 @@ def send_uploaded_video(path: str): ) except: raise ValueError("resource not found") - + +@app.route("/precompute_embedding", methods=["POST"]) +def precompute_embedding() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + try: + cache_path, reused = inference_api.precompute_image_embedding(image_input) + except Exception as exc: + logger.exception("failed to precompute embedding") + return jsonify({"error": f"failed to precompute embedding: {exc}"}), 500 + + return jsonify( + { + "cache_path": str(cache_path), + "status": "reused" if reused else "created", + } + ) + + @app.route(f"/mask", methods=["POST"]) def predict_image() -> Response: data = request.json + start_time = time.time() - res = inference_api.predict_image(data["url"],data["points"],data["labels"],None,True) + res = inference_api.predict_image(data["url"], data["points"], data["labels"], None, True) end_time = time.time() elapsed_time = end_time - start_time print(f"mask生成时间: {elapsed_time:.6f} 秒") diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index 37927a748..ea72a7c2c 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -10,7 +10,7 @@ import base64 from pathlib import Path from threading import Lock -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Optional, Tuple import numpy as np import torch @@ -113,6 +113,114 @@ def __init__(self) -> None: self.current_img = None self.inference_lock = Lock() + def _get_embedding_cache_path(self, image_input: str) -> Optional[Path]: + """Return cache file path for a local image, or None if unsupported.""" + image_path = Path(image_input) + if not image_path.exists(): + return None + return image_path.parent / "temp" / f"{image_path.stem}_sam2_embed.pt" + + def _load_cached_embedding(self, image_input: str) -> bool: + """Load cached embedding into the predictor if it exists and is fresh.""" + cache_path = self._get_embedding_cache_path(image_input) + if cache_path is None or not cache_path.is_file(): + return False + + image_path = Path(image_input) + try: + if cache_path.stat().st_mtime < image_path.stat().st_mtime: + return False + except OSError: + return False + + try: + cache = torch.load(cache_path, map_location=self.device) + except Exception: + logger.exception("failed to read cached embedding") + return False + + if cache.get("model_size") not in (None, MODEL_SIZE): + logger.warning( + "cached embedding model size %s mismatches current %s", + cache.get("model_size"), + MODEL_SIZE, + ) + return False + + features = cache.get("features") + orig_hw = cache.get("orig_hw") + if not features or orig_hw is None: + return False + + try: + image_embed = features["image_embed"].to(self.device) + high_res_feats = [ + feat.to(self.device) for feat in features["high_res_feats"] + ] + except Exception: + logger.exception("failed to move cached embedding to device") + return False + + self.img_predictor.reset_predictor() + self.img_predictor._features = { + "image_embed": image_embed, + "high_res_feats": high_res_feats, + } + self.img_predictor._orig_hw = orig_hw + self.img_predictor._is_image_set = True + self.img_predictor._is_batch = False + self.current_img = image_input + return True + + def _save_embedding_cache(self, image_input: str) -> Optional[Path]: + """Persist current image embedding to disk under the image's temp/ dir.""" + cache_path = self._get_embedding_cache_path(image_input) + if cache_path is None or self.img_predictor._features is None: + return None + + cache_path.parent.mkdir(parents=True, exist_ok=True) + features = self.img_predictor._features + try: + payload = { + "model_size": MODEL_SIZE, + "orig_hw": self.img_predictor._orig_hw, + "features": { + "image_embed": features["image_embed"].detach().cpu(), + "high_res_feats": [ + feat.detach().cpu() for feat in features["high_res_feats"] + ], + }, + } + torch.save(payload, cache_path) + except Exception: + logger.exception("failed to save embedding cache") + return None + return cache_path + + def precompute_image_embedding(self, image_input: str) -> Tuple[Optional[Path], bool]: + """ + Compute and cache image embedding ahead of time. + + Returns a tuple of (cache_path, reused_cache_flag). + """ + cache_path = self._get_embedding_cache_path(image_input) + if cache_path is None: + raise FileNotFoundError(f"image not found: {image_input}") + + cache_path.parent.mkdir(parents=True, exist_ok=True) + with self.inference_lock: + if self._load_cached_embedding(image_input): + return cache_path, True + + img = self._load_image(image_input) + self.img_predictor.set_image(img) + self.current_img = image_input + saved_path = self._save_embedding_cache(image_input) + if saved_path is None: + raise RuntimeError("failed to persist embedding cache to disk") + + return saved_path or cache_path, False + def _load_image(self, image_input: str) -> Image.Image: """Load an image from a data URI, local path, or URL.""" if image_input.startswith("data:"): @@ -131,9 +239,12 @@ def predict_image(self,url:str,input_points:List[List[int]],input_labels:List[in print(url[:50],input_points,input_labels,input_box,multimask_output) with self.inference_lock: if self.current_img != url: - img = self._load_image(url) - self.img_predictor.set_image(img) - self.current_img = url + loaded_from_cache = self._load_cached_embedding(url) + if not loaded_from_cache: + img = self._load_image(url) + self.img_predictor.set_image(img) + self.current_img = url + self._save_embedding_cache(url) masks, scores, logits = self.img_predictor.predict( point_coords=np.array(input_points), point_labels=np.array(input_labels), From 012c3146fb098230faa36ef943e2a5d1f029edd3 Mon Sep 17 00:00:00 2001 From: qzwang Date: Thu, 18 Dec 2025 09:31:51 +0800 Subject: [PATCH 6/9] add cache for sam2 --- demo/backend/server/app.py | 15 +++++++++++++++ demo/backend/server/inference/predictor.py | 19 +++++++++++-------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index ee95a9ee0..4ca059ce9 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -104,6 +104,21 @@ def precompute_embedding() -> Response: } ) +@app.route("/remove_embedding", methods=["POST"]) +def remove_embedding() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + removed = inference_api.remove_embedding_cache(image_input) + return jsonify( + { + "removed": bool(removed), + "cache_path": str(inference_api._get_embedding_cache_path(image_input) or ""), + } + ) + @app.route(f"/mask", methods=["POST"]) def predict_image() -> Response: diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index ea72a7c2c..854ab7d0c 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -114,14 +114,12 @@ def __init__(self) -> None: self.inference_lock = Lock() def _get_embedding_cache_path(self, image_input: str) -> Optional[Path]: - """Return cache file path for a local image, or None if unsupported.""" image_path = Path(image_input) if not image_path.exists(): return None return image_path.parent / "temp" / f"{image_path.stem}_sam2_embed.pt" def _load_cached_embedding(self, image_input: str) -> bool: - """Load cached embedding into the predictor if it exists and is fresh.""" cache_path = self._get_embedding_cache_path(image_input) if cache_path is None or not cache_path.is_file(): return False @@ -173,7 +171,6 @@ def _load_cached_embedding(self, image_input: str) -> bool: return True def _save_embedding_cache(self, image_input: str) -> Optional[Path]: - """Persist current image embedding to disk under the image's temp/ dir.""" cache_path = self._get_embedding_cache_path(image_input) if cache_path is None or self.img_predictor._features is None: return None @@ -198,11 +195,6 @@ def _save_embedding_cache(self, image_input: str) -> Optional[Path]: return cache_path def precompute_image_embedding(self, image_input: str) -> Tuple[Optional[Path], bool]: - """ - Compute and cache image embedding ahead of time. - - Returns a tuple of (cache_path, reused_cache_flag). - """ cache_path = self._get_embedding_cache_path(image_input) if cache_path is None: raise FileNotFoundError(f"image not found: {image_input}") @@ -221,6 +213,17 @@ def precompute_image_embedding(self, image_input: str) -> Tuple[Optional[Path], return saved_path or cache_path, False + def remove_embedding_cache(self, image_input: str) -> bool: + cache_path = self._get_embedding_cache_path(image_input) + if cache_path is None or not cache_path.exists(): + return False + try: + cache_path.unlink() + return True + except Exception: + logger.exception("failed to delete cache %s", cache_path) + return False + def _load_image(self, image_input: str) -> Image.Image: """Load an image from a data URI, local path, or URL.""" if image_input.startswith("data:"): From 66023060b1291472dbfffa64be424f126832196a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E6=AD=A3=20=E7=8E=8B?= Date: Wed, 31 Dec 2025 15:45:30 +0800 Subject: [PATCH 7/9] add base64 for mask api --- demo/backend/server/app.py | 13 +++++++-- demo/backend/server/inference/predictor.py | 31 +++++++++++++--------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index 4ca059ce9..ddf83a04d 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -122,10 +122,19 @@ def remove_embedding() -> Response: @app.route(f"/mask", methods=["POST"]) def predict_image() -> Response: - data = request.json + data = request.get_json(silent=True) or {} + image_base64 = data.get("base64") or data.get("image_base64") or data.get("url") + if not image_base64: + return jsonify({"error": "base64 is required"}), 400 start_time = time.time() - res = inference_api.predict_image(data["url"], data["points"], data["labels"], None, True) + res = inference_api.predict_image( + image_base64, + data["points"], + data["labels"], + None, + True, + ) end_time = time.time() elapsed_time = end_time - start_time print(f"mask生成时间: {elapsed_time:.6f} 秒") diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index 854ab7d0c..9091475f4 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -225,7 +225,7 @@ def remove_embedding_cache(self, image_input: str) -> bool: return False def _load_image(self, image_input: str) -> Image.Image: - """Load an image from a data URI, local path, or URL.""" + """Load an image from a base64 string, data URI, local path, or URL.""" if image_input.startswith("data:"): header, encoded = image_input.split(",", 1) return Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB") @@ -234,24 +234,31 @@ def _load_image(self, image_input: str) -> Image.Image: if local_path.exists(): return Image.open(local_path).convert("RGB") - response = requests.get(image_input) - response.raise_for_status() - return Image.open(BytesIO(response.content)).convert("RGB") + if "://" in image_input: + response = requests.get(image_input) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") - def predict_image(self,url:str,input_points:List[List[int]],input_labels:List[int],input_box:List[List[int]],multimask_output:bool): - print(url[:50],input_points,input_labels,input_box,multimask_output) + try: + decoded = base64.b64decode(image_input) + return Image.open(BytesIO(decoded)).convert("RGB") + except Exception as exc: + raise ValueError("unsupported image input") from exc + + def predict_image(self,image_input:str,input_points:List[List[int]],input_labels:List[int],input_box:List[List[int]],multimask_output:bool): + print(image_input[:50],input_points,input_labels,input_box,multimask_output) with self.inference_lock: - if self.current_img != url: - loaded_from_cache = self._load_cached_embedding(url) + if self.current_img != image_input: + loaded_from_cache = self._load_cached_embedding(image_input) if not loaded_from_cache: - img = self._load_image(url) + img = self._load_image(image_input) self.img_predictor.set_image(img) - self.current_img = url - self._save_embedding_cache(url) + self.current_img = image_input + self._save_embedding_cache(image_input) masks, scores, logits = self.img_predictor.predict( point_coords=np.array(input_points), point_labels=np.array(input_labels), - box = np.array(input_box) if isinstance(input_box,List) else None, + box = np.array(input_box) if isinstance(input_box, list) and input_box else None, multimask_output=multimask_output) sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] From 3ae8c75e96e91baf2331a30ff4b472f49dbef725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E6=AD=A3=20=E7=8E=8B?= Date: Tue, 6 Jan 2026 15:05:15 +0800 Subject: [PATCH 8/9] add base64 support --- demo/backend/server/inference/predictor.py | 44 +++++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index 9091475f4..364e62bc5 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -8,13 +8,14 @@ import os import uuid import base64 +import hashlib from pathlib import Path from threading import Lock from typing import Any, Dict, Generator, List, Optional, Tuple import numpy as np import torch -from app_conf import APP_ROOT, MODEL_SIZE +from app_conf import APP_ROOT, MODEL_SIZE, TEMP_PATH from inference.data_types import ( AddMaskRequest, AddPointsRequest, @@ -116,20 +117,51 @@ def __init__(self) -> None: def _get_embedding_cache_path(self, image_input: str) -> Optional[Path]: image_path = Path(image_input) if not image_path.exists(): - return None + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return TEMP_PATH / f"sam2_embed_{cache_key}.pt" return image_path.parent / "temp" / f"{image_path.stem}_sam2_embed.pt" + def _get_image_cache_key(self, image_input: str) -> Optional[str]: + raw_bytes = self._decode_image_bytes_for_cache(image_input) + if raw_bytes is None: + return None + return hashlib.sha256(raw_bytes).hexdigest() + + def _decode_image_bytes_for_cache(self, image_input: str) -> Optional[bytes]: + if image_input.startswith("data:"): + try: + _, encoded = image_input.split(",", 1) + except ValueError: + return None + else: + encoded = image_input + + encoded = "".join(encoded.split()) + if not encoded: + return None + + padding = (-len(encoded)) % 4 + if padding: + encoded += "=" * padding + try: + return base64.b64decode(encoded, validate=True) + except Exception: + return None + def _load_cached_embedding(self, image_input: str) -> bool: cache_path = self._get_embedding_cache_path(image_input) if cache_path is None or not cache_path.is_file(): return False image_path = Path(image_input) - try: - if cache_path.stat().st_mtime < image_path.stat().st_mtime: + if image_path.exists(): + try: + if cache_path.stat().st_mtime < image_path.stat().st_mtime: + return False + except OSError: return False - except OSError: - return False try: cache = torch.load(cache_path, map_location=self.device) From b2f39569b9a60e17ece6f35f0a2f39adca7f6b2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E6=AD=A3=20=E7=8E=8B?= Date: Tue, 6 Jan 2026 15:56:21 +0800 Subject: [PATCH 9/9] update for base64 input --- demo/backend/server/app.py | 2 + demo/backend/server/inference/predictor.py | 79 ++++++++++++++++------ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index ddf83a04d..515e334ca 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -126,6 +126,7 @@ def predict_image() -> Response: image_base64 = data.get("base64") or data.get("image_base64") or data.get("url") if not image_base64: return jsonify({"error": "base64 is required"}), 400 + cache_dir = data.get("path") start_time = time.time() res = inference_api.predict_image( @@ -134,6 +135,7 @@ def predict_image() -> Response: data["labels"], None, True, + cache_dir=cache_dir, ) end_time = time.time() elapsed_time = end_time - start_time diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index 364e62bc5..3c3bddc3d 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -114,9 +114,31 @@ def __init__(self) -> None: self.current_img = None self.inference_lock = Lock() - def _get_embedding_cache_path(self, image_input: str) -> Optional[Path]: + def _get_embedding_cache_path( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Optional[Path]: + if cache_dir: + cache_base = Path(cache_dir) + if cache_base.suffix == ".pt": + return cache_base + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return cache_base / f"sam2_embed_{cache_key}.pt" + + if image_input.startswith("data:") or "://" in image_input: + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return TEMP_PATH / f"sam2_embed_{cache_key}.pt" + image_path = Path(image_input) - if not image_path.exists(): + try: + exists = image_path.exists() + except OSError: + exists = False + + if not exists: cache_key = self._get_image_cache_key(image_input) if cache_key is None: return None @@ -150,18 +172,23 @@ def _decode_image_bytes_for_cache(self, image_input: str) -> Optional[bytes]: except Exception: return None - def _load_cached_embedding(self, image_input: str) -> bool: - cache_path = self._get_embedding_cache_path(image_input) + def _load_cached_embedding( + self, image_input: str, cache_dir: Optional[str] = None + ) -> bool: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) if cache_path is None or not cache_path.is_file(): return False - image_path = Path(image_input) - if image_path.exists(): + if not image_input.startswith("data:") and "://" not in image_input: + image_path = Path(image_input) try: - if cache_path.stat().st_mtime < image_path.stat().st_mtime: + if ( + image_path.exists() + and cache_path.stat().st_mtime < image_path.stat().st_mtime + ): return False except OSError: - return False + pass try: cache = torch.load(cache_path, map_location=self.device) @@ -202,8 +229,10 @@ def _load_cached_embedding(self, image_input: str) -> bool: self.current_img = image_input return True - def _save_embedding_cache(self, image_input: str) -> Optional[Path]: - cache_path = self._get_embedding_cache_path(image_input) + def _save_embedding_cache( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Optional[Path]: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) if cache_path is None or self.img_predictor._features is None: return None @@ -226,27 +255,31 @@ def _save_embedding_cache(self, image_input: str) -> Optional[Path]: return None return cache_path - def precompute_image_embedding(self, image_input: str) -> Tuple[Optional[Path], bool]: - cache_path = self._get_embedding_cache_path(image_input) + def precompute_image_embedding( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Tuple[Optional[Path], bool]: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) if cache_path is None: raise FileNotFoundError(f"image not found: {image_input}") cache_path.parent.mkdir(parents=True, exist_ok=True) with self.inference_lock: - if self._load_cached_embedding(image_input): + if self._load_cached_embedding(image_input, cache_dir): return cache_path, True img = self._load_image(image_input) self.img_predictor.set_image(img) self.current_img = image_input - saved_path = self._save_embedding_cache(image_input) + saved_path = self._save_embedding_cache(image_input, cache_dir) if saved_path is None: raise RuntimeError("failed to persist embedding cache to disk") return saved_path or cache_path, False - def remove_embedding_cache(self, image_input: str) -> bool: - cache_path = self._get_embedding_cache_path(image_input) + def remove_embedding_cache( + self, image_input: str, cache_dir: Optional[str] = None + ) -> bool: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) if cache_path is None or not cache_path.exists(): return False try: @@ -277,16 +310,24 @@ def _load_image(self, image_input: str) -> Image.Image: except Exception as exc: raise ValueError("unsupported image input") from exc - def predict_image(self,image_input:str,input_points:List[List[int]],input_labels:List[int],input_box:List[List[int]],multimask_output:bool): + def predict_image( + self, + image_input: str, + input_points: List[List[int]], + input_labels: List[int], + input_box: List[List[int]], + multimask_output: bool, + cache_dir: Optional[str] = None, + ): print(image_input[:50],input_points,input_labels,input_box,multimask_output) with self.inference_lock: if self.current_img != image_input: - loaded_from_cache = self._load_cached_embedding(image_input) + loaded_from_cache = self._load_cached_embedding(image_input, cache_dir) if not loaded_from_cache: img = self._load_image(image_input) self.img_predictor.set_image(img) self.current_img = image_input - self._save_embedding_cache(image_input) + self._save_embedding_cache(image_input, cache_dir) masks, scores, logits = self.img_predictor.predict( point_coords=np.array(input_points), point_labels=np.array(input_labels),