Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segment Anything Model #552

Merged
merged 24 commits into from
Jun 2, 2024
Merged

Segment Anything Model #552

merged 24 commits into from
Jun 2, 2024

Conversation

gathierry
Copy link
Contributor

@gathierry gathierry commented Mar 9, 2024

Add Segment Anything Model (SAM) in MLX.

@awni
Copy link
Member

awni commented May 20, 2024

@gathierry should we review and merge this? I'm not sure what the current status is, but let me know what you think.

@gathierry
Copy link
Contributor Author

Hi @awni , yes please review it at your convenience.
I think it's almost done except that the ConvTranspose2d implementation in very naive and may not be able to generalized enough. I can only confirm it works in this model.

@awni
Copy link
Member

awni commented May 26, 2024

@gathierry this is really nicely done! Thanks for the example. I started looking at it today and send some changes to your branch.

There's two high-level things I think we should aim to improve:

  1. Simplify as much as possible. I understand the model has a lot of moving pieces, but to the extent that we can simplify how they all interact (maybe at the cost of some flexibility) that will make the example much more approachable
  2. I would like to remove the dependence on torch in the amg class. I notice that there is still some non-trivial sections done using torch. Is that something we could work towards?

@gathierry
Copy link
Contributor Author

Thanks for the comments and improvements @awni
For 2

I would like to remove the dependence on torch in the amg class. I notice that there is still some non-trivial sections done using torch. Is that something we could work towards?

I tried to use mlx first but found it much slower than torch. I thought it was because some Boolean masking and nonzero operators that mlx didn't support yet. So I had to convert it to numpy back and forth.
But that was 3 months ago. I don't know if it's still the case.

@awni
Copy link
Member

awni commented May 26, 2024

I tried to use mlx first but found it much slower than torch.

Do you have a branch with that by any chance? We can profile and improve the ops if it's a bottleneck in MLX.

But that was 3 months ago.

PS sorry for the long delay on this. It kind of fell through the cracks. But I am quite keen to get an object segmentation example working!

@gathierry
Copy link
Contributor Author

I don't have an existing branch for that but I can try to write one and compare them. I remember the gap was pretty big but maybe we can try to improve it.

@gathierry
Copy link
Contributor Author

gathierry commented May 27, 2024

Hi @awni , I have amg implemented in torch in this branch.
I test it again but very roughly in the notebook and the speed is just a little slower than torch (<20%). And this is likely come from the overhead converting mx.array to numpy back and forth caused by the three places not implemented in mlx:

  • indexing with boolean mask code
  • nonzero code
  • torchvision.batched_nms

What do you think? Maybe the third one is the easiest one to start?

@awni
Copy link
Member

awni commented May 27, 2024

Thanks for adding that!

Which notebook did you test, the amg one?

What do you think? Maybe the third one is the easiest one to start?

Each of those ops is a bit tricky because they have output shapes which depend on input data. But I would like to take a look and see where the slowdown is coming from. It might be from the conversion to numpy but it could be something else so it would be good to verify.

@gathierry
Copy link
Contributor Author

Yes, the amg one.

@gathierry
Copy link
Contributor Author

I'm trying to profiling the filter function and run the amg notebook, and I feel there's indeed a gap between mlx (with a workaround) and torch. Please correct me if I'm wrong.
For mlx
0 04520195908344471
for torch
Pasted Graphic 1

    def filter(self, keep: mx.array) -> None:
        import time
        t1 = time.perf_counter()
        for k, v in self._stats.items():
            if v is None:
                self._stats[k] = None
            elif isinstance(v, mx.array):
                # TODO: fix this with mlx
                # self._stats[k] = mx.array(np.array(v)[np.array(keep)])
                self._stats[k] = mx.array([a for i, a in enumerate(v) if keep[i]])
            elif isinstance(v, np.ndarray):
                self._stats[k] = v[np.array(keep)]
            elif isinstance(v, list) and keep.dtype == mx.bool_:
                self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
            elif isinstance(v, list):
                self._stats[k] = [v[i] for i in keep]
            else:
                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
        t2 = time.perf_counter()
        print(type(v), keep.shape, t2 - t1)

@awni
Copy link
Member

awni commented May 28, 2024

On an M2 Ultra the fully MLX version seems to be overall faster (looking at total time)

Hybrid:

python main.py --model mlx_model --input notebooks/images/dog.jpg --output   29.91s user 86.16s system 1465% cpu 7.921 total

All MLX:

python main.py --model mlx_model --input notebooks/images/dog.jpg --output   2.06s user 6.41s system 127% cpu 6.628 total

@awni
Copy link
Member

awni commented May 28, 2024

@gathierry would you mind updating this PR to use the MLX version? Then we can just focus on optimizing it. I don't think we will merge the torch version anyway. You can keep it in a side branch for reference (or it will also be in the git history).

@gathierry
Copy link
Contributor Author

Updated to pure mlx version

@awni
Copy link
Member

awni commented May 30, 2024

Thanks for sending the pure MLX version. I'm noticing the main script isn't working 😓 . I tried:

python main.py --model mlx_model/sam-vit-base --input notebooks/images/dog.jpg --output dogs

It's possible I broke something in a previous refactor but lmk if you have any ideas.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this amazing addition!!

@awni awni merged commit 8353bbb into ml-explore:main Jun 2, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants