Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b7ea34c
Split the language model register tab panel into the separated js file.
mikeshi80 Mar 13, 2024
18f464e
WIP: 从HUB中获得模型信息
mikeshi80 Mar 15, 2024
25f2b38
Merge branch 'main' into fetch_from_hub
mikeshi80 Mar 15, 2024
cc903e1
add more unit tests for getting model info.
mikeshi80 Mar 15, 2024
c108eb0
fix the get model info correctly.
mikeshi80 Mar 17, 2024
c24a459
fix the wrong llm family build logic.
mikeshi80 Mar 17, 2024
3e4f765
Merge branch 'main' into fetch_from_hub
mikeshi80 Mar 18, 2024
4433f85
fix the bug that cannot process quantization with lower case.
mikeshi80 Mar 18, 2024
53ca758
add support for GGUF, GGML of ModelScope
mikeshi80 Mar 18, 2024
5dc38a1
add async runner to run function async, and add model hub utility to …
mikeshi80 Mar 18, 2024
3991ae6
Did the code refactoring to get gguf and ggml model info.
mikeshi80 Mar 18, 2024
c0936e8
when model size is decimals, replace dot with underscore.
mikeshi80 Mar 18, 2024
5081c48
when model size is unknown, not return "UNKNOWN", but 0
mikeshi80 Mar 18, 2024
f0bd9af
add support of pytorch and awq format to fetch model info from hub.
mikeshi80 Mar 18, 2024
46e1626
add support of pytorch and awq format to fetch model info from hub.
mikeshi80 Mar 18, 2024
97dbae6
add the rest api for fetch model info from model hub
mikeshi80 Mar 18, 2024
f915155
Merge branch 'main' into fetch_from_hub
mikeshi80 Mar 18, 2024
b853970
refactor the pytest case with pytest.raises to catch exception.
mikeshi80 Mar 19, 2024
b354cec
fix the but that '0.x' form model id cannot be processed correctly
mikeshi80 Mar 19, 2024
e5ba02d
update caniuse-lite
mikeshi80 Mar 19, 2024
9b774fe
Finish LLM custom model register from hub.
mikeshi80 Mar 19, 2024
8ed9ece
Merge branch 'main' into fetch_from_hub
mikeshi80 Apr 2, 2024
c0a636c
Finish embedding custom model register from hub.
mikeshi80 Apr 2, 2024
be6e780
Finish rerank custom model register from hub.
mikeshi80 Apr 2, 2024
dc8eb98
fix the bug that Future with generic type will cause error when it is…
mikeshi80 Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,16 @@ def serve(self, logging_conf: Optional[dict] = None):
else None
),
)
self._router.add_api_route(
"/v1/model_registrations/{model_type}/{model_hub}/{model_format}/{user}/{repo}",
self.get_model_info_from_hub,
methods=["GET"],
dependencies=(
[Security(self._auth_service, scopes=["models:register"])]
if self.is_authenticated()
else None
),
)

# Clear the global Registry for the MetricsMiddleware, or
# the MetricsMiddleware will register duplicated metrics if the port
Expand Down Expand Up @@ -1507,6 +1517,32 @@ async def get_cluster_version(self) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def get_model_info_from_hub(
self, model_type: str, model_hub: str, model_format: str, user: str, repo: str
) -> JSONResponse:
try:
if model_type == "LLM":
llm_family_info = await (
await self._get_supervisor_ref()
).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub)
return JSONResponse(content=llm_family_info)
if model_type == "embedding":
embed_spec = await (
await self._get_supervisor_ref()
).get_embedding_spec_from_hub(f"{user}/{repo}", model_hub)
return JSONResponse(content=embed_spec)
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

raise HTTPException(
status_code=400,
detail="only LLM and embedding model type supported currently",
)


