@@ -578,6 +578,47 @@ struct VectorShuffleOpConvert final
578
578
}
579
579
};
580
580
581
+ struct VectorInterleaveOpConvert final
582
+ : public OpConversionPattern<vector::InterleaveOp> {
583
+ using OpConversionPattern::OpConversionPattern;
584
+
585
+ LogicalResult
586
+ matchAndRewrite (vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
587
+ ConversionPatternRewriter &rewriter) const override {
588
+ // Check the result vector type.
589
+ VectorType oldResultType = interleaveOp.getResultVectorType ();
590
+ Type newResultType = getTypeConverter ()->convertType (oldResultType);
591
+ if (!newResultType)
592
+ return rewriter.notifyMatchFailure (interleaveOp,
593
+ " unsupported result vector type" );
594
+
595
+ // Interleave the indices.
596
+ VectorType sourceType = interleaveOp.getSourceVectorType ();
597
+ int n = sourceType.getNumElements ();
598
+
599
+ // Input vectors of size 1 are converted to scalars by the type converter.
600
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
601
+ // use `spirv::CompositeConstructOp`.
602
+ if (n == 1 ) {
603
+ Value newOperands[] = {adaptor.getLhs (), adaptor.getRhs ()};
604
+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(
605
+ interleaveOp, newResultType, newOperands);
606
+ return success ();
607
+ }
608
+
609
+ auto seq = llvm::seq<int64_t >(2 * n);
610
+ auto indices = llvm::map_to_vector (
611
+ seq, [n](int i) { return (i % 2 ? n : 0 ) + i / 2 ; });
612
+
613
+ // Emit a SPIR-V shuffle.
614
+ rewriter.replaceOpWithNewOp <spirv::VectorShuffleOp>(
615
+ interleaveOp, newResultType, adaptor.getLhs (), adaptor.getRhs (),
616
+ rewriter.getI32ArrayAttr (indices));
617
+
618
+ return success ();
619
+ }
620
+ };
621
+
581
622
struct VectorLoadOpConverter final
582
623
: public OpConversionPattern<vector::LoadOp> {
583
624
using OpConversionPattern::OpConversionPattern;
@@ -822,16 +863,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
822
863
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
823
864
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
824
865
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
825
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
826
- typeConverter, patterns.getContext (), PatternBenefit (1 ));
866
+ VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
867
+ VectorStoreOpConverter>(typeConverter, patterns.getContext (),
868
+ PatternBenefit (1 ));
827
869
828
870
// Make sure that the more specialized dot product pattern has higher benefit
829
871
// than the generic one that extracts all elements.
830
872
patterns.add <VectorReductionToFPDotProd>(typeConverter, patterns.getContext (),
831
873
PatternBenefit (2 ));
832
-
833
- // Need this until vector.interleave is handled.
834
- vector::populateVectorInterleaveToShufflePatterns (patterns);
835
874
}
836
875
837
876
void mlir::populateVectorReductionToSPIRVDotProductPatterns (
0 commit comments