14
14
# limitations under the License.
15
15
16
16
import math
17
-
18
17
import numpy as np
19
18
import torch
20
19
import torch .nn as nn
26
25
27
26
from ...models ._const import DEVICE , PLATFORM
28
27
29
-
30
28
logger = setup_logger ()
31
29
32
30
class TorchQuantLinear (BaseQuantLinear ):
@@ -62,9 +60,7 @@ def __init__(
62
60
63
61
self .infeatures = infeatures
64
62
self .outfeatures = outfeatures
65
-
66
63
self .padded_infeatures = infeatures + (- infeatures % group_size )
67
-
68
64
self .bits = bits
69
65
self .group_size = group_size if group_size != - 1 else infeatures
70
66
self .maxq = 2 ** self .bits - 1
@@ -99,7 +95,6 @@ def __init__(
99
95
else :
100
96
self .bias = None
101
97
102
- # is performed by unpacking the weights and using torch.matmul
103
98
if self .bits in [2 , 4 , 8 ]:
104
99
self .wf = torch .tensor (list (range (0 , 32 , self .bits )), dtype = torch .int32 ).unsqueeze (0 )
105
100
elif self .bits == 3 :
@@ -140,77 +135,61 @@ def pack(self, linear, scales, zeros, g_idx=None):
140
135
self .bias = linear .bias .clone ().to (dtype = linear .weight .dtype )
141
136
142
137
intweight = torch .round ((W + scale_zeros [self .g_idx ].T ) / scales [self .g_idx ].T ).to (torch .int )
143
-
144
138
intweight = intweight .t ().contiguous ()
145
139
intweight = intweight .numpy ().astype (np .uint32 )
146
140
147
- i = 0
148
- row = 0
149
141
qweight = np .zeros ((intweight .shape [0 ] // 32 * self .bits , intweight .shape [1 ]), dtype = np .uint32 )
150
- while row < qweight .shape [0 ]:
151
- if self .bits in [2 , 4 , 8 ]:
152
- for j in range (i , i + (32 // self .bits )):
153
- qweight [row ] |= intweight [j ] << (self .bits * (j - i ))
154
- i += 32 // self .bits
155
- row += 1
156
- elif self .bits == 3 :
157
- for j in range (i , i + 10 ):
158
- qweight [row ] |= intweight [j ] << (3 * (j - i ))
159
- i += 10
160
- qweight [row ] |= intweight [i ] << 30
161
- row += 1
162
- qweight [row ] |= (intweight [i ] >> 2 ) & 1
163
- i += 1
164
- for j in range (i , i + 10 ):
165
- qweight [row ] |= intweight [j ] << (3 * (j - i ) + 1 )
166
- i += 10
167
- qweight [row ] |= intweight [i ] << 31
142
+ if self .bits in [2 , 4 , 8 ]:
143
+ bits_div = 32 // self .bits
144
+ for row in range (qweight .shape [0 ]):
145
+ for j in range (bits_div ):
146
+ qweight [row ] |= intweight [row * bits_div + j ] << (self .bits * j )
147
+ elif self .bits == 3 :
148
+ for row in range (qweight .shape [0 ]):
149
+ row_offset = row * 10 # Cache row * 10
150
+ row_offset_plus_10 = row_offset + 10 # Cache row * 10 + 10
151
+ for j in range (10 ):
152
+ qweight [row ] |= intweight [row_offset + j ] << (3 * j )
153
+ qweight [row ] |= intweight [row_offset_plus_10 ] << 30
168
154
row += 1
169
- qweight [row ] |= (intweight [i ] >> 1 ) & 0x3
170
- i += 1
171
- for j in range (i , i + 10 ):
172
- qweight [row ] |= intweight [j ] << (3 * (j - i ) + 2 )
173
- i += 10
155
+ qweight [row ] |= (intweight [row_offset_plus_10 ] >> 2 ) & 1
156
+ for j in range (10 ):
157
+ qweight [row ] |= intweight [row_offset + j ] << (3 * j + 1 )
158
+ qweight [row ] |= intweight [row_offset_plus_10 ] << 31
174
159
row += 1
160
+ qweight [row ] |= (intweight [row_offset_plus_10 ] >> 1 ) & 0x3
161
+ for j in range (10 ):
162
+ qweight [row ] |= intweight [row_offset + j ] << (3 * j + 2 )
175
163
176
- qweight = qweight .astype (np .int32 )
177
- self .qweight = torch .from_numpy (qweight )
164
+ self .qweight = torch .from_numpy (qweight .astype (np .int32 ))
178
165
179
166
zeros = zeros .numpy ().astype (np .uint32 )
180
167
qzeros = np .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 32 * self .bits ), dtype = np .uint32 )
181
- i = 0
182
- col = 0
183
- while col < qzeros .shape [1 ]:
184
- if self .bits in [2 , 4 , 8 ]:
185
- for j in range (i , i + (32 // self .bits )):
186
- qzeros [:, col ] |= zeros [:, j ] << (self .bits * (j - i ))
187
- i += 32 // self .bits
188
- col += 1
189
- elif self .bits == 3 :
190
- for j in range (i , i + 10 ):
191
- qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ))
192
- i += 10
193
- qzeros [:, col ] |= zeros [:, i ] << 30
194
- col += 1
195
- qzeros [:, col ] |= (zeros [:, i ] >> 2 ) & 1
196
- i += 1
197
- for j in range (i , i + 10 ):
198
- qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 1 )
199
- i += 10
200
- qzeros [:, col ] |= zeros [:, i ] << 31
168
+ if self .bits in [2 , 4 , 8 ]:
169
+ bits_div = 32 // self .bits
170
+ for col in range (qzeros .shape [1 ]):
171
+ for j in range (bits_div ):
172
+ qzeros [:, col ] |= zeros [:, col * bits_div + j ] << (self .bits * j )
173
+ elif self .bits == 3 :
174
+ for col in range (qzeros .shape [1 ]):
175
+ col_offset = col * 10 # Cache col * 10
176
+ col_offset_plus_10 = col_offset + 10 # Cache col * 10 + 10
177
+ for j in range (10 ):
178
+ qzeros [:, col ] |= zeros [:, col_offset + j ] << (3 * j )
179
+ qzeros [:, col ] |= zeros [:, col_offset_plus_10 ] << 30
201
180
col += 1
202
- qzeros [:, col ] |= (zeros [:, i ] >> 1 ) & 0x3
203
- i += 1
204
- for j in range (i , i + 10 ):
205
- qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 2 )
206
- i += 10
181
+ qzeros [:, col ] |= (zeros [:, col_offset_plus_10 ] >> 2 ) & 1
182
+ for j in range (10 ):
183
+ qzeros [:, col ] |= zeros [:, col_offset + j ] << (3 * j + 1 )
184
+ qzeros [:, col ] |= zeros [:, col_offset_plus_10 ] << 31
207
185
col += 1
186
+ qzeros [:, col ] |= (zeros [:, col_offset_plus_10 ] >> 1 ) & 0x3
187
+ for j in range (10 ):
188
+ qzeros [:, col ] |= zeros [:, col_offset + j ] << (3 * j + 2 )
208
189
209
- qzeros = qzeros .astype (np .int32 )
210
- self .qzeros = torch .from_numpy (qzeros )
190
+ self .qzeros = torch .from_numpy (qzeros .astype (np .int32 ))
211
191
212
192
def forward (self , x : torch .Tensor ):
213
- # if infeatures is padded, we need to pad the input as well
214
193
if x .size (- 1 ) != self .padded_infeatures :
215
194
x = F .pad (x , (0 , self .padded_infeatures - self .infeatures ))
216
195
@@ -241,6 +220,7 @@ def _empty_gptq_only_weights(self):
241
220
def dequantize_weight (self , num_itr = 1 ):
242
221
if self .wf .device != self .qzeros .device :
243
222
self .wf = self .wf .to (self .qzeros .device )
223
+
244
224
if self .bits in [2 , 4 , 8 ]:
245
225
zeros = torch .bitwise_right_shift (
246
226
torch .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , 32 // self .bits ),
@@ -293,4 +273,4 @@ def dequantize_weight(self, num_itr=1):
293
273
294
274
return weights
295
275
296
- __all__ = ["TorchQuantLinear" ]
276
+ __all__ = ["TorchQuantLinear" ]
0 commit comments