Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT compile option for binary minimization #1091

Merged
merged 17 commits into from
May 22, 2024
Merged

JIT compile option for binary minimization #1091

merged 17 commits into from
May 22, 2024

Conversation

awni
Copy link
Member

@awni awni commented May 8, 2024

  • Adds a build flag MLX_METAL_JIT to reduce the Metal library size by using runtime compilation.
  • Big refactor of unary, binary, ternary, copy, scatter, gather to allow JIT compilation
  • Current MTL library size 15M mlx.metallib

[[kernel]] void {0}_v(
device const {1}* in,
device {2}* out,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put the kernels here (binary kernels in the binary.cpp, and so on). Maybe better to put them in a different file in kernels/ that gets included? Not sure if either of you have a preference there @angeloskath @jagrit06

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda like it like this. If we do want to do this for more complicated kernels maybe we need a solution like the preamble but for unary, binary, ternary this is pretty great imho.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will likely be a combination of the two. Anything that needs to be formatted at runtime will likely be like this. And we can try to keep that bit to a minimum by having it be just the instantiations essentially.

The other stuff I will probably put in preambles. But I don't think it makes sense to do it as one giant preamble since that won't scale. So likely I will change the preamble / include stuff to be a little more modular

@awni
Copy link
Member Author

awni commented May 9, 2024

Benchmarks:

No degradation in token generation:

python -m mlx_lm.generate --model mlx-community/NeuralBeagle14-7B-4bit-mlx --prompt "Write a story about Albert Einstein" --temp 0.0 --max-tokens 256
Pre:
Prompt: 219.423 tokens-per-sec
Generation: 107.316 tokens-per-sec

Post:
Prompt: 219.580 tokens-per-sec
Generation: 107.562 tokens-per-sec

Transformer training:

Pre: Iter 30: Train loss 7.943, It/sec 5.911, Peak memory 5.534 (GB)
Post: Iter 30: Train loss 7.923, It/sec 5.912, Peak memory 5.534 (GB)

LeNet training:

Pre: Test accuracy 0.982, Time 2.792 (s)
Post: Test accuracy 0.983, Time 2.798 (s)

MNIST:

Pre: Test accuracy 0.937, Time 0.639 (s)
Post: Test accuracy 0.929, Time 0.638 (s)

@awni awni force-pushed the more_jit_compile branch 2 times, most recently from 61d2f8b to 1ead2a5 Compare May 15, 2024 16:27
auto& d = metal::device(s.device);

std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angeloskath @jagrit06 this is where the function gets linked in differently depending on the compile flag MLX_METAL_JIT. You can see the different defnitions in jit_kernels.cpp and nojit_kernels.cpp.


MTL::ComputePipelineState* kernel;

if constexpr (mlx_metal_jit()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jagrit06 @angeloskath this is an example using the constexpr to figure out if we are JITing or not. It's not as messy as I thought it would be provided the right helper utilities.

#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"

#ifndef MLX_METAL_JIT
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor downside of going the constexpr route. But the only place we really need to use the preprocessor.

@awni awni force-pushed the more_jit_compile branch 2 times, most recently from bd04991 to f53af4a Compare May 16, 2024 00:53
@awni awni marked this pull request as ready for review May 16, 2024 03:01
@awni awni requested a review from jagrit06 May 16, 2024 03:01
@awni
Copy link
Member Author

awni commented May 16, 2024

@jagrit06 @angeloskath I think this is ready for review.

@awni awni changed the title [WIP] More JIT compile for binary minimization JIT compile option for binary minimization May 17, 2024
This was referenced May 19, 2024
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=ON -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about MinSizeRel (which I believe favors size to code speed), vs Release in terms of size. Also, have you tried strip?
https://stackoverflow.com/questions/38675403/how-to-config-cmake-for-strip-file/38676023

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't make a difference for what I tried for the GPU back-end. It might for the CPU, I haven't checked. But it's really meant as an option for deploying mostly on the GPU (when you want a really small CPU binary).

I haven't checked strip.. let me try and see if it reduces the size anymore.

@awni
Copy link
Member Author

awni commented May 22, 2024

For a review, the main thing to look at is:

  • Updated way the compiled includes are made: metal/CMakeLists.txt, metal/jit/includes.h, and metal/make_compiled_preamble.sh
  • The way primitives get or build kernels: metal/kernels.h and the corresponding implementations in metal/jit_kernels.cpp and metal/nojit_kernels.cpp
  • Look at an example for how that works (mostly reorganizing code): e.g. metal/kernels/unary.h, metal/kernels/unary.metal, metal/unary.cpp and the format template in metal/jit/unary.h. They all follow the same pattern.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great!

In MLX tradition one can take the simple route and not provide a JIT option for a new kernel which means no change to the usage at all. That means that even for us, we can add kernels as we 're used to and add a JIT option later.

kernels/reduction/ops.h
)
make_jit_source(scatter)
make_jit_source(gather)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gather, scatter always jitted I guess. Honestly, it seems fitting. Do you foresee speedups as well for these going down this route?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.. I kept them in the JIT for a few reasons:

  • Kernels are pretty simple
  • They take up a disproportionate amount of space for kernels which are mostly never used (e.g nidx > 2). I think 30Mb of the metal library.
  • With the JIT we don't need to worry about a hard ceiling on the number of arguments (up to the buffer limit) which is kind of nice albeit perhaps not that useful.
  • There is negligible additional change to cold start time JITing these.
  • I didn't love the way we build these kernels in the preprocessor, it was kind of hard to follow and I didn't feel like adding it back ;)

Do you foresee speedups as well for these going down this route?

If I understand your question correctly - do we plan to improve these kernels and will JITing them make it hard to do? I guess maybe.. if we get to that point and having them in the JIT is too annoying we can always revisit. (It's also pretty easy to add an instantiation which does not get included in the Metal library but is useful just for compile-time compilation.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand your question correctly

No I meant the exact opposite. These kernels feel very natural to be jitted. We could even imagine taking advantage of the fact that we are jitting to make on the fly kernels specific to index shapes for instance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes very good point! In fact I had considered doing specializations in the past for different index layouts but it was too unwieldy from a combinatorial perspective

@@ -5,9 +5,11 @@
#include <metal_atomic>
#include <metal_simdgroup>

#ifndef MLX_METAL_JIT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit annoying. I think we 'll waste a fair amount of time forgetting to do that. I don't have anything better to propose... just commenting.

Copy link
Member Author

@awni awni May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it's annoying. Getting the includes right is annoying in general.. perhaps the most annoying part of all of this. There is probably a way to avoid this but it requires some care in what you include where and in what order (which is quite brittle).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to remove a few of these and I think I can remove the rest in #1132.

@awni awni merged commit 226748b into main May 22, 2024
3 checks passed
@awni awni deleted the more_jit_compile branch May 22, 2024 19:57
jkaercher pushed a commit to jkaercher/mlx that referenced this pull request May 30, 2024
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants