File tree 7 files changed +13
-0
lines changed
gptqmodel/nn_modules/qlinear
7 files changed +13
-0
lines changed Original file line number Diff line number Diff line change 18
18
19
19
import torch
20
20
21
+ from ...utils .backend import BACKEND
21
22
from ...models ._const import DEVICE , PLATFORM
22
23
from ...adapter .adapter import Adapter , Lora
23
24
from ...nn_modules .qlinear .torch import TorchQuantLinear
@@ -80,6 +81,7 @@ def __init__(
80
81
out_features = out_features ,
81
82
bias = bias ,
82
83
pack_dtype = pack_dtype ,
84
+ backend = kwargs .pop ("backend" , BACKEND .CUDA ),
83
85
adapter = adapter ,
84
86
** kwargs )
85
87
Original file line number Diff line number Diff line change 21
21
22
22
import torch
23
23
24
+ from ...utils .backend import BACKEND
24
25
from ...adapter .adapter import Adapter , Lora
25
26
from ...models ._const import DEVICE , PLATFORM
26
27
from ...nn_modules .qlinear import BaseQuantLinear
@@ -103,6 +104,7 @@ def __init__(
103
104
out_features = out_features ,
104
105
bias = bias ,
105
106
pack_dtype = pack_dtype ,
107
+ backend = kwargs .pop ("backend" , BACKEND .EXLLAMA_V1 ),
106
108
adapter = adapter ,
107
109
register_buffers = True ,
108
110
register_buffers_in_features = in_features ,
Original file line number Diff line number Diff line change 20
20
21
21
import torch
22
22
23
+ from ...utils .backend import BACKEND
23
24
from ...adapter .adapter import Adapter , Lora
24
25
from ...models ._const import DEVICE , PLATFORM
25
26
from ...nn_modules .qlinear import BaseQuantLinear
@@ -176,6 +177,7 @@ def __init__(
176
177
out_features = out_features ,
177
178
bias = bias ,
178
179
pack_dtype = pack_dtype ,
180
+ backend = kwargs .pop ("backend" , BACKEND .EXLLAMA_V2 ),
179
181
adapter = adapter ,
180
182
register_buffers = True ,
181
183
register_buffers_in_features = in_features ,
Original file line number Diff line number Diff line change 18
18
19
19
import torch
20
20
21
+ from ...utils .backend import BACKEND
21
22
from ...utils .logger import setup_logger
22
23
from ...utils .torch import torch_compile
23
24
from ...adapter .adapter import Adapter , Lora
@@ -127,6 +128,7 @@ def __init__(
127
128
pack_dtype = pack_dtype ,
128
129
adapter = adapter ,
129
130
register_buffers = True ,
131
+ backend = kwargs .pop ("backend" , BACKEND .IPEX ),
130
132
** kwargs )
131
133
132
134
self .weight_dtype = torch .float16
Original file line number Diff line number Diff line change @@ -216,6 +216,7 @@ def __init__(
216
216
out_features = out_features ,
217
217
bias = bias ,
218
218
pack_dtype = pack_dtype ,
219
+ backend = kwargs .pop ("backend" , BACKEND .MARLIN ),
219
220
adapter = adapter ,
220
221
register_buffers = False ,
221
222
** kwargs )
Original file line number Diff line number Diff line change 19
19
import torch .nn as nn
20
20
from transformers import PreTrainedModel
21
21
22
+ from ...utils .backend import BACKEND
22
23
from ...models ._const import DEVICE , PLATFORM
23
24
from ...utils .torch import torch_compile
24
25
from ...adapter .adapter import Adapter , Lora
@@ -67,6 +68,7 @@ def __init__(
67
68
out_features = out_features ,
68
69
bias = bias ,
69
70
pack_dtype = pack_dtype ,
71
+ backend = kwargs .pop ("backend" , BACKEND .TORCH ),
70
72
adapter = adapter ,
71
73
register_buffers = True ,
72
74
** kwargs )
Original file line number Diff line number Diff line change 19
19
import torch
20
20
from packaging import version
21
21
22
+ from ...utils .backend import BACKEND
22
23
from ...models ._const import DEVICE , PLATFORM
23
24
from ...utils .logger import setup_logger
24
25
from ...adapter .adapter import Adapter , Lora
@@ -95,6 +96,7 @@ def __init__(
95
96
out_features = out_features ,
96
97
bias = bias ,
97
98
pack_dtype = pack_dtype ,
99
+ backend = kwargs .pop ("backend" , BACKEND .TRITON ),
98
100
adapter = adapter ,
99
101
register_buffers = True ,
100
102
** kwargs )
You can’t perform that action at this time.
0 commit comments