|
| 1 | +import functools |
| 2 | +import gc |
1 | 3 | import os
|
2 | 4 | import random
|
3 | 5 | from pathlib import Path
|
4 | 6 | from typing import Optional, Sequence
|
5 | 7 |
|
6 | 8 | import numpy as np
|
| 9 | +import torch |
7 | 10 | from datasets import load_dataset
|
| 11 | +from datasets.fingerprint import Hasher |
8 | 12 | from mmengine.dataset.base_dataset import Compose
|
9 | 13 | from PIL import Image
|
10 | 14 | from torch.utils.data import Dataset
|
| 15 | +from transformers import AutoTokenizer |
11 | 16 |
|
| 17 | +from diffengine.datasets.utils import encode_prompt_sdxl |
| 18 | +from diffengine.models.editors.stable_diffusion_xl.stable_diffusion_xl import \ |
| 19 | + import_model_class_from_model_name_or_path |
12 | 20 | from diffengine.registry import DATASETS
|
13 | 21 |
|
14 | 22 | Image.MAX_IMAGE_PIXELS = 1000000000
|
@@ -88,3 +96,84 @@ def __getitem__(self, idx: int) -> dict:
|
88 | 96 | result = self.pipeline(result)
|
89 | 97 |
|
90 | 98 | return result
|
| 99 | + |
| 100 | + |
| 101 | +@DATASETS.register_module() |
| 102 | +class HFDatasetPreComputeEmbs(HFDataset): |
| 103 | + """Dataset for huggingface datasets. |
| 104 | +
|
| 105 | + The difference from HFDataset is |
| 106 | + 1. pre-compute Text Encoder embeddings to save memory. |
| 107 | +
|
| 108 | + Args: |
| 109 | + model (str): pretrained model name of stable diffusion xl. |
| 110 | + Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. |
| 111 | + text_hasher (str): Text embeddings hasher name. Defaults to 'text'. |
| 112 | + device (str): Device used to compute embeddings. Defaults to 'cuda'. |
| 113 | + proportion_empty_prompts (float): The probabilities to replace empty |
| 114 | + text. Defaults to 0.9. |
| 115 | + """ |
| 116 | + |
| 117 | + def __init__(self, |
| 118 | + *args, |
| 119 | + model: str = 'stabilityai/stable-diffusion-xl-base-1.0', |
| 120 | + text_hasher: str = 'text', |
| 121 | + device: str = 'cuda', |
| 122 | + proportion_empty_prompts: float = 0.0, |
| 123 | + **kwargs) -> None: |
| 124 | + super().__init__(*args, **kwargs) |
| 125 | + |
| 126 | + tokenizer_one = AutoTokenizer.from_pretrained( |
| 127 | + model, subfolder='tokenizer', use_fast=False) |
| 128 | + tokenizer_two = AutoTokenizer.from_pretrained( |
| 129 | + model, subfolder='tokenizer_2', use_fast=False) |
| 130 | + |
| 131 | + text_encoder_cls_one = import_model_class_from_model_name_or_path( |
| 132 | + model) |
| 133 | + text_encoder_cls_two = import_model_class_from_model_name_or_path( |
| 134 | + model, subfolder='text_encoder_2') |
| 135 | + text_encoder_one = text_encoder_cls_one.from_pretrained( |
| 136 | + model, subfolder='text_encoder').to(device) |
| 137 | + text_encoder_two = text_encoder_cls_two.from_pretrained( |
| 138 | + model, subfolder='text_encoder_2').to(device) |
| 139 | + |
| 140 | + new_fingerprint = Hasher.hash(text_hasher) |
| 141 | + compute_embeddings_fn = functools.partial( |
| 142 | + encode_prompt_sdxl, |
| 143 | + text_encoders=[text_encoder_one, text_encoder_two], |
| 144 | + tokenizers=[tokenizer_one, tokenizer_two], |
| 145 | + proportion_empty_prompts=proportion_empty_prompts, |
| 146 | + caption_column=self.caption_column, |
| 147 | + ) |
| 148 | + self.dataset = self.dataset.map( |
| 149 | + compute_embeddings_fn, |
| 150 | + batched=True, |
| 151 | + new_fingerprint=new_fingerprint) |
| 152 | + |
| 153 | + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two |
| 154 | + gc.collect() |
| 155 | + torch.cuda.empty_cache() |
| 156 | + |
| 157 | + def __getitem__(self, idx: int) -> dict: |
| 158 | + """Get the idx-th image and data information of dataset after |
| 159 | + ``self.train_transforms`. |
| 160 | +
|
| 161 | + Args: |
| 162 | + idx (int): The index of self.data_list. |
| 163 | +
|
| 164 | + Returns: |
| 165 | + dict: The idx-th image and data information of dataset after |
| 166 | + ``self.train_transforms``. |
| 167 | + """ |
| 168 | + data_info = self.dataset[idx] |
| 169 | + image = data_info[self.image_column] |
| 170 | + if type(image) == str: |
| 171 | + image = Image.open(os.path.join(self.dataset_name, image)) |
| 172 | + image = image.convert('RGB') |
| 173 | + result = dict( |
| 174 | + img=image, |
| 175 | + prompt_embeds=data_info['prompt_embeds'], |
| 176 | + pooled_prompt_embeds=data_info['pooled_prompt_embeds']) |
| 177 | + result = self.pipeline(result) |
| 178 | + |
| 179 | + return result |
0 commit comments