Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm and sliding windows fixes #2033

Merged
merged 10 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
Expand All @@ -19,5 +19,5 @@ build-vllm-rocm:
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

install-vllm-rocm: build-vllm-rocm
cd vllm && git fetch && git checkout $(commit_rocm) && \
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
3 changes: 1 addition & 2 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,8 @@ def get_model(
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
f"Flash attention is available, but doesn't support windowing which is required by model {model_id} for long contexts."
)
FLASH_ATTENTION = False
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

if model_type == MAMBA:
return Mamba(
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def warmup(self, batch: FlashCausalLMBatch):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
torch.cuda.tunable.enable()

if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)

Expand All @@ -907,8 +909,11 @@ def warmup(self, batch: FlashCausalLMBatch):
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
else:
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7]

tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
Expand Down
Loading