Skip to content

Commit

Permalink
commented out unnecessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
yfzhang114 committed Mar 28, 2024
1 parent 20fe460 commit 27e4892
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
<a target="_blank"><img src="figs/VCD_logo_title.png" alt="Visual Contrastive Decoding" style="width: 75%; min-width: 200px; display: block; margin: auto;"></a>
</p> -->

# Debiasing Large Visual Language Models
# Debiasing Large Visual Language Models / Debiasing Multimodal Large Language Models
<!-- **Debiasing Large Visual Language Models** -->
This is the official repo for Debiasing Large Visual Language Models, including a Post-Hoc debias method and Visual Debias Decoding strategy. These strategies not only prove beneficial in minimizing hallucinations but also contribute to the generation of more helpful and precise illustrations

## 🔥 Update
* [2024-03-08]: ⭐️ Paper online. Check out [Debiasing Large Visual Language Models](https://arxiv.org/abs/2403.05262) for details.
* [2024-03-08]: ⭐️ Paper online. Check out [Debiasing Multimodal Large Language Models](https://arxiv.org/abs/2403.05262) for details.
* [2024-03-11]: 🚀🚀 Codes released.

## 🎯 Overview
Expand Down
2 changes: 1 addition & 1 deletion experiments/scripts/mme/run_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ model=liuhaotian/llava-v1.5-${size}b

# using SFT model
root=your_cache_dir
model=$root/LLaVA-RLHF-7b-v1.5-224/sft_model
model=liuhaotian/llava-v1.5-7b # $root/LLaVA-RLHF-7b-v1.5-224/sft_model

# naive
python eval/MME/run_llava.py \
Expand Down
54 changes: 27 additions & 27 deletions vcd_utils/vcd_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,33 +206,33 @@ def sample(
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

use_calibrate = model_kwargs.get("use_calibrate", False)
if use_calibrate:
model_kwargs_custom = model_kwargs.copy()
model_inputs_custom = self.prepare_inputs_for_generation_custom(input_ids, **model_kwargs_custom)
outputs_custom = self(
**model_inputs_custom,
return_dict=True,
output_attentions=output_attentions_wo_img,
output_hidden_states=output_hidden_states_wo_img,
)
next_token_logits_custom = outputs_custom.logits[:, -1, :]
# use_calibrate = model_kwargs.get("use_calibrate", False)
# if use_calibrate:
# model_kwargs_custom = model_kwargs.copy()
# model_inputs_custom = self.prepare_inputs_for_generation_custom(input_ids, **model_kwargs_custom)
# outputs_custom = self(
# **model_inputs_custom,
# return_dict=True,
# output_attentions=output_attentions_wo_img,
# output_hidden_states=output_hidden_states_wo_img,
# )
# next_token_logits_custom = outputs_custom.logits[:, -1, :]

cb_cut_weight = model_kwargs.get("cb_cut_weight") if model_kwargs.get("cb_cut_weight") is not None else 0.5
cb_m_weight = model_kwargs.get("cb_m_weight") if model_kwargs.get("cb_m_weight") is not None else 0.5
# cb_cut_weight = model_kwargs.get("cb_cut_weight") if model_kwargs.get("cb_cut_weight") is not None else 0.5
# cb_m_weight = model_kwargs.get("cb_m_weight") if model_kwargs.get("cb_m_weight") is not None else 0.5

cutoff = cb_cut_weight * next_token_logits.max(dim=-1, keepdim=True).values
next_token_logits = next_token_logits.masked_fill(next_token_logits < cutoff, -float("inf"))
# print(f'cnt non inf {torch.sum(next_token_logits != -float("inf")).item()}')
next_token_logits[:,eos_token_id[0] + 1:] = next_token_logits[:,eos_token_id[0] + 1:] - cb_m_weight*next_token_logits_custom[:,eos_token_id[0] + 1:]
# custom_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
# cutoff = cb_cut_weight * next_token_logits.max(dim=-1, keepdim=True).values
# next_token_logits = next_token_logits.masked_fill(next_token_logits < cutoff, -float("inf"))
# # print(f'cnt non inf {torch.sum(next_token_logits != -float("inf")).item()}')
# next_token_logits[:,eos_token_id[0] + 1:] = next_token_logits[:,eos_token_id[0] + 1:] - cb_m_weight*next_token_logits_custom[:,eos_token_id[0] + 1:]
# # custom_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))

custom_logits = logits_processor(input_ids, next_token_logits)
custom_logits = logits_warper(input_ids, custom_logits)
# custom_logits = logits_processor(input_ids, next_token_logits)
# custom_logits = logits_warper(input_ids, custom_logits)

next_token_scores = custom_logits
probs = nn.functional.softmax(custom_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# next_token_scores = custom_logits
# probs = nn.functional.softmax(custom_logits, dim=-1)
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
Expand Down Expand Up @@ -276,10 +276,10 @@ def sample(
outputs_dd, model_kwargs_dd, is_encoder_decoder=self.config.is_encoder_decoder
)

if use_calibrate:
model_kwargs_custom = self._update_model_kwargs_for_generation(
outputs_custom, model_kwargs_custom, is_encoder_decoder=self.config.is_encoder_decoder
)
# if use_calibrate:
# model_kwargs_custom = self._update_model_kwargs_for_generation(
# outputs_custom, model_kwargs_custom, is_encoder_decoder=self.config.is_encoder_decoder
# )

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
Expand Down

0 comments on commit 27e4892

Please sign in to comment.