Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/multidevice/communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp
${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
Expand Down Expand Up @@ -1143,6 +1144,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
Expand Down
2 changes: 2 additions & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Val;
f(Merge); \
f(Partition); \
f(Combine); \
f(MoeDispatch); \
f(MoeCombine); \
f(Swizzle); \
f(Resize); \
f(MatmulOp); \
Expand Down
53 changes: 53 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "multidevice/allocation_utils.h"
#include "multidevice/communication.h"
#include "multidevice/cuda_p2p.h"
#include "multidevice/dispatch_combine.h"
#include "multidevice/execution_utils.h"
#include "multidevice/symmetric_tensor.h"
#include "multidevice/utils.h"
Expand Down Expand Up @@ -386,6 +387,58 @@ void HostIrEvaluator::handle(P2PCommunication* communication) {
}
}

void HostIrEvaluator::handle(MoeDispatch* dispatch) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(dispatch->inX()).as<at::Tensor>();
auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as<at::Tensor>();
auto is_token_in_rank =
getKnownConcreteValue(dispatch->inIsTokenInRank()).as<at::Tensor>();

auto result = doMoeDispatch(
x,
topk_idx,
is_token_in_rank,
dispatch->numExperts(),
communicator_,
dispatch->backend());

expr_evaluator_.bind(dispatch->outX(), result.recv_x);
expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx);
expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx);
expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank);
expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank);
expr_evaluator_.bind(
dispatch->outTokensFromRank(), result.n_tokens_from_rank);
}

void HostIrEvaluator::handle(MoeCombine* combine) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

auto x = getKnownConcreteValue(combine->inX()).as<at::Tensor>();
auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as<at::Tensor>();
auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as<at::Tensor>();
auto n_tokens_to_rank =
getKnownConcreteValue(combine->inTokensToRank()).as<at::Tensor>();
auto n_tokens_from_rank =
getKnownConcreteValue(combine->inTokensFromRank()).as<at::Tensor>();

auto result = doMoeCombine(
x,
src_idx,
src_rank,
n_tokens_to_rank,
n_tokens_from_rank,
communicator_,
combine->backend());

expr_evaluator_.bind(combine->outX(), result.combined_x);
}

void HostIrEvaluator::handle(Wait* wait) {
Expr* expr = wait->communication();
auto* p2p_comm = dynamic_cast<P2PCommunication*>(expr);
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(LaunchKernel*) override;
void handle(Communication*) override;
void handle(P2PCommunication*) override;
void handle(MoeDispatch*) override;
void handle(MoeCombine*) override;
void handle(Wait*) override;
void handle(kir::ForLoop*) override;
void handle(hir::ForLoop*) override;
Expand Down
137 changes: 137 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,143 @@ std::string P2PCommunication::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

MoeDispatch::MoeDispatch(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* out_topk_idx,
TensorView* out_src_idx,
TensorView* out_src_rank,
TensorView* out_n_tokens_to_rank,
TensorView* out_n_tokens_from_rank,
TensorView* in_x,
TensorView* in_topk_idx,
TensorView* in_is_token_in_rank,
int64_t num_experts,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_topk_idx);
addInput(in_is_token_in_rank);
addOutput(out_x);
addOutput(out_topk_idx);
addOutput(out_src_idx);
addOutput(out_src_rank);
addOutput(out_n_tokens_to_rank);
addOutput(out_n_tokens_from_rank);
addDataAttribute(num_experts);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeDispatch)

std::string MoeDispatch::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Dispatch " << name() << " ("
<< "num_experts=" << numExperts() << ", "
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "topk_idx=" << inTopkIdx() << ", "
<< "is_token_in_rank=" << inIsTokenInRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeDispatch::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeDispatch::validate() {
NVF_CHECK(numExperts() > 0, "num_experts must be positive.");
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(inTopkIdx()->isA<TensorView>(), "topk_idx must be a TensorView.");
NVF_CHECK(
inTopkIdx()->getDataType().has_value() &&
isIntegralType(*inTopkIdx()->getDataType()),
"topk_idx must be integral.");
NVF_CHECK(
inIsTokenInRank()->getDataType().has_value() &&
inIsTokenInRank()->getDataType() == DataType::Bool,
"is_token_in_rank must be Bool.");
NVF_CHECK(
outTopkIdx()->getDataType().has_value() &&
isIntegralType(*outTopkIdx()->getDataType()),
"out_topk_idx must be integral.");
NVF_CHECK(
outSrcIdx()->getDataType().has_value() &&
isIntegralType(*outSrcIdx()->getDataType()),
"out_src_idx must be integral.");
NVF_CHECK(
outSrcRank()->getDataType().has_value() &&
isIntegralType(*outSrcRank()->getDataType()),
"out_src_rank must be integral.");
NVF_CHECK(
outTokensToRank()->getDataType().has_value() &&
isIntegralType(*outTokensToRank()->getDataType()),
"out_n_tokens_to_rank must be integral.");
NVF_CHECK(
outTokensFromRank()->getDataType().has_value() &&
isIntegralType(*outTokensFromRank()->getDataType()),
"out_n_tokens_from_rank must be integral.");
}

MoeCombine::MoeCombine(
IrBuilderPasskey passkey,
TensorView* out_x,
TensorView* in_x,
TensorView* in_src_idx,
TensorView* in_src_rank,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_src_idx);
addInput(in_src_rank);
addInput(in_n_tokens_to_rank);
addInput(in_n_tokens_from_rank);
addOutput(out_x);
addDataAttribute(backend);
validate();
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MoeCombine)

std::string MoeCombine::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Combine " << name() << " ("
<< "backend=" << backend() << ", "
<< "in=" << inX() << ", "
<< "src_idx=" << inSrcIdx() << ", "
<< "src_rank=" << inSrcRank() << ", "
<< "out=" << outX() << ")";
return ss.str();
}

std::string MoeCombine::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

void MoeCombine::validate() {
NVF_CHECK(inX()->isA<TensorView>(), "in_x must be a TensorView.");
NVF_CHECK(
inSrcIdx()->getDataType().has_value() &&
isIntegralType(*inSrcIdx()->getDataType()),
"in_src_idx must be integral.");
NVF_CHECK(
inSrcRank()->getDataType().has_value() &&
isIntegralType(*inSrcRank()->getDataType()),
"in_src_rank must be integral.");
NVF_CHECK(
inTokensToRank()->getDataType().has_value() &&
isIntegralType(*inTokensToRank()->getDataType()),
"in_n_tokens_to_rank must be integral.");
NVF_CHECK(
inTokensFromRank()->getDataType().has_value() &&
isIntegralType(*inTokensFromRank()->getDataType()),
"in_n_tokens_from_rank must be integral.");
}

namespace {
c10::intrusive_ptr<c10d::Work> postBroadcast(
Communication* communication,
Expand Down
Loading