Skip to content

Commit 2939140

Browse files
author
Ubuntu
committed
wip
Signed-off-by: Ubuntu <[email protected]>
1 parent 1934e90 commit 2939140

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

tests/unittest/_torch/modules/tests_lora_modules/test_lora_mlp_pytorch_flow_vs_torch.py

+39-18
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,14 @@ def setUpClass(cls):
6262
cls.device = torch.device('cuda')
6363

6464
def _create_mlp_inputs(self):
65+
# Initialize in FP32 first for better numerical stability
6566
hidden_states = torch.rand(
6667
size=[self.batch_size, self.seq_len, self.hidden_size],
67-
dtype=self.torch_dtype,
68-
device='cuda')
68+
dtype=torch.float32,
69+
device='cuda') * 0.01 # Use smaller scale
6970

71+
# Convert to target dtype after initialization
72+
hidden_states = hidden_states.to(self.torch_dtype)
7073
return hidden_states
7174

7275
def _create_lora_params(self):
@@ -83,35 +86,55 @@ def _create_lora_params(self):
8386

8487
# Create weights for up projection
8588
lora_weight_ins_up = [
86-
(torch.rand(self.hidden_size, lora_rank, device=self.device).to(
87-
self.torch_dtype) * 0.1) for lora_rank in lora_ranks_list
89+
# Initialize with FP32 and smaller scale (0.01 instead of 0.1)
90+
torch.rand(self.hidden_size,
91+
lora_rank,
92+
device=self.device,
93+
dtype=torch.float32) * 0.01
94+
for lora_rank in lora_ranks_list
8895
]
8996
lora_weight_outs_up = [
90-
(torch.rand(lora_rank, self.intermediate_size,
91-
device=self.device).to(self.torch_dtype) * 0.1)
97+
torch.rand(lora_rank,
98+
self.intermediate_size,
99+
device=self.device,
100+
dtype=torch.float32) *
101+
(0.01 / max(lora_rank, 1)) # Scale by rank
92102
for lora_rank in lora_ranks_list
93103
]
94104

95105
# Create weights for down projection
96106
lora_weight_ins_down = [
97-
(torch.rand(self.intermediate_size, lora_rank,
98-
device=self.device).to(self.torch_dtype) * 0.1)
107+
torch.rand(self.intermediate_size,
108+
lora_rank,
109+
device=self.device,
110+
dtype=torch.float32) * 0.01
99111
for lora_rank in lora_ranks_list
100112
]
113+
# Apply rank-based scaling to output weights
101114
lora_weight_outs_down = [
102-
(torch.rand(lora_rank, self.hidden_size, device=self.device).to(
103-
self.torch_dtype) * 0.1) for lora_rank in lora_ranks_list
115+
torch.rand(lora_rank,
116+
self.hidden_size,
117+
device=self.device,
118+
dtype=torch.float32) *
119+
(0.01 / max(lora_rank, 1)) # Scale by rank
120+
for lora_rank in lora_ranks_list
104121
]
105122

106-
lora_weight_ins_up = [tmp.contiguous() for tmp in lora_weight_ins_up]
123+
# Convert to target dtype after initialization
124+
lora_weight_ins_up = [
125+
tmp.to(self.torch_dtype).contiguous() for tmp in lora_weight_ins_up
126+
]
107127
lora_weight_outs_up = [
108-
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs_up
128+
tmp.to(self.torch_dtype).transpose(1, 0).contiguous()
129+
for tmp in lora_weight_outs_up
109130
]
110131
lora_weight_ins_down = [
111-
tmp.contiguous() for tmp in lora_weight_ins_down
132+
tmp.to(self.torch_dtype).contiguous()
133+
for tmp in lora_weight_ins_down
112134
]
113135
lora_weight_outs_down = [
114-
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs_down
136+
tmp.to(self.torch_dtype).transpose(1, 0).contiguous()
137+
for tmp in lora_weight_outs_down
115138
]
116139

117140
lora_weights_pointers_up = []
@@ -163,10 +186,8 @@ def _setup_vanilla_pytorch_mlp(self):
163186
return mlp_module
164187

165188
def test_mlp(self):
166-
hidden_states = torch.rand(
167-
size=[self.batch_size, self.seq_len, self.hidden_size],
168-
dtype=self.torch_dtype,
169-
device='cuda')
189+
# Use the _create_mlp_inputs method for consistent initialization
190+
hidden_states = self._create_mlp_inputs()
170191

171192
lora_params = self._create_lora_params()
172193

0 commit comments

Comments
 (0)