Skip to content

Commit 4aeca8d

Browse files
authored
Merge pull request #77 from okotaku/feat/update_lint
[Feature] Update lint
2 parents c462ada + e2c83fa commit 4aeca8d

File tree

88 files changed

+908
-659
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+908
-659
lines changed

.github/workflows/build.yml

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ jobs:
66
lint:
77
runs-on: ubuntu-latest
88
steps:
9-
- uses: actions/checkout@v2
10-
- name: Set up Python 3.7
11-
uses: actions/setup-python@v2
9+
- uses: actions/checkout@v3
10+
- uses: actions/setup-python@v3
1211
with:
13-
python-version: 3.8
12+
python-version: '3.11'
1413
- name: Install pre-commit hook
1514
run: |
1615
pip install pre-commit
@@ -30,6 +29,9 @@ jobs:
3029

3130
steps:
3231
- uses: actions/checkout@v3
32+
- uses: actions/setup-python@v3
33+
with:
34+
python-version: '3.11'
3335
- name: Upgrade pip
3436
run: pip install pip --upgrade
3537
- name: Install system dependencies

diffengine/datasets/hf_controlnet_datasets.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class HFControlNetDataset(Dataset):
2020
"""Dataset for huggingface datasets.
2121
2222
Args:
23+
----
2324
dataset (str): Dataset name or path to dataset.
2425
image_column (str): Image column name. Defaults to 'image'.
2526
condition_column (str): Condition column name for ControlNet.
@@ -39,7 +40,7 @@ def __init__(self,
3940
caption_column: str = "text",
4041
csv: str = "metadata.csv",
4142
pipeline: Sequence = (),
42-
cache_dir: str | None = None):
43+
cache_dir: str | None = None) -> None:
4344
self.dataset_name = dataset
4445
if Path(dataset).exists():
4546
# load local folder
@@ -58,19 +59,24 @@ def __init__(self,
5859
def __len__(self) -> int:
5960
"""Get the length of dataset.
6061
61-
Returns:
62+
Returns
63+
-------
6264
int: The length of filtered dataset.
6365
"""
6466
return len(self.dataset)
6567

6668
def __getitem__(self, idx: int) -> dict:
67-
"""Get the idx-th image and data information of dataset after
69+
"""Get item.
70+
71+
Get the idx-th image and data information of dataset after
6872
``self.pipeline`.
6973
7074
Args:
75+
----
7176
idx (int): The index of self.data_list.
7277
7378
Returns:
79+
-------
7480
dict: The idx-th image and data information of dataset after
7581
``self.pipeline``.
7682
"""

diffengine/datasets/hf_datasets.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class HFDataset(Dataset):
3232
"""Dataset for huggingface datasets.
3333
3434
Args:
35+
----
3536
dataset (str): Dataset name or path to dataset.
3637
image_column (str): Image column name. Defaults to 'image'.
3738
caption_column (str): Caption column name. Defaults to 'text'.
@@ -48,7 +49,7 @@ def __init__(self,
4849
caption_column: str = "text",
4950
csv: str = "metadata.csv",
5051
pipeline: Sequence = (),
51-
cache_dir: str | None = None):
52+
cache_dir: str | None = None) -> None:
5253
self.dataset_name = dataset
5354
if Path(dataset).exists():
5455
# load local folder
@@ -66,19 +67,24 @@ def __init__(self,
6667
def __len__(self) -> int:
6768
"""Get the length of dataset.
6869
69-
Returns:
70+
Returns
71+
-------
7072
int: The length of filtered dataset.
7173
"""
7274
return len(self.dataset)
7375

7476
def __getitem__(self, idx: int) -> dict:
75-
"""Get the idx-th image and data information of dataset after
77+
"""Get item.
78+
79+
Get the idx-th image and data information of dataset after
7680
``self.pipeline`.
7781
7882
Args:
83+
----
7984
idx (int): The index of self.data_list.
8085
8186
Returns:
87+
-------
8288
dict: The idx-th image and data information of dataset after
8389
``self.pipeline``.
8490
"""
@@ -109,6 +115,7 @@ class HFDatasetPreComputeEmbs(HFDataset):
109115
1. pre-compute Text Encoder embeddings to save memory.
110116
111117
Args:
118+
----
112119
model (str): pretrained model name of stable diffusion xl.
113120
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
114121
text_hasher (str): Text embeddings hasher name. Defaults to 'text'.
@@ -158,13 +165,17 @@ def __init__(self,
158165
torch.cuda.empty_cache()
159166

160167
def __getitem__(self, idx: int) -> dict:
161-
"""Get the idx-th image and data information of dataset after
168+
"""Get item.
169+
170+
Get the idx-th image and data information of dataset after
162171
``self.train_transforms`.
163172
164173
Args:
174+
----
165175
idx (int): The index of self.data_list.
166176
167177
Returns:
178+
-------
168179
dict: The idx-th image and data information of dataset after
169180
``self.train_transforms``.
170181
"""

diffengine/datasets/hf_dreambooth_datasets.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class HFDreamBoothDataset(Dataset):
2424
"""DreamBooth Dataset for huggingface datasets.
2525
2626
Args:
27+
----
2728
dataset (str): Dataset name.
2829
instance_prompt (str):
2930
The prompt with identifier specifying the instance.
@@ -47,6 +48,7 @@ class as provided instance images. Defaults to None.
4748
cache_dir (str, optional): The directory where the downloaded datasets
4849
will be stored.Defaults to None.
4950
"""
51+
5052
default_class_image_config: dict = {
5153
"model": "runwayml/stable-diffusion-v1-5",
5254
"data_dir": "work_dirs/class_image",
@@ -63,7 +65,7 @@ def __init__(self,
6365
class_image_config: dict | None = None,
6466
class_prompt: str | None = None,
6567
pipeline: Sequence = (),
66-
cache_dir: str | None = None):
68+
cache_dir: str | None = None) -> None:
6769

6870
if class_image_config is None:
6971
class_image_config = {
@@ -108,7 +110,8 @@ def __init__(self,
108110
f"class_image_config needs a dict with keys {essential_keys}"
109111
self.generate_class_image(class_image_config)
110112

111-
def generate_class_image(self, class_image_config):
113+
def generate_class_image(self, class_image_config) -> None:
114+
"""Generate class images for prior preservation loss."""
112115
class_images_dir = Path(class_image_config["data_dir"])
113116
if class_images_dir.exists(
114117
) and class_image_config["recreate_class_images"]:
@@ -145,19 +148,24 @@ def generate_class_image(self, class_image_config):
145148
def __len__(self) -> int:
146149
"""Get the length of dataset.
147150
148-
Returns:
151+
Returns
152+
-------
149153
int: The length of filtered dataset.
150154
"""
151155
return len(self.dataset)
152156

153157
def __getitem__(self, idx: int) -> dict:
154-
"""Get the idx-th image and data information of dataset after
158+
"""Get item.
159+
160+
Get the idx-th image and data information of dataset after
155161
``self.pipeline`.
156162
157163
Args:
164+
----
158165
idx (int): The index of self.data_list.
159166
160167
Returns:
168+
-------
161169
dict: The idx-th image and data information of dataset after
162170
``self.pipeline``.
163171
"""

diffengine/datasets/hf_esd_datasets.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121

2222
@DATASETS.register_module()
2323
class HFESDDatasetPreComputeEmbs(Dataset):
24-
"""Dataset of huggingface datasets for Erasing Concepts from Diffusion
24+
"""Huggingface Erasing Concepts from Diffusion Models Dataset.
25+
26+
Dataset of huggingface datasets for Erasing Concepts from Diffusion
2527
Models.
2628
2729
Args:
30+
----
2831
forget_caption (str): The caption used to forget.
2932
model (str): pretrained model name of stable diffusion xl.
3033
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
@@ -68,7 +71,8 @@ def __init__(self,
6871
def __len__(self) -> int:
6972
"""Get the length of dataset.
7073
71-
Returns:
74+
Returns
75+
-------
7276
int: The length of filtered dataset.
7377
"""
7478
return 1
@@ -77,9 +81,11 @@ def __getitem__(self, idx: int) -> dict:
7781
"""Get the dataset after ``self.pipeline`.
7882
7983
Args:
84+
----
8085
idx (int): The index.
8186
8287
Returns:
88+
-------
8389
dict: The idx-th data information of dataset after
8490
``self.pipeline``.
8591
"""

diffengine/datasets/samplers/batch_sampler.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
@DATA_SAMPLERS.register_module()
1111
class AspectRatioBatchSampler(BatchSampler):
12-
"""A sampler wrapper for grouping images with similar aspect ratio into a
12+
"""Aspect ratio batch sampler.
13+
14+
A sampler wrapper for grouping images with similar aspect ratio into a
1315
same batch.
1416
1517
Args:
18+
----
1619
sampler (Sampler): Base sampler.
1720
batch_size (int): Size of mini-batch.
1821
drop_last (bool): If ``True``, the sampler will drop the last batch if
@@ -46,6 +49,7 @@ def __init__(self,
4649
self.bucket_ids.append(bucket_id)
4750

4851
def __iter__(self) -> Generator:
52+
"""Get the iterator of the sampler."""
4953
for idx in self.sampler:
5054
bucket_id = self.bucket_ids[idx]
5155
if bucket_id not in self._aspect_ratio_buckets:
@@ -66,6 +70,7 @@ def __iter__(self) -> Generator:
6670
self._aspect_ratio_buckets = {}
6771

6872
def __len__(self) -> int:
73+
"""Get the length of the sampler."""
6974
total_sample = 0
7075
_, counts = np.unique(self.bucket_ids, return_counts=True)
7176
for c in counts:

diffengine/datasets/transforms/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ class BaseTransform(metaclass=ABCMeta):
77
"""Base class for all transformations."""
88

99
def __call__(self, results: dict) -> dict | tuple[list, list] | None:
10-
10+
"""Call function to transform data."""
1111
return self.transform(results)
1212

1313
@abstractmethod
1414
def transform(self, results: dict) -> dict | tuple[list, list] | None:
15-
"""The transform function. All subclass of BaseTransform should
15+
"""Transform the data.
16+
17+
The transform function. All subclass of BaseTransform should
1618
override this method.
1719
1820
This function takes the result dict as the input, and can add new
@@ -21,8 +23,10 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
2123
transforms into a pipeline.
2224
2325
Args:
26+
----
2427
results (dict): The result dict.
2528
2629
Returns:
30+
-------
2731
dict: The result dict.
2832
"""

diffengine/datasets/transforms/dump_image.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,31 @@ class DumpImage:
1414
"""Dump the image processed by the pipeline.
1515
1616
Args:
17+
----
1718
max_imgs (int): Maximum value of output.
1819
dump_dir (str): Dump output directory.
1920
"""
2021

21-
def __init__(self, max_imgs: int, dump_dir: str):
22+
def __init__(self, max_imgs: int, dump_dir: str) -> None:
2223
self.max_imgs = max_imgs
2324
self.dump_dir = dump_dir
2425
mmengine.mkdir_or_exist(self.dump_dir)
2526
self.num_dumped_imgs = Value("i", 0)
2627

27-
def __call__(self, results):
28+
def __call__(self, results) -> dict:
2829
"""Dump the input image to the specified directory.
2930
3031
No changes will be
3132
made.
33+
3234
Args:
35+
----
3336
results (dict): Result dict from loading pipeline.
37+
3438
Returns:
39+
-------
3540
results (dict): Result dict from loading pipeline. (same as input)
3641
"""
37-
3842
enable_dump = False
3943
with self.num_dumped_imgs.get_lock():
4044
if self.num_dumped_imgs.value < self.max_imgs:

diffengine/datasets/transforms/formatting.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from diffengine.registry import TRANSFORMS
1010

1111

12-
def to_tensor(data):
12+
def to_tensor(data) -> torch.Tensor:
1313
"""Convert objects of various python types to :obj:`torch.Tensor`.
1414
1515
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
@@ -45,6 +45,7 @@ class PackInputs(BaseTransform):
4545
All other keys in the dict.
4646
4747
Args:
48+
----
4849
input_keys (List[str]): The key of element to feed into the model
4950
forwarding. Defaults to ['img', 'text'].
5051
skip_to_tensor_key (List[str]): The key of element to skip to_tensor.
@@ -53,7 +54,7 @@ class PackInputs(BaseTransform):
5354

5455
def __init__(self,
5556
input_keys: list[str] | None = None,
56-
skip_to_tensor_key: list[str] | None = None):
57+
skip_to_tensor_key: list[str] | None = None) -> None:
5758
if skip_to_tensor_key is None:
5859
skip_to_tensor_key = ["text"]
5960
if input_keys is None:
@@ -62,8 +63,7 @@ def __init__(self,
6263
self.skip_to_tensor_key = skip_to_tensor_key
6364

6465
def transform(self, results: dict) -> dict:
65-
"""Method to pack the input data."""
66-
66+
"""Transform the data."""
6767
packed_results = {}
6868
for k in self.input_keys:
6969
if k in results and k not in self.skip_to_tensor_key:

0 commit comments

Comments
 (0)