Skip to content

Commit 5a54b60

Browse files
committed
[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion
1 parent 10be6c9 commit 5a54b60

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

+39-5
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final
578578
}
579579
};
580580

581+
struct VectorInterleaveOpConvert final
582+
: public OpConversionPattern<vector::InterleaveOp> {
583+
using OpConversionPattern::OpConversionPattern;
584+
585+
LogicalResult
586+
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
587+
ConversionPatternRewriter &rewriter) const override {
588+
// Check the source vector type
589+
auto sourceType = interleaveOp.getSourceVectorType();
590+
if (sourceType.getRank() != 1 || sourceType.isScalable()) {
591+
return rewriter.notifyMatchFailure(interleaveOp,
592+
"unsupported source vector type");
593+
}
594+
595+
// Check the result vector type
596+
auto oldResultType = interleaveOp.getResultVectorType();
597+
Type newResultType = getTypeConverter()->convertType(oldResultType);
598+
if (!newResultType)
599+
return rewriter.notifyMatchFailure(interleaveOp,
600+
"unsupported result vector type");
601+
602+
// Interleave the indices
603+
int n = sourceType.getNumElements();
604+
auto seq = llvm::seq<int64_t>(2 * n);
605+
auto indices = llvm::to_vector(
606+
llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
607+
608+
// Emit a SPIR-V shuffle.
609+
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
610+
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
611+
rewriter.getI32ArrayAttr(indices));
612+
613+
return success();
614+
}
615+
};
616+
581617
struct VectorLoadOpConverter final
582618
: public OpConversionPattern<vector::LoadOp> {
583619
using OpConversionPattern::OpConversionPattern;
@@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
822858
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
823859
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
824860
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
825-
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
826-
typeConverter, patterns.getContext(), PatternBenefit(1));
861+
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
862+
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
863+
PatternBenefit(1));
827864

828865
// Make sure that the more specialized dot product pattern has higher benefit
829866
// than the generic one that extracts all elements.
830867
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
831868
PatternBenefit(2));
832-
833-
// Need this until vector.interleave is handled.
834-
vector::populateVectorInterleaveToShufflePatterns(patterns);
835869
}
836870

837871
void mlir::populateVectorReductionToSPIRVDotProductPatterns(

0 commit comments

Comments
 (0)