Skip to content

Add sana sprint #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions Sana/README_sCM.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# SCM (Smart Consistency Model) Sampler for ComfyUI

A custom sampler implementation for SCM models based on ComfyUI's LCM architecture.

## Overview

This implementation provides:
- **SCM Sampling Algorithm**: Uses trigflow timesteps (1.57080 to 0 range)
- **CFG Scale Support**: Direct CFG scale passing to model forward method
- **SanaMSCM Integration**: Optimized for SanaMSCM models
- **ComfyUI Compatible**: Seamless integration with existing workflows

## Features

- ✅ Custom SCM sampling with trigflow parameterization
- ✅ Bypass ComfyUI's standard CFG mixing
- ✅ Direct CFG scale injection to model
- ✅ Multi-step sampling support (2-step and N-step)
- ✅ Compatible with SanaMSCM models

## Installation

1. Place files in your ComfyUI custom nodes directory:
```
ComfyUI/custom_nodes/ComfyUI_ExtraModels/Sana/
├── scm_sampler.py
├── models/sana_multi_scale.py
└── loader.py
```

2. Restart ComfyUI

3. The "scm" sampler will be automatically registered

## Usage

### Basic Workflow

```
SanaCheckpointLoader → ScmModelSampling → KSampler → VAEDecode → Image
```

### Node Configuration

#### 1. ScmModelSampling Node
- **Input**: Model from SanaCheckpointLoader
- **cfg_scale**: CFG scale value (default: 4.5, range: 0-30)
- **zsnr**: Zero SNR option (default: False)
- **Output**: Patched model with SCM sampling

#### 2. KSampler Settings
- **sampler_name**: Select "scm"
- **cfg**: Set to **1.0** (important: disables KSampler's CFG)
- **steps**: Number of sampling steps
- **scheduler**: Any scheduler (normal/karras/etc.)

### Example Workflow
```
1. Load model with SanaCheckpointLoader
2. Connect to ScmModelSampling (cfg_scale=7.0)
3. Connect to KSampler:
- sampler_name: "scm"
- cfg: 1.0
- steps: 4-20
4. Connect to VAEDecode for final image
```

## Technical Details

### SCM Algorithm
- **Timestep Range**: 1.57080 → 0 (trigflow parameterization)
- **Denoising Formula**: `pred_x0 = cos(s) * x - sin(s) * model_output`
- **Special Cases**:
- 2 steps: [1.57080, 1.3, 0]
- N steps: Linear interpolation

### CFG Handling
- CFG scale bypasses ComfyUI's standard mixing
- Direct injection via `transformer_options`
- Model handles CFG internally (no output mixing)

### Model Integration
The SCM sampler works with models that implement:
```python
def forward_raw(self, x, timestep, y, **kwargs):
# Get CFG scale from transformer_options
cfg_scale = kwargs.get("transformer_options", {}).get("cfg_scale", 4.5)
# Your model logic with cfg_scale
return output
```

## Supported Models

- **SanaMSCM**: Primary target model
- **Custom SCM Models**: Any model implementing SCM forward interface

## Configuration Options

| Parameter | Description | Default | Range |
|-----------|-------------|---------|-------|
| cfg_scale | CFG guidance scale | 4.5 | 0.0-30.0 |
| zsnr | Zero signal-to-noise ratio | False | Boolean |
| steps | Sampling steps | - | 2-50 |

## Performance Notes

- **Recommended Steps**: 4-20 for most cases
- **2-Step Mode**: Special optimized timesteps
- **Memory Efficient**: No additional CFG computation overhead
- **Speed**: Faster than standard diffusion sampling

## Troubleshooting

### Common Issues

**Problem**: CFG scale not working
- **Solution**: Ensure KSampler cfg=1.0 and model supports transformer_options

**Problem**: "scm" sampler not found
- **Solution**: Restart ComfyUI after installing files

**Problem**: Wrong output quality
- **Solution**: Check cfg_scale value and sampling steps

### Debug Information

Enable debug prints by modifying the sampler:
```python
print(f"SCM timesteps: {timesteps}")
print(f"CFG scale: {cfg_scale}")
```

## Architecture

```
ScmModelSampling (Node)
↓ (patches model)
Model with SCM sampling
↓ (cfg_scale in model_options)
SCM Sampler Function
↓ (transformer_options)
SanaMSCM Forward Method
↓ (uses cfg_scale)
Model Output
```

## Limitations

- Requires compatible SCM models
- CFG scale must be handled by model internally
- Limited to trigflow parameterization
- KSampler CFG must be disabled (set to 1.0)

## Contributing

Feel free to submit issues and improvements for:
- Additional SCM model support
- Performance optimizations
- New sampling schedules
- Bug fixes

## License

Same as ComfyUI license terms.
6 changes: 6 additions & 0 deletions Sana/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from .scm_sampler import register_scm_sampler
register_scm_sampler()
print("SCM sampler registered successfully")
except Exception as e:
print(f"Failed to register SCM sampler: {e}")
52 changes: 52 additions & 0 deletions Sana/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,58 @@
},
"sampling_settings" : sampling_settings,
},
"SanaSprint_1600M_P1_D20": {
"target": "SanaMSCM",
"unet_config": {
"in_channels": 32,
"depth": 20,
"hidden_size": 2240,
"patch_size": 1,
"num_heads": 20,
"linear_head_dim": 32,
"model_max_length": 300,
"y_norm": True,
"qk_norm": True,
"cross_norm": True,
"attn_type": "linear",
"ffn_type": "glumbconv",
"mlp_ratio": 2.5,
"mlp_acts": ["silu", "silu", None],
"use_pe": False,
"pred_sigma": False,
"learn_sigma": False,
"fp32_attention": True,
"cross_attn_type": "vanilla",
"cfg_embed": True,
"cfg_embed_scale": 0.1,
"timestep_norm_scale_factor": 1000,
},
},
"SanaSprint_600M_P1_D28": {
"target": "SanaMSCM",
"unet_config": {
"in_channels": 32,
"depth": 28,
"hidden_size": 1152,
"patch_size": 1,
"num_heads": 16,
"linear_head_dim": 32,
"model_max_length": 300,
"y_norm": True,
"attn_type": "linear",
"ffn_type": "glumbconv",
"mlp_ratio": 2.5,
"mlp_acts": ["silu", "silu", None],
"use_pe": False,
"pred_sigma": False,
"learn_sigma": False,
"fp32_attention": True,
"cross_attn_type": "vanilla",
"cfg_embed": True,
"cfg_embed_scale": 0.1,
"timestep_norm_scale_factor": 1000,
},
},
}

