Skip to content

Commit

Permalink
add bbox support
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Aug 27, 2024
1 parent c9cf57a commit 8516cf7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
2 changes: 1 addition & 1 deletion autodistill_gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autodistill_gemini.gemini_model import Gemini
from autodistill_gemini.gemini_model import GeminiForClassification, GeminiForObjectDetection

__version__ = "0.1.0"
Binary file not shown.
Binary file not shown.
62 changes: 59 additions & 3 deletions autodistill_gemini/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,73 @@
from dataclasses import dataclass

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Image
from vertexai.preview.generative_models import GenerativeModel
from PIL import Image
import supervision as sv
from autodistill.helpers import load_image
import google.generativeai as genai
import numpy as np
from autodistill.detection import CaptionOntology, DetectionBaseModel
from autodistill.classification import ClassificationBaseModel

HOME = os.path.expanduser("~")


@dataclass
class Gemini(DetectionBaseModel):
class GeminiForObjectDetection(DetectionBaseModel):
ontology: CaptionOntology
api_key: str
gcp_region: str
gcp_project: str
model: str

def __init__(
self, ontology: CaptionOntology, model: str = "gemini-1.5-pro-latest", api_key: str = None
) -> None:
genai.configure(api_key=api_key)
self.ontology = ontology
self.model = genai.GenerativeModel(model_name=model)

def predict(
self, input: str, prompt: str = "", confidence: int = 0.5
) -> sv.Detections:
if not prompt:
prompt = "Return bounding boxes around every instance of the following labels in the image:\n" + "\n".join(
self.ontology.prompts()
) + """\nReturn in the format {label: [x1, y1, x2, y2]}"""

response = self.model.generate_content(
[Image.open(input), prompt]
)

# "text": "- [person, 275, 0, 999, 918]\n- [a forklift, 201, 95, 728, 851]\n
# extract

text_response = response.text.strip()
import json
print(text_response)
text_as_json = json.loads(text_response)

detection_bboxes = []
detection_classes = []

for detection in text_as_json:
detection_class = detection

if detection_class in self.ontology.prompts():
detection_classes.append(self.ontology.prompts().index(detection_class))
detection_bboxes.append(text_as_json[detection])

# detection_bboxes = sv.xyxy_(np.array(detection_bboxes))

return sv.Detections(
class_id=np.array(detection_classes),
xyxy=np.array(detection_bboxes),
confidence=np.ones(len(detection_classes)),
)

@dataclass
class GeminiForClassification(ClassificationBaseModel):
AVAILABLE_MODELS = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-pro-vision"]
ontology: CaptionOntology
api_key: str
Expand All @@ -30,7 +86,7 @@ def __init__(
if model in self.AVAILABLE_MODELS:
self.model = model
else:
raise ValueError(f"Choose one of the available models from {available_models}")
raise ValueError(f"Choose one of the available models from {self.AVAILABLE_MODELS}")


def predict(
Expand Down

0 comments on commit 8516cf7

Please sign in to comment.