From 928f7282113d16e0395ca473565d6784d46c0056 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Thu, 23 May 2024 21:09:49 +0000 Subject: [PATCH 1/9] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion --- .../VectorToSPIRV/VectorToSPIRV.cpp | 44 ++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c2dd37f481466..95464ef6d438e 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final } }; +struct VectorInterleaveOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check the source vector type + auto sourceType = interleaveOp.getSourceVectorType(); + if (sourceType.getRank() != 1 || sourceType.isScalable()) { + return rewriter.notifyMatchFailure(interleaveOp, + "unsupported source vector type"); + } + + // Check the result vector type + auto oldResultType = interleaveOp.getResultVectorType(); + Type newResultType = getTypeConverter()->convertType(oldResultType); + if (!newResultType) + return rewriter.notifyMatchFailure(interleaveOp, + "unsupported result vector type"); + + // Interleave the indices + int n = sourceType.getNumElements(); + auto seq = llvm::seq(2 * n); + auto indices = llvm::to_vector( + llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; })); + + // Emit a SPIR-V shuffle. + rewriter.replaceOpWithNewOp( + interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), + rewriter.getI32ArrayAttr(indices)); + + return success(); + } +}; + struct VectorLoadOpConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( - typeConverter, patterns.getContext(), PatternBenefit(1)); + VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter, + VectorStoreOpConverter>(typeConverter, patterns.getContext(), + PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. patterns.add(typeConverter, patterns.getContext(), PatternBenefit(2)); - - // Need this until vector.interleave is handled. - vector::populateVectorInterleaveToShufflePatterns(patterns); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( From 1ff747a49e1e71b73a5de5fec5c6f4db430f9830 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 08:33:14 -0400 Subject: [PATCH 2/9] Use VectorType for sourceType Co-authored-by: Jakub Kuderski --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 95464ef6d438e..aa3670f81fea3 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -586,7 +586,7 @@ struct VectorInterleaveOpConvert final matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check the source vector type - auto sourceType = interleaveOp.getSourceVectorType(); + VectorType sourceType = interleaveOp.getSourceVectorType(); if (sourceType.getRank() != 1 || sourceType.isScalable()) { return rewriter.notifyMatchFailure(interleaveOp, "unsupported source vector type"); From e6eb044c38a88e965ffebe60877e571fbd97d0a1 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 12:42:35 +0000 Subject: [PATCH 3/9] Use VectorType for oldResultType --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index aa3670f81fea3..0af0595eebe0d 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -593,7 +593,7 @@ struct VectorInterleaveOpConvert final } // Check the result vector type - auto oldResultType = interleaveOp.getResultVectorType(); + VectorType oldResultType = interleaveOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(interleaveOp, From 7e9ea7fb8768582432c2b53dd10c0264ea23357c Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 16:06:10 +0000 Subject: [PATCH 4/9] Handle one-element input vector case and remove cmake/bazel dependencies --- mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt | 1 - .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 15 ++++++++++++++- .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 13 +++++++++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 - 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt index 113983146f5be..bb9f793d7fe0f 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV MLIRSPIRVDialect MLIRSPIRVConversion MLIRVectorDialect - MLIRVectorTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 0af0595eebe0d..a63ef5ab451eb 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -601,6 +601,19 @@ struct VectorInterleaveOpConvert final // Interleave the indices int n = sourceType.getNumElements(); + + // Input vectors of size 1 are converted to scalars by the type converter. + // We cannot use spirv::VectorShuffleOp directly in this case, and need to + // use spirv::CompositeConstructOp. + if (n == 1) { + SmallVector newOperands(2); + newOperands[0] = adaptor.getLhs(); + newOperands[1] = adaptor.getRhs(); + rewriter.replaceOpWithNewOp( + interleaveOp, newResultType, newOperands); + return success(); + } + auto seq = llvm::seq(2 * n); auto indices = llvm::to_vector( llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; })); @@ -609,7 +622,7 @@ struct VectorInterleaveOpConvert final rewriter.replaceOpWithNewOp( interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), rewriter.getI32ArrayAttr(indices)); - + return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index b24088d951259..f52e771f1d4a8 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @interleave_size1 +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>) +// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32 +// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32 +// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32> +// CHECK: return %[[RES]] +func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> { + %0 = vector.interleave %a, %b : vector<1xf32> + return %0 : vector<2xf32> +} + +// ----- + // CHECK-LABEL: func @reduction_add // CHECK-SAME: (%[[V:.+]]: vector<4xi32>) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index a7bbe459fd9d7..f31f75ca5c74a 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4976,7 +4976,6 @@ cc_library( ":VectorToLLVM", ":VectorToSCF", ":VectorTransformOpsIncGen", - ":VectorTransforms", ":X86VectorTransforms", ], ) From ff34b53a136f8af220a444d98426a817dfff9224 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 16:39:02 +0000 Subject: [PATCH 5/9] Remove check for source type and reformat code --- .../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index a63ef5ab451eb..69f89d087dd3c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -585,13 +585,6 @@ struct VectorInterleaveOpConvert final LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Check the source vector type - VectorType sourceType = interleaveOp.getSourceVectorType(); - if (sourceType.getRank() != 1 || sourceType.isScalable()) { - return rewriter.notifyMatchFailure(interleaveOp, - "unsupported source vector type"); - } - // Check the result vector type VectorType oldResultType = interleaveOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); @@ -600,10 +593,11 @@ struct VectorInterleaveOpConvert final "unsupported result vector type"); // Interleave the indices + VectorType sourceType = interleaveOp.getSourceVectorType(); int n = sourceType.getNumElements(); - // Input vectors of size 1 are converted to scalars by the type converter. - // We cannot use spirv::VectorShuffleOp directly in this case, and need to + // Input vectors of size 1 are converted to scalars by the type converter. + // We cannot use spirv::VectorShuffleOp directly in this case, and need to // use spirv::CompositeConstructOp. if (n == 1) { SmallVector newOperands(2); @@ -622,7 +616,7 @@ struct VectorInterleaveOpConvert final rewriter.replaceOpWithNewOp( interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), rewriter.getI32ArrayAttr(indices)); - + return success(); } }; From a7e9433eedf1bd0b63371670d0a367791c358838 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 13:39:33 -0400 Subject: [PATCH 6/9] Reformat code Co-authored-by: Jakub Kuderski --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 69f89d087dd3c..043b0741729d6 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -585,24 +585,22 @@ struct VectorInterleaveOpConvert final LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Check the result vector type + // Check the result vector type. VectorType oldResultType = interleaveOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(interleaveOp, "unsupported result vector type"); - // Interleave the indices + // Interleave the indices. VectorType sourceType = interleaveOp.getSourceVectorType(); int n = sourceType.getNumElements(); // Input vectors of size 1 are converted to scalars by the type converter. - // We cannot use spirv::VectorShuffleOp directly in this case, and need to - // use spirv::CompositeConstructOp. + // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to + // use `spirv::CompositeConstructOp`. if (n == 1) { - SmallVector newOperands(2); - newOperands[0] = adaptor.getLhs(); - newOperands[1] = adaptor.getRhs(); + Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()}; rewriter.replaceOpWithNewOp( interleaveOp, newResultType, newOperands); return success(); From a8e806ad2e6df3f817f30b04677487a59fb86bd1 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 20:37:22 +0000 Subject: [PATCH 7/9] Use llvm::map_to_vector --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 043b0741729d6..7c17042917ff9 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -607,8 +607,8 @@ struct VectorInterleaveOpConvert final } auto seq = llvm::seq(2 * n); - auto indices = llvm::to_vector( - llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; })); + auto indices = llvm::map_to_vector( + seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); // Emit a SPIR-V shuffle. rewriter.replaceOpWithNewOp( From 9a688a71a12f7f0ea3028c66d2a7003ddb33b888 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 20:53:31 +0000 Subject: [PATCH 8/9] Modify vector.interleave assembly format in the LIT test --- mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f52e771f1d4a8..2592d0fc04111 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -501,7 +501,7 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> { // CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32> // CHECK: return %[[RES]] func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> { - %0 = vector.interleave %a, %b : vector<1xf32> + %0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32> return %0 : vector<2xf32> } From 08e191aa1dc042a62cb6286ea71083fbaee606a7 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Mon, 27 May 2024 20:58:57 +0000 Subject: [PATCH 9/9] Reformat code --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 7c17042917ff9..a9ed25fbfbe0c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -608,7 +608,7 @@ struct VectorInterleaveOpConvert final auto seq = llvm::seq(2 * n); auto indices = llvm::map_to_vector( - seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); + seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); // Emit a SPIR-V shuffle. rewriter.replaceOpWithNewOp(