Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on my own data? #1

Open
puzzlecollector opened this issue Jun 23, 2021 · 21 comments
Open

Inference on my own data? #1

puzzlecollector opened this issue Jun 23, 2021 · 21 comments

Comments

@puzzlecollector
Copy link

Hello! First of all thank you so much for your work. I have read your paper and I want to carry out some open-ended VQA/answer generation VQA experiments with the model you proposed (VL-T5). However I am unsure where to start with the provided code. Would it be possible for you to provide an example code for extracting image features and text features for a custom dataset? (not data in VQA v2.0). I want to test if it it can generate answers based on images and questions that I have prepared.

Thank you so much and I am so sorry for troubling you.

@j-min
Copy link
Owner

j-min commented Jun 23, 2021

  1. Feature extraction
    You can refer to these lines, which extracts VG-trained faster r-cnn features from images based on this repo.
  2. datum/batch creation
    You can refer __getitem__ and collate_fn methods of Dataset class in vqa_data.py to create batches of your custom vqa input. Make sure each batch is a dictionary containing keys 'vis_feats', 'input_ids', 'boxes'.
  3. Inference
    Then you can give the batches to the models created or edited from vqa_model.py.

@puzzlecollector
Copy link
Author

@j-min thanks for the kind reply. When you have time would it be possible for you to provide a jupyter (or colab) example of how to make inferences on the given image-question pair? I think it would greatly help people like me. For example, like this detectron2 tutorial.

Once again, thank you so much for your work!

@puzzlecollector
Copy link
Author

@j-min Also, what information does 'vis_feats' actually hold?

The input_ids must be the ids of the tokenized question and the boxes must be the bounding box coordinates of the objects detected in the image right? I am unsure of what information vis_feats contains. Not very familiar with faster R-CNN or detectron2 so I hope you'll understand even if my question is very basic :(

@j-min
Copy link
Owner

j-min commented Jun 28, 2021

As you can see in the __getitem__ of the Dataset class, vis_feats is 2048-dim feature from Faster R-CNN. boxes are the 4-point coordinates of bounding boxes. The Faster R-CNN features are often called 'Bottom-up' features, in that the features are not conditioned by text input (similarly, text-conditioned features are called 'top-down'). I suggest you check out Peter Anderson's original repo and his paper. The bottom-up features have been widely used in vision-and-language research.

@puzzlecollector
Copy link
Author

@j-min
I tried using the feature extraction used in LXMERT for VL-T5 and I am using the following code to test on my data sample, but I am getting an error "TypeError: forward() got an unexpected keyword argument 'head_mask' ". Could you help me resolve this issue? Once again, I'm so sorry for troubling you :(

This is the faster R-CNN provided by the huggingface repo from LXMERT.

from IPython.display import clear_output, Image, display 
import PIL.Image 
import io 
import json 
import torch 
import torch.nn as nn 
import numpy as np 
import pandas as pd 
import wget 
import pickle 
import os 
import random 
import math 
import urllib.request 
from processing_image import Preprocess
from visualizing_image import SingleImageViz 
from modeling_frcnn import GeneralizedRCNN 
import utils 
from utils import Config 

# URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg",
URL = "https://vqa.cloudcv.org/media/test2014/COCO_test2014_000000262567.jpg"
OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
GQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/gqa/trainval_label2ans.json"
VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"

# load object, attribute and answer labels 

objids = utils.get_data(OBJ_URL) 
attrids = utils.get_data(ATTR_URL)
gqa_answers = utils.get_data(GQA_URL) 
vqa_answers = utils.get_data(VQA_URL) 
frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg) 
image_preprocess = Preprocess(frcnn_cfg) 

# used for visualizing output 
def showarray(a, fmt='jpeg'): 
    a = np.uint8(np.clip(a,0,255)) 
    f = io.BytesIO() 
    PIL.Image.fromarray(a).save(f,fmt) 
    display(Image(data=f.getvalue()))


image_dirname = 'stop_sign.png'
frcnn_visualizer = SingleImageViz(image_dirname, id2obj=objids, id2attr=attrids) 

images, sizes, scales_yx = image_preprocess(image_dirname) 

output_dict = frcnn(
    images, 
    sizes, 
    scales_yx = scales_yx, 
    padding = 'max_detections', 
    max_detections = frcnn_cfg.max_detections, 
    return_tensors = 'pt' 
)

# add boxes and labels to the image 
frcnn_visualizer.draw_boxes(
    output_dict.get("boxes"), 
    output_dict.get("obj_ids"),
    output_dict.get("obj_probs"),
    output_dict.get("attr_ids"), 
    output_dict.get("attr_probs"),
)

showarray(frcnn_visualizer._get_buffer())

normalized_boxes = output_dict.get("normalized_boxes") 
features = output_dict.get("roi_features") 

from tokenization import VLT5Tokenizer 

tokenizer = VLT5Tokenizer.from_pretrained('t5-base')

question = "vqa:What are the men doing?" 

outputs = tokenizer.encode_plus(question, 
                                max_length=100,
                                truncation=True, 
                                return_attention_mask=True,
                                return_token_type_ids=True) 


input_ids = outputs['input_ids'] 
batch = {}
batch['vis_feats'] = features 
batch['input_ids'] = torch.tensor(input_ids).reshape((1,-1)) 
batch['boxes'] = normalized_boxes

# checkpoint provided in your repo
checkpoint = torch.load('VL_T5_Epoch30.pth')

from vqa_model import VLT5VQA 

from param import parse_args  
args = parse_args( 
    backbone='t5-base', # Backbone architecture
    load=checkpoint, # Pretrained checkpoint 
    parse=False, # False for interactive env (ex. jupyter)
)
args.gpu = 0 # assign GPU 
# create config file 
from transformers import T5Config  

config_class = T5Config 
config = config_class.from_pretrained(args.backbone)
config.feat_dim = args.feat_dim 
config.pos_dim = args.pos_dim 
config.n_images = 1 
config.use_vis_order_embedding = args.use_vis_order_embedding 

config.dropout_rate = args.dropout 
config.dropout = args.dropout 
config.attention_dropout = args.dropout 
config.activation_dropout = args.dropout 

config.use_vis_layer_norm = args.use_vis_layer_norm 
config.individual_vis_layer_norm = args.individual_vis_layer_norm 
config.losses = args.losses 

config.share_vis_lang_layer_norm = args.share_vis_lang_layer_norm  

config.classifier = args.classifier # False, we want to do a generation task

model = VLT5VQA(config)
result = model.generate(input_ids=batch['input_ids'], 
                        vis_inputs=(batch['vis_feats'], batch['boxes'])) 

@puzzlecollector
Copy link
Author

@j-min

I am using the detectron2 for feature extraction for generative vqa, but I am getting empty outputs. Can you let me know if I am passing in the data correctly?

# define model 
from vqa_model import VLT5VQA 

from param import parse_args  
args = parse_args( 
    backbone='t5-base', # Backbone architecture
    load='Epoch30.pth', # Pretrained checkpoint 
    parse=False, # False for interactive env (ex. jupyter)
)
args.gpu = 0 # assign GPU 


# create config file 
from transformers import T5Config  

config_class = T5Config 
config = config_class.from_pretrained(args.backbone)
config.feat_dim = args.feat_dim 
config.pos_dim = args.pos_dim 
config.n_images = 1 
config.use_vis_order_embedding = args.use_vis_order_embedding 

config.dropout_rate = args.dropout 
config.dropout = args.dropout 
config.attention_dropout = args.dropout 
config.activation_dropout = args.dropout 

config.use_vis_layer_norm = args.use_vis_layer_norm 
config.individual_vis_layer_norm = args.individual_vis_layer_norm 
config.losses = args.losses 

config.share_vis_lang_layer_norm = args.share_vis_lang_layer_norm  

