Skip to content

Commit 4093fa9

Browse files
authored
Merge pull request #87 from okotaku/feat/support_noise_method
[Feature] Support Noise Methods
2 parents e93f733 + 5442062 commit 4093fa9

33 files changed

+501
-93
lines changed

.pre-commit-config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ repos:
1818
- id: mixed-line-ending
1919
args: ["--fix=lf"]
2020
- repo: https://github.com/codespell-project/codespell
21-
rev: v2.2.1
21+
rev: v2.2.4
2222
hooks:
2323
- id: codespell
24+
additional_dependencies:
25+
- tomli
2426
- repo: https://github.com/executablebooks/mdformat
2527
rev: 0.7.9
2628
hooks:

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ For detailed user guides and advanced guides, please refer to our [Documentation
201201
<ul>
202202
<li><a href="configs/min_snr_loss/README.md">Min-SNR Loss (ICCV'2023)</a></li>
203203
<li><a href="configs/debias_estimation_loss/README.md">DeBias Estimation Loss (2023)</a></li>
204+
<li><a href="configs/offset_noise/README.md">Offset Noise (2023)</a></li>
205+
<li><a href="configs/pyramid_noise/README.md">Pyramid Noise (2023)</a></li>
206+
<li><a href="configs/input_perturbation/README.md">Input Perturbation (2023)</a></li>
204207
</ul>
205208
</td>
206209
</tr>

configs/input_perturbation/README.md

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Input Perturbation
2+
3+
[Input Perturbation Reduces Exposure Bias in Diffusion Models](https://arxiv.org/abs/2301.11706)
4+
5+
## Abstract
6+
7+
Denoising Diffusion Probabilistic Models have shown an impressive generation quality, although their long sampling chain leads to high computational costs. In this paper, we observe that a long sampling chain also leads to an error accumulation phenomenon, which is similar to the exposure bias problem in autoregressive text generation. Specifically, we note that there is a discrepancy between training and testing, since the former is conditioned on the ground truth samples, while the latter is conditioned on the previously generated results. To alleviate this problem, we propose a very simple but effective training regularization, consisting in perturbing the ground truth samples to simulate the inference time prediction errors. We empirically show that, without affecting the recall and precision, the proposed input perturbation leads to a significant improvement in the sample quality while reducing both the training and the inference times. For instance, on CelebA 64×64, we achieve a new state-of-the-art FID score of 1.27, while saving 37.5% of the training time.
8+
9+
<div align=center>
10+
<img src="https://github.com/okotaku/diffengine/assets/24734142/60b9a296-6453-4d47-9c06-f40f43766273"/>
11+
</div>
12+
13+
## Citation
14+
15+
```
16+
@article{ning2023input,
17+
title={Input Perturbation Reduces Exposure Bias in Diffusion Models},
18+
author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita},
19+
journal={arXiv preprint arXiv:2301.11706},
20+
year={2023}
21+
}
22+
```
23+
24+
## Run Training
25+
26+
Run Training
27+
28+
```
29+
# single gpu
30+
$ mim train diffengine ${CONFIG_FILE}
31+
# multi gpus
32+
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
33+
34+
# Example.
35+
$ mim train diffengine configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py
36+
```
37+
38+
## Inference with diffusers
39+
40+
You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).
41+
42+
## Results Example
43+
44+
#### stable_diffusion_xl_pokemon_blip_input_perturbation
45+
46+
![example1](https://github.com/okotaku/diffengine/assets/24734142/b0a631e7-153c-467a-9cb6-d9155eaa7161)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
"../_base_/models/stable_diffusion_xl.py",
3+
"../_base_/datasets/pokemon_blip_xl.py",
4+
"../_base_/schedules/stable_diffusion_xl_50e.py",
5+
"../_base_/default_runtime.py",
6+
]
7+
8+
model = dict(input_perturbation_gamma=0.1)
9+
10+
train_dataloader = dict(batch_size=1)
11+
12+
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times

configs/offset_noise/README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Offset Noise
2+
3+
[Diffusion with Offset Noise](https://www.crosslabs.org/blog/diffusion-with-offset-noise)
4+
5+
## Abstract
6+
7+
Fine-tuning against a modified noise, enables Stable Diffusion to generate very dark or light images easily.
8+
9+
<div align=center>
10+
<img src="https://github.com/okotaku/diffengine/assets/24734142/76038bc8-b614-49da-9751-1a9efb83995f"/>
11+
</div>
12+
13+
## Citation
14+
15+
```
16+
```
17+
18+
## Run Training
19+
20+
Run Training
21+
22+
```
23+
# single gpu
24+
$ mim train diffengine ${CONFIG_FILE}
25+
# multi gpus
26+
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
27+
28+
# Example.
29+
$ mim train diffengine configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py
30+
```
31+
32+
## Inference with diffusers
33+
34+
You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).
35+
36+
## Results Example
37+
38+
#### stable_diffusion_xl_pokemon_blip_offset_noise
39+
40+
![example1](https://github.com/okotaku/diffengine/assets/24734142/7a3b26ff-618b-46f0-827e-32c2d47cde6f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
"../_base_/models/stable_diffusion_xl.py",
3+
"../_base_/datasets/pokemon_blip_xl.py",
4+
"../_base_/schedules/stable_diffusion_xl_50e.py",
5+
"../_base_/default_runtime.py",
6+
]
7+
8+
model = dict(noise_generator=dict(type="OffsetNoise", offset_weight=0.05))
9+
10+
train_dataloader = dict(batch_size=1)
11+
12+
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times

configs/pyramid_noise/README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Pyramid Noise
2+
3+
[Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2)
4+
5+
## Abstract
6+
7+
This report proposes a new noising approach that adds multi-resolution noise to an image or latent image during diffusion model training. A model trained with this technique can generate stunning images with a very different aesthetic to the usual diffusion model outputs. This seems like a promising direction for future research.
8+
9+
<div align=center>
10+
<img src="https://github.com/okotaku/diffengine/assets/24734142/943570cf-7283-4536-ae28-cd1cce1220b7"/>
11+
</div>
12+
13+
## Citation
14+
15+
```
16+
```
17+
18+
## Run Training
19+
20+
Run Training
21+
22+
```
23+
# single gpu
24+
$ mim train diffengine ${CONFIG_FILE}
25+
# multi gpus
26+
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
27+
28+
# Example.
29+
$ mim train diffengine configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py
30+
```
31+
32+
## Inference with diffusers
33+
34+
You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).
35+
36+
## Results Example
37+
38+
#### stable_diffusion_xl_pokemon_blip_pyramid_noise
39+
40+
![example1](https://github.com/okotaku/diffengine/assets/24734142/8ee2f0b1-6ef6-4b5e-a018-8b0acbd73ec9)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
"../_base_/models/stable_diffusion_xl.py",
3+
"../_base_/datasets/pokemon_blip_xl.py",
4+
"../_base_/schedules/stable_diffusion_xl_50e.py",
5+
"../_base_/default_runtime.py",
6+
]
7+
8+
model = dict(noise_generator=dict(type="PyramidNoise", discount=0.9))
9+
10+
train_dataloader = dict(batch_size=1)
11+
12+
optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times

diffengine/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .editors import * # noqa: F403
22
from .losses import * # noqa: F403
3+
from .utils import * # noqa: F403

diffengine/models/editors/deepfloyd_if/deepfloyd_if.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ class DeepFloydIF(BaseModel):
3131
example. dict(rank=4). Defaults to None.
3232
prior_loss_weight (float): The weight of prior preservation loss.
3333
It works when training dreambooth with class images.
34-
noise_offset_weight (bool, optional):
35-
The weight of noise offset introduced in
36-
https://www.crosslabs.org/blog/diffusion-with-offset-noise
37-
Defaults to 0.
3834
tokenizer_max_length (int): The max length of tokenizer.
3935
Defaults to 77.
4036
prediction_type (str): The prediction_type that shall be used for
@@ -43,6 +39,11 @@ class DeepFloydIF(BaseModel):
4339
scheduler: `noise_scheduler.config.prediciton_type` is chosen.
4440
data_preprocessor (dict, optional): The pre-process config of
4541
:class:`SDDataPreprocessor`.
42+
noise_generator (dict, optional): The noise generator config.
43+
Defaults to ``dict(type='WhiteNoise')``.
44+
input_perturbation_gamma (float): The gamma of input perturbation.
45+
The recommended value is 0.1 for Input Perturbation.
46+
Defaults to 0.0.
4647
finetune_text_encoder (bool, optional): Whether to fine-tune text
4748
encoder. Defaults to False.
4849
gradient_checkpointing (bool): Whether or not to use gradient
@@ -56,16 +57,19 @@ def __init__(
5657
loss: dict | None = None,
5758
lora_config: dict | None = None,
5859
prior_loss_weight: float = 1.,
59-
noise_offset_weight: float = 0,
6060
tokenizer_max_length: int = 77,
6161
prediction_type: str | None = None,
6262
data_preprocessor: dict | nn.Module | None = None,
63+
noise_generator: dict | None = None,
64+
input_perturbation_gamma: float = 0.0,
6365
*,
6466
finetune_text_encoder: bool = False,
6567
gradient_checkpointing: bool = False,
6668
) -> None:
6769
if data_preprocessor is None:
6870
data_preprocessor = {"type": "SDDataPreprocessor"}
71+
if noise_generator is None:
72+
noise_generator = {"type": "WhiteNoise"}
6973
if loss is None:
7074
loss = {"type": "L2Loss", "loss_weight": 1.0}
7175
super().__init__(data_preprocessor=data_preprocessor)
@@ -75,13 +79,12 @@ def __init__(
7579
self.prior_loss_weight = prior_loss_weight
7680
self.gradient_checkpointing = gradient_checkpointing
7781
self.tokenizer_max_length = tokenizer_max_length
82+
self.input_perturbation_gamma = input_perturbation_gamma
7883

7984
if not isinstance(loss, nn.Module):
8085
loss = MODELS.build(loss)
8186
self.loss_module: nn.Module = loss
8287

83-
self.enable_noise_offset = noise_offset_weight > 0
84-
self.noise_offset_weight = noise_offset_weight
8588
assert prediction_type in [None, "epsilon", "v_prediction"]
8689
self.prediction_type = prediction_type
8790

@@ -94,6 +97,7 @@ def __init__(
9497
model, subfolder="text_encoder")
9598
self.unet = UNet2DConditionModel.from_pretrained(
9699
model, subfolder="unet")
100+
self.noise_generator = MODELS.build(noise_generator)
97101
self.prepare_model()
98102
self.set_lora()
99103

@@ -244,6 +248,17 @@ def loss(self,
244248
loss_dict["loss"] = loss
245249
return loss_dict
246250

251+
def _preprocess_model_input(self,
252+
latents: torch.Tensor,
253+
noise: torch.Tensor,
254+
timesteps: torch.Tensor) -> torch.Tensor:
255+
if self.input_perturbation_gamma > 0:
256+
input_noise = noise + self.input_perturbation_gamma * torch.randn_like(
257+
noise)
258+
else:
259+
input_noise = noise
260+
return self.scheduler.add_noise(latents, input_noise, timesteps)
261+
247262
def forward(
248263
self,
249264
inputs: torch.Tensor,
@@ -283,15 +298,7 @@ def forward(
283298

284299
model_input = inputs["img"]
285300

286-
noise = torch.randn_like(model_input)
287-
288-
if self.enable_noise_offset:
289-
noise = noise + self.noise_offset_weight * torch.randn(
290-
model_input.shape[0],
291-
model_input.shape[1],
292-
1,
293-
1,
294-
device=noise.device)
301+
noise = self.noise_generator(model_input)
295302

296303
num_batches = model_input.shape[0]
297304
timesteps = torch.randint(
@@ -300,7 +307,7 @@ def forward(
300307
device=self.device)
301308
timesteps = timesteps.long()
302309

303-
noisy_model_input = self.scheduler.add_noise(model_input, noise,
310+
noisy_model_input = self._preprocess_model_input(model_input, noise,
304311
timesteps)
305312

306313
encoder_hidden_states = self.text_encoder(

diffengine/models/editors/distill_sd/distill_sd_xl.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,15 @@ def forward(
161161
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
162162
latents = latents * self.vae.config.scaling_factor
163163

164-
noise = torch.randn_like(latents)
165-
166-
if self.enable_noise_offset:
167-
noise = noise + self.noise_offset_weight * torch.randn(
168-
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
164+
noise = self.noise_generator(latents)
169165

170166
timesteps = torch.randint(
171167
0,
172168
self.scheduler.config.num_train_timesteps, (num_batches, ),
173169
device=self.device)
174170
timesteps = timesteps.long()
175171

176-
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
172+
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)
177173

178174
if not self.pre_compute_text_embeddings:
179175
inputs["text_one"] = self.tokenizer_one(

diffengine/models/editors/ip_adapter/ip_adapter_xl.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -242,19 +242,15 @@ def forward(
242242
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
243243
latents = latents * self.vae.config.scaling_factor
244244

245-
noise = torch.randn_like(latents)
246-
247-
if self.enable_noise_offset:
248-
noise = noise + self.noise_offset_weight * torch.randn(
249-
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
245+
noise = self.noise_generator(latents)
250246

251247
timesteps = torch.randint(
252248
0,
253249
self.scheduler.config.num_train_timesteps, (num_batches, ),
254250
device=self.device)
255251
timesteps = timesteps.long()
256252

257-
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
253+
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)
258254

259255
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
260256
inputs["text_one"], inputs["text_two"])
@@ -401,19 +397,15 @@ def forward(
401397
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
402398
latents = latents * self.vae.config.scaling_factor
403399

404-
noise = torch.randn_like(latents)
405-
406-
if self.enable_noise_offset:
407-
noise = noise + self.noise_offset_weight * torch.randn(
408-
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
400+
noise = self.noise_generator(latents)
409401

410402
timesteps = torch.randint(
411403
0,
412404
self.scheduler.config.num_train_timesteps, (num_batches, ),
413405
device=self.device)
414406
timesteps = timesteps.long()
415407

416-
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
408+
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)
417409

418410
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
419411
inputs["text_one"], inputs["text_two"])

0 commit comments

Comments
 (0)