Skip to content

Commit

Permalink
ROCm and sliding windows fixes (#2033)
Browse files Browse the repository at this point in the history
* update vllm commit & fix models using sliding window

* update

* update commit

* fix bug where tunableop is bound to cuda graph even when cuda graph are disabled

* enable tunableop by default

* fix sliding window

* address review

* dead code

* precise comment

* is it flaky?
  • Loading branch information
fxmarty authored Jun 10, 2024
1 parent bf3c813 commit 9b3674d
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 24 deletions.
8 changes: 8 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ fn shard_manager(
rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
Expand Down Expand Up @@ -553,6 +554,10 @@ fn shard_manager(
shard_args.push(otlp_endpoint);
}

// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());

// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

Expand Down Expand Up @@ -1009,6 +1014,7 @@ fn spawn_shards(
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
Expand Down Expand Up @@ -1066,6 +1072,7 @@ fn spawn_shards(
rope_factor,
max_total_tokens,
max_batch_size,
max_input_tokens,
otlp_endpoint,
max_log_level,
status_sender,
Expand Down Expand Up @@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
&args,
cuda_graphs,
max_total_tokens,
max_input_tokens,
max_log_level,
shutdown.clone(),
&shutdown_receiver,
Expand Down
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
Expand All @@ -19,5 +19,5 @@ build-vllm-rocm:
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

install-vllm-rocm: build-vllm-rocm
cd vllm && git fetch && git checkout $(commit_rocm) && \
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
2 changes: 2 additions & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
max_input_tokens: Optional[int] = None,
):
if sharded:
assert (
Expand Down Expand Up @@ -98,6 +99,7 @@ def serve(
dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)


Expand Down
11 changes: 3 additions & 8 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,8 @@ def attention(
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
Expand Down Expand Up @@ -204,10 +202,7 @@ def attention(
window_size_left=-1,
causal=True,
):
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,
Expand Down
5 changes: 1 addition & 4 deletions server/text_generation_server/layers/attention/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ def attention(
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
q,
k,
Expand Down
21 changes: 12 additions & 9 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi

from text_generation_server.utils.import_utils import SYSTEM

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -257,6 +259,7 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if dtype is None:
Expand Down Expand Up @@ -410,11 +413,15 @@ def get_model(
"Sharding is currently not supported with `exl2` quantization"
)
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"

if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
FLASH_ATTENTION = False

if model_type == MAMBA:
return Mamba(
Expand Down Expand Up @@ -701,7 +708,6 @@ def get_model(
)

if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMistral(
model_id,
Expand All @@ -724,7 +730,6 @@ def get_model(
)

if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
Expand All @@ -747,7 +752,6 @@ def get_model(
)

if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
Expand All @@ -771,8 +775,7 @@ def get_model(
)

if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
if FLASH_ATTENTION:
return FlashQwen2(
model_id,
revision,
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ def warmup(self, batch: FlashCausalLMBatch):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
torch.cuda.tunable.enable()

if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)

Expand All @@ -910,8 +912,11 @@ def warmup(self, batch: FlashCausalLMBatch):
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
else:
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7]

tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
):
async def serve_inner(
model_id: str,
Expand Down Expand Up @@ -229,6 +230,7 @@ async def serve_inner(
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
except Exception:
logger.exception("Error when initializing model")
Expand Down

0 comments on commit 9b3674d

Please sign in to comment.