-
Notifications
You must be signed in to change notification settings - Fork 859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comms #1097
Comms #1097
Conversation
I might prefer the name I also think |
mlx/dist/primitives.cpp
Outdated
auto ensure_row_contiguous = [](const array& arr) { | ||
if (arr.flags().row_contiguous) { | ||
return arr; | ||
} else { | ||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); | ||
copy(arr, arr_copy, CopyType::General); | ||
return arr_copy; | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is a better strategy for this. The CPU copy could be kind of slow / might be better to use a GPU copy prior to the comm.
I wonder if we should consider adding a ensure_contiguous(inputs)
op (which is meant for internal use only), but actually puts the copy in the graph if its needed.
mlx/dist/dist.h
Outdated
struct Group { | ||
virtual int rank() = 0; | ||
virtual int size() = 0; | ||
virtual std::shared_ptr<Group> split(int n) = 0; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we plan to compile with multiple communication backends simultaneously right?
In that case, it might be cleaner from a user perspective to make Group non-virtual and give it a payload which is like the implementation specific bit. Kind of like how Event
/ Buffer
are implemented.
Just a thought to keep the shared pointers out of the interface..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why that way would be better 🤷♂️ but I implemented it. It is pretty much the same thing as it just forces us to write the MPIGroup
in mpi.cpp (as we would) and hide it behind a std::shared_ptr<void>
. It is still pretty clean if not a little bit more cryptic and hard to follow and also forces us to have only one type of group at any given point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I'm not a big fan of the pattern of virtualization to change behavior at compile-time, feels like the wrong tool for the job. Maybe there is another option that is cleaner than the payload (which is indeed a bit cryptic).
Also it's meant as a suggestion.. if you think it's a lot more readable in the previous version feel free to revert it.
This is very nice and simple! Looks great! |
a43c554
to
7405b52
Compare
Is there any example? |
30501e0
to
e5f0e46
Compare
.gitignore
Outdated
# Negate mlx/dist | ||
!mlx/dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: don't need that anymore
@@ -167,6 +167,11 @@ else() | |||
set(MLX_BUILD_ACCELERATE OFF) | |||
endif() | |||
|
|||
find_package(MPI) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed now that you do it dynamically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it depends :-) We can skip it and then I 'd add an mpi.h defining all the functions that I am using. It is basically finding the header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions are fine actually, the types need defining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, let's leave it then! It should still build without MPI installed which is good.
mlx/distributed/CMakeLists.txt
Outdated
if (MPI_FOUND) | ||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) | ||
else() | ||
target_sources( | ||
mlx | ||
PRIVATE | ||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp | ||
) | ||
endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused by the building cases. If MPI is not available we build without it. But we also do linking at run time. Does it make sense to have a no distributed option in that case?
Another idea is that maybe we should build no_distributed
if MLX_BUILD_CPU=OFF
(rather than throwing in the copy function. It does not seem too odd to me to disable MPI if the CPU is disabled..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! Very good point. I will do that and move the copy inside the mpi implementation where it belongs.
Maan that was a very nice suggestion. It feels so much better now with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
This is huge, wish someone could write a tutorial of how to connect 2 Macs use MLX |
Usage docs coming soon! |
I can't wait to try this out!! |
Awesome work, so excited for this! Any idea how much throughput will be necessary for various use cases? Also, can MPI aggregate Thunderbolt links? |
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
Beginning of communication namespace (perhaps it should be named comms instead of dist). This is mostly to get feedback while implementing the rest of the primitives and figuring out how to package this in the distribution.
Interesting bits:
mlx::core::dist
defines a bunch of functions that are optionally implemented by a communication backend. Currently mpi.Stream communication_stream()
and all communication operations go in that CPU stream.