-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Usage] Must I reload the model when I want to inference on a new image? #1487
Comments
@lin-whale hi, I'm working the same thing with you, and meet the same problem, have you solve it now? |
Yeah, here is the code modified. You need to reinitialize the conv before starting a new query. Just in this line def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
while True:
print("-"*100)
img_folder = input("please input image folder path: ")
images = os.listdir(img_folder)
images = [item for item in images if os.path.isfile(os.path.join(img_folder, item))]
images.sort()
save_dir = os.path.join(img_folder, "llava_eval")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
try:
ori_inp = input("please input prompt: ")
except EOFError:
ori_inp = ""
total_time = 0
img_count = 0
yes_img_count = 0
for index, ori_image in enumerate(images):
img_count += 1
try:
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
image = load_image(os.path.join(img_folder, ori_image))
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, args)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
except EOFError:
image = ""
if not image:
print("NO image input, exit...")
break
if not ori_inp:
print("exit...")
break
try:
# print(f"{roles[1]}: ", end="")
print("")
print(ori_image)
T1 = time.perf_counter()
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + ori_inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + ori_inp
conv.append_message(conv.roles[0], inp)
image = None
else:
# later messages
conv.append_message(conv.roles[0], ori_inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str, 'Yes', 'No']
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
# print('------------------------')
# print('input_ids:', input_ids)
# print('image_tensor:', image_tensor)
# print('args.temperature:', args.temperature)
# print('max_new_tokens:', args.max_new_tokens)
# print('streamer:', streamer)
# print('stopping_criteria:', [stopping_criteria])
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
top_p = 1.0,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
# input('Debug...')
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
T2 =time.perf_counter()
print('------------time in second------------', (T2 - T1))
total_time += T2-T1
# print(outputs)
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
except Exception as e:
print('Error:',str(e))
print('Traceback:', traceback.format_exc()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the issue
I think the time to load model is very long, so try to reuse the model when inferring in a new image. But encounter the issue below, so is it possible to do this? How should I write the code?
Modified from llava/serve/cli.py
The code works well on first image input, but fails on the second image input.
Output:
The text was updated successfully, but these errors were encountered: