Skip to content

Commit 57c10fa

Browse files
angelz913kuhar
andauthored
[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (#93240)
- Add `vector.interleave` to `spirv.VectorShuffle` conversion - Remove the `vector.interleave` to `vector.shuffle` conversion from `populateVectorToSPIRVPatterns` and CMake/Bazel dependencies --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 21a39df commit 57c10fa

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
1414
MLIRSPIRVDialect
1515
MLIRSPIRVConversion
1616
MLIRVectorDialect
17-
MLIRVectorTransforms
1817
MLIRTransforms
1918
)

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

+44-5
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,47 @@ struct VectorShuffleOpConvert final
578578
}
579579
};
580580

581+
struct VectorInterleaveOpConvert final
582+
: public OpConversionPattern<vector::InterleaveOp> {
583+
using OpConversionPattern::OpConversionPattern;
584+
585+
LogicalResult
586+
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
587+
ConversionPatternRewriter &rewriter) const override {
588+
// Check the result vector type.
589+
VectorType oldResultType = interleaveOp.getResultVectorType();
590+
Type newResultType = getTypeConverter()->convertType(oldResultType);
591+
if (!newResultType)
592+
return rewriter.notifyMatchFailure(interleaveOp,
593+
"unsupported result vector type");
594+
595+
// Interleave the indices.
596+
VectorType sourceType = interleaveOp.getSourceVectorType();
597+
int n = sourceType.getNumElements();
598+
599+
// Input vectors of size 1 are converted to scalars by the type converter.
600+
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
601+
// use `spirv::CompositeConstructOp`.
602+
if (n == 1) {
603+
Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
604+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
605+
interleaveOp, newResultType, newOperands);
606+
return success();
607+
}
608+
609+
auto seq = llvm::seq<int64_t>(2 * n);
610+
auto indices = llvm::map_to_vector(
611+
seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
612+
613+
// Emit a SPIR-V shuffle.
614+
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
615+
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
616+
rewriter.getI32ArrayAttr(indices));
617+
618+
return success();
619+
}
620+
};
621+
581622
struct VectorLoadOpConverter final
582623
: public OpConversionPattern<vector::LoadOp> {
583624
using OpConversionPattern::OpConversionPattern;
@@ -822,16 +863,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
822863
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
823864
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
824865
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
825-
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
826-
typeConverter, patterns.getContext(), PatternBenefit(1));
866+
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
867+
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
868+
PatternBenefit(1));
827869

828870
// Make sure that the more specialized dot product pattern has higher benefit
829871
// than the generic one that extracts all elements.
830872
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
831873
PatternBenefit(2));
832-
833-
// Need this until vector.interleave is handled.
834-
vector::populateVectorInterleaveToShufflePatterns(patterns);
835874
}
836875

837876
void mlir::populateVectorReductionToSPIRVDotProductPatterns(

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

+13
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
494494

495495
// -----
496496

497+
// CHECK-LABEL: func @interleave_size1
498+
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>)
499+
// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
500+
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
501+
// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
502+
// CHECK: return %[[RES]]
503+
func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
504+
%0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32>
505+
return %0 : vector<2xf32>
506+
}
507+
508+
// -----
509+
497510
// CHECK-LABEL: func @reduction_add
498511
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
499512
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

-1
Original file line numberDiff line numberDiff line change
@@ -4976,7 +4976,6 @@ cc_library(
49764976
":VectorToLLVM",
49774977
":VectorToSCF",
49784978
":VectorTransformOpsIncGen",
4979-
":VectorTransforms",
49804979
":X86VectorTransforms",
49814980
],
49824981
)

0 commit comments

Comments
 (0)