Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching #48

Open
SinanAkkoyun opened this issue Apr 18, 2024 · 7 comments
Open

Batching #48

SinanAkkoyun opened this issue Apr 18, 2024 · 7 comments

Comments

@SinanAkkoyun
Copy link
Contributor

SinanAkkoyun commented Apr 18, 2024

Is this code "optimal" for batched inference and preprocessing?

@SinanAkkoyun
Copy link
Contributor Author

SinanAkkoyun commented Apr 19, 2024

import time
import torch
from transformers import AutoModelForCausalLM
from PIL import Image
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
import concurrent.futures

# Initialize the model and processor
model_path = "deepseek-ai/deepseek-vl-1.3b-chat"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

# Function to load and process images and text per thread
def process_conversation(conversation_piece):
    # Load images directly using the entire conversation piece
    pil_images = load_pil_images(conversation_piece)
    prepare_inputs = vl_chat_processor(
        conversations=conversation_piece,
        images=pil_images,
        force_batchify=True
    ).to(vl_gpt.device)
    return prepare_inputs


n_threads = 8

conversation = [
    [
        {"role": "User", "content": "Thoroughly describe <image_placeholder>.", "images": ["../../man_wave.png"]},
        {"role": "Assistant", "content": ""}
    ] for _ in range(n_threads)
]
start = time.time()

# Using ThreadPoolExecutor to parallelize image loading and input preparation
with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor:
    futures = [executor.submit(process_conversation, conv) for conv in conversation]
    results = [f.result() for f in concurrent.futures.as_completed(futures)]

print("Time for preprocessing: ", time.time() - start)

# Aggregate results from threads
input_ids = torch.cat([res.input_ids for res in results], dim=0)
pixel_values = torch.cat([res.pixel_values for res in results], dim=0)
attention_mask = torch.cat([res.attention_mask for res in results], dim=0)
images_seq_mask = torch.cat([res.images_seq_mask for res in results], dim=0)
images_emb_mask = torch.cat([res.images_emb_mask for res in results], dim=0)
sft_format = [res.sft_format for res in results]

# Run model to get the response
inputs_embeds = vl_gpt.prepare_inputs_embeds(
    input_ids=input_ids,
    pixel_values=pixel_values,
    images_seq_mask=images_seq_mask,
    images_emb_mask=images_emb_mask
)

outputs = vl_gpt.language_model.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask,
    pad_token_id=tokenizer.eos_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=1,
    do_sample=False,
    use_cache=True
)

answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(sft_format[0], answer)  # Assuming sft_format is consistent across threads

end = time.time()
print("Time taken to process: ", end - start)

@SinanAkkoyun SinanAkkoyun reopened this Apr 19, 2024
@SinanAkkoyun
Copy link
Contributor Author

SinanAkkoyun commented Apr 19, 2024

The vl_chat_processor seems to take up the most time, more than the image embedding model itself. If possible it would be awesome to get help on optimizing the preprocessing or somehow better parallelizing it, which would make it great for serving

@soloice
Copy link

soloice commented Apr 22, 2024

Is this code "optimal" for batched inference and preprocessing?

Nope. It's just a toy demo, not for production purpose.

@SinanAkkoyun
Copy link
Contributor Author

I know, that's why I at least tried to quickly "parallelize" the processor
Any help in actually optimizing it for batching is greatly appreciated

@gullalc
Copy link

gullalc commented May 4, 2024

You could use custom dataset class and use dataloader to do batching. This is how I run it and it is quite fast.

class CustomDataset(Dataset):
    def __init__(self, image_folder, prompt, processor):
        self.image_ids = [os.path.join(image_folder, f) for f in os.listdir(image_folder)]
        self.prompt = prompt
        self.processor = processor
            
    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_path = self.image_ids[idx]
        
        prompt = [
                    {
                        "role": "User",
                        "content": "<image_placeholder>%s"%(self.prompt),
                        "images": [image_path]
                    },
                    {
                        "role": "Assistant",
                        "content": ""
                    }
                ]
        
        pil_images = load_pil_images(prompt)
        res = self.processor(conversations=prompt, images=pil_images, force_batchify=True)
        
        return image_path, res.input_ids.squeeze(0), res.pixel_values.squeeze(0), \
                    res.attention_mask.squeeze(0), res.images_seq_mask.squeeze(0), res.images_emb_mask.squeeze(0)


def process_images(model, tokenizer, loader):
    image_descriptions = {}
    for i, batch in enumerate(loader):

        print(f'Processing batch {i}/{len(loader)}')
        
        image_ids, input_ids, pixel_values, attention_mask, images_seq_mask, images_emb_mask = batch
        
        input_ids, pixel_values, attention_mask, images_seq_mask, images_emb_mask = input_ids.to(model.device), \
                                            pixel_values.to(model.device), attention_mask.to(model.device), \
                                            images_seq_mask.to(model.device), images_emb_mask.to(model.device)
        
        print(input_ids.shape, pixel_values.shape, attention_mask.shape, images_seq_mask.shape, images_emb_mask.shape)
        
        with torch.no_grad():
            inputs_embeds = model.prepare_inputs_embeds(
                input_ids=input_ids,
                pixel_values=pixel_values.to(torch.bfloat16),
                images_seq_mask=images_seq_mask,
                images_emb_mask=images_emb_mask
            )

            outputs = model.language_model.generate(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                    pad_token_id=tokenizer.eos_token_id,
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    max_new_tokens=256,
                    do_sample=False,
                    use_cache=True
                )

        for image_id, output in zip(image_ids, outputs):
            answer = tokenizer.decode(output.cpu().tolist(), skip_special_tokens=True)
            answer = answer.split("Assistant:")[-1].strip()
            image_descriptions[image_id] = answer

    return image_descriptions
    
## Model, Processor and Tokenizer
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained("deepseek-ai/deepseek-vl-7b-chat")
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-vl-7b-chat", trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

prompt = "Describe this image in detail."

## Dataloader
dataloader = DataLoader(CustomDataset(image_folder, prompt, vl_chat_processor), batch_size=32, shuffle=False, pin_memory=True, num_workers=8)

image_descriptions = process_images(vl_gpt, tokenizer, dataloader)

@SinanAkkoyun
Copy link
Contributor Author

Oh wow that's a very cool approach, thank you so much for sharing it! I will try it out asap

@Youho99
Copy link
Contributor

Youho99 commented May 14, 2024

I'm also interested in making inference faster.
I only use inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants