Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why did Gemma7B perform poorly #365

Open
JuJoker opened this issue Apr 15, 2024 · 2 comments
Open

Why did Gemma7B perform poorly #365

JuJoker opened this issue Apr 15, 2024 · 2 comments
Labels
component:demos Update demos status:triaged Issue/PR triaged to the corresponding sub-team type:bug Something isn't working

Comments

@JuJoker
Copy link

JuJoker commented Apr 15, 2024

Description of the bug:

I ran the Gemema-7B model based on the code in the example, and found that the model's answers were rather poor and didn't seem to understand my question at all. Is this normal? My device is Nvidia 4090 GPU. My code as follows:

import os
import sys
import torch
sys.path.append('gemma_pytorch')

from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

VARIANT = '7b'
MACHINE_TYPE = 'cuda'

# Set up model config.
model_config = get_config_for_7b()

# Ensure that the tokenizer is present
tokenizer_path = os.path.join('./gemma-7b', 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join('./gemma-7b', f'gemma-7b.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

model_config.tokenizer = tokenizer_path

model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'
MODEL_CHAT_START = '<start_of_turn>model\n'
MODEL_CHAT_END = '<end_of_turn>\n'

MULTI_CHAT = ''

# while True:
#     input_text = input("User input(press q to exit):")
#     if input_text != 'q':
#         user_prompt = USER_CHAT_TEMPLATE.format(prompt=input_text)
#         MULTI_CHAT = MULTI_CHAT + user_prompt + MODEL_CHAT_START
#         model_response = model.generate(
#             user_prompt,
#             device=device,
#             output_len=64,
#         )
#         print(f'Model reply: {model_response}')
#         MULTI_CHAT = MULTI_CHAT + model_response + MODEL_CHAT_END
#     else:
#         break

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

res = model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)
print(res)

The according response as follows:
image

Actual vs expected behavior:

but this code run on the kaggle, the result seems correctly,
image

Any other information you'd like to share?

No response

@JuJoker JuJoker added the type:bug Something isn't working label Apr 15, 2024
@singhniraj08 singhniraj08 added status:triaged Issue/PR triaged to the corresponding sub-team component:demos Update demos labels Apr 18, 2024
@cog-master
Copy link

cog-master commented Apr 18, 2024

not only they perform poorly but slow also u have this "Nvidia 4090 GPU" then the model seems to be not working properly

@MarkDaoust
Copy link
Member

@gustheman

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:demos Update demos status:triaged Issue/PR triaged to the corresponding sub-team type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants