diff --git a/vllm/config.py b/vllm/config.py index a2cb9b32c65fc..275814d72e6c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1063,6 +1063,7 @@ def _get_and_verify_dtype( if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 # models. + logger.info("Casting torch.float32 to torch.float16.") torch_dtype = torch.float16 else: torch_dtype = config_dtype @@ -1087,9 +1088,11 @@ def _get_and_verify_dtype( if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) pass else: # Casting between float16 and bfloat16 is allowed with a warning.