Skip to content

Commit

Permalink
support Half in minimum and clamp
Browse files Browse the repository at this point in the history
Summary: IIUC, these ops need to support Half but don't. Noticed it as a difference from maximum.

Reviewed By: manuelcandales

Differential Revision: D56846242
  • Loading branch information
swolchok authored and facebook-github-bot committed May 1, 2024
1 parent 495c927 commit 2c0e1ef
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 27 deletions.
18 changes: 9 additions & 9 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bool is_out_of_bounds(CTYPE_VAL val) {
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
if (std::isfinite(val) &&
is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
ET_LOG(Error, "%s value out of bounds", val_name);
Expand Down Expand Up @@ -119,7 +119,7 @@ Tensor& clamp_out(

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
// Extract optional min value
CTYPE_OUT min = 0;
if (has_min) {
Expand All @@ -140,7 +140,7 @@ Tensor& clamp_out(
});
}

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() {
apply_unary_map_fn(
[has_min, min, has_max, max](const CTYPE_IN val_in) {
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
Expand Down Expand Up @@ -195,20 +195,20 @@ Tensor& clamp_tensor_out(
ScalarType out_type = out.scalar_type();

if (has_min) {
common_type = promoteTypes(common_type, min_type);
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
}
if (has_max) {
common_type = promoteTypes(common_type, max_type);
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
}

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

constexpr auto name = "clamp.Tensor_out";

ET_SWITCH_REALB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
ET_SWITCH_REALB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
ET_SWITCH_REALB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
ET_SWITCH_REALB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
apply_ternary_elementwise_fn<
CTYPE_IN,
CTYPE_MIN,
Expand Down
27 changes: 13 additions & 14 deletions kernels/portable/cpu/op_minimum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,24 @@ Tensor& minimum_out(

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType common_type = promoteTypes(a_type, b_type);
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
using CTYPE_IN =
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
MinimumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
MinimumInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});

Expand Down
42 changes: 42 additions & 0 deletions kernels/portable/cpu/util/math_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,48 @@ INT_T max_override(INT_T a, INT_T b) {
return std::max(a, b);
}

template <
typename T,
typename std::enable_if<
std::is_same<T, torch::executor::Half>::value,
bool>::type = true>
T min_override(T a, T b) {
const auto float_a = static_cast<float>(a);
if (std::isnan(float_a)) {
return a;
}
const auto float_b = static_cast<float>(b);
if (std::isnan(float_b)) {
return b;
}

if (float_a < float_b) {
return a;
}
return b;
}

template <
typename T,
typename std::enable_if<
std::is_same<T, torch::executor::Half>::value,
bool>::type = true>
T max_override(T a, T b) {
const auto float_a = static_cast<float>(a);
if (std::isnan(float_a)) {
return a;
}
const auto float_b = static_cast<float>(b);
if (std::isnan(float_b)) {
return b;
}

if (float_a > float_b) {
return a;
}
return b;
}

/**
* There is a slight difference in how std::fmod works compared to how ATen
* determines remainders:
Expand Down
25 changes: 21 additions & 4 deletions kernels/test/op_clamp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,16 @@ class OpClampOutTest : public OperatorTest {
// Test cases that are compatible with float and double.
template <ScalarType DTYPE>
void run_floating_point_test_cases() {
constexpr auto kInfinity =
std::numeric_limits<typename TensorFactory<DTYPE>::ctype>::infinity();
using ctype = typename TensorFactory<DTYPE>::ctype;
using opt_infinity_type = std::conditional_t<
std::is_same<ctype, exec_aten::Half>::value,
float,
ctype>;
constexpr auto kInfinity = std::numeric_limits<ctype>::infinity();
const auto kOptInfinity =
OptScalar(static_cast<opt_infinity_type>(kInfinity));
const auto kOptMinusInfinity =
OptScalar(static_cast<opt_infinity_type>(-kInfinity));
std::vector<ClampTestCase<DTYPE>> test_cases = {
{
std::string(__func__) + ": Simple negative/positive clamp",
Expand Down Expand Up @@ -178,7 +186,7 @@ class OpClampOutTest : public OperatorTest {
std::string(__func__) + ": Infinite min",
{2, 2}, // sizes
{-10.1, -1.1, 1.1, 10.1}, // input_data
OptScalar(-kInfinity), // min
kOptMinusInfinity, // min
OptScalar(5.5), // max
{-10.1, -1.1, 1.1, 5.5}, // expected_data
},
Expand All @@ -187,7 +195,7 @@ class OpClampOutTest : public OperatorTest {
{2, 2}, // sizes
{-10.1, -1.1, 1.1, 10.1}, // input_data
OptScalar(-5.5), // min
OptScalar(kInfinity), // max
kOptInfinity, // max
{-5.5, -1.1, 1.1, 10.1}, // expected_data
},
{
Expand Down Expand Up @@ -285,6 +293,15 @@ TEST_F(OpClampOutTest, LongTensors) {
run_signed_integer_test_cases<ScalarType::Long>();
}

TEST_F(OpClampOutTest, HalfTensors) {
// Note that the integer test cases test the situation where the min/max value
// Scalars are integer types, demonstrating that floating point types can be
// clamped to integer values.
run_unsigned_integer_test_cases<ScalarType::Half>();
run_signed_integer_test_cases<ScalarType::Half>();
run_floating_point_test_cases<ScalarType::Half>();
}

TEST_F(OpClampOutTest, FloatTensors) {
// Note that the integer test cases test the situation where the min/max value
// Scalars are integer types, demonstrating that floating point types can be
Expand Down
4 changes: 4 additions & 0 deletions kernels/test/op_minimum_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ TEST_F(OpMinimumOutTest, LongTensors) {
test_minimum_out_same_size<ScalarType::Long>();
}

TEST_F(OpMinimumOutTest, HalfTensors) {
test_minimum_out_same_size<ScalarType::Half>();
}

TEST_F(OpMinimumOutTest, FloatTensors) {
test_minimum_out_same_size<ScalarType::Float>();
}
Expand Down

0 comments on commit 2c0e1ef

Please sign in to comment.