sana_res = {
Expand Down
3 changes: 3 additions & 0 deletions Sana/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def load_sana(model_path, model_conf, dtype):
if model_conf.model_target == "SanaMS":
from .models.sana_multi_scale import SanaMS
model.diffusion_model = SanaMS(**model_conf.unet_config)
elif model_conf.model_target == "SanaMSCM":
from .models.sana_multi_scale import SanaMSCM
model.diffusion_model = SanaMSCM(**model_conf.unet_config)
else:
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")

Expand Down
6 changes: 6 additions & 0 deletions Sana/models/sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def __init__(
patch_embed_kernel=None,
mlp_acts=("silu", "silu", None),
linear_head_dim=32,
cfg_embed=False,
timestep_norm_scale_factor=1.0,
**kwargs,
):
super().__init__()
Expand All @@ -184,12 +186,16 @@ def __init__(
self.y_norm = y_norm
self.model_max_length = model_max_length
self.fp32_attention = kwargs.get("use_fp32_attention", False)
self.timestep_norm_scale_factor = timestep_norm_scale_factor

kernel_size = patch_embed_kernel or patch_size
self.x_embedder = PatchEmbed(
input_size, patch_size, in_channels, hidden_size, kernel_size=kernel_size, bias=True
)
self.t_embedder = TimestepEmbedder(hidden_size)
self.cfg_embedder = None
if cfg_embed:
self.cfg_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size // self.patch_size
# Will use fixed sin-cos embedding:
Expand Down
79 changes: 79 additions & 0 deletions Sana/models/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,85 @@ def forward(self, x, cond, mask=None):
return x


class MultiHeadCrossVallinaAttention(MultiHeadCrossAttention):
@staticmethod
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
) -> torch.Tensor:
B, H, L, S = *query.size()[:-1], key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value

def forward(self, x, cond, mask=None):
# query: img tokens; key/value: condition; mask: if padding tokens
B, N, C = x.shape

q = self.q_linear(x)
kv = self.kv_linear(cond).view(B, -1, 2, C)
k, v = kv.unbind(2)
q = self.q_norm(q).view(B, -1, self.num_heads, self.head_dim)
k = self.k_norm(k).view(B, -1, self.num_heads, self.head_dim)
v = v.view(B, -1, self.num_heads, self.head_dim)

# Cast for sCM
dtype = q.dtype
q, k, v = q.float(), k.float(), v.float()

# vanilla attention
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

attn_mask = None
if mask is not None and len(mask) > 1:
# Create equivalent of xformer diagonal block mask, still only correct for square masks
# But depth doesn't matter as tensors can expand in that dimension
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask = torch.block_diag(attn_mask_template)

# create a mask on the diagonal for each mask in the batch
for _ in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
elif mask is not None and len(mask) == 1:
# Handle single mask length case - all batches use the same mask length
attn_mask_template = torch.ones(
[q.shape[2] // B, mask[0]],
dtype=torch.bool,
device=q.device
)
attn_mask = torch.block_diag(attn_mask_template)

# create a mask on the diagonal for each mask in the batch (all same length)
for _ in range(B - 1):
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
elif mask is not None and mask.ndim == 2:
# Handle 2D mask case (original logic)
mask = (1 - mask.to(q.dtype)) * -10000.0
attn_mask = mask[:, None, None].repeat(1, self.num_heads, 1, 1)

x = self.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.to(dtype)
x = x.transpose(1, 2).contiguous()

x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)

return x

class LiteLA(Attention_):
r"""Lightweight linear attention"""

Expand Down
Loading