Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt at OpenElm #6986

Closed
wants to merge 14 commits into from
Closed

Attempt at OpenElm #6986

wants to merge 14 commits into from

Conversation

joshcarp
Copy link

Currently failing on line 821 of sgemm.cpp, still some parsing of ffn/attention head info needs to occur. Currently hard coded some stuff.

Fixes: #6868

Raising this PR as a draft because I need help. Will be adding comments to original source for reference purposes.

@joshcarp joshcarp mentioned this pull request Apr 29, 2024
4 tasks
@joshcarp
Copy link
Author

Okay I think it might have something to do with how i'm calculating the offesets from kqv

Copy link
Contributor

github-actions bot commented Apr 29, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 415 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=11355.74ms p(95)=31249.96ms fails=, finish reason: stop=361 truncated=54
  • Prompt processing (pp): avg=130.28tk/s p(95)=561.86tk/s
  • Token generation (tg): avg=27.12tk/s p(95)=34.82tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=openelm-impl commit=8d2dead6819589281227694a6d75b8c1dc825936

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 415 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1714489093 --> 1714489723
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 344.19, 344.19, 344.19, 344.19, 344.19, 527.27, 527.27, 527.27, 527.27, 527.27, 551.35, 551.35, 551.35, 551.35, 551.35, 548.93, 548.93, 548.93, 548.93, 548.93, 559.08, 559.08, 559.08, 559.08, 559.08, 586.92, 586.92, 586.92, 586.92, 586.92, 617.48, 617.48, 617.48, 617.48, 617.48, 619.48, 619.48, 619.48, 619.48, 619.48, 613.81, 613.81, 613.81, 613.81, 613.81, 627.64, 627.64, 627.64, 627.64, 627.64, 643.58, 643.58, 643.58, 643.58, 643.58, 653.93, 653.93, 653.93, 653.93, 653.93, 653.9, 653.9, 653.9, 653.9, 653.9, 668.24, 668.24, 668.24, 668.24, 668.24, 665.63, 665.63, 665.63, 665.63, 665.63, 685.22, 685.22, 685.22, 685.22, 685.22, 687.47, 687.47, 687.47, 687.47, 687.47, 707.52, 707.52, 707.52, 707.52, 707.52, 680.47, 680.47, 680.47, 680.47, 680.47, 678.94, 678.94, 678.94, 678.94, 678.94, 682.54, 682.54, 682.54, 682.54, 682.54, 680.53, 680.53, 680.53, 680.53, 680.53, 691.88, 691.88, 691.88, 691.88, 691.88, 693.19, 693.19, 693.19, 693.19, 693.19, 693.61, 693.61, 693.61, 693.61, 693.61, 692.22, 692.22, 692.22, 692.22, 692.22, 691.96, 691.96, 691.96, 691.96, 691.96, 694.96, 694.96, 694.96, 694.96, 694.96, 697.38, 697.38, 697.38, 697.38, 697.38, 707.43, 707.43, 707.43, 707.43, 707.43, 697.6, 697.6, 697.6, 697.6, 697.6, 699.85, 699.85, 699.85, 699.85, 699.85, 699.77, 699.77, 699.77, 699.77, 699.77, 701.06, 701.06, 701.06, 701.06, 701.06, 703.41, 703.41, 703.41, 703.41, 703.41, 700.92, 700.92, 700.92, 700.92, 700.92, 699.29, 699.29, 699.29, 699.29, 699.29, 697.74, 697.74, 697.74, 697.74, 697.74, 697.98, 697.98, 697.98, 697.98, 697.98, 703.8, 703.8, 703.8, 703.8, 703.8, 704.52, 704.52, 704.52, 704.52, 704.52, 703.8, 703.8, 703.8, 703.8, 703.8, 706.52, 706.52, 706.52, 706.52, 706.52, 714.62, 714.62, 714.62, 714.62, 714.62, 720.04, 720.04, 720.04, 720.04, 720.04, 721.16, 721.16, 721.16, 721.16, 721.16, 725.44, 725.44, 725.44, 725.44, 725.44, 724.44, 724.44, 724.44, 724.44, 724.44, 723.61, 723.61, 723.61, 723.61, 723.61, 722.86, 722.86, 722.86, 722.86, 722.86, 724.41, 724.41, 724.41, 724.41, 724.41, 724.66, 724.66, 724.66, 724.66, 724.66, 712.67, 712.67, 712.67, 712.67, 712.67, 696.41, 696.41, 696.41, 696.41, 696.41, 691.32, 691.32, 691.32, 691.32, 691.32, 682.78, 682.78, 682.78, 682.78, 682.78, 679.57, 679.57, 679.57, 679.57, 679.57, 678.62, 678.62, 678.62, 678.62, 678.62, 678.16, 678.16, 678.16, 678.16, 678.16, 676.41, 676.41, 676.41, 676.41, 676.41, 679.58, 679.58, 679.58, 679.58, 679.58, 679.58]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 415 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1714489093 --> 1714489723
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.2, 36.2, 36.2, 36.2, 36.2, 22.82, 22.82, 22.82, 22.82, 22.82, 20.79, 20.79, 20.79, 20.79, 20.79, 20.21, 20.21, 20.21, 20.21, 20.21, 19.23, 19.23, 19.23, 19.23, 19.23, 18.95, 18.95, 18.95, 18.95, 18.95, 20.03, 20.03, 20.03, 20.03, 20.03, 20.49, 20.49, 20.49, 20.49, 20.49, 21.82, 21.82, 21.82, 21.82, 21.82, 22.6, 22.6, 22.6, 22.6, 22.6, 22.88, 22.88, 22.88, 22.88, 22.88, 23.07, 23.07, 23.07, 23.07, 23.07, 23.07, 23.07, 23.07, 23.07, 23.07, 23.05, 23.05, 23.05, 23.05, 23.05, 22.77, 22.77, 22.77, 22.77, 22.77, 22.78, 22.78, 22.78, 22.78, 22.78, 22.36, 22.36, 22.36, 22.36, 22.36, 21.95, 21.95, 21.95, 21.95, 21.95, 21.92, 21.92, 21.92, 21.92, 21.92, 20.9, 20.9, 20.9, 20.9, 20.9, 20.98, 20.98, 20.98, 20.98, 20.98, 21.11, 21.11, 21.11, 21.11, 21.11, 21.2, 21.2, 21.2, 21.2, 21.2, 20.93, 20.93, 20.93, 20.93, 20.93, 20.89, 20.89, 20.89, 20.89, 20.89, 20.85, 20.85, 20.85, 20.85, 20.85, 20.67, 20.67, 20.67, 20.67, 20.67, 20.6, 20.6, 20.6, 20.6, 20.6, 20.52, 20.52, 20.52, 20.52, 20.52, 20.77, 20.77, 20.77, 20.77, 20.77, 20.8, 20.8, 20.8, 20.8, 20.8, 20.94, 20.94, 20.94, 20.94, 20.94, 21.1, 21.1, 21.1, 21.1, 21.1, 21.15, 21.15, 21.15, 21.15, 21.15, 21.18, 21.18, 21.18, 21.18, 21.18, 20.98, 20.98, 20.98, 20.98, 20.98, 20.9, 20.9, 20.9, 20.9, 20.9, 20.69, 20.69, 20.69, 20.69, 20.69, 20.65, 20.65, 20.65, 20.65, 20.65, 20.82, 20.82, 20.82, 20.82, 20.82, 20.94, 20.94, 20.94, 20.94, 20.94, 21.04, 21.04, 21.04, 21.04, 21.04, 21.19, 21.19, 21.19, 21.19, 21.19, 21.17, 21.17, 21.17, 21.17, 21.17, 21.14, 21.14, 21.14, 21.14, 21.14, 21.11, 21.11, 21.11, 21.11, 21.11, 21.02, 21.02, 21.02, 21.02, 21.02, 20.92, 20.92, 20.92, 20.92, 20.92, 20.95, 20.95, 20.95, 20.95, 20.95, 21.1, 21.1, 21.1, 21.1, 21.1, 21.24, 21.24, 21.24, 21.24, 21.24, 21.38, 21.38, 21.38, 21.38, 21.38, 21.48, 21.48, 21.48, 21.48, 21.48, 21.48, 21.48, 21.48, 21.48, 21.48, 21.12, 21.12, 21.12, 21.12, 21.12, 21.12, 21.12, 21.12, 21.12, 21.12, 21.02, 21.02, 21.02, 21.02, 21.02, 20.84, 20.84, 20.84, 20.84, 20.84, 20.79, 20.79, 20.79, 20.79, 20.79, 19.7, 19.7, 19.7, 19.7, 19.7, 19.65, 19.65, 19.65, 19.65, 19.65, 19.64]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 415 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1714489093 --> 1714489723
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.08, 0.08, 0.08, 0.08, 0.23, 0.23, 0.23, 0.23, 0.23, 0.34, 0.34, 0.34, 0.34, 0.34, 0.25, 0.25, 0.25, 0.25, 0.25, 0.32, 0.32, 0.32, 0.32, 0.32, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.18, 0.18, 0.18, 0.18, 0.18, 0.3, 0.3, 0.3, 0.3, 0.3, 0.11, 0.11, 0.11, 0.11, 0.11, 0.28, 0.28, 0.28, 0.28, 0.28, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.33, 0.33, 0.33, 0.33, 0.33, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.25, 0.25, 0.25, 0.25, 0.25, 0.27, 0.27, 0.27, 0.27, 0.27, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.32, 0.32, 0.32, 0.32, 0.32, 0.37, 0.37, 0.37, 0.37, 0.37, 0.23, 0.23, 0.23, 0.23, 0.23, 0.16, 0.16, 0.16, 0.16, 0.16, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.08, 0.08, 0.08, 0.08, 0.08, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.3, 0.3, 0.3, 0.3, 0.3, 0.43, 0.43, 0.43, 0.43, 0.43, 0.58, 0.58, 0.58, 0.58, 0.58, 0.6, 0.6, 0.6, 0.6, 0.6, 0.64, 0.64, 0.64, 0.64, 0.64, 0.56, 0.56, 0.56, 0.56, 0.56, 0.6, 0.6, 0.6, 0.6, 0.6, 0.31, 0.31, 0.31, 0.31, 0.31, 0.2, 0.2, 0.2, 0.2, 0.2, 0.25]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 415 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1714489093 --> 1714489723
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0]
                    

llama.cpp Outdated
// So because our original wo matrix wasn't 3x, the below function fails because there aren't enough elems in it.
// Got: [head_dim][n_tokens][n_head_v]
// Want: [n_embd_v_gqa(384)][n_tokens]
// I guess this means that i need to be able to able to repeat them
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the problem, these things need to be repeated like they are in the python part on line 10806

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.num_groups != 1:
                        # GQA
                        # [B, k_h, S, h] --> [B, q_h, S, h] // so, k=3 -> q=12
                        keys = keys.repeat_interleave(self.num_groups, dim=1)
                        # [B, v_h, S, h] --> [B, q_h, S, h] // so, v=3 -> q=12

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to do something similar to this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is solved now

LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);

cur = llm_build_ffn(ctx0, cur,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now i think i just need to implement LLM_FFN_SWIGLU which is slightly different to the existing activation functions i think:
image

@joshcarp
Copy link
Author

Okay so this does the build purely because the tensor sizes all match up, which doesn't necesarily mean that it's the correct implementation shown by the fact that there's still a sigsev somewhere

@joshcarp
Copy link
Author

Debugging this because some of the transformations are obvs wrong:

ggml_debug:                 inp_embd = (f32)   GET_ROWS(token_embd.weight{1280, 32000, 1, 1}, inp_tokens{3, 1, 1, 1}}) = {1280, 3, 1, 1}
                                     sum = 138204.000000
ggml_debug:                   norm-0 = (f32)   RMS_NORM(inp_embd{1280, 3, 1, 1}, }) = {1280, 3, 1, 1}
                                     sum = 138204.015625
ggml_debug:              attn_norm-0 = (f32)        MUL(norm-0{1280, 3, 1, 1}, blk.0.attn_norm.weight{1280, 1, 1, 1}}) = {1280, 3, 1, 1}
                                     sum = 138204.000000
ggml_debug:                   wqkv-0 = (f32)    MUL_MAT(blk.0.attn_qkv.weight{1280, 1152, 1, 1}, attn_norm-0{1280, 3, 1, 1}}) = {1152, 3, 1, 1}
                                     sum = 124380.968750
ggml_debug:            wqkv-0 (view) = (f32)       VIEW(wqkv-0{1152, 3, 1, 1}, }) = {64, 3, 12, 1}
                                     sum = 8722732.000000
ggml_debug:     wqkv-0 (view) (cont) = (f32)       CONT(wqkv-0 (view){64, 3, 12, 1}, }) = {64, 3, 12, 1}
                                     sum = 497453.906250
ggml_debug:                   norm-0 = (f32)   RMS_NORM(wqkv-0 (view) (cont){64, 3, 12, 1}, }) = {64, 3, 12, 1}
                                     sum = 497532.062500
ggml_debug:                   Qcur-0 = (f32)        MUL(norm-0{64, 3, 12, 1}, blk.0.attn_q_norm.weight{64, 1, 1, 1}}) = {64, 3, 12, 1}
                                     sum = 497535.562500
ggml_debug:        Qcur-0 (reshaped) = (f32)    RESHAPE(Qcur-0{64, 3, 12, 1}, }) = {64, 12, 3, 1}
                                     sum = 497535.593750
ggml_debug:                   Qcur-0 = (f32)       ROPE(Qcur-0 (reshaped){64, 12, 3, 1}, inp_pos{3, 1, 1, 1}}) = {64, 12, 3, 1}
                                     sum = 497535.593750
ggml_debug:                   Qcur-0 = (f32)      SCALE(Qcur-0{64, 12, 3, 1}, }) = {64, 12, 3, 1}
                                     sum = 497458.562500
ggml_debug:            wqkv-0 (view) = (f32)       VIEW(wqkv-0{1152, 3, 1, 1}, }) = {64, 3, 3, 1}
                                     sum = 1002128.312500
ggml_debug:     wqkv-0 (view) (cont) = (f32)       CONT(wqkv-0 (view){64, 3, 3, 1}, }) = {64, 3, 3, 1}
                                     sum = 62096.066406
ggml_debug:                   norm-0 = (f32)   RMS_NORM(wqkv-0 (view) (cont){64, 3, 3, 1}, }) = {64, 3, 3, 1}
                                     sum = 62041.996094
ggml_debug:                   ''Kcur-0 = (f32)        MUL(norm-0{64, 3, 3, 1}, blk.0.attn_k_norm.weight{64, 1, 1, 1}}) = {64, 3, 3, 1}
                                     sum = 62035.175781
ggml_debug:        Kcur-0 (reshaped) = (f32)    RESHAPE(Kcur-0{64, 3, 3, 1}, }) = {64, 3, 3, 1}
                                     sum = 62035.175781
ggml_debug:                  node_16 = (f32)       ROPE(Kcur-0 (reshaped){64, 3, 3, 1}, inp_pos{3, 1, 1, 1}}) = {64, 3, 3, 1}
                                     sum = 62035.175781
ggml_debug:                  node_17 = (f32)     REPEAT(node_16{64, 3, 3, 1}, }) = {64, 3, 12, 1}
                                     sum = 497318.656250
ggml_debug:            wqkv-0 (view) = (f32)       VIEW(wqkv-0{1152, 3, 1, 1}, }) = {64, 3, 3, 1}
                                     sum = 1002132.000000
ggml_debug:                   Vcur-0 = (f32)       CONT(wqkv-0 (view){64, 3, 3, 1}, }) = {64, 3, 3, 1}
                                     sum = 62100.000000
ggml_debug:                  node_20 = (f32)     REPEAT(Vcur-0{64, 3, 3, 1}, }) = {64, 3, 12, 1}
                                     sum = 497448.000000
ggml_debug:                      v-0 = (f16)       VIEW(cache_v_l0{655360, 1, 1, 1}, }) = {32, 64, 6, 1}
                                     sum = nan

@uhlhosting
Copy link

How is progress on this?

@hunt-47
Copy link

hunt-47 commented May 3, 2024

mahn, we all are waiting for you...!!
@joshcarp nice work by the way..

@joshcarp
Copy link
Author

joshcarp commented May 3, 2024

Still slaving away at debugging,
Like i said before, first time in this codebase so things are a bit wack.
All of the parameter and tensor parsing/shaping i'm pretty certain is correct, i'm comparing against the reference implementation through corenet+huggingface

@joshcarp
Copy link
Author

joshcarp commented May 7, 2024

Closing because i'm not making any meaningful progress on this

@joshcarp joshcarp closed this May 7, 2024
@icecream95 icecream95 mentioned this pull request May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for OpenELM of Apple
3 participants