Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm and sliding windows fixes #2033

Merged
merged 10 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ fn shard_manager(
rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
Expand Down Expand Up @@ -548,6 +549,10 @@ fn shard_manager(
shard_args.push(otlp_endpoint);
}

// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());

// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

Expand Down Expand Up @@ -1004,6 +1009,7 @@ fn spawn_shards(
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
Expand Down Expand Up @@ -1061,6 +1067,7 @@ fn spawn_shards(
rope_factor,
max_total_tokens,
max_batch_size,
max_input_tokens,
otlp_endpoint,
max_log_level,
status_sender,
Expand Down Expand Up @@ -1535,6 +1542,7 @@ fn main() -> Result<(), LauncherError> {
&args,
cuda_graphs,
max_total_tokens,
max_input_tokens,
max_log_level,
shutdown.clone(),
&shutdown_receiver,
Expand Down
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
Expand All @@ -19,5 +19,5 @@ build-vllm-rocm:
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

install-vllm-rocm: build-vllm-rocm
cd vllm && git fetch && git checkout $(commit_rocm) && \
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
2 changes: 2 additions & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
max_input_tokens: Optional[int] = None,
):
if sharded:
assert (
Expand Down Expand Up @@ -97,6 +98,7 @@ def serve(
dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)


Expand Down
11 changes: 3 additions & 8 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,8 @@ def attention(
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
Expand Down Expand Up @@ -204,10 +202,7 @@ def attention(
window_size_left=-1,
causal=True,
):
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,
Expand Down
5 changes: 1 addition & 4 deletions server/text_generation_server/layers/attention/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ def attention(
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
q,
k,
Expand Down
21 changes: 12 additions & 9 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi

from text_generation_server.utils.import_utils import SYSTEM

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -257,6 +259,7 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if dtype is None:
Expand Down Expand Up @@ -410,11 +413,15 @@ def get_model(
"Sharding is currently not supported with `exl2` quantization"
)
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"

if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
FLASH_ATTENTION = False
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

if model_type == MAMBA:
return Mamba(
Expand Down Expand Up @@ -701,7 +708,6 @@ def get_model(
)

if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMistral(
model_id,
Expand All @@ -724,7 +730,6 @@ def get_model(
)

if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
Expand All @@ -747,7 +752,6 @@ def get_model(
)

if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
Expand All @@ -771,8 +775,7 @@ def get_model(
)

if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
if FLASH_ATTENTION:
return FlashQwen2(
model_id,
revision,
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def warmup(self, batch: FlashCausalLMBatch):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
torch.cuda.tunable.enable()

if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)

Expand All @@ -907,8 +909,11 @@ def warmup(self, batch: FlashCausalLMBatch):
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
else:
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7]

tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
):
async def serve_inner(
model_id: str,
Expand Down Expand Up @@ -229,6 +230,7 @@ async def serve_inner(
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
except Exception:
logger.exception("Error when initializing model")
Expand Down
Loading