@@ -62,11 +62,14 @@ def setUpClass(cls):
62
62
cls .device = torch .device ('cuda' )
63
63
64
64
def _create_mlp_inputs (self ):
65
+ # Initialize in FP32 first for better numerical stability
65
66
hidden_states = torch .rand (
66
67
size = [self .batch_size , self .seq_len , self .hidden_size ],
67
- dtype = self . torch_dtype ,
68
- device = 'cuda' )
68
+ dtype = torch . float32 ,
69
+ device = 'cuda' ) * 0.01 # Use smaller scale
69
70
71
+ # Convert to target dtype after initialization
72
+ hidden_states = hidden_states .to (self .torch_dtype )
70
73
return hidden_states
71
74
72
75
def _create_lora_params (self ):
@@ -83,35 +86,55 @@ def _create_lora_params(self):
83
86
84
87
# Create weights for up projection
85
88
lora_weight_ins_up = [
86
- (torch .rand (self .hidden_size , lora_rank , device = self .device ).to (
87
- self .torch_dtype ) * 0.1 ) for lora_rank in lora_ranks_list
89
+ # Initialize with FP32 and smaller scale (0.01 instead of 0.1)
90
+ torch .rand (self .hidden_size ,
91
+ lora_rank ,
92
+ device = self .device ,
93
+ dtype = torch .float32 ) * 0.01
94
+ for lora_rank in lora_ranks_list
88
95
]
89
96
lora_weight_outs_up = [
90
- (torch .rand (lora_rank , self .intermediate_size ,
91
- device = self .device ).to (self .torch_dtype ) * 0.1 )
97
+ torch .rand (lora_rank ,
98
+ self .intermediate_size ,
99
+ device = self .device ,
100
+ dtype = torch .float32 ) *
101
+ (0.01 / max (lora_rank , 1 )) # Scale by rank
92
102
for lora_rank in lora_ranks_list
93
103
]
94
104
95
105
# Create weights for down projection
96
106
lora_weight_ins_down = [
97
- (torch .rand (self .intermediate_size , lora_rank ,
98
- device = self .device ).to (self .torch_dtype ) * 0.1 )
107
+ torch .rand (self .intermediate_size ,
108
+ lora_rank ,
109
+ device = self .device ,
110
+ dtype = torch .float32 ) * 0.01
99
111
for lora_rank in lora_ranks_list
100
112
]
113
+ # Apply rank-based scaling to output weights
101
114
lora_weight_outs_down = [
102
- (torch .rand (lora_rank , self .hidden_size , device = self .device ).to (
103
- self .torch_dtype ) * 0.1 ) for lora_rank in lora_ranks_list
115
+ torch .rand (lora_rank ,
116
+ self .hidden_size ,
117
+ device = self .device ,
118
+ dtype = torch .float32 ) *
119
+ (0.01 / max (lora_rank , 1 )) # Scale by rank
120
+ for lora_rank in lora_ranks_list
104
121
]
105
122
106
- lora_weight_ins_up = [tmp .contiguous () for tmp in lora_weight_ins_up ]
123
+ # Convert to target dtype after initialization
124
+ lora_weight_ins_up = [
125
+ tmp .to (self .torch_dtype ).contiguous () for tmp in lora_weight_ins_up
126
+ ]
107
127
lora_weight_outs_up = [
108
- tmp .transpose (1 , 0 ).contiguous () for tmp in lora_weight_outs_up
128
+ tmp .to (self .torch_dtype ).transpose (1 , 0 ).contiguous ()
129
+ for tmp in lora_weight_outs_up
109
130
]
110
131
lora_weight_ins_down = [
111
- tmp .contiguous () for tmp in lora_weight_ins_down
132
+ tmp .to (self .torch_dtype ).contiguous ()
133
+ for tmp in lora_weight_ins_down
112
134
]
113
135
lora_weight_outs_down = [
114
- tmp .transpose (1 , 0 ).contiguous () for tmp in lora_weight_outs_down
136
+ tmp .to (self .torch_dtype ).transpose (1 , 0 ).contiguous ()
137
+ for tmp in lora_weight_outs_down
115
138
]
116
139
117
140
lora_weights_pointers_up = []
@@ -163,10 +186,8 @@ def _setup_vanilla_pytorch_mlp(self):
163
186
return mlp_module
164
187
165
188
def test_mlp (self ):
166
- hidden_states = torch .rand (
167
- size = [self .batch_size , self .seq_len , self .hidden_size ],
168
- dtype = self .torch_dtype ,
169
- device = 'cuda' )
189
+ # Use the _create_mlp_inputs method for consistent initialization
190
+ hidden_states = self ._create_mlp_inputs ()
170
191
171
192
lora_params = self ._create_lora_params ()
172
193
0 commit comments