diff --git a/demo/backend/server/app.py b/demo/backend/server/app.py index 424e85bb5..515e334ca 100644 --- a/demo/backend/server/app.py +++ b/demo/backend/server/app.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging +import time from typing import Any, Generator from app_conf import ( @@ -11,18 +12,29 @@ GALLERY_PREFIX, POSTERS_PATH, POSTERS_PREFIX, + TEMP_PATH, + TEMP_PREFIX, UPLOADS_PATH, UPLOADS_PREFIX, ) from data.loader import preload_data from data.schema import schema from data.store import set_videos -from flask import Flask, make_response, Request, request, Response, send_from_directory +from flask import ( + Flask, + make_response, + Request, + jsonify, + request, + Response, + send_from_directory, +) from flask_cors import CORS from inference.data_types import PropagateDataResponse, PropagateInVideoRequest from inference.multipart import MultipartResponseBuilder from inference.predictor import InferenceAPI from strawberry.flask.views import GraphQLView +import json logger = logging.getLogger(__name__) @@ -72,6 +84,140 @@ def send_uploaded_video(path: str): except: raise ValueError("resource not found") +@app.route("/precompute_embedding", methods=["POST"]) +def precompute_embedding() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + try: + cache_path, reused = inference_api.precompute_image_embedding(image_input) + except Exception as exc: + logger.exception("failed to precompute embedding") + return jsonify({"error": f"failed to precompute embedding: {exc}"}), 500 + + return jsonify( + { + "cache_path": str(cache_path), + "status": "reused" if reused else "created", + } + ) + +@app.route("/remove_embedding", methods=["POST"]) +def remove_embedding() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + removed = inference_api.remove_embedding_cache(image_input) + return jsonify( + { + "removed": bool(removed), + "cache_path": str(inference_api._get_embedding_cache_path(image_input) or ""), + } + ) + + +@app.route(f"/mask", methods=["POST"]) +def predict_image() -> Response: + data = request.get_json(silent=True) or {} + image_base64 = data.get("base64") or data.get("image_base64") or data.get("url") + if not image_base64: + return jsonify({"error": "base64 is required"}), 400 + cache_dir = data.get("path") + + start_time = time.time() + res = inference_api.predict_image( + image_base64, + data["points"], + data["labels"], + None, + True, + cache_dir=cache_dir, + ) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"mask生成时间: {elapsed_time:.6f} 秒") + return Response( + res, + mimetype="image/jpeg", + headers={ + "Content-Disposition": "attachment; filename=mask.jpg" + } + ) + +@app.route(f"/masks", methods=["POST"]) +def generate_masks() -> Response: + data = request.json + start_time = time.time() + res = inference_api.generate_masks(data["url"]) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"masks生成时间: {elapsed_time:.6f} 秒") + return Response( + res, + mimetype="image/png", + headers={ + "Content-Disposition": "attachment; filename=mask.png" + } + ) + + +@app.route("/image_masks_save", methods=["POST"]) +def image_masks_save() -> Response: + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + filename_prefix = data.get("name") or "mask" + + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + timestamp = str(int(time.time() * 1000)) + output_dir = TEMP_PATH / timestamp + + try: + saved_files = inference_api.save_masks_to_dir( + image_input=image_input, + output_dir=output_dir, + filename_prefix=filename_prefix, + ) + except Exception as exc: + logger.exception("failed to save image masks") + return jsonify({"error": f"failed to save masks: {exc}"}), 500 + + rel_dir = f"{TEMP_PREFIX}/{timestamp}" + masks_payload = [{"path": f"{rel_dir}/{name}"} for name in saved_files] + return jsonify( + { + "count": len(saved_files), + "saved_dir": rel_dir, + "masks": masks_payload, + } + ) + +@app.route("/image_masks", methods=["POST"]) +def image_masks() -> Response: + """ + 输入图片(url/path/data uri),直接返回全局 mask(base64 PNG 列表),不落盘。 + """ + data = request.get_json(silent=True) or {} + image_input = data.get("url") or data.get("path") + if not image_input: + return jsonify({"error": "url or path is required"}), 400 + + try: + masks_payload = inference_api.generate_masks_base64(image_input=image_input) + except Exception as exc: + logger.exception("failed to generate masks in memory") + return jsonify({"error": f"failed to generate masks: {exc}"}), 500 + + return jsonify({ + "count": len(masks_payload), + "masks": masks_payload, + }) + # TOOD: Protect route with ToS permission check @app.route("/propagate_in_video", methods=["POST"]) @@ -137,4 +283,4 @@ def get_context(self, request: Request, response: Response) -> Any: if __name__ == "__main__": - app.run(host="0.0.0.0", port=5000) + app.run(host="0.0.0.0", port=7263) diff --git a/demo/backend/server/app_conf.py b/demo/backend/server/app_conf.py index eea777289..417bcdca8 100644 --- a/demo/backend/server/app_conf.py +++ b/demo/backend/server/app_conf.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -APP_ROOT = os.getenv("APP_ROOT", "/opt/sam2") +APP_ROOT = os.getenv("APP_ROOT", "./sam2") API_URL = os.getenv("API_URL", "http://localhost:7263") @@ -20,7 +20,7 @@ FFMPEG_NUM_THREADS = int(os.getenv("FFMPEG_NUM_THREADS", "1")) # Path for all data used in API -DATA_PATH = Path(os.getenv("DATA_PATH", "/data")) +DATA_PATH = Path(os.getenv("DATA_PATH", "./data")) # Max duration an uploaded video can have in seconds. The default is 10 # seconds. @@ -48,8 +48,13 @@ # Path where all posters are stored POSTERS_PATH = DATA_PATH / POSTERS_PREFIX +# Prefix and path for temporary assets (e.g. generated masks) +TEMP_PREFIX = "temp" +TEMP_PATH = DATA_PATH / TEMP_PREFIX + # Make sure any of those paths exist os.makedirs(DATA_PATH, exist_ok=True) os.makedirs(GALLERY_PATH, exist_ok=True) os.makedirs(UPLOADS_PATH, exist_ok=True) os.makedirs(POSTERS_PATH, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) diff --git a/demo/backend/server/inference/predictor.py b/demo/backend/server/inference/predictor.py index ff0dab233..3c3bddc3d 100644 --- a/demo/backend/server/inference/predictor.py +++ b/demo/backend/server/inference/predictor.py @@ -7,13 +7,15 @@ import logging import os import uuid +import base64 +import hashlib from pathlib import Path from threading import Lock -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Optional, Tuple import numpy as np import torch -from app_conf import APP_ROOT, MODEL_SIZE +from app_conf import APP_ROOT, MODEL_SIZE, TEMP_PATH from inference.data_types import ( AddMaskRequest, AddPointsRequest, @@ -34,8 +36,13 @@ StartSessionResponse, ) from pycocotools.mask import decode as decode_masks, encode as encode_masks -from sam2.build_sam import build_sam2_video_predictor - +from sam2.build_sam import build_sam2_video_predictor, build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from PIL import Image +import requests +from io import BytesIO +import cv2 logger = logging.getLogger(__name__) @@ -89,8 +96,337 @@ def __init__(self) -> None: self.predictor = build_sam2_video_predictor( model_cfg, checkpoint, device=device ) + sam2_model = build_sam2(model_cfg, checkpoint, device=device) + self.img_predictor = SAM2ImagePredictor(sam2_model) + self.mask_generator = SAM2AutomaticMaskGenerator( + sam2_model, + points_per_side=16, + points_per_batch=64, + pred_iou_thresh=0.8, + stability_score_thresh=0.95, + stability_score_offset=0.7, + crop_n_layers=1, + box_nms_thresh=0.4, + crop_n_points_downscale_factor=2, + min_mask_region_area=100.0, + use_m2m=False + ) + self.current_img = None self.inference_lock = Lock() + def _get_embedding_cache_path( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Optional[Path]: + if cache_dir: + cache_base = Path(cache_dir) + if cache_base.suffix == ".pt": + return cache_base + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return cache_base / f"sam2_embed_{cache_key}.pt" + + if image_input.startswith("data:") or "://" in image_input: + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return TEMP_PATH / f"sam2_embed_{cache_key}.pt" + + image_path = Path(image_input) + try: + exists = image_path.exists() + except OSError: + exists = False + + if not exists: + cache_key = self._get_image_cache_key(image_input) + if cache_key is None: + return None + return TEMP_PATH / f"sam2_embed_{cache_key}.pt" + return image_path.parent / "temp" / f"{image_path.stem}_sam2_embed.pt" + + def _get_image_cache_key(self, image_input: str) -> Optional[str]: + raw_bytes = self._decode_image_bytes_for_cache(image_input) + if raw_bytes is None: + return None + return hashlib.sha256(raw_bytes).hexdigest() + + def _decode_image_bytes_for_cache(self, image_input: str) -> Optional[bytes]: + if image_input.startswith("data:"): + try: + _, encoded = image_input.split(",", 1) + except ValueError: + return None + else: + encoded = image_input + + encoded = "".join(encoded.split()) + if not encoded: + return None + + padding = (-len(encoded)) % 4 + if padding: + encoded += "=" * padding + try: + return base64.b64decode(encoded, validate=True) + except Exception: + return None + + def _load_cached_embedding( + self, image_input: str, cache_dir: Optional[str] = None + ) -> bool: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) + if cache_path is None or not cache_path.is_file(): + return False + + if not image_input.startswith("data:") and "://" not in image_input: + image_path = Path(image_input) + try: + if ( + image_path.exists() + and cache_path.stat().st_mtime < image_path.stat().st_mtime + ): + return False + except OSError: + pass + + try: + cache = torch.load(cache_path, map_location=self.device) + except Exception: + logger.exception("failed to read cached embedding") + return False + + if cache.get("model_size") not in (None, MODEL_SIZE): + logger.warning( + "cached embedding model size %s mismatches current %s", + cache.get("model_size"), + MODEL_SIZE, + ) + return False + + features = cache.get("features") + orig_hw = cache.get("orig_hw") + if not features or orig_hw is None: + return False + + try: + image_embed = features["image_embed"].to(self.device) + high_res_feats = [ + feat.to(self.device) for feat in features["high_res_feats"] + ] + except Exception: + logger.exception("failed to move cached embedding to device") + return False + + self.img_predictor.reset_predictor() + self.img_predictor._features = { + "image_embed": image_embed, + "high_res_feats": high_res_feats, + } + self.img_predictor._orig_hw = orig_hw + self.img_predictor._is_image_set = True + self.img_predictor._is_batch = False + self.current_img = image_input + return True + + def _save_embedding_cache( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Optional[Path]: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) + if cache_path is None or self.img_predictor._features is None: + return None + + cache_path.parent.mkdir(parents=True, exist_ok=True) + features = self.img_predictor._features + try: + payload = { + "model_size": MODEL_SIZE, + "orig_hw": self.img_predictor._orig_hw, + "features": { + "image_embed": features["image_embed"].detach().cpu(), + "high_res_feats": [ + feat.detach().cpu() for feat in features["high_res_feats"] + ], + }, + } + torch.save(payload, cache_path) + except Exception: + logger.exception("failed to save embedding cache") + return None + return cache_path + + def precompute_image_embedding( + self, image_input: str, cache_dir: Optional[str] = None + ) -> Tuple[Optional[Path], bool]: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) + if cache_path is None: + raise FileNotFoundError(f"image not found: {image_input}") + + cache_path.parent.mkdir(parents=True, exist_ok=True) + with self.inference_lock: + if self._load_cached_embedding(image_input, cache_dir): + return cache_path, True + + img = self._load_image(image_input) + self.img_predictor.set_image(img) + self.current_img = image_input + saved_path = self._save_embedding_cache(image_input, cache_dir) + if saved_path is None: + raise RuntimeError("failed to persist embedding cache to disk") + + return saved_path or cache_path, False + + def remove_embedding_cache( + self, image_input: str, cache_dir: Optional[str] = None + ) -> bool: + cache_path = self._get_embedding_cache_path(image_input, cache_dir) + if cache_path is None or not cache_path.exists(): + return False + try: + cache_path.unlink() + return True + except Exception: + logger.exception("failed to delete cache %s", cache_path) + return False + + def _load_image(self, image_input: str) -> Image.Image: + """Load an image from a base64 string, data URI, local path, or URL.""" + if image_input.startswith("data:"): + header, encoded = image_input.split(",", 1) + return Image.open(BytesIO(base64.b64decode(encoded))).convert("RGB") + + local_path = Path(image_input) + if local_path.exists(): + return Image.open(local_path).convert("RGB") + + if "://" in image_input: + response = requests.get(image_input) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + + try: + decoded = base64.b64decode(image_input) + return Image.open(BytesIO(decoded)).convert("RGB") + except Exception as exc: + raise ValueError("unsupported image input") from exc + + def predict_image( + self, + image_input: str, + input_points: List[List[int]], + input_labels: List[int], + input_box: List[List[int]], + multimask_output: bool, + cache_dir: Optional[str] = None, + ): + print(image_input[:50],input_points,input_labels,input_box,multimask_output) + with self.inference_lock: + if self.current_img != image_input: + loaded_from_cache = self._load_cached_embedding(image_input, cache_dir) + if not loaded_from_cache: + img = self._load_image(image_input) + self.img_predictor.set_image(img) + self.current_img = image_input + self._save_embedding_cache(image_input, cache_dir) + masks, scores, logits = self.img_predictor.predict( + point_coords=np.array(input_points), + point_labels=np.array(input_labels), + box = np.array(input_box) if isinstance(input_box, list) and input_box else None, + multimask_output=multimask_output) + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + scores = scores[sorted_ind] + logits = logits[sorted_ind] + mask = Image.fromarray(masks[0].astype(np.uint8) * 255) + byte_io = BytesIO() + mask.save(byte_io, format="JPEG") + byte_io.seek(0) + return byte_io.getvalue() + + def generate_masks(self, url:str): + print(url[:50]) + with self.inference_lock: + img = self._load_image(url) + masks = self.mask_generator.generate(np.array(img)) + print(len(masks)) + + maskImgArr = self.combine_masks(masks) + mask = Image.fromarray((maskImgArr * 255).astype(np.uint8), 'RGBA') + byte_io = BytesIO() + mask.save(byte_io, format="PNG") + byte_io.seek(0) + return byte_io.getvalue() + + def generate_masks_base64(self, image_input: str): + """生成全局 mask,直接以 base64 PNG 列表返回,不落盘。""" + print(image_input[:50]) + with self.inference_lock: + img = self._load_image(image_input) + masks = self.mask_generator.generate(np.array(img)) + + payload: List[Dict[str, Any]] = [] + for ann in masks: + seg = ann.get("segmentation") + if seg is None: + continue + mask_arr = np.array(seg).astype(np.uint8) + mask_arr = np.squeeze(mask_arr) + if mask_arr.ndim != 2: + continue + h, w = mask_arr.shape + mask_img = Image.fromarray(mask_arr * 255) + buf = BytesIO() + mask_img.save(buf, format="PNG") + buf.seek(0) + payload.append({ + "size": [int(h), int(w)], + "png_base64": base64.b64encode(buf.getvalue()).decode("ascii"), + }) + return payload + + def save_masks_to_dir( + self, image_input: str, output_dir: Path, filename_prefix: str = "mask" + ) -> List[str]: + with self.autocast_context(), self.inference_lock: + img = self._load_image(image_input) + masks = self.mask_generator.generate(np.array(img)) + + output_dir.mkdir(parents=True, exist_ok=True) + saved_files: List[str] = [] + for idx, ann in enumerate(masks): + segmentation = ann.get("segmentation") + if segmentation is None: + continue + mask_arr = np.array(segmentation).astype(np.uint8) + # squeeze potential extra dimensions (e.g., 1xHxW) + mask_arr = np.squeeze(mask_arr) + mask_img = Image.fromarray(mask_arr * 255) + file_name = f"{filename_prefix}_{idx:04d}.png" + mask_img.save(output_dir / file_name) + saved_files.append(file_name) + + return saved_files + + def combine_masks(self, masks, borders=True): + if len(masks) == 0: + return + np.random.seed(3) + sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True) + + img = np.ones((sorted_masks[0]['segmentation'].shape[0], sorted_masks[0]['segmentation'].shape[1], 4)) + img[:, :, 3] = 0 + for ann in sorted_masks: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3), [0.5]]) + img[m] = color_mask + if borders: + contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # Try to smooth contours + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) + print(img) + return img + def autocast_context(self): if self.device.type == "cuda": return torch.autocast("cuda", dtype=torch.bfloat16)