From 297d3bf020eaa395ebe17fae39c372692215b8fc Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:26:37 +0100 Subject: [PATCH 1/7] add pydantic support in response_format --- src/huggingface_hub/_webhooks_payload.py | 16 +++++- src/huggingface_hub/inference/_client.py | 55 +++++++++++++++++-- .../_generated/types/chat_completion.py | 3 + 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py index 288f4b08b9..cad7572869 100644 --- a/src/huggingface_hub/_webhooks_payload.py +++ b/src/huggingface_hub/_webhooks_payload.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains data structures to parse the webhooks payload.""" -from typing import List, Literal, Optional +from typing import Any, List, Literal, Optional from .utils import is_pydantic_available @@ -32,6 +32,20 @@ def __init__(self, *args, **kwargs) -> None: " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) + @classmethod + def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + @classmethod + def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel": + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they # are not in used anymore. To keep in sync when format is updated in diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index ed473e6d11..6872bf4c96 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -37,11 +37,12 @@ import re import time import warnings -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Type, Union, overload from requests import HTTPError from requests.structures import CaseInsensitiveDict +from huggingface_hub._webhooks_payload import BaseModel from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import BadRequestError, InferenceTimeoutError from huggingface_hub.inference._common import ( @@ -538,7 +539,7 @@ def chat_completion( max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, - response_format: Optional[ChatCompletionInputGrammarType] = None, + response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, @@ -590,8 +591,8 @@ def chat_completion( presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - response_format ([`ChatCompletionInputGrammarType`], *optional*): - Grammar constraints. Can be either a JSONSchema or a regex. + response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*): + Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (Optional[`str`], *optional*): @@ -820,7 +821,7 @@ def chat_completion( ) ``` - Example using response_format: + Example using response_format (dict): ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") @@ -850,7 +851,41 @@ def chat_completion( >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` + + Example using response_format (pydantic): + ```py + >>> from huggingface_hub import InferenceClient + >>> from pydantic import BaseModel, conint + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> class OutputFormat(BaseModel): + ... location: str + ... activity: str + ... animals_seen: conint(ge=1, le=5) + ... animals: list[str] + >>> response = client.chat_completion( + ... messages=messages, + ... response_format=OutputFormat, + ... max_tokens=500, + ) + >>> response.choices[0].message.parsed + OutputFormat(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ``` """ + if issubclass(response_format, BaseModel): + base_model = response_format + response_format = ChatCompletionInputGrammarType( + type="json", + value=base_model.model_json_schema(), + ) + else: + base_model = None + model_url = self._resolve_chat_completion_url(model) # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. @@ -886,7 +921,15 @@ def chat_completion( if stream: return _stream_chat_completion_response(data) # type: ignore[arg-type] - return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + if base_model: + for choice in chat_completion_output.choices: + if choice.message.content: + try: + choice.message.parsed = base_model.model_validate_json(choice.message.content) + except ValueError: + pass + return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. diff --git a/src/huggingface_hub/inference/_generated/types/chat_completion.py b/src/huggingface_hub/inference/_generated/types/chat_completion.py index 7a1f297e4f..ce6d19cc80 100644 --- a/src/huggingface_hub/inference/_generated/types/chat_completion.py +++ b/src/huggingface_hub/inference/_generated/types/chat_completion.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, List, Literal, Optional, Union +from huggingface_hub._webhooks_payload import BaseModel + from .base import BaseInferenceType @@ -196,6 +198,7 @@ class ChatCompletionOutputMessage(BaseInferenceType): role: str content: Optional[str] = None tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None + parsed: Optional[BaseModel] = None @dataclass From 893e7d4ee61470d8f6df41e30ab4c75bf6243073 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:29:06 +0100 Subject: [PATCH 2/7] style --- .../inference/_generated/_async_client.py | 68 +++++++++++++++++-- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 74888bc0b8..56833b525a 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -24,10 +24,23 @@ import re import time import warnings -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + Dict, + List, + Literal, + Optional, + Set, + Type, + Union, + overload, +) from requests.structures import CaseInsensitiveDict +from huggingface_hub._webhooks_payload import BaseModel from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import InferenceTimeoutError from huggingface_hub.inference._common import ( @@ -574,7 +587,7 @@ async def chat_completion( max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, - response_format: Optional[ChatCompletionInputGrammarType] = None, + response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, @@ -626,8 +639,8 @@ async def chat_completion( presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - response_format ([`ChatCompletionInputGrammarType`], *optional*): - Grammar constraints. Can be either a JSONSchema or a regex. + response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*): + Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (Optional[`str`], *optional*): @@ -861,7 +874,7 @@ async def chat_completion( ) ``` - Example using response_format: + Example using response_format (dict): ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient @@ -892,7 +905,42 @@ async def chat_completion( >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` + + Example using response_format (pydantic): + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> from pydantic import BaseModel, conint + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> class OutputFormat(BaseModel): + ... location: str + ... activity: str + ... animals_seen: conint(ge=1, le=5) + ... animals: list[str] + >>> response = await client.chat_completion( + ... messages=messages, + ... response_format=OutputFormat, + ... max_tokens=500, + ) + >>> response.choices[0].message.parsed + OutputFormat(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ``` """ + if issubclass(response_format, BaseModel): + base_model = response_format + response_format = ChatCompletionInputGrammarType( + type="json", + value=base_model.model_json_schema(), + ) + else: + base_model = None + model_url = self._resolve_chat_completion_url(model) # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. @@ -928,7 +976,15 @@ async def chat_completion( if stream: return _async_stream_chat_completion_response(data) # type: ignore[arg-type] - return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + if base_model: + for choice in chat_completion_output.choices: + if choice.message.content: + try: + choice.message.parsed = base_model.model_validate_json(choice.message.content) + except ValueError: + pass + return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. From 1a5f13417a072efcb9d73c4184c385158e9e0bd2 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:30:53 +0100 Subject: [PATCH 3/7] minor --- src/huggingface_hub/inference/_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 6872bf4c96..65ae505e7b 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -863,18 +863,18 @@ def chat_completion( ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", ... }, ... ] - >>> class OutputFormat(BaseModel): + >>> class ActivitySummary(BaseModel): ... location: str ... activity: str ... animals_seen: conint(ge=1, le=5) ... animals: list[str] >>> response = client.chat_completion( ... messages=messages, - ... response_format=OutputFormat, + ... response_format=ActivitySummary, ... max_tokens=500, ) >>> response.choices[0].message.parsed - OutputFormat(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) ``` """ if issubclass(response_format, BaseModel): From 94888316ed6463fdb20bbb5c4e94d6bbe76a63ef Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:33:50 +0100 Subject: [PATCH 4/7] update async client --- src/huggingface_hub/inference/_generated/_async_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 56833b525a..7d47b89820 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -918,18 +918,18 @@ async def chat_completion( ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", ... }, ... ] - >>> class OutputFormat(BaseModel): + >>> class ActivitySummary(BaseModel): ... location: str ... activity: str ... animals_seen: conint(ge=1, le=5) ... animals: list[str] >>> response = await client.chat_completion( ... messages=messages, - ... response_format=OutputFormat, + ... response_format=ActivitySummary, ... max_tokens=500, ) >>> response.choices[0].message.parsed - OutputFormat(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) ``` """ if issubclass(response_format, BaseModel): From 90b1c5d590b1f2efdff280894549f77b59d0e8dc Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:48:10 +0100 Subject: [PATCH 5/7] add refusal --- src/huggingface_hub/_webhooks_payload.py | 14 ++++++++++++++ src/huggingface_hub/inference/_client.py | 19 +++++++++++++------ .../inference/_generated/_async_client.py | 19 +++++++++++++------ .../_generated/types/chat_completion.py | 1 + 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py index cad7572869..fa0933275b 100644 --- a/src/huggingface_hub/_webhooks_payload.py +++ b/src/huggingface_hub/_webhooks_payload.py @@ -39,6 +39,13 @@ def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]: " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) + @classmethod + def schema(cls, *args, **kwargs) -> dict[str, Any]: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + @classmethod def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel": raise ImportError( @@ -46,6 +53,13 @@ def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) + @classmethod + def parse_raw(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel": + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they # are not in used anymore. To keep in sync when format is updated in diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 65ae505e7b..7c19cc3b50 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -878,13 +878,15 @@ def chat_completion( ``` """ if issubclass(response_format, BaseModel): - base_model = response_format + response_model = response_format response_format = ChatCompletionInputGrammarType( type="json", - value=base_model.model_json_schema(), + value=response_model.model_json_schema() + if hasattr(response_model, "model_json_schema") + else response_model.schema(), ) else: - base_model = None + response_model = None model_url = self._resolve_chat_completion_url(model) @@ -922,13 +924,18 @@ def chat_completion( return _stream_chat_completion_response(data) # type: ignore[arg-type] chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] - if base_model: + if response_model: for choice in chat_completion_output.choices: if choice.message.content: try: - choice.message.parsed = base_model.model_validate_json(choice.message.content) + # pydantic v2 uses model_validate_json + choice.message.parsed = ( + response_model.model_validate_json(choice.message.content) + if hasattr(response_model, "model_validate_json") + else response_model.parse_raw(choice.message.content) + ) except ValueError: - pass + choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}" return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 7d47b89820..5679c255c1 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -933,13 +933,15 @@ async def chat_completion( ``` """ if issubclass(response_format, BaseModel): - base_model = response_format + response_model = response_format response_format = ChatCompletionInputGrammarType( type="json", - value=base_model.model_json_schema(), + value=response_model.model_json_schema() + if hasattr(response_model, "model_json_schema") + else response_model.schema(), ) else: - base_model = None + response_model = None model_url = self._resolve_chat_completion_url(model) @@ -977,13 +979,18 @@ async def chat_completion( return _async_stream_chat_completion_response(data) # type: ignore[arg-type] chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] - if base_model: + if response_model: for choice in chat_completion_output.choices: if choice.message.content: try: - choice.message.parsed = base_model.model_validate_json(choice.message.content) + # pydantic v2 uses model_validate_json + choice.message.parsed = ( + response_model.model_validate_json(choice.message.content) + if hasattr(response_model, "model_validate_json") + else response_model.parse_raw(choice.message.content) + ) except ValueError: - pass + choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}" return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: diff --git a/src/huggingface_hub/inference/_generated/types/chat_completion.py b/src/huggingface_hub/inference/_generated/types/chat_completion.py index ce6d19cc80..a15b6c4887 100644 --- a/src/huggingface_hub/inference/_generated/types/chat_completion.py +++ b/src/huggingface_hub/inference/_generated/types/chat_completion.py @@ -199,6 +199,7 @@ class ChatCompletionOutputMessage(BaseInferenceType): content: Optional[str] = None tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None parsed: Optional[BaseModel] = None + refusal: Optional[str] = None @dataclass From 30144293ef9bec249292d20c19ff4094bcb08fe0 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 12:52:34 +0100 Subject: [PATCH 6/7] mypy --- src/huggingface_hub/_webhooks_payload.py | 6 +++--- src/huggingface_hub/inference/_client.py | 2 +- src/huggingface_hub/inference/_generated/_async_client.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py index fa0933275b..1279297c8c 100644 --- a/src/huggingface_hub/_webhooks_payload.py +++ b/src/huggingface_hub/_webhooks_payload.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains data structures to parse the webhooks payload.""" -from typing import Any, List, Literal, Optional +from typing import Any, List, Literal, Optional, Union from .utils import is_pydantic_available @@ -47,14 +47,14 @@ def schema(cls, *args, **kwargs) -> dict[str, Any]: ) @classmethod - def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel": + def model_validate_json(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel": raise ImportError( "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) @classmethod - def parse_raw(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel": + def parse_raw(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel": raise ImportError( "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" " should be installed separately. Please run `pip install --upgrade pydantic` and retry." diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 7c19cc3b50..8c5993310f 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -877,7 +877,7 @@ def chat_completion( ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) ``` """ - if issubclass(response_format, BaseModel): + if isinstance(response_format, type) and issubclass(response_format, BaseModel): response_model = response_format response_format = ChatCompletionInputGrammarType( type="json", diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 5679c255c1..5b32261631 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -932,7 +932,7 @@ async def chat_completion( ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) ``` """ - if issubclass(response_format, BaseModel): + if isinstance(response_format, type) and issubclass(response_format, BaseModel): response_model = response_format response_format = ChatCompletionInputGrammarType( type="json", From f7b813a32172692d0f7cda72bb550c032b8ab5ff Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 31 Oct 2024 13:09:08 +0100 Subject: [PATCH 7/7] comment --- src/huggingface_hub/inference/_client.py | 1 + src/huggingface_hub/inference/_generated/_async_client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 8c5993310f..567adb1f61 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -879,6 +879,7 @@ def chat_completion( """ if isinstance(response_format, type) and issubclass(response_format, BaseModel): response_model = response_format + # pydantic v2 uses model_json_schema response_format = ChatCompletionInputGrammarType( type="json", value=response_model.model_json_schema() diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 5b32261631..c017ee8145 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -934,6 +934,7 @@ async def chat_completion( """ if isinstance(response_format, type) and issubclass(response_format, BaseModel): response_model = response_format + # pydantic v2 uses model_json_schema response_format = ChatCompletionInputGrammarType( type="json", value=response_model.model_json_schema()