13
13
#include < x86intrin.h>
14
14
#endif // MSHADOW_USE_F16C
15
15
16
+ // This flag dictates rounding for the float2half() routine only (used generally on Windows),
17
+ // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
18
+ #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
19
+ #define MSHADOW_HALF_ROUND_TO_NEAREST 1
20
+ #endif
21
+
16
22
#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
17
23
#define MSHADOW_CUDA_HALF 1
18
24
#include < cuda_fp16.h>
@@ -159,12 +165,18 @@ class MSHADOW_ALIGNED(2) half_t {
159
165
uint32_t ui;
160
166
};
161
167
162
- static int const shift = 13 ;
168
+ static int const fp16FractionBits = 10 ;
169
+ static int const fp32FractionBits = 23 ;
170
+ static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
171
+ static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
172
+ static int const shift = fp32FractionBits - fp16FractionBits; // == 13
163
173
static int const shiftSign = 16 ;
174
+ static int32_t const expAdjust = 127 - 15 ; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
164
175
165
176
static int32_t const infN = 0x7F800000 ; // flt32 infinity
166
- static int32_t const maxN = 0x477FE000 ; // max flt16 normal as a flt32
177
+ static int32_t const maxN = 0x477FFFFF ; // max flt32 that's a flt16 normal after >> by shift
167
178
static int32_t const minN = 0x38800000 ; // min flt16 normal as a flt32
179
+ static int32_t const maxZ = 0x33000000 ; // max fp32 number that's still rounded to zero in fp16
168
180
static int32_t const signN = 0x80000000 ; // flt32 sign bit
169
181
170
182
static int32_t const infC = infN >> shift;
@@ -183,37 +195,91 @@ class MSHADOW_ALIGNED(2) half_t {
183
195
static int32_t const minD = minC - subC - 1 ;
184
196
185
197
MSHADOW_XINLINE uint16_t float2half (const float & value) const {
186
- Bits v, s ;
198
+ Bits v;
187
199
v.f = value;
188
- uint32_t sign = v.si & signN;
189
- v.si ^= sign;
190
- sign >>= shiftSign; // logical shift
191
- s.si = mulN;
192
- s.si = s.f * v.f ; // correct subnormals
193
- v.si ^= (s.si ^ v.si ) & -(minN > v.si );
194
- v.si ^= (infN ^ v.si ) & -((infN > v.si ) & (v.si > maxN));
195
- v.si ^= (nanN ^ v.si ) & -((nanN > v.si ) & (v.si > infN));
196
- v.ui >>= shift; // logical shift
197
- v.si ^= ((v.si - maxD) ^ v.si ) & -(v.si > maxC);
198
- v.si ^= ((v.si - minD) ^ v.si ) & -(v.si > subC);
199
- return v.ui | sign;
200
+ uint32_t sign = v.si & signN; // grab sign bit
201
+ v.si ^= sign; // clear sign bit from v
202
+ sign >>= shiftSign; // logical shift sign to fp16 position
203
+
204
+ if (v.si <= maxZ) {
205
+ // Handle eventual zeros here to ensure vshift will not exceed 32 below.
206
+ v.ui = 0 ;
207
+ } else if (v.si < minN) {
208
+ // Handle denorms
209
+ uint32_t exp32 = v.ui >> fp32FractionBits;
210
+ int32_t exp16 = exp32 - expAdjust;
211
+ // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
212
+ // Smaller (so negative) exp16 values should result in greater right shifts.
213
+ uint32_t vshift = 1 - exp16;
214
+ uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
215
+ v.ui = significand >> vshift;
216
+ // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
217
+ // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
218
+ // bits to the right of the lsb are 1000... (including flt32 significand bits
219
+ // that may be lost during the above vshift). The first term below will always
220
+ // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
221
+ // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
222
+ // the proper test of the flt32 significand bits, including those lost during the vshift.
223
+ #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
224
+ // Rounding may increase the exponent to 1, but that's OK.
225
+ v.ui += (v.ui & 0x3fff ) != 0x1000 || (significand & 0x7ff ) ? 0x1000 : 0 ;
226
+ #endif
227
+ } else if (v.si <= maxN) {
228
+ // Handle norms
229
+ #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
230
+ // Rounding may increase the exponent, possibly creating an inf, but that's OK.
231
+ v.ui += (v.ui & 0x3fff ) != 0x1000 ? 0x1000 : 0 ;
232
+ #endif
233
+ v.ui -= expAdjust << fp32FractionBits;
234
+ } else if (v.si <= infN) {
235
+ v.si = infN;
236
+ } else if (v.si < nanN) {
237
+ v.si = nanN;
238
+ }
239
+
240
+ v.ui >>= shift;
241
+ return sign | (v.ui & 0x7fff );
200
242
}
201
243
244
+ // Same as above routine, except for addition of volatile keyword
202
245
MSHADOW_XINLINE uint16_t float2half (const volatile float & value) const volatile { // NOLINT (*)
203
- Bits v, s ;
246
+ Bits v;
204
247
v.f = value;
205
- uint32_t sign = v.si & signN;
206
- v.si ^= sign;
207
- sign >>= shiftSign; // logical shift
208
- s.si = mulN;
209
- s.si = s.f * v.f ; // correct subnormals
210
- v.si ^= (s.si ^ v.si ) & -(minN > v.si );
211
- v.si ^= (infN ^ v.si ) & -((infN > v.si ) & (v.si > maxN));
212
- v.si ^= (nanN ^ v.si ) & -((nanN > v.si ) & (v.si > infN));
213
- v.ui >>= shift; // logical shift
214
- v.si ^= ((v.si - maxD) ^ v.si ) & -(v.si > maxC);
215
- v.si ^= ((v.si - minD) ^ v.si ) & -(v.si > subC);
216
- return v.ui | sign;
248
+ uint32_t sign = v.si & signN; // grab sign bit
249
+ v.si ^= sign; // clear sign bit from v
250
+ sign >>= shiftSign; // logical shift sign to fp16 position
251
+
252
+ if (v.si <= maxZ) {
253
+ // Handle eventual zeros here to ensure vshift will not exceed 32 below.
254
+ v.ui = 0 ;
255
+ } else if (v.si < minN) {
256
+ // Handle denorms
257
+ uint32_t exp32 = v.ui >> fp32FractionBits;
258
+ int32_t exp16 = exp32 - expAdjust;
259
+ // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
260
+ // Smaller (so negative) exp16 values should result in greater right shifts.
261
+ uint32_t vshift = 1 - exp16;
262
+ uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
263
+ v.ui = significand >> vshift;
264
+ #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
265
+ // Rounding may increase the exponent to 1, but that's OK.
266
+ v.ui += (v.ui & 0x3fff ) != 0x1000 || (significand & 0x7ff ) ? 0x1000 : 0 ;
267
+ #endif
268
+ } else if (v.si <= maxN) {
269
+ // Handle norms
270
+ #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
271
+ // Rounding may increase the exponent, possibly creating an inf, but that's OK.
272
+ v.ui += (v.ui & 0x3fff ) != 0x1000 ? 0x1000 : 0 ;
273
+ #endif
274
+ v.ui -= expAdjust << fp32FractionBits;
275
+ } else if (v.si <= infN) {
276
+ v.si = infN;
277
+ } else if (v.si < nanN) {
278
+ v.si = nanN;
279
+ }
280
+
281
+ v.ui >>= shift;
282
+ return sign | (v.ui & 0x7fff );
217
283
}
218
284
219
285
MSHADOW_XINLINE float half2float (const uint16_t & value) const {
0 commit comments