-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU implementation of QR factorization [wip] #975
base: main
Are you sure you want to change the base?
Conversation
67816c3
to
5c205fb
Compare
mlx/backend/metal/qrf.cpp
Outdated
auto compute_encoder = | ||
metal::CommandEncoder(command_buffer->computeCommandEncoder()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this, see above
mlx/backend/metal/qrf.cpp
Outdated
compute_encoder.set_input_array(betas, 0); | ||
compute_encoder.set_input_array(Y, 1); | ||
compute_encoder.set_input_array(a, 2); | ||
compute_encoder.set_input_array(Wp, 3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for rebasing this is good!
mlx/backend/metal/qrf.cpp
Outdated
const MTL::Size threads_per_threadgroup(8, 1, 1); | ||
compute_encoder->dispatchThreads( | ||
threads_per_grid, threads_per_threadgroup); | ||
compute_encoder->endEncoding(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't end the encoding. MLX device will handle this
mlx/backend/metal/qrf.cpp
Outdated
auto command_buffer = device.new_command_buffer(stream.index); | ||
|
||
auto compute_encoder = | ||
metal::CommandEncoder(command_buffer->computeCommandEncoder()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as above.
mlx/backend/metal/qrf.cpp
Outdated
compute_encoder->endEncoding(); | ||
|
||
device.commit_command_buffer(stream.index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't end encoding or commit here. MXL device will handle it.
@awni I tried to apply your comments and pushed 501b889 to avoid creating a new command buffer for each kernel, but I get:
|
Did you manually make a command encoder from the command buffer? MLX manages an active command encoder so you should not make it directly. Rather call the |
I also tried doing that in b979ccf which just produces the wrong result. |
6aecb32
to
729e011
Compare
for (int k = 0; k < batch_size; k++) { | ||
for (int i = 0; i < m; i++) { | ||
for (int j = 0; j < m; j++) { | ||
const auto batch_offset = m * n * k; | ||
const auto loc = batch_offset + colmajor_idx(i, j, m); | ||
q.data<float>()[loc] = i == j ? 1 : 0; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should probably be a kernel.
@@ -172,7 +172,7 @@ inline size_t elem_to_loc( | |||
|
|||
template <typename T> | |||
void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { | |||
int num_print = 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for debugging?
I fixed the code (needed to introduce one more kernel to ensure the atomics were synchronized properly across different threadgroups). It's a bit slow, so I'll try to improve it now:
|
@nicolov are you planning to come back to this? |
Proposed changes
Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:
Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs
Jan Priessnitz, GPU acceleration of matrix factorization
Here is the reference code in numpy for the algorithm.
Left todo
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes