Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit 3dc8081

Browse files
DickJC123eric-haibin-lin
authored andcommitted
Add round-to-nearest-even rounding to float2half(). (#368)
* Add round-to-nearest-even to float2half(). Disable with -DMSHADOW_HALF_ROUND_TO_EVEN=0 build. * Correct #if guard name. * Fix lint. * Minor syntax fix for MXNet CI.
1 parent 6dc04f7 commit 3dc8081

File tree

1 file changed

+94
-28
lines changed

1 file changed

+94
-28
lines changed

mshadow/half.h

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
#include <x86intrin.h>
1414
#endif // MSHADOW_USE_F16C
1515

16+
// This flag dictates rounding for the float2half() routine only (used generally on Windows),
17+
// not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
18+
#ifndef MSHADOW_HALF_ROUND_TO_NEAREST
19+
#define MSHADOW_HALF_ROUND_TO_NEAREST 1
20+
#endif
21+
1622
#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
1723
#define MSHADOW_CUDA_HALF 1
1824
#include <cuda_fp16.h>
@@ -159,12 +165,18 @@ class MSHADOW_ALIGNED(2) half_t {
159165
uint32_t ui;
160166
};
161167

162-
static int const shift = 13;
168+
static int const fp16FractionBits = 10;
169+
static int const fp32FractionBits = 23;
170+
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
171+
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
172+
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
163173
static int const shiftSign = 16;
174+
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
164175

165176
static int32_t const infN = 0x7F800000; // flt32 infinity
166-
static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32
177+
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
167178
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
179+
static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
168180
static int32_t const signN = 0x80000000; // flt32 sign bit
169181

170182
static int32_t const infC = infN >> shift;
@@ -183,37 +195,91 @@ class MSHADOW_ALIGNED(2) half_t {
183195
static int32_t const minD = minC - subC - 1;
184196

185197
MSHADOW_XINLINE uint16_t float2half(const float& value) const {
186-
Bits v, s;
198+
Bits v;
187199
v.f = value;
188-
uint32_t sign = v.si & signN;
189-
v.si ^= sign;
190-
sign >>= shiftSign; // logical shift
191-
s.si = mulN;
192-
s.si = s.f * v.f; // correct subnormals
193-
v.si ^= (s.si ^ v.si) & -(minN > v.si);
194-
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
195-
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
196-
v.ui >>= shift; // logical shift
197-
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
198-
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
199-
return v.ui | sign;
200+
uint32_t sign = v.si & signN; // grab sign bit
201+
v.si ^= sign; // clear sign bit from v
202+
sign >>= shiftSign; // logical shift sign to fp16 position
203+
204+
if (v.si <= maxZ) {
205+
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
206+
v.ui = 0;
207+
} else if (v.si < minN) {
208+
// Handle denorms
209+
uint32_t exp32 = v.ui >> fp32FractionBits;
210+
int32_t exp16 = exp32 - expAdjust;
211+
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
212+
// Smaller (so negative) exp16 values should result in greater right shifts.
213+
uint32_t vshift = 1 - exp16;
214+
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
215+
v.ui = significand >> vshift;
216+
// The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
217+
// when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
218+
// bits to the right of the lsb are 1000... (including flt32 significand bits
219+
// that may be lost during the above vshift). The first term below will always
220+
// be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
221+
// right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
222+
// the proper test of the flt32 significand bits, including those lost during the vshift.
223+
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
224+
// Rounding may increase the exponent to 1, but that's OK.
225+
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
226+
#endif
227+
} else if (v.si <= maxN) {
228+
// Handle norms
229+
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
230+
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
231+
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
232+
#endif
233+
v.ui -= expAdjust << fp32FractionBits;
234+
} else if (v.si <= infN) {
235+
v.si = infN;
236+
} else if (v.si < nanN) {
237+
v.si = nanN;
238+
}
239+
240+
v.ui >>= shift;
241+
return sign | (v.ui & 0x7fff);
200242
}
201243

244+
// Same as above routine, except for addition of volatile keyword
202245
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
203-
Bits v, s;
246+
Bits v;
204247
v.f = value;
205-
uint32_t sign = v.si & signN;
206-
v.si ^= sign;
207-
sign >>= shiftSign; // logical shift
208-
s.si = mulN;
209-
s.si = s.f * v.f; // correct subnormals
210-
v.si ^= (s.si ^ v.si) & -(minN > v.si);
211-
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
212-
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
213-
v.ui >>= shift; // logical shift
214-
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
215-
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
216-
return v.ui | sign;
248+
uint32_t sign = v.si & signN; // grab sign bit
249+
v.si ^= sign; // clear sign bit from v
250+
sign >>= shiftSign; // logical shift sign to fp16 position
251+
252+
if (v.si <= maxZ) {
253+
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
254+
v.ui = 0;
255+
} else if (v.si < minN) {
256+
// Handle denorms
257+
uint32_t exp32 = v.ui >> fp32FractionBits;
258+
int32_t exp16 = exp32 - expAdjust;
259+
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
260+
// Smaller (so negative) exp16 values should result in greater right shifts.
261+
uint32_t vshift = 1 - exp16;
262+
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
263+
v.ui = significand >> vshift;
264+
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
265+
// Rounding may increase the exponent to 1, but that's OK.
266+
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
267+
#endif
268+
} else if (v.si <= maxN) {
269+
// Handle norms
270+
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
271+
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
272+
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
273+
#endif
274+
v.ui -= expAdjust << fp32FractionBits;
275+
} else if (v.si <= infN) {
276+
v.si = infN;
277+
} else if (v.si < nanN) {
278+
v.si = nanN;
279+
}
280+
281+
v.ui >>= shift;
282+
return sign | (v.ui & 0x7fff);
217283
}
218284

219285
MSHADOW_XINLINE float half2float(const uint16_t& value) const {

0 commit comments

Comments
 (0)