Skip to content

Commit ae1e35c

Browse files
committed
Initial experiments (with integer regs for fp16).
1 parent a87640c commit ae1e35c

File tree

9 files changed

+420
-18
lines changed

9 files changed

+420
-18
lines changed

clang/lib/Basic/Targets/SystemZ.h

+12
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,23 @@ class LLVM_LIBRARY_VISIBILITY SystemZTargetInfo : public TargetInfo {
9191
"-v128:64-a:8:16-n32:64");
9292
}
9393
MaxAtomicPromoteWidth = MaxAtomicInlineWidth = 128;
94+
95+
// True if the backend supports operations on the half LLVM IR type.
96+
HasLegalHalfType = false;
97+
// Allow half arguments and return values.
98+
HalfArgsAndReturns = true;
99+
// Support _Float16.
100+
HasFloat16 = true;
101+
94102
HasStrictFP = true;
95103
}
96104

97105
unsigned getMinGlobalAlign(uint64_t Size, bool HasNonWeakDef) const override;
98106

107+
bool useFP16ConversionIntrinsics() const override {
108+
return false;
109+
}
110+
99111
void getTargetDefines(const LangOptions &Opts,
100112
MacroBuilder &Builder) const override;
101113

clang/lib/CodeGen/Targets/SystemZ.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ bool SystemZABIInfo::isFPArgumentType(QualType Ty) const {
185185

186186
if (const BuiltinType *BT = Ty->getAs<BuiltinType>())
187187
switch (BT->getKind()) {
188+
// case BuiltinType::Half: // __fp16 Support __fp16??
189+
case BuiltinType::Float16: // _Float16
188190
case BuiltinType::Float:
189191
case BuiltinType::Double:
190192
return true;
@@ -277,7 +279,8 @@ RValue SystemZABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
277279
} else {
278280
if (AI.getCoerceToType())
279281
ArgTy = AI.getCoerceToType();
280-
InFPRs = (!IsSoftFloatABI && (ArgTy->isFloatTy() || ArgTy->isDoubleTy()));
282+
InFPRs = (!IsSoftFloatABI &&
283+
(ArgTy->isHalfTy() || ArgTy->isFloatTy() || ArgTy->isDoubleTy()));
281284
IsVector = ArgTy->isVectorTy();
282285
UnpaddedSize = TyInfo.Width;
283286
DirectAlign = TyInfo.Align;
@@ -446,10 +449,11 @@ ABIArgInfo SystemZABIInfo::classifyArgumentType(QualType Ty) const {
446449

447450
// The structure is passed as an unextended integer, a float, or a double.
448451
if (isFPArgumentType(SingleElementTy)) {
449-
assert(Size == 32 || Size == 64);
452+
assert(Size == 16 || Size == 32 || Size == 64);
450453
return ABIArgInfo::getDirect(
451-
Size == 32 ? llvm::Type::getFloatTy(getVMContext())
452-
: llvm::Type::getDoubleTy(getVMContext()));
454+
Size == 16 ? llvm::Type::getHalfTy(getVMContext())
455+
: Size == 32 ? llvm::Type::getFloatTy(getVMContext())
456+
: llvm::Type::getDoubleTy(getVMContext()));
453457
} else {
454458
llvm::IntegerType *PassTy = llvm::IntegerType::get(getVMContext(), Size);
455459
return Size <= 32 ? ABIArgInfo::getNoExtend(PassTy)

clang/lib/Sema/SemaExpr.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -16534,7 +16534,7 @@ ExprResult Sema::BuildVAArgExpr(SourceLocation BuiltinLoc,
1653416534
PromoteType = QualType();
1653516535
}
1653616536
}
16537-
if (TInfo->getType()->isSpecificBuiltinType(BuiltinType::Float))
16537+
if (TInfo->getType()->isFloat16Type() || TInfo->getType()->isFloat32Type())
1653816538
PromoteType = Context.DoubleTy;
1653916539
if (!PromoteType.isNull())
1654016540
DiagRuntimeBehavior(TInfo->getTypeLoc().getBeginLoc(), E,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: %clang_cc1 -triple s390x-linux-gnu \
2+
// RUN: -ffloat16-excess-precision=standard -emit-llvm -o - %s \
3+
// RUN: | FileCheck %s -check-prefix=STANDARD
4+
5+
// RUN: %clang_cc1 -triple s390x-linux-gnu \
6+
// RUN: -ffloat16-excess-precision=none -emit-llvm -o - %s \
7+
// RUN: | FileCheck %s -check-prefix=NONE
8+
9+
// RUN: %clang_cc1 -triple s390x-linux-gnu \
10+
// RUN: -ffloat16-excess-precision=fast -emit-llvm -o - %s \
11+
// RUN: | FileCheck %s -check-prefix=FAST
12+
13+
_Float16 f(_Float16 a, _Float16 b, _Float16 c, _Float16 d) {
14+
return a * b + c * d;
15+
}
16+
17+
// STANDARD-LABEL: define dso_local half @f(half noundef %a, half noundef %b, half noundef %c, half noundef %d) #0 {
18+
// STANDARD-NEXT: entry:
19+
// STANDARD-NEXT: %a.addr = alloca half, align 2
20+
// STANDARD-NEXT: %b.addr = alloca half, align 2
21+
// STANDARD-NEXT: %c.addr = alloca half, align 2
22+
// STANDARD-NEXT: %d.addr = alloca half, align 2
23+
// STANDARD-NEXT: store half %a, ptr %a.addr, align 2
24+
// STANDARD-NEXT: store half %b, ptr %b.addr, align 2
25+
// STANDARD-NEXT: store half %c, ptr %c.addr, align 2
26+
// STANDARD-NEXT: store half %d, ptr %d.addr, align 2
27+
// STANDARD-NEXT: %0 = load half, ptr %a.addr, align 2
28+
// STANDARD-NEXT: %ext = fpext half %0 to float
29+
// STANDARD-NEXT: %1 = load half, ptr %b.addr, align 2
30+
// STANDARD-NEXT: %ext1 = fpext half %1 to float
31+
// STANDARD-NEXT: %mul = fmul float %ext, %ext1
32+
// STANDARD-NEXT: %2 = load half, ptr %c.addr, align 2
33+
// STANDARD-NEXT: %ext2 = fpext half %2 to float
34+
// STANDARD-NEXT: %3 = load half, ptr %d.addr, align 2
35+
// STANDARD-NEXT: %ext3 = fpext half %3 to float
36+
// STANDARD-NEXT: %mul4 = fmul float %ext2, %ext3
37+
// STANDARD-NEXT: %add = fadd float %mul, %mul4
38+
// STANDARD-NEXT: %unpromotion = fptrunc float %add to half
39+
// STANDARD-NEXT: ret half %unpromotion
40+
// STANDARD-NEXT: }
41+
42+
// NONE-LABEL: define dso_local half @f(half noundef %a, half noundef %b, half noundef %c, half noundef %d) #0 {
43+
// NONE-NEXT: entry:
44+
// NONE-NEXT: %a.addr = alloca half, align 2
45+
// NONE-NEXT: %b.addr = alloca half, align 2
46+
// NONE-NEXT: %c.addr = alloca half, align 2
47+
// NONE-NEXT: %d.addr = alloca half, align 2
48+
// NONE-NEXT: store half %a, ptr %a.addr, align 2
49+
// NONE-NEXT: store half %b, ptr %b.addr, align 2
50+
// NONE-NEXT: store half %c, ptr %c.addr, align 2
51+
// NONE-NEXT: store half %d, ptr %d.addr, align 2
52+
// NONE-NEXT: %0 = load half, ptr %a.addr, align 2
53+
// NONE-NEXT: %1 = load half, ptr %b.addr, align 2
54+
// NONE-NEXT: %mul = fmul half %0, %1
55+
// NONE-NEXT: %2 = load half, ptr %c.addr, align 2
56+
// NONE-NEXT: %3 = load half, ptr %d.addr, align 2
57+
// NONE-NEXT: %mul1 = fmul half %2, %3
58+
// NONE-NEXT: %add = fadd half %mul, %mul1
59+
// NONE-NEXT: ret half %add
60+
// NONE-NEXT: }
61+
62+
// FAST-LABEL: define dso_local half @f(half noundef %a, half noundef %b, half noundef %c, half noundef %d) #0 {
63+
// FAST-NEXT: entry:
64+
// FAST-NEXT: %a.addr = alloca half, align 2
65+
// FAST-NEXT: %b.addr = alloca half, align 2
66+
// FAST-NEXT: %c.addr = alloca half, align 2
67+
// FAST-NEXT: %d.addr = alloca half, align 2
68+
// FAST-NEXT: store half %a, ptr %a.addr, align 2
69+
// FAST-NEXT: store half %b, ptr %b.addr, align 2
70+
// FAST-NEXT: store half %c, ptr %c.addr, align 2
71+
// FAST-NEXT: store half %d, ptr %d.addr, align 2
72+
// FAST-NEXT: %0 = load half, ptr %a.addr, align 2
73+
// FAST-NEXT: %ext = fpext half %0 to float
74+
// FAST-NEXT: %1 = load half, ptr %b.addr, align 2
75+
// FAST-NEXT: %ext1 = fpext half %1 to float
76+
// FAST-NEXT: %mul = fmul float %ext, %ext1
77+
// FAST-NEXT: %2 = load half, ptr %c.addr, align 2
78+
// FAST-NEXT: %ext2 = fpext half %2 to float
79+
// FAST-NEXT: %3 = load half, ptr %d.addr, align 2
80+
// FAST-NEXT: %ext3 = fpext half %3 to float
81+
// FAST-NEXT: %mul4 = fmul float %ext2, %ext3
82+
// FAST-NEXT: %add = fadd float %mul, %mul4
83+
// FAST-NEXT: %unpromotion = fptrunc float %add to half
84+
// FAST-NEXT: ret half %unpromotion
85+
// FAST-NEXT: }

clang/test/CodeGen/SystemZ/systemz-abi.c

+44
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ long long pass_longlong(long long arg) { return arg; }
4545
__int128 pass_int128(__int128 arg) { return arg; }
4646
// CHECK-LABEL: define{{.*}} void @pass_int128(ptr dead_on_unwind noalias writable sret(i128) align 8 %{{.*}}, ptr %0)
4747

48+
_Float16 pass__Float16(_Float16 arg) { return arg; }
49+
// CHECK-LABEL: define{{.*}} half @pass__Float16(half %{{.*}})
50+
4851
float pass_float(float arg) { return arg; }
4952
// CHECK-LABEL: define{{.*}} float @pass_float(float %{{.*}})
5053

@@ -72,6 +75,9 @@ _Complex long pass_complex_long(_Complex long arg) { return arg; }
7275
_Complex long long pass_complex_longlong(_Complex long long arg) { return arg; }
7376
// CHECK-LABEL: define{{.*}} void @pass_complex_longlong(ptr dead_on_unwind noalias writable sret({ i64, i64 }) align 8 %{{.*}}, ptr %{{.*}}arg)
7477

78+
_Complex _Float16 pass_complex__Float16(_Complex _Float16 arg) { return arg; }
79+
// CHECK-LABEL: define{{.*}} void @pass_complex__Float16(ptr dead_on_unwind noalias writable sret({ half, half }) align 2 %{{.*}}, ptr %{{.*}}arg)
80+
7581
_Complex float pass_complex_float(_Complex float arg) { return arg; }
7682
// CHECK-LABEL: define{{.*}} void @pass_complex_float(ptr dead_on_unwind noalias writable sret({ float, float }) align 4 %{{.*}}, ptr %{{.*}}arg)
7783

@@ -123,6 +129,11 @@ struct agg_16byte pass_agg_16byte(struct agg_16byte arg) { return arg; }
123129

124130
// Float-like aggregate types
125131

132+
struct agg__Float16 { _Float16 a; };
133+
struct agg__Float16 pass_agg__Float16(struct agg__Float16 arg) { return arg; }
134+
// HARD-FLOAT-LABEL: define{{.*}} void @pass_agg__Float16(ptr dead_on_unwind noalias writable sret(%struct.agg__Float16) align 2 %{{.*}}, half %{{.*}})
135+
// SOFT-FLOAT-LABEL: define{{.*}} void @pass_agg__Float16(ptr dead_on_unwind noalias writable sret(%struct.agg__Float16) align 2 %{{.*}}, i16 noext %{{.*}})
136+
126137
struct agg_float { float a; };
127138
struct agg_float pass_agg_float(struct agg_float arg) { return arg; }
128139
// HARD-FLOAT-LABEL: define{{.*}} void @pass_agg_float(ptr dead_on_unwind noalias writable sret(%struct.agg_float) align 4 %{{.*}}, float %{{.*}})
@@ -137,6 +148,11 @@ struct agg_longdouble { long double a; };
137148
struct agg_longdouble pass_agg_longdouble(struct agg_longdouble arg) { return arg; }
138149
// CHECK-LABEL: define{{.*}} void @pass_agg_longdouble(ptr dead_on_unwind noalias writable sret(%struct.agg_longdouble) align 8 %{{.*}}, ptr %{{.*}})
139150

151+
struct agg__Float16_a8 { _Float16 a __attribute__((aligned (8))); };
152+
struct agg__Float16_a8 pass_agg__Float16_a8(struct agg__Float16_a8 arg) { return arg; }
153+
// HARD-FLOAT-LABEL: define{{.*}} void @pass_agg__Float16_a8(ptr dead_on_unwind noalias writable sret(%struct.agg__Float16_a8) align 8 %{{.*}}, double %{{.*}})
154+
// SOFT-FLOAT-LABEL: define{{.*}} void @pass_agg__Float16_a8(ptr dead_on_unwind noalias writable sret(%struct.agg__Float16_a8) align 8 %{{.*}}, i64 %{{.*}})
155+
140156
struct agg_float_a8 { float a __attribute__((aligned (8))); };
141157
struct agg_float_a8 pass_agg_float_a8(struct agg_float_a8 arg) { return arg; }
142158
// HARD-FLOAT-LABEL: define{{.*}} void @pass_agg_float_a8(ptr dead_on_unwind noalias writable sret(%struct.agg_float_a8) align 8 %{{.*}}, double %{{.*}})
@@ -164,6 +180,10 @@ struct agg_nofloat3 pass_agg_nofloat3(struct agg_nofloat3 arg) { return arg; }
164180

165181
// Union types likewise are *not* float-like aggregate types
166182

183+
union union__Float16 { _Float16 a; };
184+
union union__Float16 pass_union__Float16(union union__Float16 arg) { return arg; }
185+
// CHECK-LABEL: define{{.*}} void @pass_union__Float16(ptr dead_on_unwind noalias writable sret(%union.union__Float16) align 2 %{{.*}}, i16 noext %{{.*}})
186+
167187
union union_float { float a; };
168188
union union_float pass_union_float(union union_float arg) { return arg; }
169189
// CHECK-LABEL: define{{.*}} void @pass_union_float(ptr dead_on_unwind noalias writable sret(%union.union_float) align 4 %{{.*}}, i32 noext %{{.*}})
@@ -441,6 +461,30 @@ struct agg_8byte va_agg_8byte(__builtin_va_list l) { return __builtin_va_arg(l,
441461
// CHECK: [[VA_ARG_ADDR:%[^ ]+]] = phi ptr [ [[RAW_REG_ADDR]], %{{.*}} ], [ [[RAW_MEM_ADDR]], %{{.*}} ]
442462
// CHECK: ret void
443463

464+
struct agg__Float16 va_agg__Float16(__builtin_va_list l) { return __builtin_va_arg(l, struct agg__Float16); }
465+
// CHECK-LABEL: define{{.*}} void @va_agg__Float16(ptr dead_on_unwind noalias writable sret(%struct.agg__Float16) align 2 %{{.*}}, ptr %{{.*}}
466+
// HARD-FLOAT: [[REG_COUNT_PTR:%[^ ]+]] = getelementptr inbounds nuw %struct.__va_list_tag, ptr %{{.*}}, i32 0, i32 1
467+
// SOFT-FLOAT: [[REG_COUNT_PTR:%[^ ]+]] = getelementptr inbounds nuw %struct.__va_list_tag, ptr %{{.*}}, i32 0, i32 0
468+
// CHECK: [[REG_COUNT:%[^ ]+]] = load i64, ptr [[REG_COUNT_PTR]]
469+
// HARD-FLOAT: [[FITS_IN_REGS:%[^ ]+]] = icmp ult i64 [[REG_COUNT]], 4
470+
// SOFT-FLOAT: [[FITS_IN_REGS:%[^ ]+]] = icmp ult i64 [[REG_COUNT]], 5
471+
// CHECK: br i1 [[FITS_IN_REGS]],
472+
// CHECK: [[SCALED_REG_COUNT:%[^ ]+]] = mul i64 [[REG_COUNT]], 8
473+
// HARD-FLOAT: [[REG_OFFSET:%[^ ]+]] = add i64 [[SCALED_REG_COUNT]], 128
474+
// SOFT-FLOAT: [[REG_OFFSET:%[^ ]+]] = add i64 [[SCALED_REG_COUNT]], 22
475+
// CHECK: [[REG_SAVE_AREA_PTR:%[^ ]+]] = getelementptr inbounds nuw %struct.__va_list_tag, ptr %{{.*}}, i32 0, i32 3
476+
// CHECK: [[REG_SAVE_AREA:%[^ ]+]] = load ptr, ptr [[REG_SAVE_AREA_PTR:[^ ]+]]
477+
// CHECK: [[RAW_REG_ADDR:%[^ ]+]] = getelementptr i8, ptr [[REG_SAVE_AREA]], i64 [[REG_OFFSET]]
478+
// CHECK: [[REG_COUNT1:%[^ ]+]] = add i64 [[REG_COUNT]], 1
479+
// CHECK: store i64 [[REG_COUNT1]], ptr [[REG_COUNT_PTR]]
480+
// CHECK: [[OVERFLOW_ARG_AREA_PTR:%[^ ]+]] = getelementptr inbounds nuw %struct.__va_list_tag, ptr %{{.*}}, i32 0, i32 2
481+
// CHECK: [[OVERFLOW_ARG_AREA:%[^ ]+]] = load ptr, ptr [[OVERFLOW_ARG_AREA_PTR]]
482+
// CHECK: [[RAW_MEM_ADDR:%[^ ]+]] = getelementptr i8, ptr [[OVERFLOW_ARG_AREA]], i64 6
483+
// CHECK: [[OVERFLOW_ARG_AREA2:%[^ ]+]] = getelementptr i8, ptr [[OVERFLOW_ARG_AREA]], i64 8
484+
// CHECK: store ptr [[OVERFLOW_ARG_AREA2]], ptr [[OVERFLOW_ARG_AREA_PTR]]
485+
// CHECK: [[VA_ARG_ADDR:%[^ ]+]] = phi ptr [ [[RAW_REG_ADDR]], %{{.*}} ], [ [[RAW_MEM_ADDR]], %{{.*}} ]
486+
// CHECK: ret void
487+
444488
struct agg_float va_agg_float(__builtin_va_list l) { return __builtin_va_arg(l, struct agg_float); }
445489
// CHECK-LABEL: define{{.*}} void @va_agg_float(ptr dead_on_unwind noalias writable sret(%struct.agg_float) align 4 %{{.*}}, ptr %{{.*}}
446490
// HARD-FLOAT: [[REG_COUNT_PTR:%[^ ]+]] = getelementptr inbounds nuw %struct.__va_list_tag, ptr %{{.*}}, i32 0, i32 1

llvm/lib/Target/SystemZ/SystemZCallingConv.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def RetCC_SystemZ_ELF : CallingConv<[
5050
// other floating-point argument registers available for code that
5151
// doesn't care about the ABI. All floating-point argument registers
5252
// are call-clobbered, so we can use all of them here.
53+
CCIfType<[f16], CCAssignToReg<[F0S, F2S, F4S, F6S]>>,
5354
CCIfType<[f32], CCAssignToReg<[F0S, F2S, F4S, F6S]>>,
5455
CCIfType<[f64], CCAssignToReg<[F0D, F2D, F4D, F6D]>>,
5556

@@ -115,6 +116,7 @@ def CC_SystemZ_ELF : CallingConv<[
115116
CCIfType<[i64], CCAssignToReg<[R2D, R3D, R4D, R5D, R6D]>>,
116117

117118
// The first 4 float and double arguments are passed in even registers F0-F6.
119+
CCIfType<[f16], CCAssignToReg<[F0S, F2S, F4S, F6S]>>,
118120
CCIfType<[f32], CCAssignToReg<[F0S, F2S, F4S, F6S]>>,
119121
CCIfType<[f64], CCAssignToReg<[F0D, F2D, F4D, F6D]>>,
120122

@@ -138,7 +140,7 @@ def CC_SystemZ_ELF : CallingConv<[
138140
CCAssignToStack<16, 8>>>,
139141

140142
// Other arguments are passed in 8-byte-aligned 8-byte stack slots.
141-
CCIfType<[i32, i64, f32, f64], CCAssignToStack<8, 8>>
143+
CCIfType<[i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>
142144
]>;
143145

144146
//===----------------------------------------------------------------------===//

llvm/lib/Target/SystemZ/SystemZISelLowering.cpp

+58-4
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711711
setOperationAction(ISD::BITCAST, MVT::f32, Custom);
712712
}
713713

714+
// Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715+
// stores in GPRs.
716+
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
717+
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
718+
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
719+
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
720+
714721
// VASTART and VACOPY need to deal with the SystemZ-specific varargs
715722
// structure, but VAEND is a no-op.
716723
setOperationAction(ISD::VASTART, MVT::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784791
return Subtarget.hasSoftFloat();
785792
}
786793

794+
MVT SystemZTargetLowering::getRegisterTypeForCallingConv(
795+
LLVMContext &Context, CallingConv::ID CC,
796+
EVT VT) const {
797+
// 128-bit single-element vector types are passed like other vectors,
798+
// not like their element type.
799+
if (VT.isVector() && VT.getSizeInBits() == 128 &&
800+
VT.getVectorNumElements() == 1)
801+
return MVT::v16i8;
802+
// Keep f16 so that they can be recognized and handled.
803+
if (VT == MVT::f16)
804+
return MVT::f16;
805+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
806+
}
807+
787808
EVT SystemZTargetLowering::getSetCCResultType(const DataLayout &DL,
788809
LLVMContext &, EVT VT) const {
789810
if (!VT.isVector())
@@ -1597,6 +1618,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
15971618
return true;
15981619
}
15991620

1621+
// Convert f16 to f32 (Out-arg).
1622+
if (PartVT == MVT::f16) {
1623+
assert(NumParts == 1 && "");
1624+
SDValue I16Val = DAG.getBitcast(MVT::i16, Val);
1625+
SDValue I32Val = DAG.getAnyExtOrTrunc(I16Val, DL, MVT::i32);
1626+
Parts[0] = DAG.getBitcast(MVT::f32, I32Val);
1627+
return true;
1628+
}
1629+
16001630
return false;
16011631
}
16021632

@@ -1612,6 +1642,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
16121642
return SDValue();
16131643
}
16141644

1645+
// F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1646+
// CopyFromReg was made into an f32 as required as FP32 registers are used
1647+
// for arguments, now convert it to f16.
1648+
static SDValue convertF32ToF16(SDValue F32Val, SelectionDAG &DAG,
1649+
const SDLoc &DL) {
1650+
assert(F32Val->getOpcode() == ISD::CopyFromReg &&
1651+
"Only expecting to handle f16 with CopyFromReg here.");
1652+
SDValue I32Val = DAG.getBitcast(MVT::i32, F32Val);
1653+
SDValue I16Val = DAG.getAnyExtOrTrunc(I32Val, DL, MVT::i16);
1654+
return DAG.getBitcast(MVT::f16, I16Val);
1655+
}
1656+
16151657
SDValue SystemZTargetLowering::LowerFormalArguments(
16161658
SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
16171659
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -1651,6 +1693,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16511693
NumFixedGPRs += 1;
16521694
RC = &SystemZ::GR64BitRegClass;
16531695
break;
1696+
case MVT::f16:
16541697
case MVT::f32:
16551698
NumFixedFPRs += 1;
16561699
RC = &SystemZ::FP32BitRegClass;
@@ -1675,7 +1718,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16751718

16761719
Register VReg = MRI.createVirtualRegister(RC);
16771720
MRI.addLiveIn(VA.getLocReg(), VReg);
1678-
ArgValue = DAG.getCopyFromReg(Chain, DL, VReg, LocVT);
1721+
// Special handling is needed for f16.
1722+
MVT ArgVT = VA.getLocVT() == MVT::f16 ? MVT::f32 : VA.getLocVT();
1723+
ArgValue = DAG.getCopyFromReg(Chain, DL, VReg, ArgVT);
1724+
if (VA.getLocVT() == MVT::f16)
1725+
ArgValue = convertF32ToF16(ArgValue, DAG, DL);
16791726
} else {
16801727
assert(VA.isMemLoc() && "Argument not register or memory");
16811728

@@ -1695,9 +1742,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16951742
// from this parameter. Unpromoted ints and floats are
16961743
// passed as right-justified 8-byte values.
16971744
SDValue FIN = DAG.getFrameIndex(FI, PtrVT);
1698-
if (VA.getLocVT() == MVT::i32 || VA.getLocVT() == MVT::f32)
1745+
if (VA.getLocVT() == MVT::i32 || VA.getLocVT() == MVT::f32 ||
1746+
VA.getLocVT() == MVT::f16) {
1747+
unsigned SlotOffs = VA.getLocVT() == MVT::f16 ? 6 : 4;
16991748
FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN,
1700-
DAG.getIntPtrConstant(4, DL));
1749+
DAG.getIntPtrConstant(SlotOffs, DL));
1750+
}
17011751
ArgValue = DAG.getLoad(LocVT, DL, Chain, FIN,
17021752
MachinePointerInfo::getFixedStack(MF, FI));
17031753
}
@@ -2120,10 +2170,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
21202170
// Copy all of the result registers out of their specified physreg.
21212171
for (CCValAssign &VA : RetLocs) {
21222172
// Copy the value out, gluing the copy to the end of the call sequence.
2173+
// Special handling is needed for f16.
2174+
MVT ArgVT = VA.getLocVT() == MVT::f16 ? MVT::f32 : VA.getLocVT();
21232175
SDValue RetValue = DAG.getCopyFromReg(Chain, DL, VA.getLocReg(),
2124-
VA.getLocVT(), Glue);
2176+
ArgVT, Glue);
21252177
Chain = RetValue.getValue(1);
21262178
Glue = RetValue.getValue(2);
2179+
if (VA.getLocVT() == MVT::f16)
2180+
RetValue = convertF32ToF16(RetValue, DAG, DL);
21272181

21282182
// Convert the value of the return register into the value that's
21292183
// being returned.

0 commit comments

Comments
 (0)