Skip to content

Commit e26763a

Browse files
committed
Minor changes
1 parent 7994b58 commit e26763a

File tree

3 files changed

+4
-10
lines changed

3 files changed

+4
-10
lines changed

.gitignore

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ torch-dumps/*
99
*pycache*
1010
*/*pycache*
1111
*/*/pycache*
12+
venv/
13+
.venv/
1214

1315
settings.json
1416
img_bot*
@@ -19,6 +21,3 @@ img_me*
1921
!models/place-your-models-here.txt
2022
!softprompts/place-your-softprompts-here.txt
2123
!torch-dumps/place-your-pt-models-here.txt
22-
23-
venv/
24-
.venv/

modules/models.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,13 @@ def load_model(model_name):
4747
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
4848
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
4949
else:
50-
model = AutoModelForCausalLM.from_pretrained(
51-
Path(f"models/{shared.model_name}"),
52-
low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16
53-
)
50+
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
5451
if torch.has_mps:
5552
device = torch.device('mps')
5653
model = model.to(device)
5754
else:
5855
model = model.cuda()
5956

60-
6157
# FlexGen
6258
elif shared.args.flexgen:
6359
# Initialize environment
@@ -106,7 +102,7 @@ def load_model(model_name):
106102
# Custom
107103
else:
108104
params = {"low_cpu_mem_usage": True}
109-
if not shared.args.cpu and not torch.cuda.is_available() and not torch.has_mps:
105+
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
110106
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
111107
shared.args.cpu = True
112108

modules/text_generation.py

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
3939
else:
4040
return input_ids.cuda()
4141

42-
4342
def decode(output_ids):
4443
# Open Assistant relies on special tokens like <|endoftext|>
4544
if re.match('(oasst|galactica)-*', shared.model_name.lower()):

0 commit comments

Comments
 (0)