Skip to content

[Feature] Update lint #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.8
python-version: '3.11'
- name: Install pre-commit hook
run: |
pip install pre-commit
Expand All @@ -30,6 +29,9 @@ jobs:

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.11'
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install system dependencies
Expand Down
12 changes: 9 additions & 3 deletions diffengine/datasets/hf_controlnet_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class HFControlNetDataset(Dataset):
"""Dataset for huggingface datasets.

Args:
----
dataset (str): Dataset name or path to dataset.
image_column (str): Image column name. Defaults to 'image'.
condition_column (str): Condition column name for ControlNet.
Expand All @@ -39,7 +40,7 @@ def __init__(self,
caption_column: str = "text",
csv: str = "metadata.csv",
pipeline: Sequence = (),
cache_dir: str | None = None):
cache_dir: str | None = None) -> None:
self.dataset_name = dataset
if Path(dataset).exists():
# load local folder
Expand All @@ -58,19 +59,24 @@ def __init__(self,
def __len__(self) -> int:
"""Get the length of dataset.

Returns:
Returns
-------
int: The length of filtered dataset.
"""
return len(self.dataset)

def __getitem__(self, idx: int) -> dict:
"""Get the idx-th image and data information of dataset after
"""Get item.

Get the idx-th image and data information of dataset after
``self.pipeline`.

Args:
----
idx (int): The index of self.data_list.

Returns:
-------
dict: The idx-th image and data information of dataset after
``self.pipeline``.
"""
Expand Down
19 changes: 15 additions & 4 deletions diffengine/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class HFDataset(Dataset):
"""Dataset for huggingface datasets.

Args:
----
dataset (str): Dataset name or path to dataset.
image_column (str): Image column name. Defaults to 'image'.
caption_column (str): Caption column name. Defaults to 'text'.
Expand All @@ -48,7 +49,7 @@ def __init__(self,
caption_column: str = "text",
csv: str = "metadata.csv",
pipeline: Sequence = (),
cache_dir: str | None = None):
cache_dir: str | None = None) -> None:
self.dataset_name = dataset
if Path(dataset).exists():
# load local folder
Expand All @@ -66,19 +67,24 @@ def __init__(self,
def __len__(self) -> int:
"""Get the length of dataset.

Returns:
Returns
-------
int: The length of filtered dataset.
"""
return len(self.dataset)

def __getitem__(self, idx: int) -> dict:
"""Get the idx-th image and data information of dataset after
"""Get item.

Get the idx-th image and data information of dataset after
``self.pipeline`.

Args:
----
idx (int): The index of self.data_list.

Returns:
-------
dict: The idx-th image and data information of dataset after
``self.pipeline``.
"""
Expand Down Expand Up @@ -109,6 +115,7 @@ class HFDatasetPreComputeEmbs(HFDataset):
1. pre-compute Text Encoder embeddings to save memory.

Args:
----
model (str): pretrained model name of stable diffusion xl.
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
text_hasher (str): Text embeddings hasher name. Defaults to 'text'.
Expand Down Expand Up @@ -158,13 +165,17 @@ def __init__(self,
torch.cuda.empty_cache()

def __getitem__(self, idx: int) -> dict:
"""Get the idx-th image and data information of dataset after
"""Get item.

Get the idx-th image and data information of dataset after
``self.train_transforms`.

Args:
----
idx (int): The index of self.data_list.

Returns:
-------
dict: The idx-th image and data information of dataset after
``self.train_transforms``.
"""
Expand Down
16 changes: 12 additions & 4 deletions diffengine/datasets/hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class HFDreamBoothDataset(Dataset):
"""DreamBooth Dataset for huggingface datasets.

Args:
----
dataset (str): Dataset name.
instance_prompt (str):
The prompt with identifier specifying the instance.
Expand All @@ -47,6 +48,7 @@ class as provided instance images. Defaults to None.
cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
"""

default_class_image_config: dict = {
"model": "runwayml/stable-diffusion-v1-5",
"data_dir": "work_dirs/class_image",
Expand All @@ -63,7 +65,7 @@ def __init__(self,
class_image_config: dict | None = None,
class_prompt: str | None = None,
pipeline: Sequence = (),
cache_dir: str | None = None):
cache_dir: str | None = None) -> None:

if class_image_config is None:
class_image_config = {
Expand Down Expand Up @@ -108,7 +110,8 @@ def __init__(self,
f"class_image_config needs a dict with keys {essential_keys}"
self.generate_class_image(class_image_config)

def generate_class_image(self, class_image_config):
def generate_class_image(self, class_image_config) -> None:
"""Generate class images for prior preservation loss."""
class_images_dir = Path(class_image_config["data_dir"])
if class_images_dir.exists(
) and class_image_config["recreate_class_images"]:
Expand Down Expand Up @@ -145,19 +148,24 @@ def generate_class_image(self, class_image_config):
def __len__(self) -> int:
"""Get the length of dataset.

Returns:
Returns
-------
int: The length of filtered dataset.
"""
return len(self.dataset)

def __getitem__(self, idx: int) -> dict:
"""Get the idx-th image and data information of dataset after
"""Get item.

Get the idx-th image and data information of dataset after
``self.pipeline`.

Args:
----
idx (int): The index of self.data_list.

Returns:
-------
dict: The idx-th image and data information of dataset after
``self.pipeline``.
"""
Expand Down
10 changes: 8 additions & 2 deletions diffengine/datasets/hf_esd_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@

@DATASETS.register_module()
class HFESDDatasetPreComputeEmbs(Dataset):
"""Dataset of huggingface datasets for Erasing Concepts from Diffusion
"""Huggingface Erasing Concepts from Diffusion Models Dataset.

Dataset of huggingface datasets for Erasing Concepts from Diffusion
Models.

Args:
----
forget_caption (str): The caption used to forget.
model (str): pretrained model name of stable diffusion xl.
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
Expand Down Expand Up @@ -68,7 +71,8 @@ def __init__(self,
def __len__(self) -> int:
"""Get the length of dataset.

Returns:
Returns
-------
int: The length of filtered dataset.
"""
return 1
Expand All @@ -77,9 +81,11 @@ def __getitem__(self, idx: int) -> dict:
"""Get the dataset after ``self.pipeline`.

Args:
----
idx (int): The index.

Returns:
-------
dict: The idx-th data information of dataset after
``self.pipeline``.
"""
Expand Down
7 changes: 6 additions & 1 deletion diffengine/datasets/samplers/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a
"""Aspect ratio batch sampler.

A sampler wrapper for grouping images with similar aspect ratio into a
same batch.

Args:
----
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
Expand Down Expand Up @@ -46,6 +49,7 @@ def __init__(self,
self.bucket_ids.append(bucket_id)

def __iter__(self) -> Generator:
"""Get the iterator of the sampler."""
for idx in self.sampler:
bucket_id = self.bucket_ids[idx]
if bucket_id not in self._aspect_ratio_buckets:
Expand All @@ -66,6 +70,7 @@ def __iter__(self) -> Generator:
self._aspect_ratio_buckets = {}

def __len__(self) -> int:
"""Get the length of the sampler."""
total_sample = 0
_, counts = np.unique(self.bucket_ids, return_counts=True)
for c in counts:
Expand Down
8 changes: 6 additions & 2 deletions diffengine/datasets/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ class BaseTransform(metaclass=ABCMeta):
"""Base class for all transformations."""

def __call__(self, results: dict) -> dict | tuple[list, list] | None:

"""Call function to transform data."""
return self.transform(results)

@abstractmethod
def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""The transform function. All subclass of BaseTransform should
"""Transform the data.

The transform function. All subclass of BaseTransform should
override this method.

This function takes the result dict as the input, and can add new
Expand All @@ -21,8 +23,10 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
transforms into a pipeline.

Args:
----
results (dict): The result dict.

Returns:
-------
dict: The result dict.
"""
10 changes: 7 additions & 3 deletions diffengine/datasets/transforms/dump_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,31 @@ class DumpImage:
"""Dump the image processed by the pipeline.

Args:
----
max_imgs (int): Maximum value of output.
dump_dir (str): Dump output directory.
"""

def __init__(self, max_imgs: int, dump_dir: str):
def __init__(self, max_imgs: int, dump_dir: str) -> None:
self.max_imgs = max_imgs
self.dump_dir = dump_dir
mmengine.mkdir_or_exist(self.dump_dir)
self.num_dumped_imgs = Value("i", 0)

def __call__(self, results):
def __call__(self, results) -> dict:
"""Dump the input image to the specified directory.

No changes will be
made.

Args:
----
results (dict): Result dict from loading pipeline.

Returns:
-------
results (dict): Result dict from loading pipeline. (same as input)
"""

enable_dump = False
with self.num_dumped_imgs.get_lock():
if self.num_dumped_imgs.value < self.max_imgs:
Expand Down
8 changes: 4 additions & 4 deletions diffengine/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from diffengine.registry import TRANSFORMS


def to_tensor(data):
def to_tensor(data) -> torch.Tensor:
"""Convert objects of various python types to :obj:`torch.Tensor`.

Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
Expand Down Expand Up @@ -45,6 +45,7 @@ class PackInputs(BaseTransform):
All other keys in the dict.

Args:
----
input_keys (List[str]): The key of element to feed into the model
forwarding. Defaults to ['img', 'text'].
skip_to_tensor_key (List[str]): The key of element to skip to_tensor.
Expand All @@ -53,7 +54,7 @@ class PackInputs(BaseTransform):

def __init__(self,
input_keys: list[str] | None = None,
skip_to_tensor_key: list[str] | None = None):
skip_to_tensor_key: list[str] | None = None) -> None:
if skip_to_tensor_key is None:
skip_to_tensor_key = ["text"]
if input_keys is None:
Expand All @@ -62,8 +63,7 @@ def __init__(self,
self.skip_to_tensor_key = skip_to_tensor_key

def transform(self, results: dict) -> dict:
"""Method to pack the input data."""

"""Transform the data."""
packed_results = {}
for k in self.input_keys:
if k in results and k not in self.skip_to_tensor_key:
Expand Down
Loading