Skip to content

Commit ca9cb1c

Browse files
authored
Flux Vae broke for float16, force bfloat16 or float32 were compatible (#7213)
## Summary The Flux VAE, like many VAEs, is broken if run using float16 inputs returning black images due to NaNs This will fix the issue by forcing the VAE to run in bfloat16 or float32 were compatible ## Related Issues / Discussions Fix for issue #7208 ## QA Instructions Tested on MacOS, VAE works with float16 in the invoke.yaml and left to default. I also briefly forced it down the float32 route to check that to. Needs testing on CUDA / ROCm ## Merge Plan It should be a straight forward merge,
2 parents fb19621 + b89caa0 commit ca9cb1c

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

invokeai/app/invocations/flux_vae_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
4141
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
4242
with vae_info as vae:
4343
assert isinstance(vae, AutoEncoder)
44-
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
44+
vae_dtype = next(iter(vae.parameters())).dtype
45+
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
4546
img = vae.decode(latents)
4647

4748
img = img.clamp(-1, 1)

invokeai/app/invocations/flux_vae_encode.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso
4444
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
4545
with vae_info as vae:
4646
assert isinstance(vae, AutoEncoder)
47-
image_tensor = image_tensor.to(
48-
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
49-
)
47+
vae_dtype = next(iter(vae.parameters())).dtype
48+
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
5049
latents = vae.encode(image_tensor, sample=True, generator=generator)
5150
return latents
5251

invokeai/backend/model_manager/load/load_default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
self._logger = logger
3636
self._ram_cache = ram_cache
3737
self._torch_dtype = TorchDevice.choose_torch_dtype()
38+
self._torch_device = TorchDevice.choose_torch_device()
3839

3940
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
4041
"""

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,15 @@ def _load_model(
8484
model = AutoEncoder(ae_params[config.config_path])
8585
sd = load_file(model_path)
8686
model.load_state_dict(sd, assign=True)
87-
model.to(dtype=self._torch_dtype)
87+
# VAE is broken in float16, which mps defaults to
88+
if self._torch_dtype == torch.float16:
89+
try:
90+
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
91+
except TypeError:
92+
vae_dtype = torch.float32
93+
else:
94+
vae_dtype = self._torch_dtype
95+
model.to(vae_dtype)
8896

8997
return model
9098

0 commit comments

Comments
 (0)