@@ -24,6 +24,7 @@ class HFDreamBoothDataset(Dataset):
24
24
"""DreamBooth Dataset for huggingface datasets.
25
25
26
26
Args:
27
+ ----
27
28
dataset (str): Dataset name.
28
29
instance_prompt (str):
29
30
The prompt with identifier specifying the instance.
@@ -47,6 +48,7 @@ class as provided instance images. Defaults to None.
47
48
cache_dir (str, optional): The directory where the downloaded datasets
48
49
will be stored.Defaults to None.
49
50
"""
51
+
50
52
default_class_image_config : dict = {
51
53
"model" : "runwayml/stable-diffusion-v1-5" ,
52
54
"data_dir" : "work_dirs/class_image" ,
@@ -63,7 +65,7 @@ def __init__(self,
63
65
class_image_config : dict | None = None ,
64
66
class_prompt : str | None = None ,
65
67
pipeline : Sequence = (),
66
- cache_dir : str | None = None ):
68
+ cache_dir : str | None = None ) -> None :
67
69
68
70
if class_image_config is None :
69
71
class_image_config = {
@@ -108,7 +110,8 @@ def __init__(self,
108
110
f"class_image_config needs a dict with keys { essential_keys } "
109
111
self .generate_class_image (class_image_config )
110
112
111
- def generate_class_image (self , class_image_config ):
113
+ def generate_class_image (self , class_image_config ) -> None :
114
+ """Generate class images for prior preservation loss."""
112
115
class_images_dir = Path (class_image_config ["data_dir" ])
113
116
if class_images_dir .exists (
114
117
) and class_image_config ["recreate_class_images" ]:
@@ -145,19 +148,24 @@ def generate_class_image(self, class_image_config):
145
148
def __len__ (self ) -> int :
146
149
"""Get the length of dataset.
147
150
148
- Returns:
151
+ Returns
152
+ -------
149
153
int: The length of filtered dataset.
150
154
"""
151
155
return len (self .dataset )
152
156
153
157
def __getitem__ (self , idx : int ) -> dict :
154
- """Get the idx-th image and data information of dataset after
158
+ """Get item.
159
+
160
+ Get the idx-th image and data information of dataset after
155
161
``self.pipeline`.
156
162
157
163
Args:
164
+ ----
158
165
idx (int): The index of self.data_list.
159
166
160
167
Returns:
168
+ -------
161
169
dict: The idx-th image and data information of dataset after
162
170
``self.pipeline``.
163
171
"""
0 commit comments