Skip to content
150 changes: 148 additions & 2 deletions demo/backend/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,37 @@
# LICENSE file in the root directory of this source tree.

import logging
import time
from typing import Any, Generator

from app_conf import (
GALLERY_PATH,
GALLERY_PREFIX,
POSTERS_PATH,
POSTERS_PREFIX,
TEMP_PATH,
TEMP_PREFIX,
UPLOADS_PATH,
UPLOADS_PREFIX,
)
from data.loader import preload_data
from data.schema import schema
from data.store import set_videos
from flask import Flask, make_response, Request, request, Response, send_from_directory
from flask import (
Flask,
make_response,
Request,
jsonify,
request,
Response,
send_from_directory,
)
from flask_cors import CORS
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
from inference.multipart import MultipartResponseBuilder
from inference.predictor import InferenceAPI
from strawberry.flask.views import GraphQLView
import json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,6 +84,140 @@ def send_uploaded_video(path: str):
except:
raise ValueError("resource not found")

@app.route("/precompute_embedding", methods=["POST"])
def precompute_embedding() -> Response:
data = request.get_json(silent=True) or {}
image_input = data.get("url") or data.get("path")
if not image_input:
return jsonify({"error": "url or path is required"}), 400

try:
cache_path, reused = inference_api.precompute_image_embedding(image_input)
except Exception as exc:
logger.exception("failed to precompute embedding")
return jsonify({"error": f"failed to precompute embedding: {exc}"}), 500

return jsonify(
{
"cache_path": str(cache_path),
"status": "reused" if reused else "created",
}
)

@app.route("/remove_embedding", methods=["POST"])
def remove_embedding() -> Response:
data = request.get_json(silent=True) or {}
image_input = data.get("url") or data.get("path")
if not image_input:
return jsonify({"error": "url or path is required"}), 400

removed = inference_api.remove_embedding_cache(image_input)
return jsonify(
{
"removed": bool(removed),
"cache_path": str(inference_api._get_embedding_cache_path(image_input) or ""),
}
)


@app.route(f"/mask", methods=["POST"])
def predict_image() -> Response:
data = request.get_json(silent=True) or {}
image_base64 = data.get("base64") or data.get("image_base64") or data.get("url")
if not image_base64:
return jsonify({"error": "base64 is required"}), 400
cache_dir = data.get("path")

start_time = time.time()
res = inference_api.predict_image(
image_base64,
data["points"],
data["labels"],
None,
True,
cache_dir=cache_dir,
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"mask生成时间: {elapsed_time:.6f} 秒")
return Response(
res,
mimetype="image/jpeg",
headers={
"Content-Disposition": "attachment; filename=mask.jpg"
}
)

@app.route(f"/masks", methods=["POST"])
def generate_masks() -> Response:
data = request.json
start_time = time.time()
res = inference_api.generate_masks(data["url"])
end_time = time.time()
elapsed_time = end_time - start_time
print(f"masks生成时间: {elapsed_time:.6f} 秒")
return Response(
res,
mimetype="image/png",
headers={
"Content-Disposition": "attachment; filename=mask.png"
}
)


@app.route("/image_masks_save", methods=["POST"])
def image_masks_save() -> Response:
data = request.get_json(silent=True) or {}
image_input = data.get("url") or data.get("path")
filename_prefix = data.get("name") or "mask"

if not image_input:
return jsonify({"error": "url or path is required"}), 400

timestamp = str(int(time.time() * 1000))
output_dir = TEMP_PATH / timestamp

try:
saved_files = inference_api.save_masks_to_dir(
image_input=image_input,
output_dir=output_dir,
filename_prefix=filename_prefix,
)
except Exception as exc:
logger.exception("failed to save image masks")
return jsonify({"error": f"failed to save masks: {exc}"}), 500

rel_dir = f"{TEMP_PREFIX}/{timestamp}"
masks_payload = [{"path": f"{rel_dir}/{name}"} for name in saved_files]
return jsonify(
{
"count": len(saved_files),
"saved_dir": rel_dir,
"masks": masks_payload,
}
)

@app.route("/image_masks", methods=["POST"])
def image_masks() -> Response:
"""
输入图片(url/path/data uri),直接返回全局 mask(base64 PNG 列表),不落盘。
"""
data = request.get_json(silent=True) or {}
image_input = data.get("url") or data.get("path")
if not image_input:
return jsonify({"error": "url or path is required"}), 400

try:
masks_payload = inference_api.generate_masks_base64(image_input=image_input)
except Exception as exc:
logger.exception("failed to generate masks in memory")
return jsonify({"error": f"failed to generate masks: {exc}"}), 500

return jsonify({
"count": len(masks_payload),
"masks": masks_payload,
})


# TOOD: Protect route with ToS permission check
@app.route("/propagate_in_video", methods=["POST"])
Expand Down Expand Up @@ -137,4 +283,4 @@ def get_context(self, request: Request, response: Response) -> Any:


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
app.run(host="0.0.0.0", port=7263)
9 changes: 7 additions & 2 deletions demo/backend/server/app_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

logger = logging.getLogger(__name__)

APP_ROOT = os.getenv("APP_ROOT", "/opt/sam2")
APP_ROOT = os.getenv("APP_ROOT", "./sam2")

API_URL = os.getenv("API_URL", "http://localhost:7263")

Expand All @@ -20,7 +20,7 @@
FFMPEG_NUM_THREADS = int(os.getenv("FFMPEG_NUM_THREADS", "1"))

# Path for all data used in API
DATA_PATH = Path(os.getenv("DATA_PATH", "/data"))
DATA_PATH = Path(os.getenv("DATA_PATH", "./data"))

# Max duration an uploaded video can have in seconds. The default is 10
# seconds.
Expand Down Expand Up @@ -48,8 +48,13 @@
# Path where all posters are stored
POSTERS_PATH = DATA_PATH / POSTERS_PREFIX

# Prefix and path for temporary assets (e.g. generated masks)
TEMP_PREFIX = "temp"
TEMP_PATH = DATA_PATH / TEMP_PREFIX

# Make sure any of those paths exist
os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(GALLERY_PATH, exist_ok=True)
os.makedirs(UPLOADS_PATH, exist_ok=True)
os.makedirs(POSTERS_PATH, exist_ok=True)
os.makedirs(TEMP_PATH, exist_ok=True)
Loading