Skip to content

Commit af27b21

Browse files
optimize pack (#1153)
* optimize torch packer * revert change * optimize exllama pack * optimize triton pack * optimize ipex pack * clean test * update test codes * add test_pack_speed --------- Co-authored-by: CSY <[email protected]>
1 parent e478d58 commit af27b21

File tree

5 files changed

+127
-137
lines changed

5 files changed

+127
-137
lines changed

gptqmodel/nn_modules/qlinear/exllama.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -180,28 +180,21 @@ def pack(self, linear, scales, zeros, g_idx=None):
180180
intweight = intweight.t().contiguous()
181181
intweight = intweight.numpy().astype(np.uint32)
182182

183-
i = 0
184-
row = 0
185183
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
186-
while row < qweight.shape[0]:
187-
for j in range(i, i + (32 // self.bits)):
188-
qweight[row] |= intweight[j] << (self.bits * (j - i))
189-
i += 32 // self.bits
190-
row += 1
184+
for row in range(qweight.shape[0]):
185+
i = row * (32 // self.bits)
186+
for j in range(32 // self.bits):
187+
qweight[row] |= intweight[i + j] << (self.bits * j)
191188

192189
qweight = qweight.astype(np.int32)
193190
self.qweight = torch.from_numpy(qweight)
194191

195192
zeros = zeros.numpy().astype(np.uint32)
196193
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
197-
i = 0
198-
col = 0
199-
while col < qzeros.shape[1]:
200-
for j in range(i, i + (32 // self.bits)):
201-
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
202-
i += 32 // self.bits
203-
col += 1
204-
194+
for col in range(qzeros.shape[1]):
195+
i = col * (32 // self.bits)
196+
for j in range(32 // self.bits):
197+
qzeros[:, col] |= zeros[:, i + j] << (self.bits * j)
205198

206199
qzeros = qzeros.astype(np.int32)
207200
self.qzeros = torch.from_numpy(qzeros)

gptqmodel/nn_modules/qlinear/ipex.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -202,32 +202,25 @@ def pack(self, linear, scales, zeros, g_idx=None):
202202
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)
203203

204204
intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int)
205-
206205
intweight = intweight.t().contiguous()
207206
intweight = intweight.numpy().astype(np.uint32)
208207

209-
i = 0
210-
row = 0
211208
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
212-
while row < qweight.shape[0]:
213-
for j in range(i, i + (32 // self.bits)):
214-
qweight[row] |= intweight[j] << (self.bits * (j - i))
215-
i += 32 // self.bits
216-
row += 1
209+
for row in range(qweight.shape[0]):
210+
i = row * (32 // self.bits)
211+
for j in range(32 // self.bits):
212+
qweight[row] |= intweight[i + j] << (self.bits * j)
217213

218214
qweight = qweight.astype(np.int32)
219215
self.qweight = torch.from_numpy(qweight)
220216

221217
zeros -= 1
222218
zeros = zeros.numpy().astype(np.uint32)
223219
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
224-
i = 0
225-
col = 0
226-
while col < qzeros.shape[1]:
227-
for j in range(i, i + (32 // self.bits)):
228-
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
229-
i += 32 // self.bits
230-
col += 1
220+
for col in range(qzeros.shape[1]):
221+
i = col * (32 // self.bits)
222+
for j in range(32 // self.bits):
223+
qzeros[:, col] |= zeros[:, i + j] << (self.bits * j)
231224

232225
qzeros = qzeros.astype(np.int32)
233226
self.qzeros = torch.from_numpy(qzeros)

gptqmodel/nn_modules/qlinear/torch.py

+42-62
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import math
17-
1817
import numpy as np
1918
import torch
2019
import torch.nn as nn
@@ -26,7 +25,6 @@
2625

2726
from ...models._const import DEVICE, PLATFORM
2827

29-
3028
logger = setup_logger()
3129

3230
class TorchQuantLinear(BaseQuantLinear):
@@ -62,9 +60,7 @@ def __init__(
6260

6361
self.infeatures = infeatures
6462
self.outfeatures = outfeatures
65-
6663
self.padded_infeatures = infeatures + (-infeatures % group_size)
67-
6864
self.bits = bits
6965
self.group_size = group_size if group_size != -1 else infeatures
7066
self.maxq = 2**self.bits - 1
@@ -99,7 +95,6 @@ def __init__(
9995
else:
10096
self.bias = None
10197

102-
# is performed by unpacking the weights and using torch.matmul
10398
if self.bits in [2, 4, 8]:
10499
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
105100
elif self.bits == 3:
@@ -140,77 +135,61 @@ def pack(self, linear, scales, zeros, g_idx=None):
140135
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)
141136

142137
intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int)
143-
144138
intweight = intweight.t().contiguous()
145139
intweight = intweight.numpy().astype(np.uint32)
146140

147-
i = 0
148-
row = 0
149141
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
150-
while row < qweight.shape[0]:
151-
if self.bits in [2, 4, 8]:
152-
for j in range(i, i + (32 // self.bits)):
153-
qweight[row] |= intweight[j] << (self.bits * (j - i))
154-
i += 32 // self.bits
155-
row += 1
156-
elif self.bits == 3:
157-
for j in range(i, i + 10):
158-
qweight[row] |= intweight[j] << (3 * (j - i))
159-
i += 10
160-
qweight[row] |= intweight[i] << 30
161-
row += 1
162-
qweight[row] |= (intweight[i] >> 2) & 1
163-
i += 1
164-
for j in range(i, i + 10):
165-
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
166-
i += 10
167-
qweight[row] |= intweight[i] << 31
142+
if self.bits in [2, 4, 8]:
143+
bits_div = 32 // self.bits
144+
for row in range(qweight.shape[0]):
145+
for j in range(bits_div):
146+
qweight[row] |= intweight[row * bits_div + j] << (self.bits * j)
147+
elif self.bits == 3:
148+
for row in range(qweight.shape[0]):
149+
row_offset = row * 10 # Cache row * 10
150+
row_offset_plus_10 = row_offset + 10 # Cache row * 10 + 10
151+
for j in range(10):
152+
qweight[row] |= intweight[row_offset + j] << (3 * j)
153+
qweight[row] |= intweight[row_offset_plus_10] << 30
168154
row += 1
169-
qweight[row] |= (intweight[i] >> 1) & 0x3
170-
i += 1
171-
for j in range(i, i + 10):
172-
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
173-
i += 10
155+
qweight[row] |= (intweight[row_offset_plus_10] >> 2) & 1
156+
for j in range(10):
157+
qweight[row] |= intweight[row_offset + j] << (3 * j + 1)
158+
qweight[row] |= intweight[row_offset_plus_10] << 31
174159
row += 1
160+
qweight[row] |= (intweight[row_offset_plus_10] >> 1) & 0x3
161+
for j in range(10):
162+
qweight[row] |= intweight[row_offset + j] << (3 * j + 2)
175163

176-
qweight = qweight.astype(np.int32)
177-
self.qweight = torch.from_numpy(qweight)
164+
self.qweight = torch.from_numpy(qweight.astype(np.int32))
178165

179166
zeros = zeros.numpy().astype(np.uint32)
180167
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
181-
i = 0
182-
col = 0
183-
while col < qzeros.shape[1]:
184-
if self.bits in [2, 4, 8]:
185-
for j in range(i, i + (32 // self.bits)):
186-
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
187-
i += 32 // self.bits
188-
col += 1
189-
elif self.bits == 3:
190-
for j in range(i, i + 10):
191-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
192-
i += 10
193-
qzeros[:, col] |= zeros[:, i] << 30
194-
col += 1
195-
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
196-
i += 1
197-
for j in range(i, i + 10):
198-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
199-
i += 10
200-
qzeros[:, col] |= zeros[:, i] << 31
168+
if self.bits in [2, 4, 8]:
169+
bits_div = 32 // self.bits
170+
for col in range(qzeros.shape[1]):
171+
for j in range(bits_div):
172+
qzeros[:, col] |= zeros[:, col * bits_div + j] << (self.bits * j)
173+
elif self.bits == 3:
174+
for col in range(qzeros.shape[1]):
175+
col_offset = col * 10 # Cache col * 10
176+
col_offset_plus_10 = col_offset + 10 # Cache col * 10 + 10
177+
for j in range(10):
178+
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j)
179+
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 30
201180
col += 1
202-
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
203-
i += 1
204-
for j in range(i, i + 10):
205-
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
206-
i += 10
181+
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 2) & 1
182+
for j in range(10):
183+
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 1)
184+
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 31
207185
col += 1
186+
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 1) & 0x3
187+
for j in range(10):
188+
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 2)
208189

209-
qzeros = qzeros.astype(np.int32)
210-
self.qzeros = torch.from_numpy(qzeros)
190+
self.qzeros = torch.from_numpy(qzeros.astype(np.int32))
211191

212192
def forward(self, x: torch.Tensor):
213-
# if infeatures is padded, we need to pad the input as well
214193
if x.size(-1) != self.padded_infeatures:
215194
x = F.pad(x, (0, self.padded_infeatures - self.infeatures))
216195

@@ -241,6 +220,7 @@ def _empty_gptq_only_weights(self):
241220
def dequantize_weight(self, num_itr=1):
242221
if self.wf.device != self.qzeros.device:
243222
self.wf = self.wf.to(self.qzeros.device)
223+
244224
if self.bits in [2, 4, 8]:
245225
zeros = torch.bitwise_right_shift(
246226
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
@@ -293,4 +273,4 @@ def dequantize_weight(self, num_itr=1):
293273

294274
return weights
295275

296-
__all__ = ["TorchQuantLinear"]
276+
__all__ = ["TorchQuantLinear"]

gptqmodel/nn_modules/qlinear/tritonv2.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def post_init(self):
139139
self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32,
140140
device=self.g_idx.device)
141141

142-
143142
def pack(self, linear, scales, zeros, g_idx=None):
144143
W = linear.weight.data.clone()
145144
if isinstance(linear, nn.Conv2d):
@@ -157,31 +156,24 @@ def pack(self, linear, scales, zeros, g_idx=None):
157156
self.bias = linear.bias.clone().half()
158157

159158
intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int)
160-
161159
intweight = intweight.t().contiguous()
162160
intweight = intweight.numpy().astype(np.uint32)
163161

164-
i = 0
165-
row = 0
166162
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
167-
while row < qweight.shape[0]:
168-
for j in range(i, i + (32 // self.bits)):
169-
qweight[row] |= intweight[j] << (self.bits * (j - i))
170-
i += 32 // self.bits
171-
row += 1
163+
for row in range(qweight.shape[0]):
164+
i = row * (32 // self.bits)
165+
for j in range(32 // self.bits):
166+
qweight[row] |= intweight[i + j] << (self.bits * j)
172167

173168
qweight = qweight.astype(np.int32)
174169
self.qweight = torch.from_numpy(qweight)
175170

176171
zeros = zeros.numpy().astype(np.uint32)
177172
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
178-
i = 0
179-
col = 0
180-
while col < qzeros.shape[1]:
181-
for j in range(i, i + (32 // self.bits)):
182-
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
183-
i += 32 // self.bits
184-
col += 1
173+
for col in range(qzeros.shape[1]):
174+
i = col * (32 // self.bits)
175+
for j in range(32 // self.bits):
176+
qzeros[:, col] |= zeros[:, i + j] << (self.bits * j)
185177

186178
qzeros = qzeros.astype(np.int32)
187179
self.qzeros = torch.from_numpy(qzeros)

0 commit comments

Comments
 (0)