config.classifier = args.classifier # False, we want to do a generation task

model = VLT5VQA(config) 
model.cuda()


# display image 
im = cv2.imread("bananas.jpeg") 
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 
showarray(im_rgb)

instances, features = doit(im_rgb, predictor)

# vqa 
questions = ["vqa:What fruit is in the picture?", 
             "vqa:What color is the fruit?", 
             "vqa:How much is the fruit?"] 


from tokenization import VLT5Tokenizer 

tokenizer = VLT5Tokenizer.from_pretrained('t5-base')




for question in questions: 
  outputs = tokenizer.encode_plus(question, 
                                max_length=100,
                                truncation=True, 
                                return_attention_mask=True,
                                return_token_type_ids=True) 


  input_ids = outputs['input_ids']  
  batch = {}
  batch['vis_feats'] = features.reshape((1,36,2048)).cuda() 
  batch['input_ids'] = torch.tensor(input_ids).reshape((1,-1)).cuda() 
  batch['boxes'] = instances.pred_boxes.tensor.reshape((1,36,4)).cuda()

  batch['obj_id'] = instances.pred_classes
  batch['obj_conf'] = instances.scores
  batch['attr_id'] = instances.attr_classes
  batch['attr_conf'] = instances.attr_scores
  batch['img_w'] = im_rgb.shape[1] 
  batch['img_h'] = im_rgb.shape[0] 
  batch['img'] = im_rgb 
  batch['img_ids'] = 1


  result = model.test_step(batch)
  

  print("generated answer = {}".format(result['pred_ans']))

The output is

generated answer = ['']
generated answer = ['']
generated answer = ['']

@j-min
Copy link
Owner

j-min commented Jul 6, 2021

I'm afraid if you didn't load the pretrained checkpoint properly. Please check out load_checkpoint used in vqa.py, defined in trainer_base.py.

@j-min
Copy link
Owner

j-min commented Jul 6, 2021

I created a google colab for custom image processing. Hope this helps.

@puzzlecollector
Copy link
Author

puzzlecollector commented Jul 7, 2021

@j-min
Thank you so much! your answers helped me so much!!
I have a few more questions:

  1. Is the detectron2 (in the airsplay py-bottom-up-attention repo) pretty much the same as the faster R-CNN provided by the transformers huggingface LXMERT demo? If not, is it better to use one over the other?

  2. I realized that the answers generated by the VL-T5 for VQA tasks are pretty short (it is probably because the answers in the VQA v2.0 datasets are quite short). So I am thinking about fine-tuning the VLT5VQA model with questions and answers from the VCR (visual commonsense reasoning) dataset. It seems like the answers provided for the VCR dataset tend to be longer and more complex compared to the answers provided by the VQA dataset. What script should I refer to if I want to fine tune the VLT5VQA model with my custom data? Would simply doing

# Training
model.train_step(train_batch)

After creating my own batches of data be enough to fine-tune the model?

@puzzlecollector
Copy link
Author

@j-min
Also I wonder what you think about fine-tuning the VLT5VQA model with (image, question, answer) triplets from the VCR dataset. By doing so, do you think the VLT5 model can generate more detailed answers to questions (instead of just single word answers like the conventional VQA)? I guess the visual7w dataset is also another option.

@j-min
Copy link
Owner

j-min commented Jul 9, 2021

Yes, the py-bottom-up-attention repo is compatible with huggingface transformer LXMERT demo.

