Skip to content

Commit d35d9a1

Browse files
sararobcopybara-github
authored andcommitted
feat: add support for model_selection_config to GenerateContentConfig
PiperOrigin-RevId: 745109883
1 parent 0af470b commit d35d9a1

File tree

3 files changed

+124
-0
lines changed

3 files changed

+124
-0
lines changed

google/genai/models.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ def _Schema_to_mldev(
176176
return to_object
177177

178178

179+
def _ModelSelectionConfig_to_mldev(
180+
api_client: BaseApiClient,
181+
from_object: Union[dict, object],
182+
parent_object: Optional[dict] = None,
183+
) -> dict:
184+
to_object: dict[str, Any] = {}
185+
if getv(from_object, ['feature_selection_preference']) is not None:
186+
raise ValueError(
187+
'feature_selection_preference parameter is not supported in Gemini API.'
188+
)
189+
190+
return to_object
191+
192+
179193
def _SafetySetting_to_mldev(
180194
api_client: BaseApiClient,
181195
from_object: Union[dict, object],
@@ -500,6 +514,11 @@ def _GenerateContentConfig_to_mldev(
500514
if getv(from_object, ['routing_config']) is not None:
501515
raise ValueError('routing_config parameter is not supported in Gemini API.')
502516

517+
if getv(from_object, ['model_selection_config']) is not None:
518+
raise ValueError(
519+
'model_selection_config parameter is not supported in Gemini API.'
520+
)
521+
503522
if getv(from_object, ['safety_settings']) is not None:
504523
setv(
505524
parent_object,
@@ -1273,6 +1292,22 @@ def _Schema_to_vertex(
12731292
return to_object
12741293

12751294

1295+
def _ModelSelectionConfig_to_vertex(
1296+
api_client: BaseApiClient,
1297+
from_object: Union[dict, object],
1298+
parent_object: Optional[dict] = None,
1299+
) -> dict:
1300+
to_object: dict[str, Any] = {}
1301+
if getv(from_object, ['feature_selection_preference']) is not None:
1302+
setv(
1303+
to_object,
1304+
['featureSelectionPreference'],
1305+
getv(from_object, ['feature_selection_preference']),
1306+
)
1307+
1308+
return to_object
1309+
1310+
12761311
def _SafetySetting_to_vertex(
12771312
api_client: BaseApiClient,
12781313
from_object: Union[dict, object],
@@ -1603,6 +1638,15 @@ def _GenerateContentConfig_to_vertex(
16031638
if getv(from_object, ['routing_config']) is not None:
16041639
setv(to_object, ['routingConfig'], getv(from_object, ['routing_config']))
16051640

1641+
if getv(from_object, ['model_selection_config']) is not None:
1642+
setv(
1643+
to_object,
1644+
['modelConfig'],
1645+
_ModelSelectionConfig_to_vertex(
1646+
api_client, getv(from_object, ['model_selection_config']), to_object
1647+
),
1648+
)
1649+
16061650
if getv(from_object, ['safety_settings']) is not None:
16071651
setv(
16081652
parent_object,
@@ -2665,6 +2709,16 @@ def _GenerateVideosParameters_to_vertex(
26652709
return to_object
26662710

26672711

2712+
def _FeatureSelectionPreference_to_mldev_enum_validate(enum_value: Any):
2713+
if enum_value in set([
2714+
'FEATURE_SELECTION_PREFERENCE_UNSPECIFIED',
2715+
'PRIORITIZE_QUALITY',
2716+
'BALANCED',
2717+
'PRIORITIZE_COST',
2718+
]):
2719+
raise ValueError(f'{enum_value} enum value is not supported in Gemini API.')
2720+
2721+
26682722
def _SafetyFilterLevel_to_mldev_enum_validate(enum_value: Any):
26692723
if enum_value in set(['BLOCK_NONE']):
26702724
raise ValueError(f'{enum_value} enum value is not supported in Gemini API.')

google/genai/tests/models/test_generate_content.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,36 @@ def test_simple_config(client):
441441
assert response.text
442442

443443

444+
def test_model_selection_config_dict(client):
445+
if not client.vertexai:
446+
return
447+
response = client.models.generate_content(
448+
model='gemini-1.5-flash',
449+
contents='Give me a Taylor Swift lyric and explain its meaning.',
450+
config={
451+
'model_selection_config': {
452+
'feature_selection_preference': 'PRIORITIZE_COST'
453+
}
454+
},
455+
)
456+
assert response.text
457+
458+
459+
def test_model_selection_config_pydantic(client):
460+
if not client.vertexai:
461+
return
462+
response = client.models.generate_content(
463+
model='gemini-1.5-flash',
464+
contents='Give me a Taylor Swift lyric and explain its meaning.',
465+
config=types.GenerateContentConfig(
466+
model_selection_config=types.ModelSelectionConfig(
467+
feature_selection_preference=types.FeatureSelectionPreference.PRIORITIZE_QUALITY
468+
)
469+
),
470+
)
471+
assert response.text
472+
473+
444474
def test_sdk_logger_logs_warnings(client, caplog):
445475
caplog.set_level(logging.DEBUG, logger='gemini_sdk_logger')
446476
sdk_logger = logging.getLogger('gemini_sdk_logger')

google/genai/types.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,17 @@ class AdapterSize(_common.CaseInSensitiveEnum):
236236
ADAPTER_SIZE_THIRTY_TWO = 'ADAPTER_SIZE_THIRTY_TWO'
237237

238238

239+
class FeatureSelectionPreference(_common.CaseInSensitiveEnum):
240+
"""Options for feature selection preference."""
241+
242+
FEATURE_SELECTION_PREFERENCE_UNSPECIFIED = (
243+
'FEATURE_SELECTION_PREFERENCE_UNSPECIFIED'
244+
)
245+
PRIORITIZE_QUALITY = 'PRIORITIZE_QUALITY'
246+
BALANCED = 'BALANCED'
247+
PRIORITIZE_COST = 'PRIORITIZE_COST'
248+
249+
239250
class DynamicRetrievalConfigMode(_common.CaseInSensitiveEnum):
240251
"""Config for the dynamic retrieval config mode."""
241252

@@ -1226,6 +1237,26 @@ class SchemaDict(TypedDict, total=False):
12261237
SchemaOrDict = Union[Schema, SchemaDict]
12271238

12281239

1240+
class ModelSelectionConfig(_common.BaseModel):
1241+
"""Config for model selection."""
1242+
1243+
feature_selection_preference: Optional[FeatureSelectionPreference] = Field(
1244+
default=None, description="""Options for feature selection preference."""
1245+
)
1246+
1247+
1248+
class ModelSelectionConfigDict(TypedDict, total=False):
1249+
"""Config for model selection."""
1250+
1251+
feature_selection_preference: Optional[FeatureSelectionPreference]
1252+
"""Options for feature selection preference."""
1253+
1254+
1255+
ModelSelectionConfigOrDict = Union[
1256+
ModelSelectionConfig, ModelSelectionConfigDict
1257+
]
1258+
1259+
12291260
class SafetySetting(_common.BaseModel):
12301261
"""Safety settings."""
12311262

@@ -2387,6 +2418,11 @@ class GenerateContentConfig(_common.BaseModel):
23872418
description="""Configuration for model router requests.
23882419
""",
23892420
)
2421+
model_selection_config: Optional[ModelSelectionConfig] = Field(
2422+
default=None,
2423+
description="""Configuration for model selection.
2424+
""",
2425+
)
23902426
safety_settings: Optional[list[SafetySetting]] = Field(
23912427
default=None,
23922428
description="""Safety settings in the request to block unsafe content in the
@@ -2552,6 +2588,10 @@ class GenerateContentConfigDict(TypedDict, total=False):
25522588
"""Configuration for model router requests.
25532589
"""
25542590

2591+
model_selection_config: Optional[ModelSelectionConfigDict]
2592+
"""Configuration for model selection.
2593+
"""
2594+
25552595
safety_settings: Optional[list[SafetySettingDict]]
25562596
"""Safety settings in the request to block unsafe content in the
25572597
response.

0 commit comments

Comments
 (0)