diff --git a/.github/workflows/static-check.yml b/.github/workflows/static-check.yml
index e1530b5..5bf6c12 100644
--- a/.github/workflows/static-check.yml
+++ b/.github/workflows/static-check.yml
@@ -15,42 +15,10 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- - name: Check requirements.txt exists
- id: check_req
- run: |
- if [ -f requirements.txt ]; then
- echo "requirements_exists=true" >> $GITHUB_OUTPUT
- else
- echo "requirements_exists=false" >> $GITHUB_OUTPUT
- fi
- if [ -f pyproject.toml ]; then
- echo "pyproject_exists=true" >> $GITHUB_OUTPUT
- else
- echo "pyproject_exists=false" >> $GITHUB_OUTPUT
- fi
- - name: Install dependencies by requirements.txt
+ - name: Install pylint
id: install_deps_req
- if: ${{ steps.check_req.outputs.requirements_exists == 'true' }}
run: |
python -m pip install --upgrade pylint
- python -m pip install --upgrade isort
- python -m pip install -r requirements.txt
- echo "dependencies_installed=true" >> $GITHUB_OUTPUT
- - name: Analysing the code with pylint
- if: ${{ steps.check_req.outputs.requirements_exists == 'true' }}
- run: |
- isort $(git ls-files '*.py') --check-only --diff
- pylint $(git ls-files '*.py')
- - name: Install dependencies by uv
- id: install_deps_uv
- if: ${{ steps.check_req.outputs.pyproject_exists == 'true' }}
- run: |
- python -m pip install uv
- uv sync
- uv pip install pylint
- uv pip install isort
- name: Analysing the code with pylint
- if: ${{ steps.check_req.outputs.pyproject_exists == 'true' }}
run: |
- uv run isort $(git ls-files '*.py') --check-only --diff
- uv run pylint $(git ls-files '*.py')
+ grep -r -l "# LINT_ME" **/*.py|xargs pylint
\ No newline at end of file
diff --git a/.pylintrc b/.pylintrc
index 7e930ac..01d1801 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -31,7 +31,7 @@ extension-pkg-allow-list=
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
-extension-pkg-whitelist=cv2
+extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
@@ -59,15 +59,16 @@ ignore-paths=
# Emacs file locks
ignore-patterns=^\.#
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis). It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=cv2
+# List of module names for which member attributes should not be checked and
+# will not be imported (useful for modules/projects where namespaces are
+# manipulated during runtime and thus existing member attributes cannot be
+# deduced by static analysis). It supports qualified module names, as well as
+# Unix pattern matching.
+ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
-init-hook='import sys; sys.path.append(".")'
+#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
@@ -86,9 +87,13 @@ load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
+# Resolve imports to .pyi stubs if available. May reduce no-member messages and
+# increase not-an-iterable messages.
+prefer-stubs=no
+
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
-py-version=3.10
+py-version=3.12
# Discover python modules and packages in the file system subtree.
recursive=no
@@ -99,10 +104,6 @@ recursive=no
# source root.
source-roots=
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages.
-suggestion-mode=yes
-
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
@@ -229,6 +230,11 @@ name-group=
# not require a docstring.
no-docstring-rgx=^_
+# Regular expression matching correct parameter specification variable names.
+# If left empty, parameter specification variable names will be checked with
+# the set naming style.
+#paramspec-rgx=
+
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
@@ -242,13 +248,17 @@ property-classes=abc.abstractproperty
# variable names will be checked with the set naming style.
#typevar-rgx=
+# Regular expression matching correct type variable tuple names. If left empty,
+# type variable tuple names will be checked with the set naming style.
+#typevartuple-rgx=
+
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
-variable-rgx=(_?[a-z][A-Za-z0-9]{0,30})|([A-Z0-9]{1,30})
+#variable-rgx=
[CLASSES]
@@ -285,10 +295,10 @@ exclude-too-few-public-methods=
ignored-parents=
# Maximum number of arguments for function / method.
-max-args=7
+max-args=5
# Maximum number of attributes for a class (see R0902).
-max-attributes=20
+max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
@@ -302,6 +312,9 @@ max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
+# Maximum number of positional arguments for function / method.
+max-positional-arguments=5
+
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
@@ -309,10 +322,10 @@ max-public-methods=20
max-returns=6
# Maximum number of statements in function / method body.
-max-statements=300
+max-statements=50
# Minimum number of public methods for a class (see R0903).
-min-public-methods=1
+min-public-methods=2
[EXCEPTIONS]
@@ -336,11 +349,13 @@ indent-after-paren=4
# tab).
indent-string=' '
-# Maximum number of characters on a single line.
-max-line-length=150
+# Maximum number of characters on a single line. Pylint's default of 100 is
+# based on PEP 8's guidance that teams may choose line lengths up to 99
+# characters.
+max-line-length=100
# Maximum number of lines in a module.
-max-module-lines=2000
+max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
@@ -421,11 +436,16 @@ confidence=HIGH,
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
-disable=too-many-arguments,
- too-many-locals,
- too-many-branches,
- protected-access
-
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-symbolic-message-instead,
+ use-implicit-booleaness-not-comparison-to-string,
+ use-implicit-booleaness-not-comparison-to-zero,
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
@@ -433,6 +453,13 @@ disable=too-many-arguments,
# it should appear only once). See also the "--disable" option for examples.
enable=
+[MASTER]
+# A comma-separated list of package or module names
+# from where C extensions may#be loaded. Extensions
+# are loading into the active Python interpreter and
+# may run arbitrary code .
+extension-pkg-allow-list=torch,transformers
+
[METHOD_ARGS]
@@ -443,9 +470,13 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.
[MISCELLANEOUS]
+# Whether or not to search for fixme's in docstrings.
+check-fixme-in-docstring=no
+
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
- XXX
+ XXX,
+ TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
@@ -465,7 +496,7 @@ never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
-# suggest-join-with-non-empty-separator=yes
+suggest-join-with-non-empty-separator=yes
[REPORTS]
@@ -481,10 +512,10 @@ evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor
# used to format the message information. See doc for all details.
msg-template=
-# Set the output format. Available formats are: text, parseable, colorized,
-# json2 (improved json format), json (old json format) and msvs (visual
-# studio). You can also give a reporter class, e.g.
-# mypackage.mymodule.MyReporterClass.
+# Set the output format. Available formats are: 'text', 'parseable',
+# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs
+# (visual studio) and 'github' (GitHub actions). You can also give a reporter
+# class, e.g. mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
@@ -582,11 +613,14 @@ ignored-checks-for-mixins=no-member,
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+ignored-modules=torch,torch.distributed,transformers,transformers.hf_argparser
+
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
-# The minimum edit distance a name should have in order to be considered a
+# The maximum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
@@ -630,4 +664,4 @@ init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
\ No newline at end of file
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..8cc1b46
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.10.15
diff --git a/README.md b/README.md
index 0fcb5dd..4926ca2 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,127 @@
-# Coming Soon...
+
WAM-Diff: A Masked Diffusion VLA Framework with MoE and Online Reinforcement Learning for Autonomous Driving
+
+
+
+
+ 1Fudan University 2Yinwang Intelligent Technology Co., Ltd
+
+
+
+## 📅️ Roadmap
+
+| Status | Milestone | ETA |
+| :----: | :----------------------------------------------------------------------------------------------------: | :--------: |
+| 🚀 | **[Releasing the inference source code](https://github.com/fudan-generative-vision/WAM-Diff)** | 2025.12.21 |
+| 🚀 | **[Releasing the training scripts](https://github.com/fudan-generative-vision/WAM-Diff)** | 2025.12.21 |
+| 🚀 | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Diff)** | TBD |
+
+
+## 🔧️ Framework
+
+
+## 🏆 Qualitative Results on NAVSIM
+### NAVSIM-v1 benchmark results
+
+

+
+
+### NAVSIM-v2 benchmark results
+
+

+
+
+
+
+## Quick Inference Demo
+The WAM-Diff will be available on Hugging Face Hub soon. To quickly test the model, follow these simple steps:
+
+1. **Clone the repository**
+ ```bash
+ git clone https://github.com/fudan-generative-vision/WAM-Diff
+ cd WAM-Diff
+ ```
+2. **Initialize the environment**
+ If you prefer conda, run the environment setup script to install necessary dependencies:
+ ```bash
+ bash init_env.sh
+ ```
+ Or you can use uv to create the environment:
+ ```bash
+ uv venv && uv sync
+ ```
+3. **Prepare the Model**
+ Download the pretrained WAM-Diff model from Hugging Face (pending release) to the `./model/WAM-Diff` directory:
+ ```
+ https://huggingface.co/fudan-generative-ai/WAM-Diff
+ ```
+ Download the pretrained Siglip2 model from Hugging Face to the `./model/siglip2-so400m-patch14-384` directory:
+ ```
+ https://huggingface.co/google/siglip2-so400m-patch14-384
+ ```
+
+
+3. **Run the demo script**
+ Execute the demo script to test WAM-Diff on an example image:
+ ```bash
+ bash inf.sh
+ ```
+
+## Training
+To fine-tune WAM-Diff, please follow these steps:
+1. **Set Up the Environment**
+ Follow the same environment setup steps as in the Quick Inference Demo section.
+2. **Prepare the Data**
+Prepare your training dataset in JSON format like
+ ```json
+ [
+ {
+ "image": ["path/to/image1.png"],
+ "conversations": [
+ {
+ "from": "human",
+ "value": "Here is front views of a driving vehicle:\n\nThe navigation information is: straight\nThe current position is (0.00,0.00)\nCurrent velocity is: (13.48,-0.29) and current accelerate is: (0.19,0.05)\nPredict the optimal driving action for the next 4 seconds with 8 new waypoints."
+ },
+ {
+ "from": "gpt",
+ "value": "6.60,-0.01,13.12,-0.03,19.58,-0.04,25.95,-0.03,32.27,-0.03,38.56,-0.05,44.88,-0.06,51.16,-0.09"
+ }
+ ]
+ },
+ ...
+ ]
+ ```
+3. **Run the Training Script**
+ Execute the training script with the following command:
+ ```bash
+ cd train
+ bash ./scripts/llada_v_finetune.sh
+ ```
+
+## 📝 Citation
+
+If you find our work useful for your research, please consider citing the paper:
+
+```
+@article{xu2025wam,
+ title={WAM-Diff: A Masked Diffusion VLA Framework with MoE and Online Reinforcement Learning for Autonomous Driving},
+ author={Xu, Mingwang and Cui, Jiahao and Cai, Feipeng and Shang, Hanlin and Zhu, Zhihao and Luan, Shan and Xu, Yifang and Zhang, Neng and Li, Yaoyi and Cai, Jia and others},
+ journal={arXiv preprint arXiv:2512.11872},
+ year={2025}
+}
+```
+
+## 🤗 Acknowledgements
+We gratefully acknowledge the contributors to the [LLaDA-V](https://github.com/ML-GSAI/LLaDA-V), repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models.
\ No newline at end of file
diff --git a/assets/image.png b/assets/image.png
new file mode 100644
index 0000000..3d1b68c
Binary files /dev/null and b/assets/image.png differ
diff --git a/assets/main_arch.png b/assets/main_arch.png
new file mode 100644
index 0000000..2b3f8e4
Binary files /dev/null and b/assets/main_arch.png differ
diff --git a/assets/navsim-v1.png b/assets/navsim-v1.png
new file mode 100644
index 0000000..8ca33c4
Binary files /dev/null and b/assets/navsim-v1.png differ
diff --git a/assets/navsim-v2.png b/assets/navsim-v2.png
new file mode 100644
index 0000000..26f5b62
Binary files /dev/null and b/assets/navsim-v2.png differ
diff --git a/envs.yml b/envs.yml
new file mode 100644
index 0000000..d3215ed
--- /dev/null
+++ b/envs.yml
@@ -0,0 +1,7 @@
+name: WAM-Diff
+channels:
+ - conda-forge
+ - bioconda
+ - defaults
+dependencies:
+ - python=3.10
diff --git a/inf.sh b/inf.sh
new file mode 100644
index 0000000..2afc538
--- /dev/null
+++ b/inf.sh
@@ -0,0 +1 @@
+torchrun --master_addr=127.0.0.1 --master_port=12346 --nproc-per-node=1 ./train/generate_demo_navsim_moe.py --pretrained_path path_to_WAM-Diff_ckpt
\ No newline at end of file
diff --git a/init_env.sh b/init_env.sh
new file mode 100644
index 0000000..2bd26e4
--- /dev/null
+++ b/init_env.sh
@@ -0,0 +1,48 @@
+conda env create -f envs.yml
+conda activate WAM-Diff
+
+#!/bin/bash
+
+# log file for failed installations
+FAIL_LOG="failed_packages.log"
+
+# rm existing log file
+> $FAIL_LOG
+
+while IFS= read -r line; do
+ # skip empty lines and comments
+ if [[ -z "$line" || "$line" =~ ^# ]]; then
+ continue
+ fi
+
+ echo "====================================="
+ echo "Installing: $line"
+ echo "====================================="
+
+ # Execute installation command
+ pip install $line $PIP_MIRROR
+
+ # Check if installation was successful
+ if [ $? -ne 0 ]; then
+ echo "Installation failed: $line" >> $FAIL_LOG
+ echo "====================================="
+ echo "⚠️ $line installation failed, logged to $FAIL_LOG"
+ echo "====================================="
+ else
+ echo "====================================="
+ echo "✅ $line installed successfully"
+ echo "====================================="
+ fi
+ sleep 1
+
+done < "requirements.txt"
+
+echo "====================================="
+echo "Installation complete!"
+if [ -s $FAIL_LOG ]; then
+ echo "❌ The following packages failed to install:"
+ cat $FAIL_LOG
+else
+ echo "✅ All packages installed successfully!"
+fi
+echo "====================================="
\ No newline at end of file
diff --git a/model/.gitkeep b/model/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..8715476
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,65 @@
+[project]
+name = "wam-diff"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = "==3.10.15"
+dependencies = [
+ "accelerate==0.29.1",
+ "aioboto3>=15.5.0",
+ "aiofiles>=25.1.0",
+ "bokeh==2.4.3",
+ "casadi>=3.7.2",
+ "control==0.9.1",
+ "datasets==2.16.1",
+ "deepspeed==0.14.4",
+ "einops>=0.8.1",
+ "fiona>=1.10.1",
+ "geopandas>=0.12.1",
+ "guppy3==3.1.2",
+ "hydra-core==1.2.0",
+ "imageio>=2.37.2",
+ "joblib>=1.5.2",
+ "matplotlib>=3.10.8",
+ "nest-asyncio>=1.6.0",
+ "notebook>=7.5.0",
+ "numpy==1.23.4",
+ "nuplan-devkit",
+ "opencv-python==4.9.0.80",
+ "pandas>=2.3.3",
+ "pillow>=12.0.0",
+ "positional-encodings==6.0.1",
+ "protobuf==3.20.0",
+ "psutil>=7.1.3",
+ "pyarrow>=22.0.0",
+ "pyinstrument>=5.1.1",
+ "pylint>=4.0.4",
+ "pynvml==12.0.0",
+ "pyogrio>=0.12.1",
+ "pyquaternion>=0.9.5",
+ "pytest>=9.0.2",
+ "pytorch-lightning==2.2.1",
+ "rasterio>=1.3.11",
+ "retry>=0.9.2",
+ "rtree>=1.4.1",
+ "scikit-learn==1.2.2",
+ "scipy>=1.13.1",
+ "selenium>=4.39.0",
+ "setuptools==65.5.1",
+ "shapely>=2.0.0",
+ "sqlalchemy==1.4.27",
+ "sympy>=1.14.0",
+ "tensorboard==2.16.2",
+ "timm>=1.0.22",
+ "torch==2.8.0",
+ "torchvision==0.23.0",
+ "tornado>=6.5.3",
+ "tqdm>=4.67.1",
+ "transformers==4.43.0",
+ "tyro==0.9.28",
+ "ujson>=5.11.0",
+]
+[[tool.uv.index]]
+url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..0404460
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,203 @@
+accelerate==0.29.1
+aiohappyeyeballs==2.6.1
+aiohttp==3.12.15
+aiosignal==1.4.0
+anls==0.0.2
+annotated-types==0.7.0
+anyio==4.10.0
+async-timeout==5.0.1
+attrs==25.3.0
+av==15.0.0
+bitsandbytes==0.45.3
+black==24.1.0
+blis==1.3.0
+braceexpand==0.1.7
+capture-metric==0.1.13
+catalogue==2.0.10
+certifi==2025.8.3
+cfgv==3.4.0
+chardet==5.2.0
+charset-normalizer==3.4.3
+click==8.2.2
+cloudpathlib==0.21.1
+colorama==0.4.6
+confection==0.1.5
+cymem==2.0.11
+dataproperty==1.1.0
+datasets==2.16.1
+debugpy==1.8.17
+decord==0.6.0
+deepspeed==0.14.4
+dill==0.3.7
+discosg==0.0.5
+distlib==0.4.0
+distro==1.9.0
+docstring-parser==0.17.0
+einops==0.6.1
+einops-exts==0.0.4
+et-xmlfile==2.0.0
+evaluate==0.4.5
+exceptiongroup==1.3.0
+factualscenegraph==0.7.3
+fastapi==0.116.1
+filelock==3.19.1
+frozenlist==1.7.0
+fsspec==2023.10.0
+ftfy==6.3.1
+gitdb==4.0.12
+gitpython==3.1.45
+gradio-client==0.2.9
+h11==0.14.0
+hf-transfer==0.1.9
+hf-xet==1.1.7
+hjson==3.1.0
+httpcore==0.16.3
+httpx==0.24.0
+huggingface-hub==0.34.4
+identify==2.6.13
+idna==3.10
+imageio==2.37.0
+imageio-ffmpeg==0.6.0
+isort==5.13.2
+jinja2==3.1.5
+jiter==0.10.0
+joblib==1.5.1
+jsonlines==4.0.0
+langcodes==3.5.0
+language-data==1.3.0
+lazy-import==0.2.2
+levenshtein==0.27.1
+loguru==0.7.3
+lxml==6.0.0
+marisa-trie==1.3.0
+markdown-it-py==4.0.0
+markupsafe==3.0.2
+mbstrdecoder==1.1.4
+mdurl==0.1.2
+mpmath==1.3.0
+multidict==6.6.4
+multiprocess==0.70.15
+murmurhash==1.0.13
+mypy-extensions==1.1.0
+networkx==3.4.2
+ninja==1.13.0
+nltk==3.9.1
+nodeenv==1.9.1
+numexpr==2.11.0
+nvidia-cublas-cu12
+nvidia-cuda-cupti-cu12
+nvidia-cuda-nvrtc-cu12
+nvidia-cuda-runtime-cu12
+nvidia-cudnn-cu12
+nvidia-cufft-cu12
+nvidia-cufile-cu12
+nvidia-curand-cu12
+nvidia-cusolver-cu12
+nvidia-cusparse-cu12
+nvidia-cusparselt-cu12
+nvidia-ml-py==12.575.51
+nvidia-nccl-cu12==2.21.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.4.127
+open-clip-torch==3.1.0
+openai==1.100.0
+openpyxl==3.1.5
+packaging==25.0
+pandas==2.3.1
+pathlib2==2.3.7.post1
+pathspec==0.12.1
+pathvalidate==3.3.1
+peft==0.9.0
+pillow==11.3.0
+platformdirs==4.3.8
+portalocker==3.2.0
+pre-commit==4.3.0
+preshed==3.0.10
+propcache==0.3.2
+protobuf==3.20.0
+psutil==7.0.0
+py-cpuinfo==9.0.0
+pyarrow==21.0.0
+pyarrow-hotfix==0.7
+pybind11==3.0.0
+pycocoevalcap==1.2
+pycocotools==2.0.10
+pycryptodome==3.23.0
+pydantic==2.9.2
+pydantic-core==2.33.2
+pygments==2.19.2
+pynvml==12.0.0
+pytablewriter==1.2.1
+python-dateutil==2.9.0.post0
+pytz==2025.2
+pywsd==1.2.5
+pyyaml==6.0.2
+rapidfuzz==3.13.0
+regex==2025.7.34
+requests==2.32.4
+rfc3986==1.5.0
+rich==14.1.0
+rouge==1.0.1
+sacrebleu==2.5.1
+safetensors==0.6.2
+scikit-learn==1.7.1
+scipy==1.15.3
+sentence-transformers==5.1.0
+sentencepiece==0.1.99
+sentry-sdk==2.35.0
+shellingham==1.5.4
+shortuuid==1.0.13
+shtab==1.7.2
+six==1.17.0
+smart-open==7.3.0.post1
+smmap==5.0.2
+sniffio==1.3.1
+spacy==3.8.7
+spacy-legacy==3.0.12
+spacy-loggers==1.0.5
+sqlitedict==2.1.0
+srsly==2.5.1
+starlette==0.47.2
+sympy==1.13.1
+tabledata==1.3.4
+tabulate==0.9.0
+tcolorpy==0.1.7
+tenacity==8.3.0
+tensorboardx==2.6.4
+thinc==8.3.6
+threadpoolctl==3.6.0
+tiktoken==0.11.0
+timm==1.0.19
+tokenizers==0.19.1
+tomli==2.2.1
+torch==2.6.0
+torchvision==0.21.0
+tqdm==4.67.1
+transformers==4.43.0
+transformers-stream-generator==0.0.5
+triton==3.2.0
+typeguard==4.4.4
+typepy==1.3.4
+typer==0.16.1
+typing-extensions==4.14.1
+typing-inspection==0.4.1
+tyro==0.9.28
+tzdata==2025.2
+urllib3==2.0.0
+uvicorn==0.35.0
+virtualenv==20.34.0
+wandb==0.21.1
+wasabi==1.1.3
+wcwidth==0.2.13
+weasel==0.4.1
+webdataset==1.0.2
+websockets==15.0.1
+wn==0.0.23
+wrapt==1.17.3
+xxhash==3.5.0
+yarl==1.20.1
+yt-dlp==2025.8.11
+zss==1.2.0
+zstandard==0.24.0
+opencv-python==4.12.0.88
+opencv-python-headless==4.12.0.88
\ No newline at end of file
diff --git a/train/LICENSE b/train/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/train/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/train/generate_demo_navsim_moe.py b/train/generate_demo_navsim_moe.py
new file mode 100644
index 0000000..8f75dbf
--- /dev/null
+++ b/train/generate_demo_navsim_moe.py
@@ -0,0 +1,352 @@
+# LINT_ME
+"""Generate demo navigation predictions using NavSim-MOE model in MoE setting."""
+
+from __future__ import annotations
+
+import copy
+import json
+import os
+from dataclasses import asdict, dataclass, field
+from pathlib import Path
+from typing import List, Tuple
+
+# Remove these if not needed in your environment.
+import llava.s3_ops as st
+import torch
+import torch.distributed as dist
+
+# llava / model imports
+from llava.cache import dLLMCache, dLLMCacheConfig
+from llava.conversation import conv_templates
+from llava.hooks import register_cache_LLaDA_V
+from llava.hooks.fast_dllm_hook import register_fast_dllm_hook
+from llava.mm_utils import IMAGE_TOKEN_INDEX, process_images, tokenizer_image_token
+from llava.model.builder import load_pretrained_model
+from transformers.hf_argparser import HfArgumentParser
+
+
+# =========================
+# Dataclass Args (HF-style)
+# =========================
+@dataclass
+class DataArgs:
+ """
+ Data-related arguments.
+ """
+ navsim_root: str = field(default="None")
+ steps_fut: int = 8
+ steps_hist: int = 4
+ camera_order: List[str] = field(
+ default_factory=lambda: [
+ "CAM_FRONT",
+ "CAM_FRONT_LEFT",
+ "CAM_FRONT_RIGHT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+ ]
+ )
+
+
+@dataclass
+class ModelArgs: # pylint: disable=R0902
+ """Model Configuration Arguments."""
+ pretrained_path: str
+ conv_template: str = "llava_llada"
+ use_fast_dllm: bool = True
+ use_dllm_cache: bool = False
+ prompt_interval_steps: int = 25
+ gen_interval_steps: int = 7
+ transfer_ratio: float = 0.25
+ token_gen_steps: int = 32
+ token_gen_length: int = 32
+ token_block_length: int = 32
+
+
+@dataclass
+class RuntimeArgs:
+ """Runtime / Distributed Training Arguments."""
+ save_jsons_folder: str = "output/jsons"
+ device_type: str = "cuda" # cuda|npu|cpu
+ backend: str = "nccl" # nccl|gloo
+ merge_after: bool = False
+ debug_enable: bool = False
+
+
+# =========================
+# DDP helpers
+# =========================
+def setup_distributed(device_type: str = "npu", backend: str = "hccl"):
+ """
+ Initialize torch.distributed, pin local rank to the right device,
+ and return (rank, world_size, local_rank, device_str).
+ """
+ # Figure out ranks
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ rank = int(os.environ.get("RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+ # Pin device
+ device_type = device_type.lower()
+ if device_type == "cuda":
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA requested but not available.")
+ torch.cuda.set_device(local_rank)
+ device_str = f"cuda:{local_rank}"
+ elif device_type == "npu":
+ torch.npu.set_device(local_rank)
+ device_str = f"npu:{local_rank}"
+ else:
+ device_str = "cpu"
+
+ # Init process group
+ if dist.is_available() and not dist.is_initialized():
+ dist.init_process_group(backend=backend)
+
+ return rank, world_size, local_rank, device_str
+
+
+def get_rank_world() -> Tuple[int, int]:
+ """Get the current process rank and world size for distributed training."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_rank(), dist.get_world_size()
+ return 0, 1
+
+
+def barrier():
+ """DDP barrier."""
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
+
+
+def is_rank0() -> bool:
+ """Check if current process is rank 0."""
+ r, _ = get_rank_world()
+ return r == 0
+
+
+def get_local_rank() -> int:
+ """Get local rank from env vars."""
+ if "LOCAL_RANK" in os.environ:
+ return int(os.environ["LOCAL_RANK"])
+ if "RANK" in os.environ:
+ num = torch.cuda.device_count() if torch.cuda.is_available() else 1
+ return int(os.environ["RANK"]) % max(1, num)
+ return 0
+
+
+def pick_device(device_type: str, local_rank: int) -> str:
+ """Pick device string based on type and local rank."""
+ device_type = device_type.lower()
+ if device_type == "cuda":
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA requested but not available.")
+ return f"cuda:{local_rank}"
+ if device_type == "npu":
+ # Adjust if using a different NPU stack
+ return f"npu:{local_rank}"
+ if device_type == "cpu":
+ return "cpu"
+ raise ValueError(
+ f"Unsupported --device_type '{device_type}'. Use one of: cuda|npu|cpu."
+ )
+
+
+# =========================
+# Arg parsing (HF)
+# =========================
+def parse_args() -> tuple[DataArgs, ModelArgs, RuntimeArgs]:
+ """
+ Supports:
+ - CLI flags
+ - Config file (JSON/TOML/YAML): eval_nuscenes_ddp.py cfg.json
+ - Mixed (file + overrides)
+ """
+ parser = HfArgumentParser((DataArgs, ModelArgs, RuntimeArgs))
+ data_args, model_args, runtime_args = parser.parse_args_into_dataclasses(
+ return_remaining_strings=False
+ )
+ runtime_args.device_type = runtime_args.device_type.lower()
+ runtime_args.backend = runtime_args.backend.lower()
+ return data_args, model_args, runtime_args
+
+
+def _int_stem(p: Path) -> int:
+ """Get integer from path stem."""
+ try:
+ return int(p.stem)
+ except ValueError: # narrow exception (fixes W0718)
+ return 1_000_000_000
+
+
+def merge_rank_jsons(save_jsons_folder: str, world_size: int) -> None:
+ """Merge per-rank JSON files into a single JSONL file."""
+ save_dir = Path(save_jsons_folder)
+ merged_path = save_dir / "merged.jsonl"
+
+ numbered: list[tuple[int, Path]] = []
+ for r in range(world_size):
+ rank_dir = save_dir / f"rank_{r}"
+ if not rank_dir.exists():
+ continue
+
+ for p in rank_dir.glob("*.json"):
+ idx = _int_stem(p)
+ if idx is None:
+ continue
+ numbered.append((idx, p))
+
+ numbered.sort(key=lambda t: t[0])
+
+ count = 0
+ with st.open_file(str(merged_path), "w") as out_f:
+ for _, p in numbered:
+ with st.open_file(str(p), "r") as in_f:
+ out_f.write(in_f.read().rstrip("\n") + "\n")
+ count += 1
+
+ print(f"[Rank 0] merged {count} records -> {merged_path}")
+
+# =========================
+# Main
+# =========================
+def main(): # pylint: disable=R0912, R0914, R0915
+ """Main function."""
+ data_args, model_args, runtime_args = parse_args()
+ # DDP init
+ rank, world_size, _, device_str = setup_distributed(
+ device_type=runtime_args.device_type,
+ backend=runtime_args.backend,
+ )
+
+ # Device & IO
+ device_map = device_str
+
+ st.makedirs(runtime_args.save_jsons_folder)
+ rank_out_dir = Path(runtime_args.save_jsons_folder) / f"rank_{rank}"
+ st.makedirs(str(rank_out_dir))
+
+ if is_rank0():
+ print(
+ f"[DDP] world_size={world_size} | saving to: {runtime_args.save_jsons_folder}"
+ )
+ # Save resolved config
+ with st.open_file(
+ str(Path(runtime_args.save_jsons_folder) / "resolved_args.json"), "w"
+ ) as f:
+ json.dump(
+ {
+ "DataArgs": vars(data_args),
+ "ModelArgs": vars(model_args),
+ "RuntimeArgs": vars(runtime_args),
+ },
+ f,
+ indent=2,
+ )
+
+ # Load model (per rank)
+ model_name = "llava_llada"
+ tokenizer, model, image_processor, _ = load_pretrained_model(
+ model_args.pretrained_path,
+ None,
+ model_name,
+ attn_implementation="sdpa",
+ device_map=device_map,
+ )
+ model = model.to(device_str)
+ model.config.image_aspect_ratio = "anyres_max_2"
+ model.eval()
+
+ # Optional speed-ups
+ if model_args.use_fast_dllm:
+ register_fast_dllm_hook(model)
+ if is_rank0():
+ print("[Fast-dLLM] enabled")
+ elif model_args.use_dllm_cache:
+ dLLMCache.new_instance(
+ **asdict(
+ dLLMCacheConfig(
+ prompt_interval_steps=model_args.prompt_interval_steps,
+ gen_interval_steps=model_args.gen_interval_steps,
+ transfer_ratio=model_args.transfer_ratio,
+ )
+ )
+ )
+ register_cache_LLaDA_V(model, "model.layers")
+ if is_rank0():
+ print("[dLLM-Cache] enabled")
+ else:
+ if is_rank0():
+ print("[Cache] disabled")
+
+ torch.set_grad_enabled(False)
+
+ surrounding_views = ["./assets/image.png"]
+
+ images = [st.image_open(p) for p in surrounding_views]
+
+ if len(images) > 1:
+ model.config.image_aspect_ratio = "pad"
+
+ image_tensor = process_images(images, image_processor, model.config)
+ image_tensor = [
+ _img.to(dtype=torch.float16, device=device_str) for _img in image_tensor
+ ]
+ image_sizes = [img.size for img in images]
+
+ # Prompt
+ conv = copy.deepcopy(conv_templates[model_args.conv_template])
+ question = """Here is front views of a driving vehicle:
+ \nThe navigation information is: straight
+ The current position is (0.00,0.00)
+ Current velocity is: (8.92,-0.15) and current acceleration is: (-1.25,-0.01)
+ Predict the optimal driving action for the next 4 seconds with 8 new waypoints.
+ """
+ conv.append_message(conv.roles[0], question)
+ conv.append_message(conv.roles[1], None)
+ prompt_question = conv.get_prompt()
+
+ input_ids = (
+ tokenizer_image_token(
+ prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
+ )
+ .unsqueeze(0)
+ .to(device_str)
+ )
+
+ # Generate
+ cont, track_x = model.generate(
+ input_ids,
+ images=image_tensor,
+ image_sizes=image_sizes,
+ steps=model_args.token_gen_steps,
+ gen_length=model_args.token_gen_length,
+ block_length=model_args.token_block_length,
+ tokenizer=tokenizer,
+ stopping_criteria=["<|eot_id|>"],
+ prefix_refresh_interval=32,
+ threshold=1,
+ )
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=False)
+ print(text_outputs)
+
+ with open("output.txt", "w", encoding="utf-8") as f:
+ for x_step in track_x:
+ text_outputs = tokenizer.batch_decode(x_step, skip_special_tokens=False)
+
+ # Write each decoded string on its own line
+ for text in text_outputs:
+ f.write(text + "\n")
+
+ # Add separator after each step
+ f.write("--------\n")
+
+ barrier()
+
+ # Merge (rank 0)
+ if is_rank0() and runtime_args.merge_after:
+ merge_rank_jsons(runtime_args.save_jsons_folder, world_size)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train/init_env.sh b/train/init_env.sh
new file mode 100644
index 0000000..0df7957
--- /dev/null
+++ b/train/init_env.sh
@@ -0,0 +1,2 @@
+pip install -e ".[train]"
+pip install webdataset
diff --git a/train/llava/__init__.py b/train/llava/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/train/llava/cache/Cache.py b/train/llava/cache/Cache.py
new file mode 100644
index 0000000..282bd10
--- /dev/null
+++ b/train/llava/cache/Cache.py
@@ -0,0 +1,87 @@
+# pylint: disable=C0114,C0115,C0116,C0103
+
+from collections import defaultdict
+
+import torch
+
+
+class Singleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class dLLMCache(metaclass=Singleton):
+ gen_interval_steps: int
+ prompt_interval_steps: int
+ cfg_interval_steps: int
+ prompt_length: int
+ transfer_ratio: float
+ __cache: defaultdict
+ __step_counter: defaultdict
+ cache_type: str
+
+ @classmethod
+ def new_instance(
+ cls,
+ prompt_interval_steps: int = 1,
+ gen_interval_steps: int = 1,
+ cfg_interval_steps: int = 1,
+ transfer_ratio: float = 0.0,
+ ) -> "dLLMCache":
+ ins = cls()
+ setattr(ins, "prompt_interval_steps", prompt_interval_steps)
+ setattr(ins, "gen_interval_steps", gen_interval_steps)
+ setattr(ins, "cfg_interval_steps", cfg_interval_steps)
+ setattr(ins, "transfer_ratio", transfer_ratio)
+ ins.init()
+ return ins
+
+ def init(self) -> None:
+ self.__cache = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
+ )
+ self.__step_counter = defaultdict(lambda: defaultdict(lambda: 0))
+
+ def reset_cache(self, prompt_length: int = 0) -> None:
+ self.init()
+ torch.cuda.empty_cache()
+ self.prompt_length = prompt_length
+ self.cache_type = "no_cfg"
+
+ def set_cache(
+ self, layer_id: int, feature_name: str, features: torch.Tensor, cache_type: str
+ ) -> None:
+ self.__cache[self.cache_type][cache_type][layer_id][feature_name] = {
+ 0: features
+ }
+
+ def get_cache(
+ self, layer_id: int, feature_name: str, cache_type: str
+ ) -> torch.Tensor:
+ output = self.__cache[self.cache_type][cache_type][layer_id][feature_name][0]
+ return output
+
+ def update_step(self, layer_id: int) -> None:
+ self.__step_counter[self.cache_type][layer_id] += 1
+
+ def refresh_gen(self) -> bool:
+ return (self.current_step - 1) % self.gen_interval_steps == 0
+
+ def refresh_prompt(self) -> bool:
+ return (self.current_step - 1) % self.prompt_interval_steps == 0
+
+ def refresh_cfg(self) -> bool:
+ return (
+ self.current_step - 1
+ ) % self.cfg_interval_steps == 0 or self.current_step <= 5
+
+ @property
+ def current_step(self) -> int:
+ return max(list(self.__step_counter[self.cache_type].values()), default=1)
+
+ def __repr__(self):
+ return "USE dLLMCache"
diff --git a/train/llava/cache/Config.py b/train/llava/cache/Config.py
new file mode 100644
index 0000000..d48e835
--- /dev/null
+++ b/train/llava/cache/Config.py
@@ -0,0 +1,11 @@
+# pylint: disable=C0114,C0115,C0116,C0103
+
+from dataclasses import dataclass
+
+
+@dataclass
+class dLLMCacheConfig:
+ prompt_interval_steps: int = 1
+ gen_interval_steps: int = 1
+ transfer_ratio: float = 0.0
+ cfg_interval_steps: int = 1
diff --git a/train/llava/cache/__init__.py b/train/llava/cache/__init__.py
new file mode 100644
index 0000000..a8ef4bf
--- /dev/null
+++ b/train/llava/cache/__init__.py
@@ -0,0 +1,6 @@
+# pylint: disable=C0114,C0115,C0116
+
+from .Cache import dLLMCache
+from .Config import dLLMCacheConfig
+
+__all__ = ["dLLMCache", "dLLMCacheConfig"]
diff --git a/train/llava/constants.py b/train/llava/constants.py
new file mode 100644
index 0000000..be8cf02
--- /dev/null
+++ b/train/llava/constants.py
@@ -0,0 +1,12 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
diff --git a/train/llava/conversation.py b/train/llava/conversation.py
new file mode 100644
index 0000000..365bb84
--- /dev/null
+++ b/train/llava/conversation.py
@@ -0,0 +1,396 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Any, Union
+import re
+import base64
+from io import BytesIO
+from PIL import Image
+from transformers import AutoTokenizer
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ CHATML = auto()
+ LLAMA_2 = auto()
+ LLAMA_3 = auto()
+ QWEN = auto()
+ GEMMA = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ tokenizer_id: str = ""
+ tokenizer: Any = None
+ # Stop criteria (the default one is EOS token)
+ stop_str: Union[str, List[str]] = None
+ # Stops generation if meeting any token in this list
+ stop_token_ids: List[int] = None
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0]
+ if "mmtag" in self.version:
+ init_msg = init_msg.replace("", "").strip()
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ elif not init_msg.startswith(""):
+ init_msg = init_msg.replace("", "").strip()
+ messages[0] = (init_role, "\n" + init_msg)
+ else:
+ messages[0] = (init_role, init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+
+ elif self.sep_style == SeparatorStyle.CHATML:
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, images, _ = message
+ message = "" * len(images) + message
+ ret += role + "\n" + message + self.sep + "\n"
+ else:
+ ret += role + "\n"
+ return ret
+
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
+ if self.tokenizer is None:
+ if self.version == "llama_v3":
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
+ )
+ elif self.version == "llava_llada":
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ "model/LLaDA-V", trust_remote_code=True
+ )
+ else:
+ raise ValueError(
+ "The tokenizer is not available. Make sure you have the necessary permissions."
+ )
+ chat_template_messages = [{"role": "system", "content": self.system}]
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, images = message
+ message = "" * len(images) + message
+ chat_template_messages.append({"role": role, "content": message})
+
+ # print(chat_template_messages)
+ return self.tokenizer.apply_chat_template(
+ chat_template_messages, tokenize=False, add_generation_prompt=True
+ )
+
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+
+ elif self.sep_style == SeparatorStyle.GEMMA:
+ ret = ""
+ for i, (role, message) in enumerate(messages):
+ assert role == self.roles[i % 2], (
+ "Conversation should alternate user/assistant/user/assistant/..."
+ )
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = (
+ lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ )
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0:
+ message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def process_image(
+ self, image, image_process_mode, return_pil=False, image_format="PNG"
+ ):
+ if image_process_mode == "Pad":
+
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+
+ if type(image) is not Image.Image:
+ image = Image.open(image).convert("RGB")
+
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 672, 448
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ return image
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format=image_format)
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ return img_b64_str
+
+ def get_images(self, return_pil=False, return_path=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ if type(image) != list:
+ image = [image]
+ for img in image:
+ if not return_path and self.is_image_file(img):
+ img = self.process_image(
+ img, image_process_mode, return_pil=return_pil
+ )
+ else:
+ images.append(img)
+ return images
+
+ def is_image_file(self, filename):
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
+
+ def is_video_file(self, filename):
+ video_extensions = [
+ ".mp4",
+ ".mov",
+ ".avi",
+ ".mkv",
+ ".wmv",
+ ".flv",
+ ".mpeg",
+ ".mpg",
+ ]
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ if type(image) != list:
+ image = [image]
+ if len(image) == 1:
+ msg = "\n" + msg.replace("", "").strip()
+ else:
+ msg = re.sub(r"()\n(?=)", r"\1 ", msg)
+
+ img_str_list = []
+ for img in image:
+ if self.is_image_file(img):
+ img_b64_str = self.process_image(
+ img, "Default", return_pil=False, image_format="JPEG"
+ )
+ img_str = f'
'
+ img_str_list.append(img_str)
+ elif self.is_video_file(img):
+ ret.append(((img,), None))
+
+ msg = msg.strip()
+ img_place_holder = ""
+ for img_str in img_str_list:
+ img_place_holder += f"{img_str}\n\n"
+
+ if len(img_str_list) > 0:
+ msg = f"{img_place_holder}\n\n{msg}"
+
+ if len(msg) > 0:
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version,
+ )
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
+ ],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+def safe_load_tokenizer(tokenizer_id):
+ try:
+ return AutoTokenizer.from_pretrained(tokenizer_id)
+ except Exception:
+ return None
+
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_llada_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=[],
+ version="llada_plain",
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_3,
+ sep="\n",
+)
+
+conv_llava_llada = Conversation(
+ system="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("user", "assistant"),
+ version="llava_llada",
+ messages=[],
+ offset=0,
+ sep="<|eot_id|>",
+ sep_style=SeparatorStyle.LLAMA_3,
+)
+
+
+default_conversation = conv_llava_llada
+conv_templates = {
+ "default": conv_llava_llada,
+ "plain": conv_llava_plain,
+ "llada_plain": conv_llada_plain,
+ "llava_llada": conv_llava_llada,
+ "v0_plain": conv_llava_plain,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/train/llava/hooks/__init__.py b/train/llava/hooks/__init__.py
new file mode 100644
index 0000000..01540b5
--- /dev/null
+++ b/train/llava/hooks/__init__.py
@@ -0,0 +1,5 @@
+# pylint: disable=C0114,C0115,C0116,C0103
+
+from .cache_hook_LLaDA_V import register_cache_LLaDA_V
+
+__all__ = ["register_cache_LLaDA_V"]
diff --git a/train/llava/hooks/cache_hook_LLaDA_V.py b/train/llava/hooks/cache_hook_LLaDA_V.py
new file mode 100644
index 0000000..d993338
--- /dev/null
+++ b/train/llava/hooks/cache_hook_LLaDA_V.py
@@ -0,0 +1,867 @@
+# pylint: disable=C0114,C0115,C0116,C0103,R0913,R0917,R0914,R0912,R0915
+
+import math
+import types
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from llava.cache import dLLMCache
+
+
+# Helper functions from the new LLADa model (need to be accessible)
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def register_cache_LLaDA_V(model: nn.Module, tf_block_module_key_name: str) -> None:
+ """
+ Registers cache hooks for a LLaDA-like model.
+ tf_block_module_key_name is typically 'model.layers' for LLaMA-style models.
+ """
+ target_module_path = tf_block_module_key_name.split(".")
+ current_module = model
+ for part in target_module_path:
+ current_module = getattr(current_module, part)
+
+ target_module: Optional[nn.ModuleList] = current_module
+ if target_module is None or not isinstance(target_module, nn.ModuleList):
+ raise ValueError(f"Could not find nn.ModuleList at {tf_block_module_key_name}")
+
+ for layer_index, tf_block in enumerate(target_module):
+ setattr(tf_block, "layer_idx", layer_index)
+
+ setattr(tf_block, "_old_forward", tf_block.forward)
+ tf_block.forward = types.MethodType(llada_cache_hook_feature, tf_block)
+
+ setattr(tf_block.self_attn, "_old_forward_main", tf_block.self_attn.forward)
+ tf_block.self_attn.attention_forward_for_cache = types.MethodType(
+ llada_attention_hook_for_cache, tf_block.self_attn
+ )
+
+ setattr(
+ tf_block.self_attn.rotary_emb,
+ "_old_forward",
+ tf_block.self_attn.rotary_emb.forward,
+ )
+ tf_block.self_attn.rotary_emb.forward = types.MethodType(
+ llada_RoPe_forward_hook, tf_block.self_attn.rotary_emb
+ )
+
+
+def llada_attention_hook_for_cache(
+ self, # self is LLaDAAttention instance
+ q_in_proj: torch.Tensor, # Renamed from q to clarify it's post-projection from cache_hook
+ k_in_proj: torch.Tensor, # Renamed from k
+ v_in_proj: torch.Tensor, # Renamed from v
+ attention_bias: Optional[torch.Tensor] = None,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ q_index: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+ _ = layer_past
+ _ = use_cache
+ # q_in_proj, k_in_proj, v_in_proj are (Batch, SeqLen_specific_to_them, HiddenDim_model)
+ B, q_len_current_q = (
+ q_in_proj.shape[0],
+ q_in_proj.shape[1],
+ ) # q_len_current_q is the length of the Q for this specific call
+
+ q_num_heads = self.num_heads
+ k_num_heads = self.num_key_value_heads
+ v_num_heads = self.num_key_value_heads
+ head_dim = self.head_dim
+
+ # Reshape q, k, v to (Batch, NumHeads, SeqLen, HeadDim)
+ q = q_in_proj.view(B, q_len_current_q, q_num_heads, head_dim).transpose(1, 2)
+
+ k_seq_len_current_k = k_in_proj.shape[
+ 1
+ ] # k_seq_len_current_k is the length of K for this call (e.g., full context)
+ v_seq_len_current_v = v_in_proj.shape[1]
+
+ k = k_in_proj.view(B, k_seq_len_current_k, k_num_heads, head_dim).transpose(1, 2)
+ v = v_in_proj.view(B, v_seq_len_current_v, v_num_heads, head_dim).transpose(1, 2)
+
+ if hasattr(self, "rotary_emb"):
+ # q_index passed here should be for q_in_proj
+ q, k = self.rotary_emb(q, k, q_index=q_index)
+
+ present = None
+
+ k_repeated = repeat_kv(k, self.num_key_value_groups)
+ v_repeated = repeat_kv(v, self.num_key_value_groups)
+
+ # q is (B, q_num_heads, q_len_current_q, head_dim)
+ # k_repeated is (B, q_num_heads, k_seq_len_current_k, head_dim)
+
+ attn_weights = torch.matmul(q, k_repeated.transpose(2, 3)) / math.sqrt(
+ self.head_dim
+ )
+
+ if attention_bias is not None:
+ bias_q_dim = attention_bias.shape[-2]
+ bias_k_dim = attention_bias.shape[-1]
+
+ sliced_attention_bias = attention_bias
+ if (
+ q_len_current_q < bias_q_dim
+ ): # Q is a segment, assume it's the latest tokens
+ # This assumes the q segment is the last part of
+ # the full query sequence represented by the bias
+ sliced_attention_bias = attention_bias[:, :, -q_len_current_q:, :]
+
+ if (
+ k_seq_len_current_k < bias_k_dim
+ ): # K is a segment (less likely for full KV cache but possible)
+ # This assumes K is the latest part of the full key sequence in the bias
+ sliced_attention_bias = sliced_attention_bias[
+ :, :, :, -k_seq_len_current_k:
+ ]
+
+ # Final check on dimensions before adding
+ if (
+ attn_weights.shape[-2] == sliced_attention_bias.shape[-2]
+ and attn_weights.shape[-1] == sliced_attention_bias.shape[-1]
+ ):
+ attn_weights = attn_weights + sliced_attention_bias
+ else:
+ if (
+ sliced_attention_bias.shape[1] == 1
+ and attn_weights.shape[1] == self.num_heads
+ ): # Mask has 1 head dim
+ attn_weights = (
+ attn_weights + sliced_attention_bias
+ ) # Broadcast over heads
+ elif (
+ sliced_attention_bias.shape[1] == self.num_heads
+ ): # Mask has same num heads
+ attn_weights = attn_weights + sliced_attention_bias
+ else:
+ raise RuntimeError(
+ f"Attention bias shape {sliced_attention_bias.shape} incompatible with "
+ f"attn_weights shape {attn_weights.shape} after slicing."
+ )
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ q.dtype
+ )
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.attention_dropout, training=self.training
+ )
+ att_output_heads = torch.matmul(attn_weights, v_repeated)
+
+ # Reshape to (B, q_len_current_q, ModelHiddenDim)
+ att_output_heads = (
+ att_output_heads.transpose(1, 2)
+ .contiguous()
+ .view(B, q_len_current_q, q_num_heads * head_dim)
+ )
+
+ output = self.o_proj(att_output_heads)
+
+ return output, present
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def llada_RoPe_forward_hook(
+ self_rope,
+ q_in: torch.Tensor,
+ k_in: torch.Tensor,
+ q_index: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ q_, k_ = q_in.float(), k_in.float()
+
+ _, _, query_len_current, _ = q_.shape
+ _, _, key_len_current, _ = k_.shape
+
+ max_pos_needed = key_len_current
+ if q_index is not None:
+ max_pos_needed = max(
+ max_pos_needed, int(q_index.max().item()) + 1 if q_index.numel() > 0 else 0
+ )
+
+ max_pos_needed = max(max_pos_needed, query_len_current)
+
+ if max_pos_needed == 0:
+ return q_in, k_in
+
+ inv_freq_to_use = self_rope.inv_freq.to(q_.device)
+ t = torch.arange(max_pos_needed, device=q_.device, dtype=torch.float32)
+ if hasattr(self_rope, "scaling_factor"):
+ t = t / self_rope.scaling_factor
+
+ freqs = torch.outer(t, inv_freq_to_use.float())
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ pos_cos_table = emb.cos()
+ pos_sin_table = emb.sin()
+
+ if q_index is not None:
+ actual_q_indices = q_index[:, :query_len_current]
+ cos_q = pos_cos_table[actual_q_indices]
+ sin_q = pos_sin_table[actual_q_indices]
+ q_rotated = (q_ * cos_q.unsqueeze(1)) + (rotate_half(q_) * sin_q.unsqueeze(1))
+ else:
+ q_indices = torch.arange(query_len_current, device=q_.device)
+ cos_q = pos_cos_table[q_indices].unsqueeze(0)
+ sin_q = pos_sin_table[q_indices].unsqueeze(0)
+ q_rotated = (q_ * cos_q.unsqueeze(1)) + (rotate_half(q_) * sin_q.unsqueeze(1))
+
+ k_indices = torch.arange(key_len_current, device=q_.device)
+ cos_k = pos_cos_table[k_indices].unsqueeze(0)
+ sin_k = pos_sin_table[k_indices].unsqueeze(0)
+ k_rotated = (k_ * cos_k.unsqueeze(1)) + (rotate_half(k_) * sin_k.unsqueeze(1))
+
+ return q_rotated.type_as(q_in), k_rotated.type_as(k_in)
+
+
+def refresh_index(
+ new_features: torch.Tensor,
+ cached_features: torch.Tensor = None,
+ transfer_ratio: float = 0.5,
+ layer_id: int = 0,
+) -> torch.Tensor:
+ _ = layer_id
+ batch_size, gen_len, _ = new_features.shape
+ num_replace = int(gen_len * transfer_ratio)
+ if num_replace == 0 or gen_len == 0:
+ return torch.empty(
+ (batch_size, 0), dtype=torch.long, device=new_features.device
+ )
+ if cached_features is None or cached_features.shape[1] == 0:
+ return torch.empty(
+ (batch_size, 0), dtype=torch.long, device=new_features.device
+ )
+ # pylint: disable=E1102
+ cos_sim = torch.nn.functional.cosine_similarity(
+ new_features, cached_features, dim=-1
+ )
+ k_actual = min(num_replace, cos_sim.shape[1])
+ if k_actual == 0:
+ return torch.empty(
+ (batch_size, 0), dtype=torch.long, device=new_features.device
+ )
+
+ transfer_index = torch.topk(cos_sim, largest=False, k=k_actual).indices
+ return transfer_index
+
+
+def llada_cache_hook_feature(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[
+ torch.Tensor
+ ] = None, # This is the original mask for the full layer input
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Tuple[
+ torch.Tensor,
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
+ Optional[Tuple[torch.Tensor, ...]],
+]:
+ current_layer_idx = self.layer_idx
+ feature_cache = dLLMCache()
+ feature_cache.update_step(current_layer_idx)
+ _ = past_key_value
+ _ = cache_position
+
+ prompt_length = feature_cache.prompt_length
+ # x_prompt and x_gen are sub-segments of hidden_states
+ x_prompt = hidden_states[:, :prompt_length, :]
+ x_gen = hidden_states[:, prompt_length:, :]
+
+ refresh_gen = feature_cache.refresh_gen()
+ refresh_prompt = feature_cache.refresh_prompt()
+ transfer_ratio = feature_cache.transfer_ratio
+
+ bs, _, dim = (
+ hidden_states.shape
+ ) # seq_len is the length of hidden_states input to this layer
+
+ transfer = 0 < transfer_ratio <= 1
+
+ index_from_attn_transfer = None
+ index_expanded_from_attn_transfer = None
+
+ def project(
+ x_input: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ x_normed = self.input_layernorm(x_input)
+ q = self.self_attn.q_proj(x_normed)
+ k = self.self_attn.k_proj(x_normed)
+ v = self.self_attn.v_proj(x_normed)
+ return q, k, v
+
+ # This function needs to be smarter about the attention_bias it passes.
+ # q_tensor_origin_slice: A tuple (start_idx, length) indicating where q_tensor comes from
+ # relative to the original hidden_states/attention_mask.
+ def call_attention_on_qkv(
+ q_tensor,
+ k_tensor,
+ v_tensor,
+ original_full_attention_bias,
+ q_tensor_start_idx: int, # Start index of q_tensor within the original sequence
+ q_index: Optional[torch.Tensor] = None,
+ ):
+ q_len_for_this_call = q_tensor.shape[1]
+ k_len_for_this_call = k_tensor.shape[
+ 1
+ ] # This K is already the full K from dLLM cache
+
+ # Slice the original_full_attention_bias
+ # original_full_attention_bias is likely
+ # (B, 1 or H, S_full_layer_input, K_full_from_dllm_or_max)
+ # We need (B, 1 or H, q_len_for_this_call, k_len_for_this_call)
+ sliced_bias = original_full_attention_bias
+ if original_full_attention_bias is not None:
+ # Slice query dimension based on q_tensor_start_idx and its length
+ # Slice key dimension to match k_tensor's length (which is k_full from dLLM)
+ sliced_bias = original_full_attention_bias[
+ :,
+ :,
+ q_tensor_start_idx : q_tensor_start_idx + q_len_for_this_call,
+ :k_len_for_this_call,
+ ]
+ # Ensure num_heads dim is compatible
+ # (1 for broadcast, or matches self.self_attn.num_heads)
+ if (
+ sliced_bias.shape[1] != 1
+ and sliced_bias.shape[1] != self.self_attn.num_heads
+ ):
+ # This might indicate an issue with mask preparation upstream
+ # if it's not 1 or num_heads
+ # For now, assume it's (B,1,q,k) and will broadcast
+ # if num_heads > 1 in attention_hook
+ pass
+
+ att_output, _ = self.self_attn.attention_forward_for_cache(
+ q_tensor,
+ k_tensor,
+ v_tensor,
+ attention_bias=sliced_bias,
+ layer_past=None,
+ use_cache=False,
+ q_index=q_index,
+ )
+ return att_output
+
+ def compute_mlp(input_to_mlp: torch.Tensor) -> torch.Tensor:
+ if input_to_mlp.shape[1] == 0:
+ return torch.empty_like(input_to_mlp)
+ x_norm = self.post_attention_layernorm(input_to_mlp)
+ gate_proj_out = self.mlp.gate_proj(x_norm)
+ up_proj_out = self.mlp.up_proj(x_norm)
+ act_out = self.mlp.act_fn(gate_proj_out)
+ x = act_out * up_proj_out
+ return self.mlp.down_proj(x)
+
+ residual_pre_attn = hidden_states
+
+ if refresh_gen and refresh_prompt:
+ q_full, k_full, v_full = project(hidden_states)
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={
+ "k": k_full[:, :prompt_length, :],
+ "v": v_full[:, :prompt_length, :],
+ },
+ cache_type="prompt",
+ )
+ if hidden_states.shape[1] > prompt_length:
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={
+ "k": k_full[:, prompt_length:, :],
+ "v": v_full[:, prompt_length:, :],
+ },
+ cache_type="gen",
+ )
+
+ # Q is all of hidden_states, K,V also from all of hidden_states
+ # q_start_idx is 0 because q_full corresponds to the start of hidden_states
+ att = call_attention_on_qkv(
+ q_full,
+ k_full,
+ v_full,
+ attention_mask,
+ q_tensor_start_idx=0,
+ q_index=position_ids,
+ )
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att[:, :prompt_length, :],
+ cache_type="prompt",
+ )
+ if hidden_states.shape[1] > prompt_length:
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att[:, prompt_length:, :],
+ cache_type="gen",
+ )
+
+ elif refresh_gen and not refresh_prompt:
+ att_gen_part = torch.empty((bs, 0, dim), device=hidden_states.device)
+ if x_gen.shape[1] > 0:
+ q_gen, k_gen, v_gen = project(x_gen)
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={"k": k_gen, "v": v_gen},
+ cache_type="gen",
+ )
+ kv_cache_prompt = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="kv_cache", cache_type="prompt"
+ )
+ k_prompt_val = kv_cache_prompt.get(
+ "k", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+ v_prompt_val = kv_cache_prompt.get(
+ "v", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+
+ k_full_ctx = torch.cat([k_prompt_val, k_gen], dim=1)
+ v_full_ctx = torch.cat([v_prompt_val, v_gen], dim=1)
+
+ q_gen_pos_ids = (
+ position_ids[:, prompt_length:]
+ if position_ids is not None and position_ids.shape[1] > prompt_length
+ else None
+ )
+
+ # q_gen starts at prompt_length in the original hidden_states sequence
+ att_gen_part = call_attention_on_qkv(
+ q_gen,
+ k_full_ctx,
+ v_full_ctx,
+ attention_mask,
+ q_tensor_start_idx=prompt_length,
+ q_index=q_gen_pos_ids,
+ )
+
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att_gen_part,
+ cache_type="gen",
+ )
+
+ att_prompt_cache = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="attn", cache_type="prompt"
+ )
+ att = torch.cat([att_prompt_cache, att_gen_part], dim=1)
+
+ elif not refresh_gen and refresh_prompt:
+ q_prompt, k_prompt, v_prompt = project(x_prompt)
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={"k": k_prompt, "v": v_prompt},
+ cache_type="prompt",
+ )
+ kv_cache_gen = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="kv_cache", cache_type="gen"
+ )
+ att_gen_cache = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="attn", cache_type="gen"
+ )
+
+ k_gen_current = kv_cache_gen.get(
+ "k", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+ v_gen_current = kv_cache_gen.get(
+ "v", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+
+ q_for_attn_segments = [q_prompt]
+ q_idx_for_attn_segments = (
+ [position_ids[:, :prompt_length]] if position_ids is not None else [None]
+ )
+
+ if transfer and x_gen.shape[1] > 0 and k_gen_current.shape[1] > 0:
+ _, _, v_gen_for_transfer = project(x_gen)
+ index_from_attn_transfer = refresh_index(
+ v_gen_for_transfer, v_gen_current, transfer_ratio, current_layer_idx
+ )
+
+ if index_from_attn_transfer.numel() > 0:
+ index_expanded_from_attn_transfer = index_from_attn_transfer.unsqueeze(
+ -1
+ ).expand(-1, -1, dim)
+
+ x_gen_normed_selected = torch.gather(
+ self.input_layernorm(x_gen),
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ )
+ q_gen_index = self.self_attn.q_proj(x_gen_normed_selected)
+ k_gen_index = self.self_attn.k_proj(x_gen_normed_selected)
+ v_gen_index_part = self.self_attn.v_proj(x_gen_normed_selected)
+
+ k_gen_current = k_gen_current.scatter(
+ dim=1, index=index_expanded_from_attn_transfer, src=k_gen_index
+ )
+ v_gen_current = v_gen_current.scatter(
+ dim=1, index=index_expanded_from_attn_transfer, src=v_gen_index_part
+ )
+
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={"k": k_gen_current, "v": v_gen_current},
+ cache_type="gen",
+ )
+
+ q_for_attn_segments.append(q_gen_index)
+ if position_ids is not None and position_ids.shape[1] > prompt_length:
+ gen_abs_positions_all = position_ids[:, prompt_length:]
+ gen_abs_positions_selected = torch.gather(
+ gen_abs_positions_all, 1, index_from_attn_transfer
+ )
+ q_idx_for_attn_segments.append(gen_abs_positions_selected)
+ else:
+ q_idx_for_attn_segments.append(None)
+
+ q_combined_for_attn = torch.cat(q_for_attn_segments, dim=1)
+ q_idx_combined_for_rope = (
+ torch.cat(q_idx_for_attn_segments, dim=1)
+ if all(s is not None for s in q_idx_for_attn_segments)
+ else None
+ )
+
+ k_full_ctx = torch.cat([k_prompt, k_gen_current], dim=1)
+ v_full_ctx = torch.cat([v_prompt, v_gen_current], dim=1)
+
+ # q_combined_for_attn effectively starts at index 0 of a conceptual sequence.
+ # Its RoPE is handled by q_idx_combined_for_rope.
+ att_for_q_combined = call_attention_on_qkv(
+ q_combined_for_attn,
+ k_full_ctx,
+ v_full_ctx,
+ attention_mask,
+ q_tensor_start_idx=0, # Since Q is combined and starts from effective 0
+ q_index=q_idx_combined_for_rope,
+ )
+
+ att_prompt_new = att_for_q_combined[:, : q_prompt.shape[1], :]
+ if (
+ transfer
+ and index_from_attn_transfer is not None
+ and index_from_attn_transfer.numel() > 0
+ ):
+ att_gen_index_new = att_for_q_combined[
+ :, q_prompt.shape[1] :, :
+ ] # Segment for transferred Qs
+ if att_gen_cache.shape[1] > 0:
+ att_gen_cache = att_gen_cache.scatter(
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ src=att_gen_index_new,
+ )
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att_gen_cache,
+ cache_type="gen",
+ )
+
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att_prompt_new,
+ cache_type="prompt",
+ )
+ att = torch.cat([att_prompt_new, att_gen_cache], dim=1)
+
+ else: # Not refresh gen, not refresh prompt
+ att_prompt_cache = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="attn", cache_type="prompt"
+ )
+ att_gen_cache = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="attn", cache_type="gen"
+ )
+ kv_cache_gen = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="kv_cache", cache_type="gen"
+ )
+ kv_cache_prompt = feature_cache.get_cache(
+ layer_id=current_layer_idx, feature_name="kv_cache", cache_type="prompt"
+ )
+
+ k_gen_current = kv_cache_gen.get(
+ "k", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+ v_gen_current = kv_cache_gen.get(
+ "v", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+ k_prompt_val = kv_cache_prompt.get(
+ "k", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+ v_prompt_val = kv_cache_prompt.get(
+ "v", torch.empty(bs, 0, dim, device=hidden_states.device)
+ )
+
+ if transfer and x_gen.shape[1] > 0 and k_gen_current.shape[1] > 0:
+ x_gen_normed = self.input_layernorm(x_gen)
+ v_gen_for_transfer = self.self_attn.v_proj(x_gen_normed)
+ index_from_attn_transfer = refresh_index(
+ v_gen_for_transfer, v_gen_current, transfer_ratio, current_layer_idx
+ )
+
+ if index_from_attn_transfer.numel() > 0:
+ index_expanded_from_attn_transfer = index_from_attn_transfer.unsqueeze(
+ -1
+ ).expand(-1, -1, dim)
+
+ x_gen_normed_selected = torch.gather(
+ x_gen_normed, dim=1, index=index_expanded_from_attn_transfer
+ )
+ q_gen_index_only = self.self_attn.q_proj(
+ x_gen_normed_selected
+ ) # Q only for transferred items
+ k_gen_index = self.self_attn.k_proj(x_gen_normed_selected)
+ v_gen_index_part = self.self_attn.v_proj(x_gen_normed_selected)
+
+ k_gen_current = k_gen_current.scatter(
+ dim=1, index=index_expanded_from_attn_transfer, src=k_gen_index
+ )
+ v_gen_current = v_gen_current.scatter(
+ dim=1, index=index_expanded_from_attn_transfer, src=v_gen_index_part
+ )
+
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="kv_cache",
+ features={"k": k_gen_current, "v": v_gen_current},
+ cache_type="gen",
+ )
+
+ q_idx_for_transferred_rope = None
+ if position_ids is not None and position_ids.shape[1] > prompt_length:
+ gen_abs_positions_all = position_ids[:, prompt_length:]
+ q_idx_for_transferred_rope = torch.gather(
+ gen_abs_positions_all, 1, index_from_attn_transfer
+ )
+
+ k_full_ctx = torch.cat([k_prompt_val, k_gen_current], dim=1)
+ v_full_ctx = torch.cat([v_prompt_val, v_gen_current], dim=1)
+
+ att_gen_index_new = call_attention_on_qkv(
+ q_gen_index_only,
+ k_full_ctx,
+ v_full_ctx,
+ attention_mask,
+ q_tensor_start_idx=prompt_length, # Approximate for mask slicing
+ q_index=q_idx_for_transferred_rope,
+ )
+
+ if att_gen_cache.shape[1] > 0: # Make sure att_gen_cache has a gen part
+ att_gen_cache = att_gen_cache.scatter(
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ src=att_gen_index_new,
+ )
+ elif (
+ x_gen.shape[1] > 0
+ ): # If original att_gen_cache was for an empty gen part, but x_gen is not empty
+ # Initialize att_gen_cache to be of x_gen's length before scattering
+ att_gen_cache = torch.zeros(
+ (bs, x_gen.shape[1], dim),
+ device=hidden_states.device,
+ dtype=att_gen_index_new.dtype,
+ )
+ att_gen_cache = att_gen_cache.scatter(
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ src=att_gen_index_new,
+ )
+ # Else: if x_gen.shape[1] is 0, att_gen_cache remains empty, scatter won't happen.
+
+ feature_cache.set_cache(
+ layer_id=current_layer_idx,
+ feature_name="attn",
+ features=att_gen_cache,
+ cache_type="gen",
+ )
+
+ att = torch.cat([att_prompt_cache, att_gen_cache], dim=1)
+
+ # ... rest of the llada_cache_hook_feature (MLP part) remains the same ...
+ hidden_states_after_attn = residual_pre_attn + att
+ residual_pre_mlp = hidden_states_after_attn
+
+ x_prompt_mlp = hidden_states_after_attn[:, :prompt_length, :]
+ x_gen_mlp = hidden_states_after_attn[:, prompt_length:, :]
+
+ mlp_out_prompt_part = torch.empty(
+ (bs, prompt_length, dim),
+ device=hidden_states.device,
+ dtype=hidden_states_after_attn.dtype,
+ )
+ mlp_out_gen_part = torch.empty(
+ (bs, x_gen_mlp.shape[1], dim),
+ device=hidden_states.device,
+ dtype=hidden_states_after_attn.dtype,
+ )
+
+ if refresh_gen and refresh_prompt:
+ mlp_out_full = compute_mlp(hidden_states_after_attn)
+ mlp_out_prompt_part = mlp_out_full[:, :prompt_length, :]
+ if x_gen_mlp.shape[1] > 0:
+ mlp_out_gen_part = mlp_out_full[:, prompt_length:, :]
+
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_prompt_part, cache_type="prompt"
+ )
+ if x_gen_mlp.shape[1] > 0:
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen"
+ )
+ if mlp_out_gen_part.shape[1] > 0:
+ mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1)
+ else:
+ mlp_out = mlp_out_prompt_part
+
+ elif refresh_gen and not refresh_prompt:
+ mlp_out_prompt_part = feature_cache.get_cache(
+ current_layer_idx, "mlp", cache_type="prompt"
+ )
+ if x_gen_mlp.shape[1] > 0:
+ mlp_out_gen_part = compute_mlp(x_gen_mlp)
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen"
+ )
+
+ if mlp_out_gen_part.shape[1] > 0:
+ mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1)
+ else:
+ mlp_out = mlp_out_prompt_part
+
+ elif refresh_prompt and not refresh_gen:
+ mlp_gen_cache_data = feature_cache.get_cache(
+ current_layer_idx, "mlp", cache_type="gen"
+ )
+ if x_gen_mlp.shape[1] > 0:
+ mlp_out_gen_part = mlp_gen_cache_data
+
+ mlp_input_for_prompt_path = x_prompt_mlp
+ # Use index_expanded_from_attn_transfer which was set in the attention block
+ if (
+ transfer
+ and index_expanded_from_attn_transfer is not None
+ and index_expanded_from_attn_transfer.numel() > 0
+ and x_gen_mlp.shape[1] > 0
+ ):
+ x_gen_mlp_selected = torch.gather(
+ x_gen_mlp, dim=1, index=index_expanded_from_attn_transfer
+ )
+ mlp_input_for_prompt_path = torch.cat(
+ [x_prompt_mlp, x_gen_mlp_selected], dim=1
+ )
+
+ mlp_out_prompt_path_processed = compute_mlp(mlp_input_for_prompt_path)
+ mlp_out_prompt_part = mlp_out_prompt_path_processed[
+ :, : x_prompt_mlp.shape[1], :
+ ]
+
+ if (
+ transfer
+ and index_expanded_from_attn_transfer is not None
+ and index_expanded_from_attn_transfer.numel() > 0
+ and x_gen_mlp.shape[1] > 0
+ ):
+ mlp_gen_index_new = mlp_out_prompt_path_processed[
+ :, x_prompt_mlp.shape[1] :, :
+ ]
+ if mlp_out_gen_part.shape[1] > 0: # If gen part exists
+ mlp_out_gen_part = mlp_out_gen_part.scatter(
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ src=mlp_gen_index_new,
+ )
+
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen"
+ )
+
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_prompt_part, cache_type="prompt"
+ )
+ if mlp_out_gen_part.shape[1] > 0:
+ mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1)
+ else:
+ mlp_out = mlp_out_prompt_part
+
+ else:
+ mlp_out_prompt_part = feature_cache.get_cache(
+ current_layer_idx, "mlp", cache_type="prompt"
+ )
+ mlp_gen_cache_data = feature_cache.get_cache(
+ current_layer_idx, "mlp", cache_type="gen"
+ )
+ if x_gen_mlp.shape[1] > 0:
+ mlp_out_gen_part = mlp_gen_cache_data
+
+ # Use index_expanded_from_attn_transfer
+ if (
+ transfer
+ and index_expanded_from_attn_transfer is not None
+ and index_expanded_from_attn_transfer.numel() > 0
+ and x_gen_mlp.shape[1] > 0
+ ):
+ x_gen_mlp_selected = torch.gather(
+ x_gen_mlp, dim=1, index=index_expanded_from_attn_transfer
+ )
+ mlp_out_gen_index = compute_mlp(x_gen_mlp_selected)
+ if mlp_out_gen_part.shape[1] > 0:
+ mlp_out_gen_part = mlp_out_gen_part.scatter(
+ dim=1,
+ index=index_expanded_from_attn_transfer,
+ src=mlp_out_gen_index,
+ )
+
+ feature_cache.set_cache(
+ current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen"
+ )
+
+ if mlp_out_gen_part.shape[1] > 0:
+ mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1)
+ else:
+ mlp_out = mlp_out_prompt_part
+
+ final_hidden_states = residual_pre_mlp + mlp_out
+
+ returned_outputs = (final_hidden_states,)
+ if output_attentions:
+ returned_outputs += (None,)
+ if use_cache:
+ returned_outputs += (None,)
+
+ return returned_outputs
diff --git a/train/llava/hooks/fast_dllm_hook.py b/train/llava/hooks/fast_dllm_hook.py
new file mode 100644
index 0000000..c4d9dbe
--- /dev/null
+++ b/train/llava/hooks/fast_dllm_hook.py
@@ -0,0 +1,702 @@
+# pylint: disable=C0114,C0115,C0116,C0103,R0913,R0917,R0914,R0912,R0915,E1102
+"""
+Fast dLLM Cache Hook for LLaDA Model
+
+This module contains hooks and utilities for implementing fast distributed LLM caching
+in the LLaDA model architecture.
+"""
+
+from typing import List, Optional, Sequence, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .cache_hook_LLaDA_V import repeat_kv
+
+
+class FastDLLMGenerationHook:
+ """
+ Hook class for implementing fast dLLM caching functionality in LLaDA model.
+ This class handles both attention-level caching and generation-level optimizations.
+ """
+
+ def __init__(self, model):
+ self.model = model
+ self.original_methods = {}
+ self.is_registered = False
+
+ def register_hooks(self):
+ """Register fast dLLM hooks to the model."""
+ if self.is_registered:
+ return
+
+ # Store original methods
+ # 新增:保存原始的generate方法
+ self.original_methods["generate"] = self.model.generate
+
+ # Store original attention forwards
+ for layer_idx, layer in enumerate(self.model.model.layers):
+ self.original_methods[f"attention_{layer_idx}"] = layer.self_attn.forward
+ # Replace attention forward with fast cache version
+ layer.self_attn.forward = self._create_fast_attention_forward(
+ layer.self_attn, layer_idx
+ )
+
+ self.model.generate = self._fast_generate # 新增:替换generate方法
+
+ self.is_registered = True
+
+ def unregister_hooks(self):
+ """Unregister fast dLLM hooks from the model."""
+ if not self.is_registered:
+ return
+
+ # Restore original methods
+ # 新增:恢复原始的generate方法
+ self.model.generate = self.original_methods["generate"]
+
+ # Restore original attention forwards
+ for layer_idx, layer in enumerate(self.model.model.layers):
+ layer.self_attn.forward = self.original_methods[f"attention_{layer_idx}"]
+
+ self.is_registered = False
+
+ def _create_fast_attention_forward(self, attention_layer, layer_idx):
+ """Create fast attention forward method with caching support."""
+
+ def fast_attention_forward(
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ fast_dllm_cache: Optional[
+ Sequence[Tuple[torch.Tensor, torch.Tensor]]
+ ] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # If not using fast cache or output_attentions is needed, use original method
+ if output_attentions:
+ return self.original_methods[f"attention_{layer_idx}"](
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = attention_layer.q_proj(hidden_states)
+ key_states = attention_layer.k_proj(hidden_states)
+ value_states = attention_layer.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, attention_layer.num_heads, attention_layer.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz,
+ q_len,
+ attention_layer.num_key_value_heads,
+ attention_layer.head_dim,
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz,
+ q_len,
+ attention_layer.num_key_value_heads,
+ attention_layer.head_dim,
+ ).transpose(1, 2)
+
+ # Apply rotary position embedding with fast cache consideration
+ cache_offset = 0
+ if fast_dllm_cache and len(fast_dllm_cache) > layer_idx:
+ cache_offset = fast_dllm_cache[layer_idx][0].shape[-2]
+
+ cos, sin = attention_layer.rotary_emb(
+ value_states,
+ position_ids + cache_offset if cache_offset > 0 else position_ids,
+ )
+ query_states, key_states = self._apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ # Handle past key values
+ past_key_value = getattr(attention_layer, "past_key_value", past_key_value)
+ if past_key_value is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, layer_idx, cache_kwargs
+ )
+
+ # Fast dLLM cache logic
+ if fast_dllm_cache is not None:
+ if len(fast_dllm_cache) <= layer_idx:
+ fast_dllm_cache.append((key_states, value_states))
+ else:
+ past_key, past_value = fast_dllm_cache[layer_idx]
+ key_states = torch.cat([past_key, key_states], dim=-2)
+ value_states = torch.cat([past_value, value_states], dim=-2)
+
+ # Repeat key-value pairs for multi-head attention
+ key_states = repeat_kv(key_states, attention_layer.num_key_value_groups)
+ value_states = repeat_kv(value_states, attention_layer.num_key_value_groups)
+
+ # Apply causal mask
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # Ensure contiguous tensors for CUDA
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # Scaled dot-product attention
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=None,
+ is_causal=False,
+ dropout_p=attention_layer.attention_dropout
+ if attention_layer.training
+ else 0.0,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, attention_layer.hidden_size)
+ attn_output = attention_layer.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+ return fast_attention_forward
+
+ @torch.no_grad()
+ def _fast_generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ):
+ modalities = (
+ kwargs.pop("modalities", None)
+ if "modalities" in kwargs and modalities is None
+ else modalities
+ )
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
+ self.model.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ modalities,
+ image_sizes=image_sizes,
+ )
+ )
+ else:
+ inputs_embeds = self.model.get_model().embed_tokens(inputs)
+ output = self._fast_generate_with_embeds(inputs_embeds=inputs_embeds, **kwargs)
+ return output
+
+ def reverse_causal_x0_p_comma(self, x0, logits, block_start, block_end):
+ batch_size, seq_len = x0.shape
+ p = F.softmax(logits.to(torch.float64), dim=-1)
+ x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
+ if block_start > block_end:
+ block_start = 0
+ assert seq_len == block_end
+ # 计算 block 内的长度
+ block_len = block_end - block_start
+ assert block_len == 24, "block len is not 24! is that true?"
+ even_indices = torch.tensor(
+ [block_start + i for i in range(0, block_len, 2)], device=x0.device
+ ).repeat(batch_size, 1)
+ num_even = even_indices.shape[1]
+ decreasing_values = torch.linspace(
+ 0.1, 0.99, steps=num_even, device=x0.device, dtype=x0_p.dtype
+ ).repeat(batch_size, 1)
+ x0_p = x0_p.scatter_(1, even_indices, decreasing_values)
+ return x0_p
+
+ @torch.no_grad()
+ def _fast_generate_with_embeds(
+ self,
+ inputs_embeds,
+ steps=128,
+ gen_length=128,
+ block_length=128,
+ temperature=0.0,
+ cfg_scale=0.0,
+ remasking="low_confidence",
+ mask_id=126336,
+ tokenizer=None,
+ stopping_criteria=None,
+ generation_suffix=None,
+ threshold=None,
+ prefix_refresh_interval=32,
+ **kwargs,
+ ):
+ """
+ Fast generation with embeddings using dLLM cache optimization.
+ This method incorporates all fast dLLM related optimizations.
+ """
+ _ = kwargs # Unused
+ # Use mixed precision for faster computation
+ with torch.autocast(enabled=True, device_type="cuda", dtype=torch.bfloat16):
+ # Handle generation suffix
+ suffix_embeds = None
+ suffix_token_ids = None
+ suffix_len = 0
+ if (
+ generation_suffix is not None
+ and tokenizer is not None
+ and len(generation_suffix) > 0
+ ):
+ suffix_token_ids = tokenizer.encode(
+ generation_suffix, add_special_tokens=False
+ )
+ suffix_token_ids = torch.tensor(
+ suffix_token_ids, dtype=torch.long, device=inputs_embeds.device
+ ).unsqueeze(0)
+ suffix_embeds = self.model.model.embed_tokens(suffix_token_ids)
+ suffix_len = suffix_embeds.shape[1]
+
+ # Create input in embedding space
+ total_length = inputs_embeds.shape[1] + gen_length + suffix_len
+ masked_embed = self.model.model.embed_tokens(
+ torch.tensor([mask_id]).to(inputs_embeds.device)
+ )
+ x_embeds = masked_embed.repeat(1, total_length, 1).to(inputs_embeds.device)
+ x_embeds[:, : inputs_embeds.shape[1]] = inputs_embeds.clone()
+ if suffix_embeds is not None:
+ x_embeds[:, -suffix_len:] = suffix_embeds
+
+ # Create tracking tensor for token IDs
+ x = torch.full(
+ (1, total_length),
+ mask_id,
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ track_x = [] # record all step of x
+ if suffix_token_ids is not None:
+ x[:, -suffix_len:] = suffix_token_ids
+
+ # Prompt index tracking
+ prompt_index = torch.zeros(
+ (1, total_length), dtype=torch.bool, device=inputs_embeds.device
+ )
+ prompt_index[:, : inputs_embeds.shape[1]] = 1
+
+ assert gen_length % block_length == 0
+ num_blocks = gen_length // block_length
+ assert steps % num_blocks == 0
+ steps = steps // num_blocks
+
+ # Initialize stop tracking
+ stop_position = inputs_embeds.shape[1] + gen_length
+ found_stop_seq = False
+ stop_tokens = []
+
+ if stopping_criteria is not None:
+ assert tokenizer is not None, (
+ "tokenizer is required when stopping_criteria is not None"
+ )
+ for stop_str in stopping_criteria:
+ tokens = tokenizer.encode(stop_str, add_special_tokens=False)
+ stop_tokens.append(tokens)
+
+ # Process each block
+ for num_block in range(num_blocks):
+ block_start = inputs_embeds.shape[1] + num_block * block_length
+ block_end = inputs_embeds.shape[1] + (num_block + 1) * block_length
+
+ # Skip if stop found before current block
+ if found_stop_seq and stop_position <= block_start:
+ break
+
+ block_embeds = x_embeds[:, block_start:block_end]
+ block_mask_index = torch.all(
+ torch.abs(block_embeds - masked_embed) < 1e-5, dim=2
+ )
+ num_transfer_tokens = self._get_num_transfer_tokens(
+ block_mask_index, steps
+ )
+
+ i = 0
+
+ while True:
+ if threshold is None and i >= steps:
+ break
+
+ # Check mask state
+ mask_index = torch.all(
+ torch.abs(x_embeds - masked_embed) < 1e-5, dim=2
+ )
+
+ if found_stop_seq:
+ pre_stop_masks = mask_index[
+ 0, inputs_embeds.shape[1] : stop_position
+ ]
+ if not pre_stop_masks.any():
+ break
+
+ current_block_masks = mask_index[0, block_start:block_end]
+ if not current_block_masks.any():
+ break
+
+ # Handle CFG
+ if cfg_scale > 0.0:
+ un_embeds = x_embeds.clone()
+ un_mask = prompt_index.unsqueeze(-1).expand_as(x_embeds)
+ un_embeds[un_mask] = masked_embed.repeat(
+ x_embeds.shape[0], x_embeds.shape[1], 1
+ )[un_mask]
+ combined_embeds = torch.cat([x_embeds, un_embeds], dim=0)
+
+ if i % prefix_refresh_interval == 0:
+ fast_dllm_cache = []
+ outputs = self.model.model(
+ inputs_embeds=combined_embeds,
+ fast_dllm_cache=fast_dllm_cache,
+ )
+ fast_dllm_cache = self._create_cache_slice(
+ fast_dllm_cache, block_start
+ )
+ else:
+ outputs = self.model.model(
+ inputs_embeds=combined_embeds[:, block_start:],
+ fast_dllm_cache=fast_dllm_cache,
+ )
+
+ logits = self.model.lm_head(outputs[0]).float()
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
+ else:
+ if i % prefix_refresh_interval == 0:
+ fast_dllm_cache = []
+ outputs = self.model.model(
+ inputs_embeds=x_embeds, fast_dllm_cache=fast_dllm_cache
+ )
+ # Slice cache to block start
+ fast_dllm_cache = self._create_cache_slice(
+ fast_dllm_cache, block_start
+ )
+ else:
+ # Incremental forward pass
+ outputs = self.model.model(
+ inputs_embeds=x_embeds[:, block_start:],
+ fast_dllm_cache=fast_dllm_cache,
+ )
+ logits = self.model.lm_head(outputs[0]).float()
+
+ # Filter forbidden tokens
+ forbidden_tokens = [126081, 126080, 126346, 126347]
+ if i % prefix_refresh_interval == 0:
+ for token_id in forbidden_tokens:
+ logits[:, :, token_id] = torch.where(
+ mask_index, -float("inf"), logits[:, :, token_id]
+ )
+ else:
+ for token_id in forbidden_tokens:
+ logits[:, :, token_id] = torch.where(
+ mask_index[:, block_start:],
+ -float("inf"),
+ logits[:, :, token_id],
+ )
+
+ # Get transfer indices and update
+ if i % prefix_refresh_interval == 0:
+ x0, transfer_index = self._get_transfer_index(
+ logits,
+ temperature,
+ remasking,
+ mask_index,
+ x,
+ num_transfer_tokens[:, i] if threshold is None else None,
+ found_stop_seq,
+ stop_position,
+ block_end,
+ suffix_len,
+ block_start,
+ threshold,
+ )
+ x0_embeds = self.model.model.embed_tokens(x0)
+ x0_embeds = torch.where(
+ mask_index.unsqueeze(-1).expand_as(x_embeds),
+ x0_embeds,
+ x_embeds,
+ )
+ x_embeds[transfer_index] = x0_embeds[transfer_index]
+ x[transfer_index] = x0[transfer_index]
+ else:
+ x0, transfer_index = self._get_transfer_index(
+ logits,
+ temperature,
+ remasking,
+ mask_index[:, block_start:],
+ x[:, block_start:],
+ num_transfer_tokens[:, i] if threshold is None else None,
+ found_stop_seq,
+ stop_position - block_start,
+ block_end - block_start,
+ suffix_len,
+ block_start,
+ threshold,
+ )
+ x0_embeds = self.model.model.embed_tokens(x0)
+ x0_embeds = torch.where(
+ mask_index[:, block_start:]
+ .unsqueeze(-1)
+ .expand_as(x_embeds[:, block_start:]),
+ x0_embeds,
+ x_embeds[:, block_start:],
+ )
+ x_embeds[:, block_start:][transfer_index] = x0_embeds[
+ transfer_index
+ ]
+ x[:, block_start:][transfer_index] = x0[transfer_index]
+
+ track_x.append(
+ x.clone()[
+ :,
+ inputs_embeds.shape[1] : inputs_embeds.shape[1]
+ + gen_length,
+ ]
+ )
+
+ # Check for stop words
+ if stopping_criteria is not None:
+ generated_part = x[
+ 0,
+ inputs_embeds.shape[1] : inputs_embeds.shape[1]
+ + gen_length,
+ ]
+ for stop_seq in stop_tokens:
+ if not isinstance(stop_seq, list):
+ stop_seq = [stop_seq]
+ for start_idx in range(
+ generated_part.size(0) - len(stop_seq) + 1
+ ):
+ if torch.all(
+ generated_part[
+ start_idx : start_idx + len(stop_seq)
+ ]
+ == torch.tensor(stop_seq, device=x.device)
+ ):
+ current_position = (
+ inputs_embeds.shape[1] + start_idx
+ )
+ if (
+ not found_stop_seq
+ or current_position < stop_position
+ ):
+ stop_position = current_position
+ found_stop_seq = True
+ break
+ if found_stop_seq:
+ break
+ i += 1
+
+ if threshold is not None:
+ print(f"Number of steps: {i}")
+
+ # Return results
+ if found_stop_seq:
+ if suffix_len > 0:
+ return torch.cat(
+ [
+ x[:, inputs_embeds.shape[1] : stop_position],
+ x[:, -suffix_len:],
+ ],
+ dim=1,
+ )
+ return x[:, inputs_embeds.shape[1] : stop_position], track_x
+ if suffix_len > 0:
+ return torch.cat(
+ [
+ x[
+ :,
+ inputs_embeds.shape[1] : inputs_embeds.shape[1]
+ + gen_length,
+ ],
+ x[:, -suffix_len:],
+ ],
+ dim=1,
+ )
+ return x[
+ :, inputs_embeds.shape[1] : inputs_embeds.shape[1] + gen_length
+ ], track_x
+
+ @staticmethod
+ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Apply rotary position embedding to query and key tensors."""
+ _ = position_ids # Unused in this implementation
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (FastDLLMGenerationHook._rotate_half(q) * sin)
+ k_embed = (k * cos) + (FastDLLMGenerationHook._rotate_half(k) * sin)
+ return q_embed, k_embed
+
+ @staticmethod
+ def _rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _get_transfer_index(
+ self,
+ logits,
+ temperature,
+ remasking,
+ mask_index,
+ x,
+ num_transfer_tokens,
+ found_stop_seq,
+ stop_position,
+ block_end,
+ suffix_len,
+ block_start,
+ threshold=None,
+ ):
+ """Get transfer indices for token updates during generation."""
+ logits_with_noise = self._add_gumbel_noise(logits, temperature=temperature)
+ x0 = torch.argmax(logits_with_noise, dim=-1)
+
+ if remasking == "low_confidence":
+ p = F.softmax(logits.to(torch.float64), dim=-1)
+ x0_p = torch.squeeze(
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
+ )
+ elif remasking == "reverse_causal_comma":
+ x0_p = self.reverse_causal_x0_p_comma(x0, logits, block_start, block_end)
+ else:
+ raise NotImplementedError(remasking)
+
+ # Handle stop sequences and block boundaries
+ if found_stop_seq:
+ x0_p[:, stop_position:] = -np.inf
+ else:
+ x0_p[:, block_end:] = -np.inf
+
+ # Prevent overwriting suffix
+ if suffix_len > 0:
+ x0_p[:, -suffix_len:] = -np.inf
+
+ x0 = torch.where(mask_index, x0, x)
+ confidence = torch.where(mask_index, x0_p, -np.inf)
+
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
+
+ if threshold is not None:
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
+
+ for j in range(confidence.shape[0]):
+ if threshold is None:
+ top_i = num_transfer_tokens[j]
+ else:
+ ns = list(range(1, num_transfer_tokens[j] + 1))
+ es = [threshold / (n + 1) for n in ns]
+ threshs = [1 - e for e in es]
+ threshs[0] = -1 # at least one token is transferred
+
+ sorted_confidence = torch.sort(
+ confidence[j][mask_index[j]], dim=-1, descending=True
+ )[0]
+ assert len(sorted_confidence) == len(threshs)
+
+ top_i = 0
+ for top_i, _ in enumerate(threshs):
+ if sorted_confidence[top_i] < threshs[top_i]:
+ break
+
+ if top_i in (0, len(threshs) - 1):
+ top_i += 1
+
+ _, select_index = torch.topk(confidence[j], k=top_i)
+ transfer_index[j, select_index] = True
+
+ return x0, transfer_index
+
+ @staticmethod
+ def _add_gumbel_noise(logits, temperature):
+ """Add Gumbel noise for categorical sampling."""
+ if temperature == 0:
+ return logits
+
+ logits = logits.to(torch.float64)
+ noise = torch.rand_like(logits, dtype=torch.float64)
+ gumbel_noise = (-torch.log(noise)) ** temperature
+ return logits.exp() / gumbel_noise
+
+ @staticmethod
+ def _get_num_transfer_tokens(mask_index, steps):
+ """Precompute the number of tokens to transition at each step."""
+ mask_num = mask_index.sum(dim=1, keepdim=True)
+ base = mask_num // steps
+ remainder = mask_num % steps
+
+ num_transfer_tokens = base.expand(-1, steps).clone()
+
+ if remainder.sum() > 0:
+ indices = torch.arange(steps, device=mask_index.device)
+ mask = indices.unsqueeze(0) < remainder
+ num_transfer_tokens[mask] += 1
+
+ return num_transfer_tokens.to(torch.int64)
+
+ def _create_cache_slice(self, fast_dllm_cache, block_start):
+ """Create a sliced version of fast_dllm_cache for block processing."""
+ new_past_key_values = []
+ for i, cache_i in enumerate(fast_dllm_cache):
+ new_past_key_values.append([])
+ for _, cache_j in enumerate(cache_i):
+ new_past_key_values[i].append(cache_j[:, :, :block_start])
+ return new_past_key_values
+
+
+def register_fast_dllm_hook(model):
+ """
+ Register fast dLLM cache hooks to the model.
+
+ Args:
+ model: The LLaDA model to register hooks to
+
+ Returns:
+ FastDLLMGenerationHook: The hook instance for management
+ """
+ hook = FastDLLMGenerationHook(model)
+ hook.register_hooks()
+ return hook
+
+
+def unregister_fast_dllm_hook(hook):
+ """
+ Unregister fast dLLM cache hooks from the model.
+
+ Args:
+ hook: The FastDLLMGenerationHook instance to unregister
+ """
+ hook.unregister_hooks()
diff --git a/train/llava/mm_utils.py b/train/llava/mm_utils.py
new file mode 100644
index 0000000..52b4503
--- /dev/null
+++ b/train/llava/mm_utils.py
@@ -0,0 +1,396 @@
+from PIL import Image
+from io import BytesIO
+import base64
+import math
+import ast
+import re
+import torch
+from transformers import StoppingCriteria
+from llava.constants import IMAGE_TOKEN_INDEX
+from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord
+
+
+def resize_and_center_crop(image, shortest_edge_length):
+ # Calculate new dimensions and resize
+ aspect_ratio = float(image.width) / float(image.height)
+ if aspect_ratio > 1:
+ new_width = int(shortest_edge_length * aspect_ratio)
+ new_height = shortest_edge_length
+ else:
+ new_width = shortest_edge_length
+ new_height = int(shortest_edge_length / aspect_ratio)
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
+
+ # Calculate the position and perform the center crop
+ left = (new_width - shortest_edge_length) / 2
+ top = (new_height - shortest_edge_length) / 2
+ right = (new_width + shortest_edge_length) / 2
+ bottom = (new_height + shortest_edge_length) / 2
+ cropped_image = resized_image.crop((left, top, right, bottom))
+
+ return cropped_image
+
+
+def auto_pad_images(image, grid_params):
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
+
+ # Step 1: Calculate and find the closest aspect ratio
+ input_width, input_height = image.size
+ input_aspect_ratio = input_width / input_height
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
+
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
+
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
+
+ resize_width, resize_height = target_resolution
+ if input_width > input_height:
+ resize_height = int(resize_width / input_aspect_ratio)
+ else:
+ resize_width = int(resize_height * input_aspect_ratio)
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
+
+ # Step 5: Pad the resized image if necessary to match the target resolution
+ pad_width = target_resolution[0] - resize_width
+ pad_height = target_resolution[1] - resize_height
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
+
+ return padded_image
+
+
+def extract_patches(image, patch_size, overlap_ratio):
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
+ assert patch_size > 0, "Patch size should be greater than 0"
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
+
+ W, H = image.size
+ patches = []
+
+ stride = int(patch_size * (1 - overlap_ratio))
+
+ num_patches_y = (H - patch_size) // stride + 1
+ num_patches_x = (W - patch_size) // stride + 1
+
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
+
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
+ patches.append(patch)
+
+ return patches
+
+
+def process_highres_image_crop_split(image, data_args, processor=None):
+ crop_resolution = data_args.image_crop_resolution
+ split_resolution = data_args.image_split_resolution
+ if processor is None:
+ processor = data_args.image_processor
+ image_crop = resize_and_center_crop(image, crop_resolution)
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def process_highres_image(image, processor, grid_pinpoints):
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
+ width_height = max(image.size)
+ fit_grid_params = [x for x in grid_params if x >= width_height]
+ if len(fit_grid_params) == 0:
+ select_size = max(grid_params)
+ else:
+ select_size = min(fit_grid_params)
+ # FIXME: always select the 448
+ select_size = max(grid_params)
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+
+ # FIXME: this seems to be a bug that it always resizes instead of padding
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
+ image_padded = image_padded.resize((select_size, select_size))
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
+ image_patches = [image_original_resize] + image_patches
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def select_best_resolution(original_size, possible_resolutions):
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for width, height in possible_resolutions:
+ # Calculate the downscaled size to keep the aspect ratio
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+
+ # Calculate effective and wasted resolutions
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ # Determine which dimension (width or height) to fill
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ # Width will be filled completely
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ # Height will be filled completely
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ # Create a new image with the target size and paste the resized image onto it
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (tuple): The size of the input image in the format (width, height).
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+ patch_size (int): The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+ # Use regex to extract the range from the input string
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ # Multiply all elements by patch_size
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ width, height = select_best_resolution(image_size, possible_resolutions)
+ return width // patch_size, height // patch_size
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """
+ Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """
+ # Convert grid_pinpoints from string to list
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+ try:
+ patch_size = processor.size[0]
+ except Exception as e:
+ patch_size = processor.size["shortest_edge"]
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+ # Use regex to extract the range from the input string
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ # Multiply all elements by patch_size
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
+
+ # FIXME: this seems to be a bug that it resizes instead of pad.
+ # but to keep it consistent with previous, i will keep it as it is
+ # TODO: uncomment below to ablate with the padding
+ if isinstance(processor.size, dict):
+ shortest_edge = processor.size["shortest_edge"]
+ else:
+ shortest_edge = min(processor.size)
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == "highres":
+ for image in images:
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+ for image in images:
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ elif image_aspect_ratio == "crop_split":
+ for image in images:
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
+ new_images.append(image)
+ elif image_aspect_ratio == "pad":
+ for image in images:
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ new_images.append(image)
+ else:
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == "pt":
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
+ offset = min(output_ids.shape[1] - self.start_len, 3)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/train/llava/model/__init__.py b/train/llava/model/__init__.py
new file mode 100644
index 0000000..0a9ae48
--- /dev/null
+++ b/train/llava/model/__init__.py
@@ -0,0 +1,8 @@
+import os
+
+AVAILABLE_MODELS = {
+ "llava_llada": "LlavaLLaDAModelLM, LlavaLLaDAConfig",
+}
+
+for model_name, model_classes in AVAILABLE_MODELS.items():
+ exec(f"from .language_model.{model_name} import {model_classes}")
\ No newline at end of file
diff --git a/train/llava/model/apply_delta.py b/train/llava/model/apply_delta.py
new file mode 100644
index 0000000..c183ba1
--- /dev/null
+++ b/train/llava/model/apply_delta.py
@@ -0,0 +1,47 @@
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/train/llava/model/builder.py b/train/llava/model/builder.py
new file mode 100644
index 0000000..264534c
--- /dev/null
+++ b/train/llava/model/builder.py
@@ -0,0 +1,145 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+import torch
+from llava.model import *
+from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.utils import rank0_print
+from trl.import_utils import is_npu_available
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
+ kwargs["device_map"] = device_map
+
+ if load_8bit:
+ kwargs["load_in_8bit"] = True
+ elif load_4bit:
+ kwargs["load_in_4bit"] = True
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
+ elif torch_dtype == "float16":
+ kwargs["torch_dtype"] = torch.float16
+ elif torch_dtype == "bfloat16":
+ kwargs["torch_dtype"] = torch.bfloat16
+ else:
+ import pdb;pdb.set_trace()
+
+ if customized_config is not None:
+ kwargs["config"] = customized_config
+
+ if "multimodal" in kwargs:
+ if kwargs["multimodal"] is True:
+ is_multimodal = True
+ kwargs.pop("multimodal")
+ else:
+ is_multimodal = False
+
+ if "llava" in model_name.lower() or is_multimodal:
+ # Load LLaVA model
+ rank0_print(f"Loaded LLaVA model: {model_path}")
+ if "llada" in model_name.lower() or "exp" in model_name.lower():
+ from llava.model.language_model.llava_llada import LlavaLLaDAConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llada_cfg = LlavaLLaDAConfig.from_pretrained(model_path)
+ else:
+ llada_cfg = customized_config
+
+ if overwrite_config is not None:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llada_cfg, k, v)
+
+ model = LlavaLLaDAModelLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llada_cfg, **kwargs)
+
+ else:
+ try:
+ from llava.model.language_model.llava_llama import LlavaConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
+ if "v1.5" in model_path.lower():
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
+ else:
+ llava_cfg = customized_config
+
+ if overwrite_config is not None:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+ except:
+ raise ValueError(f"Model {model_name} not supported")
+
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print("Convert to FP16...")
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ if "mpt" in model_name.lower().replace("prompt", ""):
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+ rank0_print(f"Model Class: {model.__class__.__name__}")
+ image_processor = None
+
+ if "llava" in model_name.lower() or is_multimodal:
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ vision_tower = model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ vision_tower.load_model(device_map=device_map)
+ if device_map != "auto":
+ if is_npu_available():
+ vision_tower.to(device="npu", dtype=torch.float16)
+ else:
+ vision_tower.to(device="cuda", dtype=torch.bfloat16)
+ image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ elif hasattr(model.config, "max_position_embeddings"):
+ context_len = model.config.max_position_embeddings
+ elif hasattr(model.config, "tokenizer_model_max_length"):
+ context_len = model.config.tokenizer_model_max_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
diff --git a/train/llava/model/consolidate.py b/train/llava/model/consolidate.py
new file mode 100644
index 0000000..f02e575
--- /dev/null
+++ b/train/llava/model/consolidate.py
@@ -0,0 +1,30 @@
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+
+import argparse
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/train/llava/model/language_model/configuration_llada.py b/train/llava/model/language_model/configuration_llada.py
new file mode 100644
index 0000000..7fc5604
--- /dev/null
+++ b/train/llava/model/language_model/configuration_llada.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LLaDA model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LLaDA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+class LLaDAConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LLaDAModel`]. It is used to instantiate an LLaDA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LLaDA-8B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaDA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LLaDAModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ """
+
+ model_type = "llada"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ num_experts_per_tok=-1,
+ num_experts=-1,
+ moe_choice='expert',
+ capacity_factor=0.1,
+ moe_router_enable_expert_bias=None,
+ moe_router_score_function=None,
+ moe_lora_rank=32,
+ moe_lora_alpha=32,
+ moe_lora_in_features=4096,
+ moe_lora_out_features=4096,
+ moe_lora_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.moe_choice = moe_choice
+ self.capacity_factor = capacity_factor
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
+ self.moe_router_score_function = moe_router_score_function
+ self.moe_lora_rank=moe_lora_rank
+ self.moe_lora_alpha=moe_lora_alpha
+ self.moe_lora_in_features=moe_lora_in_features
+ self.moe_lora_out_features=moe_lora_out_features
+ self.moe_lora_dropout=moe_lora_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
diff --git a/train/llava/model/language_model/llava_llada.py b/train/llava/model/language_model/llava_llada.py
new file mode 100644
index 0000000..22b1e10
--- /dev/null
+++ b/train/llava/model/language_model/llava_llada.py
@@ -0,0 +1,211 @@
+# Copyright 2023 Zebin You
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from llava.model.language_model.configuration_llada import LLaDAConfig
+from llava.model.language_model.modeling_llada import LLaDAModel, LLaDAModelLM
+from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+)
+from transformers.generation.utils import GenerateOutput
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+
+class LlavaLLaDAConfig(LLaDAConfig):
+ model_type = "llava_llada"
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
+ max_new_tokens: int = 1024
+ do_sample: bool = False
+ top_p: Optional[float] = None
+ # rope_scaling: Optional[dict] = {}
+
+
+class LlavaLLaDAModel(LlavaMetaModel, LLaDAModel):
+ config_class = LlavaLLaDAConfig
+
+ def __init__(self, config: LLaDAConfig):
+ super(LlavaLLaDAModel, self).__init__(config)
+
+
+class LlavaLLaDAModelLM(LLaDAModelLM, LlavaMetaForCausalLM):
+ config_class = LlavaLLaDAConfig
+
+ def __init__(self, config):
+ LLaDAModelLM.__init__(self, config)
+
+ # configure default generation settings
+ config.model_type = "llava_llada"
+ # config.rope_scaling = None
+
+ self.model = LlavaLLaDAModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ modalities: Optional[List[str]] = ["image"],
+ dpo_forward: Optional[bool] = None,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ if inputs_embeds is None and attention_mask is not None:
+ # donate multi-dialogue
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels,
+ conversation_ids,
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ modalities,
+ image_sizes,
+ is_llada=True,
+ )
+ elif inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels,
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ modalities,
+ image_sizes,
+ )
+ conversation_ids = None
+ if dpo_forward:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ return logits, labels
+
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ conversation_ids=conversation_ids,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ modalities = (
+ kwargs.pop("modalities", None)
+ if "modalities" in kwargs and modalities is None
+ else modalities
+ )
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
+ self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ modalities,
+ image_sizes=image_sizes,
+ )
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate_with_embeds(inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
+ ):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_llada", LlavaLLaDAConfig)
+AutoModelForCausalLM.register(LlavaLLaDAConfig, LlavaLLaDAModelLM)
diff --git a/train/llava/model/language_model/modeling_llada.py b/train/llava/model/language_model/modeling_llada.py
new file mode 100644
index 0000000..997bad5
--- /dev/null
+++ b/train/llava/model/language_model/modeling_llada.py
@@ -0,0 +1,2330 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LLaDA model."""
+
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from llava.cache import dLLMCache
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from .configuration_llada import LLaDAConfig
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LLaDAConfig"
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+class LLaDARMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LLaDARMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(LLaDARMSNorm)
+
+
+class LLaDARotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ ):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ ** (
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
+ / self.dim
+ )
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(
+ self.max_seq_len_cached, device=device, dtype=torch.int64
+ ).type_as(self.inv_freq)
+ t = t / self.scaling_factor
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer(
+ "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False
+ )
+ self.register_buffer(
+ "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False
+ )
+
+ @property
+ def sin_cached(self):
+ logger.warning_once(
+ "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
+ "the forward method of RoPE from now on instead. It is not used in the `LLaDAAttention` class"
+ )
+ return self._sin_cached
+
+ @property
+ def cos_cached(self):
+ logger.warning_once(
+ "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
+ "the forward method of RoPE from now on instead. It is not used in the `LLaDAAttention` class"
+ )
+ return self._cos_cached
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = (
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ )
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = (
+ device_type
+ if isinstance(device_type, str) and device_type != "mps"
+ else "cpu"
+ )
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (
+ inv_freq_expanded.float() @ position_ids_expanded.float()
+ ).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class LLaDALinearScalingRotaryEmbedding(LLaDARotaryEmbedding):
+ """LLaDARotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class LLaDADynamicNTKScalingRotaryEmbedding(LLaDARotaryEmbedding):
+ """LLaDARotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
+ - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base
+ ** (
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device)
+ / self.dim
+ )
+ )
+ self.register_buffer(
+ "inv_freq", inv_freq, persistent=False
+ ) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LLaDAMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ if self.config.pretraining_tp > 1:
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat(
+ [
+ F.linear(x, gate_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ],
+ dim=-1,
+ )
+ up_proj = torch.cat(
+ [
+ F.linear(x, up_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ],
+ dim=-1,
+ )
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class LLaDAAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LLaDAConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ # self.is_causal = True
+ # Modify: MDM set causal to False.
+ self.is_causal = False
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ )
+ self.o_proj = nn.Linear(
+ self.hidden_size, self.hidden_size, bias=config.attention_bias
+ )
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = LLaDARotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = LLaDALinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = LLaDADynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (
+ self.num_key_value_heads * self.head_dim
+ ) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [
+ F.linear(hidden_states, query_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [
+ F.linear(hidden_states, key_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [
+ F.linear(hidden_states, value_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.attention_dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(
+ self.hidden_size // self.config.pretraining_tp, dim=2
+ )
+ o_proj_slices = self.o_proj.weight.split(
+ self.hidden_size // self.config.pretraining_tp, dim=1
+ )
+ attn_output = sum(
+ [
+ F.linear(attn_output[i], o_proj_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ )
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LLaDAFlashAttention2(LLaDAAttention):
+ """
+ LLaDA flash attention module. This module inherits from `LLaDAAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LLaDARMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LLaDAFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ assert causal is False # Modify: MDM
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ (
+ query_states,
+ key_states,
+ value_states,
+ indices_q,
+ cu_seq_lens,
+ max_seq_lens,
+ ) = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(
+ attn_output_unpad, indices_q, batch_size, query_length
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ return attn_output
+
+ def _upad_input(
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
+ ):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
+ indices_k,
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
+ query_layer, attention_mask
+ )
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class LLaDASdpaAttention(LLaDAAttention):
+ """
+ LLaDA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LLaDAAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LLaDAAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LLaDAModel is using LLaDASdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ # In case static cache is used, it is an instance attribute.
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ # if attention_mask is not None and cache_position is not None:
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ is_causal=False, # Modify: MDM
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+LLaDA_ATTENTION_CLASSES = {
+ "eager": LLaDAAttention,
+ "flash_attention_2": LLaDAFlashAttention2,
+ "sdpa": LLaDASdpaAttention,
+}
+
+
+class LoRALayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.r = config.moe_lora_rank
+ self.alpha = config.moe_lora_alpha
+ self.scaling = self.alpha / self.r
+
+ in_features = config.moe_lora_in_features
+ out_features = config.moe_lora_out_features
+
+ if self.r > 0:
+ self.lora_A = nn.Parameter(torch.randn(self.r, in_features))
+ self.lora_B = nn.Parameter(torch.zeros(out_features, self.r))
+ self.dropout = nn.Dropout(config.moe_lora_dropout)
+ else:
+ self.lora_A = None
+ self.lora_B = None
+ self.dropout = nn.Identity()
+
+ def forward(self, x):
+ return self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
+
+
+class LLaDAMoESparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.capacity_factor = config.capacity_factor
+ self.norm_topk_prob = config.moe_choice == "token"
+
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
+ self.experts = nn.ModuleList(
+ [LoRALayer(config) for _ in range(self.num_experts)]
+ )
+ # self.score_func = config.moe_router_score_function
+ if (
+ hasattr(config, "moe_router_enable_expert_bias")
+ and config.moe_router_enable_expert_bias
+ ):
+ self.register_buffer("expert_bias", torch.zeros(self.num_experts))
+ else:
+ self.expert_bias = None
+ if config.moe_choice == "expert":
+ self.forward = self.forward_expert_choice
+ else:
+ self.forward = self.forward_token_choice
+
+ # expert choice
+ def forward_expert_choice(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ num_tokens = batch_size * sequence_length
+ hidden_states = hidden_states.view(-1, hidden_dim) # flatten operation
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = F.softmax(
+ self.gate(hidden_states), dim=1
+ ) # importance of each token and experts
+
+ router_logits_T = router_logits.t() # expert_num, sequence_length
+
+ capacity = max(
+ 1, int(round(self.capacity_factor * num_tokens / self.num_experts))
+ ) # find out the capacity works?
+ capacity = min(capacity, num_tokens) # clamp to valid range
+ # import pdb; pdb.set_trace()
+ expert_weights, expert_indices = torch.topk(
+ router_logits_T,
+ k=capacity,
+ dim=1,
+ sorted=False, # no need to sort for efficiency
+ )
+
+ # Initialize output and token processing counter
+ output = torch.zeros_like(hidden_states)
+ # import pdb; pdb.set_trace()
+
+ # Process each expert in parallel (if possible) or sequentially
+ for expert_idx in range(self.num_experts):
+ # Get tokens selected by this expert
+ token_indices = expert_indices[expert_idx] # [capacity]
+ weights = expert_weights[expert_idx] # [capacity]
+
+ # Get hidden states for selected tokens
+ selected_hidden_states = hidden_states[
+ token_indices
+ ] # [capacity, hidden_size]
+
+ # Expert forward pass
+ expert_output = self.experts[expert_idx](
+ selected_hidden_states
+ ) # [capacity, hidden_size] ->
+
+ # Weighted accumulation to output
+ # Each token may be processed by multiple experts, so we accumulate
+ weighted_output = (
+ weights.unsqueeze(-1) * expert_output
+ ) # [1, capacity] * [capacity, hidden_size]
+
+ # Scatter-add to output tensor
+ output.index_add_(0, token_indices, weighted_output)
+ # break
+ # Reshape back to original shape
+ output = output.view(batch_size, sequence_length, hidden_dim)
+
+ # print(f"in mlp output_router_logits: {output_router_logits}")
+ # import pdb; pdb.set_trace()
+ # print("router logits: ", router_logits)
+
+ return output
+
+ # token choice
+ def forward_token_choice(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass with improved routing and load balancing.
+
+ Args:
+ hidden_states: [batch_size, sequence_length, hidden_dim]
+
+ Returns:
+ output: [batch_size, sequence_length, hidden_dim]
+ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ num_tokens = batch_size * sequence_length
+
+ # Reshape to [num_tokens, hidden_dim]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ # Compute routing weights
+ router_logits_raw, routing_probs = self._compute_routing_weights(hidden_states)
+
+ # Select top-k experts per token
+ # topk_weights: [num_tokens, top_k]
+ # selected_experts: [num_tokens, top_k]
+ topk_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1)
+
+ # Normalize top-k probabilities (Fix 1)
+ if self.norm_topk_prob:
+ topk_weights = topk_weights / topk_weights.sum(
+ dim=-1, keepdim=True
+ ).clamp_min(1e-9)
+
+ # Cast to input dtype
+ topk_weights = topk_weights.to(hidden_states.dtype)
+
+ # Compute auxiliary loss (Fix 3)
+ if self.training:
+ self.aux_loss = self._compute_aux_loss(
+ routing_probs, selected_experts, num_tokens
+ )
+ else:
+ self.aux_loss = torch.tensor(
+ 0.0, device=hidden_states.device, dtype=hidden_states.dtype
+ )
+
+ # Initialize output
+ final_hidden_states = torch.zeros(
+ (num_tokens, hidden_dim),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ )
+
+ # Efficient per-expert dispatch (Fix 2 - avoid one_hot)
+ # Process each expert
+ for expert_idx in range(self.num_experts):
+ # Find which tokens route to this expert
+ # expert_mask: [num_tokens, top_k] - boolean mask
+ expert_mask = selected_experts == expert_idx
+
+ # Get token indices that route to this expert (at any of their top-k positions)
+ # token_indices: [num_routed_tokens]
+ token_indices = expert_mask.any(dim=-1).nonzero(as_tuple=True)[0]
+
+ if token_indices.numel() == 0:
+ # No tokens routed to this expert, skip
+ continue
+
+ # For each routed token, find which k position corresponds to this expert
+ # k_positions: [num_routed_tokens]
+ k_positions = expert_mask[token_indices].float().argmax(dim=-1)
+
+ # Get the routing weights for this expert
+ # weights: [num_routed_tokens, 1]
+ weights = topk_weights[token_indices, k_positions].unsqueeze(-1)
+
+ # Get the hidden states for routed tokens
+ # current_states: [num_routed_tokens, hidden_dim]
+ current_states = hidden_states.index_select(0, token_indices)
+
+ # Compute expert output
+ # expert_output: [num_routed_tokens, hidden_dim]
+ expert_output = self.experts[expert_idx](current_states)
+
+ # Weight the expert output
+ weighted_output = expert_output * weights
+
+ # Accumulate to final output
+ final_hidden_states.index_add_(
+ 0, token_indices, weighted_output.to(hidden_states.dtype)
+ )
+
+ # Reshape back to [batch_size, sequence_length, hidden_dim]
+ final_hidden_states = final_hidden_states.reshape(
+ batch_size, sequence_length, hidden_dim
+ )
+
+ return final_hidden_states
+
+
+class LLaDADecoderLayer(nn.Module):
+ def __init__(self, config: LLaDAConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.mlp_type = "moe"
+
+ self.self_attn = LLaDA_ATTENTION_CLASSES[config._attn_implementation](
+ config=config, layer_idx=layer_idx
+ )
+
+ self.mlp = LLaDAMLP(config)
+ self.moe = LLaDAMoESparseMoeBlock(config) if config.num_experts > 0 else None
+ self.input_layernorm = LLaDARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LLaDARMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states_mlp = self.mlp(hidden_states)
+ if self.moe:
+ hidden_states_moe = self.moe(hidden_states)
+ hidden_states = residual + hidden_states_mlp + hidden_states_moe
+ else:
+ hidden_states = residual + hidden_states_mlp
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLaDA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LLaDAConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaDA Model outputting raw hidden-states without any specific head on top.",
+ LLaDA_START_DOCSTRING,
+)
+class LLaDAPreTrainedModel(PreTrainedModel):
+ config_class = LLaDAConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LLaDADecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _setup_cache(
+ self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None
+ ):
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ and cache_cls == StaticCache
+ ):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ for layer in self.model.layers:
+ device = layer.input_layernorm.weight.device
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ dtype = self.config._pre_quantization_dtype
+ else:
+ dtype = layer.self_attn.o_proj.weight.dtype
+ layer.self_attn.past_key_value = cache_cls(
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
+ )
+
+ def _reset_cache(self):
+ for layer in self.model.layers:
+ layer.self_attn.past_key_value = None
+
+
+LLaDA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaDA Model outputting raw hidden-states without any specific head on top.",
+ LLaDA_START_DOCSTRING,
+)
+class LLaDAModel(LLaDAPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaDADecoderLayer`]
+
+ Args:
+ config: LLaDAConfig
+ """
+
+ def __init__(self, config: LLaDAConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [
+ LLaDADecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = LLaDARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLaDA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ # Add Basic MDM Model config check
+ assert past_key_values is None and not use_cache, (
+ "The kvcache is not suppotred for MDM."
+ )
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ past_seen_tokens = 0
+ if use_cache: # kept for BC (cache positions)
+ if not isinstance(past_key_values, StaticCache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+
+ if cache_position is None:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError(
+ "cache_position is a required argument when using StaticCache."
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, is_causal=False
+ ) # Modify: MDM
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ **kwargs,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = (
+ next_decoder_cache.to_legacy_cache()
+ if isinstance(next_decoder_cache, Cache)
+ else next_decoder_cache
+ )
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+ def _update_causal_mask(
+ self, attention_mask, input_tensor, cache_position, is_causal=True
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
+ target_length = self.config.max_position_embeddings
+ else: # dynamic cache
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[-1] + 1
+ )
+
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+
+ if is_causal == False:
+ causal_mask = torch.zeros(
+ (sequence_length, target_length), dtype=dtype, device=device
+ )
+
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(
+ input_tensor.shape[0], 1, -1, -1
+ )
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ if attention_mask.dim() == 2:
+ # The position with 1 in attention_mask represents the place to be attended to, so here we need to mask the place where attention_mask is 0
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
+ :, None, None, :
+ ].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[
+ ..., :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+ elif attention_mask.dim() == 4:
+ # The position with 1 in attention_mask represents the place to be attended to, so here we need to mask the place where attention_mask is 0
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
+ offset = cache_position[0]
+ else:
+ offset = 0
+ mask_shape = attention_mask.shape
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
+ causal_mask[
+ : mask_shape[0],
+ : mask_shape[1],
+ offset : mask_shape[2] + offset,
+ : mask_shape[3],
+ ] = mask_slice
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ ):
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(input_tensor, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+ if not is_tracing and torch.any(attention_mask != 1):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(
+ causal_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class LLaDAModelLM(LLaDAPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LLaDAModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def _build_conversation_mask_optimized(self, conversation_ids):
+ # Reshape conversation_ids for broadcasting
+ ids_i = conversation_ids.unsqueeze(-1) # [batch_size, seq_len, 1]
+ ids_j = conversation_ids.unsqueeze(-2) # [batch_size, 1, seq_len]
+
+ # Use broadcasting to compare all pairs of conversation IDs
+ conv_mask = ids_j <= ids_i # [batch_size, seq_len, seq_len]
+
+ # Add the attention head dimension
+ return conv_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
+
+ @staticmethod
+ def add_gumbel_noise(logits, temperature):
+ """
+ The Gumbel max is a method for sampling categorical distributions.
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
+ Thus, we use float64.
+ """
+ if temperature == 0:
+ # When temperature=0, we can directly return the original logits.
+ # without any noise or transformation
+ return logits
+
+ # use float64 for more stable computation
+ logits = logits.to(torch.float64)
+ noise = torch.rand_like(logits, dtype=torch.float64)
+ gumbel_noise = (-torch.log(noise)) ** temperature
+ return logits.exp() / gumbel_noise
+
+ @staticmethod
+ def get_num_transfer_tokens(mask_index, steps):
+ """
+ Precompute the number of tokens to transition at each step.
+ Optimized to be more efficient.
+ """
+ mask_num = mask_index.sum(dim=1, keepdim=True)
+ base = mask_num // steps
+ remainder = mask_num % steps
+
+ # Create tensor once and modify in-place (via clone)
+ num_transfer_tokens = base.expand(-1, steps).clone()
+
+ # Handle remainder more efficiently
+ if remainder.sum() > 0: # Optimization: only proceed if there are remainders
+ indices = torch.arange(steps, device=mask_index.device)
+ # Create mask using broadcasting
+ # indices shape: [steps] -> [1, steps]
+ # remainder shape: [batch_size, 1]
+ # mask shape: [batch_size, steps]
+ mask = indices.unsqueeze(0) < remainder
+ num_transfer_tokens[mask] += 1
+
+ return num_transfer_tokens.to(torch.int64)
+
+ @staticmethod
+ def get_masked_indices_from_embeds(noisy_embeds, masked_embed):
+ # Get shape information
+ b, l, d = noisy_embeds.shape
+ # Expand masked_embed to the same shape as noisy_embeds [b, l, d]
+ masked_embed_expanded = masked_embed.expand(b, l, d)
+ # Calculate absolute difference
+ abs_diff = torch.abs(noisy_embeds - masked_embed_expanded)
+ # Calculate tolerance boundary (atol + rtol * abs(masked_embed))
+ tolerance = 1e-5 + 1e-5 * torch.abs(masked_embed_expanded)
+ # Check if all dimensions at each position are within tolerance
+ # all(dim=-1) ensures all dimensions of each embedding meet the condition
+ masked_indices = (abs_diff <= tolerance).all(dim=-1)
+
+ return masked_indices
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt,
+ steps=128,
+ gen_length=128,
+ block_length=128,
+ temperature=0.0,
+ cfg_scale=0.0,
+ remasking="low_confidence",
+ mask_id=126336,
+ ):
+ """
+ Args:
+ prompt: A tensor of shape (1, l).
+ steps: Sampling steps, less than or equal to gen_length.
+ gen_length: Generated answer length.
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
+ temperature: Categorical distribution sampling temperature.
+ cfg_scale: Unsupervised classifier-free guidance scale.
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
+ mask_id: The toke id of [MASK] is 126336.
+ """
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(
+ prompt.device
+ )
+ x[:, : prompt.shape[1]] = prompt.clone()
+
+ prompt_index = x != mask_id
+
+ assert gen_length % block_length == 0
+ num_blocks = gen_length // block_length
+
+ assert steps % num_blocks == 0
+ steps = steps // num_blocks
+
+ for num_block in range(num_blocks):
+ block_mask_index = (
+ x[
+ :,
+ prompt.shape[1] + num_block * block_length : prompt.shape[1]
+ + (num_block + 1) * block_length :,
+ ]
+ == mask_id
+ )
+ num_transfer_tokens = self.get_num_transfer_tokens(block_mask_index, steps)
+ for i in range(steps):
+ mask_index = x == mask_id
+ if cfg_scale > 0.0:
+ un_x = x.clone()
+ un_x[prompt_index] = mask_id
+ x_ = torch.cat([x, un_x], dim=0)
+ logits = self.model(x_).logits
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
+ else:
+ logits = self.model(x).logits
+
+ logits_with_noise = self.add_gumbel_noise(
+ logits, temperature=temperature
+ )
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
+
+ if remasking == "low_confidence":
+ p = F.softmax(logits.to(torch.float64), dim=-1)
+ x0_p = torch.squeeze(
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
+ ) # b, l
+ elif remasking == "random":
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
+ else:
+ raise NotImplementedError(remasking)
+
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length :] = -np.inf
+
+ x0 = torch.where(mask_index, x0, x)
+ confidence = torch.where(mask_index, x0_p, -np.inf)
+
+ transfer_index = torch.zeros_like(
+ x0, dtype=torch.bool, device=x0.device
+ )
+ for j in range(confidence.shape[0]):
+ _, select_index = torch.topk(
+ confidence[j], k=num_transfer_tokens[j, i]
+ )
+ transfer_index[j, select_index] = True
+ x[transfer_index] = x0[transfer_index]
+
+ return x
+
+ @torch.no_grad()
+ def generate_with_embeds(
+ self,
+ inputs_embeds,
+ steps=128,
+ gen_length=128,
+ block_length=128,
+ temperature=0.0,
+ cfg_scale=0.0,
+ remasking="low_confidence",
+ mask_id=126336,
+ tokenizer=None,
+ stopping_criteria=None,
+ generation_suffix=None,
+ **kwargs,
+ ):
+ """
+ Args:
+ inputs_embeds: A tensor of shape (1, l, d).
+ steps: Sampling steps, less than or equal to gen_length.
+ gen_length: Generated answer length.
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
+ temperature: Categorical distribution sampling temperature.
+ cfg_scale: Unsupervised classifier-free guidance scale.
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
+ mask_id: The toke id of [MASK] is 126336.
+ generation_suffix: (str or None) Generation suffix, such as "The answer is xxx", will be appended to the end
+ """
+ # Use mixed precision for faster computation
+ with torch.cuda.amp.autocast(enabled=True):
+ # Handle generation suffix
+ suffix_embeds = None
+ suffix_token_ids = None
+ suffix_len = 0
+ if (
+ generation_suffix is not None
+ and tokenizer is not None
+ and len(generation_suffix) > 0
+ ):
+ # Encode as token id
+ suffix_token_ids = tokenizer.encode(
+ generation_suffix, add_special_tokens=False
+ )
+ suffix_token_ids = torch.tensor(
+ suffix_token_ids, dtype=torch.long, device=inputs_embeds.device
+ ).unsqueeze(0) # (1, s)
+ # Convert to embedding
+ suffix_embeds = self.model.embed_tokens(suffix_token_ids) # (1, s, d)
+ suffix_len = suffix_embeds.shape[1]
+ else:
+ suffix_len = 0
+
+ # Create input in embedding space
+ total_length = inputs_embeds.shape[1] + gen_length + suffix_len
+ masked_embed = self.model.embed_tokens(
+ torch.tensor([mask_id]).to(inputs_embeds.device)
+ ) # shape (1, d)
+ x_embeds = masked_embed.repeat(1, total_length, 1).to(
+ inputs_embeds.device
+ ) # shape (1, l + gen_length + suffix_len, d)
+ x_embeds[:, : inputs_embeds.shape[1]] = inputs_embeds.clone()
+ if suffix_embeds is not None:
+ x_embeds[:, -suffix_len:] = suffix_embeds
+
+ # Create a tracking tensor for token IDs for final output
+ x = torch.full(
+ (1, total_length),
+ mask_id,
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ if suffix_token_ids is not None:
+ x[:, -suffix_len:] = suffix_token_ids
+
+ # prompt_index: A tensor of shape (1, l + gen_length + suffix_len) where the first l elements are 1 (representing the prompt)
+ # and the remaining gen_length+suffix_len elements are 0 (representing the generated part)
+ prompt_index = torch.zeros(
+ (1, total_length), dtype=torch.bool, device=inputs_embeds.device
+ )
+ prompt_index[:, : inputs_embeds.shape[1]] = (
+ 1 # shape (1, l + gen_length + suffix_len)
+ )
+
+ assert gen_length % block_length == 0
+ num_blocks = gen_length // block_length
+
+ assert steps % num_blocks == 0
+ steps = steps // num_blocks
+
+ # New: Initialize stop position variable (default to maximum length)
+ stop_position = inputs_embeds.shape[1] + gen_length
+ found_stop_seq = False
+
+ stop_tokens = []
+ if stopping_criteria is not None:
+ assert tokenizer is not None, (
+ "tokenizer is required when stopping_criteria is not None"
+ )
+ for stop_str in stopping_criteria:
+ # Use tokenizer to convert stop words to token IDs
+ tokens = tokenizer.encode(stop_str, add_special_tokens=False)
+ stop_tokens.append(tokens)
+
+ feature_cache = dLLMCache()
+ feature_cache.reset_cache(inputs_embeds.shape[1])
+ for num_block in range(num_blocks):
+ # Create mask index for the current block
+ block_start = inputs_embeds.shape[1] + num_block * block_length
+ block_end = inputs_embeds.shape[1] + (num_block + 1) * block_length
+
+ # If a stop word is found and the stop word position is before the current block, do not process the current block
+ if found_stop_seq and stop_position <= block_start:
+ break
+
+ block_embeds = x_embeds[:, block_start:block_end]
+ block_mask_index = torch.all(
+ torch.abs(block_embeds - masked_embed) < 1e-5, dim=2
+ )
+
+ num_transfer_tokens = self.get_num_transfer_tokens(
+ block_mask_index, steps
+ )
+
+ for i in range(steps):
+ # Determine which positions are mask embeddings
+ mask_index = torch.all(
+ torch.abs(x_embeds - masked_embed) < 1e-5, dim=2
+ )
+
+ # If a stop word has been found, check if the masks before the stop word are all filled
+ if found_stop_seq:
+ # Get the mask state before the stop word
+ pre_stop_masks = mask_index[
+ 0, inputs_embeds.shape[1] : stop_position
+ ]
+ # If the masks before the stop word are all filled, exit generation
+ if not pre_stop_masks.any():
+ break
+
+ # Check if there are any masks left to fill in the current block
+ current_block_masks = mask_index[0, block_start:block_end]
+ if not current_block_masks.any():
+ break
+
+ # Handle CFG
+ if cfg_scale > 0.0:
+ un_embeds = (
+ x_embeds.clone()
+ ) # shape (1, l + gen_length + suffix_len, d)
+ un_mask = prompt_index.unsqueeze(-1).expand_as(
+ x_embeds
+ ) # shape (1, l + gen_length + suffix_len, d)
+ un_embeds[un_mask] = masked_embed.repeat(
+ x_embeds.shape[0], x_embeds.shape[1], 1
+ )[un_mask] # Use repeat to avoid the complexity of expand_as
+ combined_embeds = torch.cat([x_embeds, un_embeds], dim=0)
+
+ # Forward pass
+ outputs = self.model(inputs_embeds=combined_embeds)
+ logits = self.lm_head(outputs[0]).float()
+
+ # Split and apply CFG
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
+ else:
+ # Forward pass
+ outputs = self.model(inputs_embeds=x_embeds)
+ logits = self.lm_head(outputs[0]).float()
+
+ for token_id in [126081, 126080, 126346, 126347]:
+ logits[:, :, token_id] = torch.where(
+ mask_index, -float("inf"), logits[:, :, token_id]
+ )
+
+ # Add noise and get the most likely token
+ logits_with_noise = self.add_gumbel_noise(
+ logits, temperature=temperature
+ ) # shape (1, l + gen_length + suffix_len, vocab_size)
+ x0 = torch.argmax(
+ logits_with_noise, dim=-1
+ ) # 1, l + gen_length + suffix_len
+
+ # Get confidence scores
+ if remasking == "low_confidence":
+ p = F.softmax(
+ logits.to(torch.float64), dim=-1
+ ) # shape (1, l + gen_length + suffix_len, vocab_size)
+ x0_p = torch.squeeze(
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
+ ) # 1, l + gen_length + suffix_len represents the confidence of each x0
+ elif remasking == "random":
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
+ else:
+ raise NotImplementedError(remasking)
+
+ # If a stop word is found, only process positions before the stop word
+ if found_stop_seq:
+ x0_p[:, stop_position:] = -np.inf
+ else:
+ # Prevent processing future blocks
+ x0_p[:, block_end:] = -np.inf
+
+ # Do not allow the generated suffix part to be overwritten
+ if suffix_len > 0:
+ x0_p[:, -suffix_len:] = -np.inf
+
+ # Update predictions only at mask positions
+ x0_embeds = self.model.embed_tokens(
+ x0
+ ) # shape (1, l + gen_length + suffix_len, d)
+ x0_embeds = torch.where(
+ mask_index.unsqueeze(-1).expand_as(x_embeds),
+ x0_embeds,
+ x_embeds,
+ )
+ x0 = torch.where(
+ mask_index, x0, x
+ ) # shape (1, l + gen_length + suffix_len)
+
+ # Calculate confidence and determine transfer index
+ confidence = torch.where(mask_index, x0_p, -np.inf)
+
+ transfer_index = torch.zeros_like(
+ x0, dtype=torch.bool, device=x0.device
+ )
+ for j in range(confidence.shape[0]):
+ _, select_index = torch.topk(
+ confidence[j], k=num_transfer_tokens[j, i]
+ )
+ transfer_index[j, select_index] = True
+
+ # Update embeddings and token IDs
+ x_embeds[transfer_index] = x0_embeds[transfer_index]
+ x[transfer_index] = x0[transfer_index]
+
+ # New: Check for stop words after each update
+ if stopping_criteria is not None:
+ # Only check the generated part (excluding the suffix)
+ generated_part = x[
+ 0,
+ inputs_embeds.shape[1] : inputs_embeds.shape[1]
+ + gen_length,
+ ]
+ current_stop_position = None
+
+ for stop_seq in stop_tokens:
+ if not isinstance(stop_seq, list):
+ stop_seq = [stop_seq]
+ # Check if the generated sequence contains stop words
+ for start_idx in range(
+ generated_part.size(0) - len(stop_seq) + 1
+ ):
+ if torch.all(
+ generated_part[
+ start_idx : start_idx + len(stop_seq)
+ ]
+ == torch.tensor(stop_seq, device=x.device)
+ ):
+ # Calculate the position of the currently found stop word
+ current_position = (
+ inputs_embeds.shape[1] + start_idx
+ )
+ # If it is the first time a stop word is found, or this stop word is earlier than the previously found one
+ if (
+ not found_stop_seq
+ or current_position < stop_position
+ ):
+ stop_position = current_position
+ found_stop_seq = True
+ break
+ if found_stop_seq and current_stop_position is None:
+ break
+
+ # Return the generated result, up to stop_position, and append the suffix
+ if found_stop_seq:
+ if suffix_len > 0:
+ return torch.cat(
+ [
+ x[:, inputs_embeds.shape[1] : stop_position],
+ x[:, -suffix_len:],
+ ],
+ dim=1,
+ )
+ else:
+ return x[:, inputs_embeds.shape[1] : stop_position]
+ else:
+ if suffix_len > 0:
+ return torch.cat(
+ [
+ x[
+ :,
+ inputs_embeds.shape[1] : inputs_embeds.shape[1]
+ + gen_length,
+ ],
+ x[:, -suffix_len:],
+ ],
+ dim=1,
+ )
+ else:
+ return x[
+ :, inputs_embeds.shape[1] : inputs_embeds.shape[1] + gen_length
+ ]
+
+ @add_start_docstrings_to_model_forward(LLaDA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ conversation_ids: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ def forward_process_embeds(input_embeds, labels, eps=1e-3):
+ b, l, d = input_embeds.shape
+ t = torch.rand(b, device=input_embeds.device)
+ p_mask = (1 - eps) * t + eps
+ p_mask = p_mask[:, None].repeat(1, l)
+
+ masked_indices = torch.rand((b, l), device=input_embeds.device) < p_mask
+ # Add label condition filtering
+ valid_mask = labels != -100 # Create valid encoding
+ masked_indices = (
+ masked_indices & valid_mask
+ ) # Combine random encoding and valid encoding
+ # Magic number 126336 stands for the tokenizer special token,
+ # Magic embeddings, which is used for [MASK] token here,
+ masked_embed = self.model.embed_tokens(
+ torch.tensor([126336]).to(input_embeds.device)
+ )
+ noisy_embeds = torch.where(
+ masked_indices.unsqueeze(-1), masked_embed, input_embeds
+ )
+
+ return noisy_embeds, p_mask, masked_embed
+
+ noisy_embeds, p_mask, masked_embed = forward_process_embeds(
+ inputs_embeds, labels
+ )
+
+ masked_indices = self.get_masked_indices_from_embeds(
+ noisy_embeds, masked_embed
+ ) # shape (b, l)
+ prompt_index = (labels == -100).to(torch.int64) # shape (b, l)
+
+ noisy_data_length = torch.sum(
+ (1 - prompt_index), dim=-1, keepdim=True
+ ) # shape (b, 1)
+ noisy_data_length = noisy_data_length.repeat(
+ 1, noisy_embeds.shape[1]
+ ) # shape (b, l)
+
+ if conversation_ids is not None:
+ conversation_mask = self._build_conversation_mask_optimized(
+ conversation_ids
+ )
+ if attention_mask is not None:
+ # 1. Dimension expansion
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(
+ 2
+ ) # (batch, length) -> (batch, 1, 1, length)
+ attention_mask = attention_mask.expand_as(
+ conversation_mask
+ ) # (batch, 1, 1, length) -> (batch, 1, length, length)
+ # 2. Mask combination (element-wise multiplication)
+ combined_mask = conversation_mask * attention_mask
+ else:
+ # If attention_mask is None, directly use conversation_mask
+ combined_mask = conversation_mask
+ attention_mask = combined_mask
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=noisy_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(
+ self.vocab_size // self.config.pretraining_tp, dim=0
+ )
+ logits = [
+ F.linear(hidden_states, lm_head_slices[i])
+ for i in range(self.config.pretraining_tp)
+ ]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Change for MDM
+ token_loss = (
+ F.cross_entropy(
+ logits[masked_indices],
+ labels[masked_indices],
+ ignore_index=-100,
+ reduction="none",
+ )
+ / p_mask[masked_indices]
+ )
+ loss = (
+ torch.sum(token_loss / noisy_data_length[masked_indices])
+ / labels.shape[0]
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ # With static cache, the `past_key_values` is None
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
+ has_static_cache = False
+ if past_key_values is None:
+ past_key_values = getattr(
+ self.model.layers[0].self_attn, "past_key_value", None
+ )
+ has_static_cache = past_key_values is not None
+
+ past_length = 0
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ past_length = (
+ cache_position[0]
+ if cache_position is not None
+ else past_key_values.get_seq_length()
+ )
+ max_cache_length = (
+ torch.tensor(
+ past_key_values.get_max_length(), device=input_ids.device
+ )
+ if past_key_values.get_max_length() is not None
+ else None
+ )
+ cache_length = (
+ past_length
+ if max_cache_length is None
+ else torch.min(max_cache_length, past_length)
+ )
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if (
+ attention_mask is not None
+ and attention_mask.shape[1] > input_ids.shape[1]
+ ):
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+ # TODO: use `next_tokens` directly instead.
+ model_inputs = {"input_ids": input_ids.contiguous()}
+
+ input_length = (
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+ )
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_length, past_length + input_length, device=input_ids.device
+ )
+ else:
+ cache_position = cache_position[-input_length:]
+
+ if has_static_cache:
+ past_key_values = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past
+ ),
+ )
+ return reordered_past
diff --git a/train/llava/model/llava_arch.py b/train/llava/model/llava_arch.py
new file mode 100644
index 0000000..2467aa7
--- /dev/null
+++ b/train/llava/model/llava_arch.py
@@ -0,0 +1,701 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+
+import math
+import re
+import time
+import torch
+import torch.nn as nn
+from .multimodal_encoder.builder import build_vision_tower
+from .multimodal_resampler.builder import build_vision_resampler
+from .multimodal_projector.builder import build_vision_projector
+
+from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+from llava.mm_utils import get_anyres_image_grid_shape
+from llava.utils import rank0_print, rank_print
+import random
+
+
+class LlavaMetaModel:
+
+ def __init__(self, config):
+ super(LlavaMetaModel, self).__init__(config)
+
+ if hasattr(config, "mm_vision_tower"):
+ delay_load = getattr(config, "delay_load", False)
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
+
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, "vision_tower", None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ def initialize_vision_modules(self, model_args, fsdp=None):
+ vision_tower = model_args.vision_tower
+ mm_vision_select_layer = model_args.mm_vision_select_layer
+ mm_vision_select_feature = model_args.mm_vision_select_feature
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+ mm_patch_merge_type = model_args.mm_patch_merge_type
+
+ self.config.mm_vision_tower = vision_tower
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
+
+ if self.get_vision_tower() is None:
+ vision_tower = build_vision_tower(model_args)
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
+ for k, v in vision_resampler.config.items():
+ setattr(self.config, k, v)
+
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [vision_tower]
+ self.vision_resampler = [vision_resampler]
+ else:
+ self.vision_tower = vision_tower
+ self.vision_resampler = vision_resampler
+ else:
+ if fsdp is not None and len(fsdp) > 0:
+ vision_resampler = self.vision_resampler[0]
+ vision_tower = self.vision_tower[0]
+ else:
+ vision_resampler = self.vision_resampler
+ vision_tower = self.vision_tower
+ vision_tower.load_model()
+
+ # In case it is frozen by LoRA
+ for p in self.vision_resampler.parameters():
+ p.requires_grad = True
+
+ self.config.use_mm_proj = True
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
+ self.config.mm_vision_select_layer = mm_vision_select_layer
+ self.config.mm_vision_select_feature = mm_vision_select_feature
+ self.config.mm_patch_merge_type = mm_patch_merge_type
+
+
+ if not hasattr(self.config, 'add_faster_video'):
+ if model_args.add_faster_video:
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
+ self.faster_token = nn.Parameter(
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
+ )
+
+ if getattr(self, "mm_projector", None) is None:
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
+
+ if "unpad" in mm_patch_merge_type:
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
+ else:
+ # In case it is frozen by LoRA
+ for p in self.mm_projector.parameters():
+ p.requires_grad = True
+
+ if pretrain_mm_mlp_adapter is not None:
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
+
+ def get_w(weights, keyword):
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
+
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
+ original_size (tuple): The original size of the image (height, width).
+
+ Returns:
+ torch.Tensor: The unpadded image tensor.
+ """
+ original_width, original_height = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ # Compute aspect ratios
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ # Determine padding size and direction
+ if original_aspect_ratio > current_aspect_ratio:
+ # Padding was added to the height
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ # Padding was added to the width
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+class LlavaMetaForCausalLM(ABC):
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ def get_2dPool(self, image_feature, stride=2):
+ height = width = self.get_vision_tower().num_patches_per_side
+ num_frames, num_tokens, num_dim = image_feature.shape
+ image_feature = image_feature.view(num_frames, height, width, -1)
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
+ if self.config.mm_spatial_pool_mode == "average":
+ image_feature = nn.functional.avg_pool2d(image_feature, stride)
+ elif self.config.mm_spatial_pool_mode == "max":
+ image_feature = nn.functional.max_pool2d(image_feature, stride)
+ elif self.config.mm_spatial_pool_mode == "bilinear":
+ height, width = image_feature.shape[2:]
+ scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
+ image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
+
+ else:
+ raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
+ image_feature = image_feature.permute(0, 2, 3, 1)
+ image_feature = image_feature.view(num_frames, -1, num_dim)
+ return image_feature
+
+ def encode_images(self, images):
+ image_features = self.get_model().get_vision_tower()(images)
+ # image_features = self.get_model().vision_resampler(image_features, images=images)
+ image_features = self.get_model().mm_projector(image_features)
+ return image_features
+
+ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
+ videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
+ all_videos_or_images_features = []
+ all_faster_video_features = []
+ cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride
+
+ for idx, feat in enumerate(per_videos_or_images_features):
+
+ feat = self.get_model().mm_projector(feat)
+ faster_video_feature = 0
+ slower_img_feat = 0
+ if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1:
+ slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
+ if self.config.add_faster_video:
+ cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
+ faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
+ if slower_img_feat != 0:
+ all_videos_or_images_features.append(slower_img_feat)
+ else:
+ all_videos_or_images_features.append(feat)
+ all_faster_video_features.append(faster_video_feature)
+ return all_videos_or_images_features,all_faster_video_features
+
+ def add_token_per_grid(self, image_feature):
+ resize_h = int(math.sqrt(image_feature.shape[1]))
+ num_frames = image_feature.shape[0]
+ feature_dim = image_feature.shape[-1]
+
+ image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ if getattr(self.config, "add_faster_video", False):
+ # import pdb; pdb.set_trace()
+ # (3584, 832, 14) -> (3584, 64, 13, 14)
+ image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
+ # (3584, 64, 13, 14) -> (64, 13, 14, 3584)
+ image_feature = image_feature.permute(1, 2, 3, 0).contiguous()
+ # (64, 13, 14, 3584) -> (64, 13*14, 3584)
+ image_feature = image_feature.flatten(1, 2)
+ # import pdb; pdb.set_trace()
+ return image_feature
+ # import pdb; pdb.set_trace()
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ return image_feature
+
+ def add_token_per_frame(self, image_feature):
+ image_feature = image_feature.permute(2, 0, 1).contiguous()
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ image_feature = image_feature.permute(1, 2, 0).contiguous()
+ return image_feature
+
+ def generate_conversation_ids(self, labels):
+ """
+ Args:
+ labels: Label tensor, can be one-dimensional or two-dimensional
+ Returns:
+ Conversation ID tensor with the same shape as labels
+ """
+ # Process input dimensions
+ original_shape = labels.shape
+ if labels.ndim == 1:
+ labels = labels.unsqueeze(0)
+
+ batch_size, seq_len = labels.shape
+ device = labels.device
+ conversation_ids = torch.zeros_like(labels)
+
+ # Special token IDs
+ start_header_id = 126346
+ eot_id = 126348
+ assistant_role_id1 = 598
+
+ # Process all sequences in batch
+ for b in range(batch_size):
+ # Pre-search all special token positions to reduce repeated searches
+ start_positions = (labels[b] == start_header_id).nonzero(as_tuple=True)[0]
+ end_positions = (labels[b] == eot_id).nonzero(as_tuple=True)[0]
+
+ # If no boundaries are found, continue to the next sequence
+ if len(start_positions) == 0 or len(end_positions) == 0:
+ continue
+
+ # Pair all message start and end positions
+ message_boundaries = []
+ for start_pos in start_positions:
+ # Find the nearest end position
+ end_indices = (end_positions >= start_pos).nonzero(as_tuple=True)[0]
+ if len(end_indices) == 0:
+ continue
+
+ end_pos = end_positions[end_indices[0]]
+
+ # Quickly check if it's an assistant message
+ start_idx = start_pos.item()
+ is_assistant = (start_idx + 1 < seq_len and
+ labels[b, start_idx + 1] == assistant_role_id1)
+
+ message_boundaries.append((start_idx, end_pos.item(), is_assistant))
+
+ # Sort by start position
+ message_boundaries.sort(key=lambda x: x[0])
+
+ # Determine if there is a system message
+ has_system = len(message_boundaries) > 0 and not message_boundaries[0][2]
+
+ # Assign conversation turn IDs
+ current_turn = 0
+ prev_was_assistant = False
+
+ # Efficiently handle BOS token (usually at the start of the sequence)
+ if labels[b, 0] == 126080: # BOS ID
+ conversation_ids[b, 0] = 0
+
+ # Assign IDs to all messages at once
+ for i, (start_pos, end_pos, is_assistant) in enumerate(message_boundaries):
+ # The first non-assistant message is a system message
+ is_system = i == 0 and not is_assistant and has_system
+
+ if is_system:
+ # System message belongs to the first conversation turn
+ conversation_ids[b, start_pos:end_pos+1] = 0
+ else:
+ # If it's a user message and the previous one was an assistant message, increase the turn
+ if not is_assistant and prev_was_assistant:
+ current_turn += 1
+
+ # Assign ID to the entire message block at once
+ conversation_ids[b, start_pos:end_pos+1] = current_turn
+
+ prev_was_assistant = is_assistant
+
+ # Fill gaps between messages - use cumulative max method
+ # This is much faster than looping element by element
+ for i in range(1, seq_len):
+ if conversation_ids[b, i] == 0 and conversation_ids[b, i-1] > 0:
+ conversation_ids[b, i] = conversation_ids[b, i-1]
+
+ # New: Handle end padding
+ non_zero_mask = (conversation_ids[b] != 0)
+ if non_zero_mask.any():
+ last_non_zero_idx = torch.nonzero(non_zero_mask, as_tuple=True)[0][-1]
+ last_turn = conversation_ids[b, last_non_zero_idx]
+ conversation_ids[b, last_non_zero_idx+1:] = last_turn
+
+ # Return a tensor with the same dimensions as the input
+ if len(original_shape) == 1:
+ return conversation_ids.squeeze(0)
+
+ return conversation_ids
+
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None, is_llada=False):
+ vision_tower = self.get_vision_tower()
+ # rank_print(modalities)
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
+
+ if isinstance(modalities, str):
+ modalities = [modalities]
+
+ # import pdb; pdb.set_trace()
+ if type(images) is list or images.ndim == 5:
+ if type(images) is list:
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
+
+ video_idx_in_batch = []
+ for _ in range(len(modalities)):
+ if modalities[_] == "video":
+ video_idx_in_batch.append(_)
+
+ images_list = []
+ for image in images:
+ if image.ndim == 4:
+ images_list.append(image)
+ else:
+ images_list.append(image.unsqueeze(0))
+
+ concat_images = torch.cat([image for image in images_list], dim=0)
+ split_sizes = [image.shape[0] for image in images_list]
+ encoded_image_features = self.encode_images(concat_images)
+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
+
+ # This is a list, each element is [num_images, patch * patch, dim]
+ # rank_print(f"Concat images : {concat_images.shape}")
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
+ image_features = []
+ for idx, image_feat in enumerate(encoded_image_features):
+ if idx in video_idx_in_batch:
+ image_features.append(self.get_2dPool(image_feat))
+ else:
+ image_features.append(image_feat)
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
+ # image_features = torch.split(image_features, split_sizes, dim=0)
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
+ mm_newline_position = getattr(self.config, "mm_newline_position", "one_token")
+
+ if mm_patch_merge_type == "flat":
+ image_features = [x.flatten(0, 1) for x in image_features]
+
+ elif mm_patch_merge_type.startswith("spatial"):
+ new_image_features = []
+ for image_idx, image_feature in enumerate(image_features):
+ # FIXME: now assume the image is square, and split to 2x2 patches
+ # num_patches = h * w, where h = w = sqrt(num_patches)
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
+ # rank0_print("At least we are reaching here")
+ # import pdb; pdb.set_trace()
+ if image_idx in video_idx_in_batch: # video operations
+ # rank0_print("Video")
+ if mm_newline_position == "grid":
+ # Grid-wise
+ image_feature = self.add_token_per_grid(image_feature)
+ if getattr(self.config, "add_faster_video", False):
+ faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
+ # Add a token for each frame
+ concat_slow_fater_token = []
+ # import pdb; pdb.set_trace()
+ for _ in range(image_feature.shape[0]):
+ if _ % self.config.faster_token_stride == 0:
+ concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
+ else:
+ concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
+ # import pdb; pdb.set_trace()
+ image_feature = torch.cat(concat_slow_fater_token)
+
+ # print("!!!!!!!!!!!!")
+
+ new_image_features.append(image_feature)
+ elif mm_newline_position == "frame":
+ # Frame-wise
+ image_feature = self.add_token_per_frame(image_feature)
+
+ new_image_features.append(image_feature.flatten(0, 1))
+
+ elif mm_newline_position == "one_token":
+ # one-token
+ image_feature = image_feature.flatten(0, 1)
+ if 'unpad' in mm_patch_merge_type:
+ image_feature = torch.cat((
+ image_feature,
+ self.model.image_newline[None].to(image_feature.device)
+ ), dim=0)
+ new_image_features.append(image_feature)
+ elif mm_newline_position == "no_token":
+ new_image_features.append(image_feature.flatten(0, 1))
+ else:
+ raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")
+ elif image_feature.shape[0] > 1: # multi patches and multi images operations
+ # rank0_print("Single-images")
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.get_vision_tower().num_patches_per_side
+ assert height * width == base_image_feature.shape[0]
+
+ if "anyres_max" in image_aspect_ratio:
+ matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
+ if matched_anyres_max_num_patches:
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
+
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+ if hasattr(self.get_vision_tower(), "image_size"):
+ vision_tower_image_size = self.get_vision_tower().image_size
+ else:
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
+ try:
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
+ except Exception as e:
+ rank0_print(f"Error: {e}")
+ num_patch_width, num_patch_height = 2, 2
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ else:
+ image_feature = image_feature.view(2, 2, height, width, -1)
+
+ if "maxpool2x2" in mm_patch_merge_type:
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
+ unit = image_feature.shape[2]
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ c, h, w = image_feature.shape
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
+ if times > 1.1:
+ image_feature = image_feature[None]
+ image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ elif "unpad" in mm_patch_merge_type:
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ else:
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
+ image_feature = image_feature.flatten(0, 3)
+ if "nobase" in mm_patch_merge_type:
+ pass
+ else:
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ new_image_features.append(image_feature)
+ else: # single image operations
+ image_feature = image_feature[0]
+ if "unpad" in mm_patch_merge_type:
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
+
+ new_image_features.append(image_feature)
+ image_features = new_image_features
+ else:
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
+ else:
+ image_features = self.encode_images(images)
+
+ # TODO: image start / end is not implemented here to support pretraining.
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
+ raise NotImplementedError
+ # rank_print(f"Total images : {len(image_features)}")
+
+ # Let's just add dummy tensors if they do not exist,
+ # it is a headache to deal with None all the time.
+ # But it is not ideal, and if you have a better idea,
+ # please open an issue / submit a PR, thanks.
+ _labels = labels
+ _position_ids = position_ids
+ _attention_mask = attention_mask
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ else:
+ attention_mask = attention_mask.bool()
+ if position_ids is None:
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
+ if labels is None:
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
+
+ # remove the padding using attention_mask -- FIXME
+ _input_ids = input_ids
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
+
+ new_input_embeds = []
+ new_labels = []
+ cur_image_idx = 0
+ # rank_print("Inserting Images embedding")
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
+ # rank0_print(num_images)
+ if num_images == 0:
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
+ cur_input_ids_noim = []
+ cur_labels = labels[batch_idx]
+ cur_labels_noim = []
+ for i in range(len(image_token_indices) - 1):
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
+ cur_new_input_embeds = []
+ cur_new_labels = []
+
+ for i in range(num_images + 1):
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
+ cur_new_labels.append(cur_labels_noim[i])
+ if i < num_images:
+ try:
+ cur_image_features = image_features[cur_image_idx]
+ except IndexError:
+ cur_image_features = image_features[cur_image_idx - 1]
+ cur_image_idx += 1
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
+
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
+
+ # import pdb; pdb.set_trace()
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
+ cur_new_labels = torch.cat(cur_new_labels)
+
+ new_input_embeds.append(cur_new_input_embeds)
+ new_labels.append(cur_new_labels)
+
+ # Truncate sequences to max length as image embeddings can make the sequence longer
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
+ # rank_print("Finishing Inserting")
+
+ new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
+ # TODO: Hard code for control loss spike
+ # if tokenizer_model_max_length is not None:
+ # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
+ # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
+
+ # Combine them
+ max_len = max(x.shape[0] for x in new_input_embeds)
+ batch_size = len(new_input_embeds)
+
+ new_input_embeds_padded = []
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
+ # rank0_print("Prepare pos id")
+
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
+ cur_len = cur_new_embed.shape[0]
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, -cur_len:] = cur_new_labels
+ attention_mask[i, -cur_len:] = True
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+ else:
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, :cur_len] = cur_new_labels
+ attention_mask[i, :cur_len] = True
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
+ # rank0_print("tokenizer padding")
+
+ if _labels is None:
+ new_labels = None
+ else:
+ new_labels = new_labels_padded
+
+ if _attention_mask is None:
+ attention_mask = None
+ else:
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
+
+ if _position_ids is None:
+ position_ids = None
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
+ split_position = random.randint(0, new_input_embeds.size(1))
+ left_add = random.randint(0, self.config.pos_skipping_range)
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
+ position_ids[:, :split_position] += left_add
+ position_ids[:, split_position:] += right_add
+
+ # add conversation_ids
+ if is_llada and attention_mask is not None:
+ conversation_ids = self.generate_conversation_ids(new_labels)
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, conversation_ids
+ # import pdb; pdb.set_trace()
+ # rank0_print("Finish preparing")
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
diff --git a/train/llava/model/make_delta.py b/train/llava/model/make_delta.py
new file mode 100644
index 0000000..7b3fbab
--- /dev/null
+++ b/train/llava/model/make_delta.py
@@ -0,0 +1,52 @@
+"""
+Usage:
+python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
+"""
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model.utils import auto_upgrade
+
+
+def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading target model")
+ auto_upgrade(target_model_path)
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Calculating delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ if name not in base.state_dict():
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data -= base.state_dict()[name]
+ else:
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
+
+ print("Saving delta")
+ if hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, default=None)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
diff --git a/train/llava/model/multimodal_encoder/builder.py b/train/llava/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000..d1cab60
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/builder.py
@@ -0,0 +1,36 @@
+import os
+from .clip_encoder import CLIPVisionTower
+from .imagebind import ImageBindWrapper
+from .open_clip_encoder import OpenCLIPVisionTower
+from .hf_vision import HFVisionTower
+from .siglip_encoder import SigLipVisionTower
+from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
+
+# from .eva_clip.eva_clip_encoder import EvaClipVisionTower
+# from .dev_eva_clip.eva_vit import EvaViTWrapper
+
+
+def build_vision_tower(vision_tower_cfg, **kwargs):
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
+ is_absolute_path_exists = os.path.exists(vision_tower)
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
+ if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: # is_absolute_path_exists or
+ if use_s2:
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
+ else:
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif "siglip" in vision_tower:
+ vision_tower = "../model/siglip2-so400m-patch14-384"
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
+ elif vision_tower.startswith("hf:"):
+ return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif vision_tower in ["imagebind_huge"]:
+ return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif vision_tower.startswith("open_clip_hub"):
+ return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
+ # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
+ # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
+
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
diff --git a/train/llava/model/multimodal_encoder/clip_encoder.py b/train/llava/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000..212b262
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+from llava.utils import rank0_print
+from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
+
+try:
+ from s2wrapper import forward as multiscale_forward
+except:
+ pass
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+ else:
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ select_feature_type = self.select_feature
+
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
+ select_feature_type = select_feature_type.replace("slicefour_", "")
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
+ select_layers = [-2, -5, -8, -11, 6]
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if select_feature_type == "patch":
+ image_features = image_features[:, 1:]
+ elif select_feature_type == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ _hidden_size = self.config.hidden_size
+ if "slicefour" in self.select_feature:
+ _hidden_size *= 4
+ if "slice_m25811_f6" in self.select_feature:
+ _hidden_size *= 5
+ return _hidden_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def num_patches(self):
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def image_size(self):
+ return self.config.image_size
+
+
+class CLIPVisionTowerS2(CLIPVisionTower):
+ def __init__(self, vision_tower, args, delay_load=False):
+
+ self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
+ self.s2_scales.sort()
+ self.s2_split_size = self.s2_scales[0]
+ self.s2_image_size = self.s2_scales[-1]
+
+ super().__init__(vision_tower, args, delay_load)
+
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
+ if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
+
+ self.is_loaded = True
+
+ def forward_feature(self, images):
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
+ image_features.append(image_feature)
+ else:
+ image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
+
+ return image_features
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size * len(self.s2_scales)
diff --git a/train/llava/model/multimodal_encoder/hf_vision.py b/train/llava/model/multimodal_encoder/hf_vision.py
new file mode 100644
index 0000000..a413208
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/hf_vision.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+
+from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor
+from llava.utils import rank0_print
+
+
+class HFVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower.replace("hf:", "", 1)
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ self.load_model()
+ else:
+ self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self):
+ try:
+ self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
+ except Exception as e:
+ if "448" in self.vision_tower_name:
+ image_size = 448
+ # use image processor with conig
+ self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size)
+ else:
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
+ self.device = self.vision_tower.device
+ self.dtype = self.vision_tower.dtype
+ self.config = self.vision_tower.config
+
+ if hasattr(self.vision_tower, "vision_model"):
+ self.vision_tower = self.vision_tower.vision_model
+ self.vision_tower.requires_grad_(False)
+ # self.vision_tower.eval()
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ select_feature_type = self.select_feature
+
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
+ select_feature_type = select_feature_type.replace("slicefour_", "")
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if select_feature_type == "patch":
+ image_features = image_features[:, 1:]
+ elif select_feature_type == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ # @property
+ # def dtype(self):
+ # return self.vision_tower.dtype
+
+ # @property
+ # def device(self):
+ # return self.vision_tower.device
+
+ @property
+ def hidden_size(self):
+ try:
+ _hidden_size = self.config.hidden_size
+ except:
+ _hidden_size = self.config.vision_config.hidden_size
+ if "slicefour" in self.select_feature:
+ _hidden_size *= 4
+ return _hidden_size
+
+ @property
+ def num_patches(self):
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def image_size(self):
+ return self.config.image_size
diff --git a/train/llava/model/multimodal_encoder/imagebind.py b/train/llava/model/multimodal_encoder/imagebind.py
new file mode 100644
index 0000000..8bbe71c
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/imagebind.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+
+from transformers import CLIPImageProcessor
+
+try:
+ from imagebind.models import imagebind_model
+ from imagebind.models.imagebind_model import ModalityType
+ from imagebind.data import load_and_transform_audio_data
+except ImportError:
+ pass
+
+
+class ImageBindWrapper(nn.Module):
+ def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = select_layer
+ self.select_feature = select_feature
+
+ if not delay_load:
+ self.load_model()
+
+ def load_model(self):
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ self.vision_tower = imagebind_model.imagebind_huge(pretrained=True)
+ for p in self.vision_tower.parameters():
+ p.requires_grad = False
+ self.vision_tower.eval()
+ self.is_loaded = True
+
+ def train(self, mode=True):
+ self.training = mode
+
+ if self.is_loaded:
+ self.vision_tower.eval()
+
+ @torch.no_grad()
+ def forward(self, x):
+ if type(x) == dict:
+ if x["audios"] is not None:
+ inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()}
+ embeddings = self.vision_tower(inputs)
+ audio_embedding = embeddings[ModalityType.AUDIO]
+ return audio_embedding.unsqueeze(1)
+ else:
+ inputs = {ModalityType.VISION: x.to(dtype=self.dtype)}
+ embeddings = self.vision_tower(inputs)
+ vision_embedding = embeddings[ModalityType.VISION]
+ if vision_embedding.ndim == 2:
+ return vision_embedding.unsqueeze(1)
+ if vision_embedding.shape[1] == 257:
+ return vision_embedding[:, 1:]
+ raise ValueError(f"Unexpected shape: {vision_embedding.shape}")
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, 1024, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.modality_preprocessors.vision.cls_token.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.modality_preprocessors.vision.cls_token.device
+
+ @property
+ def hidden_size(self):
+ return 1024
diff --git a/train/llava/model/multimodal_encoder/open_clip_encoder.py b/train/llava/model/multimodal_encoder/open_clip_encoder.py
new file mode 100644
index 0000000..17a3277
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/open_clip_encoder.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+from transformers import CLIPImageProcessor
+from llava.utils import rank0_print
+
+try:
+ import open_clip
+ import torchvision
+ from open_clip.transformer import _expand_token
+except ImportError:
+ print("OpenCLIP not installed")
+ open_clip = None
+
+HIDDEN_SIZE_DICT = {
+ "ViT-H-14-378-quickgelu": 1280,
+}
+
+
+class OpenCLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+ self.model_name = vision_tower.replace("open_clip_hub:", "")
+ self.pretrained = args.vision_tower_pretrained
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+
+ def load_model(self, device_map="auto"):
+ rank0_print(f"Loading OpenCLIP model: {self.model_name}")
+ rank0_print(f"Pretrained: {self.pretrained}")
+ vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda")
+
+ resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
+ normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
+ self.resize_transform_size = resize_transform.size # 224 or 384
+ self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-large-patch14",
+ crop_size=resize_transform.size,
+ size={"shortest_edge": resize_transform.size},
+ image_mean=list(normalize_transform.mean),
+ image_std=list(normalize_transform.std),
+ )
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower = vision_tower.visual
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ image_features = image_forward_outs[self.select_layer]
+ if self.select_feature == "patch":
+ image_features = image_features[:, 1:]
+ elif self.select_feature == "cls_patch":
+ image_features = image_features
+ elif self.select_feature == "conv_flatten":
+ image_features = image_features.flatten(2).transpose(1, 2)
+ else:
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
+ return image_features
+
+ def forward_visual(self, x, output_hidden_states=False):
+ if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"):
+ return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer))
+ else:
+
+ def forward_openclip(self, x: torch.Tensor):
+ features = []
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
+ dim=1,
+ )
+ # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for r in self.transformer.resblocks:
+ x = r(x, attn_mask=None)
+ features.append(x)
+ return features
+
+ return forward_openclip(self.vision_tower, x)
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ if hasattr(self.vision_tower, "conv1"):
+ return self.vision_tower.conv1.weight.dtype
+ if hasattr(self.vision_tower, "trunk"):
+ return self.vision_tower.trunk.patch_embed.proj.weight.dtype
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ if hasattr(self.vision_tower, "conv1"):
+ return self.vision_tower.conv1.weight.device
+ if hasattr(self.vision_tower, "trunk"):
+ return self.vision_tower.trunk.patch_embed.proj.weight.device
+ raise NotImplementedError
+
+ @property
+ def config(self):
+ return None
+
+ @property
+ def hidden_size(self):
+ if self.model_name in HIDDEN_SIZE_DICT:
+ return HIDDEN_SIZE_DICT[self.model_name]
+ else:
+ raise NotImplementedError
+
+ @property
+ def num_patches(self):
+ image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0]
+ _num_patches = (image_size // self.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def image_size(self):
+ return self.resize_transform_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.resize_transform_size // self.patch_size
diff --git a/train/llava/model/multimodal_encoder/siglip_encoder.py b/train/llava/model/multimodal_encoder/siglip_encoder.py
new file mode 100644
index 0000000..f1e101a
--- /dev/null
+++ b/train/llava/model/multimodal_encoder/siglip_encoder.py
@@ -0,0 +1,620 @@
+"""
+# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
+"""
+
+from typing import Optional, Tuple, Union, Dict
+from dataclasses import dataclass
+from functools import partial, reduce
+from PIL import Image
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import os
+from transformers.image_processing_utils import BatchFeature, get_size_dict
+from transformers.image_transforms import (
+ convert_to_rgb,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from transformers.image_utils import (
+ ChannelDimension,
+ PILImageResampling,
+ to_numpy_array,
+)
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from transformers import PretrainedConfig
+from transformers.utils import ModelOutput
+from llava.utils import rank0_print
+
+
+class SigLipImageProcessor:
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.data_format = data_format
+ self.crop_size = crop_size
+
+ def preprocess(self, images, return_tensors):
+ if isinstance(images, Image.Image):
+ images = [images]
+ else:
+ # to adapt video data
+ images = [to_numpy_array(image) for image in images]
+ assert isinstance(images, list)
+
+ transforms = [
+ convert_to_rgb,
+ to_numpy_array,
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
+ ]
+
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
+ data = {"pixel_values": images}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+class SigLipVisionConfig(PretrainedConfig):
+ model_type = "siglip_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ image_mean=(0.5, 0.5, 0.5),
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.image_mean = image_mean
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from SigLipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
+class SigLipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class SigLipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class SigLipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
+ raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
+class SigLipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
+class SigLipEncoderLayer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SigLipAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ # Ignore copy
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class SigLipPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SigLipVisionConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ pass
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
+class SigLipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SigLipEncoderLayer`].
+
+ Args:
+ config: SigLipVisionConfig
+ """
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
+
+
+class SigLipVisionTransformer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SigLipVisionEmbeddings(config)
+ self.encoder = SigLipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = self.head(last_hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SigLipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class SigLipVisionModel(SigLipPreTrainedModel):
+ config_class = SigLipVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["SigLipEncoderLayer"]
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SigLipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SigLipVisionModel
+
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class SigLipVisionTower(nn.Module):
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.config = SigLipVisionConfig()
+
+ self.vision_tower_name = vision_tower
+
+ self.image_processor = SigLipImageProcessor()
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+ else:
+ self.cfg_only = self.config
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+
+ del self.vision_tower.vision_model.encoder.layers[-1:]
+ self.vision_tower.vision_model.head = nn.Identity()
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
+ assert image_features.shape[-2] == 729
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
+ assert image_features.shape[-2] == 729
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ for p in self.vision_tower.parameters():
+ return p.dtype
+
+ @property
+ def device(self):
+ for p in self.vision_tower.parameters():
+ return p.device
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
+
+ @property
+ def image_size(self):
+ return self.config.image_size
diff --git a/train/llava/model/multimodal_projector/builder.py b/train/llava/model/multimodal_projector/builder.py
new file mode 100644
index 0000000..3122a0c
--- /dev/null
+++ b/train/llava/model/multimodal_projector/builder.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+import re
+
+from .pooler_projector import PoolerProjector
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": "identity"}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+def build_vision_projector(config, delay_load=False, **kwargs):
+ projector_type = getattr(config, "mm_projector_type", "linear")
+
+ if projector_type == "linear":
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
+
+ if projector_type == "pooler":
+ return PoolerProjector(config, kwargs["vision_cfg"])
+
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ return nn.Sequential(*modules)
+
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
+ if mlp_gelu_resnet_match:
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
+ res_depth = int(mlp_gelu_resnet_match.group(2))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ for _ in range(res_depth):
+ modules.append(SimpleResBlock(config.hidden_size))
+ return nn.Sequential(*modules)
+
+ if projector_type == "identity":
+ return IdentityMap()
+
+ raise ValueError(f"Unknown projector type: {projector_type}")
diff --git a/train/llava/model/multimodal_projector/pooler_projector.py b/train/llava/model/multimodal_projector/pooler_projector.py
new file mode 100644
index 0000000..ce5a2e0
--- /dev/null
+++ b/train/llava/model/multimodal_projector/pooler_projector.py
@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+import math
+
+from transformers.models.clip.modeling_clip import CLIPVisionModel
+
+
+class PoolerProjector(nn.Module):
+ def __init__(self, config, vision_cfg):
+ super().__init__()
+ self._config = config
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
+
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
+
+ self.proj = nn.Sequential(
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+
+ def forward(self, x, *args, **kwargs):
+ height = width = self.hw
+ assert height * width == x.shape[1]
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
+ x = self.conv_pool(x)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": "pooler"}
diff --git a/train/llava/model/multimodal_resampler/builder.py b/train/llava/model/multimodal_resampler/builder.py
new file mode 100644
index 0000000..7a4b207
--- /dev/null
+++ b/train/llava/model/multimodal_resampler/builder.py
@@ -0,0 +1,34 @@
+import torch
+
+from .masked_drop import MaskedDrop
+from .spatial_pool import SpatialPool
+from .perceiver import PerceiverResampler
+from .qformer import Qformer
+
+
+class IdentityMap(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_resampler_type": None}
+
+
+def build_vision_resampler(model_args, delay_load=False, **kwargs):
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
+ if resampler_type == "masked_drop":
+ return MaskedDrop(model_args)
+ elif resampler_type == "spatial_pool":
+ return SpatialPool(model_args, **kwargs)
+ elif resampler_type == "perceiver":
+ return PerceiverResampler(model_args, **kwargs)
+ elif resampler_type == "qformer":
+ return Qformer(model_args, **kwargs)
+ elif resampler_type is None:
+ return IdentityMap()
+
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
diff --git a/train/llava/model/multimodal_resampler/masked_drop.py b/train/llava/model/multimodal_resampler/masked_drop.py
new file mode 100644
index 0000000..03f0bf0
--- /dev/null
+++ b/train/llava/model/multimodal_resampler/masked_drop.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+
+import random
+
+
+class MaskedDrop(nn.Module):
+ def __init__(self, model_args):
+ super().__init__()
+
+ self.mode = model_args.mm_mask_drop_mode
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
+ self.ratio = model_args.mm_mask_drop_ratio
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
+
+ def forward(self, image_features, *args, **kwargs):
+
+ if not self.training:
+ return image_features
+
+ if self.skip_percentage > random.random():
+ return image_features
+
+ masked_features = []
+
+ for image_feature in image_features:
+ num_tokens = image_feature.shape[0]
+ if self.mode == "fixed":
+ num_keep = int(num_tokens * self.ratio)
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
+ elif self.mode == "range":
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
+ elif self.mode == "cls_only":
+ masked_features.append(image_feature[0:1])
+ else:
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
+
+ if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
+ masked_features = torch.stack(masked_features, dim=0)
+
+ return masked_features
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "masked_drop",
+ "mm_mask_drop_mode": self.mode,
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
+ "mm_mask_drop_ratio": self.ratio,
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
+ }
+
+ def random_masking(self, x, len_keep):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
diff --git a/train/llava/model/multimodal_resampler/perceiver.py b/train/llava/model/multimodal_resampler/perceiver.py
new file mode 100644
index 0000000..d6b17a5
--- /dev/null
+++ b/train/llava/model/multimodal_resampler/perceiver.py
@@ -0,0 +1,155 @@
+"""
+Taken from https://github.com/lucidrains/flamingo-pytorch
+"""
+
+import torch
+from einops import rearrange, repeat
+
+try:
+ from einops_exts import rearrange_many
+except:
+ pass
+
+from torch import einsum, nn
+
+
+def exists(val):
+ return val is not None
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm_media = nn.LayerNorm(dim)
+ self.norm_latents = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, T, n2, D)
+ """
+ x = self.norm_media(x)
+ latents = self.norm_latents(latents)
+
+ h = self.heads
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
+ q = q * self.scale
+
+ # attention
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
+ return self.to_out(out)
+
+
+class PerceiverResamplerModule(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth=6,
+ dim_head=64,
+ heads=8,
+ num_latents=64,
+ max_num_media=None,
+ max_num_frames=None,
+ ff_mult=4,
+ ):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
+ ]
+ )
+ )
+
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, F, v, D)
+ Returns:
+ shape (b, T, n, D) where n is self.num_latents
+ """
+ b, T, F, v = x.shape[:4]
+
+ # frame and media time embeddings
+ if exists(self.frame_embs):
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
+ x = x + frame_embs
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
+ if exists(self.media_time_embs):
+ x = x + self.media_time_embs[:T]
+
+ # blocks
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ return self.norm(latents)
+
+
+class PerceiverResampler(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.depth = model_args.mm_perceiver_depth
+ self.num_latents = model_args.mm_perceiver_latents
+ self.ff_mult = model_args.mm_perceiver_ff_mult
+ self.pretrained = model_args.mm_perceiver_pretrained
+
+ self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
+
+ if self.pretrained is not None:
+ self.load_state_dict(torch.load(self.pretrained))
+
+ def forward(self, image_features, *args, **kwargs):
+ return self.perceiver(image_features[:, None, None]).squeeze(1)
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "perceiver",
+ "mm_perceiver_depth": self.depth,
+ "mm_perceiver_latents": self.num_latents,
+ "mm_perceiver_ff_mult": self.ff_mult,
+ "mm_perceiver_pretrained": self.pretrained,
+ }
diff --git a/train/llava/model/multimodal_resampler/qformer.py b/train/llava/model/multimodal_resampler/qformer.py
new file mode 100644
index 0000000..b86754c
--- /dev/null
+++ b/train/llava/model/multimodal_resampler/qformer.py
@@ -0,0 +1,1160 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions, query_length)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Qformer(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.depth = model_args.mm_qformer_depth
+ self.num_latents = model_args.mm_qformer_latents
+ self.pretrained = model_args.mm_qformer_pretrained
+
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
+
+ if self.pretrained is not None:
+ pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
+ self.load_state_dict(pretrained_dict)
+
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ Qformer.cls = None
+ Qformer.bert.embeddings.word_embeddings = None
+ Qformer.bert.embeddings.position_embeddings = None
+ for layer in Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
+
+ def forward(self, image_features, *args, **kwargs):
+ x = self.ln_vision(image_features)
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
+
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=x,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ return query_output.last_hidden_state
+
+ @property
+ def hidden_size(self):
+ return 768
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "qformer",
+ "mm_qformer_depth": self.depth,
+ "mm_qformer_latents": self.num_latents,
+ "mm_qformer_pretrained": self.pretrained,
+ }
diff --git a/train/llava/model/multimodal_resampler/spatial_pool.py b/train/llava/model/multimodal_resampler/spatial_pool.py
new file mode 100644
index 0000000..4bdbe3a
--- /dev/null
+++ b/train/llava/model/multimodal_resampler/spatial_pool.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import math
+
+
+class SpatialPool(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.mode = model_args.mm_spatial_pool_mode
+ self.stride = model_args.mm_spatial_pool_stride
+ self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
+
+ if self.mode == "average":
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
+ elif self.mode == "max":
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
+ elif self.mode == "conv":
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
+ else:
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
+
+ def forward(self, image_features, images, *args, **kwargs):
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
+
+ B, _, F = image_features.shape
+
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
+ image_features_spatial_pool = self.pool(image_features_spatial)
+
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "spatial_pool",
+ "mm_spatial_pool_stride": self.stride,
+ "mm_spatial_pool_mode": self.mode,
+ "mm_spatial_pool_out_channels": self.out_channels,
+ }
+
+ @property
+ def hidden_size(self):
+ return self.out_channels
diff --git a/train/llava/model/simple_heading_mlp.py b/train/llava/model/simple_heading_mlp.py
new file mode 100644
index 0000000..d977db6
--- /dev/null
+++ b/train/llava/model/simple_heading_mlp.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+
+
+class TrajectoryHeadingSimpleMLP(nn.Module):
+ def __init__(self, hidden_layer_dim=512, num_poses=8):
+ super().__init__()
+ self.num_poses = num_poses
+ self.mlp = nn.Sequential(
+ nn.Linear(8 * 2, hidden_layer_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_layer_dim, hidden_layer_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_layer_dim, hidden_layer_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_layer_dim, num_poses * 3),
+ )
+
+ def forward(self, traj_xy):
+ x = traj_xy.reshape(traj_xy.shape[0], -1)
+ out = self.mlp(x)
+ out = out.view(-1, self.num_poses, 3)
+ heading_pred = torch.tanh(out[:, :, 2:]) * 3.14159
+ return heading_pred
+
+
+def load_heading_model(
+ ckpt_path: str, device: str = "cpu", hidden_layer_dim=512, num_poses=8
+) -> nn.Module:
+ model = TrajectoryHeadingSimpleMLP(
+ hidden_layer_dim=hidden_layer_dim, num_poses=num_poses
+ )
+ state_dict = torch.load(ckpt_path, map_location=device)
+ model.load_state_dict(state_dict)
+ model.to(device)
+ model.eval()
+ return model
diff --git a/train/llava/model/utils.py b/train/llava/model/utils.py
new file mode 100644
index 0000000..10652a5
--- /dev/null
+++ b/train/llava/model/utils.py
@@ -0,0 +1,20 @@
+from transformers import AutoConfig
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if "llava" in config and "llava" not in cfg.model_type:
+ assert cfg.model_type == "llama"
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "llava")
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
diff --git a/train/llava/s3_ops.py b/train/llava/s3_ops.py
new file mode 100644
index 0000000..8569aed
--- /dev/null
+++ b/train/llava/s3_ops.py
@@ -0,0 +1,217 @@
+# LINT_ME
+"""Filesystem operations supporting local and remote (obs://, s3://, gs://) paths."""
+from __future__ import annotations
+
+import io
+import json
+import posixpath
+from pathlib import Path
+from typing import IO, List, Union
+from urllib.parse import urlsplit, urlunsplit
+
+import pyarrow as pa
+import pyarrow.parquet as pq
+from PIL import Image
+
+try:
+ import moxing as mox
+ _HAS_MOX = True
+except ImportError:
+ mox = None # type: ignore
+ _HAS_MOX = False
+
+PathLike = Union[str, Path]
+
+# ---------- helpers ----------
+
+
+def _is_remote(p: PathLike) -> bool:
+ """
+ Check if path is a remote storage URL (obs://, s3://, gs://).
+ """
+ s = str(p)
+ if s.startswith(("obs://", "s3://", "gs://")):
+ return True
+ return False
+
+
+def join(base: PathLike, *parts: PathLike) -> str:
+ """
+ Join path components for local or remote paths.
+ """
+ b = str(base)
+ if _is_remote(b):
+ scheme, netloc, path, query, frag = urlsplit(b)
+ segs = [path] + [str(p) for p in parts]
+ new_path = posixpath.join(*[s.lstrip("/") for s in segs if s])
+ if not new_path.startswith("/"):
+ new_path = "/" + new_path
+ return urlunsplit((scheme, netloc, new_path, query, frag))
+ return str(Path(b).joinpath(*map(str, parts)))
+
+def exists(p: PathLike) -> bool:
+ """
+ Check if path exists (local or remote).
+ """
+ if _is_remote(p):
+ if not _HAS_MOX:
+ raise RuntimeError("moxing not installed.")
+ return bool(mox.file.exists(str(p)))
+ return Path(p).exists()
+
+
+def isdir(p: PathLike) -> bool:
+ """
+ Check if path is a directory (local or remote).
+ """
+ if _is_remote(p):
+ if not _HAS_MOX:
+ raise RuntimeError("moxing not installed.")
+ return bool(mox.file.is_directory(str(p)))
+ return Path(p).is_dir()
+
+
+def listdir(p: PathLike) -> List[str]:
+ """
+ List directory contents for local or obs:// paths.
+ Remote returns full paths, local returns names (like os.listdir).
+ """
+ if _is_remote(p):
+ if not _HAS_MOX:
+ raise RuntimeError("moxing not installed.")
+ return mox.file.list_directory(str(p))
+ return [x.name for x in Path(p).iterdir()]
+
+
+def basename(p: PathLike) -> str:
+ """
+ Return the final component of a path, works for local and obs:// URLs.
+ """
+ s = str(p)
+ if _is_remote(s):
+ # Strip trailing slash if present, then take the last segment
+ return s.rstrip("/").rsplit("/",maxsplit=1)[-1]
+ return Path(s).name
+
+
+def split(p: PathLike) -> tuple[str, str]:
+ """
+ Split path into (head, tail).
+ For remote (obs://...), head keeps scheme+netloc+parent, tail is the last component.
+ For local, same as os.path.split.
+ """
+ s = str(p)
+ if _is_remote(s):
+ scheme, netloc, path, query, frag = urlsplit(s)
+ parts = path.rstrip("/").rsplit("/",maxsplit=1)
+ tail = parts[-1] if parts else ""
+ head_path = "/".join(parts[:-1])
+ if head_path and not head_path.startswith("/"):
+ head_path = "/" + head_path
+ head = urlunsplit((scheme, netloc, head_path, query, frag))
+ return head, tail
+ return str(Path(s).parent), Path(s).name
+
+
+def splitext(p: PathLike) -> tuple[str, str]:
+ """
+ Split path into (root, ext).
+ Works like os.path.splitext, for both local and remote.
+ """
+ s = str(p)
+ if _is_remote(s):
+ base = basename(s) # last part only
+ root, ext = posixpath.splitext(base)
+ # Reconstruct full root path (without extension)
+ head, _ = split(s)
+ return join(head, root), ext
+ root, ext = Path(s).with_suffix("").as_posix(), Path(s).suffix
+ return root, ext
+
+
+def open_file(p: PathLike, mode: str = "r", **kwargs) -> IO:
+ """
+ Open a file for local or remote (obs://, s3://, gs://) paths.
+ """
+ if _is_remote(p):
+ if not _HAS_MOX:
+ raise RuntimeError("moxing not installed.")
+ return mox.file.File(str(p), mode, **kwargs) # type: ignore
+ return open(p, mode, **kwargs) # pylint: disable=W1514
+
+
+def json_load(p: PathLike, **json_kwargs):
+ """
+ Load JSON or JSONL files automatically.
+ Returns:
+ - list[dict]: for JSONL or list-based JSON
+ - dict: for JSON object
+ """
+ encoding = json_kwargs.pop("encoding", "utf-8")
+
+ with open_file(p, "r", encoding=encoding) as f:
+ f.seek(0)
+
+ # JSONL: detect line-delimited JSON
+ if p.endswith(".jsonl"):
+ f.seek(0)
+ return [json.loads(line) for line in f if line.strip()]
+ f.seek(0)
+ return json.load(f, **json_kwargs)
+
+
+def image_open(p: PathLike):
+ """
+ Open an image from local or remote path.
+ """
+
+ if str(p).endswith(".parquet"):
+ with open_file(p, "rb") as fp:
+ file_data = fp.read()
+ table = pq.read_table(pa.py_buffer(file_data)).to_pandas()
+ images = {}
+ for cam_id in table.columns:
+ data = table[cam_id].iloc[0] # 取第一行的 bytes 数据
+ try:
+ img = Image.open(io.BytesIO(data))
+ img.load() # 强制加载,释放文件句柄
+ images[cam_id] = img
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to decode image from camera '{cam_id}' in {p}: {e}"
+ ) from e
+ return images["cam_1"]
+
+ f = open_file(p, "rb")
+ try:
+ img = Image.open(f)
+ # Force load into memory so we can close the file immediately
+ img.load()
+ f.close()
+ return img
+ except Exception:
+ f.close()
+ raise
+
+
+def makedirs(p: PathLike, exist_ok: bool = True) -> None:
+ """
+ Create directories (local or remote).
+ - Local: same as Path(...).mkdir(parents=True, exist_ok=...)
+ - Remote: uses mox.file.make_dirs
+ """
+ s = str(p)
+ if _is_remote(s):
+ if not _HAS_MOX:
+ raise RuntimeError("moxing not installed.")
+ # moxing's make_dirs creates all necessary parent dirs automatically
+ if exist_ok:
+ # No harm if it already exists
+ if not mox.file.exists(s):
+ mox.file.make_dirs(s)
+ else:
+ if mox.file.exists(s):
+ raise FileExistsError(f"Remote path already exists: {s}")
+ mox.file.make_dirs(s)
+ else:
+ Path(s).mkdir(parents=True, exist_ok=exist_ok)
diff --git a/train/llava/train/llama_flash_attn_monkey_patch.py b/train/llava/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000..c88fe34
--- /dev/null
+++ b/train/llava/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,87 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/train/llava/train/llava_trainer.py b/train/llava/train/llava_trainer.py
new file mode 100644
index 0000000..e989297
--- /dev/null
+++ b/train/llava/train/llava_trainer.py
@@ -0,0 +1,591 @@
+import os
+import torch
+import torch.nn as nn
+import datetime
+import torch.distributed as dist
+
+from accelerate import Accelerator
+from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin
+from torch.utils.data import Dataset, Sampler, DataLoader
+
+from trl.trainer import DPOTrainer
+from trl.trainer.utils import DPODataCollatorWithPadding
+
+from transformers import Trainer
+from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, GradientAccumulationPlugin, logger, is_accelerate_available, is_datasets_available
+from transformers.trainer_utils import seed_worker
+from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
+from transformers.trainer_pt_utils import AcceleratorConfig
+from typing import List, Optional
+from datetime import timedelta
+
+if is_accelerate_available():
+ from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
+
+if is_datasets_available():
+ import datasets
+
+from llava.utils import rank0_print
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, "no ignore status")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
+ megabatch_size = world_size * batch_size * megabatch_mult
+ megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
+ shuffled_indices = [i for megabatch in megabatches for i in megabatch]
+ world_batch_size = world_size * batch_size
+ batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
+ batch_indices = torch.randperm(len(batches), generator=generator)
+ batches = [batches[i] for i in batch_indices]
+
+ return [i for batch in batches for i in batch]
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ """
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+ lengths. To do this, the indices are:
+
+ - randomly permuted
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
+ - reorder by length in each mega-batch
+
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+ maximum length placed first, so that an OOM happens sooner rather than later.
+ """
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ """
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+ lengths. To do this, the indices are:
+
+ - randomly permuted
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
+ - reorder by length in each mega-batch
+
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+ maximum length placed first, so that an OOM happens sooner rather than later.
+ """
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
+ indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
+
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ batch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in batch_indices]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ # FIXME: Hard code to avoid last batch mixed with different modalities
+ # if len(additional_batch) > 0:
+ # megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ variable_length: bool = False,
+ group_by_modality: bool = False,
+ group_by_modality_auto: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.variable_length = variable_length
+ self.group_by_modality = group_by_modality
+ self.group_by_modality_auto = group_by_modality_auto
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.variable_length:
+ assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
+ indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ elif self.group_by_modality_auto:
+ indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class LLaVATrainer(Trainer):
+ def __init__(self, obs_upload_path=None, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.obs_upload_path = obs_upload_path
+
+ def create_accelerator_and_postprocess(self):
+ grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
+ grad_acc_kwargs["sync_with_dataloader"] = False
+ gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
+
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
+ rank0_print("Setting NCCL timeout to INF to avoid running errors.")
+
+ # create accelerator object
+ self.accelerator = Accelerator(
+ split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs]
+ )
+ # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
+ self.gather_function = self.accelerator.gather_for_metrics
+
+ # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+
+ # post accelerator creation setup
+ if self.is_fsdp_enabled:
+ fsdp_plugin = self.accelerator.state.fsdp_plugin
+ fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers)
+ if is_accelerate_available("0.23.0"):
+ fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing)
+ if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
+ raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.")
+
+ if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
+ self.propagate_args_to_deepspeed()
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_length:
+ lengths = self.train_dataset.lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ )
+ elif self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ elif self.args.group_by_modality_length_auto:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ group_by_modality_auto=True,
+ )
+ elif self.args.group_by_varlen:
+ lengths = self.train_dataset.lengths
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ variable_length=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+ dataloader_params = {
+ "batch_size": self._train_batch_size,
+ "collate_fn": data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "persistent_workers": self.args.dataloader_persistent_workers,
+ }
+
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
+ dataloader_params["sampler"] = self._get_train_sampler()
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
+ dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
+
+ if hasattr(self.args, "use_webdataset") and self.args.use_webdataset:
+ dataloader = DataLoader(train_dataset, **dataloader_params)
+ else:
+ dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
+
+ return dataloader
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ lr_mapper = {}
+ if self.args.mm_projector_lr is not None:
+ lr_mapper["mm_projector"] = self.args.mm_projector_lr
+ if self.args.mm_vision_tower_lr is not None:
+ lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
+ if len(lr_mapper) > 0:
+ special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ },
+ ]
+ for module_keyword, lr in lr_mapper.items():
+ module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
+ optimizer_grouped_parameters.extend(
+ [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ "lr": lr,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ "lr": lr,
+ },
+ ]
+ )
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False) or (
+ hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
+ ):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+
+ else:
+ #super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ self.save_model(output_dir, _internal_call=True)
+
+ if not self.args.save_only_model:
+ # Save optimizer and scheduler
+ self._save_optimizer_and_scheduler(output_dir)
+ # Save RNG state
+ self._save_rng_state(output_dir)
+
+ # Determine the new best metric / best model checkpoint
+ if metrics is not None and self.args.metric_for_best_model is not None:
+ import numpy as np
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ metric_value = metrics[metric_to_check]
+
+ operator = np.greater if self.args.greater_is_better else np.less
+ if (
+ self.state.best_metric is None
+ or self.state.best_model_checkpoint is None
+ or operator(metric_value, self.state.best_metric)
+ ):
+ self.state.best_metric = metric_value
+ self.state.best_model_checkpoint = output_dir
+
+ # Save the Trainer state
+ if self.args.should_save:
+ TRAINER_STATE_NAME = "trainer_state.json"
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+ if self.args.push_to_hub:
+ self._push_from_checkpoint(output_dir)
+
+ if self.obs_upload_path:
+ if dist.get_rank()==0:
+ log_dir = os.path.join(os.path.dirname(output_dir),"runs")
+ mox.file.set_auth(ak="",sk="")
+ mox.file.copy_parallel(log_dir, self.obs_upload_path+log_dir)
+ mox.file.copy_parallel(output_dir, self.obs_upload_path+output_dir)
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ # Solely rely on numerical checkpoint id for rotation.
+ # mtime is not reliable especially on some fuse fs in cloud environments.
+ self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
+
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
+ pass
+ else:
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
+
+
+class LLaVADPOTrainer(DPOTrainer):
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ world_size=self.args.world_size,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False) or (
+ hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
+ ):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+ else:
+ # super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
+ # print(type(model))
+ # from transformers.modeling_utils import unwrap_model
+ # print(type(unwrap_model(model)))
+ # print(unwrap_model(model).config)
+ if self.args.lora_enable:
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ from transformers.modeling_utils import unwrap_model
+
+ unwrapped_model = unwrap_model(model)
+ self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model)
+ else:
+ super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
+ pass
+ else:
+ super(LLaVADPOTrainer, self)._save(output_dir, state_dict)
diff --git a/train/llava/train/llava_trainer_eval.py b/train/llava/train/llava_trainer_eval.py
new file mode 100644
index 0000000..e822258
--- /dev/null
+++ b/train/llava/train/llava_trainer_eval.py
@@ -0,0 +1,76 @@
+import json
+import subprocess
+
+from llava.train.llava_trainer import LLaVATrainer
+
+
+class LLaVAEvalTrainer(LLaVATrainer):
+ def evaluate(self, evaluate_args):
+ cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
+ --model {evaluate_args.model} \
+ --model_args {evaluate_args.model_args} \
+ --tasks {evaluate_args.task_names} \
+ --batch_size {evaluate_args.batch_size} \
+ --log_samples_suffix {evaluate_args.log_samples_suffix} \
+ --output_path {evaluate_args.output_path}"
+ if evaluate_args.limit:
+ cmd += f" --limit {evaluate_args.limit}"
+ if evaluate_args.num_fewshot:
+ cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
+ if evaluate_args.gen_kwargs != "":
+ cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
+ if evaluate_args.log_samples:
+ cmd += f" --log_samples"
+ else:
+ assert False, "Please log samples so that the result can be parsed"
+ results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
+ try:
+ result_file_index_start = results.stdout.index("Saved samples to ")
+ result_file_index_end = results.stdout.index(f".json")
+ result_file_index_start += len("Saved samples to ")
+ file = results.stdout[result_file_index_start:result_file_index_end]
+ except:
+ result_file_index_start = results.stderr.index("Saved samples to ")
+ result_file_index_end = results.stderr.index(f".json")
+ result_file_index_start += len("Saved samples to ")
+ file = results.stderr[result_file_index_start:result_file_index_end]
+ file = file.split("/")[:-1]
+ file = "/".join(file) + "/results.json"
+ with open(file, "r") as f:
+ lmms_eval_results = json.load(f)
+ result_dict = {}
+ tasks_list = evaluate_args.task_names.split(",")
+ for task in tasks_list:
+ task_results = lmms_eval_results["results"][task]
+ for k, v in task_results.items():
+ if k != "alias" and "stderr" not in k:
+ metric = k.split(",")[0]
+ result_dict[f"{task}_{metric}"] = v
+ return result_dict
+
+ """def evaluate(self, evaluate_args):
+ initialize_tasks()
+ tasks_list = evaluate_args.task_names.split(",")
+ result_dict = {}
+ results = evaluator.simple_evaluate(
+ model=evaluate_args.model,
+ model_args=evaluate_args.model_args,
+ tasks=tasks_list,
+ num_fewshot=evaluate_args.num_fewshot,
+ batch_size=evaluate_args.batch_size,
+ device=evaluate_args.device,
+ limit=evaluate_args.limit,
+ check_integrity=evaluate_args.check_integrity,
+ show_task_to_terminal=evaluate_args.show_task_to_terminal,
+ log_samples=evaluate_args.log_samples,
+ gen_kwargs=evaluate_args.gen_kwargs,
+ cli_args=evaluate_args,
+ )
+ for task in tasks_list:
+ task_results = results["results"][task]
+ for k,v in task_results.items():
+ if k != "alias" and "stderr" not in k:
+ metric = k.split(",")[0]
+ result_dict[f"{task}_{metric}"] = v
+
+ return result_dict"""
diff --git a/train/llava/train/train.py b/train/llava/train/train.py
new file mode 100644
index 0000000..bda1869
--- /dev/null
+++ b/train/llava/train/train.py
@@ -0,0 +1,1773 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import copy
+import logging
+import os
+import pathlib
+import random
+import re
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Sequence
+
+import numpy as np
+import tokenizers
+import torch
+import torch.distributed as dist
+import transformers
+from llava import conversation as conversation_lib
+from llava.constants import (
+ DEFAULT_IM_END_TOKEN,
+ DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_TOKEN,
+ IGNORE_INDEX,
+ IMAGE_TOKEN_INDEX,
+)
+from llava.mm_utils import (
+ process_anyres_image,
+ process_highres_image,
+ process_highres_image_crop_split,
+ tokenizer_image_token,
+)
+from llava.model import *
+from llava.train.llava_trainer import LLaVATrainer
+from llava.utils import (
+ rank0_print,
+ read_frames_gif,
+ read_frames_pyav,
+)
+from packaging import version
+from PIL import Image, ImageFile
+from torch.utils.data import Dataset
+from transformers import AutoConfig
+
+import train.llava.s3_ops as st
+
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+local_rank = None
+
+IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse(
+ "0.14"
+)
+
+import warnings
+
+warnings.filterwarnings("ignore")
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ model_class_name: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"
+ },
+ )
+
+ mm_tunable_parts: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'
+ },
+ )
+ # deciding which part of the multimodal model to tune, will overwrite other previous settings
+
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ tune_mm_vision_resampler: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ vision_tower_pretrained: Optional[str] = field(
+ default=None
+ ) # default to the last layer
+
+ num_token: Optional[bool] = field(default=False)
+
+ unfreeze_mm_vision_tower: bool = field(default=False)
+ unfreeze_language_model: bool = field(default=False)
+ mm_vision_select_layer: Optional[int] = field(
+ default=-1
+ ) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default="linear")
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_patch_merge_type: Optional[str] = field(default="flat")
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+ mm_resampler_type: Optional[str] = field(default=None)
+ mm_mask_drop_mode: str = field(default="fixed")
+ mm_mask_drop_skip_percentage: float = field(default=0.0)
+ mm_mask_drop_ratio: float = field(default=0.25)
+ mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
+ mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
+ mm_spatial_pool_stride: Optional[int] = field(default=None)
+ mm_spatial_pool_mode: str = field(default="bilinear")
+ mm_spatial_pool_out_channels: Optional[int] = field(default=None)
+ mm_perceiver_depth: Optional[int] = field(default=3)
+ mm_perceiver_latents: Optional[int] = field(default=32)
+ mm_perceiver_ff_mult: Optional[float] = field(default=4)
+ mm_perceiver_pretrained: Optional[str] = field(default=None)
+ mm_qformer_depth: Optional[int] = field(default=3)
+ mm_qformer_latents: Optional[int] = field(default=32)
+ mm_qformer_pretrained: Optional[str] = field(default=None)
+
+ rope_scaling_factor: Optional[float] = field(default=None)
+ rope_scaling_type: Optional[str] = field(default=None)
+
+ s2: Optional[bool] = field(default=False)
+ s2_scales: Optional[str] = field(default="336,672,1008")
+
+ use_pos_skipping: Optional[bool] = field(default=False)
+ pos_skipping_range: Optional[int] = field(default=4096)
+
+ mm_newline_position: Optional[str] = field(default="grid")
+ delay_load: Optional[bool] = field(default=True)
+ add_faster_video: Optional[bool] = field(default=False)
+ faster_token_stride: Optional[int] = field(default=10)
+
+ num_experts_per_tok: Optional[int] = field(default=-1)
+ num_experts: Optional[int] = field(default=-1)
+ moe_choice: Optional[str] = field(default="expert")
+ moe_router_score_function: Optional[str] = field(default=None)
+ moe_lora_rank: Optional[int] = field(default=32)
+ moe_lora_alpha: Optional[int] = field(default=32)
+ moe_lora_in_features: Optional[int] = field(default=4096)
+ moe_lora_out_features: Optional[int] = field(default=4096)
+ moe_lora_dropout: Optional[float] = field(default=0.0)
+ capacity_factor: Optional[float] = field(default=0.1)
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(
+ default=None,
+ metadata={
+ "help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"
+ },
+ )
+
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ early_mix_text: bool = False
+ use_webdataset: bool = False
+ total_samples: Optional[int] = field(default=None)
+
+ image_folder: Optional[str] = field(default=None)
+ image_folder_2: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the second image folder, used as a fallback."},
+ )
+ image_aspect_ratio: str = "square"
+ image_grid_pinpoints: Optional[str] = field(default=None)
+ image_crop_resolution: Optional[int] = field(default=None)
+ image_split_resolution: Optional[int] = field(default=None)
+
+ video_folder: Optional[str] = field(default=None)
+ video_fps: Optional[int] = field(default=1)
+ frames_upbound: Optional[int] = field(default=0)
+ add_time_instruction: Optional[bool] = field(default=False)
+ force_sample: Optional[bool] = field(default=False)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ freeze_mm_vision_resampler: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=4096,
+ metadata={
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={
+ "help": "Compress the quantization statistics through double quantization."
+ },
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={
+ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
+ },
+ )
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ mm_vision_tower_lr: Optional[float] = None
+ group_by_varlen: bool = field(default=False)
+ group_by_modality_length: bool = field(default=False)
+ group_by_modality_length_auto: bool = field(default=False)
+ auto_find_batch_size: bool = field(default=False)
+ gradient_checkpointing: bool = field(default=True)
+ verbose_logging: bool = field(default=False)
+ attn_implementation: str = field(
+ default="flash_attention_2",
+ metadata={"help": "Use transformers attention implementation."},
+ )
+ use_conversation_mask: bool = field(default=True)
+ split_batches: bool = field(default=False)
+ torch_empty_cache_steps: int = field(default=1)
+ fp16: bool = field(default=False)
+ debug_mode: bool = field(default=False)
+ obs_upload_path: str = field(
+ default="", metadata={"help": "Path to the obs upload path."}
+ )
+ output_dir: str = field(
+ default="exp", metadata={"help": "Path to the output directory."}
+ )
+ save_only_model: bool = True
+
+
+# @dataclass
+# class EvaluationArguments:
+# eval_num_processes: int = field(default=1)
+# task_names: str = field(default=None)
+# model: str = field(default="llava")
+# model_args: Optional[str] = field(default=None)
+# num_fewshot: Optional[int] = field(default=None)
+# batch_size: int = field(default=1)
+# device: Optional[str] = field(default=None)
+# limit: Optional[int] = field(default=None)
+# check_integrity: Optional[bool] = field(default=False)
+# show_task_to_terminal: Optional[bool] = field(default=False)
+# log_samples: Optional[bool] = field(default=True)
+# gen_kwargs: Optional[str] = field(default="")
+# log_samples_suffix: Optional[str] = field(default="")
+# output_path: Optional[str] = field(default="./logs/")
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(
+ f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
+ )
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
+ }
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {
+ k: t
+ for k, t in named_params
+ if any(key_match in k for key_match in keys_to_match)
+ }
+ to_return = {
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
+ }
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split(".")
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if "lm_head" in lora_module_names: # needed for 16-bit
+ lora_module_names.remove("lm_head")
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if (
+ hasattr(trainer.args, "tune_mm_mlp_adapter")
+ and trainer.args.tune_mm_mlp_adapter
+ ):
+ check_only_save_mm_adapter_tunnable = True
+ # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
+ elif hasattr(trainer.args, "mm_tunable_parts") and (
+ len(trainer.args.mm_tunable_parts.split(",")) == 1
+ and (
+ "mm_mlp_adapter" in trainer.args.mm_tunable_parts
+ or "mm_vision_resampler" in trainer.args.mm_tunable_parts
+ )
+ ):
+ check_only_save_mm_adapter_tunnable = True
+ else:
+ check_only_save_mm_adapter_tunnable = False
+
+ trainer.accelerator.wait_for_everyone()
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ elif hasattr(torch, "npu") and torch.npu.is_available():
+ torch.npu.synchronize()
+ # torch.cuda.synchronize()
+
+ rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
+ if check_only_save_mm_adapter_tunnable:
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(
+ trainer.model.named_parameters(), keys_to_match
+ )
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split("/")[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith("checkpoint-"):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(
+ weight_to_save,
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
+ )
+ else:
+ torch.save(
+ weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
+ )
+ return
+
+ if trainer.deepspeed:
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True
+ )
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True
+ )
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
+) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = "unknown"
+ sentence["value"] = (
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
+ )
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ # TODO maybe this should be changed for interleaved data?
+ # if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
+ # only check for num_im=1
+ num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
+ if (
+ num_im == 1
+ and DEFAULT_IMAGE_TOKEN in sentence["value"]
+ and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN)
+ ):
+ sentence["value"] = (
+ sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ )
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
+ sentence["value"] = sentence["value"].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence["value"] = sentence["value"].replace(
+ DEFAULT_IMAGE_TOKEN,
+ "" + DEFAULT_IMAGE_TOKEN + "",
+ )
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ sentence["value"] = sentence["value"].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+
+ # For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here.
+ sentence["value"] = sentence["value"].replace(
+ "QA_GT_caption_based_noisy", ""
+ )
+
+ return sources
+
+
+def preprocess_llada(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ max_len=2048,
+ system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
+) -> Dict:
+ # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
+ roles = {"human": "user", "gpt": "assistant", "system": "system"}
+
+ # Add image tokens to tokenizer as a special tokens
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
+ # tokenizer = copy.deepcopy(tokenizer)
+ # # When there is actually an image, we add the image tokens as a special token
+ # if has_image:
+ # tokenizer.add_tokens([""], special_tokens=True)
+ image_token_index = tokenizer.convert_tokens_to_ids("")
+ bos_token_id = tokenizer.convert_tokens_to_ids("<|startoftext|>")
+ start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
+ end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
+
+ unmask_tokens = [
+ "<|startoftext|>",
+ "<|start_header_id|>",
+ "<|end_header_id|>",
+ "<|eot_id|>",
+ "\n\n",
+ ]
+ unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
+ # Reset LLaDA chat templates so that it won't include assistant message every time we apply
+ chat_template = "{% for message in messages %}{{'<|startoftext|>' + '<|start_header_id|>' + message['role'] + '<|end_header_id|>' + '\n\n' + message['content'] + '<|eot_id|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
+ tokenizer.chat_template = chat_template
+
+ # After update, calling tokenizer of llama3 will
+ # auto add bos id for the tokens. ヽ(`⌒´)ノ
+ def safe_tokenizer_llama3(text):
+ input_ids = tokenizer(text).input_ids
+ if input_ids[0] == bos_token_id:
+ input_ids = input_ids[1:]
+ return input_ids
+
+ nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
+ # Apply prompt templates
+ input_ids, targets = [], []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] == roles["system"]:
+ try:
+ system_message = source[0]["content"]
+ except:
+ system_message = source[0]["value"]
+ source = source[1:]
+
+ input_id, target = [], []
+
+ # New version, use apply chat template
+ # Build system message for each sentence
+ input_id += tokenizer.apply_chat_template(
+ [{"role": "system", "content": system_message}]
+ )
+ target += [IGNORE_INDEX] * len(input_id)
+
+ for conv in source:
+ # Make sure llava data can load
+ try:
+ role = conv["role"]
+ content = conv["content"]
+ except:
+ role = conv["from"]
+ content = conv["value"]
+
+ role = roles.get(role, role)
+
+ if "video" in content:
+ content = content.replace("