@@ -1141,3 +1141,192 @@ def test_tutorial_basic_tma_example6(nvfuser_direct_test):
11411141 ke .compile (fd .fusion , [t0 ], compile_params = index32bit )
11421142 outputs = ke .run ([t0 ])
11431143 assert outputs [0 ].equal (t0 )
1144+
1145+
1146+ @pytest .mark .skipif (
1147+ is_pre_hopper (), reason = "Only supported on Hopper and newer devices."
1148+ )
1149+ def test_tutorial_vectorize_store_pointwise_tma (nvfuser_direct_test ):
1150+ with FusionDefinition () as fd :
1151+ tv0 = fd .define_tensor (shape = [- 1 , - 1 ], contiguity = [True , True ])
1152+ tv1 = fd .define_tensor (shape = [- 1 , - 1 ], contiguity = [True , True ])
1153+ tv2 = fd .ops .add (tv0 , tv1 )
1154+ fd .add_output (tv2 )
1155+
1156+ # Create cache_tvs
1157+ tv0a = tv0 .cache_after (LoadStoreOpType .tma )
1158+ tv1a = tv1 .cache_after (LoadStoreOpType .tma )
1159+ tv2b = tv2 .cache_before ()
1160+
1161+ tv0a .set_memory_type (MemoryType .shared )
1162+ tv1a .set_memory_type (MemoryType .shared )
1163+
1164+ reference_tv = tv2
1165+
1166+ # Step 1: Create tma domain
1167+ # Use the root domain as TMA domain
1168+ # root domain: [I0, I1]
1169+
1170+ num_threads = 128
1171+ vectorization = 2
1172+ tma_tile = num_threads * vectorization
1173+ num_stages = 4
1174+ num_ctas_for_hopper = 132
1175+
1176+ # Step 2: Create Box
1177+ # After TMA domain creation
1178+ # split: [I0, I3, 256]
1179+ reference_tv .split (- 1 , tma_tile )
1180+ # split: [I2, 4, I3, 256]
1181+ reference_tv .split (0 , num_stages )
1182+
1183+ # Step 3: Create Tile
1184+ # Do nothing here because box == tile
1185+
1186+ # Step 4: Schedule Shared Memory Tensor
1187+ # split: [I2, 4, I3, 128, 2]
1188+ reference_tv .split (- 1 , vectorization )
1189+ # split: [I4, 132, 4, I3, 128, 2]
1190+ reference_tv .split (0 , num_ctas_for_hopper )
1191+ # reorder: [I4, 132, I3, 4, 128, 2]
1192+ reference_tv .reorder ({3 : 2 , 2 : 3 })
1193+
1194+ # Transform Operations between cache operations and output reference
1195+ fd .sched .transform_like (reference_tv )
1196+
1197+ # Propagate common parallel dimensions
1198+ reference_tv .axis (1 ).parallelize (ParallelType .grid_x )
1199+ fd .sched .parallelize_like (reference_tv )
1200+
1201+ tv2b .axis (- 2 ).parallelize (ParallelType .block_x )
1202+
1203+ # Vectorization for writing results to gmem
1204+ reference_tv .axis (- 3 ).parallelize (ParallelType .unroll )
1205+ reference_tv .axis (- 2 ).parallelize (ParallelType .block_x )
1206+ reference_tv .axis (- 1 ).parallelize (ParallelType .vectorize )
1207+
1208+ # Apply bulk type to TMA tensors
1209+ tv0a .axis (- 1 ).parallelize (ParallelType .tma )
1210+ tv0a .axis (- 2 ).parallelize (ParallelType .tma )
1211+ tv0a .axis (- 3 ).parallelize (ParallelType .tma )
1212+
1213+ tv1a .axis (- 1 ).parallelize (ParallelType .tma )
1214+ tv1a .axis (- 2 ).parallelize (ParallelType .tma )
1215+ tv1a .axis (- 3 ).parallelize (ParallelType .tma )
1216+
1217+ # ComputeAt
1218+ fd .sched .inline_most ()
1219+
1220+ if verbose_ :
1221+ print (fd .fusion .print_math ())
1222+ print (fd .fusion .print_kernel ())
1223+
1224+ dim0 = 16384
1225+ dim1 = 16384
1226+
1227+ # Compile with KernelExecutor directly to avoid scheduling
1228+ index32bit = CompileParams (
1229+ index_type = DataType .Int32 , maxrregcount = 255 , enable_magic_zero = False
1230+ )
1231+ t0 = torch .randn (dim0 , dim1 , dtype = torch .float , device = "cuda:0" )
1232+ t1 = torch .randn (dim0 , dim1 , dtype = torch .float , device = "cuda:0" )
1233+ t2 = t0 + t1
1234+ ke = KernelExecutor ()
1235+ ke .compile (fd .fusion , [t0 , t1 ], compile_params = index32bit )
1236+ outputs = ke .run ([t0 , t1 ])
1237+ assert outputs [0 ].equal (t2 )
1238+
1239+
1240+ @pytest .mark .skipif (
1241+ is_pre_hopper (), reason = "Only supported on Hopper and newer devices."
1242+ )
1243+ def test_tutorial_pointwise_broadcast_tma (nvfuser_direct_test ):
1244+ with FusionDefinition () as fd :
1245+ tv0 = fd .define_tensor (shape = [- 1 , - 1 , - 1 ], contiguity = [True , True , True ])
1246+ tv1 = fd .define_tensor (
1247+ shape = [- 1 , - 1 , - 1 , - 1 ], contiguity = [True , False , True , True ]
1248+ )
1249+ tv2 = fd .ops .broadcast (tv0 , [True , False , False , False ])
1250+ tv3 = fd .ops .add (tv2 , tv1 )
1251+ fd .add_output (tv3 )
1252+
1253+ # Create cache_tvs
1254+ tv0a = tv0 .cache_after (LoadStoreOpType .tma )
1255+ tv1a = tv1 .cache_after (LoadStoreOpType .tma )
1256+ tv3b = tv3 .cache_before (LoadStoreOpType .tma )
1257+
1258+ tv0a .set_memory_type (MemoryType .shared )
1259+ tv1a .set_memory_type (MemoryType .shared )
1260+ tv3b .set_memory_type (MemoryType .shared )
1261+
1262+ reference_tv = tv3
1263+
1264+ # Step 1: Create tma domain
1265+ # root domain: [I0, I1, I2, I3]
1266+ # TMA domain: [I0, I1, I4]
1267+ reference_tv .merge (- 2 , - 1 )
1268+
1269+ # Step 2: Define TMA Box
1270+ # split: [I0, I1, I5, 256]
1271+ reference_tv .split (- 1 , 256 )
1272+
1273+ # Step 3: Define Tile
1274+ # Do nothing here because tile == box.
1275+
1276+ # Step 4: Schedule Shared Memory Tensor
1277+ # merge: [I10, I5, 256]
1278+ reference_tv .merge (0 , 1 )
1279+ # split: [I10, I7, 4, 256]
1280+ reference_tv .split (- 2 , 4 )
1281+ # merge: [I11, 4, 256]
1282+ reference_tv .merge (0 , 1 )
1283+
1284+ # Transform Operations between cache operations and output reference
1285+ fd .sched .transform_like (reference_tv )
1286+
1287+ # Define Parallelization Schema
1288+ # Intermediate Tensors
1289+ tv3b .axis (0 ).parallelize (ParallelType .grid_x )
1290+ tv3b .axis (1 ).parallelize (ParallelType .unroll )
1291+ tv3b .axis (2 ).parallelize (ParallelType .block_x )
1292+
1293+ tv2 .axis (0 ).parallelize (ParallelType .grid_x )
1294+ tv2 .axis (1 ).parallelize (ParallelType .unroll )
1295+ tv2 .axis (2 ).parallelize (ParallelType .block_x )
1296+
1297+ # TMA Tensors
1298+ tv1a .axis (0 ).parallelize (ParallelType .grid_x )
1299+ tv1a .axis (1 ).parallelize (ParallelType .block_x )
1300+ tv1a .axis (2 ).parallelize (ParallelType .tma )
1301+
1302+ tv0a .axis (0 ).parallelize (ParallelType .grid_x )
1303+ tv0a .axis (1 ).parallelize (ParallelType .block_x )
1304+ tv0a .axis (2 ).parallelize (ParallelType .tma )
1305+
1306+ tv3 .axis (0 ).parallelize (ParallelType .grid_x )
1307+ tv3 .axis (1 ).parallelize (ParallelType .block_x )
1308+ tv3 .axis (2 ).parallelize (ParallelType .tma )
1309+
1310+ # ComputeAt
1311+ fd .sched .inline_most ()
1312+
1313+ if verbose_ :
1314+ print (fd .fusion .print_math ())
1315+ print (fd .fusion .print_kernel ())
1316+
1317+ dim0 = 32
1318+ dim1 = 2
1319+ dim2 = 4
1320+ dim3 = 256
1321+
1322+ # Compile with KernelExecutor directly to avoid scheduling
1323+ index32bit = CompileParams (
1324+ index_type = DataType .Int32 , maxrregcount = 255 , enable_magic_zero = False
1325+ )
1326+ t0 = torch .randn (dim1 , dim2 , dim3 , dtype = torch .float , device = "cuda:0" )
1327+ t1 = torch .randn (dim0 , dim1 , dim2 , dim3 , dtype = torch .float , device = "cuda:0" )
1328+ t2 = t0 + t1
1329+ ke = KernelExecutor ()
1330+ ke .compile (fd .fusion , [t0 , t1 ], compile_params = index32bit )
1331+ outputs = ke .run ([t0 , t1 ])
1332+ assert outputs [0 ].equal (t2 )
0 commit comments