Skip to content

Add check for _fsdp_wrap for FSDP2 #3826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

Add check for _fsdp_wrap for FSDP2 #3826

wants to merge 16 commits into from

Conversation

rithwik-db
Copy link
Contributor

@rithwik-db rithwik-db commented Apr 17, 2025

What does this PR do?

Modified existing code to support _fsdp_wrap and _fsdp_wrap_fn similar to how FSDP1 handles it (with the slight caveat that we allow None to be a valid input for skipping the current module and checking descendants).

Reworked module wrapping so that we do a recursive check where we pre-order check the legalization and validation of params to make sure nothing that shouldn't be weight tied is actually tied and then we do a post-order fully_shard on valid parameters.

@rithwik-db rithwik-db requested a review from bowenyang008 April 17, 2025 23:56
@rithwik-db rithwik-db changed the title [WIP] Add check for _fsdp_wrap for FSDP2 Add check for _fsdp_wrap for FSDP2 Apr 23, 2025
@rithwik-db rithwik-db requested a review from dakinggg April 23, 2025 03:29
# This should wrap the model fine, we have to make sure that m1.fsdp_wrap is set to False to not allow
# for general FSDP wrapping to happen.
m1 = DeepNestedModel()
m1._fsdp_wrap = False # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will come back to this once i've looked at the new implementation (just looking at tests right now), but i thought the previous implementation was such that if you set wrap to false on the top it would get skipped. probably im just misunderstanding something while reading this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is what FSDP1 does, but @bowenyang008 mentioned that we should go for a policy that we still recurse even if the root module has this attribute set to false and it seems to make sense in situations where the user would not want all tensors to be dtensors but a mix of tensors and dtensors (in those situations, we can't fully_shard(root_module). To address this, we added the code:

elif current_module == parent_model:
            # Unless the user specifically sets the _fsdp_wrap attribute to False for the parent model,
            # we default to wrapping the parent model.
            ret = True

So unless the user specifically sets the attribute _fsdp_wrap=False on the root module, we still end up wrapping the root model at the end.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, so if the root module and submodules all have _fsdp_wrap = False, the behavior should be the same, yet if root module has _fsdp_wrap = True but submodules has _fsdp_wrap = False (which should have never occurred) we can respect it with FSDP2 wrapper

@@ -141,6 +141,10 @@ def __init__(self, num_features: int, device: Union[str, torch.device], num_clas
net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2)
super().__init__(num_classes=num_classes, module=net)

def add_fsdp_wrap_attribute_to_children(self):
for child in self.module.children():
child._fsdp_wrap = True # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have some test also test: fsdp_wrap_fn? actually this might be a question for @dakinggg, if most of code base just rely on fsdp_wrap_fn, should we just remove support of using _fsdp_wrap?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use both in LLM Foundry (although of course could switch to only using one or the other)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rithwik-db I just discussed with @dakinggg, we can keep supporting both for the moment, but let's raise a Deprecation Warning for use of _fsdp_wrap in generate_default_policy, so we will only support fsdp_wrap_fn in the future.

@@ -141,6 +141,10 @@ def __init__(self, num_features: int, device: Union[str, torch.device], num_clas
net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2)
super().__init__(num_classes=num_classes, module=net)

def add_fsdp_wrap_attribute_to_children(self):
for child in self.module.children():
child._fsdp_wrap = True # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use both in LLM Foundry (although of course could switch to only using one or the other)

Comment on lines +120 to +122
NOTE: We take the recursive approach since .named_parameters() has a weird behavior where if you do
m1.m2.m3.weight = m1.m2.m4.weight and then call m1.named_parameters(), it will only return the FQN for m1.m2.m3.weight
but not m1.m2.m4.weight.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is expected as torch only returns params w/o duplication, so maybe we can simply compare set(m.named_parameters()) is the same before and after the context?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your current implementation is also valid, so either way is fine

Comment on lines +16 to +18
# FSDP2 Weight Tying Functions
# TODO: These functions are all relatively similar to each other, we should consider
# refactoring them in the future to be simpler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, maybe also worth putting the weight tying related ones into a new file as well in a following up PR. We may need a dir for FSDP utils instead

Comment on lines +143 to +144
tying_groups = [fqns for fqns in param_object_to_fqns.values() if len(fqns) > 1]
return tying_groups
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably just compare all groups, e.g., we can also avoid the case the case one param was not tied before but became tied afterwards for some reason

Comment on lines +325 to +329
def check_reshard_after_forward(module: nn.Module):
fsdp_state = module._get_fsdp_state() # type: ignore
param_group = fsdp_state._fsdp_param_group
assert param_group.post_forward_mesh_info is None, \
f'reshard_after_forward should be False, but got {param_group.post_forward_mesh_info}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test, though I am just wondering if we can directly check if forward_hooks are empty (assuming they are empty if reshard_after_forward is False) so we don't depend on these internal states prefixed with _xxx so subject to change

m1 = DeepNestedModel()

def wrap_fn(module: nn.Module):
return {'tacos': False}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ 🌮

Copy link
Contributor

@bowenyang008 bowenyang008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, thanks @rithwik-db!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants