Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage] Must I reload the model when I want to inference on a new image? #1487

Open
lin-whale opened this issue May 7, 2024 · 2 comments
Open

Comments

@lin-whale
Copy link

lin-whale commented May 7, 2024

Describe the issue

I think the time to load model is very long, so try to reuse the model when inferring in a new image. But encounter the issue below, so is it possible to do this? How should I write the code?
Modified from llava/serve/cli.py

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    ...

    while True:
        image_file = input("Please input image path:")
        image = load_image(image_file)
        image_size = image.size
        # Similar operation in model_worker.py
        image_tensor = process_images([image], image_processor, model.config)
        if type(image_tensor) is list:
            image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
        else:
            image_tensor = image_tensor.to(model.device, dtype=torch.float16)

        while True:
            try:
                inp = input(f"{roles[0]}: ")
            except EOFError:
                inp = ""
            if not inp:
                print("exit...")
                break

            print(f"{roles[1]}: ", end="")

            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                image = None
            
            conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    image_sizes=[image_size],
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    max_new_tokens=args.max_new_tokens,
                    streamer=streamer,
                    use_cache=True)

            outputs = tokenizer.decode(output_ids[0]).strip()
            conv.messages[-1][-1] = outputs

            if args.debug:
                print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

The code works well on first image input, but fails on the second image input.

Output:

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:29<00:00,  1.94s/it]
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Hello! This is a beautiful image of a wooden dock extending into a serene lake. The calm water reflects the surrounding landscape, which includes a forest and mountains in the distance. The sky is partly cloudy, suggesting a pleasant day. The dock appears to be a quiet spot for relaxation or perhaps a starting point for boating or fishing. It's a peaceful scene that evokes a sense of tranquility and connection with nature.
<|im_start|>user
: 
exit...
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Traceback (most recent call last):
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 137, in <module>
    main(args)
  File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 107, in main
    output_ids = model.generate(
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/aistar/llava/LLaVA/llava/model/language_model/llava_llama.py", line 125, in generate
    ) = self.prepare_inputs_labels_for_multimodal(
  File "/home/aistar/llava/LLaVA/llava/model/llava_arch.py", line 260, in prepare_inputs_labels_for_multimodal
    cur_image_features = image_features[cur_image_idx]
IndexError: list index out of range
@J0eky
Copy link

J0eky commented May 30, 2024

@lin-whale hi, I'm working the same thing with you, and meet the same problem, have you solve it now?

@lin-whale
Copy link
Author

lin-whale commented May 30, 2024

@lin-whale hi, I'm working the same thing with you, and meet the same problem, have you solve it now?

Yeah, here is the code modified. You need to reinitialize the conv before starting a new query. Just in this line conv = conv_templates[args.conv_mode].copy().

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    while True:
        print("-"*100)
        img_folder = input("please input image folder path: ")
        images = os.listdir(img_folder)
        images = [item for item in images if os.path.isfile(os.path.join(img_folder, item))]
        images.sort()
        save_dir = os.path.join(img_folder, "llava_eval")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        try:
            ori_inp = input("please input prompt: ")
        except EOFError:
            ori_inp = ""
            
        total_time = 0
        img_count = 0
        yes_img_count = 0

        for index, ori_image in enumerate(images):
            img_count += 1
            
            try:
                conv = conv_templates[args.conv_mode].copy()
                if "mpt" in model_name.lower():
                    roles = ('user', 'assistant')
                else:
                    roles = conv.roles

                image = load_image(os.path.join(img_folder, ori_image))

                # Similar operation in model_worker.py
                image_tensor = process_images([image], image_processor, args)
                if type(image_tensor) is list:
                    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
                else:
                    image_tensor = image_tensor.to(model.device, dtype=torch.float16)            
            except EOFError:
                image = ""
            if not image:
                print("NO image input, exit...")
                break
            
            if not ori_inp:
                print("exit...")
                break

            try:
                # print(f"{roles[1]}: ", end="")
                print("")
                print(ori_image)
                
                T1 = time.perf_counter()
                if image is not None:
                    # first message
                    if model.config.mm_use_im_start_end:
                        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + ori_inp
                    else:
                        inp = DEFAULT_IMAGE_TOKEN + '\n' + ori_inp
                    conv.append_message(conv.roles[0], inp)
                    image = None
                else:
                    # later messages
                    conv.append_message(conv.roles[0], ori_inp)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()

                input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
                stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
                keywords = [stop_str, 'Yes', 'No']
                stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
                streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)

                # print('------------------------')
                # print('input_ids:', input_ids)
                # print('image_tensor:', image_tensor)
                # print('args.temperature:', args.temperature)
                # print('max_new_tokens:', args.max_new_tokens)
                # print('streamer:', streamer)
                # print('stopping_criteria:', [stopping_criteria])


                with torch.inference_mode():
                    output_ids = model.generate(
                        input_ids,
                        images=image_tensor,
                        do_sample=True,
                        top_p = 1.0,
                        temperature=args.temperature,
                        max_new_tokens=args.max_new_tokens,
                        streamer=streamer,
                        use_cache=True,
                        stopping_criteria=[stopping_criteria])


                # input('Debug...')

                outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
                conv.messages[-1][-1] = outputs
                T2 =time.perf_counter()
                print('------------time in second------------', (T2 - T1))
                total_time += T2-T1
                # print(outputs)
                if args.debug:
                    print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
            
            except Exception as e:
                print('Error:',str(e))
                print('Traceback:', traceback.format_exc())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants