Skip to content

Commit

Permalink
Use compile-time promotion to reduce remainder size & build time (#3458)
Browse files Browse the repository at this point in the history
Summary:

Yet another op that can benefit from compile-time type promotion.

Differential Revision: D56831293
  • Loading branch information
swolchok authored and facebook-github-bot committed May 3, 2024
1 parent e0156df commit 199d6d1
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 25 deletions.
81 changes: 56 additions & 25 deletions kernels/portable/cpu/op_remainder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@ namespace native {

using Tensor = exec_aten::Tensor;

namespace {
template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct RemainderInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct RemainderInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::remainder_override(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct RemainderInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace
Tensor& remainder_Tensor_out(
RuntimeContext& ctx,
const Tensor& a,
Expand All @@ -45,32 +91,17 @@ Tensor& remainder_Tensor_out(
Bool, a_type, ctx, "remainder.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, b_type, ctx, "remainder.Tensor_out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES(
common_type, ctx, "remainder.Tensor_out", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(
out_type,
ctx,
"remainder.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = utils::remainder_override(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
out_type, ctx, "remainder.Tensor_out", CTYPE_OUT, [&]() {
RemainderInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down
14 changes: 14 additions & 0 deletions kernels/test/op_remainder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using exec_aten::Tensor;
using torch::executor::testing::TensorFactory;

class OpRemainderOutTest : public OperatorTest {
protected:
Tensor& op_remainder_tensor_out(
const Tensor& self,
const Tensor& other,
Expand All @@ -35,3 +36,16 @@ class OpRemainderOutTest : public OperatorTest {
return torch::executor::aten::remainder_outf(context_, self, other, out);
}
};

TEST_F(OpRemainderOutTest, SmokeTest) {
TensorFactory<ScalarType::Long> tfDouble;
TensorFactory<ScalarType::Long> tfLong;
TensorFactory<ScalarType::Int> tfInt;

Tensor self = tfLong.full({2, 2}, 46);
Tensor other = tfInt.full({2, 2}, 4);
Tensor out = tfDouble.zeros({2, 2});
Tensor out_expected = tfDouble.full({2, 2}, 2.0);
op_remainder_tensor_out(self, other, out);
EXPECT_TENSOR_CLOSE(out, out_expected);
}

0 comments on commit 199d6d1

Please sign in to comment.