Skip to content

Commit 764977b

Browse files
committed
Update
[ghstack-poisoned]
1 parent efd1a06 commit 764977b

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

kernels/portable/cpu/op_argmax.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ Tensor& argmax_out(
5050
for (const auto out_ix : c10::irange(out.numel())) {
5151
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
5252
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) {
53+
// the below condition as written is equivalent to
54+
// !isnan(accval) && (isnan(v) || v > acc_val). See
55+
// argument in op_argmin.cpp.
56+
if (!std::isnan(acc_val) && !(v <= acc_val)) {
5457
acc_val = v;
5558
acc_ix = ix;
5659
}

kernels/portable/cpu/op_argmin.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,17 @@ Tensor& argmin_out(
5050
for (const auto out_ix : c10::irange(out.numel())) {
5151
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
5252
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
53+
// the below condition as written is equivalent to !isnan(accval) &&
54+
// (isnan(v) || v < acc_val). cases:
55+
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
56+
// trivially equivalent to v < acc_val.
57+
// - if acc_val is NaN, the whole thing is trivially false.
58+
// - if acc_val is not NaN and v is NaN, then v >= acc_val
59+
// - is false because all comparisons involving NaN are
60+
// - false, so the result is true. The result is trivially
61+
// - true for the above condition that uses isnan(v) as
62+
// - well.
63+
if (!std::isnan(acc_val) && !(v >= acc_val)) {
5464
acc_val = v;
5565
acc_ix = ix;
5666
}

kernels/test/op_argmax_test.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,16 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) {
9090
EXPECT_TENSOR_EQ(out, expected);
9191
// clang-format on
9292
}
93+
94+
TEST_F(OpArgmaxTest, FirstNaNWins) {
95+
TensorFactory<ScalarType::Float> tf_float;
96+
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});
97+
98+
TensorFactory<ScalarType::Long> tf_long;
99+
Tensor out = tf_long.zeros({});
100+
Tensor expected = tf_long.make({}, {1});
101+
102+
Tensor ret = op_argmax_out(in, {}, false, out);
103+
EXPECT_TENSOR_EQ(out, ret);
104+
EXPECT_TENSOR_EQ(out, expected);
105+
}

kernels/test/op_argmin_test.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,16 @@ TEST_F(OpArgminTest, SanityCheckNullDim) {
9090
EXPECT_TENSOR_EQ(out, expected);
9191
// clang-format on
9292
}
93+
94+
TEST_F(OpArgminTest, FirstNaNWins) {
95+
TensorFactory<ScalarType::Float> tf_float;
96+
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});
97+
98+
TensorFactory<ScalarType::Long> tf_long;
99+
Tensor out = tf_long.zeros({});
100+
Tensor expected = tf_long.make({}, {1});
101+
102+
Tensor ret = op_argmin_out(in, {}, false, out);
103+
EXPECT_TENSOR_EQ(out, ret);
104+
EXPECT_TENSOR_EQ(out, expected);
105+
}

0 commit comments

Comments
 (0)