Skip to content

Commit

Permalink
[ET-VK][15/n] reconcile Dim4D and NchwDim into DimIndex
Browse files Browse the repository at this point in the history
Pull Request resolved: #3489

1. Adapt @SSJia's idea to represent `Dim4D` as a "negative index", and rename it as `DimIndex`
2. Merge `NchwDim`'s functionality with `Dim4D`.
3. Clean up `dim_at` call to assume only `DimIndex` as input.
4. Further clean up some usage of `uint` and convert them into `int`.


ghstack-source-id: 225521662

Differential Revision: [D56778340](https://our.internmc.facebook.com/intern/diff/D56778340/)
  • Loading branch information
yipjustin committed May 8, 2024
1 parent c001f59 commit 59bf7b4
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 135 deletions.
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/permute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {

layout(set = 0, binding = 4) uniform PRECISION restrict Block {
// output dims
uvec4 out_ndims;
ivec4 out_ndims;
// x = output channels aligned to 4, y = input channels aligned to 4
uvec2 ch_info;
ivec2 ch_info;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand Down
14 changes: 7 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/Cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ void add_cat_default_node(
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
vTensorPtr t_out = graph.get_tensor(out);

NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);
DimIndex dim_index = normalize_to_dim_index(*t_out, dim);

// TODO: Find ways to factor out the similar code for width, height, and batch
if (nchw_dim == DimWidth) {
if (dim_index == kWidth4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -46,7 +46,7 @@ void add_cat_default_node(
dst_offset.data[0] += range.data[0];
}

} else if (nchw_dim == DimHeight) {
} else if (dim_index == kHeight4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -57,7 +57,7 @@ void add_cat_default_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset.data[1] += range.data[1];
}
} else if (nchw_dim == DimBatch) {
} else if (dim_index == kBatch4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -68,19 +68,19 @@ void add_cat_default_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset.data[2] += range.data[2];
}
} else if (nchw_dim == DimChannel) {
} else if (dim_index == kChannel4D) {
int32_t src_offset = 0;
int32_t dst_offset = 0;

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
int32_t range = dim_at(t_in->sizes(), kChannel4D);
add_copy_channel_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset += range;
}
} else {
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
VK_THROW("Unexpected value of dim_index=", dim_index);
}
}

Expand Down
17 changes: 8 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,23 @@ void add_copy_channel_offset_node(
VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3");

VK_CHECK_COND(
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
dim_at<kChannel4D>(in_sizes) >= src_channel_offset + channel_range,
"Src channel (",
src_channel_offset,
") and range (",
channel_range,
") should be less than or equal to input tensor's channel size (",
dim_at<Dim4D::Channel>(in_sizes),
dim_at<kChannel4D>(in_sizes),
")");

VK_CHECK_COND(
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
dim_at<kChannel4D>(out_sizes) >= dst_channel_offset + channel_range,
"Dst channel (",
dst_channel_offset,
") and range (",
channel_range,
") should be less than or equal to input tensor's channel size (",
dim_at<Dim4D::Channel>(out_sizes),
dim_at<kChannel4D>(out_sizes),
")");

VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
Expand All @@ -121,11 +121,10 @@ void add_copy_channel_offset_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

int32_t out_channels = dim_at<Dim4D::Channel>(out_sizes);
int32_t out_channels = dim_at<kChannel4D>(out_sizes);

// Copy one batch at a time.
for (int batch_idx = 0; batch_idx < dim_at<Dim4D::Batch>(in_sizes);
batch_idx++) {
for (int batch_idx = 0; batch_idx < dim_at<kBatch4D>(in_sizes); batch_idx++) {
// Mapping the tensor NCHW coordinates into texture XYZ coordinates
int32_t dst_first_z = dst_channel_offset / 4;
int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4;
Expand All @@ -139,8 +138,8 @@ void add_copy_channel_offset_node(
0, 0, dst_first_z + batch_idx * api::utils::div_up(out_channels, 4)};

uvec3 global_size{
dim_at<Dim4D::Width>(in_sizes),
dim_at<Dim4D::Height>(in_sizes),
api::utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
api::utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
api::utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};

uvec3 local_size = adaptive_work_group_size(global_size);
Expand Down
19 changes: 10 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

namespace vkcompute {

using api::utils::ivec2;
using api::utils::ivec3;
using api::utils::uvec2;
using api::utils::ivec4;
using api::utils::uvec4;

namespace {
Expand Down Expand Up @@ -53,7 +54,7 @@ void add_permute_node(

check_args(*t_in, permute_dims, *t_out);

uvec4 out_dims{0u, 1u, 2u, 3u};
ivec4 out_dims{0, 1, 2, 3};

int64_t out_dim = t_out->dim();
std::vector<bool> seen(out_dim);
Expand All @@ -63,22 +64,22 @@ void add_permute_node(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
out_dims.data[(4u - out_dim) + i] = permute_dim + (4 - out_dim);
}

std::string kernel_name = "permute";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());

uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
int32_t out_c_aligned = api::utils::align_up(out_channels, 4);
int32_t in_c_aligned = api::utils::align_up(in_channels, 4);

const struct Block final {
uvec4 out_ndims;
uvec2 ch_info;
ivec4 out_ndims;
ivec2 ch_info;
} params{
out_dims,
{out_c_aligned, in_c_aligned},
Expand Down
52 changes: 24 additions & 28 deletions backends/vulkan/runtime/graph/ops/impl/Repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ void check_args(
"Input tensor dim size must be not greater than the repeat argument's size");

VK_CHECK_COND(
dim_at<Dim4D::Width>(in.sizes()) * dim_at<Dim4D::Width>(repeats) ==
dim_at<Dim4D::Width>(out.sizes()),
dim_at<kWidth4D>(in.sizes()) * dim_at<kWidth4D>(repeats) ==
dim_at<kWidth4D>(out.sizes()),
"Output's width doesn't match input's width * repeat count");

VK_CHECK_COND(
dim_at<Dim4D::Height>(in.sizes()) * dim_at<Dim4D::Height>(repeats) ==
dim_at<Dim4D::Height>(out.sizes()),
dim_at<kHeight4D>(in.sizes()) * dim_at<kHeight4D>(repeats) ==
dim_at<kHeight4D>(out.sizes()),
"Output's height doesn't match input's height * repeat count");

VK_CHECK_COND(
dim_at<Dim4D::Channel>(in.sizes()) * dim_at<Dim4D::Channel>(repeats) ==
dim_at<Dim4D::Channel>(out.sizes()),
dim_at<kChannel4D>(in.sizes()) * dim_at<kChannel4D>(repeats) ==
dim_at<kChannel4D>(out.sizes()),
"Output's channel doesn't match input's channel * repeat count");

VK_CHECK_COND(
dim_at<Dim4D::Batch>(in.sizes()) * dim_at<Dim4D::Batch>(repeats) ==
dim_at<Dim4D::Batch>(out.sizes()),
dim_at<kBatch4D>(in.sizes()) * dim_at<kBatch4D>(repeats) ==
dim_at<kBatch4D>(out.sizes()),
"Output's batch doesn't match input's batch * repeat count");
}

Expand All @@ -70,13 +70,13 @@ void add_repeat_channel_node(
const std::vector<int64_t>& in_sizes = t_in->sizes();

int32_t in_width =
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(in_sizes));
api::utils::safe_downcast<int32_t>(dim_at<kWidth4D>(in_sizes));
int32_t in_height =
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(in_sizes));
api::utils::safe_downcast<int32_t>(dim_at<kHeight4D>(in_sizes));
int32_t in_channel =
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(in_sizes));
api::utils::safe_downcast<int32_t>(dim_at<kChannel4D>(in_sizes));
int32_t in_batch =
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Batch>(in_sizes));
api::utils::safe_downcast<int32_t>(dim_at<kBatch4D>(in_sizes));

int32_t out_channel = repeat_channel * in_channel;

Expand Down Expand Up @@ -142,11 +142,11 @@ void add_repeat_node(
// dimension, we copy over the input texure to the output. In subsequent
// dimensions, we read and write from the same tensor.

if (int64_t channel_repeat = dim_at<Dim4D::Channel>(repeats);
if (int64_t channel_repeat = dim_at<kChannel4D>(repeats);
channel_repeat == 1) {
// If no repeat, short-cut to a direct copy
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 src_offset{0, 0, 0};
api::utils::ivec3 dst_offset{0, 0, 0};

add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out);

Expand All @@ -156,12 +156,11 @@ void add_repeat_node(

// TODO: refactor width, height, and batch into a common helper function.
// Width
if (int64_t width_repeat = dim_at<Dim4D::Width>(repeats); width_repeat > 1) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
if (int64_t width_repeat = dim_at<kWidth4D>(repeats); width_repeat > 1) {
api::utils::ivec3 src_offset{0, 0, 0};

for (int i = 1; i < width_repeat; ++i) {
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
{i * dim_at<Dim4D::Width>(in_sizes), 0, 0}, false);
api::utils::ivec3 dst_offset{i * dim_at<kWidth4D>(in_sizes), 0, 0};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
Expand All @@ -171,13 +170,11 @@ void add_repeat_node(
}

// Height
if (int64_t height_repeat = dim_at<Dim4D::Height>(repeats);
height_repeat > 1) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
if (int64_t height_repeat = dim_at<kHeight4D>(repeats); height_repeat > 1) {
api::utils::ivec3 src_offset{0, 0, 0};

for (int i = 1; i < height_repeat; ++i) {
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
{0, i * dim_at<Dim4D::Height>(in_sizes), 0}, false);
api::utils::ivec3 dst_offset = {0, i * dim_at<kHeight4D>(in_sizes), 0};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
Expand All @@ -187,12 +184,11 @@ void add_repeat_node(
}

// Batch
if (int64_t batch_repeat = dim_at<Dim4D::Batch>(repeats); batch_repeat > 1) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
if (int64_t batch_repeat = dim_at<kBatch4D>(repeats); batch_repeat > 1) {
api::utils::ivec3 src_offset{0, 0, 0};

for (int i = 1; i < batch_repeat; ++i) {
api::utils::ivec3 dst_offset =
api::utils::make_ivec3({0, 0, i * running_range.data[2]}, false);
api::utils::ivec3 dst_offset = {0, 0, i * running_range.data[2]};

add_copy_offset_node(
graph, out, running_range, src_offset, dst_offset, out);
Expand Down
13 changes: 6 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/Slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ void add_slice_tensor_out_node(

dim = normalize(dim, t_in->dim());

// Create a dim value as in the underlying dim is 4-dimension.
int64_t nchw_dim = dim + (4 - t_in->dim());
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);

std::optional<int64_t> opt_start =
graph.extract_optional_scalar<int64_t>(opt_start_ref);
Expand All @@ -61,7 +60,7 @@ void add_slice_tensor_out_node(
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));

if (nchw_dim == 1) {
if (dim_index == kChannel4D) {
// slice by channel
std::string kernel_name = "slice_channel";
kernel_name.reserve(kShaderNameReserve);
Expand Down Expand Up @@ -93,17 +92,17 @@ void add_slice_tensor_out_node(
// GPU's coordinate is in x, y, z
int64_t gpu_dim = -1;
int64_t stride = 1;
if (nchw_dim == 3) {
if (dim_index == kWidth4D) {
gpu_dim = 0; // width: x dimension in gpu
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
} else if (nchw_dim == 2) {
} else if (dim_index == kHeight4D) {
gpu_dim = 1; // height: y dimension
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
} else if (nchw_dim == 0) {
} else if (dim_index == kBatch4D) {
gpu_dim = 2; // batch: z dimension

// Due to channel packing, each batch value is span over stride planes
int64_t n_channels = dim_at<Dim4D::Channel>(in_sizes);
int64_t n_channels = dim_at(in_sizes, kChannel4D);
stride = api::utils::div_up<int64_t>(n_channels, 4ll);
} else {
VK_THROW("Unexpected ncwh_dim!");
Expand Down
18 changes: 9 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void add_split_with_sizes_default_node(

ValueListPtr out_list = graph.get_value_list(out_list_ref);

NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);

VK_CHECK_COND(out_list->size() == split_sizes.size());

Expand All @@ -39,10 +39,10 @@ void add_split_with_sizes_default_node(

vTensorPtr t_out = graph.get_tensor(out_ref);
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
VK_CHECK_COND(dim_at(*t_out, nchw_dim) == split_size);
VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size);
}

if (nchw_dim == DimWidth) {
if (dim_index == kWidth4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -55,7 +55,7 @@ void add_split_with_sizes_default_node(

src_offset.data[0] += range.data[0];
}
} else if (nchw_dim == DimHeight) {
} else if (dim_index == kHeight4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -66,7 +66,7 @@ void add_split_with_sizes_default_node(

src_offset.data[1] += range.data[1];
}
} else if (nchw_dim == DimBatch) {
} else if (dim_index == kBatch4D) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

Expand All @@ -77,13 +77,13 @@ void add_split_with_sizes_default_node(

src_offset.data[2] += range.data[2];
}
} else if (nchw_dim == DimChannel) {
} else if (dim_index == kChannel4D) {
int32_t src_offset = 0;
int32_t dst_offset = 0;

for (ValueRef out_ref : *out_list) {
vTensorPtr t_out = graph.get_tensor(out_ref);
int32_t range = dim_at<Dim4D::Channel>(t_out->sizes());
int32_t range = dim_at<kChannel4D>(t_out->sizes());
add_copy_channel_offset_node(
graph, in, range, src_offset, dst_offset, out_ref);
src_offset += range;
Expand Down Expand Up @@ -122,8 +122,8 @@ void add_split_tensor_node(
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);

vTensorPtr t_in = graph.get_tensor(in);
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
int64_t size = dim_at(*t_in, nchw_dim);
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
int64_t size = dim_at(*t_in, dim_index);
std::vector<int64_t> split_sizes(size / split_size, split_size);

add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
Expand Down

0 comments on commit 59bf7b4

Please sign in to comment.