Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature complete Metal FFT #1102

Merged
merged 7 commits into from
Jun 6, 2024

Conversation

barronalex
Copy link
Collaborator

@barronalex barronalex commented May 10, 2024

Proposed changes

A feature complete GPU FFT implementation in Metal.

Supports

  • All n < 2^20
  • Real and Inverse FFTs: fft, ifft, rfft, irfft
  • ND FFTs: fft2, ifft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn

Algorithms

  • A mixed radix out of place Stockham FFT for n where all prime factors p have 2 =< p <= 13.
  • Rader's Algorithm for n with one prime factor p > 13 where p-1 can be computed via Stockholm.
  • Bluestein's Algorithm for all other n.
  • Four Step FFT for n > 4096 when the FFT can no longer be done purely in GPU shared memory.

Performance

For 2 <= n < 512, 1D complex to complex FFTs on my M1 Max, the average bandwidths are:

MLX GPU: 162.9 GB/s
MPS GPU: 69.3 GB/s
MLX CPU: 5.9 GB/s

So this implementation is about 2.3x faster than MPS on average and about 27x faster than CPU MLX which uses pocketfft.

This implementation does specialize for different values of n with Metal function constants so it will have more overhead than MPS on the first call for new Stockham/Rader sizes.

Radix 2-13
Bluestein's

@@ -255,6 +257,96 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels a bit wrong having this as a primitive but I wasn't sure if there's a better way to it.

@@ -357,7 +357,6 @@ MTL::Function* Device::get_function_(
}

mtl_func_consts->release();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jagrit06 I was getting segfaults caused by this release when using function constants, but couldn't figure out the best place in the code to move it to. Any idea where it should fit in?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni
Copy link
Member

awni commented May 13, 2024

Very impressive perf!

Regarding the design, there is a big style difference from other MLX ops which we should change if possible. Basically you do the dispatch at the op-level rather than the Primitive level. I see how this might be easier as you have access to all the ops you need for different FFT algorithms, but I don't think we should do it this way. The compute graph should be more independent of the implementation details. Also, I don't think it makes sense for the FFT plans themselves should not be part of the compute graph (implementation detail).

This redesign may require some changes to our existing backend to make it workable for you to use the requisite back-end ops from the FFT primitive's eval_gpu.

@barronalex
Copy link
Collaborator Author

That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient!
Let me give the re-write a go today, I don't think it'll be too bad.

@awni
Copy link
Member

awni commented May 13, 2024

That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient!
Let me give the re-write a go today, I don't think it'll be too bad.

We have really bad support for doing stuff on arrays inside primitives (MLX wasn't really designed with that in mind 😓 ). But I think we can improve it a lot if needed.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome perf and generally very nice work! Kudos!

I left a comment on BluesteinFFTSetup to maybe avoid the double precision math. I think it should be doable, let me know if I am missing something or if it feels too experimental.

// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of section IV.E of https://mc.stanford.edu/cgi-bin/images/7/75/SC08_FFT_on_GPUs.pdf . Would it solve our problem here to avoid double precision arithmetic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fig 6 is very promising :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice! I simplified the double precision part a bit so I think I'm going to keep it for now since it's not really an accuracy or performance bottleneck. Happy to revisit in the future though.

@barronalex
Copy link
Collaborator Author

OK that took a little while but I think the FFTs are in a reasonable state now:

  • All the GPU planning/running logic has been moved to metal/fft.cpp so we're not bloating the graph at all
  • Added a no transpose four step FFT implementation so big powers of two are fast now (~100-140GB/s on M1 Max)
  • Added FFT to the JIT
  • Refactored the reading/writing so we now support RFFT/IRFFT for Stockham/Rader/Bluestein/4 Step directly in the kernel

mlx/fft.cpp Outdated
Comment on lines 84 to 86
// GPU scatter for complex64 is NYI
in =
scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, Device::cpu);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do that with a slice_update instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds nicer -- I'll update it

#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is why you don't need to use the utils() in the JIT, because its already included here by the preprocessor.

To keep the JIT source small, it would be better to move the includes that we already have in the JIT out of this file (e.g. kernels/utils.h) and use the utils() when constructing the JIT source.

You can include kernels/utils.h in fft.metal before you include fft.h. I would just turn off clang formatting for that whole file and it won't mess with the include order.


#include <metal_common>

#include "mlx/backend/metal/kernels/utils.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note you should also remove the include here.

Comment on lines 343 to 357
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}

// Complex mul followed by conjugate
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
}

// Compute an FFT twiddle factor
METAL_FUNC float2 get_twiddle(int k, int p) {
float theta = -2.0f * k * M_PI_F / p;

float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
return twiddle;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the only reason you are using utils.h is for these, it might be cleaner to just put those in fft.h instead. I think they also just fit better in fft.h if it works.. we have the complex64_t which should be used in general for complex muls.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it's definitely a bit confusing otherwise. I've removed the utils.h import.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 🚀

@barronalex barronalex merged commit 27d70c7 into ml-explore:main Jun 6, 2024
5 checks passed
@barronalex barronalex deleted the ab-metal-fft-complete branch June 6, 2024 19:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants