diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index 1b3a43de..8ad96dc0 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -38,6 +38,10 @@ def get_model_config( disable_log_requests=True, enable_sleep_mode=enable_sleep_mode, generation_config="vllm", + # Default tensor parallel to visible GPU count (respects CUDA_VISIBLE_DEVICES) + tensor_parallel_size=( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ), ) engine_args.update(config.get("engine_args", {})) init_args.update(config.get("init_args", {}))