@@ -791,7 +791,20 @@ def attribute_future(
791
791
)
792
792
793
793
if enable_cross_tensor_attribution :
794
- raise NotImplementedError ("Not supported yet" )
794
+ # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
795
+ # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
796
+ # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
797
+ return self ._attribute_with_cross_tensor_feature_masks_future ( # type: ignore # noqa: E501 line too long
798
+ formatted_inputs = formatted_inputs ,
799
+ formatted_additional_forward_args = formatted_additional_forward_args ,
800
+ target = target ,
801
+ baselines = baselines ,
802
+ formatted_feature_mask = formatted_feature_mask ,
803
+ attr_progress = attr_progress ,
804
+ processed_initial_eval_fut = processed_initial_eval_fut ,
805
+ is_inputs_tuple = is_inputs_tuple ,
806
+ perturbations_per_eval = perturbations_per_eval ,
807
+ )
795
808
else :
796
809
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797
810
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
@@ -921,6 +934,213 @@ def _attribute_with_independent_feature_masks_future(
921
934
922
935
return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
923
936
937
+ def _attribute_with_cross_tensor_feature_masks_future (
938
+ self ,
939
+ formatted_inputs : Tuple [Tensor , ...],
940
+ formatted_additional_forward_args : Optional [Tuple [object , ...]],
941
+ target : TargetType ,
942
+ baselines : BaselineType ,
943
+ formatted_feature_mask : Tuple [Tensor , ...],
944
+ attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
945
+ processed_initial_eval_fut : Future [
946
+ Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
947
+ ],
948
+ is_inputs_tuple : bool ,
949
+ perturbations_per_eval : int ,
950
+ ** kwargs : Any ,
951
+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
952
+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
953
+ for i , mask in enumerate (formatted_feature_mask ):
954
+ for feature_idx in torch .unique (mask ):
955
+ if feature_idx .item () not in feature_idx_to_tensor_idx :
956
+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
957
+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
958
+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
959
+
960
+ additional_args_repeated : object
961
+ if perturbations_per_eval > 1 :
962
+ # Repeat features and additional args for batch size.
963
+ all_features_repeated = tuple (
964
+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
965
+ for j in range (len (formatted_inputs ))
966
+ )
967
+ additional_args_repeated = (
968
+ _expand_additional_forward_args (
969
+ formatted_additional_forward_args , perturbations_per_eval
970
+ )
971
+ if formatted_additional_forward_args is not None
972
+ else None
973
+ )
974
+ target_repeated = _expand_target (target , perturbations_per_eval )
975
+ else :
976
+ all_features_repeated = formatted_inputs
977
+ additional_args_repeated = formatted_additional_forward_args
978
+ target_repeated = target
979
+ num_examples = formatted_inputs [0 ].shape [0 ]
980
+
981
+ current_additional_args : object
982
+ if isinstance (baselines , tuple ):
983
+ reshaped = False
984
+ reshaped_baselines : list [Union [Tensor , int , float ]] = []
985
+ for baseline in baselines :
986
+ if isinstance (baseline , Tensor ):
987
+ reshaped = True
988
+ reshaped_baselines .append (
989
+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
990
+ )
991
+ else :
992
+ reshaped_baselines .append (baseline )
993
+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
994
+
995
+ all_modified_eval_futures : List [Future [Tuple [List [Tensor ], List [Tensor ]]]] = []
996
+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
997
+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
998
+ current_num_ablated_features = min (
999
+ perturbations_per_eval , len (current_feature_idxs )
1000
+ )
1001
+
1002
+ should_skip = False
1003
+ all_empty = True
1004
+ tensor_idx_list = []
1005
+ for feature_idx in current_feature_idxs :
1006
+ tensor_idx_list += feature_idx_to_tensor_idx [feature_idx ]
1007
+ for tensor_idx in set (tensor_idx_list ):
1008
+ if all_empty and torch .numel (formatted_inputs [tensor_idx ]) != 0 :
1009
+ all_empty = False
1010
+ if self ._min_examples_per_batch_grouped is not None and (
1011
+ formatted_inputs [tensor_idx ].shape [0 ]
1012
+ # pyre-ignore[58]: Type has been narrowed to int
1013
+ < self ._min_examples_per_batch_grouped
1014
+ ):
1015
+ should_skip = True
1016
+ break
1017
+ if all_empty :
1018
+ logger .info (
1019
+ f"Skipping feature group { current_feature_idxs } since all "
1020
+ f"input tensors are empty"
1021
+ )
1022
+ continue
1023
+
1024
+ if should_skip :
1025
+ logger .warning (
1026
+ f"Skipping feature group { current_feature_idxs } since it contains "
1027
+ f"at least one input tensor with 0th dim less than "
1028
+ f"{ self ._min_examples_per_batch_grouped } "
1029
+ )
1030
+ continue
1031
+
1032
+ # Store appropriate inputs and additional args based on batch size.
1033
+ if current_num_ablated_features != perturbations_per_eval :
1034
+ current_additional_args = (
1035
+ _expand_additional_forward_args (
1036
+ formatted_additional_forward_args , current_num_ablated_features
1037
+ )
1038
+ if formatted_additional_forward_args is not None
1039
+ else None
1040
+ )
1041
+ current_target = _expand_target (target , current_num_ablated_features )
1042
+ expanded_inputs = tuple (
1043
+ feature_repeated [0 : current_num_ablated_features * num_examples ]
1044
+ for feature_repeated in all_features_repeated
1045
+ )
1046
+ else :
1047
+ current_additional_args = additional_args_repeated
1048
+ current_target = target_repeated
1049
+ expanded_inputs = all_features_repeated
1050
+
1051
+ current_inputs , current_masks = (
1052
+ self ._construct_ablated_input_across_tensors (
1053
+ expanded_inputs ,
1054
+ formatted_feature_mask ,
1055
+ baselines ,
1056
+ current_feature_idxs ,
1057
+ feature_idx_to_tensor_idx ,
1058
+ current_num_ablated_features ,
1059
+ )
1060
+ )
1061
+
1062
+ # modified_eval has (n_feature_perturbed * n_outputs) elements
1063
+ # shape:
1064
+ # agg mode: (*initial_eval.shape)
1065
+ # non-agg mode:
1066
+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
1067
+ modified_eval = _run_forward (
1068
+ self .forward_func ,
1069
+ current_inputs ,
1070
+ current_target ,
1071
+ current_additional_args ,
1072
+ )
1073
+
1074
+ if attr_progress is not None :
1075
+ attr_progress .update ()
1076
+
1077
+ if not isinstance (modified_eval , torch .Future ):
1078
+ raise AssertionError (
1079
+ "when using attribute_future, modified_eval should have "
1080
+ f"Future type rather than { type (modified_eval )} "
1081
+ )
1082
+
1083
+ # Need to collect both initial eval and modified_eval
1084
+ eval_futs : Future [
1085
+ List [
1086
+ Future [
1087
+ Union [
1088
+ Tuple [
1089
+ List [Tensor ],
1090
+ List [Tensor ],
1091
+ Tensor ,
1092
+ Tensor ,
1093
+ int ,
1094
+ dtype ,
1095
+ ],
1096
+ Tensor ,
1097
+ ]
1098
+ ]
1099
+ ]
1100
+ ] = collect_all (
1101
+ [
1102
+ processed_initial_eval_fut ,
1103
+ modified_eval ,
1104
+ ]
1105
+ )
1106
+
1107
+ ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = eval_futs .then (
1108
+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_masks , i = i : self ._eval_fut_to_ablated_out_fut_cross_tensor ( # type: ignore # noqa: E501 line too long
1109
+ eval_futs = eval_futs ,
1110
+ current_inputs = current_inputs ,
1111
+ current_mask = current_mask ,
1112
+ perturbations_per_eval = perturbations_per_eval ,
1113
+ num_examples = num_examples ,
1114
+ )
1115
+ )
1116
+
1117
+ all_modified_eval_futures .append (ablated_out_fut )
1118
+
1119
+ if attr_progress is not None :
1120
+ attr_progress .close ()
1121
+
1122
+ return self ._generate_async_result_cross_tensor (
1123
+ all_modified_eval_futures ,
1124
+ is_inputs_tuple ,
1125
+ )
1126
+
1127
+ def _fut_tuple_to_accumulate_fut_list_cross_tensor (
1128
+ self ,
1129
+ total_attrib : List [Tensor ],
1130
+ weights : List [Tensor ],
1131
+ fut_tuple : Future [Tuple [List [Tensor ], List [Tensor ]]],
1132
+ ) -> None :
1133
+ try :
1134
+ # process_ablated_out_* already accumlates the total attribution.
1135
+ # Just get the latest value
1136
+ attribs , this_weights = fut_tuple .value ()
1137
+ total_attrib [:] = attribs
1138
+ weights [:] = this_weights
1139
+ except FeatureAblationFutureError as e :
1140
+ raise FeatureAblationFutureError (
1141
+ "_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
1142
+ ) from e
1143
+
924
1144
# pyre-fixme[3] return type must be annotated
925
1145
def _attribute_progress_setup (
926
1146
self ,
@@ -950,7 +1170,6 @@ def _attribute_progress_setup(
950
1170
951
1171
def _eval_fut_to_ablated_out_fut (
952
1172
self ,
953
- # pyre-ignore Invalid type parameters [24]
954
1173
eval_futs : Future [List [Future [List [object ]]]],
955
1174
current_inputs : Tuple [Tensor , ...],
956
1175
current_mask : Tensor ,
@@ -1012,6 +1231,94 @@ def _eval_fut_to_ablated_out_fut(
1012
1231
) from e
1013
1232
return result
1014
1233
1234
+ def _generate_async_result_cross_tensor (
1235
+ self ,
1236
+ futs : List [Future [Tuple [List [Tensor ], List [Tensor ]]]],
1237
+ is_inputs_tuple : bool ,
1238
+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
1239
+ accumulate_fut_list : List [Future [None ]] = []
1240
+ total_attrib : List [Tensor ] = []
1241
+ weights : List [Tensor ] = []
1242
+
1243
+ for fut_tuple in futs :
1244
+ accumulate_fut_list .append (
1245
+ fut_tuple .then (
1246
+ lambda fut_tuple : self ._fut_tuple_to_accumulate_fut_list_cross_tensor ( # noqa: E501 line too long
1247
+ total_attrib , weights , fut_tuple
1248
+ )
1249
+ )
1250
+ )
1251
+
1252
+ result_fut = collect_all (accumulate_fut_list ).then (
1253
+ lambda x : self ._generate_result (
1254
+ total_attrib ,
1255
+ weights ,
1256
+ is_inputs_tuple ,
1257
+ )
1258
+ )
1259
+
1260
+ return result_fut
1261
+
1262
+ def _eval_fut_to_ablated_out_fut_cross_tensor (
1263
+ self ,
1264
+ eval_futs : Future [List [Future [List [object ]]]],
1265
+ current_inputs : Tuple [Tensor , ...],
1266
+ current_mask : Tuple [Optional [Tensor ], ...],
1267
+ perturbations_per_eval : int ,
1268
+ num_examples : int ,
1269
+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
1270
+ try :
1271
+ modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
1272
+ initial_eval_tuple = cast (
1273
+ Tuple [
1274
+ List [Tensor ],
1275
+ List [Tensor ],
1276
+ Tensor ,
1277
+ Tensor ,
1278
+ int ,
1279
+ dtype ,
1280
+ ],
1281
+ eval_futs .value ()[0 ].value (),
1282
+ )
1283
+ if len (initial_eval_tuple ) != 6 :
1284
+ raise AssertionError (
1285
+ "eval_fut_to_ablated_out_fut_cross_tensor: "
1286
+ "initial_eval_tuple should have 6 elements: "
1287
+ "total_attrib, weights, initial_eval, "
1288
+ "flattened_initial_eval, n_outputs, attrib_type "
1289
+ )
1290
+ if not isinstance (modified_eval , Tensor ):
1291
+ raise AssertionError (
1292
+ "_eval_fut_to_ablated_out_fut_cross_tensor: "
1293
+ "modified eval should be a Tensor"
1294
+ )
1295
+ (
1296
+ total_attrib ,
1297
+ weights ,
1298
+ initial_eval ,
1299
+ flattened_initial_eval ,
1300
+ n_outputs ,
1301
+ attrib_type ,
1302
+ ) = initial_eval_tuple
1303
+ total_attrib , weights = self ._process_ablated_out_full (
1304
+ modified_eval = modified_eval ,
1305
+ inputs = current_inputs ,
1306
+ current_mask = current_mask ,
1307
+ perturbations_per_eval = perturbations_per_eval ,
1308
+ num_examples = num_examples ,
1309
+ initial_eval = initial_eval ,
1310
+ flattened_initial_eval = flattened_initial_eval ,
1311
+ n_outputs = n_outputs ,
1312
+ total_attrib = total_attrib ,
1313
+ weights = weights ,
1314
+ attrib_type = attrib_type ,
1315
+ )
1316
+ except FeatureAblationFutureError as e :
1317
+ raise FeatureAblationFutureError (
1318
+ "_eval_fut_to_ablated_out_fut_cross_tensor func failed"
1319
+ ) from e
1320
+ return total_attrib , weights
1321
+
1015
1322
def _ith_input_ablation_generator (
1016
1323
self ,
1017
1324
i : int ,
0 commit comments