diff --git a/awq/kernels/csrc/pybind.cpp b/awq/kernels/csrc/pybind.cpp index 30424ee0..a7193f58 100644 --- a/awq/kernels/csrc/pybind.cpp +++ b/awq/kernels/csrc/pybind.cpp @@ -30,9 +30,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("w8a8_gemm_forward_cuda", &w8a8_gemm_forward_cuda, "our w8a8 gemm kernel"); m.def("w8a8_gemm_fuse_bias_forward_cuda", &w8a8_gemm_fuse_bias_forward_cuda, "our w8a8 gemm fused bias kernel"); m.def("invoke_quant", &invoke_quant, "fp16->int8 quantization"); - m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"), + m.def("layer_norm_general", &layer_norm_general, py::arg("out"), py::arg("input"), py::arg("weight"), py::arg("bias"),py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = true, - "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel)."); + "Apply Layer Normalization to the input tensor (TRTLLM kernel) and quantize the tensor into 8 bits."); m.def("silu_and_mul", &silu_and_mul, "Activation function."); m.def("gelu_and_quant",&gelu_and_quant, "Apply gelu act and quant output"); -} +} \ No newline at end of file diff --git a/awq/kernels/csrc/w8a8/layernorm.cu b/awq/kernels/csrc/w8a8/layernorm.cu index 510d4e1e..b73304da 100644 --- a/awq/kernels/csrc/w8a8/layernorm.cu +++ b/awq/kernels/csrc/w8a8/layernorm.cu @@ -190,7 +190,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, } // namespace vllm -void rms_norm_general(torch::Tensor &out, // [..., hidden_size] +void layer_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] torch::Tensor &bias, // [hidden_size] diff --git a/awq/kernels/csrc/w8a8/layernorm.h b/awq/kernels/csrc/w8a8/layernorm.h index 9e5d7405..be0fe478 100644 --- a/awq/kernels/csrc/w8a8/layernorm.h +++ b/awq/kernels/csrc/w8a8/layernorm.h @@ -11,7 +11,7 @@ #include #include -void rms_norm_general(torch::Tensor &out, // [..., hidden_size] +void layer_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] torch::Tensor &bias, // [hidden_size] diff --git a/tinychat/README.md b/tinychat/README.md index 9302b664..c9974633 100644 --- a/tinychat/README.md +++ b/tinychat/README.md @@ -184,7 +184,21 @@ Time-To-First-Token (TTFT) of Llama-2-7B (Unit: Seconds): | ----------- |:-------:|:-------:|:-------:|:-------:|:-------:|:-------:| | FP16 | 0.029 | 0.058 | 0.100 | 0.211 | 0.329 | 0.441 | | TinyChat | 0.018 | 0.031 | 0.060 | 0.124 | 0.193 | 0.265 | -| Speedup | 1.57x | 1.83x | 1.66x | 1.70x | 1.70x | 1.66x | +| Speedup | 1.57x | 1.83x | 1.66x | 1.70x | 1.70x | 1.66x | + +Time-To-First-Token (TTFT) NVILA models processing 8-image inputs (Unit: seconds): + +| Model | Precison | VisonTower | LLM | Total | +|:---------------:|:----------:|:------------:|:------------:|:------------:| +| NVILA-lite-2B | FP16 | 0.074 | 0.024 | 0.097 | +| | TinyChat | 0.045 | 0.016 | 0.060 | +| | Speedup | 1.65x | 1.52x | 1.62x | +| NVILA-lite-8B | FP16 | 0.073 | 0.098 | 0.172 | +| | TinyChat | 0.045 | 0.059 | 0.104 | +| | Speedup | 1.63x | 1.67x | 1.65x | +| NVILA-8B | FP16 | 0.075 | 0.205 | 0.280 | +| | TinyChat | 0.046 | 0.122 | 0.168 | +| | Speedup | 1.61x | 1.69x | 1.66x | #### Jetson Orin Results @@ -197,6 +211,21 @@ Time-To-First-Token (TTFT) of Llama-3-8B (Unit: Seconds): | TinyChat | 0.166 | 0.315 | 0.623 | 1.248 | 1.907 | 2.573 | | Speedup | 1.24x | 1.26x | 0.91x | 1.22x | 1.21x | 1.21x | +Time-To-First-Token (TTFT) NVILA models processing 8-image inputs (Unit: seconds): + + +| Model | Precison | VisonTower | LLM | Total | +|:---------------:|:----------:|:------------:|:------------:|:------------:| +| NVILA-lite-2B | FP16 | 0.449 | 0.155 | 0.605 | +| | TinyChat | 0.419 | 0.145 | 0.564 | +| | Speedup | 1.07x | 1.07x | 1.07x | +| NVILA-lite-8B | FP16 | 0.449 | 0.733 | 1.183 | +| | TinyChat | 0.419 | 0.620 | 1.040 | +| | Speedup | 1.07x | 1.18x | 1.14x | +| NVILA-8B | FP16 | 0.449 | 1.798 | 2.247 | +| | TinyChat | 0.419 | 1.200 | 1.620 | +| | Speedup | 1.07x | 1.50x | 1.39x | + #### Comparison with Other Systems @@ -500,11 +529,11 @@ python -m awq.entry --model_path PATH/TO/NVILA/llm \ ``` Next, try chatting with it using the command below to experience shorter Time To First Token (TTFT) and higher decoding throughput. ```bash -python nvila_demo.py --model-path EPATH/TO/NVILA \ +python nvila_demo.py --model-path PATH/TO/NVILA \ --quant_path PATH/TO/NVILA-w4-g128-v2.pt \ --media PATH/TO/MEDIA \ --act_scale_path PATH/TO/NVILA-smooth-scale.pt \ - --quant_llm --chunk --model_type nvila + --all --chunk --model_type nvila ``` diff --git a/tinychat/modules/fused_siglipdecoder.py b/tinychat/modules/fused_siglipdecoder.py index 4614baa4..37de198b 100644 --- a/tinychat/modules/fused_siglipdecoder.py +++ b/tinychat/modules/fused_siglipdecoder.py @@ -14,7 +14,7 @@ from typing import Optional, Tuple, Union from flash_attn import flash_attn_func import time - +import argparse CLIP_RANGE = 5 @@ -24,7 +24,11 @@ class QuantSiglipEncoder(nn.Module): def __init__(self, module: SiglipEncoder, bsz=64, seqlen=1024): super().__init__() - self.config = module.config + self.config=module.config + if "output_hidden_states" not in self.config: + self.config["output_hidden_states"]=False + self.config["use_return_dict"]=False + self.layers = [QuantSiglipEncoderLayer(layer) for layer in module.layers] self.buffer = ActivationBuffer(module) self.bsz = bsz @@ -40,14 +44,12 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: - # TODO Find why this code is necessary - # torch.sum(inputs_embeds!=inputs_embeds) + inputs_embeds=inputs_embeds.contiguous() bsz, seqlen, _ = inputs_embeds.shape if self.bsz != bsz or self.seqlen != seqlen: self.buffer.allocate_activation_buffer(bsz * seqlen) self.bsz = bsz self.seqlen = seqlen - output_hidden_states = ( output_hidden_states if output_hidden_states is not None @@ -68,7 +70,6 @@ def forward( hidden_states = encoder_layer( hidden_states, self.buffer, attention_mask, bsz, seqlen ) - if output_hidden_states: encoder_states = encoder_states + (hidden_states.reshape(bsz, seqlen, -1),) if not return_dict: @@ -84,7 +85,7 @@ class QuantSiglipMLP(nn.Module): def __init__(self, siglipmlp, init_only=False): super().__init__() self.config = siglipmlp.config - self.activation_fn = siglipmlp.activation_fn + self.activation_fn = getattr(siglipmlp, "activation_fn", None) self.fc1 = W8A8OF16LinearDynamicInputScale.from_linear( siglipmlp.fc1, init_only=init_only, fc1=False ) @@ -182,14 +183,14 @@ def __init__(self, module: SiglipEncoderLayer): super().__init__() self.embed_dim = module.embed_dim self.self_attn = QuantSiglipFlashAttention2(module.self_attn) - self.layer_norm1 = RMSNormGeneral( + self.layer_norm1 = LayerNormGeneral( module.layer_norm1.weight.data, module.layer_norm1.bias.data, module.layer_norm1.eps, True, ).cuda() self.mlp = QuantSiglipMLP(module.mlp) - self.layer_norm2 = RMSNormGeneral( + self.layer_norm2 = LayerNormGeneral( module.layer_norm2.weight.data, module.layer_norm2.bias.data, module.layer_norm2.eps, @@ -220,7 +221,6 @@ def forward( buffer.quantized_hidden_states_buffer, buffer.quantized_scale_buffer, ) - # INT8 -> FP16 self.self_attn(buffer, bsz, seqlen) hidden_states = ( @@ -234,7 +234,6 @@ def forward( buffer.quantized_hidden_states_buffer, buffer.quantized_scale_buffer, ) - # INT8 -> FP16 self.mlp(buffer) hidden_states = ( @@ -243,11 +242,11 @@ def forward( return hidden_states -class RMSNormGeneral(nn.Module): - """Root mean square normalization (w/ per-token or per-tensor quant). +class LayerNormGeneral(nn.Module): + """Layer normalization (w/ per-token or per-tensor quant). - Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. - Refer to https://arxiv.org/abs/1910.07467 + Computes x -> w * (x-E(x)) / sqrt(E[x^2] + eps) + b where w is the learned weight. + Refer to https://arxiv.org/abs/1607.06450 """ def __init__( @@ -271,7 +270,8 @@ def forward( quantized_sum_buffer: torch.Tensor = None, ) -> torch.Tensor: # quantized_sum_buffer is not used, only to keep the consistency of the interface - awq_inference_engine.rms_norm_general( + + awq_inference_engine.layer_norm_general( quantized_hidden_states_buffer, x, self.weight.data, @@ -279,4 +279,4 @@ def forward( quantized_scale_buffer, self.variance_epsilon, self.use_per_token_quant, - ) + ) \ No newline at end of file diff --git a/tinychat/scripts/nvila_demo.sh b/tinychat/scripts/nvila_demo.sh index 58f8bbc7..f034d548 100755 --- a/tinychat/scripts/nvila_demo.sh +++ b/tinychat/scripts/nvila_demo.sh @@ -17,6 +17,6 @@ python -m awq.entry --model_path $MODEL_PATH/llm \ # Run the TinyChat demo: python nvila_demo.py --model-path $MODEL_PATH \ --quant_path quant_cache/$MODEL_NAME-w4-g128-awq.pt \ - --media ../figures/nvila-logo.jpg \ + --media ../figures/vila-logo.jpg \ --act_scale_path awq_cache/$MODEL_NAME-smooth-scale.pt \ --all --chunk --model_type nvila --vis_image \ No newline at end of file