Skip to content

Commit fa9abdd

Browse files
authored
Check dimensions match in getExpanded python utility (#4671)
Previously we hit a segfault in `getExpanded` when using `define_tensor` and providing a `contiguity` vector that does not match the length of `shape`. This provides an error message instead. Fixes #4661
1 parent 1a902f7 commit fa9abdd

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

python/python_common/python_utils.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,21 @@ std::vector<bool> getExpanded(
164164
const std::vector<int64_t>& shape,
165165
const std::vector<std::optional<bool>>& contiguity,
166166
const std::vector<int64_t>& stride_order) {
167+
NVF_CHECK(
168+
contiguity.size() == shape.size(),
169+
"Length of contiguity argument (",
170+
contiguity.size(),
171+
") must match that of shape argument (",
172+
shape.size(),
173+
")");
174+
NVF_CHECK(
175+
stride_order.empty() || stride_order.size() == shape.size(),
176+
"Length of stride_order argument (",
177+
stride_order.size(),
178+
") must be zero or match that of shape argument (",
179+
shape.size(),
180+
")");
181+
167182
size_t rank = shape.size();
168183
std::vector<bool> is_expand(rank);
169184
for (size_t index : arange(rank)) {
@@ -178,9 +193,9 @@ std::vector<bool> getExpanded(
178193
// index `contig_index = rank - 1 - stride_order[index]`
179194
const size_t contig_index = stride_order.empty()
180195
? index
181-
: rank - 1 - static_cast<size_t>(stride_order[index]);
182-
const bool is_broadcast = !contiguity[contig_index].has_value();
183-
const bool has_non_broadcast_size = (shape[index] != 1);
196+
: rank - 1 - static_cast<size_t>(stride_order.at(index));
197+
const bool is_broadcast = !contiguity.at(contig_index).has_value();
198+
const bool has_non_broadcast_size = (shape.at(index) != 1);
184199
// A root dimension is expand dimension if:
185200
// The dimension is marked a broadcast; and
186201
// The dimension has an expanded extent.

0 commit comments

Comments
 (0)