Skip to content

Commit

Permalink
Use compile-time promotion to reduce max/min size & build time (pytor…
Browse files Browse the repository at this point in the history
…ch#3459)

Summary:
Pull Request resolved: pytorch#3459

Yet another smaller pair of ops.

Differential Revision: D56807402
  • Loading branch information
swolchok authored and facebook-github-bot committed May 2, 2024
1 parent c87a8d0 commit 0c7c9e6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 29 deletions.
68 changes: 54 additions & 14 deletions kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,50 @@ const T& max(const T& a, const T& b) {
return (b > a) ? b : a;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = max(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& maximum_out(
Expand All @@ -44,20 +88,16 @@ Tensor& maximum_out(

ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = max(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
MaximumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down
69 changes: 54 additions & 15 deletions kernels/portable/cpu/op_minimum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,50 @@ const T& min(const T& a, const T& b) {
return (b < a) ? b : a;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = min(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& minimum_out(
Expand All @@ -44,22 +88,17 @@ Tensor& minimum_out(

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
using CTYPE_IN =
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = min(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
MinimumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down

0 comments on commit 0c7c9e6

Please sign in to comment.