-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion #93240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion #93240
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) ChangesAdd Full diff: https://github.com./llvm/llvm-project/pull/93240.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(llvm::map_range(
+ seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern,
+ VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
|
@llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesAdd Full diff: https://github.com./llvm/llvm-project/pull/93240.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(llvm::map_range(
+ seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern,
+ VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
5a54b60
to
cdc9def
Compare
c006ee5
to
cf47613
Compare
24c6d24
to
aed3118
Compare
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
aed3118
to
9a688a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
vector.interleave
tospirv.VectorShuffle
conversion,vector.interleave
tovector.shuffle
conversion frompopulateVectorToSPIRVPatterns
and CMake/Bazel dependencies