Skip to content

Commit 618673a

Browse files
authored
Migrate Tutorial.VectorizeStorePointwiseTMA and Tutorial.PointwiseBroadcastTMA to direct bindings (#5248)
* Add `parallelize_like` and `inline_most` to schedule API PR stack * #5247 * #5248 **<< This PR** * #5249
1 parent 33337e9 commit 618673a

2 files changed

Lines changed: 256 additions & 0 deletions

File tree

python/python_direct/schedule.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*/
77
// clang-format on
88
#include <bindings.h>
9+
#include <scheduler/tools/inlining.h>
10+
#include <scheduler/utils.h>
911
#include <transform_replay.h>
1012

1113
namespace nvfuser::python {
@@ -49,6 +51,71 @@ void bindTensorviewScheduleOps(py::module_& schedule) {
4951
)",
5052
py::arg("reference_tv"),
5153
py::arg("selected_tensors") = std::vector<TensorView*>());
54+
55+
schedule.def(
56+
"parallelize_like",
57+
[](TensorView* reference_tv,
58+
int64_t pos,
59+
const std::vector<TensorView*>& selected_tensors,
60+
const std::unordered_set<ParallelType>& selected_parallel_types,
61+
bool propagate_padding) {
62+
scheduler_utils::parallelizeAllLike(
63+
reference_tv,
64+
pos,
65+
selected_tensors,
66+
selected_parallel_types,
67+
propagate_padding);
68+
},
69+
R"(
70+
Propagate the parallelization from the selected dimensions of the
71+
reference tensor to their corresponding dimensions in all selected
72+
tensors in the DAG.
73+
74+
Parameters
75+
----------
76+
reference_tv : TensorView
77+
The reference TensorView whose parallelization will be propagated.
78+
pos : int, optional
79+
The position up to which dimensions should be selected. -1 means all dimensions.
80+
selected_tensors : List[TensorView], optional
81+
List of TensorViews to propagate parallelization to. If empty, propagates to all TensorViews.
82+
selected_parallel_types : Set[ParallelType], optional
83+
Set of parallel types to propagate. If empty, propagates all parallel types.
84+
propagate_padding : bool, optional
85+
Whether to propagate padding (default: True).
86+
87+
Returns
88+
-------
89+
None
90+
)",
91+
py::arg("reference_tv"),
92+
py::arg("pos") = -1,
93+
py::arg("selected_tensors") = std::vector<TensorView*>(),
94+
py::arg("selected_parallel_types") = std::unordered_set<ParallelType>(),
95+
py::arg("propagate_padding") = true);
96+
97+
schedule.def(
98+
"inline_most",
99+
[](const std::vector<TensorView*>& selected_tensors) {
100+
if (selected_tensors.empty()) {
101+
inlineMost();
102+
} else {
103+
inlineMost(selected_tensors);
104+
}
105+
},
106+
R"(
107+
Inline operations to the right most allowed position for the selected tensors.
108+
109+
Parameters
110+
----------
111+
selected_tensors : List[TensorView], optional
112+
List of TensorViews to inline. If empty, inlines all operations.
113+
114+
Returns
115+
-------
116+
None
117+
)",
118+
py::arg("selected_tensors") = std::vector<TensorView*>());
52119
}
53120

54121
} // namespace

