diff --git a/.github/workflows/static-check.yml b/.github/workflows/static-check.yml index e1530b5..5bf6c12 100644 --- a/.github/workflows/static-check.yml +++ b/.github/workflows/static-check.yml @@ -15,42 +15,10 @@ jobs: uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - - name: Check requirements.txt exists - id: check_req - run: | - if [ -f requirements.txt ]; then - echo "requirements_exists=true" >> $GITHUB_OUTPUT - else - echo "requirements_exists=false" >> $GITHUB_OUTPUT - fi - if [ -f pyproject.toml ]; then - echo "pyproject_exists=true" >> $GITHUB_OUTPUT - else - echo "pyproject_exists=false" >> $GITHUB_OUTPUT - fi - - name: Install dependencies by requirements.txt + - name: Install pylint id: install_deps_req - if: ${{ steps.check_req.outputs.requirements_exists == 'true' }} run: | python -m pip install --upgrade pylint - python -m pip install --upgrade isort - python -m pip install -r requirements.txt - echo "dependencies_installed=true" >> $GITHUB_OUTPUT - - name: Analysing the code with pylint - if: ${{ steps.check_req.outputs.requirements_exists == 'true' }} - run: | - isort $(git ls-files '*.py') --check-only --diff - pylint $(git ls-files '*.py') - - name: Install dependencies by uv - id: install_deps_uv - if: ${{ steps.check_req.outputs.pyproject_exists == 'true' }} - run: | - python -m pip install uv - uv sync - uv pip install pylint - uv pip install isort - name: Analysing the code with pylint - if: ${{ steps.check_req.outputs.pyproject_exists == 'true' }} run: | - uv run isort $(git ls-files '*.py') --check-only --diff - uv run pylint $(git ls-files '*.py') + grep -r -l "# LINT_ME" **/*.py|xargs pylint \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 7e930ac..01d1801 100644 --- a/.pylintrc +++ b/.pylintrc @@ -31,7 +31,7 @@ extension-pkg-allow-list= # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. (This is an alternative name to extension-pkg-allow-list # for backward compatibility.) -extension-pkg-whitelist=cv2 +extension-pkg-whitelist= # Return non-zero exit code if any of these messages/categories are detected, # even if score is above --fail-under value. Syntax same as enable. Messages @@ -59,15 +59,16 @@ ignore-paths= # Emacs file locks ignore-patterns=^\.# -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules=cv2 +# List of module names for which member attributes should not be checked and +# will not be imported (useful for modules/projects where namespaces are +# manipulated during runtime and thus existing member attributes cannot be +# deduced by static analysis). It supports qualified module names, as well as +# Unix pattern matching. +ignored-modules= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). -init-hook='import sys; sys.path.append(".")' +#init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to @@ -86,9 +87,13 @@ load-plugins= # Pickle collected data for later comparisons. persistent=yes +# Resolve imports to .pyi stubs if available. May reduce no-member messages and +# increase not-an-iterable messages. +prefer-stubs=no + # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.10 +py-version=3.12 # Discover python modules and packages in the file system subtree. recursive=no @@ -99,10 +104,6 @@ recursive=no # source root. source-roots= -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no @@ -229,6 +230,11 @@ name-group= # not require a docstring. no-docstring-rgx=^_ +# Regular expression matching correct parameter specification variable names. +# If left empty, parameter specification variable names will be checked with +# the set naming style. +#paramspec-rgx= + # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. @@ -242,13 +248,17 @@ property-classes=abc.abstractproperty # variable names will be checked with the set naming style. #typevar-rgx= +# Regular expression matching correct type variable tuple names. If left empty, +# type variable tuple names will be checked with the set naming style. +#typevartuple-rgx= + # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. If left empty, variable names will be checked with the set # naming style. -variable-rgx=(_?[a-z][A-Za-z0-9]{0,30})|([A-Z0-9]{1,30}) +#variable-rgx= [CLASSES] @@ -285,10 +295,10 @@ exclude-too-few-public-methods= ignored-parents= # Maximum number of arguments for function / method. -max-args=7 +max-args=5 # Maximum number of attributes for a class (see R0902). -max-attributes=20 +max-attributes=7 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 @@ -302,6 +312,9 @@ max-locals=15 # Maximum number of parents for a class (see R0901). max-parents=7 +# Maximum number of positional arguments for function / method. +max-positional-arguments=5 + # Maximum number of public methods for a class (see R0904). max-public-methods=20 @@ -309,10 +322,10 @@ max-public-methods=20 max-returns=6 # Maximum number of statements in function / method body. -max-statements=300 +max-statements=50 # Minimum number of public methods for a class (see R0903). -min-public-methods=1 +min-public-methods=2 [EXCEPTIONS] @@ -336,11 +349,13 @@ indent-after-paren=4 # tab). indent-string=' ' -# Maximum number of characters on a single line. -max-line-length=150 +# Maximum number of characters on a single line. Pylint's default of 100 is +# based on PEP 8's guidance that teams may choose line lengths up to 99 +# characters. +max-line-length=100 # Maximum number of lines in a module. -max-module-lines=2000 +max-module-lines=1000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. @@ -421,11 +436,16 @@ confidence=HIGH, # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". -disable=too-many-arguments, - too-many-locals, - too-many-branches, - protected-access - +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -433,6 +453,13 @@ disable=too-many-arguments, # it should appear only once). See also the "--disable" option for examples. enable= +[MASTER] +# A comma-separated list of package or module names +# from where C extensions may#be loaded. Extensions +# are loading into the active Python interpreter and +# may run arbitrary code . +extension-pkg-allow-list=torch,transformers + [METHOD_ARGS] @@ -443,9 +470,13 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests. [MISCELLANEOUS] +# Whether or not to search for fixme's in docstrings. +check-fixme-in-docstring=no + # List of note tags to take in consideration, separated by a comma. notes=FIXME, - XXX + XXX, + TODO # Regular expression of note tags to take in consideration. notes-rgx= @@ -465,7 +496,7 @@ never-returning-functions=sys.exit,argparse.parse_error # Let 'consider-using-join' be raised when the separator to join on would be # non-empty (resulting in expected fixes of the type: ``"- " + " - # ".join(items)``) -# suggest-join-with-non-empty-separator=yes +suggest-join-with-non-empty-separator=yes [REPORTS] @@ -481,10 +512,10 @@ evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor # used to format the message information. See doc for all details. msg-template= -# Set the output format. Available formats are: text, parseable, colorized, -# json2 (improved json format), json (old json format) and msvs (visual -# studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. +# Set the output format. Available formats are: 'text', 'parseable', +# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs +# (visual studio) and 'github' (GitHub actions). You can also give a reporter +# class, e.g. mypackage.mymodule.MyReporterClass. #output-format= # Tells whether to display a full report or only the messages. @@ -582,11 +613,14 @@ ignored-checks-for-mixins=no-member, # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +ignored-modules=torch,torch.distributed,transformers,transformers.hf_argparser + # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes -# The minimum edit distance a name should have in order to be considered a +# The maximum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 @@ -630,4 +664,4 @@ init-import=no # List of qualified module names which can have objects that can redefine # builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..8cc1b46 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10.15 diff --git a/README.md b/README.md index 0fcb5dd..4926ca2 100644 --- a/README.md +++ b/README.md @@ -1 +1,127 @@ -# Coming Soon... +

WAM-Diff: A Masked Diffusion VLA Framework with MoE and Online Reinforcement Learning for Autonomous Driving

+
+ Mingwang Xu1*  + Jiahao Cui1*  + Feipeng Cai2*  + Hanlin Shang1*  + Zhihao Zhu1  + Shan Luan1  +
+
+ Yifang Xu1  + Neng Zhang2  + Yaoyi Li2  + Jia Cai2  + Siyu Zhu1  +
+ +
+ 1Fudan University  2Yinwang Intelligent Technology Co., Ltd  +
+ + +## 📅️ Roadmap + +| Status | Milestone | ETA | +| :----: | :----------------------------------------------------------------------------------------------------: | :--------: | +| 🚀 | **[Releasing the inference source code](https://github.com/fudan-generative-vision/WAM-Diff)** | 2025.12.21 | +| 🚀 | **[Releasing the training scripts](https://github.com/fudan-generative-vision/WAM-Diff)** | 2025.12.21 | +| 🚀 | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Diff)** | TBD | + + +## 🔧️ Framework +![framework](assets/main_arch.png) + +## 🏆 Qualitative Results on NAVSIM +### NAVSIM-v1 benchmark results +
+ navsim-v1 +
+ +### NAVSIM-v2 benchmark results +
+navsim-v2 +
+ + + +## Quick Inference Demo +The WAM-Diff will be available on Hugging Face Hub soon. To quickly test the model, follow these simple steps: + +1. **Clone the repository** + ```bash + git clone https://github.com/fudan-generative-vision/WAM-Diff + cd WAM-Diff + ``` +2. **Initialize the environment** + If you prefer conda, run the environment setup script to install necessary dependencies: + ```bash + bash init_env.sh + ``` + Or you can use uv to create the environment: + ```bash + uv venv && uv sync + ``` +3. **Prepare the Model** + Download the pretrained WAM-Diff model from Hugging Face (pending release) to the `./model/WAM-Diff` directory: + ``` + https://huggingface.co/fudan-generative-ai/WAM-Diff + ``` + Download the pretrained Siglip2 model from Hugging Face to the `./model/siglip2-so400m-patch14-384` directory: + ``` + https://huggingface.co/google/siglip2-so400m-patch14-384 + ``` + + +3. **Run the demo script** + Execute the demo script to test WAM-Diff on an example image: + ```bash + bash inf.sh + ``` + +## Training +To fine-tune WAM-Diff, please follow these steps: +1. **Set Up the Environment** + Follow the same environment setup steps as in the Quick Inference Demo section. +2. **Prepare the Data** +Prepare your training dataset in JSON format like + ```json + [ + { + "image": ["path/to/image1.png"], + "conversations": [ + { + "from": "human", + "value": "Here is front views of a driving vehicle:\n\nThe navigation information is: straight\nThe current position is (0.00,0.00)\nCurrent velocity is: (13.48,-0.29) and current accelerate is: (0.19,0.05)\nPredict the optimal driving action for the next 4 seconds with 8 new waypoints." + }, + { + "from": "gpt", + "value": "6.60,-0.01,13.12,-0.03,19.58,-0.04,25.95,-0.03,32.27,-0.03,38.56,-0.05,44.88,-0.06,51.16,-0.09" + } + ] + }, + ... + ] + ``` +3. **Run the Training Script** + Execute the training script with the following command: + ```bash + cd train + bash ./scripts/llada_v_finetune.sh + ``` + +## 📝 Citation + +If you find our work useful for your research, please consider citing the paper: + +``` +@article{xu2025wam, + title={WAM-Diff: A Masked Diffusion VLA Framework with MoE and Online Reinforcement Learning for Autonomous Driving}, + author={Xu, Mingwang and Cui, Jiahao and Cai, Feipeng and Shang, Hanlin and Zhu, Zhihao and Luan, Shan and Xu, Yifang and Zhang, Neng and Li, Yaoyi and Cai, Jia and others}, + journal={arXiv preprint arXiv:2512.11872}, + year={2025} +} +``` + +## 🤗 Acknowledgements +We gratefully acknowledge the contributors to the [LLaDA-V](https://github.com/ML-GSAI/LLaDA-V), repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models. \ No newline at end of file diff --git a/assets/image.png b/assets/image.png new file mode 100644 index 0000000..3d1b68c Binary files /dev/null and b/assets/image.png differ diff --git a/assets/main_arch.png b/assets/main_arch.png new file mode 100644 index 0000000..2b3f8e4 Binary files /dev/null and b/assets/main_arch.png differ diff --git a/assets/navsim-v1.png b/assets/navsim-v1.png new file mode 100644 index 0000000..8ca33c4 Binary files /dev/null and b/assets/navsim-v1.png differ diff --git a/assets/navsim-v2.png b/assets/navsim-v2.png new file mode 100644 index 0000000..26f5b62 Binary files /dev/null and b/assets/navsim-v2.png differ diff --git a/envs.yml b/envs.yml new file mode 100644 index 0000000..d3215ed --- /dev/null +++ b/envs.yml @@ -0,0 +1,7 @@ +name: WAM-Diff +channels: + - conda-forge + - bioconda + - defaults +dependencies: + - python=3.10 diff --git a/inf.sh b/inf.sh new file mode 100644 index 0000000..2afc538 --- /dev/null +++ b/inf.sh @@ -0,0 +1 @@ +torchrun --master_addr=127.0.0.1 --master_port=12346 --nproc-per-node=1 ./train/generate_demo_navsim_moe.py --pretrained_path path_to_WAM-Diff_ckpt \ No newline at end of file diff --git a/init_env.sh b/init_env.sh new file mode 100644 index 0000000..2bd26e4 --- /dev/null +++ b/init_env.sh @@ -0,0 +1,48 @@ +conda env create -f envs.yml +conda activate WAM-Diff + +#!/bin/bash + +# log file for failed installations +FAIL_LOG="failed_packages.log" + +# rm existing log file +> $FAIL_LOG + +while IFS= read -r line; do + # skip empty lines and comments + if [[ -z "$line" || "$line" =~ ^# ]]; then + continue + fi + + echo "=====================================" + echo "Installing: $line" + echo "=====================================" + + # Execute installation command + pip install $line $PIP_MIRROR + + # Check if installation was successful + if [ $? -ne 0 ]; then + echo "Installation failed: $line" >> $FAIL_LOG + echo "=====================================" + echo "⚠️ $line installation failed, logged to $FAIL_LOG" + echo "=====================================" + else + echo "=====================================" + echo "✅ $line installed successfully" + echo "=====================================" + fi + sleep 1 + +done < "requirements.txt" + +echo "=====================================" +echo "Installation complete!" +if [ -s $FAIL_LOG ]; then + echo "❌ The following packages failed to install:" + cat $FAIL_LOG +else + echo "✅ All packages installed successfully!" +fi +echo "=====================================" \ No newline at end of file diff --git a/model/.gitkeep b/model/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8715476 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,65 @@ +[project] +name = "wam-diff" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = "==3.10.15" +dependencies = [ + "accelerate==0.29.1", + "aioboto3>=15.5.0", + "aiofiles>=25.1.0", + "bokeh==2.4.3", + "casadi>=3.7.2", + "control==0.9.1", + "datasets==2.16.1", + "deepspeed==0.14.4", + "einops>=0.8.1", + "fiona>=1.10.1", + "geopandas>=0.12.1", + "guppy3==3.1.2", + "hydra-core==1.2.0", + "imageio>=2.37.2", + "joblib>=1.5.2", + "matplotlib>=3.10.8", + "nest-asyncio>=1.6.0", + "notebook>=7.5.0", + "numpy==1.23.4", + "nuplan-devkit", + "opencv-python==4.9.0.80", + "pandas>=2.3.3", + "pillow>=12.0.0", + "positional-encodings==6.0.1", + "protobuf==3.20.0", + "psutil>=7.1.3", + "pyarrow>=22.0.0", + "pyinstrument>=5.1.1", + "pylint>=4.0.4", + "pynvml==12.0.0", + "pyogrio>=0.12.1", + "pyquaternion>=0.9.5", + "pytest>=9.0.2", + "pytorch-lightning==2.2.1", + "rasterio>=1.3.11", + "retry>=0.9.2", + "rtree>=1.4.1", + "scikit-learn==1.2.2", + "scipy>=1.13.1", + "selenium>=4.39.0", + "setuptools==65.5.1", + "shapely>=2.0.0", + "sqlalchemy==1.4.27", + "sympy>=1.14.0", + "tensorboard==2.16.2", + "timm>=1.0.22", + "torch==2.8.0", + "torchvision==0.23.0", + "tornado>=6.5.3", + "tqdm>=4.67.1", + "transformers==4.43.0", + "tyro==0.9.28", + "ujson>=5.11.0", +] +[[tool.uv.index]] +url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0404460 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,203 @@ +accelerate==0.29.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +anls==0.0.2 +annotated-types==0.7.0 +anyio==4.10.0 +async-timeout==5.0.1 +attrs==25.3.0 +av==15.0.0 +bitsandbytes==0.45.3 +black==24.1.0 +blis==1.3.0 +braceexpand==0.1.7 +capture-metric==0.1.13 +catalogue==2.0.10 +certifi==2025.8.3 +cfgv==3.4.0 +chardet==5.2.0 +charset-normalizer==3.4.3 +click==8.2.2 +cloudpathlib==0.21.1 +colorama==0.4.6 +confection==0.1.5 +cymem==2.0.11 +dataproperty==1.1.0 +datasets==2.16.1 +debugpy==1.8.17 +decord==0.6.0 +deepspeed==0.14.4 +dill==0.3.7 +discosg==0.0.5 +distlib==0.4.0 +distro==1.9.0 +docstring-parser==0.17.0 +einops==0.6.1 +einops-exts==0.0.4 +et-xmlfile==2.0.0 +evaluate==0.4.5 +exceptiongroup==1.3.0 +factualscenegraph==0.7.3 +fastapi==0.116.1 +filelock==3.19.1 +frozenlist==1.7.0 +fsspec==2023.10.0 +ftfy==6.3.1 +gitdb==4.0.12 +gitpython==3.1.45 +gradio-client==0.2.9 +h11==0.14.0 +hf-transfer==0.1.9 +hf-xet==1.1.7 +hjson==3.1.0 +httpcore==0.16.3 +httpx==0.24.0 +huggingface-hub==0.34.4 +identify==2.6.13 +idna==3.10 +imageio==2.37.0 +imageio-ffmpeg==0.6.0 +isort==5.13.2 +jinja2==3.1.5 +jiter==0.10.0 +joblib==1.5.1 +jsonlines==4.0.0 +langcodes==3.5.0 +language-data==1.3.0 +lazy-import==0.2.2 +levenshtein==0.27.1 +loguru==0.7.3 +lxml==6.0.0 +marisa-trie==1.3.0 +markdown-it-py==4.0.0 +markupsafe==3.0.2 +mbstrdecoder==1.1.4 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.6.4 +multiprocess==0.70.15 +murmurhash==1.0.13 +mypy-extensions==1.1.0 +networkx==3.4.2 +ninja==1.13.0 +nltk==3.9.1 +nodeenv==1.9.1 +numexpr==2.11.0 +nvidia-cublas-cu12 +nvidia-cuda-cupti-cu12 +nvidia-cuda-nvrtc-cu12 +nvidia-cuda-runtime-cu12 +nvidia-cudnn-cu12 +nvidia-cufft-cu12 +nvidia-cufile-cu12 +nvidia-curand-cu12 +nvidia-cusolver-cu12 +nvidia-cusparse-cu12 +nvidia-cusparselt-cu12 +nvidia-ml-py==12.575.51 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +open-clip-torch==3.1.0 +openai==1.100.0 +openpyxl==3.1.5 +packaging==25.0 +pandas==2.3.1 +pathlib2==2.3.7.post1 +pathspec==0.12.1 +pathvalidate==3.3.1 +peft==0.9.0 +pillow==11.3.0 +platformdirs==4.3.8 +portalocker==3.2.0 +pre-commit==4.3.0 +preshed==3.0.10 +propcache==0.3.2 +protobuf==3.20.0 +psutil==7.0.0 +py-cpuinfo==9.0.0 +pyarrow==21.0.0 +pyarrow-hotfix==0.7 +pybind11==3.0.0 +pycocoevalcap==1.2 +pycocotools==2.0.10 +pycryptodome==3.23.0 +pydantic==2.9.2 +pydantic-core==2.33.2 +pygments==2.19.2 +pynvml==12.0.0 +pytablewriter==1.2.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 +pywsd==1.2.5 +pyyaml==6.0.2 +rapidfuzz==3.13.0 +regex==2025.7.34 +requests==2.32.4 +rfc3986==1.5.0 +rich==14.1.0 +rouge==1.0.1 +sacrebleu==2.5.1 +safetensors==0.6.2 +scikit-learn==1.7.1 +scipy==1.15.3 +sentence-transformers==5.1.0 +sentencepiece==0.1.99 +sentry-sdk==2.35.0 +shellingham==1.5.4 +shortuuid==1.0.13 +shtab==1.7.2 +six==1.17.0 +smart-open==7.3.0.post1 +smmap==5.0.2 +sniffio==1.3.1 +spacy==3.8.7 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +sqlitedict==2.1.0 +srsly==2.5.1 +starlette==0.47.2 +sympy==1.13.1 +tabledata==1.3.4 +tabulate==0.9.0 +tcolorpy==0.1.7 +tenacity==8.3.0 +tensorboardx==2.6.4 +thinc==8.3.6 +threadpoolctl==3.6.0 +tiktoken==0.11.0 +timm==1.0.19 +tokenizers==0.19.1 +tomli==2.2.1 +torch==2.6.0 +torchvision==0.21.0 +tqdm==4.67.1 +transformers==4.43.0 +transformers-stream-generator==0.0.5 +triton==3.2.0 +typeguard==4.4.4 +typepy==1.3.4 +typer==0.16.1 +typing-extensions==4.14.1 +typing-inspection==0.4.1 +tyro==0.9.28 +tzdata==2025.2 +urllib3==2.0.0 +uvicorn==0.35.0 +virtualenv==20.34.0 +wandb==0.21.1 +wasabi==1.1.3 +wcwidth==0.2.13 +weasel==0.4.1 +webdataset==1.0.2 +websockets==15.0.1 +wn==0.0.23 +wrapt==1.17.3 +xxhash==3.5.0 +yarl==1.20.1 +yt-dlp==2025.8.11 +zss==1.2.0 +zstandard==0.24.0 +opencv-python==4.12.0.88 +opencv-python-headless==4.12.0.88 \ No newline at end of file diff --git a/train/LICENSE b/train/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/train/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/train/generate_demo_navsim_moe.py b/train/generate_demo_navsim_moe.py new file mode 100644 index 0000000..8f75dbf --- /dev/null +++ b/train/generate_demo_navsim_moe.py @@ -0,0 +1,352 @@ +# LINT_ME +"""Generate demo navigation predictions using NavSim-MOE model in MoE setting.""" + +from __future__ import annotations + +import copy +import json +import os +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import List, Tuple + +# Remove these if not needed in your environment. +import llava.s3_ops as st +import torch +import torch.distributed as dist + +# llava / model imports +from llava.cache import dLLMCache, dLLMCacheConfig +from llava.conversation import conv_templates +from llava.hooks import register_cache_LLaDA_V +from llava.hooks.fast_dllm_hook import register_fast_dllm_hook +from llava.mm_utils import IMAGE_TOKEN_INDEX, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model +from transformers.hf_argparser import HfArgumentParser + + +# ========================= +# Dataclass Args (HF-style) +# ========================= +@dataclass +class DataArgs: + """ + Data-related arguments. + """ + navsim_root: str = field(default="None") + steps_fut: int = 8 + steps_hist: int = 4 + camera_order: List[str] = field( + default_factory=lambda: [ + "CAM_FRONT", + "CAM_FRONT_LEFT", + "CAM_FRONT_RIGHT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", + ] + ) + + +@dataclass +class ModelArgs: # pylint: disable=R0902 + """Model Configuration Arguments.""" + pretrained_path: str + conv_template: str = "llava_llada" + use_fast_dllm: bool = True + use_dllm_cache: bool = False + prompt_interval_steps: int = 25 + gen_interval_steps: int = 7 + transfer_ratio: float = 0.25 + token_gen_steps: int = 32 + token_gen_length: int = 32 + token_block_length: int = 32 + + +@dataclass +class RuntimeArgs: + """Runtime / Distributed Training Arguments.""" + save_jsons_folder: str = "output/jsons" + device_type: str = "cuda" # cuda|npu|cpu + backend: str = "nccl" # nccl|gloo + merge_after: bool = False + debug_enable: bool = False + + +# ========================= +# DDP helpers +# ========================= +def setup_distributed(device_type: str = "npu", backend: str = "hccl"): + """ + Initialize torch.distributed, pin local rank to the right device, + and return (rank, world_size, local_rank, device_str). + """ + # Figure out ranks + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + # Pin device + device_type = device_type.lower() + if device_type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available.") + torch.cuda.set_device(local_rank) + device_str = f"cuda:{local_rank}" + elif device_type == "npu": + torch.npu.set_device(local_rank) + device_str = f"npu:{local_rank}" + else: + device_str = "cpu" + + # Init process group + if dist.is_available() and not dist.is_initialized(): + dist.init_process_group(backend=backend) + + return rank, world_size, local_rank, device_str + + +def get_rank_world() -> Tuple[int, int]: + """Get the current process rank and world size for distributed training.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_rank(), dist.get_world_size() + return 0, 1 + + +def barrier(): + """DDP barrier.""" + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def is_rank0() -> bool: + """Check if current process is rank 0.""" + r, _ = get_rank_world() + return r == 0 + + +def get_local_rank() -> int: + """Get local rank from env vars.""" + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + if "RANK" in os.environ: + num = torch.cuda.device_count() if torch.cuda.is_available() else 1 + return int(os.environ["RANK"]) % max(1, num) + return 0 + + +def pick_device(device_type: str, local_rank: int) -> str: + """Pick device string based on type and local rank.""" + device_type = device_type.lower() + if device_type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available.") + return f"cuda:{local_rank}" + if device_type == "npu": + # Adjust if using a different NPU stack + return f"npu:{local_rank}" + if device_type == "cpu": + return "cpu" + raise ValueError( + f"Unsupported --device_type '{device_type}'. Use one of: cuda|npu|cpu." + ) + + +# ========================= +# Arg parsing (HF) +# ========================= +def parse_args() -> tuple[DataArgs, ModelArgs, RuntimeArgs]: + """ + Supports: + - CLI flags + - Config file (JSON/TOML/YAML): eval_nuscenes_ddp.py cfg.json + - Mixed (file + overrides) + """ + parser = HfArgumentParser((DataArgs, ModelArgs, RuntimeArgs)) + data_args, model_args, runtime_args = parser.parse_args_into_dataclasses( + return_remaining_strings=False + ) + runtime_args.device_type = runtime_args.device_type.lower() + runtime_args.backend = runtime_args.backend.lower() + return data_args, model_args, runtime_args + + +def _int_stem(p: Path) -> int: + """Get integer from path stem.""" + try: + return int(p.stem) + except ValueError: # narrow exception (fixes W0718) + return 1_000_000_000 + + +def merge_rank_jsons(save_jsons_folder: str, world_size: int) -> None: + """Merge per-rank JSON files into a single JSONL file.""" + save_dir = Path(save_jsons_folder) + merged_path = save_dir / "merged.jsonl" + + numbered: list[tuple[int, Path]] = [] + for r in range(world_size): + rank_dir = save_dir / f"rank_{r}" + if not rank_dir.exists(): + continue + + for p in rank_dir.glob("*.json"): + idx = _int_stem(p) + if idx is None: + continue + numbered.append((idx, p)) + + numbered.sort(key=lambda t: t[0]) + + count = 0 + with st.open_file(str(merged_path), "w") as out_f: + for _, p in numbered: + with st.open_file(str(p), "r") as in_f: + out_f.write(in_f.read().rstrip("\n") + "\n") + count += 1 + + print(f"[Rank 0] merged {count} records -> {merged_path}") + +# ========================= +# Main +# ========================= +def main(): # pylint: disable=R0912, R0914, R0915 + """Main function.""" + data_args, model_args, runtime_args = parse_args() + # DDP init + rank, world_size, _, device_str = setup_distributed( + device_type=runtime_args.device_type, + backend=runtime_args.backend, + ) + + # Device & IO + device_map = device_str + + st.makedirs(runtime_args.save_jsons_folder) + rank_out_dir = Path(runtime_args.save_jsons_folder) / f"rank_{rank}" + st.makedirs(str(rank_out_dir)) + + if is_rank0(): + print( + f"[DDP] world_size={world_size} | saving to: {runtime_args.save_jsons_folder}" + ) + # Save resolved config + with st.open_file( + str(Path(runtime_args.save_jsons_folder) / "resolved_args.json"), "w" + ) as f: + json.dump( + { + "DataArgs": vars(data_args), + "ModelArgs": vars(model_args), + "RuntimeArgs": vars(runtime_args), + }, + f, + indent=2, + ) + + # Load model (per rank) + model_name = "llava_llada" + tokenizer, model, image_processor, _ = load_pretrained_model( + model_args.pretrained_path, + None, + model_name, + attn_implementation="sdpa", + device_map=device_map, + ) + model = model.to(device_str) + model.config.image_aspect_ratio = "anyres_max_2" + model.eval() + + # Optional speed-ups + if model_args.use_fast_dllm: + register_fast_dllm_hook(model) + if is_rank0(): + print("[Fast-dLLM] enabled") + elif model_args.use_dllm_cache: + dLLMCache.new_instance( + **asdict( + dLLMCacheConfig( + prompt_interval_steps=model_args.prompt_interval_steps, + gen_interval_steps=model_args.gen_interval_steps, + transfer_ratio=model_args.transfer_ratio, + ) + ) + ) + register_cache_LLaDA_V(model, "model.layers") + if is_rank0(): + print("[dLLM-Cache] enabled") + else: + if is_rank0(): + print("[Cache] disabled") + + torch.set_grad_enabled(False) + + surrounding_views = ["./assets/image.png"] + + images = [st.image_open(p) for p in surrounding_views] + + if len(images) > 1: + model.config.image_aspect_ratio = "pad" + + image_tensor = process_images(images, image_processor, model.config) + image_tensor = [ + _img.to(dtype=torch.float16, device=device_str) for _img in image_tensor + ] + image_sizes = [img.size for img in images] + + # Prompt + conv = copy.deepcopy(conv_templates[model_args.conv_template]) + question = """Here is front views of a driving vehicle: + \nThe navigation information is: straight + The current position is (0.00,0.00) + Current velocity is: (8.92,-0.15) and current acceleration is: (-1.25,-0.01) + Predict the optimal driving action for the next 4 seconds with 8 new waypoints. + """ + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + + input_ids = ( + tokenizer_image_token( + prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" + ) + .unsqueeze(0) + .to(device_str) + ) + + # Generate + cont, track_x = model.generate( + input_ids, + images=image_tensor, + image_sizes=image_sizes, + steps=model_args.token_gen_steps, + gen_length=model_args.token_gen_length, + block_length=model_args.token_block_length, + tokenizer=tokenizer, + stopping_criteria=["<|eot_id|>"], + prefix_refresh_interval=32, + threshold=1, + ) + text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=False) + print(text_outputs) + + with open("output.txt", "w", encoding="utf-8") as f: + for x_step in track_x: + text_outputs = tokenizer.batch_decode(x_step, skip_special_tokens=False) + + # Write each decoded string on its own line + for text in text_outputs: + f.write(text + "\n") + + # Add separator after each step + f.write("--------\n") + + barrier() + + # Merge (rank 0) + if is_rank0() and runtime_args.merge_after: + merge_rank_jsons(runtime_args.save_jsons_folder, world_size) + + +if __name__ == "__main__": + main() diff --git a/train/init_env.sh b/train/init_env.sh new file mode 100644 index 0000000..0df7957 --- /dev/null +++ b/train/init_env.sh @@ -0,0 +1,2 @@ +pip install -e ".[train]" +pip install webdataset diff --git a/train/llava/__init__.py b/train/llava/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/train/llava/cache/Cache.py b/train/llava/cache/Cache.py new file mode 100644 index 0000000..282bd10 --- /dev/null +++ b/train/llava/cache/Cache.py @@ -0,0 +1,87 @@ +# pylint: disable=C0114,C0115,C0116,C0103 + +from collections import defaultdict + +import torch + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class dLLMCache(metaclass=Singleton): + gen_interval_steps: int + prompt_interval_steps: int + cfg_interval_steps: int + prompt_length: int + transfer_ratio: float + __cache: defaultdict + __step_counter: defaultdict + cache_type: str + + @classmethod + def new_instance( + cls, + prompt_interval_steps: int = 1, + gen_interval_steps: int = 1, + cfg_interval_steps: int = 1, + transfer_ratio: float = 0.0, + ) -> "dLLMCache": + ins = cls() + setattr(ins, "prompt_interval_steps", prompt_interval_steps) + setattr(ins, "gen_interval_steps", gen_interval_steps) + setattr(ins, "cfg_interval_steps", cfg_interval_steps) + setattr(ins, "transfer_ratio", transfer_ratio) + ins.init() + return ins + + def init(self) -> None: + self.__cache = defaultdict( + lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + ) + self.__step_counter = defaultdict(lambda: defaultdict(lambda: 0)) + + def reset_cache(self, prompt_length: int = 0) -> None: + self.init() + torch.cuda.empty_cache() + self.prompt_length = prompt_length + self.cache_type = "no_cfg" + + def set_cache( + self, layer_id: int, feature_name: str, features: torch.Tensor, cache_type: str + ) -> None: + self.__cache[self.cache_type][cache_type][layer_id][feature_name] = { + 0: features + } + + def get_cache( + self, layer_id: int, feature_name: str, cache_type: str + ) -> torch.Tensor: + output = self.__cache[self.cache_type][cache_type][layer_id][feature_name][0] + return output + + def update_step(self, layer_id: int) -> None: + self.__step_counter[self.cache_type][layer_id] += 1 + + def refresh_gen(self) -> bool: + return (self.current_step - 1) % self.gen_interval_steps == 0 + + def refresh_prompt(self) -> bool: + return (self.current_step - 1) % self.prompt_interval_steps == 0 + + def refresh_cfg(self) -> bool: + return ( + self.current_step - 1 + ) % self.cfg_interval_steps == 0 or self.current_step <= 5 + + @property + def current_step(self) -> int: + return max(list(self.__step_counter[self.cache_type].values()), default=1) + + def __repr__(self): + return "USE dLLMCache" diff --git a/train/llava/cache/Config.py b/train/llava/cache/Config.py new file mode 100644 index 0000000..d48e835 --- /dev/null +++ b/train/llava/cache/Config.py @@ -0,0 +1,11 @@ +# pylint: disable=C0114,C0115,C0116,C0103 + +from dataclasses import dataclass + + +@dataclass +class dLLMCacheConfig: + prompt_interval_steps: int = 1 + gen_interval_steps: int = 1 + transfer_ratio: float = 0.0 + cfg_interval_steps: int = 1 diff --git a/train/llava/cache/__init__.py b/train/llava/cache/__init__.py new file mode 100644 index 0000000..a8ef4bf --- /dev/null +++ b/train/llava/cache/__init__.py @@ -0,0 +1,6 @@ +# pylint: disable=C0114,C0115,C0116 + +from .Cache import dLLMCache +from .Config import dLLMCacheConfig + +__all__ = ["dLLMCache", "dLLMCacheConfig"] diff --git a/train/llava/constants.py b/train/llava/constants.py new file mode 100644 index 0000000..be8cf02 --- /dev/null +++ b/train/llava/constants.py @@ -0,0 +1,12 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" diff --git a/train/llava/conversation.py b/train/llava/conversation.py new file mode 100644 index 0000000..365bb84 --- /dev/null +++ b/train/llava/conversation.py @@ -0,0 +1,396 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Any, Union +import re +import base64 +from io import BytesIO +from PIL import Image +from transformers import AutoTokenizer + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + CHATML = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + QWEN = auto() + GEMMA = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0] + if "mmtag" in self.version: + init_msg = init_msg.replace("", "").strip() + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + elif not init_msg.startswith(""): + init_msg = init_msg.replace("", "").strip() + messages[0] = (init_role, "\n" + init_msg) + else: + messages[0] = (init_role, init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + message, images, _ = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + + elif self.sep_style == SeparatorStyle.LLAMA_3: + if self.tokenizer is None: + if self.version == "llama_v3": + self.tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True + ) + elif self.version == "llava_llada": + self.tokenizer = AutoTokenizer.from_pretrained( + "model/LLaDA-V", trust_remote_code=True + ) + else: + raise ValueError( + "The tokenizer is not available. Make sure you have the necessary permissions." + ) + chat_template_messages = [{"role": "system", "content": self.system}] + for role, message in messages: + if message: + if type(message) is tuple: + message, images = message + message = "" * len(images) + message + chat_template_messages.append({"role": role, "content": message}) + + # print(chat_template_messages) + return self.tokenizer.apply_chat_template( + chat_template_messages, tokenize=False, add_generation_prompt=True + ) + + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + + elif self.sep_style == SeparatorStyle.GEMMA: + ret = "" + for i, (role, message) in enumerate(messages): + assert role == self.roles[i % 2], ( + "Conversation should alternate user/assistant/user/assistant/..." + ) + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = ( + lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + ) + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def process_image( + self, image, image_process_mode, return_pil=False, image_format="PNG" + ): + if image_process_mode == "Pad": + + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + + if type(image) is not Image.Image: + image = Image.open(image).convert("RGB") + + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 672, 448 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + return image + else: + buffered = BytesIO() + image.save(buffered, format=image_format) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + return img_b64_str + + def get_images(self, return_pil=False, return_path=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + if type(image) != list: + image = [image] + for img in image: + if not return_path and self.is_image_file(img): + img = self.process_image( + img, image_process_mode, return_pil=return_pil + ) + else: + images.append(img) + return images + + def is_image_file(self, filename): + image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"] + return any(filename.lower().endswith(ext) for ext in image_extensions) + + def is_video_file(self, filename): + video_extensions = [ + ".mp4", + ".mov", + ".avi", + ".mkv", + ".wmv", + ".flv", + ".mpeg", + ".mpg", + ] + return any(filename.lower().endswith(ext) for ext in video_extensions) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + if type(image) != list: + image = [image] + if len(image) == 1: + msg = "\n" + msg.replace("", "").strip() + else: + msg = re.sub(r"()\n(?=)", r"\1 ", msg) + + img_str_list = [] + for img in image: + if self.is_image_file(img): + img_b64_str = self.process_image( + img, "Default", return_pil=False, image_format="JPEG" + ) + img_str = f'' + img_str_list.append(img_str) + elif self.is_video_file(img): + ret.append(((img,), None)) + + msg = msg.strip() + img_place_holder = "" + for img_str in img_str_list: + img_place_holder += f"{img_str}\n\n" + + if len(img_str_list) > 0: + msg = f"{img_place_holder}\n\n{msg}" + + if len(msg) > 0: + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [ + [x, y[0] if type(y) is tuple else y] for x, y in self.messages + ], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +def safe_load_tokenizer(tokenizer_id): + try: + return AutoTokenizer.from_pretrained(tokenizer_id) + except Exception: + return None + + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=[], + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +conv_llada_plain = Conversation( + system="", + roles=("", ""), + messages=[], + version="llada_plain", + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="\n", +) + +conv_llava_llada = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("user", "assistant"), + version="llava_llada", + messages=[], + offset=0, + sep="<|eot_id|>", + sep_style=SeparatorStyle.LLAMA_3, +) + + +default_conversation = conv_llava_llada +conv_templates = { + "default": conv_llava_llada, + "plain": conv_llava_plain, + "llada_plain": conv_llada_plain, + "llava_llada": conv_llava_llada, + "v0_plain": conv_llava_plain, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/train/llava/hooks/__init__.py b/train/llava/hooks/__init__.py new file mode 100644 index 0000000..01540b5 --- /dev/null +++ b/train/llava/hooks/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=C0114,C0115,C0116,C0103 + +from .cache_hook_LLaDA_V import register_cache_LLaDA_V + +__all__ = ["register_cache_LLaDA_V"] diff --git a/train/llava/hooks/cache_hook_LLaDA_V.py b/train/llava/hooks/cache_hook_LLaDA_V.py new file mode 100644 index 0000000..d993338 --- /dev/null +++ b/train/llava/hooks/cache_hook_LLaDA_V.py @@ -0,0 +1,867 @@ +# pylint: disable=C0114,C0115,C0116,C0103,R0913,R0917,R0914,R0912,R0915 + +import math +import types +from typing import Optional, Tuple + +import torch +from torch import nn + +from llava.cache import dLLMCache + + +# Helper functions from the new LLADa model (need to be accessible) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def register_cache_LLaDA_V(model: nn.Module, tf_block_module_key_name: str) -> None: + """ + Registers cache hooks for a LLaDA-like model. + tf_block_module_key_name is typically 'model.layers' for LLaMA-style models. + """ + target_module_path = tf_block_module_key_name.split(".") + current_module = model + for part in target_module_path: + current_module = getattr(current_module, part) + + target_module: Optional[nn.ModuleList] = current_module + if target_module is None or not isinstance(target_module, nn.ModuleList): + raise ValueError(f"Could not find nn.ModuleList at {tf_block_module_key_name}") + + for layer_index, tf_block in enumerate(target_module): + setattr(tf_block, "layer_idx", layer_index) + + setattr(tf_block, "_old_forward", tf_block.forward) + tf_block.forward = types.MethodType(llada_cache_hook_feature, tf_block) + + setattr(tf_block.self_attn, "_old_forward_main", tf_block.self_attn.forward) + tf_block.self_attn.attention_forward_for_cache = types.MethodType( + llada_attention_hook_for_cache, tf_block.self_attn + ) + + setattr( + tf_block.self_attn.rotary_emb, + "_old_forward", + tf_block.self_attn.rotary_emb.forward, + ) + tf_block.self_attn.rotary_emb.forward = types.MethodType( + llada_RoPe_forward_hook, tf_block.self_attn.rotary_emb + ) + + +def llada_attention_hook_for_cache( + self, # self is LLaDAAttention instance + q_in_proj: torch.Tensor, # Renamed from q to clarify it's post-projection from cache_hook + k_in_proj: torch.Tensor, # Renamed from k + v_in_proj: torch.Tensor, # Renamed from v + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + q_index: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + _ = layer_past + _ = use_cache + # q_in_proj, k_in_proj, v_in_proj are (Batch, SeqLen_specific_to_them, HiddenDim_model) + B, q_len_current_q = ( + q_in_proj.shape[0], + q_in_proj.shape[1], + ) # q_len_current_q is the length of the Q for this specific call + + q_num_heads = self.num_heads + k_num_heads = self.num_key_value_heads + v_num_heads = self.num_key_value_heads + head_dim = self.head_dim + + # Reshape q, k, v to (Batch, NumHeads, SeqLen, HeadDim) + q = q_in_proj.view(B, q_len_current_q, q_num_heads, head_dim).transpose(1, 2) + + k_seq_len_current_k = k_in_proj.shape[ + 1 + ] # k_seq_len_current_k is the length of K for this call (e.g., full context) + v_seq_len_current_v = v_in_proj.shape[1] + + k = k_in_proj.view(B, k_seq_len_current_k, k_num_heads, head_dim).transpose(1, 2) + v = v_in_proj.view(B, v_seq_len_current_v, v_num_heads, head_dim).transpose(1, 2) + + if hasattr(self, "rotary_emb"): + # q_index passed here should be for q_in_proj + q, k = self.rotary_emb(q, k, q_index=q_index) + + present = None + + k_repeated = repeat_kv(k, self.num_key_value_groups) + v_repeated = repeat_kv(v, self.num_key_value_groups) + + # q is (B, q_num_heads, q_len_current_q, head_dim) + # k_repeated is (B, q_num_heads, k_seq_len_current_k, head_dim) + + attn_weights = torch.matmul(q, k_repeated.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_bias is not None: + bias_q_dim = attention_bias.shape[-2] + bias_k_dim = attention_bias.shape[-1] + + sliced_attention_bias = attention_bias + if ( + q_len_current_q < bias_q_dim + ): # Q is a segment, assume it's the latest tokens + # This assumes the q segment is the last part of + # the full query sequence represented by the bias + sliced_attention_bias = attention_bias[:, :, -q_len_current_q:, :] + + if ( + k_seq_len_current_k < bias_k_dim + ): # K is a segment (less likely for full KV cache but possible) + # This assumes K is the latest part of the full key sequence in the bias + sliced_attention_bias = sliced_attention_bias[ + :, :, :, -k_seq_len_current_k: + ] + + # Final check on dimensions before adding + if ( + attn_weights.shape[-2] == sliced_attention_bias.shape[-2] + and attn_weights.shape[-1] == sliced_attention_bias.shape[-1] + ): + attn_weights = attn_weights + sliced_attention_bias + else: + if ( + sliced_attention_bias.shape[1] == 1 + and attn_weights.shape[1] == self.num_heads + ): # Mask has 1 head dim + attn_weights = ( + attn_weights + sliced_attention_bias + ) # Broadcast over heads + elif ( + sliced_attention_bias.shape[1] == self.num_heads + ): # Mask has same num heads + attn_weights = attn_weights + sliced_attention_bias + else: + raise RuntimeError( + f"Attention bias shape {sliced_attention_bias.shape} incompatible with " + f"attn_weights shape {attn_weights.shape} after slicing." + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + q.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + att_output_heads = torch.matmul(attn_weights, v_repeated) + + # Reshape to (B, q_len_current_q, ModelHiddenDim) + att_output_heads = ( + att_output_heads.transpose(1, 2) + .contiguous() + .view(B, q_len_current_q, q_num_heads * head_dim) + ) + + output = self.o_proj(att_output_heads) + + return output, present + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def llada_RoPe_forward_hook( + self_rope, + q_in: torch.Tensor, + k_in: torch.Tensor, + q_index: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_, k_ = q_in.float(), k_in.float() + + _, _, query_len_current, _ = q_.shape + _, _, key_len_current, _ = k_.shape + + max_pos_needed = key_len_current + if q_index is not None: + max_pos_needed = max( + max_pos_needed, int(q_index.max().item()) + 1 if q_index.numel() > 0 else 0 + ) + + max_pos_needed = max(max_pos_needed, query_len_current) + + if max_pos_needed == 0: + return q_in, k_in + + inv_freq_to_use = self_rope.inv_freq.to(q_.device) + t = torch.arange(max_pos_needed, device=q_.device, dtype=torch.float32) + if hasattr(self_rope, "scaling_factor"): + t = t / self_rope.scaling_factor + + freqs = torch.outer(t, inv_freq_to_use.float()) + emb = torch.cat((freqs, freqs), dim=-1) + + pos_cos_table = emb.cos() + pos_sin_table = emb.sin() + + if q_index is not None: + actual_q_indices = q_index[:, :query_len_current] + cos_q = pos_cos_table[actual_q_indices] + sin_q = pos_sin_table[actual_q_indices] + q_rotated = (q_ * cos_q.unsqueeze(1)) + (rotate_half(q_) * sin_q.unsqueeze(1)) + else: + q_indices = torch.arange(query_len_current, device=q_.device) + cos_q = pos_cos_table[q_indices].unsqueeze(0) + sin_q = pos_sin_table[q_indices].unsqueeze(0) + q_rotated = (q_ * cos_q.unsqueeze(1)) + (rotate_half(q_) * sin_q.unsqueeze(1)) + + k_indices = torch.arange(key_len_current, device=q_.device) + cos_k = pos_cos_table[k_indices].unsqueeze(0) + sin_k = pos_sin_table[k_indices].unsqueeze(0) + k_rotated = (k_ * cos_k.unsqueeze(1)) + (rotate_half(k_) * sin_k.unsqueeze(1)) + + return q_rotated.type_as(q_in), k_rotated.type_as(k_in) + + +def refresh_index( + new_features: torch.Tensor, + cached_features: torch.Tensor = None, + transfer_ratio: float = 0.5, + layer_id: int = 0, +) -> torch.Tensor: + _ = layer_id + batch_size, gen_len, _ = new_features.shape + num_replace = int(gen_len * transfer_ratio) + if num_replace == 0 or gen_len == 0: + return torch.empty( + (batch_size, 0), dtype=torch.long, device=new_features.device + ) + if cached_features is None or cached_features.shape[1] == 0: + return torch.empty( + (batch_size, 0), dtype=torch.long, device=new_features.device + ) + # pylint: disable=E1102 + cos_sim = torch.nn.functional.cosine_similarity( + new_features, cached_features, dim=-1 + ) + k_actual = min(num_replace, cos_sim.shape[1]) + if k_actual == 0: + return torch.empty( + (batch_size, 0), dtype=torch.long, device=new_features.device + ) + + transfer_index = torch.topk(cos_sim, largest=False, k=k_actual).indices + return transfer_index + + +def llada_cache_hook_feature( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[ + torch.Tensor + ] = None, # This is the original mask for the full layer input + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[ + torch.Tensor, + Optional[Tuple[torch.Tensor, torch.Tensor]], + Optional[Tuple[torch.Tensor, ...]], +]: + current_layer_idx = self.layer_idx + feature_cache = dLLMCache() + feature_cache.update_step(current_layer_idx) + _ = past_key_value + _ = cache_position + + prompt_length = feature_cache.prompt_length + # x_prompt and x_gen are sub-segments of hidden_states + x_prompt = hidden_states[:, :prompt_length, :] + x_gen = hidden_states[:, prompt_length:, :] + + refresh_gen = feature_cache.refresh_gen() + refresh_prompt = feature_cache.refresh_prompt() + transfer_ratio = feature_cache.transfer_ratio + + bs, _, dim = ( + hidden_states.shape + ) # seq_len is the length of hidden_states input to this layer + + transfer = 0 < transfer_ratio <= 1 + + index_from_attn_transfer = None + index_expanded_from_attn_transfer = None + + def project( + x_input: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x_normed = self.input_layernorm(x_input) + q = self.self_attn.q_proj(x_normed) + k = self.self_attn.k_proj(x_normed) + v = self.self_attn.v_proj(x_normed) + return q, k, v + + # This function needs to be smarter about the attention_bias it passes. + # q_tensor_origin_slice: A tuple (start_idx, length) indicating where q_tensor comes from + # relative to the original hidden_states/attention_mask. + def call_attention_on_qkv( + q_tensor, + k_tensor, + v_tensor, + original_full_attention_bias, + q_tensor_start_idx: int, # Start index of q_tensor within the original sequence + q_index: Optional[torch.Tensor] = None, + ): + q_len_for_this_call = q_tensor.shape[1] + k_len_for_this_call = k_tensor.shape[ + 1 + ] # This K is already the full K from dLLM cache + + # Slice the original_full_attention_bias + # original_full_attention_bias is likely + # (B, 1 or H, S_full_layer_input, K_full_from_dllm_or_max) + # We need (B, 1 or H, q_len_for_this_call, k_len_for_this_call) + sliced_bias = original_full_attention_bias + if original_full_attention_bias is not None: + # Slice query dimension based on q_tensor_start_idx and its length + # Slice key dimension to match k_tensor's length (which is k_full from dLLM) + sliced_bias = original_full_attention_bias[ + :, + :, + q_tensor_start_idx : q_tensor_start_idx + q_len_for_this_call, + :k_len_for_this_call, + ] + # Ensure num_heads dim is compatible + # (1 for broadcast, or matches self.self_attn.num_heads) + if ( + sliced_bias.shape[1] != 1 + and sliced_bias.shape[1] != self.self_attn.num_heads + ): + # This might indicate an issue with mask preparation upstream + # if it's not 1 or num_heads + # For now, assume it's (B,1,q,k) and will broadcast + # if num_heads > 1 in attention_hook + pass + + att_output, _ = self.self_attn.attention_forward_for_cache( + q_tensor, + k_tensor, + v_tensor, + attention_bias=sliced_bias, + layer_past=None, + use_cache=False, + q_index=q_index, + ) + return att_output + + def compute_mlp(input_to_mlp: torch.Tensor) -> torch.Tensor: + if input_to_mlp.shape[1] == 0: + return torch.empty_like(input_to_mlp) + x_norm = self.post_attention_layernorm(input_to_mlp) + gate_proj_out = self.mlp.gate_proj(x_norm) + up_proj_out = self.mlp.up_proj(x_norm) + act_out = self.mlp.act_fn(gate_proj_out) + x = act_out * up_proj_out + return self.mlp.down_proj(x) + + residual_pre_attn = hidden_states + + if refresh_gen and refresh_prompt: + q_full, k_full, v_full = project(hidden_states) + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={ + "k": k_full[:, :prompt_length, :], + "v": v_full[:, :prompt_length, :], + }, + cache_type="prompt", + ) + if hidden_states.shape[1] > prompt_length: + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={ + "k": k_full[:, prompt_length:, :], + "v": v_full[:, prompt_length:, :], + }, + cache_type="gen", + ) + + # Q is all of hidden_states, K,V also from all of hidden_states + # q_start_idx is 0 because q_full corresponds to the start of hidden_states + att = call_attention_on_qkv( + q_full, + k_full, + v_full, + attention_mask, + q_tensor_start_idx=0, + q_index=position_ids, + ) + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att[:, :prompt_length, :], + cache_type="prompt", + ) + if hidden_states.shape[1] > prompt_length: + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att[:, prompt_length:, :], + cache_type="gen", + ) + + elif refresh_gen and not refresh_prompt: + att_gen_part = torch.empty((bs, 0, dim), device=hidden_states.device) + if x_gen.shape[1] > 0: + q_gen, k_gen, v_gen = project(x_gen) + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={"k": k_gen, "v": v_gen}, + cache_type="gen", + ) + kv_cache_prompt = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="kv_cache", cache_type="prompt" + ) + k_prompt_val = kv_cache_prompt.get( + "k", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + v_prompt_val = kv_cache_prompt.get( + "v", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + + k_full_ctx = torch.cat([k_prompt_val, k_gen], dim=1) + v_full_ctx = torch.cat([v_prompt_val, v_gen], dim=1) + + q_gen_pos_ids = ( + position_ids[:, prompt_length:] + if position_ids is not None and position_ids.shape[1] > prompt_length + else None + ) + + # q_gen starts at prompt_length in the original hidden_states sequence + att_gen_part = call_attention_on_qkv( + q_gen, + k_full_ctx, + v_full_ctx, + attention_mask, + q_tensor_start_idx=prompt_length, + q_index=q_gen_pos_ids, + ) + + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att_gen_part, + cache_type="gen", + ) + + att_prompt_cache = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="attn", cache_type="prompt" + ) + att = torch.cat([att_prompt_cache, att_gen_part], dim=1) + + elif not refresh_gen and refresh_prompt: + q_prompt, k_prompt, v_prompt = project(x_prompt) + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={"k": k_prompt, "v": v_prompt}, + cache_type="prompt", + ) + kv_cache_gen = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="kv_cache", cache_type="gen" + ) + att_gen_cache = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="attn", cache_type="gen" + ) + + k_gen_current = kv_cache_gen.get( + "k", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + v_gen_current = kv_cache_gen.get( + "v", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + + q_for_attn_segments = [q_prompt] + q_idx_for_attn_segments = ( + [position_ids[:, :prompt_length]] if position_ids is not None else [None] + ) + + if transfer and x_gen.shape[1] > 0 and k_gen_current.shape[1] > 0: + _, _, v_gen_for_transfer = project(x_gen) + index_from_attn_transfer = refresh_index( + v_gen_for_transfer, v_gen_current, transfer_ratio, current_layer_idx + ) + + if index_from_attn_transfer.numel() > 0: + index_expanded_from_attn_transfer = index_from_attn_transfer.unsqueeze( + -1 + ).expand(-1, -1, dim) + + x_gen_normed_selected = torch.gather( + self.input_layernorm(x_gen), + dim=1, + index=index_expanded_from_attn_transfer, + ) + q_gen_index = self.self_attn.q_proj(x_gen_normed_selected) + k_gen_index = self.self_attn.k_proj(x_gen_normed_selected) + v_gen_index_part = self.self_attn.v_proj(x_gen_normed_selected) + + k_gen_current = k_gen_current.scatter( + dim=1, index=index_expanded_from_attn_transfer, src=k_gen_index + ) + v_gen_current = v_gen_current.scatter( + dim=1, index=index_expanded_from_attn_transfer, src=v_gen_index_part + ) + + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={"k": k_gen_current, "v": v_gen_current}, + cache_type="gen", + ) + + q_for_attn_segments.append(q_gen_index) + if position_ids is not None and position_ids.shape[1] > prompt_length: + gen_abs_positions_all = position_ids[:, prompt_length:] + gen_abs_positions_selected = torch.gather( + gen_abs_positions_all, 1, index_from_attn_transfer + ) + q_idx_for_attn_segments.append(gen_abs_positions_selected) + else: + q_idx_for_attn_segments.append(None) + + q_combined_for_attn = torch.cat(q_for_attn_segments, dim=1) + q_idx_combined_for_rope = ( + torch.cat(q_idx_for_attn_segments, dim=1) + if all(s is not None for s in q_idx_for_attn_segments) + else None + ) + + k_full_ctx = torch.cat([k_prompt, k_gen_current], dim=1) + v_full_ctx = torch.cat([v_prompt, v_gen_current], dim=1) + + # q_combined_for_attn effectively starts at index 0 of a conceptual sequence. + # Its RoPE is handled by q_idx_combined_for_rope. + att_for_q_combined = call_attention_on_qkv( + q_combined_for_attn, + k_full_ctx, + v_full_ctx, + attention_mask, + q_tensor_start_idx=0, # Since Q is combined and starts from effective 0 + q_index=q_idx_combined_for_rope, + ) + + att_prompt_new = att_for_q_combined[:, : q_prompt.shape[1], :] + if ( + transfer + and index_from_attn_transfer is not None + and index_from_attn_transfer.numel() > 0 + ): + att_gen_index_new = att_for_q_combined[ + :, q_prompt.shape[1] :, : + ] # Segment for transferred Qs + if att_gen_cache.shape[1] > 0: + att_gen_cache = att_gen_cache.scatter( + dim=1, + index=index_expanded_from_attn_transfer, + src=att_gen_index_new, + ) + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att_gen_cache, + cache_type="gen", + ) + + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att_prompt_new, + cache_type="prompt", + ) + att = torch.cat([att_prompt_new, att_gen_cache], dim=1) + + else: # Not refresh gen, not refresh prompt + att_prompt_cache = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="attn", cache_type="prompt" + ) + att_gen_cache = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="attn", cache_type="gen" + ) + kv_cache_gen = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="kv_cache", cache_type="gen" + ) + kv_cache_prompt = feature_cache.get_cache( + layer_id=current_layer_idx, feature_name="kv_cache", cache_type="prompt" + ) + + k_gen_current = kv_cache_gen.get( + "k", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + v_gen_current = kv_cache_gen.get( + "v", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + k_prompt_val = kv_cache_prompt.get( + "k", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + v_prompt_val = kv_cache_prompt.get( + "v", torch.empty(bs, 0, dim, device=hidden_states.device) + ) + + if transfer and x_gen.shape[1] > 0 and k_gen_current.shape[1] > 0: + x_gen_normed = self.input_layernorm(x_gen) + v_gen_for_transfer = self.self_attn.v_proj(x_gen_normed) + index_from_attn_transfer = refresh_index( + v_gen_for_transfer, v_gen_current, transfer_ratio, current_layer_idx + ) + + if index_from_attn_transfer.numel() > 0: + index_expanded_from_attn_transfer = index_from_attn_transfer.unsqueeze( + -1 + ).expand(-1, -1, dim) + + x_gen_normed_selected = torch.gather( + x_gen_normed, dim=1, index=index_expanded_from_attn_transfer + ) + q_gen_index_only = self.self_attn.q_proj( + x_gen_normed_selected + ) # Q only for transferred items + k_gen_index = self.self_attn.k_proj(x_gen_normed_selected) + v_gen_index_part = self.self_attn.v_proj(x_gen_normed_selected) + + k_gen_current = k_gen_current.scatter( + dim=1, index=index_expanded_from_attn_transfer, src=k_gen_index + ) + v_gen_current = v_gen_current.scatter( + dim=1, index=index_expanded_from_attn_transfer, src=v_gen_index_part + ) + + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="kv_cache", + features={"k": k_gen_current, "v": v_gen_current}, + cache_type="gen", + ) + + q_idx_for_transferred_rope = None + if position_ids is not None and position_ids.shape[1] > prompt_length: + gen_abs_positions_all = position_ids[:, prompt_length:] + q_idx_for_transferred_rope = torch.gather( + gen_abs_positions_all, 1, index_from_attn_transfer + ) + + k_full_ctx = torch.cat([k_prompt_val, k_gen_current], dim=1) + v_full_ctx = torch.cat([v_prompt_val, v_gen_current], dim=1) + + att_gen_index_new = call_attention_on_qkv( + q_gen_index_only, + k_full_ctx, + v_full_ctx, + attention_mask, + q_tensor_start_idx=prompt_length, # Approximate for mask slicing + q_index=q_idx_for_transferred_rope, + ) + + if att_gen_cache.shape[1] > 0: # Make sure att_gen_cache has a gen part + att_gen_cache = att_gen_cache.scatter( + dim=1, + index=index_expanded_from_attn_transfer, + src=att_gen_index_new, + ) + elif ( + x_gen.shape[1] > 0 + ): # If original att_gen_cache was for an empty gen part, but x_gen is not empty + # Initialize att_gen_cache to be of x_gen's length before scattering + att_gen_cache = torch.zeros( + (bs, x_gen.shape[1], dim), + device=hidden_states.device, + dtype=att_gen_index_new.dtype, + ) + att_gen_cache = att_gen_cache.scatter( + dim=1, + index=index_expanded_from_attn_transfer, + src=att_gen_index_new, + ) + # Else: if x_gen.shape[1] is 0, att_gen_cache remains empty, scatter won't happen. + + feature_cache.set_cache( + layer_id=current_layer_idx, + feature_name="attn", + features=att_gen_cache, + cache_type="gen", + ) + + att = torch.cat([att_prompt_cache, att_gen_cache], dim=1) + + # ... rest of the llada_cache_hook_feature (MLP part) remains the same ... + hidden_states_after_attn = residual_pre_attn + att + residual_pre_mlp = hidden_states_after_attn + + x_prompt_mlp = hidden_states_after_attn[:, :prompt_length, :] + x_gen_mlp = hidden_states_after_attn[:, prompt_length:, :] + + mlp_out_prompt_part = torch.empty( + (bs, prompt_length, dim), + device=hidden_states.device, + dtype=hidden_states_after_attn.dtype, + ) + mlp_out_gen_part = torch.empty( + (bs, x_gen_mlp.shape[1], dim), + device=hidden_states.device, + dtype=hidden_states_after_attn.dtype, + ) + + if refresh_gen and refresh_prompt: + mlp_out_full = compute_mlp(hidden_states_after_attn) + mlp_out_prompt_part = mlp_out_full[:, :prompt_length, :] + if x_gen_mlp.shape[1] > 0: + mlp_out_gen_part = mlp_out_full[:, prompt_length:, :] + + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_prompt_part, cache_type="prompt" + ) + if x_gen_mlp.shape[1] > 0: + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen" + ) + if mlp_out_gen_part.shape[1] > 0: + mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1) + else: + mlp_out = mlp_out_prompt_part + + elif refresh_gen and not refresh_prompt: + mlp_out_prompt_part = feature_cache.get_cache( + current_layer_idx, "mlp", cache_type="prompt" + ) + if x_gen_mlp.shape[1] > 0: + mlp_out_gen_part = compute_mlp(x_gen_mlp) + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen" + ) + + if mlp_out_gen_part.shape[1] > 0: + mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1) + else: + mlp_out = mlp_out_prompt_part + + elif refresh_prompt and not refresh_gen: + mlp_gen_cache_data = feature_cache.get_cache( + current_layer_idx, "mlp", cache_type="gen" + ) + if x_gen_mlp.shape[1] > 0: + mlp_out_gen_part = mlp_gen_cache_data + + mlp_input_for_prompt_path = x_prompt_mlp + # Use index_expanded_from_attn_transfer which was set in the attention block + if ( + transfer + and index_expanded_from_attn_transfer is not None + and index_expanded_from_attn_transfer.numel() > 0 + and x_gen_mlp.shape[1] > 0 + ): + x_gen_mlp_selected = torch.gather( + x_gen_mlp, dim=1, index=index_expanded_from_attn_transfer + ) + mlp_input_for_prompt_path = torch.cat( + [x_prompt_mlp, x_gen_mlp_selected], dim=1 + ) + + mlp_out_prompt_path_processed = compute_mlp(mlp_input_for_prompt_path) + mlp_out_prompt_part = mlp_out_prompt_path_processed[ + :, : x_prompt_mlp.shape[1], : + ] + + if ( + transfer + and index_expanded_from_attn_transfer is not None + and index_expanded_from_attn_transfer.numel() > 0 + and x_gen_mlp.shape[1] > 0 + ): + mlp_gen_index_new = mlp_out_prompt_path_processed[ + :, x_prompt_mlp.shape[1] :, : + ] + if mlp_out_gen_part.shape[1] > 0: # If gen part exists + mlp_out_gen_part = mlp_out_gen_part.scatter( + dim=1, + index=index_expanded_from_attn_transfer, + src=mlp_gen_index_new, + ) + + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen" + ) + + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_prompt_part, cache_type="prompt" + ) + if mlp_out_gen_part.shape[1] > 0: + mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1) + else: + mlp_out = mlp_out_prompt_part + + else: + mlp_out_prompt_part = feature_cache.get_cache( + current_layer_idx, "mlp", cache_type="prompt" + ) + mlp_gen_cache_data = feature_cache.get_cache( + current_layer_idx, "mlp", cache_type="gen" + ) + if x_gen_mlp.shape[1] > 0: + mlp_out_gen_part = mlp_gen_cache_data + + # Use index_expanded_from_attn_transfer + if ( + transfer + and index_expanded_from_attn_transfer is not None + and index_expanded_from_attn_transfer.numel() > 0 + and x_gen_mlp.shape[1] > 0 + ): + x_gen_mlp_selected = torch.gather( + x_gen_mlp, dim=1, index=index_expanded_from_attn_transfer + ) + mlp_out_gen_index = compute_mlp(x_gen_mlp_selected) + if mlp_out_gen_part.shape[1] > 0: + mlp_out_gen_part = mlp_out_gen_part.scatter( + dim=1, + index=index_expanded_from_attn_transfer, + src=mlp_out_gen_index, + ) + + feature_cache.set_cache( + current_layer_idx, "mlp", mlp_out_gen_part, cache_type="gen" + ) + + if mlp_out_gen_part.shape[1] > 0: + mlp_out = torch.cat([mlp_out_prompt_part, mlp_out_gen_part], dim=1) + else: + mlp_out = mlp_out_prompt_part + + final_hidden_states = residual_pre_mlp + mlp_out + + returned_outputs = (final_hidden_states,) + if output_attentions: + returned_outputs += (None,) + if use_cache: + returned_outputs += (None,) + + return returned_outputs diff --git a/train/llava/hooks/fast_dllm_hook.py b/train/llava/hooks/fast_dllm_hook.py new file mode 100644 index 0000000..c4d9dbe --- /dev/null +++ b/train/llava/hooks/fast_dllm_hook.py @@ -0,0 +1,702 @@ +# pylint: disable=C0114,C0115,C0116,C0103,R0913,R0917,R0914,R0912,R0915,E1102 +""" +Fast dLLM Cache Hook for LLaDA Model + +This module contains hooks and utilities for implementing fast distributed LLM caching +in the LLaDA model architecture. +""" + +from typing import List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from .cache_hook_LLaDA_V import repeat_kv + + +class FastDLLMGenerationHook: + """ + Hook class for implementing fast dLLM caching functionality in LLaDA model. + This class handles both attention-level caching and generation-level optimizations. + """ + + def __init__(self, model): + self.model = model + self.original_methods = {} + self.is_registered = False + + def register_hooks(self): + """Register fast dLLM hooks to the model.""" + if self.is_registered: + return + + # Store original methods + # 新增:保存原始的generate方法 + self.original_methods["generate"] = self.model.generate + + # Store original attention forwards + for layer_idx, layer in enumerate(self.model.model.layers): + self.original_methods[f"attention_{layer_idx}"] = layer.self_attn.forward + # Replace attention forward with fast cache version + layer.self_attn.forward = self._create_fast_attention_forward( + layer.self_attn, layer_idx + ) + + self.model.generate = self._fast_generate # 新增:替换generate方法 + + self.is_registered = True + + def unregister_hooks(self): + """Unregister fast dLLM hooks from the model.""" + if not self.is_registered: + return + + # Restore original methods + # 新增:恢复原始的generate方法 + self.model.generate = self.original_methods["generate"] + + # Restore original attention forwards + for layer_idx, layer in enumerate(self.model.model.layers): + layer.self_attn.forward = self.original_methods[f"attention_{layer_idx}"] + + self.is_registered = False + + def _create_fast_attention_forward(self, attention_layer, layer_idx): + """Create fast attention forward method with caching support.""" + + def fast_attention_forward( + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + fast_dllm_cache: Optional[ + Sequence[Tuple[torch.Tensor, torch.Tensor]] + ] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # If not using fast cache or output_attentions is needed, use original method + if output_attentions: + return self.original_methods[f"attention_{layer_idx}"]( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = attention_layer.q_proj(hidden_states) + key_states = attention_layer.k_proj(hidden_states) + value_states = attention_layer.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, attention_layer.num_heads, attention_layer.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, + q_len, + attention_layer.num_key_value_heads, + attention_layer.head_dim, + ).transpose(1, 2) + value_states = value_states.view( + bsz, + q_len, + attention_layer.num_key_value_heads, + attention_layer.head_dim, + ).transpose(1, 2) + + # Apply rotary position embedding with fast cache consideration + cache_offset = 0 + if fast_dllm_cache and len(fast_dllm_cache) > layer_idx: + cache_offset = fast_dllm_cache[layer_idx][0].shape[-2] + + cos, sin = attention_layer.rotary_emb( + value_states, + position_ids + cache_offset if cache_offset > 0 else position_ids, + ) + query_states, key_states = self._apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # Handle past key values + past_key_value = getattr(attention_layer, "past_key_value", past_key_value) + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, layer_idx, cache_kwargs + ) + + # Fast dLLM cache logic + if fast_dllm_cache is not None: + if len(fast_dllm_cache) <= layer_idx: + fast_dllm_cache.append((key_states, value_states)) + else: + past_key, past_value = fast_dllm_cache[layer_idx] + key_states = torch.cat([past_key, key_states], dim=-2) + value_states = torch.cat([past_value, value_states], dim=-2) + + # Repeat key-value pairs for multi-head attention + key_states = repeat_kv(key_states, attention_layer.num_key_value_groups) + value_states = repeat_kv(value_states, attention_layer.num_key_value_groups) + + # Apply causal mask + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # Ensure contiguous tensors for CUDA + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # Scaled dot-product attention + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=False, + dropout_p=attention_layer.attention_dropout + if attention_layer.training + else 0.0, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, attention_layer.hidden_size) + attn_output = attention_layer.o_proj(attn_output) + + return attn_output, None, past_key_value + + return fast_attention_forward + + @torch.no_grad() + def _fast_generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ): + modalities = ( + kwargs.pop("modalities", None) + if "modalities" in kwargs and modalities is None + else modalities + ) + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( + self.model.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + modalities, + image_sizes=image_sizes, + ) + ) + else: + inputs_embeds = self.model.get_model().embed_tokens(inputs) + output = self._fast_generate_with_embeds(inputs_embeds=inputs_embeds, **kwargs) + return output + + def reverse_causal_x0_p_comma(self, x0, logits, block_start, block_end): + batch_size, seq_len = x0.shape + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) + if block_start > block_end: + block_start = 0 + assert seq_len == block_end + # 计算 block 内的长度 + block_len = block_end - block_start + assert block_len == 24, "block len is not 24! is that true?" + even_indices = torch.tensor( + [block_start + i for i in range(0, block_len, 2)], device=x0.device + ).repeat(batch_size, 1) + num_even = even_indices.shape[1] + decreasing_values = torch.linspace( + 0.1, 0.99, steps=num_even, device=x0.device, dtype=x0_p.dtype + ).repeat(batch_size, 1) + x0_p = x0_p.scatter_(1, even_indices, decreasing_values) + return x0_p + + @torch.no_grad() + def _fast_generate_with_embeds( + self, + inputs_embeds, + steps=128, + gen_length=128, + block_length=128, + temperature=0.0, + cfg_scale=0.0, + remasking="low_confidence", + mask_id=126336, + tokenizer=None, + stopping_criteria=None, + generation_suffix=None, + threshold=None, + prefix_refresh_interval=32, + **kwargs, + ): + """ + Fast generation with embeddings using dLLM cache optimization. + This method incorporates all fast dLLM related optimizations. + """ + _ = kwargs # Unused + # Use mixed precision for faster computation + with torch.autocast(enabled=True, device_type="cuda", dtype=torch.bfloat16): + # Handle generation suffix + suffix_embeds = None + suffix_token_ids = None + suffix_len = 0 + if ( + generation_suffix is not None + and tokenizer is not None + and len(generation_suffix) > 0 + ): + suffix_token_ids = tokenizer.encode( + generation_suffix, add_special_tokens=False + ) + suffix_token_ids = torch.tensor( + suffix_token_ids, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0) + suffix_embeds = self.model.model.embed_tokens(suffix_token_ids) + suffix_len = suffix_embeds.shape[1] + + # Create input in embedding space + total_length = inputs_embeds.shape[1] + gen_length + suffix_len + masked_embed = self.model.model.embed_tokens( + torch.tensor([mask_id]).to(inputs_embeds.device) + ) + x_embeds = masked_embed.repeat(1, total_length, 1).to(inputs_embeds.device) + x_embeds[:, : inputs_embeds.shape[1]] = inputs_embeds.clone() + if suffix_embeds is not None: + x_embeds[:, -suffix_len:] = suffix_embeds + + # Create tracking tensor for token IDs + x = torch.full( + (1, total_length), + mask_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + track_x = [] # record all step of x + if suffix_token_ids is not None: + x[:, -suffix_len:] = suffix_token_ids + + # Prompt index tracking + prompt_index = torch.zeros( + (1, total_length), dtype=torch.bool, device=inputs_embeds.device + ) + prompt_index[:, : inputs_embeds.shape[1]] = 1 + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # Initialize stop tracking + stop_position = inputs_embeds.shape[1] + gen_length + found_stop_seq = False + stop_tokens = [] + + if stopping_criteria is not None: + assert tokenizer is not None, ( + "tokenizer is required when stopping_criteria is not None" + ) + for stop_str in stopping_criteria: + tokens = tokenizer.encode(stop_str, add_special_tokens=False) + stop_tokens.append(tokens) + + # Process each block + for num_block in range(num_blocks): + block_start = inputs_embeds.shape[1] + num_block * block_length + block_end = inputs_embeds.shape[1] + (num_block + 1) * block_length + + # Skip if stop found before current block + if found_stop_seq and stop_position <= block_start: + break + + block_embeds = x_embeds[:, block_start:block_end] + block_mask_index = torch.all( + torch.abs(block_embeds - masked_embed) < 1e-5, dim=2 + ) + num_transfer_tokens = self._get_num_transfer_tokens( + block_mask_index, steps + ) + + i = 0 + + while True: + if threshold is None and i >= steps: + break + + # Check mask state + mask_index = torch.all( + torch.abs(x_embeds - masked_embed) < 1e-5, dim=2 + ) + + if found_stop_seq: + pre_stop_masks = mask_index[ + 0, inputs_embeds.shape[1] : stop_position + ] + if not pre_stop_masks.any(): + break + + current_block_masks = mask_index[0, block_start:block_end] + if not current_block_masks.any(): + break + + # Handle CFG + if cfg_scale > 0.0: + un_embeds = x_embeds.clone() + un_mask = prompt_index.unsqueeze(-1).expand_as(x_embeds) + un_embeds[un_mask] = masked_embed.repeat( + x_embeds.shape[0], x_embeds.shape[1], 1 + )[un_mask] + combined_embeds = torch.cat([x_embeds, un_embeds], dim=0) + + if i % prefix_refresh_interval == 0: + fast_dllm_cache = [] + outputs = self.model.model( + inputs_embeds=combined_embeds, + fast_dllm_cache=fast_dllm_cache, + ) + fast_dllm_cache = self._create_cache_slice( + fast_dllm_cache, block_start + ) + else: + outputs = self.model.model( + inputs_embeds=combined_embeds[:, block_start:], + fast_dllm_cache=fast_dllm_cache, + ) + + logits = self.model.lm_head(outputs[0]).float() + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + if i % prefix_refresh_interval == 0: + fast_dllm_cache = [] + outputs = self.model.model( + inputs_embeds=x_embeds, fast_dllm_cache=fast_dllm_cache + ) + # Slice cache to block start + fast_dllm_cache = self._create_cache_slice( + fast_dllm_cache, block_start + ) + else: + # Incremental forward pass + outputs = self.model.model( + inputs_embeds=x_embeds[:, block_start:], + fast_dllm_cache=fast_dllm_cache, + ) + logits = self.model.lm_head(outputs[0]).float() + + # Filter forbidden tokens + forbidden_tokens = [126081, 126080, 126346, 126347] + if i % prefix_refresh_interval == 0: + for token_id in forbidden_tokens: + logits[:, :, token_id] = torch.where( + mask_index, -float("inf"), logits[:, :, token_id] + ) + else: + for token_id in forbidden_tokens: + logits[:, :, token_id] = torch.where( + mask_index[:, block_start:], + -float("inf"), + logits[:, :, token_id], + ) + + # Get transfer indices and update + if i % prefix_refresh_interval == 0: + x0, transfer_index = self._get_transfer_index( + logits, + temperature, + remasking, + mask_index, + x, + num_transfer_tokens[:, i] if threshold is None else None, + found_stop_seq, + stop_position, + block_end, + suffix_len, + block_start, + threshold, + ) + x0_embeds = self.model.model.embed_tokens(x0) + x0_embeds = torch.where( + mask_index.unsqueeze(-1).expand_as(x_embeds), + x0_embeds, + x_embeds, + ) + x_embeds[transfer_index] = x0_embeds[transfer_index] + x[transfer_index] = x0[transfer_index] + else: + x0, transfer_index = self._get_transfer_index( + logits, + temperature, + remasking, + mask_index[:, block_start:], + x[:, block_start:], + num_transfer_tokens[:, i] if threshold is None else None, + found_stop_seq, + stop_position - block_start, + block_end - block_start, + suffix_len, + block_start, + threshold, + ) + x0_embeds = self.model.model.embed_tokens(x0) + x0_embeds = torch.where( + mask_index[:, block_start:] + .unsqueeze(-1) + .expand_as(x_embeds[:, block_start:]), + x0_embeds, + x_embeds[:, block_start:], + ) + x_embeds[:, block_start:][transfer_index] = x0_embeds[ + transfer_index + ] + x[:, block_start:][transfer_index] = x0[transfer_index] + + track_x.append( + x.clone()[ + :, + inputs_embeds.shape[1] : inputs_embeds.shape[1] + + gen_length, + ] + ) + + # Check for stop words + if stopping_criteria is not None: + generated_part = x[ + 0, + inputs_embeds.shape[1] : inputs_embeds.shape[1] + + gen_length, + ] + for stop_seq in stop_tokens: + if not isinstance(stop_seq, list): + stop_seq = [stop_seq] + for start_idx in range( + generated_part.size(0) - len(stop_seq) + 1 + ): + if torch.all( + generated_part[ + start_idx : start_idx + len(stop_seq) + ] + == torch.tensor(stop_seq, device=x.device) + ): + current_position = ( + inputs_embeds.shape[1] + start_idx + ) + if ( + not found_stop_seq + or current_position < stop_position + ): + stop_position = current_position + found_stop_seq = True + break + if found_stop_seq: + break + i += 1 + + if threshold is not None: + print(f"Number of steps: {i}") + + # Return results + if found_stop_seq: + if suffix_len > 0: + return torch.cat( + [ + x[:, inputs_embeds.shape[1] : stop_position], + x[:, -suffix_len:], + ], + dim=1, + ) + return x[:, inputs_embeds.shape[1] : stop_position], track_x + if suffix_len > 0: + return torch.cat( + [ + x[ + :, + inputs_embeds.shape[1] : inputs_embeds.shape[1] + + gen_length, + ], + x[:, -suffix_len:], + ], + dim=1, + ) + return x[ + :, inputs_embeds.shape[1] : inputs_embeds.shape[1] + gen_length + ], track_x + + @staticmethod + def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Apply rotary position embedding to query and key tensors.""" + _ = position_ids # Unused in this implementation + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (FastDLLMGenerationHook._rotate_half(q) * sin) + k_embed = (k * cos) + (FastDLLMGenerationHook._rotate_half(k) * sin) + return q_embed, k_embed + + @staticmethod + def _rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _get_transfer_index( + self, + logits, + temperature, + remasking, + mask_index, + x, + num_transfer_tokens, + found_stop_seq, + stop_position, + block_end, + suffix_len, + block_start, + threshold=None, + ): + """Get transfer indices for token updates during generation.""" + logits_with_noise = self._add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) + + if remasking == "low_confidence": + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 + ) + elif remasking == "reverse_causal_comma": + x0_p = self.reverse_causal_x0_p_comma(x0, logits, block_start, block_end) + else: + raise NotImplementedError(remasking) + + # Handle stop sequences and block boundaries + if found_stop_seq: + x0_p[:, stop_position:] = -np.inf + else: + x0_p[:, block_end:] = -np.inf + + # Prevent overwriting suffix + if suffix_len > 0: + x0_p[:, -suffix_len:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + + if threshold is not None: + num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) + + for j in range(confidence.shape[0]): + if threshold is None: + top_i = num_transfer_tokens[j] + else: + ns = list(range(1, num_transfer_tokens[j] + 1)) + es = [threshold / (n + 1) for n in ns] + threshs = [1 - e for e in es] + threshs[0] = -1 # at least one token is transferred + + sorted_confidence = torch.sort( + confidence[j][mask_index[j]], dim=-1, descending=True + )[0] + assert len(sorted_confidence) == len(threshs) + + top_i = 0 + for top_i, _ in enumerate(threshs): + if sorted_confidence[top_i] < threshs[top_i]: + break + + if top_i in (0, len(threshs) - 1): + top_i += 1 + + _, select_index = torch.topk(confidence[j], k=top_i) + transfer_index[j, select_index] = True + + return x0, transfer_index + + @staticmethod + def _add_gumbel_noise(logits, temperature): + """Add Gumbel noise for categorical sampling.""" + if temperature == 0: + return logits + + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + @staticmethod + def _get_num_transfer_tokens(mask_index, steps): + """Precompute the number of tokens to transition at each step.""" + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = base.expand(-1, steps).clone() + + if remainder.sum() > 0: + indices = torch.arange(steps, device=mask_index.device) + mask = indices.unsqueeze(0) < remainder + num_transfer_tokens[mask] += 1 + + return num_transfer_tokens.to(torch.int64) + + def _create_cache_slice(self, fast_dllm_cache, block_start): + """Create a sliced version of fast_dllm_cache for block processing.""" + new_past_key_values = [] + for i, cache_i in enumerate(fast_dllm_cache): + new_past_key_values.append([]) + for _, cache_j in enumerate(cache_i): + new_past_key_values[i].append(cache_j[:, :, :block_start]) + return new_past_key_values + + +def register_fast_dllm_hook(model): + """ + Register fast dLLM cache hooks to the model. + + Args: + model: The LLaDA model to register hooks to + + Returns: + FastDLLMGenerationHook: The hook instance for management + """ + hook = FastDLLMGenerationHook(model) + hook.register_hooks() + return hook + + +def unregister_fast_dllm_hook(hook): + """ + Unregister fast dLLM cache hooks from the model. + + Args: + hook: The FastDLLMGenerationHook instance to unregister + """ + hook.unregister_hooks() diff --git a/train/llava/mm_utils.py b/train/llava/mm_utils.py new file mode 100644 index 0000000..52b4503 --- /dev/null +++ b/train/llava/mm_utils.py @@ -0,0 +1,396 @@ +from PIL import Image +from io import BytesIO +import base64 +import math +import ast +import re +import torch +from transformers import StoppingCriteria +from llava.constants import IMAGE_TOKEN_INDEX +from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord + + +def resize_and_center_crop(image, shortest_edge_length): + # Calculate new dimensions and resize + aspect_ratio = float(image.width) / float(image.height) + if aspect_ratio > 1: + new_width = int(shortest_edge_length * aspect_ratio) + new_height = shortest_edge_length + else: + new_width = shortest_edge_length + new_height = int(shortest_edge_length / aspect_ratio) + resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) + + # Calculate the position and perform the center crop + left = (new_width - shortest_edge_length) / 2 + top = (new_height - shortest_edge_length) / 2 + right = (new_width + shortest_edge_length) / 2 + bottom = (new_height + shortest_edge_length) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + + return cropped_image + + +def auto_pad_images(image, grid_params): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert len(grid_params) > 0, "Grid parameters should not be empty" + + # Step 1: Calculate and find the closest aspect ratio + input_width, input_height = image.size + input_aspect_ratio = input_width / input_height + candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params] + closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0])) + + candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3] + + target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1)) + + resize_width, resize_height = target_resolution + if input_width > input_height: + resize_height = int(resize_width / input_aspect_ratio) + else: + resize_width = int(resize_height * input_aspect_ratio) + resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS) + + # Step 5: Pad the resized image if necessary to match the target resolution + pad_width = target_resolution[0] - resize_width + pad_height = target_resolution[1] - resize_height + padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0)) + padded_image.paste(resized_image, (pad_width // 2, pad_height // 2)) + + return padded_image + + +def extract_patches(image, patch_size, overlap_ratio): + assert isinstance(image, Image.Image), "Input should be a Pillow Image" + assert patch_size > 0, "Patch size should be greater than 0" + assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1" + + W, H = image.size + patches = [] + + stride = int(patch_size * (1 - overlap_ratio)) + + num_patches_y = (H - patch_size) // stride + 1 + num_patches_x = (W - patch_size) // stride + 1 + + y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2 + x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2 + + for y in range(y_start, y_start + num_patches_y * stride, stride): + for x in range(x_start, x_start + num_patches_x * stride, stride): + patch = image.crop((x, y, x + patch_size, y + patch_size)) + patches.append(patch) + + return patches + + +def process_highres_image_crop_split(image, data_args, processor=None): + crop_resolution = data_args.image_crop_resolution + split_resolution = data_args.image_split_resolution + if processor is None: + processor = data_args.image_processor + image_crop = resize_and_center_crop(image, crop_resolution) + image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0) + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def process_highres_image(image, processor, grid_pinpoints): + grid_params = [int(x) for x in grid_pinpoints.split(",")] + width_height = max(image.size) + fit_grid_params = [x for x in grid_params if x >= width_height] + if len(fit_grid_params) == 0: + select_size = max(grid_params) + else: + select_size = min(fit_grid_params) + # FIXME: always select the 448 + select_size = max(grid_params) + image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + + # FIXME: this seems to be a bug that it always resizes instead of padding + image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"])) + image_padded = image_padded.resize((select_size, select_size)) + image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0) + image_patches = [image_original_resize] + image_patches + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + + # Calculate effective and wasted resolutions + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + # Determine which dimension (width or height) to fill + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + # Width will be filled completely + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + # Height will be filled completely + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + # Create a new image with the target size and paste the resized image onto it + new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def process_anyres_image(image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + # Convert grid_pinpoints from string to list + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size["height"]) + + # FIXME: this seems to be a bug that it resizes instead of pad. + # but to keep it consistent with previous, i will keep it as it is + # TODO: uncomment below to ablate with the padding + if isinstance(processor.size, dict): + shortest_edge = processor.size["shortest_edge"] + else: + shortest_edge = min(processor.size) + image_original_resize = image.resize((shortest_edge, shortest_edge)) + # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + new_images = [] + if image_aspect_ratio == "highres": + for image in images: + image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + for image in images: + image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + elif image_aspect_ratio == "crop_split": + for image in images: + image = process_highres_image_crop_split(image, model_cfg, image_processor) + new_images.append(image) + elif image_aspect_ratio == "pad": + for image in images: + image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + new_images.append(image) + else: + return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == "pt": + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f"Unsupported tensor type: {return_tensors}") + return input_ids + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, 3) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if output_ids[0, -keyword_id.shape[0] :] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False diff --git a/train/llava/model/__init__.py b/train/llava/model/__init__.py new file mode 100644 index 0000000..0a9ae48 --- /dev/null +++ b/train/llava/model/__init__.py @@ -0,0 +1,8 @@ +import os + +AVAILABLE_MODELS = { + "llava_llada": "LlavaLLaDAModelLM, LlavaLLaDAConfig", +} + +for model_name, model_classes in AVAILABLE_MODELS.items(): + exec(f"from .language_model.{model_name} import {model_classes}") \ No newline at end of file diff --git a/train/llava/model/apply_delta.py b/train/llava/model/apply_delta.py new file mode 100644 index 0000000..c183ba1 --- /dev/null +++ b/train/llava/model/apply_delta.py @@ -0,0 +1,47 @@ +""" +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +""" + +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from llava import LlavaLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" + bparam = base.state_dict()[name] + param.data[: bparam.shape[0], : bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/train/llava/model/builder.py b/train/llava/model/builder.py new file mode 100644 index 0000000..264534c --- /dev/null +++ b/train/llava/model/builder.py @@ -0,0 +1,145 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from llava.model import * +from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.utils import rank0_print +from trl.import_utils import is_npu_available + + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs): + kwargs["device_map"] = device_map + + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") + elif torch_dtype == "float16": + kwargs["torch_dtype"] = torch.float16 + elif torch_dtype == "bfloat16": + kwargs["torch_dtype"] = torch.bfloat16 + else: + import pdb;pdb.set_trace() + + if customized_config is not None: + kwargs["config"] = customized_config + + if "multimodal" in kwargs: + if kwargs["multimodal"] is True: + is_multimodal = True + kwargs.pop("multimodal") + else: + is_multimodal = False + + if "llava" in model_name.lower() or is_multimodal: + # Load LLaVA model + rank0_print(f"Loaded LLaVA model: {model_path}") + if "llada" in model_name.lower() or "exp" in model_name.lower(): + from llava.model.language_model.llava_llada import LlavaLLaDAConfig + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if customized_config is None: + llada_cfg = LlavaLLaDAConfig.from_pretrained(model_path) + else: + llada_cfg = customized_config + + if overwrite_config is not None: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llada_cfg, k, v) + + model = LlavaLLaDAModelLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llada_cfg, **kwargs) + + else: + try: + from llava.model.language_model.llava_llama import LlavaConfig + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if customized_config is None: + llava_cfg = LlavaConfig.from_pretrained(model_path) + if "v1.5" in model_path.lower(): + llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models + else: + llava_cfg = customized_config + + if overwrite_config is not None: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) + except: + raise ValueError(f"Model {model_name} not supported") + + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print("Convert to FP16...") + model.to(torch.float16) + else: + use_fast = False + if "mpt" in model_name.lower().replace("prompt", ""): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + + rank0_print(f"Model Class: {model.__class__.__name__}") + image_processor = None + + if "llava" in model_name.lower() or is_multimodal: + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + vision_tower = model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model(device_map=device_map) + if device_map != "auto": + if is_npu_available(): + vision_tower.to(device="npu", dtype=torch.float16) + else: + vision_tower.to(device="cuda", dtype=torch.bfloat16) + image_processor = vision_tower.image_processor + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + elif hasattr(model.config, "max_position_embeddings"): + context_len = model.config.max_position_embeddings + elif hasattr(model.config, "tokenizer_model_max_length"): + context_len = model.config.tokenizer_model_max_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len diff --git a/train/llava/model/consolidate.py b/train/llava/model/consolidate.py new file mode 100644 index 0000000..f02e575 --- /dev/null +++ b/train/llava/model/consolidate.py @@ -0,0 +1,30 @@ +""" +Usage: +python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate +""" + +import argparse + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from llava.model import * +from llava.model.utils import auto_upgrade + + +def consolidate_ckpt(src_path, dst_path): + print("Loading model") + auto_upgrade(src_path) + src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) + src_model.save_pretrained(dst_path) + src_tokenizer.save_pretrained(dst_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, required=True) + parser.add_argument("--dst", type=str, required=True) + + args = parser.parse_args() + + consolidate_ckpt(args.src, args.dst) diff --git a/train/llava/model/language_model/configuration_llada.py b/train/llava/model/language_model/configuration_llada.py new file mode 100644 index 0000000..7fc5604 --- /dev/null +++ b/train/llava/model/language_model/configuration_llada.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LLaDA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +LLaDA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class LLaDAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LLaDAModel`]. It is used to instantiate an LLaDA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaDA-8B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaDA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LLaDAModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + model_type = "llada" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + num_experts_per_tok=-1, + num_experts=-1, + moe_choice='expert', + capacity_factor=0.1, + moe_router_enable_expert_bias=None, + moe_router_score_function=None, + moe_lora_rank=32, + moe_lora_alpha=32, + moe_lora_in_features=4096, + moe_lora_out_features=4096, + moe_lora_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.moe_choice = moe_choice + self.capacity_factor = capacity_factor + self.moe_router_enable_expert_bias = moe_router_enable_expert_bias + self.moe_router_score_function = moe_router_score_function + self.moe_lora_rank=moe_lora_rank + self.moe_lora_alpha=moe_lora_alpha + self.moe_lora_in_features=moe_lora_in_features + self.moe_lora_out_features=moe_lora_out_features + self.moe_lora_dropout=moe_lora_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/train/llava/model/language_model/llava_llada.py b/train/llava/model/language_model/llava_llada.py new file mode 100644 index 0000000..22b1e10 --- /dev/null +++ b/train/llava/model/language_model/llava_llada.py @@ -0,0 +1,211 @@ +# Copyright 2023 Zebin You +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from llava.model.language_model.configuration_llada import LLaDAConfig +from llava.model.language_model.modeling_llada import LLaDAModel, LLaDAModelLM +from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +from transformers import ( + AutoConfig, + AutoModelForCausalLM, +) +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class LlavaLLaDAConfig(LLaDAConfig): + model_type = "llava_llada" + temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna + max_new_tokens: int = 1024 + do_sample: bool = False + top_p: Optional[float] = None + # rope_scaling: Optional[dict] = {} + + +class LlavaLLaDAModel(LlavaMetaModel, LLaDAModel): + config_class = LlavaLLaDAConfig + + def __init__(self, config: LLaDAConfig): + super(LlavaLLaDAModel, self).__init__(config) + + +class LlavaLLaDAModelLM(LLaDAModelLM, LlavaMetaForCausalLM): + config_class = LlavaLLaDAConfig + + def __init__(self, config): + LLaDAModelLM.__init__(self, config) + + # configure default generation settings + config.model_type = "llava_llada" + # config.rope_scaling = None + + self.model = LlavaLLaDAModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + dpo_forward: Optional[bool] = None, + cache_position=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if inputs_embeds is None and attention_mask is not None: + # donate multi-dialogue + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + conversation_ids, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + modalities, + image_sizes, + is_llada=True, + ) + elif inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + modalities, + image_sizes, + ) + conversation_ids = None + if dpo_forward: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + return logits, labels + + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + conversation_ids=conversation_ids, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + modalities = ( + kwargs.pop("modalities", None) + if "modalities" in kwargs and modalities is None + else modalities + ) + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( + self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + modalities, + image_sizes=image_sizes, + ) + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate_with_embeds(inputs_embeds=inputs_embeds, **kwargs) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + if images is not None: + inputs["images"] = images + if image_sizes is not None: + inputs["image_sizes"] = image_sizes + return inputs + + +AutoConfig.register("llava_llada", LlavaLLaDAConfig) +AutoModelForCausalLM.register(LlavaLLaDAConfig, LlavaLLaDAModelLM) diff --git a/train/llava/model/language_model/modeling_llada.py b/train/llava/model/language_model/modeling_llada.py new file mode 100644 index 0000000..997bad5 --- /dev/null +++ b/train/llava/model/language_model/modeling_llada.py @@ -0,0 +1,2330 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaDA model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from llava.cache import dLLMCache +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_llada import LLaDAConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LLaDAConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LLaDARMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LLaDARMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LLaDARMSNorm) + + +class LLaDARotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False + ) + self.register_buffer( + "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False + ) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LLaDAAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LLaDAAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LLaDALinearScalingRotaryEmbedding(LLaDARotaryEmbedding): + """LLaDARotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LLaDADynamicNTKScalingRotaryEmbedding(LLaDARotaryEmbedding): + """LLaDARotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) + / self.dim + ) + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LLaDAMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [ + F.linear(x, gate_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + up_proj = torch.cat( + [ + F.linear(x, up_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LLaDAAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LLaDAConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + # self.is_causal = True + # Modify: MDM set causal to False. + self.is_causal = False + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LLaDARotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LLaDALinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LLaDADynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.config.pretraining_tp, dim=2 + ) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LLaDAFlashAttention2(LLaDAAttention): + """ + LLaDA flash attention module. This module inherits from `LLaDAAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LLaDARMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LLaDAFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + assert causal is False # Modify: MDM + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LLaDASdpaAttention(LLaDAAttention): + """ + LLaDA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LLaDAAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LLaDAAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LLaDAModel is using LLaDASdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + is_causal=False, # Modify: MDM + dropout_p=self.attention_dropout if self.training else 0.0, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLaDA_ATTENTION_CLASSES = { + "eager": LLaDAAttention, + "flash_attention_2": LLaDAFlashAttention2, + "sdpa": LLaDASdpaAttention, +} + + +class LoRALayer(nn.Module): + def __init__(self, config): + super().__init__() + self.r = config.moe_lora_rank + self.alpha = config.moe_lora_alpha + self.scaling = self.alpha / self.r + + in_features = config.moe_lora_in_features + out_features = config.moe_lora_out_features + + if self.r > 0: + self.lora_A = nn.Parameter(torch.randn(self.r, in_features)) + self.lora_B = nn.Parameter(torch.zeros(out_features, self.r)) + self.dropout = nn.Dropout(config.moe_lora_dropout) + else: + self.lora_A = None + self.lora_B = None + self.dropout = nn.Identity() + + def forward(self, x): + return self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling + + +class LLaDAMoESparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.capacity_factor = config.capacity_factor + self.norm_topk_prob = config.moe_choice == "token" + + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList( + [LoRALayer(config) for _ in range(self.num_experts)] + ) + # self.score_func = config.moe_router_score_function + if ( + hasattr(config, "moe_router_enable_expert_bias") + and config.moe_router_enable_expert_bias + ): + self.register_buffer("expert_bias", torch.zeros(self.num_experts)) + else: + self.expert_bias = None + if config.moe_choice == "expert": + self.forward = self.forward_expert_choice + else: + self.forward = self.forward_token_choice + + # expert choice + def forward_expert_choice(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + num_tokens = batch_size * sequence_length + hidden_states = hidden_states.view(-1, hidden_dim) # flatten operation + # router_logits: (batch * sequence_length, n_experts) + router_logits = F.softmax( + self.gate(hidden_states), dim=1 + ) # importance of each token and experts + + router_logits_T = router_logits.t() # expert_num, sequence_length + + capacity = max( + 1, int(round(self.capacity_factor * num_tokens / self.num_experts)) + ) # find out the capacity works? + capacity = min(capacity, num_tokens) # clamp to valid range + # import pdb; pdb.set_trace() + expert_weights, expert_indices = torch.topk( + router_logits_T, + k=capacity, + dim=1, + sorted=False, # no need to sort for efficiency + ) + + # Initialize output and token processing counter + output = torch.zeros_like(hidden_states) + # import pdb; pdb.set_trace() + + # Process each expert in parallel (if possible) or sequentially + for expert_idx in range(self.num_experts): + # Get tokens selected by this expert + token_indices = expert_indices[expert_idx] # [capacity] + weights = expert_weights[expert_idx] # [capacity] + + # Get hidden states for selected tokens + selected_hidden_states = hidden_states[ + token_indices + ] # [capacity, hidden_size] + + # Expert forward pass + expert_output = self.experts[expert_idx]( + selected_hidden_states + ) # [capacity, hidden_size] -> + + # Weighted accumulation to output + # Each token may be processed by multiple experts, so we accumulate + weighted_output = ( + weights.unsqueeze(-1) * expert_output + ) # [1, capacity] * [capacity, hidden_size] + + # Scatter-add to output tensor + output.index_add_(0, token_indices, weighted_output) + # break + # Reshape back to original shape + output = output.view(batch_size, sequence_length, hidden_dim) + + # print(f"in mlp output_router_logits: {output_router_logits}") + # import pdb; pdb.set_trace() + # print("router logits: ", router_logits) + + return output + + # token choice + def forward_token_choice(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass with improved routing and load balancing. + + Args: + hidden_states: [batch_size, sequence_length, hidden_dim] + + Returns: + output: [batch_size, sequence_length, hidden_dim] + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + num_tokens = batch_size * sequence_length + + # Reshape to [num_tokens, hidden_dim] + hidden_states = hidden_states.view(-1, hidden_dim) + + # Compute routing weights + router_logits_raw, routing_probs = self._compute_routing_weights(hidden_states) + + # Select top-k experts per token + # topk_weights: [num_tokens, top_k] + # selected_experts: [num_tokens, top_k] + topk_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) + + # Normalize top-k probabilities (Fix 1) + if self.norm_topk_prob: + topk_weights = topk_weights / topk_weights.sum( + dim=-1, keepdim=True + ).clamp_min(1e-9) + + # Cast to input dtype + topk_weights = topk_weights.to(hidden_states.dtype) + + # Compute auxiliary loss (Fix 3) + if self.training: + self.aux_loss = self._compute_aux_loss( + routing_probs, selected_experts, num_tokens + ) + else: + self.aux_loss = torch.tensor( + 0.0, device=hidden_states.device, dtype=hidden_states.dtype + ) + + # Initialize output + final_hidden_states = torch.zeros( + (num_tokens, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Efficient per-expert dispatch (Fix 2 - avoid one_hot) + # Process each expert + for expert_idx in range(self.num_experts): + # Find which tokens route to this expert + # expert_mask: [num_tokens, top_k] - boolean mask + expert_mask = selected_experts == expert_idx + + # Get token indices that route to this expert (at any of their top-k positions) + # token_indices: [num_routed_tokens] + token_indices = expert_mask.any(dim=-1).nonzero(as_tuple=True)[0] + + if token_indices.numel() == 0: + # No tokens routed to this expert, skip + continue + + # For each routed token, find which k position corresponds to this expert + # k_positions: [num_routed_tokens] + k_positions = expert_mask[token_indices].float().argmax(dim=-1) + + # Get the routing weights for this expert + # weights: [num_routed_tokens, 1] + weights = topk_weights[token_indices, k_positions].unsqueeze(-1) + + # Get the hidden states for routed tokens + # current_states: [num_routed_tokens, hidden_dim] + current_states = hidden_states.index_select(0, token_indices) + + # Compute expert output + # expert_output: [num_routed_tokens, hidden_dim] + expert_output = self.experts[expert_idx](current_states) + + # Weight the expert output + weighted_output = expert_output * weights + + # Accumulate to final output + final_hidden_states.index_add_( + 0, token_indices, weighted_output.to(hidden_states.dtype) + ) + + # Reshape back to [batch_size, sequence_length, hidden_dim] + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + + return final_hidden_states + + +class LLaDADecoderLayer(nn.Module): + def __init__(self, config: LLaDAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.mlp_type = "moe" + + self.self_attn = LLaDA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = LLaDAMLP(config) + self.moe = LLaDAMoESparseMoeBlock(config) if config.num_experts > 0 else None + self.input_layernorm = LLaDARMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LLaDARMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states_mlp = self.mlp(hidden_states) + if self.moe: + hidden_states_moe = self.moe(hidden_states) + hidden_states = residual + hidden_states_mlp + hidden_states_moe + else: + hidden_states = residual + hidden_states_mlp + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLaDA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LLaDAConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaDA Model outputting raw hidden-states without any specific head on top.", + LLaDA_START_DOCSTRING, +) +class LLaDAPreTrainedModel(PreTrainedModel): + config_class = LLaDAConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LLaDADecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache( + self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None + ): + if ( + self.config._attn_implementation == "flash_attention_2" + and cache_cls == StaticCache + ): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + for layer in self.model.layers: + device = layer.input_layernorm.weight.device + if hasattr(self.config, "_pre_quantization_dtype"): + dtype = self.config._pre_quantization_dtype + else: + dtype = layer.self_attn.o_proj.weight.dtype + layer.self_attn.past_key_value = cache_cls( + self.config, max_batch_size, max_cache_len, device=device, dtype=dtype + ) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + + +LLaDA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaDA Model outputting raw hidden-states without any specific head on top.", + LLaDA_START_DOCSTRING, +) +class LLaDAModel(LLaDAPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaDADecoderLayer`] + + Args: + config: LLaDAConfig + """ + + def __init__(self, config: LLaDAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + LLaDADecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LLaDARMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLaDA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # Add Basic MDM Model config check + assert past_key_values is None and not use_cache, ( + "The kvcache is not suppotred for MDM." + ) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError( + "cache_position is a required argument when using StaticCache." + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, is_causal=False + ) # Modify: MDM + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask( + self, attention_mask, input_tensor, cache_position, is_causal=True + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[-1] + 1 + ) + + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + if is_causal == False: + causal_mask = torch.zeros( + (sequence_length, target_length), dtype=dtype, device=device + ) + + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1 + ) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + # The position with 1 in attention_mask represents the place to be attended to, so here we need to mask the place where attention_mask is 0 + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ + :, None, None, : + ].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[ + ..., :mask_length + ].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # The position with 1 in attention_mask represents the place to be attended to, so here we need to mask the place where attention_mask is 0 + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + +class LLaDAModelLM(LLaDAPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LLaDAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def _build_conversation_mask_optimized(self, conversation_ids): + # Reshape conversation_ids for broadcasting + ids_i = conversation_ids.unsqueeze(-1) # [batch_size, seq_len, 1] + ids_j = conversation_ids.unsqueeze(-2) # [batch_size, 1, seq_len] + + # Use broadcasting to compare all pairs of conversation IDs + conv_mask = ids_j <= ids_i # [batch_size, seq_len, seq_len] + + # Add the attention head dimension + return conv_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len] + + @staticmethod + def add_gumbel_noise(logits, temperature): + """ + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + """ + if temperature == 0: + # When temperature=0, we can directly return the original logits. + # without any noise or transformation + return logits + + # use float64 for more stable computation + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + @staticmethod + def get_num_transfer_tokens(mask_index, steps): + """ + Precompute the number of tokens to transition at each step. + Optimized to be more efficient. + """ + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + + # Create tensor once and modify in-place (via clone) + num_transfer_tokens = base.expand(-1, steps).clone() + + # Handle remainder more efficiently + if remainder.sum() > 0: # Optimization: only proceed if there are remainders + indices = torch.arange(steps, device=mask_index.device) + # Create mask using broadcasting + # indices shape: [steps] -> [1, steps] + # remainder shape: [batch_size, 1] + # mask shape: [batch_size, steps] + mask = indices.unsqueeze(0) < remainder + num_transfer_tokens[mask] += 1 + + return num_transfer_tokens.to(torch.int64) + + @staticmethod + def get_masked_indices_from_embeds(noisy_embeds, masked_embed): + # Get shape information + b, l, d = noisy_embeds.shape + # Expand masked_embed to the same shape as noisy_embeds [b, l, d] + masked_embed_expanded = masked_embed.expand(b, l, d) + # Calculate absolute difference + abs_diff = torch.abs(noisy_embeds - masked_embed_expanded) + # Calculate tolerance boundary (atol + rtol * abs(masked_embed)) + tolerance = 1e-5 + 1e-5 * torch.abs(masked_embed_expanded) + # Check if all dimensions at each position are within tolerance + # all(dim=-1) ensures all dimensions of each embedding meet the condition + masked_indices = (abs_diff <= tolerance).all(dim=-1) + + return masked_indices + + @torch.no_grad() + def generate( + self, + prompt, + steps=128, + gen_length=128, + block_length=128, + temperature=0.0, + cfg_scale=0.0, + remasking="low_confidence", + mask_id=126336, + ): + """ + Args: + prompt: A tensor of shape (1, l). + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + """ + x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to( + prompt.device + ) + x[:, : prompt.shape[1]] = prompt.clone() + + prompt_index = x != mask_id + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = ( + x[ + :, + prompt.shape[1] + num_block * block_length : prompt.shape[1] + + (num_block + 1) * block_length :, + ] + == mask_id + ) + num_transfer_tokens = self.get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = x == mask_id + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self.model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self.model(x).logits + + logits_with_noise = self.add_gumbel_noise( + logits, temperature=temperature + ) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == "low_confidence": + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 + ) # b, l + elif remasking == "random": + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length :] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like( + x0, dtype=torch.bool, device=x0.device + ) + for j in range(confidence.shape[0]): + _, select_index = torch.topk( + confidence[j], k=num_transfer_tokens[j, i] + ) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + + @torch.no_grad() + def generate_with_embeds( + self, + inputs_embeds, + steps=128, + gen_length=128, + block_length=128, + temperature=0.0, + cfg_scale=0.0, + remasking="low_confidence", + mask_id=126336, + tokenizer=None, + stopping_criteria=None, + generation_suffix=None, + **kwargs, + ): + """ + Args: + inputs_embeds: A tensor of shape (1, l, d). + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + generation_suffix: (str or None) Generation suffix, such as "The answer is xxx", will be appended to the end + """ + # Use mixed precision for faster computation + with torch.cuda.amp.autocast(enabled=True): + # Handle generation suffix + suffix_embeds = None + suffix_token_ids = None + suffix_len = 0 + if ( + generation_suffix is not None + and tokenizer is not None + and len(generation_suffix) > 0 + ): + # Encode as token id + suffix_token_ids = tokenizer.encode( + generation_suffix, add_special_tokens=False + ) + suffix_token_ids = torch.tensor( + suffix_token_ids, dtype=torch.long, device=inputs_embeds.device + ).unsqueeze(0) # (1, s) + # Convert to embedding + suffix_embeds = self.model.embed_tokens(suffix_token_ids) # (1, s, d) + suffix_len = suffix_embeds.shape[1] + else: + suffix_len = 0 + + # Create input in embedding space + total_length = inputs_embeds.shape[1] + gen_length + suffix_len + masked_embed = self.model.embed_tokens( + torch.tensor([mask_id]).to(inputs_embeds.device) + ) # shape (1, d) + x_embeds = masked_embed.repeat(1, total_length, 1).to( + inputs_embeds.device + ) # shape (1, l + gen_length + suffix_len, d) + x_embeds[:, : inputs_embeds.shape[1]] = inputs_embeds.clone() + if suffix_embeds is not None: + x_embeds[:, -suffix_len:] = suffix_embeds + + # Create a tracking tensor for token IDs for final output + x = torch.full( + (1, total_length), + mask_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + if suffix_token_ids is not None: + x[:, -suffix_len:] = suffix_token_ids + + # prompt_index: A tensor of shape (1, l + gen_length + suffix_len) where the first l elements are 1 (representing the prompt) + # and the remaining gen_length+suffix_len elements are 0 (representing the generated part) + prompt_index = torch.zeros( + (1, total_length), dtype=torch.bool, device=inputs_embeds.device + ) + prompt_index[:, : inputs_embeds.shape[1]] = ( + 1 # shape (1, l + gen_length + suffix_len) + ) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # New: Initialize stop position variable (default to maximum length) + stop_position = inputs_embeds.shape[1] + gen_length + found_stop_seq = False + + stop_tokens = [] + if stopping_criteria is not None: + assert tokenizer is not None, ( + "tokenizer is required when stopping_criteria is not None" + ) + for stop_str in stopping_criteria: + # Use tokenizer to convert stop words to token IDs + tokens = tokenizer.encode(stop_str, add_special_tokens=False) + stop_tokens.append(tokens) + + feature_cache = dLLMCache() + feature_cache.reset_cache(inputs_embeds.shape[1]) + for num_block in range(num_blocks): + # Create mask index for the current block + block_start = inputs_embeds.shape[1] + num_block * block_length + block_end = inputs_embeds.shape[1] + (num_block + 1) * block_length + + # If a stop word is found and the stop word position is before the current block, do not process the current block + if found_stop_seq and stop_position <= block_start: + break + + block_embeds = x_embeds[:, block_start:block_end] + block_mask_index = torch.all( + torch.abs(block_embeds - masked_embed) < 1e-5, dim=2 + ) + + num_transfer_tokens = self.get_num_transfer_tokens( + block_mask_index, steps + ) + + for i in range(steps): + # Determine which positions are mask embeddings + mask_index = torch.all( + torch.abs(x_embeds - masked_embed) < 1e-5, dim=2 + ) + + # If a stop word has been found, check if the masks before the stop word are all filled + if found_stop_seq: + # Get the mask state before the stop word + pre_stop_masks = mask_index[ + 0, inputs_embeds.shape[1] : stop_position + ] + # If the masks before the stop word are all filled, exit generation + if not pre_stop_masks.any(): + break + + # Check if there are any masks left to fill in the current block + current_block_masks = mask_index[0, block_start:block_end] + if not current_block_masks.any(): + break + + # Handle CFG + if cfg_scale > 0.0: + un_embeds = ( + x_embeds.clone() + ) # shape (1, l + gen_length + suffix_len, d) + un_mask = prompt_index.unsqueeze(-1).expand_as( + x_embeds + ) # shape (1, l + gen_length + suffix_len, d) + un_embeds[un_mask] = masked_embed.repeat( + x_embeds.shape[0], x_embeds.shape[1], 1 + )[un_mask] # Use repeat to avoid the complexity of expand_as + combined_embeds = torch.cat([x_embeds, un_embeds], dim=0) + + # Forward pass + outputs = self.model(inputs_embeds=combined_embeds) + logits = self.lm_head(outputs[0]).float() + + # Split and apply CFG + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + # Forward pass + outputs = self.model(inputs_embeds=x_embeds) + logits = self.lm_head(outputs[0]).float() + + for token_id in [126081, 126080, 126346, 126347]: + logits[:, :, token_id] = torch.where( + mask_index, -float("inf"), logits[:, :, token_id] + ) + + # Add noise and get the most likely token + logits_with_noise = self.add_gumbel_noise( + logits, temperature=temperature + ) # shape (1, l + gen_length + suffix_len, vocab_size) + x0 = torch.argmax( + logits_with_noise, dim=-1 + ) # 1, l + gen_length + suffix_len + + # Get confidence scores + if remasking == "low_confidence": + p = F.softmax( + logits.to(torch.float64), dim=-1 + ) # shape (1, l + gen_length + suffix_len, vocab_size) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 + ) # 1, l + gen_length + suffix_len represents the confidence of each x0 + elif remasking == "random": + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + # If a stop word is found, only process positions before the stop word + if found_stop_seq: + x0_p[:, stop_position:] = -np.inf + else: + # Prevent processing future blocks + x0_p[:, block_end:] = -np.inf + + # Do not allow the generated suffix part to be overwritten + if suffix_len > 0: + x0_p[:, -suffix_len:] = -np.inf + + # Update predictions only at mask positions + x0_embeds = self.model.embed_tokens( + x0 + ) # shape (1, l + gen_length + suffix_len, d) + x0_embeds = torch.where( + mask_index.unsqueeze(-1).expand_as(x_embeds), + x0_embeds, + x_embeds, + ) + x0 = torch.where( + mask_index, x0, x + ) # shape (1, l + gen_length + suffix_len) + + # Calculate confidence and determine transfer index + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like( + x0, dtype=torch.bool, device=x0.device + ) + for j in range(confidence.shape[0]): + _, select_index = torch.topk( + confidence[j], k=num_transfer_tokens[j, i] + ) + transfer_index[j, select_index] = True + + # Update embeddings and token IDs + x_embeds[transfer_index] = x0_embeds[transfer_index] + x[transfer_index] = x0[transfer_index] + + # New: Check for stop words after each update + if stopping_criteria is not None: + # Only check the generated part (excluding the suffix) + generated_part = x[ + 0, + inputs_embeds.shape[1] : inputs_embeds.shape[1] + + gen_length, + ] + current_stop_position = None + + for stop_seq in stop_tokens: + if not isinstance(stop_seq, list): + stop_seq = [stop_seq] + # Check if the generated sequence contains stop words + for start_idx in range( + generated_part.size(0) - len(stop_seq) + 1 + ): + if torch.all( + generated_part[ + start_idx : start_idx + len(stop_seq) + ] + == torch.tensor(stop_seq, device=x.device) + ): + # Calculate the position of the currently found stop word + current_position = ( + inputs_embeds.shape[1] + start_idx + ) + # If it is the first time a stop word is found, or this stop word is earlier than the previously found one + if ( + not found_stop_seq + or current_position < stop_position + ): + stop_position = current_position + found_stop_seq = True + break + if found_stop_seq and current_stop_position is None: + break + + # Return the generated result, up to stop_position, and append the suffix + if found_stop_seq: + if suffix_len > 0: + return torch.cat( + [ + x[:, inputs_embeds.shape[1] : stop_position], + x[:, -suffix_len:], + ], + dim=1, + ) + else: + return x[:, inputs_embeds.shape[1] : stop_position] + else: + if suffix_len > 0: + return torch.cat( + [ + x[ + :, + inputs_embeds.shape[1] : inputs_embeds.shape[1] + + gen_length, + ], + x[:, -suffix_len:], + ], + dim=1, + ) + else: + return x[ + :, inputs_embeds.shape[1] : inputs_embeds.shape[1] + gen_length + ] + + @add_start_docstrings_to_model_forward(LLaDA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + conversation_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + def forward_process_embeds(input_embeds, labels, eps=1e-3): + b, l, d = input_embeds.shape + t = torch.rand(b, device=input_embeds.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_embeds.device) < p_mask + # Add label condition filtering + valid_mask = labels != -100 # Create valid encoding + masked_indices = ( + masked_indices & valid_mask + ) # Combine random encoding and valid encoding + # Magic number 126336 stands for the tokenizer special token, + # Magic embeddings, which is used for [MASK] token here, + masked_embed = self.model.embed_tokens( + torch.tensor([126336]).to(input_embeds.device) + ) + noisy_embeds = torch.where( + masked_indices.unsqueeze(-1), masked_embed, input_embeds + ) + + return noisy_embeds, p_mask, masked_embed + + noisy_embeds, p_mask, masked_embed = forward_process_embeds( + inputs_embeds, labels + ) + + masked_indices = self.get_masked_indices_from_embeds( + noisy_embeds, masked_embed + ) # shape (b, l) + prompt_index = (labels == -100).to(torch.int64) # shape (b, l) + + noisy_data_length = torch.sum( + (1 - prompt_index), dim=-1, keepdim=True + ) # shape (b, 1) + noisy_data_length = noisy_data_length.repeat( + 1, noisy_embeds.shape[1] + ) # shape (b, l) + + if conversation_ids is not None: + conversation_mask = self._build_conversation_mask_optimized( + conversation_ids + ) + if attention_mask is not None: + # 1. Dimension expansion + attention_mask = attention_mask.unsqueeze(1).unsqueeze( + 2 + ) # (batch, length) -> (batch, 1, 1, length) + attention_mask = attention_mask.expand_as( + conversation_mask + ) # (batch, 1, 1, length) -> (batch, 1, length, length) + # 2. Mask combination (element-wise multiplication) + combined_mask = conversation_mask * attention_mask + else: + # If attention_mask is None, directly use conversation_mask + combined_mask = conversation_mask + attention_mask = combined_mask + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=noisy_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Change for MDM + token_loss = ( + F.cross_entropy( + logits[masked_indices], + labels[masked_indices], + ignore_index=-100, + reduction="none", + ) + / p_mask[masked_indices] + ) + loss = ( + torch.sum(token_loss / noisy_data_length[masked_indices]) + / labels.shape[0] + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + **kwargs, + ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr( + self.model.layers[0].self_attn, "past_key_value", None + ) + has_static_cache = past_key_values is not None + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = ( + cache_position[0] + if cache_position is not None + else past_key_values.get_seq_length() + ) + max_cache_length = ( + torch.tensor( + past_key_values.get_max_length(), device=input_ids.device + ) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length + if max_cache_length is None + else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = ( + position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + ) + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + input_length, device=input_ids.device + ) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past diff --git a/train/llava/model/llava_arch.py b/train/llava/model/llava_arch.py new file mode 100644 index 0000000..2467aa7 --- /dev/null +++ b/train/llava/model/llava_arch.py @@ -0,0 +1,701 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +import math +import re +import time +import torch +import torch.nn as nn +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_resampler.builder import build_vision_resampler +from .multimodal_projector.builder import build_vision_projector + +from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +from llava.mm_utils import get_anyres_image_grid_shape +from llava.utils import rank0_print, rank_print +import random + + +class LlavaMetaModel: + + def __init__(self, config): + super(LlavaMetaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + delay_load = getattr(config, "delay_load", False) + self.vision_tower = build_vision_tower(config, delay_load=delay_load) + self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") + self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + + + if not hasattr(self.config, 'add_faster_video'): + if model_args.add_faster_video: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.faster_token = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + + if getattr(self, "mm_projector", None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + + if "unpad" in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") + + def get_w(weights, keyword): + return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} + + incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) + rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) + rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + # Compute aspect ratios + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + # Determine padding size and direction + if original_aspect_ratio > current_aspect_ratio: + # Padding was added to the height + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + # Padding was added to the width + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +class LlavaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_2dPool(self, image_feature, stride=2): + height = width = self.get_vision_tower().num_patches_per_side + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + if self.config.mm_spatial_pool_mode == "average": + image_feature = nn.functional.avg_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "max": + image_feature = nn.functional.max_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "bilinear": + height, width = image_feature.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') + + else: + raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images) + # image_features = self.get_model().vision_resampler(image_features, images=images) + image_features = self.get_model().mm_projector(image_features) + return image_features + + def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): + videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) + per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) + all_videos_or_images_features = [] + all_faster_video_features = [] + cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride + + for idx, feat in enumerate(per_videos_or_images_features): + + feat = self.get_model().mm_projector(feat) + faster_video_feature = 0 + slower_img_feat = 0 + if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1: + slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if self.config.add_faster_video: + cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2 + faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if slower_img_feat != 0: + all_videos_or_images_features.append(slower_img_feat) + else: + all_videos_or_images_features.append(feat) + all_faster_video_features.append(faster_video_feature) + return all_videos_or_images_features,all_faster_video_features + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + feature_dim = image_feature.shape[-1] + + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + if getattr(self.config, "add_faster_video", False): + # import pdb; pdb.set_trace() + # (3584, 832, 14) -> (3584, 64, 13, 14) + image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1) + # (3584, 64, 13, 14) -> (64, 13, 14, 3584) + image_feature = image_feature.permute(1, 2, 3, 0).contiguous() + # (64, 13, 14, 3584) -> (64, 13*14, 3584) + image_feature = image_feature.flatten(1, 2) + # import pdb; pdb.set_trace() + return image_feature + # import pdb; pdb.set_trace() + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def add_token_per_frame(self, image_feature): + image_feature = image_feature.permute(2, 0, 1).contiguous() + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.permute(1, 2, 0).contiguous() + return image_feature + + def generate_conversation_ids(self, labels): + """ + Args: + labels: Label tensor, can be one-dimensional or two-dimensional + Returns: + Conversation ID tensor with the same shape as labels + """ + # Process input dimensions + original_shape = labels.shape + if labels.ndim == 1: + labels = labels.unsqueeze(0) + + batch_size, seq_len = labels.shape + device = labels.device + conversation_ids = torch.zeros_like(labels) + + # Special token IDs + start_header_id = 126346 + eot_id = 126348 + assistant_role_id1 = 598 + + # Process all sequences in batch + for b in range(batch_size): + # Pre-search all special token positions to reduce repeated searches + start_positions = (labels[b] == start_header_id).nonzero(as_tuple=True)[0] + end_positions = (labels[b] == eot_id).nonzero(as_tuple=True)[0] + + # If no boundaries are found, continue to the next sequence + if len(start_positions) == 0 or len(end_positions) == 0: + continue + + # Pair all message start and end positions + message_boundaries = [] + for start_pos in start_positions: + # Find the nearest end position + end_indices = (end_positions >= start_pos).nonzero(as_tuple=True)[0] + if len(end_indices) == 0: + continue + + end_pos = end_positions[end_indices[0]] + + # Quickly check if it's an assistant message + start_idx = start_pos.item() + is_assistant = (start_idx + 1 < seq_len and + labels[b, start_idx + 1] == assistant_role_id1) + + message_boundaries.append((start_idx, end_pos.item(), is_assistant)) + + # Sort by start position + message_boundaries.sort(key=lambda x: x[0]) + + # Determine if there is a system message + has_system = len(message_boundaries) > 0 and not message_boundaries[0][2] + + # Assign conversation turn IDs + current_turn = 0 + prev_was_assistant = False + + # Efficiently handle BOS token (usually at the start of the sequence) + if labels[b, 0] == 126080: # BOS ID + conversation_ids[b, 0] = 0 + + # Assign IDs to all messages at once + for i, (start_pos, end_pos, is_assistant) in enumerate(message_boundaries): + # The first non-assistant message is a system message + is_system = i == 0 and not is_assistant and has_system + + if is_system: + # System message belongs to the first conversation turn + conversation_ids[b, start_pos:end_pos+1] = 0 + else: + # If it's a user message and the previous one was an assistant message, increase the turn + if not is_assistant and prev_was_assistant: + current_turn += 1 + + # Assign ID to the entire message block at once + conversation_ids[b, start_pos:end_pos+1] = current_turn + + prev_was_assistant = is_assistant + + # Fill gaps between messages - use cumulative max method + # This is much faster than looping element by element + for i in range(1, seq_len): + if conversation_ids[b, i] == 0 and conversation_ids[b, i-1] > 0: + conversation_ids[b, i] = conversation_ids[b, i-1] + + # New: Handle end padding + non_zero_mask = (conversation_ids[b] != 0) + if non_zero_mask.any(): + last_non_zero_idx = torch.nonzero(non_zero_mask, as_tuple=True)[0][-1] + last_turn = conversation_ids[b, last_non_zero_idx] + conversation_ids[b, last_non_zero_idx+1:] = last_turn + + # Return a tensor with the same dimensions as the input + if len(original_shape) == 1: + return conversation_ids.squeeze(0) + + return conversation_ids + + def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None, is_llada=False): + vision_tower = self.get_vision_tower() + # rank_print(modalities) + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + if getattr(self.config, "add_faster_video", False): + faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) + # Add a token for each frame + concat_slow_fater_token = [] + # import pdb; pdb.set_trace() + for _ in range(image_feature.shape[0]): + if _ % self.config.faster_token_stride == 0: + concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + else: + concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + # import pdb; pdb.set_trace() + image_feature = torch.cat(concat_slow_fater_token) + + # print("!!!!!!!!!!!!") + + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}") + elif image_feature.shape[0] > 1: # multi patches and multi images operations + # rank0_print("Single-images") + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + + if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = self.get_vision_tower().image_size + else: + raise ValueError("vision_tower_image_size is not found in the vision tower.") + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + unit = image_feature.shape[2] + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + new_image_features.append(image_feature) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + # TODO: Hard code for control loss spike + # if tokenizer_model_max_length is not None: + # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # rank0_print("Prepare pos id") + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + + # add conversation_ids + if is_llada and attention_mask is not None: + conversation_ids = self.generate_conversation_ids(new_labels) + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, conversation_ids + # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False diff --git a/train/llava/model/make_delta.py b/train/llava/model/make_delta.py new file mode 100644 index 0000000..7b3fbab --- /dev/null +++ b/train/llava/model/make_delta.py @@ -0,0 +1,52 @@ +""" +Usage: +python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta +""" + +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from llava.model.utils import auto_upgrade + + +def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading target model") + auto_upgrade(target_model_path) + target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Calculating delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + if name not in base.state_dict(): + assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" + continue + if param.data.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" + bparam = base.state_dict()[name] + param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam + + print("Saving delta") + if hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, default=None) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) diff --git a/train/llava/model/multimodal_encoder/builder.py b/train/llava/model/multimodal_encoder/builder.py new file mode 100644 index 0000000..d1cab60 --- /dev/null +++ b/train/llava/model/multimodal_encoder/builder.py @@ -0,0 +1,36 @@ +import os +from .clip_encoder import CLIPVisionTower +from .imagebind import ImageBindWrapper +from .open_clip_encoder import OpenCLIPVisionTower +from .hf_vision import HFVisionTower +from .siglip_encoder import SigLipVisionTower +from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 + +# from .eva_clip.eva_clip_encoder import EvaClipVisionTower +# from .dev_eva_clip.eva_vit import EvaViTWrapper + + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) + is_absolute_path_exists = os.path.exists(vision_tower) + use_s2 = getattr(vision_tower_cfg, "s2", False) + if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: # is_absolute_path_exists or + if use_s2: + return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) + else: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif "siglip" in vision_tower: + vision_tower = "../model/siglip2-so400m-patch14-384" + return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) + elif vision_tower.startswith("hf:"): + return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif vision_tower in ["imagebind_huge"]: + return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) + elif vision_tower.startswith("open_clip_hub"): + return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): + # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: + # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) + + raise ValueError(f"Unknown vision tower: {vision_tower}") diff --git a/train/llava/model/multimodal_encoder/clip_encoder.py b/train/llava/model/multimodal_encoder/clip_encoder.py new file mode 100644 index 0000000..212b262 --- /dev/null +++ b/train/llava/model/multimodal_encoder/clip_encoder.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +from llava.utils import rank0_print +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + +try: + from s2wrapper import forward as multiscale_forward +except: + pass + + +class CLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + select_layers = [-2, -5, -8, -11, 6] + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) + select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + _hidden_size = self.config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + if "slice_m25811_f6" in self.select_feature: + _hidden_size *= 5 + return _hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.config.image_size + + +class CLIPVisionTowerS2(CLIPVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + + self.s2_scales = getattr(args, "s2_scales", "336,672,1008") + self.s2_scales = list(map(int, self.s2_scales.split(","))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + super().__init__(vision_tower, args, delay_load) + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + self.is_loaded = True + + def forward_feature(self, images): + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + image_features.append(image_feature) + else: + image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/train/llava/model/multimodal_encoder/hf_vision.py b/train/llava/model/multimodal_encoder/hf_vision.py new file mode 100644 index 0000000..a413208 --- /dev/null +++ b/train/llava/model/multimodal_encoder/hf_vision.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor +from llava.utils import rank0_print + + +class HFVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower.replace("hf:", "", 1) + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + self.load_model() + else: + self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) + + def load_model(self): + try: + self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) + except Exception as e: + if "448" in self.vision_tower_name: + image_size = 448 + # use image processor with conig + self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) + else: + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + rank0_print(f"Loaded image processor: {self.image_processor}") + self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") + self.device = self.vision_tower.device + self.dtype = self.vision_tower.dtype + self.config = self.vision_tower.config + + if hasattr(self.vision_tower, "vision_model"): + self.vision_tower = self.vision_tower.vision_model + self.vision_tower.requires_grad_(False) + # self.vision_tower.eval() + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + # @property + # def dtype(self): + # return self.vision_tower.dtype + + # @property + # def device(self): + # return self.vision_tower.device + + @property + def hidden_size(self): + try: + _hidden_size = self.config.hidden_size + except: + _hidden_size = self.config.vision_config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + return _hidden_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def image_size(self): + return self.config.image_size diff --git a/train/llava/model/multimodal_encoder/imagebind.py b/train/llava/model/multimodal_encoder/imagebind.py new file mode 100644 index 0000000..8bbe71c --- /dev/null +++ b/train/llava/model/multimodal_encoder/imagebind.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +from transformers import CLIPImageProcessor + +try: + from imagebind.models import imagebind_model + from imagebind.models.imagebind_model import ModalityType + from imagebind.data import load_and_transform_audio_data +except ImportError: + pass + + +class ImageBindWrapper(nn.Module): + def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = select_layer + self.select_feature = select_feature + + if not delay_load: + self.load_model() + + def load_model(self): + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode=True): + self.training = mode + + if self.is_loaded: + self.vision_tower.eval() + + @torch.no_grad() + def forward(self, x): + if type(x) == dict: + if x["audios"] is not None: + inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} + embeddings = self.vision_tower(inputs) + audio_embedding = embeddings[ModalityType.AUDIO] + return audio_embedding.unsqueeze(1) + else: + inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} + embeddings = self.vision_tower(inputs) + vision_embedding = embeddings[ModalityType.VISION] + if vision_embedding.ndim == 2: + return vision_embedding.unsqueeze(1) + if vision_embedding.shape[1] == 257: + return vision_embedding[:, 1:] + raise ValueError(f"Unexpected shape: {vision_embedding.shape}") + + @property + def dummy_feature(self): + return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.dtype + + @property + def device(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.device + + @property + def hidden_size(self): + return 1024 diff --git a/train/llava/model/multimodal_encoder/open_clip_encoder.py b/train/llava/model/multimodal_encoder/open_clip_encoder.py new file mode 100644 index 0000000..17a3277 --- /dev/null +++ b/train/llava/model/multimodal_encoder/open_clip_encoder.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +from transformers import CLIPImageProcessor +from llava.utils import rank0_print + +try: + import open_clip + import torchvision + from open_clip.transformer import _expand_token +except ImportError: + print("OpenCLIP not installed") + open_clip = None + +HIDDEN_SIZE_DICT = { + "ViT-H-14-378-quickgelu": 1280, +} + + +class OpenCLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.model_name = vision_tower.replace("open_clip_hub:", "") + self.pretrained = args.vision_tower_pretrained + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + + def load_model(self, device_map="auto"): + rank0_print(f"Loading OpenCLIP model: {self.model_name}") + rank0_print(f"Pretrained: {self.pretrained}") + vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") + + resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] + normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] + self.resize_transform_size = resize_transform.size # 224 or 384 + self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 + + self.image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14", + crop_size=resize_transform.size, + size={"shortest_edge": resize_transform.size}, + image_mean=list(normalize_transform.mean), + image_std=list(normalize_transform.std), + ) + rank0_print(f"Loaded image processor: {self.image_processor}") + self.vision_tower = vision_tower.visual + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs[self.select_layer] + if self.select_feature == "patch": + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "conv_flatten": + image_features = image_features.flatten(2).transpose(1, 2) + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward_visual(self, x, output_hidden_states=False): + if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): + return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) + else: + + def forward_openclip(self, x: torch.Tensor): + features = [] + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1, + ) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + for r in self.transformer.resblocks: + x = r(x, attn_mask=None) + features.append(x) + return features + + return forward_openclip(self.vision_tower, x) + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + if hasattr(self.vision_tower, "conv1"): + return self.vision_tower.conv1.weight.dtype + if hasattr(self.vision_tower, "trunk"): + return self.vision_tower.trunk.patch_embed.proj.weight.dtype + raise NotImplementedError + + @property + def device(self): + if hasattr(self.vision_tower, "conv1"): + return self.vision_tower.conv1.weight.device + if hasattr(self.vision_tower, "trunk"): + return self.vision_tower.trunk.patch_embed.proj.weight.device + raise NotImplementedError + + @property + def config(self): + return None + + @property + def hidden_size(self): + if self.model_name in HIDDEN_SIZE_DICT: + return HIDDEN_SIZE_DICT[self.model_name] + else: + raise NotImplementedError + + @property + def num_patches(self): + image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] + _num_patches = (image_size // self.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.resize_transform_size + + @property + def num_patches_per_side(self): + return self.resize_transform_size // self.patch_size diff --git a/train/llava/model/multimodal_encoder/siglip_encoder.py b/train/llava/model/multimodal_encoder/siglip_encoder.py new file mode 100644 index 0000000..f1e101a --- /dev/null +++ b/train/llava/model/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,620 @@ +""" +# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py +""" + +from typing import Optional, Tuple, Union, Dict +from dataclasses import dataclass +from functools import partial, reduce +from PIL import Image +import torch +import torch.utils.checkpoint +from torch import nn +import os +from transformers.image_processing_utils import BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, + to_numpy_array, +) +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers import PretrainedConfig +from transformers.utils import ModelOutput +from llava.utils import rank0_print + + +class SigLipImageProcessor: + def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): + crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.image_mean = image_mean + self.image_std = image_std + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.data_format = data_format + self.crop_size = crop_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + # to adapt video data + images = [to_numpy_array(image) for image in images] + assert isinstance(images, list) + + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), + partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), + ] + + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.image_mean = image_mean + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.") + + return cls.from_dict(config_dict, **kwargs) + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip +class SigLipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SigLipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip +class SigLipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip +class SigLipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SigLipEncoderLayer`]. + + Args: + config: SigLipVisionConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + + self.vision_model = SigLipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SigLipVisionModel + + >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SigLipVisionTower(nn.Module): + def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.config = SigLipVisionConfig() + + self.vision_tower_name = vision_tower + + self.image_processor = SigLipImageProcessor() + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = self.config + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + + del self.vision_tower.vision_model.encoder.layers[-1:] + self.vision_tower.vision_model.head = nn.Identity() + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = image_forward_out.hidden_states[-1].to(image.dtype) + assert image_features.shape[-2] == 729 + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = image_forward_outs.hidden_states[-1].to(images.dtype) + assert image_features.shape[-2] == 729 + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + for p in self.vision_tower.parameters(): + return p.dtype + + @property + def device(self): + for p in self.vision_tower.parameters(): + return p.device + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config.image_size diff --git a/train/llava/model/multimodal_projector/builder.py b/train/llava/model/multimodal_projector/builder.py new file mode 100644 index 0000000..3122a0c --- /dev/null +++ b/train/llava/model/multimodal_projector/builder.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import re + +from .pooler_projector import PoolerProjector + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": "identity"} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) + + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, "mm_projector_type", "linear") + + if projector_type == "linear": + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + if projector_type == "pooler": + return PoolerProjector(config, kwargs["vision_cfg"]) + + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == "identity": + return IdentityMap() + + raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/train/llava/model/multimodal_projector/pooler_projector.py b/train/llava/model/multimodal_projector/pooler_projector.py new file mode 100644 index 0000000..ce5a2e0 --- /dev/null +++ b/train/llava/model/multimodal_projector/pooler_projector.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +import math + +from transformers.models.clip.modeling_clip import CLIPVisionModel + + +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": "pooler"} diff --git a/train/llava/model/multimodal_resampler/builder.py b/train/llava/model/multimodal_resampler/builder.py new file mode 100644 index 0000000..7a4b207 --- /dev/null +++ b/train/llava/model/multimodal_resampler/builder.py @@ -0,0 +1,34 @@ +import torch + +from .masked_drop import MaskedDrop +from .spatial_pool import SpatialPool +from .perceiver import PerceiverResampler +from .qformer import Qformer + + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + + +def build_vision_resampler(model_args, delay_load=False, **kwargs): + resampler_type = getattr(model_args, "mm_resampler_type", None) + if resampler_type == "masked_drop": + return MaskedDrop(model_args) + elif resampler_type == "spatial_pool": + return SpatialPool(model_args, **kwargs) + elif resampler_type == "perceiver": + return PerceiverResampler(model_args, **kwargs) + elif resampler_type == "qformer": + return Qformer(model_args, **kwargs) + elif resampler_type is None: + return IdentityMap() + + raise ValueError(f"Unknown resampler type: {resampler_type}") diff --git a/train/llava/model/multimodal_resampler/masked_drop.py b/train/llava/model/multimodal_resampler/masked_drop.py new file mode 100644 index 0000000..03f0bf0 --- /dev/null +++ b/train/llava/model/multimodal_resampler/masked_drop.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + +import random + + +class MaskedDrop(nn.Module): + def __init__(self, model_args): + super().__init__() + + self.mode = model_args.mm_mask_drop_mode + self.skip_percentage = model_args.mm_mask_drop_skip_percentage + self.ratio = model_args.mm_mask_drop_ratio + self.ratio_upper = model_args.mm_mask_drop_ratio_upper + self.ratio_lower = model_args.mm_mask_drop_ratio_lower + + def forward(self, image_features, *args, **kwargs): + + if not self.training: + return image_features + + if self.skip_percentage > random.random(): + return image_features + + masked_features = [] + + for image_feature in image_features: + num_tokens = image_feature.shape[0] + if self.mode == "fixed": + num_keep = int(num_tokens * self.ratio) + masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) + elif self.mode == "range": + num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) + masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) + elif self.mode == "cls_only": + masked_features.append(image_feature[0:1]) + else: + raise ValueError(f"Unexpected masked drop mode: {self.mode}") + + if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): + masked_features = torch.stack(masked_features, dim=0) + + return masked_features + + @property + def config(self): + return { + "mm_resampler_type": "masked_drop", + "mm_mask_drop_mode": self.mode, + "mm_mask_drop_skip_percentage": self.skip_percentage, + "mm_mask_drop_ratio": self.ratio, + "mm_mask_drop_ratio_upper": self.ratio_upper, + "mm_mask_drop_ratio_lower": self.ratio_lower, + } + + def random_masking(self, x, len_keep): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore diff --git a/train/llava/model/multimodal_resampler/perceiver.py b/train/llava/model/multimodal_resampler/perceiver.py new file mode 100644 index 0000000..d6b17a5 --- /dev/null +++ b/train/llava/model/multimodal_resampler/perceiver.py @@ -0,0 +1,155 @@ +""" +Taken from https://github.com/lucidrains/flamingo-pytorch +""" + +import torch +from einops import rearrange, repeat + +try: + from einops_exts import rearrange_many +except: + pass + +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResamplerModule(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None + self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class PerceiverResampler(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_perceiver_depth + self.num_latents = model_args.mm_perceiver_latents + self.ff_mult = model_args.mm_perceiver_ff_mult + self.pretrained = model_args.mm_perceiver_pretrained + + self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) + + if self.pretrained is not None: + self.load_state_dict(torch.load(self.pretrained)) + + def forward(self, image_features, *args, **kwargs): + return self.perceiver(image_features[:, None, None]).squeeze(1) + + @property + def config(self): + return { + "mm_resampler_type": "perceiver", + "mm_perceiver_depth": self.depth, + "mm_perceiver_latents": self.num_latents, + "mm_perceiver_ff_mult": self.ff_mult, + "mm_perceiver_pretrained": self.pretrained, + } diff --git a/train/llava/model/multimodal_resampler/qformer.py b/train/llava/model/multimodal_resampler/qformer.py new file mode 100644 index 0000000..b86754c --- /dev/null +++ b/train/llava/model/multimodal_resampler/qformer.py @@ -0,0 +1,1160 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert query_embeds is not None, "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qformer(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_qformer_depth + self.num_latents = model_args.mm_qformer_latents + self.pretrained = model_args.mm_qformer_pretrained + + self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents) + + if self.pretrained is not None: + pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] + pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")} + self.load_state_dict(pretrained_dict) + + def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + return Qformer, query_tokens, nn.LayerNorm(vision_width) + + def forward(self, image_features, *args, **kwargs): + x = self.ln_vision(image_features) + image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) + + query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=x, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return query_output.last_hidden_state + + @property + def hidden_size(self): + return 768 + + @property + def config(self): + return { + "mm_resampler_type": "qformer", + "mm_qformer_depth": self.depth, + "mm_qformer_latents": self.num_latents, + "mm_qformer_pretrained": self.pretrained, + } diff --git a/train/llava/model/multimodal_resampler/spatial_pool.py b/train/llava/model/multimodal_resampler/spatial_pool.py new file mode 100644 index 0000000..4bdbe3a --- /dev/null +++ b/train/llava/model/multimodal_resampler/spatial_pool.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import math + + +class SpatialPool(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.mode = model_args.mm_spatial_pool_mode + self.stride = model_args.mm_spatial_pool_stride + self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) + + if self.mode == "average": + self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "max": + self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "conv": + self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) + else: + raise ValueError(f"Unknown pooling mode: {self.pool}.") + + def forward(self, image_features, images, *args, **kwargs): + ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) + ori_H = int(ori_W * images.shape[2] // images.shape[3]) + + B, _, F = image_features.shape + + image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) + image_features_spatial_pool = self.pool(image_features_spatial) + + return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() + + @property + def config(self): + return { + "mm_resampler_type": "spatial_pool", + "mm_spatial_pool_stride": self.stride, + "mm_spatial_pool_mode": self.mode, + "mm_spatial_pool_out_channels": self.out_channels, + } + + @property + def hidden_size(self): + return self.out_channels diff --git a/train/llava/model/simple_heading_mlp.py b/train/llava/model/simple_heading_mlp.py new file mode 100644 index 0000000..d977db6 --- /dev/null +++ b/train/llava/model/simple_heading_mlp.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + + +class TrajectoryHeadingSimpleMLP(nn.Module): + def __init__(self, hidden_layer_dim=512, num_poses=8): + super().__init__() + self.num_poses = num_poses + self.mlp = nn.Sequential( + nn.Linear(8 * 2, hidden_layer_dim), + nn.ReLU(), + nn.Linear(hidden_layer_dim, hidden_layer_dim), + nn.ReLU(), + nn.Linear(hidden_layer_dim, hidden_layer_dim), + nn.ReLU(), + nn.Linear(hidden_layer_dim, num_poses * 3), + ) + + def forward(self, traj_xy): + x = traj_xy.reshape(traj_xy.shape[0], -1) + out = self.mlp(x) + out = out.view(-1, self.num_poses, 3) + heading_pred = torch.tanh(out[:, :, 2:]) * 3.14159 + return heading_pred + + +def load_heading_model( + ckpt_path: str, device: str = "cpu", hidden_layer_dim=512, num_poses=8 +) -> nn.Module: + model = TrajectoryHeadingSimpleMLP( + hidden_layer_dim=hidden_layer_dim, num_poses=num_poses + ) + state_dict = torch.load(ckpt_path, map_location=device) + model.load_state_dict(state_dict) + model.to(device) + model.eval() + return model diff --git a/train/llava/model/utils.py b/train/llava/model/utils.py new file mode 100644 index 0000000..10652a5 --- /dev/null +++ b/train/llava/model/utils.py @@ -0,0 +1,20 @@ +from transformers import AutoConfig + + +def auto_upgrade(config): + cfg = AutoConfig.from_pretrained(config) + if "llava" in config and "llava" not in cfg.model_type: + assert cfg.model_type == "llama" + print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") + print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") + if confirm.lower() in ["y", "yes"]: + print("Upgrading checkpoint...") + assert len(cfg.architectures) == 1 + setattr(cfg.__class__, "model_type", "llava") + cfg.architectures[0] = "LlavaLlamaForCausalLM" + cfg.save_pretrained(config) + print("Checkpoint upgraded.") + else: + print("Checkpoint upgrade aborted.") + exit(1) diff --git a/train/llava/s3_ops.py b/train/llava/s3_ops.py new file mode 100644 index 0000000..8569aed --- /dev/null +++ b/train/llava/s3_ops.py @@ -0,0 +1,217 @@ +# LINT_ME +"""Filesystem operations supporting local and remote (obs://, s3://, gs://) paths.""" +from __future__ import annotations + +import io +import json +import posixpath +from pathlib import Path +from typing import IO, List, Union +from urllib.parse import urlsplit, urlunsplit + +import pyarrow as pa +import pyarrow.parquet as pq +from PIL import Image + +try: + import moxing as mox + _HAS_MOX = True +except ImportError: + mox = None # type: ignore + _HAS_MOX = False + +PathLike = Union[str, Path] + +# ---------- helpers ---------- + + +def _is_remote(p: PathLike) -> bool: + """ + Check if path is a remote storage URL (obs://, s3://, gs://). + """ + s = str(p) + if s.startswith(("obs://", "s3://", "gs://")): + return True + return False + + +def join(base: PathLike, *parts: PathLike) -> str: + """ + Join path components for local or remote paths. + """ + b = str(base) + if _is_remote(b): + scheme, netloc, path, query, frag = urlsplit(b) + segs = [path] + [str(p) for p in parts] + new_path = posixpath.join(*[s.lstrip("/") for s in segs if s]) + if not new_path.startswith("/"): + new_path = "/" + new_path + return urlunsplit((scheme, netloc, new_path, query, frag)) + return str(Path(b).joinpath(*map(str, parts))) + +def exists(p: PathLike) -> bool: + """ + Check if path exists (local or remote). + """ + if _is_remote(p): + if not _HAS_MOX: + raise RuntimeError("moxing not installed.") + return bool(mox.file.exists(str(p))) + return Path(p).exists() + + +def isdir(p: PathLike) -> bool: + """ + Check if path is a directory (local or remote). + """ + if _is_remote(p): + if not _HAS_MOX: + raise RuntimeError("moxing not installed.") + return bool(mox.file.is_directory(str(p))) + return Path(p).is_dir() + + +def listdir(p: PathLike) -> List[str]: + """ + List directory contents for local or obs:// paths. + Remote returns full paths, local returns names (like os.listdir). + """ + if _is_remote(p): + if not _HAS_MOX: + raise RuntimeError("moxing not installed.") + return mox.file.list_directory(str(p)) + return [x.name for x in Path(p).iterdir()] + + +def basename(p: PathLike) -> str: + """ + Return the final component of a path, works for local and obs:// URLs. + """ + s = str(p) + if _is_remote(s): + # Strip trailing slash if present, then take the last segment + return s.rstrip("/").rsplit("/",maxsplit=1)[-1] + return Path(s).name + + +def split(p: PathLike) -> tuple[str, str]: + """ + Split path into (head, tail). + For remote (obs://...), head keeps scheme+netloc+parent, tail is the last component. + For local, same as os.path.split. + """ + s = str(p) + if _is_remote(s): + scheme, netloc, path, query, frag = urlsplit(s) + parts = path.rstrip("/").rsplit("/",maxsplit=1) + tail = parts[-1] if parts else "" + head_path = "/".join(parts[:-1]) + if head_path and not head_path.startswith("/"): + head_path = "/" + head_path + head = urlunsplit((scheme, netloc, head_path, query, frag)) + return head, tail + return str(Path(s).parent), Path(s).name + + +def splitext(p: PathLike) -> tuple[str, str]: + """ + Split path into (root, ext). + Works like os.path.splitext, for both local and remote. + """ + s = str(p) + if _is_remote(s): + base = basename(s) # last part only + root, ext = posixpath.splitext(base) + # Reconstruct full root path (without extension) + head, _ = split(s) + return join(head, root), ext + root, ext = Path(s).with_suffix("").as_posix(), Path(s).suffix + return root, ext + + +def open_file(p: PathLike, mode: str = "r", **kwargs) -> IO: + """ + Open a file for local or remote (obs://, s3://, gs://) paths. + """ + if _is_remote(p): + if not _HAS_MOX: + raise RuntimeError("moxing not installed.") + return mox.file.File(str(p), mode, **kwargs) # type: ignore + return open(p, mode, **kwargs) # pylint: disable=W1514 + + +def json_load(p: PathLike, **json_kwargs): + """ + Load JSON or JSONL files automatically. + Returns: + - list[dict]: for JSONL or list-based JSON + - dict: for JSON object + """ + encoding = json_kwargs.pop("encoding", "utf-8") + + with open_file(p, "r", encoding=encoding) as f: + f.seek(0) + + # JSONL: detect line-delimited JSON + if p.endswith(".jsonl"): + f.seek(0) + return [json.loads(line) for line in f if line.strip()] + f.seek(0) + return json.load(f, **json_kwargs) + + +def image_open(p: PathLike): + """ + Open an image from local or remote path. + """ + + if str(p).endswith(".parquet"): + with open_file(p, "rb") as fp: + file_data = fp.read() + table = pq.read_table(pa.py_buffer(file_data)).to_pandas() + images = {} + for cam_id in table.columns: + data = table[cam_id].iloc[0] # 取第一行的 bytes 数据 + try: + img = Image.open(io.BytesIO(data)) + img.load() # 强制加载,释放文件句柄 + images[cam_id] = img + except Exception as e: + raise RuntimeError( + f"Failed to decode image from camera '{cam_id}' in {p}: {e}" + ) from e + return images["cam_1"] + + f = open_file(p, "rb") + try: + img = Image.open(f) + # Force load into memory so we can close the file immediately + img.load() + f.close() + return img + except Exception: + f.close() + raise + + +def makedirs(p: PathLike, exist_ok: bool = True) -> None: + """ + Create directories (local or remote). + - Local: same as Path(...).mkdir(parents=True, exist_ok=...) + - Remote: uses mox.file.make_dirs + """ + s = str(p) + if _is_remote(s): + if not _HAS_MOX: + raise RuntimeError("moxing not installed.") + # moxing's make_dirs creates all necessary parent dirs automatically + if exist_ok: + # No harm if it already exists + if not mox.file.exists(s): + mox.file.make_dirs(s) + else: + if mox.file.exists(s): + raise FileExistsError(f"Remote path already exists: {s}") + mox.file.make_dirs(s) + else: + Path(s).mkdir(parents=True, exist_ok=exist_ok) diff --git a/train/llava/train/llama_flash_attn_monkey_patch.py b/train/llava/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000..c88fe34 --- /dev/null +++ b/train/llava/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple +import warnings + +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + max_s = q_len + output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/train/llava/train/llava_trainer.py b/train/llava/train/llava_trainer.py new file mode 100644 index 0000000..e989297 --- /dev/null +++ b/train/llava/train/llava_trainer.py @@ -0,0 +1,591 @@ +import os +import torch +import torch.nn as nn +import datetime +import torch.distributed as dist + +from accelerate import Accelerator +from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin +from torch.utils.data import Dataset, Sampler, DataLoader + +from trl.trainer import DPOTrainer +from trl.trainer.utils import DPODataCollatorWithPadding + +from transformers import Trainer +from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, GradientAccumulationPlugin, logger, is_accelerate_available, is_datasets_available +from transformers.trainer_utils import seed_worker +from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf +from transformers.trainer_pt_utils import AcceleratorConfig +from typing import List, Optional +from datetime import timedelta + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs + +if is_datasets_available(): + import datasets + +from llava.utils import rank0_print + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, "no ignore status") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True) + megabatch_size = world_size * batch_size * megabatch_mult + megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches] + shuffled_indices = [i for megabatch in megabatches for i in megabatch] + world_batch_size = world_size * batch_size + batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)] + batch_indices = torch.randperm(len(batches), generator=generator) + batches = [batches[i] for i in batch_indices] + + return [i for batch in batches for i in batch] + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None): + indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator) + + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + batch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in batch_indices] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + # FIXME: Hard code to avoid last batch mixed with different modalities + # if len(additional_batch) > 0: + # megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + variable_length: bool = False, + group_by_modality: bool = False, + group_by_modality_auto: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.variable_length = variable_length + self.group_by_modality = group_by_modality + self.group_by_modality_auto = group_by_modality_auto + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.variable_length: + assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping." + indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + elif self.group_by_modality_auto: + indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class LLaVATrainer(Trainer): + def __init__(self, obs_upload_path=None, **kwargs) -> None: + super().__init__(**kwargs) + self.obs_upload_path = obs_upload_path + + def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + rank0_print("Setting NCCL timeout to INF to avoid running errors.") + + # create accelerator object + self.accelerator = Accelerator( + split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs] + ) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.") + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_length: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + ) + elif self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality=True, + ) + elif self.args.group_by_modality_length_auto: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality_auto=True, + ) + elif self.args.group_by_varlen: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + variable_length=True, + ) + else: + return super()._get_train_sampler() + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None + + if hasattr(self.args, "use_webdataset") and self.args.use_webdataset: + dataloader = DataLoader(train_dataset, **dataloader_params) + else: + dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + return dataloader + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + lr_mapper = {} + if self.args.mm_projector_lr is not None: + lr_mapper["mm_projector"] = self.args.mm_projector_lr + if self.args.mm_vision_tower_lr is not None: + lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr + if len(lr_mapper) > 0: + special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + for module_keyword, lr in lr_mapper.items(): + module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] + optimizer_grouped_parameters.extend( + [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + "lr": lr, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": 0.0, + "lr": lr, + }, + ] + ) + else: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False) or ( + hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts)) + ): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + + else: + #super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.save_model(output_dir, _internal_call=True) + + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + import numpy as np + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + TRAINER_STATE_NAME = "trainer_state.json" + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + if self.obs_upload_path: + if dist.get_rank()==0: + log_dir = os.path.join(os.path.dirname(output_dir),"runs") + mox.file.set_auth(ak="",sk="") + mox.file.copy_parallel(log_dir, self.obs_upload_path+log_dir) + mox.file.copy_parallel(output_dir, self.obs_upload_path+output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + # Solely rely on numerical checkpoint id for rotation. + # mtime is not reliable especially on some fuse fs in cloud environments. + self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) + + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + else: + super(LLaVATrainer, self)._save(output_dir, state_dict) + + +class LLaVADPOTrainer(DPOTrainer): + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + world_size=self.args.world_size, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False) or ( + hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts)) + ): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + else: + # super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics) + # print(type(model)) + # from transformers.modeling_utils import unwrap_model + # print(type(unwrap_model(model))) + # print(unwrap_model(model).config) + if self.args.lora_enable: + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + from transformers.modeling_utils import unwrap_model + + unwrapped_model = unwrap_model(model) + self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model) + else: + super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + else: + super(LLaVADPOTrainer, self)._save(output_dir, state_dict) diff --git a/train/llava/train/llava_trainer_eval.py b/train/llava/train/llava_trainer_eval.py new file mode 100644 index 0000000..e822258 --- /dev/null +++ b/train/llava/train/llava_trainer_eval.py @@ -0,0 +1,76 @@ +import json +import subprocess + +from llava.train.llava_trainer import LLaVATrainer + + +class LLaVAEvalTrainer(LLaVATrainer): + def evaluate(self, evaluate_args): + cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ + --model {evaluate_args.model} \ + --model_args {evaluate_args.model_args} \ + --tasks {evaluate_args.task_names} \ + --batch_size {evaluate_args.batch_size} \ + --log_samples_suffix {evaluate_args.log_samples_suffix} \ + --output_path {evaluate_args.output_path}" + if evaluate_args.limit: + cmd += f" --limit {evaluate_args.limit}" + if evaluate_args.num_fewshot: + cmd += f" --num_fewshot {evaluate_args.num_fewshot}" + if evaluate_args.gen_kwargs != "": + cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" + if evaluate_args.log_samples: + cmd += f" --log_samples" + else: + assert False, "Please log samples so that the result can be parsed" + results = subprocess.run([cmd], shell=True, capture_output=True, text=True) + try: + result_file_index_start = results.stdout.index("Saved samples to ") + result_file_index_end = results.stdout.index(f".json") + result_file_index_start += len("Saved samples to ") + file = results.stdout[result_file_index_start:result_file_index_end] + except: + result_file_index_start = results.stderr.index("Saved samples to ") + result_file_index_end = results.stderr.index(f".json") + result_file_index_start += len("Saved samples to ") + file = results.stderr[result_file_index_start:result_file_index_end] + file = file.split("/")[:-1] + file = "/".join(file) + "/results.json" + with open(file, "r") as f: + lmms_eval_results = json.load(f) + result_dict = {} + tasks_list = evaluate_args.task_names.split(",") + for task in tasks_list: + task_results = lmms_eval_results["results"][task] + for k, v in task_results.items(): + if k != "alias" and "stderr" not in k: + metric = k.split(",")[0] + result_dict[f"{task}_{metric}"] = v + return result_dict + + """def evaluate(self, evaluate_args): + initialize_tasks() + tasks_list = evaluate_args.task_names.split(",") + result_dict = {} + results = evaluator.simple_evaluate( + model=evaluate_args.model, + model_args=evaluate_args.model_args, + tasks=tasks_list, + num_fewshot=evaluate_args.num_fewshot, + batch_size=evaluate_args.batch_size, + device=evaluate_args.device, + limit=evaluate_args.limit, + check_integrity=evaluate_args.check_integrity, + show_task_to_terminal=evaluate_args.show_task_to_terminal, + log_samples=evaluate_args.log_samples, + gen_kwargs=evaluate_args.gen_kwargs, + cli_args=evaluate_args, + ) + for task in tasks_list: + task_results = results["results"][task] + for k,v in task_results.items(): + if k != "alias" and "stderr" not in k: + metric = k.split(",")[0] + result_dict[f"{task}_{metric}"] = v + + return result_dict""" diff --git a/train/llava/train/train.py b/train/llava/train/train.py new file mode 100644 index 0000000..bda1869 --- /dev/null +++ b/train/llava/train/train.py @@ -0,0 +1,1773 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import copy +import logging +import os +import pathlib +import random +import re +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + +import numpy as np +import tokenizers +import torch +import torch.distributed as dist +import transformers +from llava import conversation as conversation_lib +from llava.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + IGNORE_INDEX, + IMAGE_TOKEN_INDEX, +) +from llava.mm_utils import ( + process_anyres_image, + process_highres_image, + process_highres_image_crop_split, + tokenizer_image_token, +) +from llava.model import * +from llava.train.llava_trainer import LLaVATrainer +from llava.utils import ( + rank0_print, + read_frames_gif, + read_frames_pyav, +) +from packaging import version +from PIL import Image, ImageFile +from torch.utils.data import Dataset +from transformers import AutoConfig + +import train.llava.s3_ops as st + +torch.multiprocessing.set_sharing_strategy("file_system") + +ImageFile.LOAD_TRUNCATED_IMAGES = True +local_rank = None + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse( + "0.14" +) + +import warnings + +warnings.filterwarnings("ignore") + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_class_name: Optional[str] = field( + default=None, + metadata={ + "help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama" + }, + ) + + mm_tunable_parts: Optional[str] = field( + default=None, + metadata={ + "help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"' + }, + ) + # deciding which part of the multimodal model to tune, will overwrite other previous settings + + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + tune_mm_vision_resampler: bool = field(default=False) + vision_tower: Optional[str] = field(default=None) + vision_tower_pretrained: Optional[str] = field( + default=None + ) # default to the last layer + + num_token: Optional[bool] = field(default=False) + + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_language_model: bool = field(default=False) + mm_vision_select_layer: Optional[int] = field( + default=-1 + ) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default="linear") + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default="flat") + mm_vision_select_feature: Optional[str] = field(default="patch") + mm_resampler_type: Optional[str] = field(default=None) + mm_mask_drop_mode: str = field(default="fixed") + mm_mask_drop_skip_percentage: float = field(default=0.0) + mm_mask_drop_ratio: float = field(default=0.25) + mm_mask_drop_ratio_upper: Optional[float] = field(default=None) + mm_mask_drop_ratio_lower: Optional[float] = field(default=None) + mm_spatial_pool_stride: Optional[int] = field(default=None) + mm_spatial_pool_mode: str = field(default="bilinear") + mm_spatial_pool_out_channels: Optional[int] = field(default=None) + mm_perceiver_depth: Optional[int] = field(default=3) + mm_perceiver_latents: Optional[int] = field(default=32) + mm_perceiver_ff_mult: Optional[float] = field(default=4) + mm_perceiver_pretrained: Optional[str] = field(default=None) + mm_qformer_depth: Optional[int] = field(default=3) + mm_qformer_latents: Optional[int] = field(default=32) + mm_qformer_pretrained: Optional[str] = field(default=None) + + rope_scaling_factor: Optional[float] = field(default=None) + rope_scaling_type: Optional[str] = field(default=None) + + s2: Optional[bool] = field(default=False) + s2_scales: Optional[str] = field(default="336,672,1008") + + use_pos_skipping: Optional[bool] = field(default=False) + pos_skipping_range: Optional[int] = field(default=4096) + + mm_newline_position: Optional[str] = field(default="grid") + delay_load: Optional[bool] = field(default=True) + add_faster_video: Optional[bool] = field(default=False) + faster_token_stride: Optional[int] = field(default=10) + + num_experts_per_tok: Optional[int] = field(default=-1) + num_experts: Optional[int] = field(default=-1) + moe_choice: Optional[str] = field(default="expert") + moe_router_score_function: Optional[str] = field(default=None) + moe_lora_rank: Optional[int] = field(default=32) + moe_lora_alpha: Optional[int] = field(default=32) + moe_lora_in_features: Optional[int] = field(default=4096) + moe_lora_out_features: Optional[int] = field(default=4096) + moe_lora_dropout: Optional[float] = field(default=0.0) + capacity_factor: Optional[float] = field(default=0.1) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, + metadata={ + "help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json" + }, + ) + + lazy_preprocess: bool = False + is_multimodal: bool = False + early_mix_text: bool = False + use_webdataset: bool = False + total_samples: Optional[int] = field(default=None) + + image_folder: Optional[str] = field(default=None) + image_folder_2: Optional[str] = field( + default=None, + metadata={"help": "Path to the second image folder, used as a fallback."}, + ) + image_aspect_ratio: str = "square" + image_grid_pinpoints: Optional[str] = field(default=None) + image_crop_resolution: Optional[int] = field(default=None) + image_split_resolution: Optional[int] = field(default=None) + + video_folder: Optional[str] = field(default=None) + video_fps: Optional[int] = field(default=1) + frames_upbound: Optional[int] = field(default=0) + add_time_instruction: Optional[bool] = field(default=False) + force_sample: Optional[bool] = field(default=False) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + freeze_mm_vision_resampler: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=4096, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={ + "help": "Compress the quantization statistics through double quantization." + }, + ) + quant_type: str = field( + default="nf4", + metadata={ + "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." + }, + ) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_tower_lr: Optional[float] = None + group_by_varlen: bool = field(default=False) + group_by_modality_length: bool = field(default=False) + group_by_modality_length_auto: bool = field(default=False) + auto_find_batch_size: bool = field(default=False) + gradient_checkpointing: bool = field(default=True) + verbose_logging: bool = field(default=False) + attn_implementation: str = field( + default="flash_attention_2", + metadata={"help": "Use transformers attention implementation."}, + ) + use_conversation_mask: bool = field(default=True) + split_batches: bool = field(default=False) + torch_empty_cache_steps: int = field(default=1) + fp16: bool = field(default=False) + debug_mode: bool = field(default=False) + obs_upload_path: str = field( + default="", metadata={"help": "Path to the obs upload path."} + ) + output_dir: str = field( + default="exp", metadata={"help": "Path to the output directory."} + ) + save_only_model: bool = True + + +# @dataclass +# class EvaluationArguments: +# eval_num_processes: int = field(default=1) +# task_names: str = field(default=None) +# model: str = field(default="llava") +# model_args: Optional[str] = field(default=None) +# num_fewshot: Optional[int] = field(default=None) +# batch_size: int = field(default=1) +# device: Optional[str] = field(default=None) +# limit: Optional[int] = field(default=None) +# check_integrity: Optional[bool] = field(default=False) +# show_task_to_terminal: Optional[bool] = field(default=False) +# log_samples: Optional[bool] = field(default=True) +# gen_kwargs: Optional[str] = field(default="") +# log_samples_suffix: Optional[str] = field(default="") +# output_path: Optional[str] = field(default="./logs/") + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning( + f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" + ) + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = { + k: t + for k, t in named_params + if any(key_match in k for key_match in keys_to_match) + } + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + if ( + hasattr(trainer.args, "tune_mm_mlp_adapter") + and trainer.args.tune_mm_mlp_adapter + ): + check_only_save_mm_adapter_tunnable = True + # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts + elif hasattr(trainer.args, "mm_tunable_parts") and ( + len(trainer.args.mm_tunable_parts.split(",")) == 1 + and ( + "mm_mlp_adapter" in trainer.args.mm_tunable_parts + or "mm_vision_resampler" in trainer.args.mm_tunable_parts + ) + ): + check_only_save_mm_adapter_tunnable = True + else: + check_only_save_mm_adapter_tunnable = False + + trainer.accelerator.wait_for_everyone() + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.synchronize() + # torch.cuda.synchronize() + + rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}") + if check_only_save_mm_adapter_tunnable: + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3( + trainer.model.named_parameters(), keys_to_match + ) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save( + weight_to_save, + os.path.join(mm_projector_folder, f"{current_folder}.bin"), + ) + else: + torch.save( + weight_to_save, os.path.join(output_dir, f"mm_projector.bin") + ) + return + + if trainer.deepspeed: + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = "unknown" + sentence["value"] = ( + BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + ) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + # TODO maybe this should be changed for interleaved data? + # if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN): + # only check for num_im=1 + num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) + if ( + num_im == 1 + and DEFAULT_IMAGE_TOKEN in sentence["value"] + and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN) + ): + sentence["value"] = ( + sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + ) + sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] + sentence["value"] = sentence["value"].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence["value"] = sentence["value"].replace( + DEFAULT_IMAGE_TOKEN, + "" + DEFAULT_IMAGE_TOKEN + "", + ) + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + sentence["value"] = sentence["value"].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + + # For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here. + sentence["value"] = sentence["value"].replace( + "QA_GT_caption_based_noisy", "" + ) + + return sources + + +def preprocess_llada( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + max_len=2048, + system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", +) -> Dict: + # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"} + roles = {"human": "user", "gpt": "assistant", "system": "system"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + # tokenizer = copy.deepcopy(tokenizer) + # # When there is actually an image, we add the image tokens as a special token + # if has_image: + # tokenizer.add_tokens([""], special_tokens=True) + image_token_index = tokenizer.convert_tokens_to_ids("") + bos_token_id = tokenizer.convert_tokens_to_ids("<|startoftext|>") + start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") + end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") + eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") + + unmask_tokens = [ + "<|startoftext|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eot_id|>", + "\n\n", + ] + unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] + # Reset LLaDA chat templates so that it won't include assistant message every time we apply + chat_template = "{% for message in messages %}{{'<|startoftext|>' + '<|start_header_id|>' + message['role'] + '<|end_header_id|>' + '\n\n' + message['content'] + '<|eot_id|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + tokenizer.chat_template = chat_template + + # After update, calling tokenizer of llama3 will + # auto add bos id for the tokens. ヽ(`⌒´)ノ + def safe_tokenizer_llama3(text): + input_ids = tokenizer(text).input_ids + if input_ids[0] == bos_token_id: + input_ids = input_ids[1:] + return input_ids + + nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] == roles["system"]: + try: + system_message = source[0]["content"] + except: + system_message = source[0]["value"] + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template( + [{"role": "system", "content": system_message}] + ) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + if "video" in content: + content = content.replace("