@@ -67,6 +67,23 @@ Returns
6767-------
6868Val
6969 The extent of this domain.
70+ )" )
71+ .def (
72+ " parallelize" ,
73+ &IterDomain::parallelize,
74+ py::arg (" parallel_type" ),
75+ R"(
76+ Set the parallel type of this domain.
77+
78+ Parameters
79+ ----------
80+ parallel_type : ParallelType
81+ The type of parallelization to apply (e.g., BIDx, TIDx, etc.).
82+
83+ Notes
84+ -----
85+ This is a key function used in scheduling to specify how the domain should be parallelized
86+ across CUDA threads and blocks.
7087)" );
7188
7289 // TensorDomain
@@ -111,6 +128,72 @@ TensorDomain
111128 - Logical domain (The original dimensions. It may contain rFactor iterDomains.)
112129 - Allocation domain (How the memory is allocated for the tensor?)
113130 - Loop domain (The for-loop structure for the tensor.)
131+ )" )
132+ .def (
133+ " get_loop_domain" ,
134+ &TensorView::getLoopDomain,
135+ R"(
136+ Get the loop domain of this tensor.
137+
138+ Returns
139+ -------
140+ list of IterDomain
141+ The loop iteration domains.
142+ )" )
143+ .def (
144+ " split" ,
145+ static_cast <TensorView* (TensorView::*)(int64_t , int64_t , bool )>(
146+ &TensorView::split),
147+ py::arg (" axis" ),
148+ py::arg (" factor" ),
149+ py::arg (" inner_split" ) = true ,
150+ py::return_value_policy::reference,
151+ R"(
152+ Split an axis into two axes.
153+
154+ Parameters
155+ ----------
156+ axis : int
157+ The axis to split.
158+ factor : int
159+ The factor to split by.
160+ inner_split : bool, optional
161+ If True, the factor determines the size of the inner domain.
162+ If False, the factor determines the size of the outer domain.
163+ Default is True.
164+
165+ Returns
166+ -------
167+ TensorView
168+ A TensorView with the split axes in its loop domain.
169+ )" )
170+ .def (
171+ " set_allocation_domain" ,
172+ static_cast <void (TensorView::*)(std::vector<IterDomain*>, bool )>(
173+ &TensorView::setAllocationDomain),
174+ py::arg (" new_allocation_domain" ),
175+ py::arg (" new_contiguity" ),
176+ R"(
177+ Set the allocation domain of this tensor.
178+
179+ Parameters
180+ ----------
181+ new_allocation_domain : list of IterDomain
182+ The new allocation iteration domains.
183+ new_contiguity : bool
184+ The new contiguity flag.
185+ )" )
186+ .def (
187+ " set_device_mesh" ,
188+ &TensorView::setDeviceMesh,
189+ py::arg (" mesh" ),
190+ R"(
191+ Set the device mesh of this tensor.
192+
193+ Parameters
194+ ----------
195+ mesh : DeviceMesh
196+ The device mesh to set.
114197)" )
115198 .def (
116199 " axis" ,
0 commit comments