Skip to content

Commit cf21fa0

Browse files
authored
Merge pull request #60 from okotaku/feat/pre-compute-embs
[Feature] Pre compute Text Embeddings
2 parents fd1c9c3 + 6167caa commit cf21fa0

File tree

11 files changed

+344
-43
lines changed

11 files changed

+344
-43
lines changed

.devcontainer/devcontainer.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
"ms-vscode-remote.remote-ssh-edit",
2626
"ms-vscode.remote-explorer",
2727
"wayou.vscode-todo-highlight",
28-
"Gruntfuggly.todo-tree"
28+
"Gruntfuggly.todo-tree",
29+
"streetsidesoftware.code-spell-checker"
2930
]
3031
}
3132
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
train_pipeline = [
2+
dict(type='SaveImageShape'),
3+
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
4+
dict(type='RandomCrop', size=1024),
5+
dict(type='RandomHorizontalFlip', p=0.5),
6+
dict(type='ComputeTimeIds'),
7+
dict(type='torchvision/ToTensor'),
8+
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
9+
dict(
10+
type='PackInputs',
11+
input_keys=[
12+
'img', 'time_ids', 'prompt_embeds', 'pooled_prompt_embeds'
13+
]),
14+
]
15+
train_dataloader = dict(
16+
batch_size=2,
17+
num_workers=2,
18+
dataset=dict(
19+
type='HFDatasetPreComputeEmbs',
20+
dataset='lambdalabs/pokemon-blip-captions',
21+
text_hasher='text_pokemon_blip',
22+
model='stabilityai/stable-diffusion-xl-base-1.0',
23+
pipeline=train_pipeline),
24+
sampler=dict(type='DefaultSampler', shuffle=True),
25+
)
26+
27+
val_dataloader = None
28+
val_evaluator = None
29+
test_dataloader = val_dataloader
30+
test_evaluator = val_evaluator
31+
32+
custom_hooks = [
33+
dict(
34+
type='VisualizationHook',
35+
prompt=['yoda pokemon'] * 4,
36+
height=1024,
37+
width=1024),
38+
dict(type='SDCheckpointHook')
39+
]

configs/stable_diffusion_xl/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,7 @@ You can see more details on [`docs/source/run_guides/run_xl.md`](../../docs/sour
7676
#### stable_diffusion_xl_pokemon_blip
7777

7878
![example1](https://github.com/okotaku/diffengine/assets/24734142/dd04fb22-64fb-4c4f-8164-b8391d94abab)
79+
80+
#### stable_diffusion_xl_pokemon_blip_pre_compute
81+
82+
![example2](https://github.com/okotaku/diffengine/assets/24734142/5da59a56-ce36-48cc-b113-007f8b9faeba)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
'../_base_/models/stable_diffusion_xl.py',
3+
'../_base_/datasets/pokemon_blip_xl_pre_compute.py',
4+
'../_base_/schedules/stable_diffusion_xl_50e.py',
5+
'../_base_/default_runtime.py'
6+
]
7+
8+
model = dict(pre_compute_text_embeddings=True)
9+
10+
train_dataloader = dict(batch_size=1)
11+
12+
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times

diffengine/datasets/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from .hf_controlnet_datasets import HFControlNetDataset
2-
from .hf_datasets import HFDataset
2+
from .hf_datasets import HFDataset, HFDatasetPreComputeEmbs
33
from .hf_dreambooth_datasets import HFDreamBoothDataset
44
from .samplers import * # noqa: F401, F403
55
from .transforms import * # noqa: F401, F403
66

7-
__all__ = ['HFDataset', 'HFDreamBoothDataset', 'HFControlNetDataset']
7+
__all__ = [
8+
'HFDataset', 'HFDreamBoothDataset', 'HFControlNetDataset',
9+
'HFDatasetPreComputeEmbs'
10+
]

diffengine/datasets/hf_datasets.py

+89
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import functools
2+
import gc
13
import os
24
import random
35
from pathlib import Path
46
from typing import Optional, Sequence
57

68
import numpy as np
9+
import torch
710
from datasets import load_dataset
11+
from datasets.fingerprint import Hasher
812
from mmengine.dataset.base_dataset import Compose
913
from PIL import Image
1014
from torch.utils.data import Dataset
15+
from transformers import AutoTokenizer
1116

17+
from diffengine.datasets.utils import encode_prompt_sdxl
18+
from diffengine.models.editors.stable_diffusion_xl.stable_diffusion_xl import \
19+
import_model_class_from_model_name_or_path
1220
from diffengine.registry import DATASETS
1321

1422
Image.MAX_IMAGE_PIXELS = 1000000000
@@ -88,3 +96,84 @@ def __getitem__(self, idx: int) -> dict:
8896
result = self.pipeline(result)
8997

9098
return result
99+
100+
101+
@DATASETS.register_module()
102+
class HFDatasetPreComputeEmbs(HFDataset):
103+
"""Dataset for huggingface datasets.
104+
105+
The difference from HFDataset is
106+
1. pre-compute Text Encoder embeddings to save memory.
107+
108+
Args:
109+
model (str): pretrained model name of stable diffusion xl.
110+
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
111+
text_hasher (str): Text embeddings hasher name. Defaults to 'text'.
112+
device (str): Device used to compute embeddings. Defaults to 'cuda'.
113+
proportion_empty_prompts (float): The probabilities to replace empty
114+
text. Defaults to 0.9.
115+
"""
116+
117+
def __init__(self,
118+
*args,
119+
model: str = 'stabilityai/stable-diffusion-xl-base-1.0',
120+
text_hasher: str = 'text',
121+
device: str = 'cuda',
122+
proportion_empty_prompts: float = 0.0,
123+
**kwargs) -> None:
124+
super().__init__(*args, **kwargs)
125+
126+
tokenizer_one = AutoTokenizer.from_pretrained(
127+
model, subfolder='tokenizer', use_fast=False)
128+
tokenizer_two = AutoTokenizer.from_pretrained(
129+
model, subfolder='tokenizer_2', use_fast=False)
130+
131+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
132+
model)
133+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
134+
model, subfolder='text_encoder_2')
135+
text_encoder_one = text_encoder_cls_one.from_pretrained(
136+
model, subfolder='text_encoder').to(device)
137+
text_encoder_two = text_encoder_cls_two.from_pretrained(
138+
model, subfolder='text_encoder_2').to(device)
139+
140+
new_fingerprint = Hasher.hash(text_hasher)
141+
compute_embeddings_fn = functools.partial(
142+
encode_prompt_sdxl,
143+
text_encoders=[text_encoder_one, text_encoder_two],
144+
tokenizers=[tokenizer_one, tokenizer_two],
145+
proportion_empty_prompts=proportion_empty_prompts,
146+
caption_column=self.caption_column,
147+
)
148+
self.dataset = self.dataset.map(
149+
compute_embeddings_fn,
150+
batched=True,
151+
new_fingerprint=new_fingerprint)
152+
153+
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
154+
gc.collect()
155+
torch.cuda.empty_cache()
156+
157+
def __getitem__(self, idx: int) -> dict:
158+
"""Get the idx-th image and data information of dataset after
159+
``self.train_transforms`.
160+
161+
Args:
162+
idx (int): The index of self.data_list.
163+
164+
Returns:
165+
dict: The idx-th image and data information of dataset after
166+
``self.train_transforms``.
167+
"""
168+
data_info = self.dataset[idx]
169+
image = data_info[self.image_column]
170+
if type(image) == str:
171+
image = Image.open(os.path.join(self.dataset_name, image))
172+
image = image.convert('RGB')
173+
result = dict(
174+
img=image,
175+
prompt_embeds=data_info['prompt_embeds'],
176+
pooled_prompt_embeds=data_info['pooled_prompt_embeds'])
177+
result = self.pipeline(result)
178+
179+
return result

