@@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final
578
578
}
579
579
};
580
580
581
+ struct VectorInterleaveOpConvert final
582
+ : public OpConversionPattern<vector::InterleaveOp> {
583
+ using OpConversionPattern::OpConversionPattern;
584
+
585
+ LogicalResult
586
+ matchAndRewrite (vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
587
+ ConversionPatternRewriter &rewriter) const override {
588
+ // Check the source vector type
589
+ auto sourceType = interleaveOp.getSourceVectorType ();
590
+ if (sourceType.getRank () != 1 || sourceType.isScalable ()) {
591
+ return rewriter.notifyMatchFailure (interleaveOp,
592
+ " unsupported source vector type" );
593
+ }
594
+
595
+ // Check the result vector type
596
+ auto oldResultType = interleaveOp.getResultVectorType ();
597
+ Type newResultType = getTypeConverter ()->convertType (oldResultType);
598
+ if (!newResultType)
599
+ return rewriter.notifyMatchFailure (interleaveOp,
600
+ " unsupported result vector type" );
601
+
602
+ // Interleave the indices
603
+ int n = sourceType.getNumElements ();
604
+ auto seq = llvm::seq<int64_t >(2 * n);
605
+ auto indices = llvm::to_vector (
606
+ llvm::map_range (seq, [n](int i) { return (i % 2 ? n : 0 ) + i / 2 ; }));
607
+
608
+ // Emit a SPIR-V shuffle.
609
+ rewriter.replaceOpWithNewOp <spirv::VectorShuffleOp>(
610
+ interleaveOp, newResultType, adaptor.getLhs (), adaptor.getRhs (),
611
+ rewriter.getI32ArrayAttr (indices));
612
+
613
+ return success ();
614
+ }
615
+ };
616
+
581
617
struct VectorLoadOpConverter final
582
618
: public OpConversionPattern<vector::LoadOp> {
583
619
using OpConversionPattern::OpConversionPattern;
@@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
822
858
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
823
859
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
824
860
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
825
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
826
- typeConverter, patterns.getContext (), PatternBenefit (1 ));
861
+ VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
862
+ VectorStoreOpConverter>(typeConverter, patterns.getContext (),
863
+ PatternBenefit (1 ));
827
864
828
865
// Make sure that the more specialized dot product pattern has higher benefit
829
866
// than the generic one that extracts all elements.
830
867
patterns.add <VectorReductionToFPDotProd>(typeConverter, patterns.getContext (),
831
868
PatternBenefit (2 ));
832
-
833
- // Need this until vector.interleave is handled.
834
- vector::populateVectorInterleaveToShufflePatterns (patterns);
835
869
}
836
870
837
871
void mlir::populateVectorReductionToSPIRVDotProductPatterns (
0 commit comments