Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPU implementation of QR factorization [wip] #975

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nicolov
Copy link
Contributor

@nicolov nicolov commented Apr 9, 2024

Proposed changes

Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:

Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs
Jan Priessnitz, GPU acceleration of matrix factorization

Here is the reference code in numpy for the algorithm.

Left todo

  • clean up handling of batched inputs: slice the inputs/outputs and only pass the slice to the algorithm. Temporaries need only be sized for a single input matrix.
  • share some constants between the kernel and the driver function.
  • consider merging the two kernels to compute W.
  • benchmark and optimize grid/block sizes.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@nicolov nicolov marked this pull request as draft April 11, 2024 14:32
@nicolov nicolov force-pushed the nicolov-qr-gpu branch 2 times, most recently from 67816c3 to 5c205fb Compare April 11, 2024 14:44
mlx/backend/metal/qrf.cpp Outdated Show resolved Hide resolved
Comment on lines 122 to 200
auto compute_encoder =
metal::CommandEncoder(command_buffer->computeCommandEncoder());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, see above

Comment on lines 127 to 194
compute_encoder.set_input_array(betas, 0);
compute_encoder.set_input_array(Y, 1);
compute_encoder.set_input_array(a, 2);
compute_encoder.set_input_array(Wp, 3);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for rebasing this is good!

const MTL::Size threads_per_threadgroup(8, 1, 1);
compute_encoder->dispatchThreads(
threads_per_grid, threads_per_threadgroup);
compute_encoder->endEncoding();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't end the encoding. MLX device will handle this

Comment on lines 189 to 240
auto command_buffer = device.new_command_buffer(stream.index);

auto compute_encoder =
metal::CommandEncoder(command_buffer->computeCommandEncoder());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as above.

Comment on lines 211 to 260
compute_encoder->endEncoding();

device.commit_command_buffer(stream.index);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't end encoding or commit here. MXL device will handle it.

@nicolov
Copy link
Contributor Author

nicolov commented Apr 15, 2024

@awni I tried to apply your comments and pushed 501b889 to avoid creating a new command buffer for each kernel, but I get:

-[AGXG13XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1015: failed assertion `A command encoder is already encoding to this command buffer'

@awni
Copy link
Member

awni commented Apr 15, 2024

Did you manually make a command encoder from the command buffer? MLX manages an active command encoder so you should not make it directly. Rather call the device.get_command_encoder() to get the active encoder.

@nicolov
Copy link
Contributor Author

nicolov commented Apr 15, 2024

Rather call the device.get_command_encoder() to get the active encoder.

I also tried doing that in b979ccf which just produces the wrong result.

@nicolov
Copy link
Contributor Author

nicolov commented Apr 15, 2024

I also tried tracing and XCode complains about redundant bindings. Should I somehow refactor how I bind buffers to the encoder?

Screenshot 2024-04-15 at 3 09 58 PM

@nicolov nicolov force-pushed the nicolov-qr-gpu branch 2 times, most recently from 6aecb32 to 729e011 Compare April 15, 2024 19:12
Comment on lines +585 to +581
for (int k = 0; k < batch_size; k++) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
const auto batch_offset = m * n * k;
const auto loc = batch_offset + colmajor_idx(i, j, m);
q.data<float>()[loc] = i == j ? 1 : 0;
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should probably be a kernel.

@@ -172,7 +172,7 @@ inline size_t elem_to_loc(

template <typename T>
void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
int num_print = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for debugging?

@nicolov
Copy link
Contributor Author

nicolov commented Apr 17, 2024

I fixed the code (needed to introduce one more kernel to ensure the atomics were synchronized properly across different threadgroups). It's a bit slow, so I'll try to improve it now:

  device     n  time_ms
0    cpu  2000    99.39
1    gpu  2000   283.36

@awni
Copy link
Member

awni commented Apr 25, 2024

@nicolov are you planning to come back to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants