Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] 多机进行MPO时出现shape mismatch问题 #856

Open
3 tasks done
JackeyHRan opened this issue Jan 15, 2025 · 0 comments
Open
3 tasks done

[Bug] 多机进行MPO时出现shape mismatch问题 #856

JackeyHRan opened this issue Jan 15, 2025 · 0 comments

Comments

@JackeyHRan
Copy link

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

当我使用单机8卡进行InternVL2.5-8B-MPO时,可以正常训练,但当我使用3机24卡时,出现warning: shape mismatch: value tensor of shape [4608, 4096] cannot be broadcast to indexing result of shape [1098, 4096], input_embeds[selected].shape=torch.Size([1098, 4096]), vit_embeds.shape=torch.Size([4608, 4096]),并且loss为0

Reproduction

torchrun --master_port=${MASTER_PORT}
--nnodes=${NNODES}
--node_rank=${NODE_RANK}
--master_addr=${MASTER_ADDR}
--nproc_per_node=${NPROC_PER_NODE}
internvl/train/internvl_chat_dpo.py
--model_name_or_path "./pretrained/InternVL2_5-8B-MPO"
--conv_style "internvl2_5"
--use_fast_tokenizer False
--output_dir ${OUTPUT_DIR}
--meta_path ./shell/data/relevance_v2.4.1_dpo.json
--overwrite_output_dir True
--force_image_size 448
--down_sample_ratio 0.5
--drop_path_rate 0.1
--pad2square False
--freeze_llm False
--freeze_mlp False
--freeze_backbone False
--vision_select_layer -1
--use_data_resampling False
--dataloader_num_workers 8
--bf16 True
--num_train_epochs 1
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE}
--gradient_accumulation_steps ${GRADIENT_ACC}
--evaluation_strategy "no"
--save_strategy "no"
--save_steps 100
--save_total_limit 100
--learning_rate 1e-6
--weight_decay 0.05
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--max_seq_length 1024
--do_train True
--grad_checkpoint True
--group_by_length False
--dynamic_image_size True
--use_thumbnail True
--ps_version 'v2'
--deepspeed "zero_stage1_config.json"
--report_to "tensorboard"
--loss_type sigmoid,bco_pair
--sigmoid_loss_weight 0.8
--bco_pair_loss_weight 0.2
--rpo_alpha 1
--use_liger True
2>&1 | tee -a "$LOG_FILE"

Environment

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
absl-py                   2.1.0                    pypi_0    pypi
accelerate                1.1.1                    pypi_0    pypi
addict                    2.4.0                    pypi_0    pypi
aiohappyeyeballs          2.4.4                    pypi_0    pypi
aiohttp                   3.11.9                   pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
blas                      2.116                       mkl    conda-forge
blas-devel                3.9.0            16_linux64_mkl    conda-forge
boto3                     1.35.73                  pypi_0    pypi
botocore                  1.35.73                  pypi_0    pypi
brotli-python             1.1.0           py311hfdbb021_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
cffi                      1.17.1          py311hf29c0ef_0    conda-forge
charset-normalizer        3.4.0              pyhd8ed1ab_0    conda-forge
coloredlogs               15.0.1                   pypi_0    pypi
contourpy                 1.3.1                    pypi_0    pypi
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.6.77                       0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cuda-version              12.6                          3    nvidia
cycler                    0.12.1                   pypi_0    pypi
datasets                  3.1.0                    pypi_0    pypi
decord                    0.6.0                    pypi_0    pypi
deepspeed                 0.15.4                   pypi_0    pypi
dill                      0.3.8                    pypi_0    pypi
docstring-parser          0.16                     pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
environs                  11.0.0                   pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.16.1             pyhd8ed1ab_0    conda-forge
flash-attn                2.7.0.post2              pypi_0    pypi
fonttools                 4.55.0                   pypi_0    pypi
freetype                  2.12.1               h267a509_2    conda-forge
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.9.0                 pypi_0    pypi
giflib                    5.2.2                hd590300_0    conda-forge
gmp                       6.3.0                hac33072_2    conda-forge
gmpy2                     2.1.5           py311h0f6cedb_2    conda-forge
gnutls                    3.6.13               h85f3911_1    conda-forge
grpcio                    1.68.1                   pypi_0    pypi
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
hjson                     3.1.0                    pypi_0    pypi
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
huggingface-hub           0.26.3                   pypi_0    pypi
humanfriendly             10.0                     pypi_0    pypi
humanize                  4.11.0                   pypi_0    pypi
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.10               pyhd8ed1ab_0    conda-forge
imageio                   2.36.1                   pypi_0    pypi
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
jmespath                  1.0.1                    pypi_0    pypi
jpeg                      9e                   h166bdaf_2    conda-forge
kiwisolver                1.4.7                    pypi_0    pypi
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.15                 hfd0df8a_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcublas                 12.1.0.26                     0    nvidia
libcufft                  11.0.2.4                      0    nvidia
libcufile                 1.11.1.6                      0    nvidia
libcurand                 10.3.7.77                     0    nvidia
libcusolver               11.4.4.55                     0    nvidia
libcusparse               12.0.2.55                     0    nvidia
libdeflate                1.17                 h0b41bf4_0    conda-forge
libexpat                  2.6.4                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran-ng            14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libhwloc                  2.11.2          default_he43201b_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
liblapack                 3.9.0            16_linux64_mkl    conda-forge
liblapacke                3.9.0            16_linux64_mkl    conda-forge
libnpp                    12.0.2.50                     0    nvidia
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.1.105                      0    nvidia
libnvjpeg                 12.1.1.14                     0    nvidia
libpng                    1.6.43               h2797004_0    conda-forge
libsqlite                 3.46.0               hde9e2c9_0    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libtiff                   4.5.0                h6adf6a1_2    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp                   1.2.4                h1daa5a0_1    conda-forge
libwebp-base              1.2.4                h166bdaf_0    conda-forge
libxcb                    1.13              h7f98852_1004    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.7               hc051c1a_1    conda-forge
libzlib                   1.2.13               h4ab18f5_6    conda-forge
liger-kernel              0.4.2                    pypi_0    pypi
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
lmdeploy                  0.6.4                    pypi_0    pypi
loguru                    0.7.2                    pypi_0    pypi
markdown                  3.7                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                3.0.2           py311h2dc5d0c_0    conda-forge
marshmallow               3.23.1                   pypi_0    pypi
matplotlib                3.9.3                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mkl                       2022.1.0           h84fe81f_915    conda-forge
mkl-devel                 2022.1.0           ha770c72_916    conda-forge
mkl-include               2022.1.0           h84fe81f_915    conda-forge
mmengine                  0.10.5                   pypi_0    pypi
mpc                       1.3.1                h24ddda3_1    conda-forge
mpfr                      4.2.1                h90cbb55_3    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.1.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
multiprocessing-logging   0.3.4                    pypi_0    pypi
ncurses                   6.5                  he02047a_1    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  3.4.2              pyh267e887_2    conda-forge
ninja                     1.11.1.2                 pypi_0    pypi
numpy                     2.1.3           py311h71ddf71_0    conda-forge
opencv-python             4.10.0.84                pypi_0    pypi
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.4.0                hb9d3cd8_0    conda-forge
packaging                 24.2                     pypi_0    pypi
pandas                    2.2.3                    pypi_0    pypi
peft                      0.13.2                   pypi_0    pypi
pillow                    9.4.0           py311h50def17_1    conda-forge
pip                       24.3.1             pyh8b19718_0    conda-forge
platformdirs              4.3.6                    pypi_0    pypi
propcache                 0.2.1                    pypi_0    pypi
protobuf                  5.29.0                   pypi_0    pypi
psutil                    6.1.0                    pypi_0    pypi
pthread-stubs             0.4               hb9d3cd8_1002    conda-forge
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   18.1.0                   pypi_0    pypi
pycparser                 2.22               pyhd8ed1ab_0    conda-forge
pydantic                  2.10.2                   pypi_0    pypi
pydantic-core             2.27.1                   pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
pyparsing                 3.2.0                    pypi_0    pypi
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.11.9          hb806964_0_cpython    conda-forge
python-dateutil           2.9.0.post0              pypi_0    pypi
python-dotenv             1.0.1                    pypi_0    pypi
python_abi                3.11                    5_cp311    conda-forge
pytorch                   2.5.1           py3.11_cuda12.1_cudnn9.1.0_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_6    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.2                   pypi_0    pypi
pyyaml                    6.0.2           py311h9ecbd09_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
rich                      13.9.4                   pypi_0    pypi
s3transfer                0.10.4                   pypi_0    pypi
safetensors               0.4.5                    pypi_0    pypi
sentencepiece             0.2.0                    pypi_0    pypi
setuptools                75.6.0             pyhff2d567_1    conda-forge
shtab                     1.7.1                    pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tbb                       2021.13.0            hceb3a55_1    conda-forge
tensorboard               2.18.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
termcolor                 2.5.0                    pypi_0    pypi
timm                      1.0.11                   pypi_0    pypi
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tokenizers                0.20.3                   pypi_0    pypi
torchaudio                2.5.1               py311_cu121    pytorch
torchtriton               3.1.0                     py311    pytorch
torchvision               0.20.1              py311_cu121    pytorch
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.45.1                   pypi_0    pypi
trl                       0.10.1                   pypi_0    pypi
typeguard                 4.4.1                    pypi_0    pypi
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tyro                      0.9.2                    pypi_0    pypi
tzdata                    2024.2                   pypi_0    pypi
urllib3                   2.2.3              pyhd8ed1ab_0    conda-forge
werkzeug                  3.1.3                    pypi_0    pypi
wheel                     0.45.1             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hb9d3cd8_1    conda-forge
xorg-libxdmcp             1.1.5                hb9d3cd8_0    conda-forge
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
yapf                      0.43.0                   pypi_0    pypi
yarl                      1.18.3                   pypi_0    pypi
zlib                      1.2.13               h4ab18f5_6    conda-forge
zstandard                 0.23.0          py311hbc35293_1    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

Error traceback

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant