@@ -48,8 +48,7 @@ def test_convert_diffusers_instantx_state_dict_to_bfl_format():
48
48
49
49
def test_infer_flux_params_from_state_dict ():
50
50
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
51
- with torch .device ("meta" ):
52
- sd = {k : torch .zeros (v ) for k , v in instantx_sd_shapes .items ()}
51
+ sd = {k : torch .zeros (v , device = "meta" ) for k , v in instantx_sd_shapes .items ()}
53
52
54
53
sd = convert_diffusers_instantx_state_dict_to_bfl_format (sd )
55
54
flux_params = infer_flux_params_from_state_dict (sd )
@@ -70,8 +69,7 @@ def test_infer_flux_params_from_state_dict():
70
69
71
70
def test_infer_instantx_num_control_modes_from_state_dict ():
72
71
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
73
- with torch .device ("meta" ):
74
- sd = {k : torch .zeros (v ) for k , v in instantx_sd_shapes .items ()}
72
+ sd = {k : torch .zeros (v , device = "meta" ) for k , v in instantx_sd_shapes .items ()}
75
73
76
74
sd = convert_diffusers_instantx_state_dict_to_bfl_format (sd )
77
75
num_control_modes = infer_instantx_num_control_modes_from_state_dict (sd )
@@ -81,8 +79,7 @@ def test_infer_instantx_num_control_modes_from_state_dict():
81
79
82
80
def test_load_instantx_from_state_dict ():
83
81
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
84
- with torch .device ("meta" ):
85
- sd = {k : torch .zeros (v ) for k , v in instantx_sd_shapes .items ()}
82
+ sd = {k : torch .zeros (v , device = "meta" ) for k , v in instantx_sd_shapes .items ()}
86
83
87
84
sd = convert_diffusers_instantx_state_dict_to_bfl_format (sd )
88
85
flux_params = infer_flux_params_from_state_dict (sd )
0 commit comments