diffengine/datasets/utils.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import random
2+
from typing import Dict
3+
4+
import numpy as np
5+
import torch
6+
7+
8+
def encode_prompt_sdxl(batch,
9+
text_encoders,
10+
tokenizers,
11+
proportion_empty_prompts,
12+
caption_column,
13+
is_train: bool = True) -> Dict[str, torch.Tensor]:
14+
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
15+
prompt_embeds_list = []
16+
prompt_batch = batch[caption_column]
17+
18+
captions = []
19+
for caption in prompt_batch:
20+
if random.random() < proportion_empty_prompts:
21+
captions.append('')
22+
elif isinstance(caption, str):
23+
captions.append(caption)
24+
elif isinstance(caption, (list, np.ndarray)):
25+
# take a random caption if there are multiple
26+
captions.append(random.choice(caption) if is_train else caption[0])
27+
28+
with torch.no_grad():
29+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
30+
text_inputs = tokenizer(
31+
captions,
32+
padding='max_length',
33+
max_length=tokenizer.model_max_length,
34+
truncation=True,
35+
return_tensors='pt',
36+
)
37+
text_input_ids = text_inputs.input_ids
38+
prompt_embeds = text_encoder(
39+
text_input_ids.to(text_encoder.device),
40+
output_hidden_states=True,
41+
)
42+
43+
# We are only ALWAYS interested in the pooled output of the final
44+
# text encoder
45+
pooled_prompt_embeds = prompt_embeds[0]
46+
prompt_embeds = prompt_embeds.hidden_states[-2]
47+
bs_embed, seq_len, _ = prompt_embeds.shape
48+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
49+
prompt_embeds_list.append(prompt_embeds)
50+
51+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
52+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
53+
return {
54+
'prompt_embeds': prompt_embeds.cpu(),
55+
'pooled_prompt_embeds': pooled_prompt_embeds.cpu()
56+
}

diffengine/models/editors/stable_diffusion_xl/sdxl_data_preprocessor.py

+7
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,11 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
3434

3535
data['inputs']['img'] = torch.stack(data['inputs']['img'])
3636
data['inputs']['time_ids'] = torch.stack(data['inputs']['time_ids'])
37+
# pre-compute text embeddings
38+
if 'prompt_embeds' in data['inputs']:
39+
data['inputs']['prompt_embeds'] = torch.stack(
40+
data['inputs']['prompt_embeds'])
41+
if 'pooled_prompt_embeds' in data['inputs']:
42+
data['inputs']['pooled_prompt_embeds'] = torch.stack(
43+
data['inputs']['pooled_prompt_embeds'])
3744
return super().forward(data) # type: ignore

0 commit comments

Comments
 (0)