Skip to content

Commit

Permalink
Skip invalid gen config (#618)
Browse files Browse the repository at this point in the history
* chore: bump dev version

* feat(decoder): rebuild invalid generation config

* docs(tgi): use more recent model
  • Loading branch information
dacorvo authored Jun 3, 2024
1 parent 08e0499 commit 5b311a3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/source/guides/neuronx_tgi.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ possible to export it dynamically, pending some conditions:
The snippet below shows how you can deploy a service from a hub standard model:

```
export HF_TOKEN=<YOUR_TOKEN>
docker run -p 8080:80 \
-v $(pwd)/data:/data \
--privileged \
-e HF_TOKEN=${HF_TOKEN} \
-e HF_AUTO_CAST_TYPE="fp16" \
-e HF_NUM_CORES=2 \
ghcr.io/huggingface/neuronx-tgi:latest \
--model-id NousResearch/Llama-2-7b-chat-hf \
--model-id meta-llama/Meta-Llama-3-8B \
--max-batch-size 1 \
--max-input-length 3164 \
--max-total-tokens 4096
Expand Down
8 changes: 8 additions & 0 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import re
import shutil
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Optional, Tuple, Union
Expand Down Expand Up @@ -250,6 +251,13 @@ def _create_checkpoint(
**kwargs,
)

if model.generation_config is not None:
with warnings.catch_warnings(record=True) as caught_warnings:
model.generation_config.validate()
if len(caught_warnings) > 0:
logger.warning("Invalid generation config: recreating it from model config.")
model.generation_config = GenerationConfig.from_model_config(model.config)

# Save the model checkpoint in a temporary directory
checkpoint_dir = TemporaryDirectory()
os.chmod(checkpoint_dir.name, 0o775)
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.23.dev0"
__version__ = "0.0.24.dev0"

__sdk_version__ = "2.18.0"

0 comments on commit 5b311a3

Please sign in to comment.