Skip to content

Commit 9ef504e

Browse files
authored
Merge pull request #116 from okotaku/feat/kandinsky
[Feature] Support Kandinsky v2.2
2 parents 0ed5f89 + e72def0 commit 9ef504e

27 files changed

+1483
-15
lines changed

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvcr.io/nvidia/pytorch:23.11-py3
1+
FROM nvcr.io/nvidia/pytorch:23.12-py3
22

33
RUN apt update -y && apt install -y \
44
git tmux

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
141141
- [Run PixArt-α](https://diffengine.readthedocs.io/en/latest/run_guides/run_pixart_alpha.html)
142142
- [Run PixArt-α LoRA](https://diffengine.readthedocs.io/en/latest/run_guides/run_pixart_alpha_lora.html)
143143
- [Run PixArt-α DreamBooth](https://diffengine.readthedocs.io/en/latest/run_guides/run_pixart_alpha_dreambooth.html)
144+
- - [Run Kandinsky 2.2](https://diffengine.readthedocs.io/en/latest/run_guides/run_kandinsky_v22.html)
144145
- [Inference](https://diffengine.readthedocs.io/en/latest/run_guides/inference.html)
145146

146147
</details>
@@ -248,6 +249,9 @@ For detailed user guides and advanced guides, please refer to our [Documentation
248249
<td>
249250
<b>PixArt-α</b>
250251
</td>
252+
<td>
253+
<b>Kandinsky</b>
254+
</td>
251255
</tr>
252256
<tr valign="top">
253257
<td>
@@ -269,6 +273,11 @@ For detailed user guides and advanced guides, please refer to our [Documentation
269273
<li><a href="configs/pixart_alpha_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
270274
</ul>
271275
</td>
276+
<td>
277+
<ul>
278+
<li><a href="configs/kandinsky_v22/README.md">Kandinsky 2.2 (2023)</a></li>
279+
</ul>
280+
</td>
272281
</tr>
273282
</td>
274283
</tr>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
train_pipeline = [
2+
dict(type="CLIPImageProcessor",
3+
pretrained="kandinsky-community/kandinsky-2-2-prior"),
4+
dict(type="torchvision/Resize", size=768, interpolation="bicubic"),
5+
dict(type="RandomCrop", size=768),
6+
dict(type="RandomHorizontalFlip", p=0.5),
7+
dict(type="torchvision/ToTensor"),
8+
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
9+
dict(type="PackInputs", input_keys=["img", "text", "clip_img"]),
10+
]
11+
train_dataloader = dict(
12+
batch_size=4,
13+
num_workers=4,
14+
dataset=dict(
15+
type="HFDataset",
16+
dataset="lambdalabs/pokemon-blip-captions",
17+
pipeline=train_pipeline),
18+
sampler=dict(type="DefaultSampler", shuffle=True),
19+
)
20+
21+
val_dataloader = None
22+
val_evaluator = None
23+
test_dataloader = val_dataloader
24+
test_evaluator = val_evaluator
25+
26+
custom_hooks = [
27+
dict(type="VisualizationHook", prompt=["yoda pokemon"] * 4,
28+
height=768, width=768),
29+
dict(type="SDCheckpointHook"),
30+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
train_pipeline = [
2+
dict(type="CLIPImageProcessor", output_key="img",
3+
pretrained="kandinsky-community/kandinsky-2-2-prior"),
4+
dict(type="PackInputs"),
5+
]
6+
train_dataloader = dict(
7+
batch_size=4,
8+
num_workers=4,
9+
dataset=dict(
10+
type="HFDataset",
11+
dataset="lambdalabs/pokemon-blip-captions",
12+
pipeline=train_pipeline),
13+
sampler=dict(type="DefaultSampler", shuffle=True),
14+
)
15+
16+
val_dataloader = None
17+
val_evaluator = None
18+
test_dataloader = val_dataloader
19+
test_evaluator = val_evaluator
20+
21+
custom_hooks = [
22+
dict(type="VisualizationHook", prompt=["yoda pokemon"] * 4,
23+
height=512, width=512),
24+
dict(type="PriorSaveHook"),
25+
]

configs/_base_/datasets/pokemon_blip_wuerstchen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
custom_hooks = [
2626
dict(type="VisualizationHook", prompt=["A robot pokemon, 4k photo"] * 4,
2727
height=768, width=768),
28-
dict(type="WuerstchenSaveHook"),
28+
dict(type="PriorSaveHook"),
2929
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
model = dict(
2+
type="KandinskyV22Decoder",
3+
decoder_model="kandinsky-community/kandinsky-2-2-decoder",
4+
prior_model="kandinsky-community/kandinsky-2-2-prior")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
model = dict(
2+
type="KandinskyV22Prior",
3+
decoder_model="kandinsky-community/kandinsky-2-2-decoder",
4+
prior_model="kandinsky-community/kandinsky-2-2-prior")

configs/kandinsky_v22/README.md

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Kandinsky 2.2
2+
3+
[Kandinsky 2.2](https://habr.com/ru/companies/sberbank/articles/747446/)
4+
5+
## Abstract
6+
7+
Kandinsky 2.2 brings substantial improvements upon its predecessor, Kandinsky 2.1, by introducing a new, more powerful image encoder - CLIP-ViT-G and the ControlNet support. The switch to CLIP-ViT-G as the image encoder significantly increases the model’s capability to generate more aesthetic pictures and better understand text, thus enhancing the model’s overall performance. The addition of the ControlNet mechanism allows the model to effectively control the process of generating images. This leads to more accurate and visually appealing outputs and opens new possibilities for text-guided image manipulation.
8+
9+
<div align=center>
10+
<img src="https://github.com/okotaku/diffengine/assets/24734142/b07d82fb-4c2c-4216-a4b1-a64b278cee2a"/>
11+
</div>
12+
13+
## Citation
14+
15+
```
16+
```
17+
18+
## Run Training
19+
20+
Run Training
21+
22+
```
23+
# single gpu
24+
$ mim train diffengine ${CONFIG_FILE}
25+
# multi gpus
26+
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
27+
28+
# Example.
29+
$ mim train diffengine configs/kandinsky_v22/kandinsky_v22_prior_pokemon_blip.py
30+
```
31+
32+
## Inference prior with diffusers
33+
34+
Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
35+
36+
```py
37+
import torch
38+
from diffusers import AutoPipelineForText2Image, PriorTransformer
39+
40+
prompt = 'yoda pokemon'
41+
checkpoint = 'work_dirs/kandinsky_v22_prior_pokemon_blip/step10450'
42+
43+
prior = PriorTransformer.from_pretrained(
44+
checkpoint, subfolder="prior",
45+
)
46+
pipe = AutoPipelineForText2Image.from_pretrained(
47+
"kandinsky-community/kandinsky-2-2-decoder",
48+
prior_prior=prior,
49+
torch_dtype=torch.float32,
50+
)
51+
pipe.to('cuda')
52+
53+
image = pipe(
54+
prompt,
55+
num_inference_steps=50,
56+
width=512,
57+
height=512,
58+
).images[0]
59+
image.save('demo.png')
60+
```
61+
62+
You can see more details on [`docs/source/run_guides/run_kandinsky_v22.md`](../../docs/source/run_guides/run_kandinsky_v22.md#inference-with-diffusers).
63+
64+
## Results Example
65+
66+
#### kandinsky_v22_prior_pokemon_blip
67+
68+
![example1](https://github.com/okotaku/diffengine/assets/24734142/b709f558-5c03-4235-98d7-fe1c663182b8)
69+
70+
#### kandinsky_v22_decoder_pokemon_blip
71+
72+
![example1](https://github.com/okotaku/diffengine/assets/24734142/6c9cce50-9f31-4637-9933-27697d65c830)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = [
2+
"../_base_/models/kandinsky_v22_decoder.py",
3+
"../_base_/datasets/pokemon_blip_kandinsky_decoder.py",
4+
"../_base_/schedules/stable_diffusion_50e.py",
5+
"../_base_/default_runtime.py",
6+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = [
2+
"../_base_/models/kandinsky_v22_prior.py",
3+
"../_base_/datasets/pokemon_blip_kandinsky_prior.py",
4+
"../_base_/schedules/stable_diffusion_50e.py",
5+
"../_base_/default_runtime.py",
6+
]

diffengine/datasets/transforms/processing.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,15 @@ class CLIPImageProcessor(BaseTransform):
409409
results. Defaults to 'clip_img'.
410410
"""
411411

412-
def __init__(self, key: str = "img", output_key: str = "clip_img") -> None:
412+
def __init__(self, key: str = "img", output_key: str = "clip_img",
413+
pretrained: str | None = None) -> None:
413414
self.key = key
414415
self.output_key = output_key
415-
self.pipeline = HFCLIPImageProcessor()
416+
if pretrained is None:
417+
self.pipeline = HFCLIPImageProcessor()
418+
else:
419+
self.pipeline = HFCLIPImageProcessor.from_pretrained(
420+
pretrained, subfolder="image_processor")
416421

417422
def transform(self, results: dict) -> dict | tuple[list, list] | None:
418423
"""Transform.

diffengine/engine/hooks/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from .lcm_ema_update_hook import LCMEMAUpdateHook
66
from .peft_save_hook import PeftSaveHook
77
from .pixart_checkpoint_hook import PixArtCheckpointHook
8+
from .prior_save_hook import PriorSaveHook
89
from .sd_checkpoint_hook import SDCheckpointHook
910
from .t2i_adapter_save_hook import T2IAdapterSaveHook
1011
from .unet_ema_hook import UnetEMAHook
1112
from .visualization_hook import VisualizationHook
12-
from .wuerstchen_save_hook import WuerstchenSaveHook
1313

1414
__all__ = [
1515
"VisualizationHook",
@@ -21,7 +21,7 @@
2121
"T2IAdapterSaveHook",
2222
"CompileHook",
2323
"FastNormHook",
24-
"WuerstchenSaveHook",
24+
"PriorSaveHook",
2525
"LCMEMAUpdateHook",
2626
"PixArtCheckpointHook",
2727
]

diffengine/engine/hooks/wuerstchen_save_hook.py renamed to diffengine/engine/hooks/prior_save_hook.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88

99
@HOOKS.register_module()
10-
class WuerstchenSaveHook(Hook):
11-
"""Wuerstchen Save Hook.
10+
class PriorSaveHook(Hook):
11+
"""Prior Save Hook.
1212
13-
Save Wuerstchen weights with diffusers format and pick up Wuerstchen
14-
weights from checkpoint.
13+
Save Prior weights with diffusers format and pick up Prior weights from
14+
checkpoint.
1515
"""
1616

1717
priority = "VERY_LOW"
@@ -30,7 +30,8 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
3030
model = model.module
3131
ckpt_path = osp.join(runner.work_dir, f"step{runner.iter}")
3232
model.prior.save_pretrained(osp.join(ckpt_path, "prior"))
33-
if model.finetune_text_encoder:
33+
if hasattr(
34+
model, "finetune_text_encoder") and model.finetune_text_encoder:
3435
model.text_encoder.save_pretrained(
3536
osp.join(ckpt_path, "text_encoder"))
3637

diffengine/models/editors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .esd import * # noqa: F403
44
from .instruct_pix2pix import * # noqa: F403
55
from .ip_adapter import * # noqa: F403
6+
from .kandinsky import * # noqa: F403
67
from .lcm import * # noqa: F403
78
from .pixart_alpha import * # noqa: F403
89
from .ssd_1b import * # noqa: F403

diffengine/models/editors/deepfloyd_if/deepfloyd_if.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class DeepFloydIF(BaseModel):
4646
training. Choose between 'epsilon' or 'v_prediction' or leave
4747
`None`. If left to `None` the default prediction type of the
4848
scheduler: `noise_scheduler.config.prediciton_type` is chosen.
49+
Defaults to None.
4950
data_preprocessor (dict, optional): The pre-process config of
5051
:class:`SDDataPreprocessor`.
5152
noise_generator (dict, optional): The noise generator config.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .kandinskyv22_decoder import KandinskyV22Decoder
2+
from .kandinskyv22_decoder_preprocessor import KandinskyV22DecoderDataPreprocessor
3+
from .kandinskyv22_prior import KandinskyV22Prior
4+
5+
__all__ = ["KandinskyV22Prior", "KandinskyV22Decoder",
6+
"KandinskyV22DecoderDataPreprocessor"]

0 commit comments

Comments
 (0)