Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using token sequences as stop criteria does not work in mlx_lm #524

Open
Y4hL opened this issue Mar 3, 2024 · 7 comments · May be fixed by #541
Open

Using token sequences as stop criteria does not work in mlx_lm #524

Y4hL opened this issue Mar 3, 2024 · 7 comments · May be fixed by #541

Comments

@Y4hL
Copy link
Contributor

Y4hL commented Mar 3, 2024

The implementation of stop_criteria in mlx_lm.server is inherently flawed. Stop sequences only get matched when the newest tokens generated perfectly match a stop sequence. However it does not stop if the stop sequence is inside of the tokens in any other way.

This only checks if the newest tokens perfectly match a sequence

for stop_ids in stop_id_sequences:
    if len(tokens) >= len(stop_ids):
        if np.array_equal(tokens[-len(stop_ids) :], stop_ids):
            return StopCondition(stop_met=True, trim_length=len(stop_ids))

stopping_criteria only gets called when max_stop_id_sequence_len amount of tokens have been generated, which is the length of the longest stop sequence.

if len(stop_sequence_buffer) > max_stop_id_sequence_len:
      if REPLACEMENT_CHAR in _tokenizer.decode(token):
          continue
      stop_condition = stopping_criteria(
          tokens,
          stop_id_sequences,
          eos_token_id,
      )

Example where it breaks:

I have two stop sequences, one is of length 4, and one of length 6. Once 6 tokens have been generated, stop_criteria is called. However the tokens I have generated only match the stop sequence of length 4, and the match happens at the start of the new tokens, not at the end. However since stop_criteria only checks the end of the full token list, it does not get matched and generation does not stop.

stop sequence = [1, 2, 3, 4]

new tokens = [1, 2, 3, 4, 5, 6]

# stop_criteria get called and checks for len(stop_sequence) of tokens at the end of new tokens

[3, 4, 5, 6] != [1, 2, 3, 4]
@Y4hL
Copy link
Contributor Author

Y4hL commented Mar 4, 2024

Additionally the condition for checking the stop criteria is:

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

It is not inclusive, meaning there will always be an extra token appended. Meaning no stop sequences can ever be matched.

@mzbac
Copy link
Contributor

mzbac commented Mar 4, 2024

Additionally the condition for checking the stop criteria is:

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

It is not inclusive, meaning there will always be an extra token appended. Meaning no stop sequences can ever be matched.

Just curious, how did you find the issue? I ran a few tests before and didn't see any extra tokens added.

@Y4hL
Copy link
Contributor Author

Y4hL commented Mar 4, 2024

I'll try to sum it up

max_stop_id_sequence_len is the length of the longest stop id sequence. Now let's assume the buffer is the same size as the longest stop sequence.

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

Since this check is non-inclusive, it would loop one extra time before running what comes after the if statement. Meaning the length of the buffer would now be 1 larger than the longest stop sequence.

Now when we take into account, that the stop_criteria function only checks for perfect matches, where the tail of the "tokens" matches a stop sequence. It can no longer ever match, because the extra tokens was generated before calling stop_criteria.

@mzbac
Copy link
Contributor

mzbac commented Mar 4, 2024

I see, it would only happen when the model starts with the stop word. Maybe that's why it wasn't picked up by my testing. This edge case is a bit difficult to pick up with manual testing, but it would be more obvious when the model starts with a stop word.

@Y4hL
Copy link
Contributor Author

Y4hL commented Mar 4, 2024

Yeah, I think it was intended that it runs on every token generated, and we can throw away the buffer entirely. This would address most issues.

It still needs a special case for streaming, since we need to anticipate a stop word, or we might stream parts of it. I have a prototype for that check currently, but still doing some testing.

@mzbac
Copy link
Contributor

mzbac commented Mar 4, 2024

Yeah, you are right. The original implementation didn't have a buffer, so it ended up sending the stop word back to the client in the streaming. The buffer was introduced to solve that issue, but it seems like it wasn't well thought out in the implementation.

@Y4hL Y4hL linked a pull request Mar 6, 2024 that will close this issue
@s-smits
Copy link

s-smits commented Apr 21, 2024

Probably, I have kind of the same problem. The 'generate' function outputs a single key per token, here is some pseudocode for the problem:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

temperature = 0.2
prompt = "Extract in JSON format: "

for token in generate(model, tokenizer, prompt, temp, max_tokens, verbose=True):

    generated_tokens.append(tokenizer.encode(token))

text = "Here is the extracted text"
generated_tokens = [[220], [39], [68], [81], [68], [220], [72], [82], [220], [83], [71], [68], [220], [68], [87], [83], [81], [64], [66], [83], [68], [67], [220], [83], [68], [87], [83]]

how_it_should_be = tokenizer.encode(text)

output = [8586, 374, 279, 28532, 1495]

Is there an easier way /example to retrieve the produced tokens per generated token such as 8586 or 374, and use those for stop criteria? Ideally multiple tokens as a stop criteria would be best because of stopping synthetic JSON generation just as it finishes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants