Skip to content

[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion #93240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 27, 2024
1 change: 0 additions & 1 deletion mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
MLIRSPIRVDialect
MLIRSPIRVConversion
MLIRVectorDialect
MLIRVectorTransforms
MLIRTransforms
)
49 changes: 44 additions & 5 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,47 @@ struct VectorShuffleOpConvert final
}
};

struct VectorInterleaveOpConvert final
: public OpConversionPattern<vector::InterleaveOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check the result vector type.
VectorType oldResultType = interleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(interleaveOp,
"unsupported result vector type");

// Interleave the indices.
VectorType sourceType = interleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();

// Input vectors of size 1 are converted to scalars by the type converter.
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeConstructOp`.
if (n == 1) {
Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
interleaveOp, newResultType, newOperands);
return success();
}

auto seq = llvm::seq<int64_t>(2 * n);
auto indices = llvm::map_to_vector(
seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });

// Emit a SPIR-V shuffle.
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
rewriter.getI32ArrayAttr(indices));

return success();
}
};

struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -822,16 +863,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
PatternBenefit(1));

// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));

// Need this until vector.interleave is handled.
vector::populateVectorInterleaveToShufflePatterns(patterns);
}

void mlir::populateVectorReductionToSPIRVDotProductPatterns(
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {

// -----

// CHECK-LABEL: func @interleave_size1
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>)
// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
// CHECK: return %[[RES]]
func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
%0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32>
return %0 : vector<2xf32>
}

// -----

// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
Expand Down
1 change: 0 additions & 1 deletion utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4976,7 +4976,6 @@ cc_library(
":VectorToLLVM",
":VectorToSCF",
":VectorTransformOpsIncGen",
":VectorTransforms",
":X86VectorTransforms",
],
)
Expand Down
Loading