Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting attention weights for generated text from llama3-8b-instruct #155

Open
bear96 opened this issue Apr 26, 2024 · 6 comments
Open

Getting attention weights for generated text from llama3-8b-instruct #155

bear96 opened this issue Apr 26, 2024 · 6 comments

Comments

@bear96
Copy link

bear96 commented Apr 26, 2024

Hello,

I'm trying to visualize the attention weights for Llama 3 when it generates text, but I am facing some complications. I slightly modified the Attention class to output the scores variable (which I am guessing is the attention weights, since it is multiplied to produce the attention outputs), and then I save the attention weight values in the TransformerBlock class as an attribute. I also modified this step in the forward function h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) to

y, self.weights = self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
h = x + y

where self.weights is attention weights.

Now, in generation.py, in generate function, starting from line 175, I modify the for loop in this way:

attention_dict = dict()
for cur_pos in range(min_prompt_len, total_len):
        logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        # taking the last transformer block
        attention_dict[cur_pos] = self.model.layers[-1].weights.float().cpu().numpy()

        if temperature > 0:
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        if logprobs:
            token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens[:, prev_pos + 1 : cur_pos + 1],
                reduction="none",
                ignore_index=pad_id,
            )
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
            torch.isin(next_token, stop_tokens)
        )
        prev_pos = cur_pos
        if all(eos_reached):
            break

The idea is to get the attention weights from the last transformer block for each step of token generation, so that I can go back to any generated token and see how the attention weights are distributed along the generated sequence length.

However, the problem that I am facing is, that there are 32 heads for Llama3. If I average across all 32 heads to reduce dimensionality for visualization purposes, and then apply Softmax to the output, the attention weights that I am getting for any step of generation has the exact same distribution, with the first few tokens, and the last two tokens having a much higher value, while every other token has the exact same weight (which is very miniscule). I have also tried max pooling across heads instead of averaging, but it yielded similar results.

My point is that this doesn't seem right, because it also remains the same across different prompts, which means there is something wrong with my approach. Could you please guide me in the right direction?

Thanks!

@Icamd
Copy link

Icamd commented Apr 28, 2024

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

@bear96
Copy link
Author

bear96 commented Apr 28, 2024

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

Hi @Icamd, I think you're using the Huggingface version? I have tried using the same thing you have, but the attention weights I get are of a strange shape. Usually, attention weights have the shape (batch_size, num_heads, seq_length, seq_length), but in Huggingface Llama's case, I get a mismatch in the batch_size axis. It is my guess that since output_attentions is not actually a parameter in the model architecture shown in this repo, Huggingface does something internally to calculate the attention weights, and thus provides wrong values. I could be wrong, of course. I also get some warnings whenever I have tried to do this with Huggingface. That's why I am using the PyTorch version of this model instead.

@Icamd
Copy link

Icamd commented Apr 29, 2024

Hi, I am also trying to output the attention weight of llama3. Have you tried output Llama3 attention weight with itself? (for example, outputs = model.generate(tokens, output_attentions=True))

Hi @Icamd, I think you're using the Huggingface version? I have tried using the same thing you have, but the attention weights I get are of a strange shape. Usually, attention weights have the shape (batch_size, num_heads, seq_length, seq_length), but in Huggingface Llama's case, I get a mismatch in the batch_size axis. It is my guess that since output_attentions is not actually a parameter in the model architecture shown in this repo, Huggingface does something internally to calculate the attention weights, and thus provides wrong values. I could be wrong, of course. I also get some warnings whenever I have tried to do this with Huggingface. That's why I am using the PyTorch version of this model instead.

I thinks Huggingface version's attention weight has the shape of (outputs_token_numbe, layers, batch_size, heads, input_token_number, input_token_number), for example (50, 32, 1, 32, 251, 251), but I am not sure. I am still trying to visualize the attention between tokens to find out if there is any connection. However I have strange model outputs using NousResearch/Meta-Llama-3-8B-Instruct in google colab :(
image

@Icamd
Copy link

Icamd commented Apr 29, 2024

@bear96 Hi! I find this paper "Analyzing the Structure of Attention in a Transformer Language Model" mentioned something called " Null Attention", which said "attention focused on the first token". Maybe you can try to mask the first token's attention so it won't influence the overall attention weight?(I'm not sure)

@bear96
Copy link
Author

bear96 commented Apr 29, 2024

@bear96 Hi! I find this paper "Analyzing the Structure of Attention in a Transformer Language Model" mentioned something called " Null Attention", which said "attention focused on the first token". Maybe you can try to mask the first token's attention so it won't influence the overall attention weight?(I'm not sure)

I'll definitely check that out! Thanks!

@bear96
Copy link
Author

bear96 commented May 3, 2024

I believe I have solved the issue. I was taking an average across all 32 heads and then applying a softmax function to get them to appear as probabilities, but that caused a lot of minute changes in the attention weights to disappear, leaving an almost uniform distribution of weights. I'm trying to visualize the attention weights with respect to individual heads instead. Due to Null Attention as cited by @Icamd the first token has extremely high attention weights, whereas the rest of the weights vary in an exponential way, so I am having to take the log of these weights instead for better visualizations.

I am not sure why Null Attention occurs however. If someone knows more about this, please let me know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants