Skip to content

Commit e24a516

Browse files
Update SD3 init parameters (replacing height, width with image_shape) (#1951)
* Replace SD3 `height` and `width` with `image_shape` * Update URI * Revert comment * Update SD3 handle * Replace `height` and `width` with `image_shape` * Update docstrings * Fix CI
1 parent 1283e70 commit e24a516

13 files changed

+34
-46
lines changed

keras_hub/src/models/image_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def normalize_images(x):
234234
input_is_scalar = True
235235
x = ops.image.resize(
236236
x,
237-
(self.backbone.height, self.backbone.width),
237+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
238238
interpolation="nearest",
239239
data_format=data_format,
240240
)

keras_hub/src/models/inpaint.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def normalize(x):
202202
input_is_scalar = True
203203
x = ops.image.resize(
204204
x,
205-
(self.backbone.height, self.backbone.width),
205+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
206206
interpolation="nearest",
207207
data_format=data_format,
208208
)
@@ -240,7 +240,7 @@ def normalize(x):
240240
x = ops.cast(x, "float32")
241241
x = ops.image.resize(
242242
x,
243-
(self.backbone.height, self.backbone.width),
243+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
244244
interpolation="nearest",
245245
data_format=data_format,
246246
)
@@ -303,7 +303,7 @@ def normalize_images(x):
303303
input_is_scalar = True
304304
x = ops.image.resize(
305305
x,
306-
(self.backbone.height, self.backbone.width),
306+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
307307
interpolation="nearest",
308308
data_format=data_format,
309309
)
@@ -323,7 +323,7 @@ def normalize_masks(x):
323323
x = ops.cast(x, "float32")
324324
x = ops.image.resize(
325325
x,
326-
(self.backbone.height, self.backbone.width),
326+
(self.backbone.image_shape[0], self.backbone.image_shape[1]),
327327
interpolation="nearest",
328328
data_format=data_format,
329329
)
@@ -384,8 +384,8 @@ def generate(
384384
385385
Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"`
386386
keys. `"images"` are reference images within a value range of
387-
`[-1.0, 1.0]`, which will be resized to `self.backbone.height` and
388-
`self.backbone.width`, then encoded into latent space by the VAE
387+
`[-1.0, 1.0]`, which will be resized to height and width from
388+
`self.backbone.image_shape`, then encoded into latent space by the VAE
389389
encoder. `"masks"` are mask images with a boolean dtype, where white
390390
pixels are repainted while black pixels are preserved. `"prompts"` are
391391
strings that will be tokenized and encoded by the text encoder.

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ class StableDiffusion3Backbone(Backbone):
215215
model. Defaults to `1000`.
216216
shift: float. The shift value for the timestep schedule. Defaults to
217217
`3.0`.
218-
height: optional int. The output height of the image.
219-
width: optional int. The output width of the image.
218+
image_shape: tuple. The input shape without the batch size. Defaults to
219+
`(1024, 1024, 3)`.
220220
data_format: `None` or str. If specified, either `"channels_last"` or
221221
`"channels_first"`. The ordering of the dimensions in the
222222
inputs. `"channels_last"` corresponds to inputs with shape
@@ -270,23 +270,21 @@ def __init__(
270270
output_channels=3,
271271
num_train_timesteps=1000,
272272
shift=3.0,
273-
height=None,
274-
width=None,
273+
image_shape=(1024, 1024, 3),
275274
data_format=None,
276275
dtype=None,
277276
**kwargs,
278277
):
279-
height = int(height or 1024)
280-
width = int(width or 1024)
281-
if height % 8 != 0 or width % 8 != 0:
282-
raise ValueError(
283-
"`height` and `width` must be divisible by 8. "
284-
f"Received: height={height}, width={width}"
285-
)
286278
data_format = standardize_data_format(data_format)
287279
if data_format != "channels_last":
288280
raise NotImplementedError
289-
image_shape = (height, width, int(vae.input_channels))
281+
height = image_shape[0]
282+
width = image_shape[1]
283+
if height % 8 != 0 or width % 8 != 0:
284+
raise ValueError(
285+
"height and width in `image_shape` must be divisible by 8. "
286+
f"Received: image_shape={image_shape}"
287+
)
290288
latent_shape = (height // 8, width // 8, int(latent_channels))
291289
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
292290
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
@@ -452,8 +450,7 @@ def __init__(
452450
self.output_channels = output_channels
453451
self.num_train_timesteps = num_train_timesteps
454452
self.shift = shift
455-
self.height = height
456-
self.width = width
453+
self.image_shape = image_shape
457454

458455
@property
459456
def latent_shape(self):
@@ -585,8 +582,7 @@ def get_config(self):
585582
"output_channels": self.output_channels,
586583
"num_train_timesteps": self.num_train_timesteps,
587584
"shift": self.shift,
588-
"height": self.height,
589-
"width": self.width,
585+
"image_shape": self.image_shape,
590586
}
591587
)
592588
return config

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
class StableDiffusion3BackboneTest(TestCase):
1313
def setUp(self):
14-
height, width = 64, 64
14+
image_shape = (64, 64, 3)
15+
height, width = image_shape[0], image_shape[1]
1516
vae = VAEBackbone(
1617
[32, 32, 32, 32],
1718
[1, 1, 1, 1],
@@ -36,8 +37,7 @@ def setUp(self):
3637
"vae": vae,
3738
"clip_l": clip_l,
3839
"clip_g": clip_g,
39-
"height": height,
40-
"width": width,
40+
"image_shape": image_shape,
4141
}
4242
self.input_data = {
4343
"images": ops.ones((2, height, width, 3)),
@@ -82,7 +82,6 @@ def test_all_presets(self):
8282
preset=preset,
8383
input_data=self.input_data,
8484
init_kwargs={
85-
"height": self.init_kwargs["height"],
86-
"width": self.init_kwargs["width"],
85+
"image_shape": self.init_kwargs["image_shape"],
8786
},
8887
)

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
2727
Use `generate()` to do image generation.
2828
```python
2929
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
30-
"stable_diffusion_3_medium", height=512, width=512
30+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
3131
)
3232
image_to_image.generate(
3333
{

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def setUp(self):
5555
clip_g=CLIPTextEncoder(
5656
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
5757
),
58-
height=64,
59-
width=64,
58+
image_shape=(64, 64, 3),
6059
)
6160
self.init_kwargs = {
6261
"preprocessor": self.preprocessor,

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint):
2929
reference_image = np.ones((1024, 1024, 3), dtype="float32")
3030
reference_mask = np.ones((1024, 1024), dtype="float32")
3131
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
32-
"stable_diffusion_3_medium", height=512, width=512
32+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
3333
)
3434
inpaint.generate(
3535
reference_image,

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def setUp(self):
5555
clip_g=CLIPTextEncoder(
5656
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
5757
),
58-
height=64,
59-
width=64,
58+
image_shape=(64, 64, 3),
6059
)
6160
self.init_kwargs = {
6261
"preprocessor": self.preprocessor,

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
"path": "stable_diffusion_3",
1414
"model_card": "https://arxiv.org/abs/2110.00476",
1515
},
16-
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/2",
16+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
1717
}
1818
}

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
2727
Use `generate()` to do image generation.
2828
```python
2929
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
30-
"stable_diffusion_3_medium", height=512, width=512
30+
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
3131
)
3232
text_to_image.generate(
3333
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def setUp(self):
5555
clip_g=CLIPTextEncoder(
5656
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
5757
),
58-
height=64,
59-
width=64,
58+
image_shape=(64, 64, 3),
6059
)
6160
self.init_kwargs = {
6261
"preprocessor": self.preprocessor,

keras_hub/src/utils/preset_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs):
563563
backbone_kwargs["dtype"] = kwargs.pop("dtype", None)
564564

565565
# Forward `height` and `width` to backbone when using `TextToImage`.
566-
if "height" in kwargs:
567-
backbone_kwargs["height"] = kwargs.pop("height", None)
568-
if "width" in kwargs:
569-
backbone_kwargs["width"] = kwargs.pop("width", None)
566+
if "image_shape" in kwargs:
567+
backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)
570568

571569
return backbone_kwargs, kwargs
572570

tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def convert_model(preset, height, width):
113113
vae,
114114
clip_l,
115115
clip_g,
116-
height=height,
117-
width=width,
116+
image_shape=(height, width, 3),
118117
name="stable_diffusion_3_backbone",
119118
)
120119
return backbone
@@ -532,8 +531,7 @@ def main(_):
532531

533532
keras_preprocessor.save_to_preset(preset)
534533
# Set the image size to 1024, the same as in huggingface/diffusers.
535-
keras_model.height = 1024
536-
keras_model.width = 1024
534+
keras_model.image_shape = (1024, 1024, 3)
537535
keras_model.save_to_preset(preset)
538536
print(f"🏁 Preset saved to ./{preset}.")
539537

0 commit comments

Comments
 (0)