def run(
supervisor_address: str,
Expand Down
133 changes: 130 additions & 3 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import asyncio
import itertools
import json
import time
import typing
from dataclasses import dataclass
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union

import xoscar as xo
from typing_extensions import Literal, cast

from ..constants import (
XINFERENCE_DISABLE_HEALTH_CHECK,
Expand All @@ -30,11 +32,22 @@
)
from ..core import ModelActor
from ..core.status_guard import InstanceInfo, LaunchStatus
from ..model.embedding import CustomEmbeddingModelSpec
from ..model.embedding.utils import get_language_from_model_id
from ..model.llm import GgmlLLMSpecV1
from ..model.llm.llm_family import (
DEFAULT_CONTEXT_LENGTH,
HubImportLLMFamilyV1,
PytorchLLMSpecV1,
)
from ..model.llm.utils import MODEL_HUB, ModelHubUtil
from .metrics import record_metrics
from .resource import GPUStatus, ResourceStatus
from .utils import (
build_replica_model_uid,
gen_random_string,
get_llama_cpp_quantization_info,
get_model_size_from_model_id,
is_valid_model_uid,
iter_replica_model_uid,
log_async,
Expand All @@ -54,7 +67,6 @@

logger = getLogger(__name__)


ASYNC_LAUNCH_TASKS = {} # type: ignore


Expand All @@ -79,6 +91,7 @@ class ReplicaInfo:
class SupervisorActor(xo.StatelessActor):
def __init__(self):
super().__init__()
self._model_hub_util = ModelHubUtil()
self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {}
self._worker_status: Dict[str, WorkerStatus] = {}
self._replica_model_uid_to_worker: Dict[
Expand Down Expand Up @@ -665,8 +678,8 @@ async def launch_speculative_llm(
model_uid = self._gen_model_uid(model_name)
logger.debug(
(
f"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, "
f"draft_model_name: %s, draft_model_size: %s"
"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, "
"draft_model_name: %s, draft_model_size: %s"
),
model_uid,
model_name,
Expand Down Expand Up @@ -1005,3 +1018,117 @@ async def report_worker_status(
@staticmethod
def record_metrics(name, op, kwargs):
record_metrics(name, op, kwargs)

@log_async(logger=logger)
async def get_llm_family_from_hub(
self,
model_id: str,
model_format: str,
model_hub: str,
) -> HubImportLLMFamilyV1:
if model_hub not in ["huggingface", "modelscope"]:
raise ValueError(f"Unsupported model hub: {model_hub}")

model_hub = cast(MODEL_HUB, model_hub)

context_length = DEFAULT_CONTEXT_LENGTH

repo_exists = await self._model_hub_util.a_repo_exists(
model_id,
model_hub,
)
if not repo_exists:
raise ValueError(f"Model {model_id} does not exist")

if config_path := await self._model_hub_util.a_get_config_path(
model_id, model_hub
):
with open(config_path) as f:
config = json.load(f)
if "max_position_embeddings" in config:
context_length = config["max_position_embeddings"]

if model_format in ["ggmlv3", "ggufv2"]:
filenames = await self._model_hub_util.a_list_repo_files(
model_id, model_hub
)

(
model_file_name_template,
model_file_name_split_template,
quantizations,
quantization_parts,
) = get_llama_cpp_quantization_info(
filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format)
)

llm_spec = GgmlLLMSpecV1(
model_id=model_id,
model_format=model_format,
model_hub=model_hub,
quantizations=quantizations,
quantization_parts=quantization_parts,
model_size_in_billions=get_model_size_from_model_id(model_id),
model_file_name_template=model_file_name_template,
model_file_name_split_template=model_file_name_split_template,
)

return HubImportLLMFamilyV1(
version=1, context_length=context_length, model_specs=[llm_spec]
)
elif model_format in ["pytorch", "awq"]:
llm_spec = PytorchLLMSpecV1(
model_id=model_id,
model_format=model_format,
model_hub=model_hub,
model_size_in_billions=get_model_size_from_model_id(model_id),
quantizations=(
["4-bit", "8-bit", "none"]
if model_format == "pytorch"
else ["Int4"]
),
)
return HubImportLLMFamilyV1(
version=1, context_length=context_length, model_specs=[llm_spec]
)
elif model_format == "gptq":
raise NotImplementedError("gptq is not implemented yet")
else:
raise ValueError(f"Unsupported model format: {model_format}")

@log_async(logger=logger)
async def get_embedding_spec_from_hub(
self, model_id: str, model_hub: str
) -> CustomEmbeddingModelSpec:
if model_hub not in ["huggingface", "modelscope"]:
raise ValueError(f"Unsupported model hub: {model_hub}")

model_hub = cast(MODEL_HUB, model_hub)

repo_exists = await self._model_hub_util.a_repo_exists(
model_id,
model_hub,
)

if not repo_exists:
raise ValueError(f"Model {model_id} does not exist")

max_tokens = 512
dimensions = 768
if config_path := await self._model_hub_util.a_get_config_path(
model_id, model_hub
):
with open(config_path) as f:
config = json.load(f)
if "max_position_embeddings" in config:
max_tokens = config["max_position_embeddings"]
if "hidden_size" in config:
dimensions = config["hidden_size"]
return CustomEmbeddingModelSpec(
model_name=model_id.split("/")[-1],
model_id=model_id,
max_tokens=max_tokens,
dimensions=dimensions,
model_hub=model_hub,
language=[get_language_from_model_id(model_id)],
)
Loading