Skip to content

Replace module hooking with tree-defined targeting #1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2025
Merged

Conversation

Qubitium
Copy link
Collaborator

@Qubitium Qubitium commented Apr 10, 2025

Even though existing code does not error, we should never ever hook/wrap a module that will never particpates in quantization. This doesn't make any sense. The fix is much more complicated and likely require for 2-3 PRs to refractor.

Currently add optional pin-point module targeting via tree of base.layers_modules_tree. This will play nicely into more generic multi-modal support as mm models embeds up to 3 separate models essentially into 1. The existing static-non-tree based config will be harder to support.

Tree syntax

# Full tree of quantizable modules
# `#` str will match any number: useful for layers and moe indexing.
# List[str] for serial linked nodes. List str are linear depth linked modules presented in a linear fashion with no divergence.
# Dict{str: List[str] | Dict | Tuple[str]} for diverging nodes where a node splits into multiple paths/nodes.
# Tuple(str) for final targeted modules/nodes: there are only strings representing the final targeted modules
layers_modules_tree = [
    "model",
    "layers",
    "#",
    {
        "self_attn": ("k_proj", "v_proj", "q_proj", "o_proj"),
        "mlp": ("up_proj", "gate_proj", "down_proj"),
    }
]

@Qubitium
Copy link
Collaborator Author

@Cecilwang While testing for conv1d/2d support I found deep issues/oversight with the current hooking targeting mechnaism so this Pr kind of spiraled out of control. But the code relevant to you is minimal. You just need to add the layer_modules_tree definition to your model for pinpoint control of hooks and check if the hooked_linear.py changes is usable for your model. I could not find a model that has nn.conv1d or nn.conv2d that is actually quantized. Let me know if you know any, other than mamba which I don't think we support yet.

@Qubitium Qubitium changed the title Refractor replace_linear with replace_module but using tree guide Refractor replace_linear with tree-defined targeting Apr 10, 2025
@Qubitium Qubitium changed the title Refractor replace_linear with tree-defined targeting Replace module hooking with tree-defined targeting Apr 10, 2025
Signed-off-by: Qubitium <[email protected]>
Signed-off-by: Qubitium <[email protected]>
Signed-off-by: Qubitium <[email protected]>
@Qubitium Qubitium merged commit 444b277 into main Apr 11, 2025
40 of 111 checks passed
@Qubitium Qubitium deleted the fix-hooked-conv12D branch April 11, 2025 01:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant