Skip to content

Commit d06a725

Browse files
add default value for backend, fix optimum doesn't pass it (#1334)
* add default value for backend * Update tritonv2.py * Update marlin.py --------- Co-authored-by: Qubitium-ModelCloud <[email protected]>
1 parent a2ac0b0 commit d06a725

File tree

7 files changed

+13
-0
lines changed

7 files changed

+13
-0
lines changed

gptqmodel/nn_modules/qlinear/dynamic_cuda.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ...utils.backend import BACKEND
2122
from ...models._const import DEVICE, PLATFORM
2223
from ...adapter.adapter import Adapter, Lora
2324
from ...nn_modules.qlinear.torch import TorchQuantLinear
@@ -80,6 +81,7 @@ def __init__(
8081
out_features=out_features,
8182
bias=bias,
8283
pack_dtype=pack_dtype,
84+
backend=kwargs.pop("backend", BACKEND.CUDA),
8385
adapter=adapter,
8486
**kwargs)
8587

gptqmodel/nn_modules/qlinear/exllama.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323

24+
from ...utils.backend import BACKEND
2425
from ...adapter.adapter import Adapter, Lora
2526
from ...models._const import DEVICE, PLATFORM
2627
from ...nn_modules.qlinear import BaseQuantLinear
@@ -103,6 +104,7 @@ def __init__(
103104
out_features=out_features,
104105
bias=bias,
105106
pack_dtype=pack_dtype,
107+
backend=kwargs.pop("backend", BACKEND.EXLLAMA_V1),
106108
adapter=adapter,
107109
register_buffers=True,
108110
register_buffers_in_features=in_features,

gptqmodel/nn_modules/qlinear/exllamav2.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
from ...utils.backend import BACKEND
2324
from ...adapter.adapter import Adapter, Lora
2425
from ...models._const import DEVICE, PLATFORM
2526
from ...nn_modules.qlinear import BaseQuantLinear
@@ -176,6 +177,7 @@ def __init__(
176177
out_features=out_features,
177178
bias=bias,
178179
pack_dtype=pack_dtype,
180+
backend=kwargs.pop("backend", BACKEND.EXLLAMA_V2),
179181
adapter=adapter,
180182
register_buffers=True,
181183
register_buffers_in_features=in_features,

gptqmodel/nn_modules/qlinear/ipex.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ...utils.backend import BACKEND
2122
from ...utils.logger import setup_logger
2223
from ...utils.torch import torch_compile
2324
from ...adapter.adapter import Adapter, Lora
@@ -127,6 +128,7 @@ def __init__(
127128
pack_dtype=pack_dtype,
128129
adapter=adapter,
129130
register_buffers=True,
131+
backend=kwargs.pop("backend", BACKEND.IPEX),
130132
**kwargs)
131133

132134
self.weight_dtype = torch.float16

gptqmodel/nn_modules/qlinear/marlin.py

+1
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
out_features=out_features,
217217
bias=bias,
218218
pack_dtype=pack_dtype,
219+
backend=kwargs.pop("backend", BACKEND.MARLIN),
219220
adapter=adapter,
220221
register_buffers=False,
221222
**kwargs)

gptqmodel/nn_modules/qlinear/torch.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.nn as nn
2020
from transformers import PreTrainedModel
2121

22+
from ...utils.backend import BACKEND
2223
from ...models._const import DEVICE, PLATFORM
2324
from ...utils.torch import torch_compile
2425
from ...adapter.adapter import Adapter, Lora
@@ -67,6 +68,7 @@ def __init__(
6768
out_features=out_features,
6869
bias=bias,
6970
pack_dtype=pack_dtype,
71+
backend=kwargs.pop("backend", BACKEND.TORCH),
7072
adapter=adapter,
7173
register_buffers=True,
7274
**kwargs)

gptqmodel/nn_modules/qlinear/tritonv2.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from packaging import version
2121

22+
from ...utils.backend import BACKEND
2223
from ...models._const import DEVICE, PLATFORM
2324
from ...utils.logger import setup_logger
2425
from ...adapter.adapter import Adapter, Lora
@@ -95,6 +96,7 @@ def __init__(
9596
out_features=out_features,
9697
bias=bias,
9798
pack_dtype=pack_dtype,
99+
backend=kwargs.pop("backend", BACKEND.TRITON),
98100
adapter=adapter,
99101
register_buffers=True,
100102
**kwargs)

0 commit comments

Comments
 (0)