Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM evaluator setting any repetitionPenalty crashes the program #71

Closed
shawiz opened this issue May 12, 2024 · 3 comments
Closed

LLM evaluator setting any repetitionPenalty crashes the program #71

shawiz opened this issue May 12, 2024 · 3 comments

Comments

@shawiz
Copy link

shawiz commented May 12, 2024

I'm adding a repetitionPenalty to the GenerateParameters constructor. Regardless what values I set (I tried 0.5, 1, 1.2, 1.5), it crashes the program immediate as the evaluator runs. I was testing various Qwen1.5 models. Error message I got is

-[MTLDebugComputeCommandEncoder dispatchThreads:threadsPerThreadgroup:]:1441: failed assertion '(threadsPerGrid.width(0) * threadsPerGrid.y(1) * threadsPerGrid.depth(0))(0) must not be 0.'

@davidkoski
Copy link
Collaborator

I tried this on mlx-community/Mistral-7B-v0.1-hf-4bit-mlx without error. The Qwen1.5 models are not loading for me because of #53 -- I will try it again once that is resolved.

@davidkoski
Copy link
Collaborator

OK, I can reproduce it. I won't check it in yet because of the quantization issues in #53 -- it will be included there.

If you want to try it locally, here is the change:

diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift
index 6c85558..89212a7 100644
--- a/Libraries/LLM/Evaluate.swift
+++ b/Libraries/LLM/Evaluate.swift
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
 ) -> MLXArray {
     if repetitionContext.shape[0] > 0 {
         let indices = repetitionContext
-        var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
+        var selectedLogits = logits[0..., indices]
 
         selectedLogits = MLX.where(
             selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
             if prompt.shape[0] <= parameters.repetitionContextSize {
                 self.repetitionContext = prompt
             } else {
-                self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
+                self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
             }
         } else {
             self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
         y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
         // append the current token to the context and check repetitionPenalty context see if need to remove the first token
         if parameters.repetitionContextSize > 1 {
-            repetitionContext = concatenated([repetitionContext, y], axis: 0)
             if repetitionContext.shape[0] > parameters.repetitionContextSize {
-                repetitionContext = repetitionContext[1...]
+                repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
             }
         }

I just switched it to use the full array indexing and made it conform to the python code. I don't know if this is a bug in the mlx core code or a bug in the calling code -- certainly the calling code requires some changes and I don't think it is logically the same.

davidkoski added a commit that referenced this issue May 20, 2024
- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
davidkoski added a commit that referenced this issue May 28, 2024
* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
@davidkoski
Copy link
Collaborator

#76 should fi this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants