Skip to content

Commit f1090b0

Browse files
committed
first crack at optimized op_where
ghstack-comment-id: 2691805026 ghstack-source-id: c88f9387a18951f40fffb1cc9971daafe7b82122 ghstack-comment-id: 2691808920 Pull Request resolved: #8866
1 parent 1273532 commit f1090b0

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

kernels/optimized/cpu/op_where.cpp

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <iostream>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
Tensor& opt_where_out(
17+
KernelRuntimeContext& ctx,
18+
const Tensor& cond,
19+
const Tensor& a,
20+
const Tensor& b,
21+
Tensor& out) {
22+
// Common Dtype
23+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
24+
25+
// Check Common Dtype
26+
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
27+
28+
// Check Dim Order
29+
ET_KERNEL_CHECK(
30+
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
31+
32+
// Resize
33+
ET_KERNEL_CHECK(
34+
ctx,
35+
resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok,
36+
InvalidArgument,
37+
out);
38+
39+
// Compute Dtype
40+
ScalarType compute_type = utils::get_compute_type(common_type);
41+
42+
// @lint-ignore CLANGTIDY facebook-hte-CArray
43+
static constexpr const char op_name[] = "where.self_out";
44+
45+
if (a.scalar_type() == b.scalar_type() &&
46+
a.scalar_type() == out.scalar_type() && a.scalar_type() == compute_type &&
47+
// Using a Byte tensor for cond has been deprecated for a long time.
48+
cond.scalar_type() == ScalarType::Bool) {
49+
auto out_numel = out.numel();
50+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
51+
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
52+
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
53+
const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
54+
const bool any_is_broadcasted =
55+
(a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
56+
const CTYPE_COMPUTE* const data_a = a.const_data_ptr<CTYPE_COMPUTE>();
57+
const CTYPE_COMPUTE* const data_b = b.const_data_ptr<CTYPE_COMPUTE>();
58+
const bool* const data_cond = cond.const_data_ptr<bool>();
59+
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
60+
if (any_is_broadcasted) {
61+
for (const auto [out_index, a_index, b_index, cond_index] :
62+
BroadcastIndexesRange<3>(out, a, b, cond)) {
63+
data_out[out_index] =
64+
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
65+
}
66+
} else {
67+
for (const auto i : c10::irange(out_numel)) {
68+
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
69+
}
70+
}
71+
});
72+
} else {
73+
// Fall back for mixed dtype to keep code size and compile time
74+
// reasonable.
75+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
76+
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
77+
[](const CTYPE_COMPUTE val_a,
78+
const CTYPE_COMPUTE val_b,
79+
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
80+
ctx,
81+
a,
82+
utils::SupportedTensorDtypes::REALHBBF16,
83+
b,
84+
utils::SupportedTensorDtypes::REALHBBF16,
85+
cond,
86+
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
87+
out,
88+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
89+
});
90+
}
91+
92+
return out;
93+
}
94+
95+
} // namespace native
96+
} // namespace executor
97+
} // namespace torch

kernels/optimized/cpu/targets.bzl

+6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = (
9595
"//executorch/kernels/portable/cpu/util:broadcast_util",
9696
],
9797
),
98+
op_target(
99+
name = "op_where",
100+
deps = [
101+
"//executorch/kernels/portable/cpu/util:elementwise_util",
102+
],
103+
),
98104
)
99105

100106

kernels/optimized/optimized.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,8 @@
101101
kernels:
102102
- arg_meta: null
103103
kernel_name: torch::executor::opt_sub_scalar_out
104+
105+
- op: where.self_out
106+
kernels:
107+
- arg_meta: null
108+
kernel_name: torch::executor::opt_where_out

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ set(_optimized_kernels_test_sources
275275
"op_native_layer_norm_test.cpp"
276276
"op_neg_test.cpp"
277277
"op_sub_test.cpp"
278+
"op_where_test.cpp"
278279
"UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp"
279280
${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp
280281
)

0 commit comments

Comments
 (0)