@@ -1330,3 +1330,107 @@ def test_tutorial_pointwise_broadcast_tma(nvfuser_direct_test):
13301330 ke .compile (fd .fusion , [t0 , t1 ], compile_params = index32bit )
13311331 outputs = ke .run ([t0 , t1 ])
13321332 assert outputs [0 ].equal (t2 )
1333+
1334+
1335+ @pytest .mark .skipif (
1336+ is_pre_hopper (), reason = "Only supported on Hopper and newer devices."
1337+ )
1338+ def test_tutorial_tma_bank_conflict_free_transpose (nvfuser_direct_test ):
1339+ with FusionDefinition () as fd :
1340+ input = fd .define_tensor (shape = [- 1 , - 1 ], contiguity = [True , True ])
1341+ output = fd .ops .permute (input , [1 , 0 ])
1342+ fd .add_output (output )
1343+
1344+ # Change the fusion to input->smem->register->smem->output where the
1345+ # smem->register part does the transpose
1346+ input_smem_cache = input .cache_after (LoadStoreOpType .tma )
1347+ input_smem_cache .set_memory_type (MemoryType .shared )
1348+
1349+ output_smem_cache = output .cache_before (LoadStoreOpType .tma )
1350+ output_smem_cache .set_memory_type (MemoryType .shared )
1351+
1352+ output_reg_cache = output_smem_cache .cache_before ()
1353+
1354+ # Create 32x32 tile. Each CTA has one tile, and the entire tile will be
1355+ # loaded to shared memory by TMA, and stored back to global memory by TMA.
1356+
1357+ # [I1, I0]
1358+ output .split (1 , 32 )
1359+ output .split (0 , 32 )
1360+ # [I1, 32', I0, 32]
1361+ output .reorder ({0 : 1 , 1 : 2 , 2 : 0 })
1362+ output .merge (0 , 1 )
1363+ # [I0/32 * I1/32', 32', 32]
1364+ output .axis (0 ).parallelize (ParallelType .grid_x )
1365+ # [BIDx, 32', 32]
1366+
1367+ fd .sched .bounded_transform_backward (
1368+ output , - 1 , [input ], propagate_parallel_type = True
1369+ )
1370+
1371+ # For fusion output, we just use TMA to store the entire tile back to global
1372+ # memory. There is no need to further schedule the output tensor.
1373+ output .axis (1 ).parallelize (ParallelType .tma )
1374+ output .axis (2 ).parallelize (ParallelType .tma )
1375+ # [BIDx, Bulk, Bulk]
1376+
1377+ # output_smem_cache and output_reg_cache are scheduled in the same way.
1378+ # We use each warp to load one column of input_smem_cache. We vectorize
1379+ # the load to 16 bytes, and use 8 warps to load all these 8 columns. Then,
1380+ # when we write to output_smem_cache, we unroll the write. Each warp writes
1381+ # one row in output_smem_cache in each iteration, so there is no bank
1382+ # conflict.
1383+
1384+ # [BIDx, 32', 32]
1385+ output_smem_cache .set_allocation_domain (
1386+ output_smem_cache .get_loop_domain (), new_contiguity = True
1387+ )
1388+ output_smem_cache .split (1 , 4 )
1389+ # [BIDx, 8', 4', 32]
1390+
1391+ fd .sched .bounded_transform_backward (output_smem_cache , - 1 , [input ])
1392+
1393+ output_smem_cache .merge (1 , 3 )
1394+ # [BIDx, 256, 4']
1395+ output_smem_cache .axis (1 ).parallelize (ParallelType .block_x )
1396+
1397+ fd .sched .bounded_transform_backward (
1398+ output_smem_cache , - 1 , [input_smem_cache ], propagate_parallel_type = True
1399+ )
1400+
1401+ output_smem_cache .axis (2 ).parallelize (ParallelType .unroll )
1402+ output_reg_cache .axis (2 ).parallelize (ParallelType .vectorize )
1403+ output_reg_cache .set_allocation_domain (
1404+ output_reg_cache .get_loop_domain (), new_contiguity = True
1405+ )
1406+
1407+ # Schedule the memory format for 128 byte swizzle
1408+ # [BIDx, 8', 4', 32]
1409+ input_smem_cache .reorder ({3 : 1 , 1 : 2 , 2 : 3 })
1410+ # [BIDx, 32, 8', 4']
1411+ input_smem_cache .split (1 , 8 )
1412+ # [BIDx, 4, 8, 8', 4']
1413+ input_smem_cache .swizzle (2 , 3 )
1414+ # [BIDx, 4, 8, 8', 4']
1415+ input_smem_cache .set_allocation_domain (
1416+ input_smem_cache .get_loop_domain (), new_contiguity = True
1417+ )
1418+
1419+ input_smem_cache .axis (1 ).parallelize (ParallelType .tma )
1420+ input_smem_cache .axis (2 ).parallelize (ParallelType .tma )
1421+ input_smem_cache .axis (3 ).parallelize (ParallelType .tma )
1422+ input_smem_cache .axis (4 ).parallelize (ParallelType .tma )
1423+ # [BIDx, Bulk, Bulk, Bulk, Bulk]
1424+
1425+ if verbose_ :
1426+ print (fd .fusion .print_math ())
1427+ print (fd .fusion .print_kernel ())
1428+
1429+ index32bit = CompileParams (
1430+ index_type = DataType .Int32 , maxrregcount = 255 , enable_magic_zero = False
1431+ )
1432+ t0 = torch .randn (10000 , 10000 , dtype = torch .float , device = "cuda:0" )
1433+ ke = KernelExecutor ()
1434+ ke .compile (fd .fusion , [t0 ], compile_params = index32bit )
1435+ outputs = ke .run ([t0 ])
1436+ assert outputs [0 ].equal (t0 .t ())
0 commit comments