VCR questions (https://visualcommonsense.com/explore/?im=2519) have a different format than VQA, for example, person grounding / multiple-choice. So I don't think such fine-tuning is trivial. You can search or create a new dataset that has a format similar to VQA with longer answers. Once you get a custom VQA dataset, then you can finetune VLT5 or VLT5VQA model with that. You can start by modifying the Dataset class in vqa_data.py.

@puzzlecollector
Copy link
Author

@j-min

Looking at the VL-T5 paper, it seems like the decoder generates text in an autoregressive manner i.e. it predicts the probability of future text tokens (among all the tokens it already knows) conditioned on the encoder outputs and the previously generated tokens. Is this understanding of mine correct? if so if I want to fine-tune with additional training examples (with longer answers) then do I need to modify the vocab size of the output space (so that it is tailored for the additional fine-tuning data that I have)?

Screen Shot 2021-08-13 at 1 34 10 PM

@j-min
Copy link
Owner

j-min commented Aug 13, 2021

  • Looking at the VL-T5 paper, it seems like the decoder generates text in an autoregressive manner i.e. it predicts the probability of future text tokens (among all the tokens it already knows) conditioned on the encoder outputs and the previously generated tokens. Is this understanding of mine correct?
    -> Yes, you're right.

  • if so if I want to fine-tune with additional training examples (with longer answers) then do I need to modify the vocab size of the output space (so that it is tailored for the additional fine-tuning data that I have)?
    -> As long as your training examples are written in English, you won't need to modify the vocabulary since T5's tokenizer is based on sentencepiece trained on a large corpus. You might want to modify the vocabulary when you're training on foreign languages, non-natural languages such as equations and programming languages.

@puzzlecollector
Copy link
Author

@j-min
So I have managed to create my custom dataset (VQA with longer answers) and I managed to form them into batches (so batch['input_id'], batch['vis_feats'], batch['boxes']).

If I want to train, then do I simply have to create a VLT5VQA model instance like

model = VLT5VQA(config)

and then simply do

model.train_step(batch)

in order to train?

@puzzlecollector
Copy link
Author

puzzlecollector commented Aug 24, 2021

@j-min
Also in https://github.com/j-min/VL-T5/blob/2b02c32cf156881bf235d12abe64b558d43b1b5c/VL-T5/src/vqa_model.py

what information does batch['scores'] hold? Also how do I obtain the scores information?

I am not so familiar with this concept so I'm sorry if the question is too basic. Are scores the intersection over union between the reference boxes and the boxes detected by faster R-CNN? If so, how can I obtain the reference boxes? I am using the COCO datasets for images (the train2014.zip: http://images.cocodataset.org/zips/train2014.zip) I just have longer questions and answers.

@j-min
Copy link
Owner

j-min commented Aug 24, 2021

scores are from VQA evaluation. Many VQA methods train models by directly regressing the soft scores (ex. lxmert).
But in our text-generation based method, I just used score = 1 if there's an answer label for question. In your case, you can ignore it by commenting out score.

@puzzlecollector
Copy link
Author

puzzlecollector commented Aug 27, 2021

@j-min
Also when I train the VLT5VQA model, I guess I should add "vqa:" tag to the questions?

So I accidentally fine-tuned the model using data that does not have the 'vqa:' tag for the questions, but after training it for 10 epochs and when I tested it, it seems to be working fine.

@j-min
Copy link
Owner

j-min commented Aug 27, 2021

You don't have to use vqa: prefix, especially if you finetune on enough data.

@puzzlecollector
Copy link
Author

@j-min
I have a research question: Is it possible to create a VQA system that also uses visual grounding information? so for example, I have person tag informations, and somehow those tag informations are also input into the VL-T5 (or some VQA model) and the VQA model makes use of this tag information to answer specific questions. For instance, we have person1, person2 and person3 tagged in the picture, and the question asks what person1 is doing. I need to look into this a bit more, but I feel like VCR requires us to do something similar to the task I just described. Would you be able to give me a few pointers on where to start if I want to make such a system/model? Thanks

@j-min
Copy link
Owner

j-min commented Sep 28, 2021

Yes, you can check out VCR for such a setting. You also might want to check Visual7W and how models tackle these datasets.

@alice-cool
Copy link

I want to test if it it can generate answers based on images and questions that I have prepared.

About github link of VL-T5 model about the vcr task dataset image36 "train_boxes36.h5" can't download, could you give me another link?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants