Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinv #875

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open

Pinv #875

wants to merge 28 commits into from

Conversation

adhulipa
Copy link

@adhulipa adhulipa commented Mar 22, 2024

Proposed changes

Add Moore-Penrose Pseudo Inverse function. Inspired by the recent PRs from @nicolov in adding svd and inv, this PR adds the pinv primitive

Tests

Ran some tests locally and included them in PR

>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... A @ A_pinv @ A
>>>
>>>  A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
       [3, 4, 2, 2, 8],
       [2, 2, 1, 0.999999, 4],
       [5, 6, 7, 2, 3]], dtype=float32)

Re-Tested, and everything looks good


>>>
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2], [3, 4], [2, 2], [5, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1.0, 2], [3, 4.0]])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
(4, 2) ,	allclose?  array(True, dtype=bool)
(4, 5) ,	allclose?  array(True, dtype=bool)
(2, 2) ,	allclose?  array(True, dtype=bool)
>>>
>>>

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@angeloskath
Copy link
Member

@adhulipa I think this implementation is the wrong way to go about it. Using SVD to compute the pseudo inverse means we don't need a primitive and kernels etc. It is just an op that can reside in the mlx::core::linalg namespace.

Basically something like the following (of the top of my head so ymmv)

def pinv(x):
     U, S, V = mx.linalg.svd(x)
     return (V[:len(x)].T * 1/S) @ U.T

@adhulipa
Copy link
Author

Ahh I see! I didn’t think about that. Thanks for the review @angeloskath

I suppose we could modify this PR to merge in a Python form one as a first step and then investigate whether a custom kernel is necessary.

Would you recommend such a direction?

@awni
Copy link
Member

awni commented Mar 25, 2024

The op should be in C++ and then do a binding (we try to keep the C++ and Python APIs reasonably consistent). I think the Python impl from @angeloskath is just intended as pseudo-code for that.

@adhulipa
Copy link
Author

Ah yes that makes sense. I should add the Python api that matches the cpp api for pinv(). I haven’t gotten around to it. Thank you for taking a look folks!

@adhulipa adhulipa force-pushed the pinv branch 2 times, most recently from 1837c9a to 1b513c7 Compare April 2, 2024 21:54
@adhulipa
Copy link
Author

adhulipa commented Apr 5, 2024

I made a few updates. Still gotta figure out how to fix the cpp op where svd(A) returns u, s, vt where u has same dims as A (when rectangular). This makes the matmul incompatible.

I have a path to green where I need to tweak u to match expected end-shape.

(I’m positive the SVD approach works accurately because I validated it in Python api mlx; and few other langs such as matlab to be certain)

Also can use the PyTorch impl as a reference https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/aten/src/ATen/native/LinearAlgebra.cpp#L532

just need to get around to it in due time

// v* 5x4
auto inner = transpose(matmul(s_plus, u));
auto result = matmul(v, inner);
copy(result, pinv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's bug here (or it's call site where the pinv array is allocated) where the size/shape of the pinv array is not correctly setup

terminated by signal SIGSEGV (Address boundary error)

@awni
Copy link
Member

awni commented Apr 25, 2024

@adhulipa are you planning to come back to this?

@awni
Copy link
Member

awni commented May 3, 2024

@adhulipa are you planning to return to this one?

@adhulipa
Copy link
Author

adhulipa commented May 8, 2024

Hi @awni yes I will update this one. I am running into an issue where I haven't figured out how to allocate the rectangular array for the output/result array before passing it off to the PINV function.

Apologies for the delay; other priorities took precedence lately.

I think I should be able to dedicate a few hours this weekend -- likely 4-8 hours on 5/11

@adhulipa
Copy link
Author

@awni do you think it’s better to close this PR and reopen against a newer mainline commit? Happy to do so if it helps keep your PR todo list clean

@awni
Copy link
Member

awni commented May 13, 2024

Its more up to you. If you plan to work on it in the near future then you can keep it open (or start a new one if you prefer). If not, I would close it.

@adhulipa
Copy link
Author

I'll keep it open for now. I'll close and re-open if it gets too far behind significantly -- for now these changes are additive; so that's not a risk. It just needs a bit of polish/bugfixing.

@adhulipa
Copy link
Author

Made some progress. Need to fix a few more things.

@adhulipa
Copy link
Author

I am suspecting there's something I need to figure out with how im using mx.linalg.svd(A) or rather in cpp svd_imp() and then the matmuls for getting the pinv

Im seeing

>>> A = mx.array([[1.0, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
>>> U, S, Vt = mx.linalg.svd(A)
>>> U @ mx.diag(S) @ Vt
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: [matmul] Last dimension of first input with shape (4,4) must match second to last dimension of second input with shape (5,5).

Which seems contradictory to the MLX svd doc

Returns
The U, S, and Vt matrices, such that A = U @ diag(S) @ Vt

Of course, I ack MLX mimics the NumPy API and NumPy indeed also produces a similar result. But it looks like they have support for a full_matrices: Bool = True kwarg; which I suppose was designed to help for these types of cases

A = np.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
U, S, Vt = np.linalg.svd(A, full_matrices=False)
U @ np.diag(S) @ Vt

>>> U @ np.diag(S) @ Vt
array([[1., 2., 1., 1., 9.],
       [3., 4., 2., 2., 8.],
       [2., 2., 1., 1., 4.],
       [5., 6., 7., 2., 3.]])

>>> np.allclose(A, U @ np.diag(S) @ Vt)
True

(Fwiw, without full_matrics=False, error is same as mlx)

>>> U, S, Vt = np.linalg.svd(A)
>>> U @ np.diag(S) @ Vt
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 5 is different from 4)

@awni
Copy link
Member

awni commented May 23, 2024

I think you can just do something like this:

U, S, V = mx.linalg.svd(A)
K = min(A.shape[0], A.shape[1])
Atilde = (U[:, :K] * S) @ V[:K, :]

We could add the slicing as an option like Numpy if it's useful.

Also I would recommend you rebase before making further progress to make it easier to resolve conflicts.

@adhulipa
Copy link
Author

Ah thanks Awni! Will use that
ack on rebase.

@adhulipa
Copy link
Author

adhulipa commented May 27, 2024

Small update: Got a local build that correctly computes pinv in most of the tests. Cleaning up some things and polishing up the code.

>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
>>> A_pinv
array([[2.25371e-07, -1, 2, -7.40725e-08],
       [-0.408602, 1.2043, -1.44086, -0.064516],
       [0.363441, -0.48172, -0.0236563, 0.225806],
       [-0.346237, 0.873119, -0.894624, -0.0967741],
       [0.2, -0.2, 0.2, 1.00349e-08]], dtype=float32)

>>> A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
       [3, 4, 2, 2, 8],
       [2, 2, 1, 0.999999, 4],
       [5, 6, 7, 2, 3]], dtype=float32)

>>> ans = A @ A_pinv @ A
>>> mx.allclose(A, A @ A_pinv @ A)
array(True, dtype=bool)

Turns out I was incorrectly relying on the computation array graph API instead of computing the actual matrix products (D'oh!). Now I have some code locally using lapack's mm func (such as sgemm) to compute the final pinv product.

Will update this PR soon

@adhulipa adhulipa marked this pull request as ready for review May 27, 2024 07:10
@adhulipa
Copy link
Author

adhulipa commented May 27, 2024

Updated the PR. This PR is in a good enough shape for a review from @awni and other MLX folks. Thanks!

Perhaps there's one more thing to check (on my part) in the python tests. Will look into it. But in the meantime, this PR is still good for a review.

@adhulipa
Copy link
Author

Drats. I have another bug to fix. I updated the tests to catch it. Will look into and fix. Essentially, long rectangular matrices have a matmul dim mismatch -- which means I have made an error in the m, n, k calculations and/or slice selections or U/Vt

@adhulipa
Copy link
Author

Fixed the bug for rectangular matrices where M > N 🎉
Will publish commit soon

@adhulipa
Copy link
Author

This PR is ready for a review from Awni, Angelos and other MLX folks. Thanks!

@angeloskath
Copy link
Member

Hi @adhulipa . I think there shouldn't be a primitive for this operation. It can really just be an op in the linalg namespace.

@adhulipa
Copy link
Author

adhulipa commented May 29, 2024

Hi @angeloskath ohh I see. I think I may have misinterpreted something in the thread here then. Particularly what @awni shared after you (@angeloskath) shared that comment earlier.

The op should be in C++ and then do a binding

Is it accurate to say that you meant this should in linalg.cpp where we add something relatively simple

auto outs = linalg.svd(x);
array U = outs[0]; 
array S = outs[1];
...///  etc.. 

return (V[:len(x)].T * 1/S) @ U.T // of course, in the cpp variant instead of the python-esque here

@adhulipa
Copy link
Author

Actually, @angeloskath do you mean to say that we don't need a primitive; but all the logic of calling linalg::svd()' ensuring the svd() call is eval()'d and then passing to lapack's sgemm()` function calls all should be in the linalg.cpp file (and namespace)?

It seems like the recommendation here is to keep the core logic intact but just not make this a primitive. Am I understanding that correctly?

@adhulipa
Copy link
Author

adhulipa commented May 29, 2024

I think im starting to understand the motivation behind the c++ op sans primitive recommendation from Angelos. Pardon the roundabout way I needed to understand this 😅

The following change in linalg.cpp does pass the tests. Just checking a few more things before I can publish a new commit.

array pinv(const array& a, StreamOrDevice s /* = {} */) {
....

  const auto m = a.shape(-2);
  const auto n = a.shape(-1);
  const auto k = std::min(m, n);
  const auto rank = a.ndim();

  auto outs = linalg::svd(a, Device::cpu);
  auto U = outs[0];
  auto S = outs[1];
  auto Vt = outs[2];

....
....

  const auto U_slice = slice(U, {0, 0}, {m, k});
  const auto Vt_slice = slice(Vt, {0, 0}, {k, n});
  return matmul(matmul(transpose(Vt_slice), diag(1.0/S)), transpose(U_slice));;
}

@adhulipa
Copy link
Author

adhulipa commented Jun 4, 2024

@angeloskath @awni -- question: do you folks feel like this is in a good shape for a review? Of course, no rush from my pov; just thought I'd check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants