Skip to content

Commit

Permalink
added the option to choose the model
Browse files Browse the repository at this point in the history
  • Loading branch information
lab176344 committed Jun 12, 2024
1 parent 9e361e0 commit cb1da92
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ base_model = Gemini(
),
gcp_region="us-central1",
gcp_project="project-name",
model="gemini-1.5-flash"
)

# run inference on an image
Expand Down
12 changes: 10 additions & 2 deletions autodistill_gemini/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,26 @@

@dataclass
class Gemini(DetectionBaseModel):
AVAILABLE_MODELS = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-pro-vision"]
ontology: CaptionOntology
api_key: str
gcp_region: str
gcp_project: str
model: str

def __init__(
self, ontology: CaptionOntology, gcp_region: str, gcp_project: str
self, ontology: CaptionOntology, gcp_region: str, gcp_project: str, model: str
) -> None:
self.ontology = ontology
self.gcp_region = gcp_region
self.gcp_project = gcp_project

if model in self.AVAILABLE_MODELS:
self.model = model
else:
raise ValueError(f"Choose one of the available models from {available_models}")


def predict(
self, input: str, prompt: str = "", confidence: int = 0.5
) -> sv.Detections:
Expand All @@ -40,7 +48,7 @@ def predict(

vertexai.init(project=self.gcp_project, location=self.gcp_region)

multimodal_model = GenerativeModel("gemini-pro-vision")
multimodal_model = GenerativeModel(self.model)

response = multimodal_model.generate_content(
[prompt, Image.load_from_file(input)]
Expand Down

0 comments on commit cb1da92

Please sign in to comment.