tests/python/direct/test_tutorial.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,3 +1141,192 @@ def test_tutorial_basic_tma_example6(nvfuser_direct_test):
11411141
ke.compile(fd.fusion, [t0], compile_params=index32bit)
11421142
outputs = ke.run([t0])
11431143
assert outputs[0].equal(t0)
1144+
1145+
1146+
@pytest.mark.skipif(
1147+
is_pre_hopper(), reason="Only supported on Hopper and newer devices."
1148+
)
1149+
def test_tutorial_vectorize_store_pointwise_tma(nvfuser_direct_test):
1150+
with FusionDefinition() as fd:
1151+
tv0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True])
1152+
tv1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True])
1153+
tv2 = fd.ops.add(tv0, tv1)
1154+
fd.add_output(tv2)
1155+
1156+
# Create cache_tvs
1157+
tv0a = tv0.cache_after(LoadStoreOpType.tma)
1158+
tv1a = tv1.cache_after(LoadStoreOpType.tma)
1159+
tv2b = tv2.cache_before()
1160+
1161+
tv0a.set_memory_type(MemoryType.shared)
1162+
tv1a.set_memory_type(MemoryType.shared)
1163+
1164+
reference_tv = tv2
1165+
1166+
# Step 1: Create tma domain
1167+
# Use the root domain as TMA domain
1168+
# root domain: [I0, I1]
1169+
1170+
num_threads = 128
1171+
vectorization = 2
1172+
tma_tile = num_threads * vectorization
1173+
num_stages = 4
1174+
num_ctas_for_hopper = 132
1175+
1176+
# Step 2: Create Box
1177+
# After TMA domain creation
1178+
# split: [I0, I3, 256]
1179+
reference_tv.split(-1, tma_tile)
1180+
# split: [I2, 4, I3, 256]
1181+
reference_tv.split(0, num_stages)
1182+
1183+
# Step 3: Create Tile
1184+
# Do nothing here because box == tile
1185+
1186+
# Step 4: Schedule Shared Memory Tensor
1187+
# split: [I2, 4, I3, 128, 2]
1188+
reference_tv.split(-1, vectorization)
1189+
# split: [I4, 132, 4, I3, 128, 2]
1190+
reference_tv.split(0, num_ctas_for_hopper)
1191+
# reorder: [I4, 132, I3, 4, 128, 2]
1192+
reference_tv.reorder({3: 2, 2: 3})
1193+
1194+
# Transform Operations between cache operations and output reference
1195+
fd.sched.transform_like(reference_tv)
1196+
1197+
# Propagate common parallel dimensions
1198+
reference_tv.axis(1).parallelize(ParallelType.grid_x)
1199+
fd.sched.parallelize_like(reference_tv)
1200+
1201+
tv2b.axis(-2).parallelize(ParallelType.block_x)
1202+
1203+
# Vectorization for writing results to gmem
1204+
reference_tv.axis(-3).parallelize(ParallelType.unroll)
1205+
reference_tv.axis(-2).parallelize(ParallelType.block_x)
1206+
reference_tv.axis(-1).parallelize(ParallelType.vectorize)
1207+
1208+
# Apply bulk type to TMA tensors
1209+
tv0a.axis(-1).parallelize(ParallelType.tma)
1210+
tv0a.axis(-2).parallelize(ParallelType.tma)
1211+
tv0a.axis(-3).parallelize(ParallelType.tma)
1212+
1213+
tv1a.axis(-1).parallelize(ParallelType.tma)
1214+
tv1a.axis(-2).parallelize(ParallelType.tma)
1215+
tv1a.axis(-3).parallelize(ParallelType.tma)
1216+
1217+
# ComputeAt
1218+
fd.sched.inline_most()
1219+
1220+
if verbose_:
1221+
print(fd.fusion.print_math())
1222+
print(fd.fusion.print_kernel())
1223+
1224+
dim0 = 16384
1225+
dim1 = 16384
1226+
1227+
# Compile with KernelExecutor directly to avoid scheduling
1228+
index32bit = CompileParams(
1229+
index_type=DataType.Int32, maxrregcount=255, enable_magic_zero=False
1230+
)
1231+
t0 = torch.randn(dim0, dim1, dtype=torch.float, device="cuda:0")
1232+
t1 = torch.randn(dim0, dim1, dtype=torch.float, device="cuda:0")
1233+
t2 = t0 + t1
1234+
ke = KernelExecutor()
1235+
ke.compile(fd.fusion, [t0, t1], compile_params=index32bit)
1236+
outputs = ke.run([t0, t1])
1237+
assert outputs[0].equal(t2)
1238+
1239+
1240+
@pytest.mark.skipif(
1241+
is_pre_hopper(), reason="Only supported on Hopper and newer devices."
1242+
)
1243+
def test_tutorial_pointwise_broadcast_tma(nvfuser_direct_test):
1244+
with FusionDefinition() as fd:
1245+
tv0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True])
1246+
tv1 = fd.define_tensor(
1247+
shape=[-1, -1, -1, -1], contiguity=[True, False, True, True]
1248+
)
1249+
tv2 = fd.ops.broadcast(tv0, [True, False, False, False])
1250+
tv3 = fd.ops.add(tv2, tv1)
1251+
fd.add_output(tv3)
1252+
1253+
# Create cache_tvs
1254+
tv0a = tv0.cache_after(LoadStoreOpType.tma)
1255+
tv1a = tv1.cache_after(LoadStoreOpType.tma)
1256+
tv3b = tv3.cache_before(LoadStoreOpType.tma)
1257+
1258+
tv0a.set_memory_type(MemoryType.shared)
1259+
tv1a.set_memory_type(MemoryType.shared)
1260+
tv3b.set_memory_type(MemoryType.shared)
1261+
1262+
reference_tv = tv3
1263+
1264+
# Step 1: Create tma domain
1265+
# root domain: [I0, I1, I2, I3]
1266+
# TMA domain: [I0, I1, I4]
1267+
reference_tv.merge(-2, -1)
1268+
1269+
# Step 2: Define TMA Box
1270+
# split: [I0, I1, I5, 256]
1271+
reference_tv.split(-1, 256)
1272+
1273+
# Step 3: Define Tile
1274+
# Do nothing here because tile == box.
1275+
1276+
# Step 4: Schedule Shared Memory Tensor
1277+
# merge: [I10, I5, 256]
1278+
reference_tv.merge(0, 1)
1279+
# split: [I10, I7, 4, 256]
1280+
reference_tv.split(-2, 4)
1281+
# merge: [I11, 4, 256]
1282+
reference_tv.merge(0, 1)
1283+
1284+
# Transform Operations between cache operations and output reference
1285+
fd.sched.transform_like(reference_tv)
1286+
1287+
# Define Parallelization Schema
1288+
# Intermediate Tensors
1289+
tv3b.axis(0).parallelize(ParallelType.grid_x)
1290+
tv3b.axis(1).parallelize(ParallelType.unroll)
1291+
tv3b.axis(2).parallelize(ParallelType.block_x)
1292+
1293+
tv2.axis(0).parallelize(ParallelType.grid_x)
1294+
tv2.axis(1).parallelize(ParallelType.unroll)
1295+
tv2.axis(2).parallelize(ParallelType.block_x)
1296+
1297+
# TMA Tensors
1298+
tv1a.axis(0).parallelize(ParallelType.grid_x)
1299+
tv1a.axis(1).parallelize(ParallelType.block_x)
1300+
tv1a.axis(2).parallelize(ParallelType.tma)
1301+
1302+
tv0a.axis(0).parallelize(ParallelType.grid_x)
1303+
tv0a.axis(1).parallelize(ParallelType.block_x)
1304+
tv0a.axis(2).parallelize(ParallelType.tma)
1305+
1306+
tv3.axis(0).parallelize(ParallelType.grid_x)
1307+
tv3.axis(1).parallelize(ParallelType.block_x)
1308+
tv3.axis(2).parallelize(ParallelType.tma)
1309+
1310+
# ComputeAt
1311+
fd.sched.inline_most()
1312+
1313+
if verbose_:
1314+
print(fd.fusion.print_math())
1315+
print(fd.fusion.print_kernel())
1316+
1317+
dim0 = 32
1318+
dim1 = 2
1319+
dim2 = 4
1320+
dim3 = 256
1321+
1322+
# Compile with KernelExecutor directly to avoid scheduling
1323+
index32bit = CompileParams(
1324+
index_type=DataType.Int32, maxrregcount=255, enable_magic_zero=False
1325+
)
1326+
t0 = torch.randn(dim1, dim2, dim3, dtype=torch.float, device="cuda:0")
1327+
t1 = torch.randn(dim0, dim1, dim2, dim3, dtype=torch.float, device="cuda:0")
1328+
t2 = t0 + t1
1329+
ke = KernelExecutor()
1330+
ke.compile(fd.fusion, [t0, t1], compile_params=index32bit)
1331+
outputs = ke.run([t0, t1])
1332+
assert outputs[0].equal(t2)

0 commit comments

Comments
 (0)