Skip to content

Commit 9b3875a

Browse files
authored
Migrate Tutorial.TMABankConflictFreeTranspose to direct bindings (#5249)
1 parent 9d70894 commit 9b3875a

4 files changed

Lines changed: 165 additions & 1 deletion

File tree

csrc/scheduler/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ void transformPropagateToAllFrom(TensorView* from_tv, int64_t pos);
565565
//!
566566
//! There are currently three modes of propagation: forward, backward and
567567
//! both-way, see comment on the interface functions for details.
568-
struct BoundedDirectionalTransformPropagator {
568+
struct NVF_API BoundedDirectionalTransformPropagator {
569569
//! Custom option container for configuring
570570
//! the transform propagation actions.
571571
//! All option values default to false unless

python/python_direct/ir.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,29 @@ Returns
403403
-------
404404
TensorView
405405
A TensorView with the reordered axes in its loop domain.
406+
)")
407+
.def(
408+
"swizzle",
409+
[](TensorView* self, int64_t x, int64_t y) {
410+
return self->swizzle(SwizzleType::XOR, x, y);
411+
},
412+
py::return_value_policy::reference,
413+
py::arg("x"),
414+
py::arg("y"),
415+
R"(
416+
Swizzle the axes of this tensor.
417+
418+
Parameters
419+
----------
420+
x : int
421+
The x axis to swizzle.
422+
y : int
423+
The y axis to swizzle.
424+
425+
Returns
426+
-------
427+
TensorView
428+
A TensorView with the swizzled axes in its loop domain.
406429
)")
407430
.def(
408431
"rfactor",

python/python_direct/schedule.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,43 @@ namespace nvfuser::python {
1515
namespace {
1616

1717
void bindTensorviewScheduleOps(py::module_& schedule) {
18+
schedule.def(
19+
"bounded_transform_backward",
20+
[](TensorView* from,
21+
int64_t pos,
22+
std::vector<TensorView*> to,
23+
bool propagate_parallel_type) {
24+
using TransformPropagator =
25+
scheduler_utils::BoundedDirectionalTransformPropagator;
26+
TransformPropagator::Options options;
27+
if (propagate_parallel_type) {
28+
options.propagateParallelType();
29+
}
30+
TransformPropagator::backward(from, pos, to, options);
31+
},
32+
R"(
33+
Propagate scheduler transformations from a reference TensorView to other TensorViews.
34+
35+
Parameters
36+
----------
37+
from : TensorView
38+
The reference TensorView whose transformations will be propagated.
39+
pos : int
40+
The position up to which dimensions should be selected. -1 means all dimensions.
41+
to : List[TensorView]
42+
List of TensorViews to propagate transformations to.
43+
propagate_parallel_type : bool
44+
Whether to propagate parallel type.
45+
46+
Returns
47+
-------
48+
None
49+
)",
50+
py::arg("from"),
51+
py::arg("pos"),
52+
py::arg("to"),
53+
py::arg("propagate_parallel_type") = false);
54+
1855
schedule.def(
1956
"transform_like",
2057
[](TensorView* reference_tv,

tests/python/direct/test_tutorial.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,3 +1330,107 @@ def test_tutorial_pointwise_broadcast_tma(nvfuser_direct_test):
13301330
ke.compile(fd.fusion, [t0, t1], compile_params=index32bit)
13311331
outputs = ke.run([t0, t1])
13321332
assert outputs[0].equal(t2)
1333+
1334+
1335+
@pytest.mark.skipif(
1336+
is_pre_hopper(), reason="Only supported on Hopper and newer devices."
1337+
)
1338+
def test_tutorial_tma_bank_conflict_free_transpose(nvfuser_direct_test):
1339+
with FusionDefinition() as fd:
1340+
input = fd.define_tensor(shape=[-1, -1], contiguity=[True, True])
1341+
output = fd.ops.permute(input, [1, 0])
1342+
fd.add_output(output)
1343+
1344+
# Change the fusion to input->smem->register->smem->output where the
1345+
# smem->register part does the transpose
1346+
input_smem_cache = input.cache_after(LoadStoreOpType.tma)
1347+
input_smem_cache.set_memory_type(MemoryType.shared)
1348+
1349+
output_smem_cache = output.cache_before(LoadStoreOpType.tma)
1350+
output_smem_cache.set_memory_type(MemoryType.shared)
1351+
1352+
output_reg_cache = output_smem_cache.cache_before()
1353+
1354+
# Create 32x32 tile. Each CTA has one tile, and the entire tile will be
1355+
# loaded to shared memory by TMA, and stored back to global memory by TMA.
1356+
1357+
# [I1, I0]
1358+
output.split(1, 32)
1359+
output.split(0, 32)
1360+
# [I1, 32', I0, 32]
1361+
output.reorder({0: 1, 1: 2, 2: 0})
1362+
output.merge(0, 1)
1363+
# [I0/32 * I1/32', 32', 32]
1364+
output.axis(0).parallelize(ParallelType.grid_x)
1365+
# [BIDx, 32', 32]
1366+
1367+
fd.sched.bounded_transform_backward(
1368+
output, -1, [input], propagate_parallel_type=True
1369+
)
1370+
1371+
# For fusion output, we just use TMA to store the entire tile back to global
1372+
# memory. There is no need to further schedule the output tensor.
1373+
output.axis(1).parallelize(ParallelType.tma)
1374+
output.axis(2).parallelize(ParallelType.tma)
1375+
# [BIDx, Bulk, Bulk]
1376+
1377+
# output_smem_cache and output_reg_cache are scheduled in the same way.
1378+
# We use each warp to load one column of input_smem_cache. We vectorize
1379+
# the load to 16 bytes, and use 8 warps to load all these 8 columns. Then,
1380+
# when we write to output_smem_cache, we unroll the write. Each warp writes
1381+
# one row in output_smem_cache in each iteration, so there is no bank
1382+
# conflict.
1383+
1384+
# [BIDx, 32', 32]
1385+
output_smem_cache.set_allocation_domain(
1386+
output_smem_cache.get_loop_domain(), new_contiguity=True
1387+
)
1388+
output_smem_cache.split(1, 4)
1389+
# [BIDx, 8', 4', 32]
1390+
1391+
fd.sched.bounded_transform_backward(output_smem_cache, -1, [input])
1392+
1393+
output_smem_cache.merge(1, 3)
1394+
# [BIDx, 256, 4']
1395+
output_smem_cache.axis(1).parallelize(ParallelType.block_x)
1396+
1397+
fd.sched.bounded_transform_backward(
1398+
output_smem_cache, -1, [input_smem_cache], propagate_parallel_type=True
1399+
)
1400+
1401+
output_smem_cache.axis(2).parallelize(ParallelType.unroll)
1402+
output_reg_cache.axis(2).parallelize(ParallelType.vectorize)
1403+
output_reg_cache.set_allocation_domain(
1404+
output_reg_cache.get_loop_domain(), new_contiguity=True
1405+
)
1406+
1407+
# Schedule the memory format for 128 byte swizzle
1408+
# [BIDx, 8', 4', 32]
1409+
input_smem_cache.reorder({3: 1, 1: 2, 2: 3})
1410+
# [BIDx, 32, 8', 4']
1411+
input_smem_cache.split(1, 8)
1412+
# [BIDx, 4, 8, 8', 4']
1413+
input_smem_cache.swizzle(2, 3)
1414+
# [BIDx, 4, 8, 8', 4']
1415+
input_smem_cache.set_allocation_domain(
1416+
input_smem_cache.get_loop_domain(), new_contiguity=True
1417+
)
1418+
1419+
input_smem_cache.axis(1).parallelize(ParallelType.tma)
1420+
input_smem_cache.axis(2).parallelize(ParallelType.tma)
1421+
input_smem_cache.axis(3).parallelize(ParallelType.tma)
1422+
input_smem_cache.axis(4).parallelize(ParallelType.tma)
1423+
# [BIDx, Bulk, Bulk, Bulk, Bulk]
1424+
1425+
if verbose_:
1426+
print(fd.fusion.print_math())
1427+
print(fd.fusion.print_kernel())
1428+
1429+
index32bit = CompileParams(
1430+
index_type=DataType.Int32, maxrregcount=255, enable_magic_zero=False
1431+
)
1432+
t0 = torch.randn(10000, 10000, dtype=torch.float, device="cuda:0")
1433+
ke = KernelExecutor()
1434+
ke.compile(fd.fusion, [t0], compile_params=index32bit)
1435+
outputs = ke.run([t0])
1436+
assert outputs[0].equal(t0.t())

0 commit comments

Comments
 (0)