diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt index 113983146f5be..bb9f793d7fe0f 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV MLIRSPIRVDialect MLIRSPIRVConversion MLIRVectorDialect - MLIRVectorTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c2dd37f481466..a9ed25fbfbe0c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -578,6 +578,47 @@ struct VectorShuffleOpConvert final } }; +struct VectorInterleaveOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check the result vector type. + VectorType oldResultType = interleaveOp.getResultVectorType(); + Type newResultType = getTypeConverter()->convertType(oldResultType); + if (!newResultType) + return rewriter.notifyMatchFailure(interleaveOp, + "unsupported result vector type"); + + // Interleave the indices. + VectorType sourceType = interleaveOp.getSourceVectorType(); + int n = sourceType.getNumElements(); + + // Input vectors of size 1 are converted to scalars by the type converter. + // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to + // use `spirv::CompositeConstructOp`. + if (n == 1) { + Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()}; + rewriter.replaceOpWithNewOp( + interleaveOp, newResultType, newOperands); + return success(); + } + + auto seq = llvm::seq(2 * n); + auto indices = llvm::map_to_vector( + seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); + + // Emit a SPIR-V shuffle. + rewriter.replaceOpWithNewOp( + interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), + rewriter.getI32ArrayAttr(indices)); + + return success(); + } +}; + struct VectorLoadOpConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -822,16 +863,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( - typeConverter, patterns.getContext(), PatternBenefit(1)); + VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter, + VectorStoreOpConverter>(typeConverter, patterns.getContext(), + PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. patterns.add(typeConverter, patterns.getContext(), PatternBenefit(2)); - - // Need this until vector.interleave is handled. - vector::populateVectorInterleaveToShufflePatterns(patterns); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index b24088d951259..2592d0fc04111 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @interleave_size1 +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>) +// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32 +// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32 +// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32> +// CHECK: return %[[RES]] +func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> { + %0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// ----- + // CHECK-LABEL: func @reduction_add // CHECK-SAME: (%[[V:.+]]: vector<4xi32>) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index a7bbe459fd9d7..f31f75ca5c74a 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4976,7 +4976,6 @@ cc_library( ":VectorToLLVM", ":VectorToSCF", ":VectorTransformOpsIncGen", - ":VectorTransforms", ":X86VectorTransforms", ], )