Skip to content

Commit 43e586f

Browse files
authored
Add Multi-Gpu Support to Direct Python Bindings (#4689)
This PR add MultiGpu Support to Direct Python Bindings. PR Stack: - #4689 **<<< This PR.** - #4697 - #4698 - #4704 - #4701 cc: @kshitij12345
1 parent 5927e55 commit 43e586f

15 files changed

Lines changed: 413 additions & 81 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,14 @@ endif()
335335

336336
if(BUILD_PYTHON)
337337
list(APPEND NVFUSER_SRCS
338-
${NVFUSER_PYTHON_BINDINGS}/distributed_tensor.cpp
339338
${NVFUSER_PYTHON_BINDINGS}/fusion_cache.cpp
340339
${NVFUSER_PYTHON_BINDINGS}/fusion_definition.cpp
341340
${NVFUSER_PYTHON_BINDINGS}/fusion_state.cpp
342341
${NVFUSER_PYTHON_BINDINGS}/segmentation.cpp
343342
${NVFUSER_PYTHON_BINDINGS}/translation.cpp
344343
${NVFUSER_PYTHON_BINDINGS}/translation_utils.cpp
345344
${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp
345+
${NVFUSER_PYTHON_COMMON}/distributed_tensor.cpp
346346
${NVFUSER_PYTHON_COMMON}/python_utils.cpp
347347
${NVFUSER_PYTHON_COMMON}/translation_names.cpp
348348
)
@@ -608,6 +608,7 @@ if(BUILD_PYTHON)
608608
${NVFUSER_PYTHON_DIRECT_BINDINGS}/bindings.cpp
609609
${NVFUSER_PYTHON_DIRECT_BINDINGS}/enum.cpp
610610
${NVFUSER_PYTHON_DIRECT_BINDINGS}/ir.cpp
611+
${NVFUSER_PYTHON_DIRECT_BINDINGS}/multidevice.cpp
611612
${NVFUSER_PYTHON_DIRECT_BINDINGS}/ops.cpp
612613
${NVFUSER_PYTHON_DIRECT_BINDINGS}/runtime.cpp
613614
${NVFUSER_PYTHON_DIRECT_BINDINGS}/direct_utils.cpp

csrc/multidevice/executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <multidevice/communication.h>
1818
#include <multidevice/communicator.h>
1919
#include <multidevice/multidevice.h>
20+
#include <runtime/fusion_kernel_runtime.h>
2021

2122
namespace nvfuser {
2223

python/nvfuser_direct/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,42 @@
2121
from ._C_DIRECT import * # noqa: F401,F403
2222

2323

24+
def execute_with_dtensors(fd, in_dtensors):
25+
"""
26+
Execute a fusion on a list of DTensor inputs.
27+
28+
Parameters
29+
----------
30+
fd : FusionDefinition
31+
The fusion definition to execute
32+
in_dtensors : list of DTensor
33+
The list of DTensor inputs to the fusion
34+
35+
Returns
36+
-------
37+
list of DTensor
38+
The list of DTensor outputs from the fusion
39+
"""
40+
import torch.distributed as dist
41+
from torch.distributed.tensor import DTensor
42+
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
43+
44+
inputs = [in_dtensor.to_local() for in_dtensor in in_dtensors]
45+
out_tensors = self.execute(inputs, auto_schedule=True)
46+
out_shardings = self.fec.get_output_shardings()
47+
assert len(out_tensors) == len(out_shardings)
48+
49+
out_dtensors: list[DTensor] = []
50+
for out_tensor, out_sharding in zip(out_tensors, out_shardings):
51+
mesh = dist.device_mesh.init_device_mesh("cuda", (out_sharding.mesh.size,))
52+
placements: list[Placement] = []
53+
for parallel_type in [_C_DIRECT.ParallelType.mesh_x]:
54+
axis: int = out_sharding.axis_sharded_on(parallel_type)
55+
placements.append(Replicate() if axis == -1 else Shard(axis))
56+
out_dtensors.append(DTensor.from_local(out_tensor, mesh, placements))
57+
return out_dtensors
58+
59+
2460
class FusionDefinition:
2561
"""
2662
A class for defining and executing fused operations in nvFuser.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
9+
#include <distributed_tensor.h>
10+
#include <exceptions.h>
11+
#include <ir/interface_nodes.h>
12+
#include <type.h>
13+
#include <utils.h>
14+
15+
namespace nvfuser {
16+
17+
void Sharding::setAxisIsShardedOn(
18+
const int64_t axis,
19+
const ParallelType parallel_type) {
20+
NVF_CHECK(isParallelTypeDeviceDim(parallel_type));
21+
NVF_CHECK(mesh_.size() > 0, "Cannot shard a non-distributed tensor.");
22+
const auto i = axis_sharded_on_.find(parallel_type);
23+
NVF_CHECK(
24+
i == axis_sharded_on_.end(),
25+
"Parallel type ",
26+
parallel_type,
27+
" was already used to shard axis ",
28+
i->second);
29+
axis_sharded_on_[parallel_type] = axis;
30+
}
31+
32+
int64_t Sharding::axisShardedOn(const ParallelType parallel_type) const {
33+
return getOrDefault(axis_sharded_on_, parallel_type, -1L);
34+
}
35+
36+
std::vector<Sharding> getOutputShardings(Fusion* fusion) {
37+
std::vector<TensorView*> all_tvs = fusion->allTvs();
38+
if (std::none_of(
39+
all_tvs.begin(),
40+
all_tvs.end(),
41+
std::mem_fn(&TensorView::hasDeviceMesh))) {
42+
return {};
43+
}
44+
45+
std::vector<Sharding> output_shardings;
46+
output_shardings.reserve(fusion->outputs().size());
47+
for (Val* out_val : fusion->outputs()) {
48+
if (auto* out_tv = dynamic_cast<TensorView*>(out_val)) {
49+
if (fusion->getOutputAlias(out_tv).hide_output) {
50+
continue;
51+
}
52+
const DeviceMesh& mesh = out_tv->getDeviceMesh();
53+
Sharding& output_sharding = output_shardings.emplace_back(mesh);
54+
if (mesh.size() > 0) {
55+
for (const ParallelType parallel_type : kParallelTypeDIDs) {
56+
if (const auto axis = getShardedLogicalAxis(out_tv, parallel_type);
57+
axis != -1) {
58+
output_sharding.setAxisIsShardedOn(axis, parallel_type);
59+
}
60+
}
61+
}
62+
} else {
63+
output_shardings.emplace_back(DeviceMesh());
64+
}
65+
}
66+
67+
return output_shardings;
68+
}
69+
70+
} // namespace nvfuser

python/python_frontend/distributed_tensor.h renamed to python/python_common/distributed_tensor.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
#include <ATen/core/TensorBody.h>
1212

13+
#include <fusion.h>
1314
#include <multidevice/device_mesh.h>
15+
#include <multidevice/utils.h>
1416
#include <type.h>
1517

16-
namespace nvfuser::python_frontend {
18+
namespace nvfuser {
1719

1820
class Sharding {
1921
public:
@@ -36,4 +38,9 @@ class Sharding {
3638
std::unordered_map<ParallelType, int64_t> axis_sharded_on_;
3739
};
3840

39-
} // namespace nvfuser::python_frontend
41+
// Returns the output shardings of the given fusion. As a short cut, if none of
42+
// the outputs have a device mesh, returns an empty vector indicating single-GPU
43+
// execution.
44+
std::vector<Sharding> getOutputShardings(Fusion* fusion);
45+
46+
} // namespace nvfuser

python/python_direct/bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// clang-format on
88

99
#include <bindings.h>
10+
#include <multidevice/communicator.h>
1011

1112
namespace nvfuser::python {
1213

@@ -16,7 +17,11 @@ void initNvFuserPythonBindings(PyObject* module) {
1617
bindFusionIr(nvfuser);
1718
bindRuntime(nvfuser);
1819
bindOperations(nvfuser);
20+
bindMultiDevice(nvfuser);
1921
nvfuser.def("translate_fusion", &translateFusion);
22+
23+
auto cleanup = []() -> void { Communicator::getInstance().cleanup(); };
24+
nvfuser.add_object("_cleanup", py::capsule(cleanup));
2025
}
2126

2227
} // namespace nvfuser::python

python/python_direct/bindings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ void bindRuntime(py::module& nvfuser);
2727
// Add bindings for CPP Fusion Operations
2828
void bindOperations(py::module& nvfuser);
2929

30+
// Add bindings for MultiDevice features
31+
void bindMultiDevice(py::module& nvfuser);
32+
3033
// Translate a CPP Fusion to a bindings python function
3134
std::string translateFusion(Fusion* f);
3235

python/python_direct/enum.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ void bindEnums(py::module& nvfuser) {
3333
.value("ComplexFloat", DataType::ComplexFloat)
3434
.value("ComplexDouble", DataType::ComplexDouble)
3535
.value("Null", DataType::Null);
36+
37+
py::enum_<ParallelType>(nvfuser, "ParallelType")
38+
.value("mesh_x", ParallelType::DIDx)
39+
.value("grid_x", ParallelType::BIDx)
40+
.value("grid_y", ParallelType::BIDy)
41+
.value("grid_z", ParallelType::BIDz)
42+
.value("block_x", ParallelType::TIDx)
43+
.value("block_y", ParallelType::TIDy)
44+
.value("block_z", ParallelType::TIDz)
45+
.value("mma", ParallelType::Mma)
46+
.value("serial", ParallelType::Serial)
47+
.value("tma", ParallelType::Bulk)
48+
.value("unroll", ParallelType::Unroll)
49+
.value("unswitch", ParallelType::Unswitch)
50+
.value("vectorize", ParallelType::Vectorize)
51+
.value("stream", ParallelType::Stream);
52+
53+
py::enum_<CommunicatorBackend>(nvfuser, "CommunicatorBackend")
54+
.value("nccl", CommunicatorBackend::kNccl)
55+
.value("ucc", CommunicatorBackend::kUcc);
3656
}
3757

3858
} // namespace nvfuser::python

python/python_direct/ir.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,23 @@ Returns
6767
-------
6868
Val
6969
The extent of this domain.
70+
)")
71+
.def(
72+
"parallelize",
73+
&IterDomain::parallelize,
74+
py::arg("parallel_type"),
75+
R"(
76+
Set the parallel type of this domain.
77+
78+
Parameters
79+
----------
80+
parallel_type : ParallelType
81+
The type of parallelization to apply (e.g., BIDx, TIDx, etc.).
82+
83+
Notes
84+
-----
85+
This is a key function used in scheduling to specify how the domain should be parallelized
86+
across CUDA threads and blocks.
7087
)");
7188

7289
// TensorDomain
@@ -111,6 +128,72 @@ TensorDomain
111128
- Logical domain (The original dimensions. It may contain rFactor iterDomains.)
112129
- Allocation domain (How the memory is allocated for the tensor?)
113130
- Loop domain (The for-loop structure for the tensor.)
131+
)")
132+
.def(
133+
"get_loop_domain",
134+
&TensorView::getLoopDomain,
135+
R"(
136+
Get the loop domain of this tensor.
137+
138+
Returns
139+
-------
140+
list of IterDomain
141+
The loop iteration domains.
142+
)")
143+
.def(
144+
"split",
145+
static_cast<TensorView* (TensorView::*)(int64_t, int64_t, bool)>(
146+
&TensorView::split),
147+
py::arg("axis"),
148+
py::arg("factor"),
149+
py::arg("inner_split") = true,
150+
py::return_value_policy::reference,
151+
R"(
152+
Split an axis into two axes.
153+
154+
Parameters
155+
----------
156+
axis : int
157+
The axis to split.
158+
factor : int
159+
The factor to split by.
160+
inner_split : bool, optional
161+
If True, the factor determines the size of the inner domain.
162+
If False, the factor determines the size of the outer domain.
163+
Default is True.
164+
165+
Returns
166+
-------
167+
TensorView
168+
A TensorView with the split axes in its loop domain.
169+
)")
170+
.def(
171+
"set_allocation_domain",
172+
static_cast<void (TensorView::*)(std::vector<IterDomain*>, bool)>(
173+
&TensorView::setAllocationDomain),
174+
py::arg("new_allocation_domain"),
175+
py::arg("new_contiguity"),
176+
R"(
177+
Set the allocation domain of this tensor.
178+
179+
Parameters
180+
----------
181+
new_allocation_domain : list of IterDomain
182+
The new allocation iteration domains.
183+
new_contiguity : bool
184+
The new contiguity flag.
185+
)")
186+
.def(
187+
"set_device_mesh",
188+
&TensorView::setDeviceMesh,
189+
py::arg("mesh"),
190+
R"(
191+
Set the device mesh of this tensor.
192+
193+
Parameters
194+
----------
195+
mesh : DeviceMesh
196+
The device mesh to set.
114197
)")
115198
.def(
116199
"axis",

0 commit comments

Comments
 (0)