From e8991a86f1255b15d308a44509feb6e772dfca2d Mon Sep 17 00:00:00 2001 From: ryan-qiyu-jiang Date: Thu, 16 Dec 2021 12:54:07 -0800 Subject: [PATCH] [docs] Add ViLT model code usage example Add an example of running ViLT from code outside of MMF cli at the end of the ViLT tutorial. Example is of ViLT vqa on raw image and text. [ghstack-poisoned] --- website/docs/projects/vilt.md | 100 ++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/website/docs/projects/vilt.md b/website/docs/projects/vilt.md index 4a16923fa..7da003e1c 100644 --- a/website/docs/projects/vilt.md +++ b/website/docs/projects/vilt.md @@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset, ``` mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt ``` + +## Using the ViLT model from code +Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text. +The forward pass takes ~15ms which is very fast compared to UNITER's ~600ms. + +```python +from argparse import Namespace + +import torch +from mmf.common.sample import SampleList +from mmf.datasets.processors.bert_processors import VILTTextTokenizer +from mmf.datasets.processors.image_processors import VILTImageProcessor +from mmf.utils.build import build_model +from mmf.utils.configuration import Configuration, load_yaml +from mmf.utils.general import get_current_device +from mmf.utils.text import VocabDict +from omegaconf import OmegaConf +from PIL import Image +``` + +A way to make model configs and instantiate the ViLT model. +```python +# make model config for vilt vqa2 +model_name = "vilt" +config_args = Namespace( + config_override=None, + opts=["model=vilt", "dataset=vqa2", "config=configs/defaults.yaml"], +) +default_config = Configuration(config_args).get_config() +model_vqa_config = load_yaml( + "/private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml" +) +config = OmegaConf.merge(default_config, model_vqa_config) +OmegaConf.resolve(config) +model_config = config.model_config[model_name] +model_config.model = model_name +vilt_model = build_model(model_config) +``` + +Load model weights, `model_checkpoint_path` is the model checkpoint downloaded at model zoo path `vilt.vqa`, +with current url `s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz` +```python +# build model and load weights +model_checkpoint_path = './vilt_vqa2.pth' +state_dict = torch.load(model_checkpoint_path) +vilt_model.load_state_dict(state_dict, strict=False) +vilt_model.eval() +vilt_model = vilt_model.to(get_current_device()) +``` + +Prepare input image and text. +This example is using an image of a man with a hat kissing his daughter. +The text is the question posed to the ViLT model for visual question answering. +```python +# get image input +image_processor = VILTImageProcessor({"size": [384, 384]}) +image_path = "./kissing_image.jpg" +raw_img = Image.open(image_path).convert("RGB") +image = image_processor(raw_img) + +# get text input +text_tokenizer = VILTTextTokenizer({}) +question = "What is on his head?" +processed_text_dict = text_tokenizer({"text": question}) +``` + +Wrap everything up in a sample list as expected by the ViLT BaseModel. +```python +# make batch inputs +sample_dict = {**processed_text_dict, "image": image} +sample_dict = { + k: v.unsqueeze(0) for k, v in sample_dict.items() if isinstance(v, torch.Tensor) +} +sample_dict["targets"] = torch.zeros((1, 3129)) +sample_dict["targets"][0,1358] = 1 +sample_dict["dataset_name"] = "vqa2" +sample_dict["dataset_type"] = "test" +sample_list = SampleList(sample_dict).to(get_current_device()) +``` + +Load the vqa answer -> word string map to understand what it says! +Currently file url at `s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt` +```python +# load vqa2 id -> answers +vocab_file_path = "/private/home/path/to/answers_vqa.txt" +answer_vocab = VocabDict(vocab_file_path) +``` + +And heres the part you've been waiting for! +```python +# do prediction +with torch.no_grad(): + vqa_logits = vilt_model(sample_list)["scores"] + answer_id = vqa_logits.argmax().item() + answer = answer_vocab.idx2word(answer_id) + print(chr(27) + "[2J") # clear the terminal + print(f"{question}: {answer}") +``` + +Expected output `What is on his head?: hat`