diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 430323bcd5b..1f5f3232cb1 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -44,6 +44,9 @@ def serve( otlp_endpoint: Optional[str] = None, max_input_tokens: Optional[int] = None, ): + # derive sharded from environment variables if not provided + sharded = sharded or os.getenv("WORLD_SIZE", None) is not None + if sharded: assert ( os.getenv("RANK", None) is not None