Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unintuitive Logic in masked_fill Function of DistilBERT Model Implementation #2721

Open
sondalex opened this issue Jan 16, 2025 · 0 comments

Comments

@sondalex
Copy link

sondalex commented Jan 16, 2025

masked_fill function of distilbert model implementation has currently unintuitive logic

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

In the current setup, the user must invert the attention mask obtained from the tokenizer before passing it to the model.forward function. This requirement can be confusing as it differs from transformers implementation.

...
let text: Vec<&str>  = vec![...];
let encoded = tokenizer.encode_batch(text.to_vec().clone(), true)?;
let input_ids = encoded.iter().map(|v| v.get_ids().to_vec()).collect::<Vec<_>>();
let input_ids = Tensor::new(input_ids, &device)?;
let attention_mask = encoded.iter().map(|encoding| encoding.get_attention_mask().to_vec()).collect::<Vec<_>>();
let attention_mask = Tensor::new(attention_mask, &device)?;

let (batch_size, feature_size) = input_ids.dims2()?;

// Invert the attention mask for correct behavior --> Counterintuitive
let attention_mask = attention_mask.eq(0 as u32)?.reshape((batch_size, 1, 1, feature_size))?;

let output = model.forward(&input_ids, &attention_mask)?;
...

Proposition:

Replace masked_fill function with:

fn masked_fill(on_true: &Tensor, mask: &Tensor, on_false: f32) -> Result<Tensor> {
    let shape = mask.shape();
    let on_false = Tensor::new(on_false, on_true.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, &on_false)?;
    Ok(m)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant