Skip to content

Commit c8685c0

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Futures support with cross-tensor attribution 2/n
Summary: A lot of copypasta from `_attribute_with_cross_tensor_feature_masks()`, refactored in the next diff in the stack. Keeping the copying transparent in this diff and refactoring in another diff to make it easier for review (imo). The original `attribute_future()` has this structure: ``` process, format input, etc get initial eval (future) initialize list of ablated output futures for i in perturbed batch generator: get modified eval (future) add callback when modified eval & init eval are completed - async wrapper around _process_ablated_out _generate_async_result(list of all ablated output futures) ``` which is largely preserved in this diff for the new way of creating perturbed inputs. Reviewed By: styusuf Differential Revision: D73466780 fbshipit-source-id: 2b52aa5c96ec42387fc175eb1adb23b6e15cb360
1 parent 1f8f6f3 commit c8685c0

File tree

1 file changed

+309
-2
lines changed

1 file changed

+309
-2
lines changed

captum/attr/_core/feature_ablation.py

+309-2
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,20 @@ def attribute_future(
791791
)
792792

793793
if enable_cross_tensor_attribution:
794-
raise NotImplementedError("Not supported yet")
794+
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
795+
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
796+
# `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
797+
return self._attribute_with_cross_tensor_feature_masks_future( # type: ignore # noqa: E501 line too long
798+
formatted_inputs=formatted_inputs,
799+
formatted_additional_forward_args=formatted_additional_forward_args,
800+
target=target,
801+
baselines=baselines,
802+
formatted_feature_mask=formatted_feature_mask,
803+
attr_progress=attr_progress,
804+
processed_initial_eval_fut=processed_initial_eval_fut,
805+
is_inputs_tuple=is_inputs_tuple,
806+
perturbations_per_eval=perturbations_per_eval,
807+
)
795808
else:
796809
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797810
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
@@ -921,6 +934,213 @@ def _attribute_with_independent_feature_masks_future(
921934

922935
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
923936

937+
def _attribute_with_cross_tensor_feature_masks_future(
938+
self,
939+
formatted_inputs: Tuple[Tensor, ...],
940+
formatted_additional_forward_args: Optional[Tuple[object, ...]],
941+
target: TargetType,
942+
baselines: BaselineType,
943+
formatted_feature_mask: Tuple[Tensor, ...],
944+
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
945+
processed_initial_eval_fut: Future[
946+
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
947+
],
948+
is_inputs_tuple: bool,
949+
perturbations_per_eval: int,
950+
**kwargs: Any,
951+
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
952+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
953+
for i, mask in enumerate(formatted_feature_mask):
954+
for feature_idx in torch.unique(mask):
955+
if feature_idx.item() not in feature_idx_to_tensor_idx:
956+
feature_idx_to_tensor_idx[feature_idx.item()] = []
957+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
958+
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
959+
960+
additional_args_repeated: object
961+
if perturbations_per_eval > 1:
962+
# Repeat features and additional args for batch size.
963+
all_features_repeated = tuple(
964+
torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0)
965+
for j in range(len(formatted_inputs))
966+
)
967+
additional_args_repeated = (
968+
_expand_additional_forward_args(
969+
formatted_additional_forward_args, perturbations_per_eval
970+
)
971+
if formatted_additional_forward_args is not None
972+
else None
973+
)
974+
target_repeated = _expand_target(target, perturbations_per_eval)
975+
else:
976+
all_features_repeated = formatted_inputs
977+
additional_args_repeated = formatted_additional_forward_args
978+
target_repeated = target
979+
num_examples = formatted_inputs[0].shape[0]
980+
981+
current_additional_args: object
982+
if isinstance(baselines, tuple):
983+
reshaped = False
984+
reshaped_baselines: list[Union[Tensor, int, float]] = []
985+
for baseline in baselines:
986+
if isinstance(baseline, Tensor):
987+
reshaped = True
988+
reshaped_baselines.append(
989+
baseline.reshape((1,) + tuple(baseline.shape))
990+
)
991+
else:
992+
reshaped_baselines.append(baseline)
993+
baselines = tuple(reshaped_baselines) if reshaped else baselines
994+
995+
all_modified_eval_futures: List[Future[Tuple[List[Tensor], List[Tensor]]]] = []
996+
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
997+
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
998+
current_num_ablated_features = min(
999+
perturbations_per_eval, len(current_feature_idxs)
1000+
)
1001+
1002+
should_skip = False
1003+
all_empty = True
1004+
tensor_idx_list = []
1005+
for feature_idx in current_feature_idxs:
1006+
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
1007+
for tensor_idx in set(tensor_idx_list):
1008+
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
1009+
all_empty = False
1010+
if self._min_examples_per_batch_grouped is not None and (
1011+
formatted_inputs[tensor_idx].shape[0]
1012+
# pyre-ignore[58]: Type has been narrowed to int
1013+
< self._min_examples_per_batch_grouped
1014+
):
1015+
should_skip = True
1016+
break
1017+
if all_empty:
1018+
logger.info(
1019+
f"Skipping feature group {current_feature_idxs} since all "
1020+
f"input tensors are empty"
1021+
)
1022+
continue
1023+
1024+
if should_skip:
1025+
logger.warning(
1026+
f"Skipping feature group {current_feature_idxs} since it contains "
1027+
f"at least one input tensor with 0th dim less than "
1028+
f"{self._min_examples_per_batch_grouped}"
1029+
)
1030+
continue
1031+
1032+
# Store appropriate inputs and additional args based on batch size.
1033+
if current_num_ablated_features != perturbations_per_eval:
1034+
current_additional_args = (
1035+
_expand_additional_forward_args(
1036+
formatted_additional_forward_args, current_num_ablated_features
1037+
)
1038+
if formatted_additional_forward_args is not None
1039+
else None
1040+
)
1041+
current_target = _expand_target(target, current_num_ablated_features)
1042+
expanded_inputs = tuple(
1043+
feature_repeated[0 : current_num_ablated_features * num_examples]
1044+
for feature_repeated in all_features_repeated
1045+
)
1046+
else:
1047+
current_additional_args = additional_args_repeated
1048+
current_target = target_repeated
1049+
expanded_inputs = all_features_repeated
1050+
1051+
current_inputs, current_masks = (
1052+
self._construct_ablated_input_across_tensors(
1053+
expanded_inputs,
1054+
formatted_feature_mask,
1055+
baselines,
1056+
current_feature_idxs,
1057+
feature_idx_to_tensor_idx,
1058+
current_num_ablated_features,
1059+
)
1060+
)
1061+
1062+
# modified_eval has (n_feature_perturbed * n_outputs) elements
1063+
# shape:
1064+
# agg mode: (*initial_eval.shape)
1065+
# non-agg mode:
1066+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
1067+
modified_eval = _run_forward(
1068+
self.forward_func,
1069+
current_inputs,
1070+
current_target,
1071+
current_additional_args,
1072+
)
1073+
1074+
if attr_progress is not None:
1075+
attr_progress.update()
1076+
1077+
if not isinstance(modified_eval, torch.Future):
1078+
raise AssertionError(
1079+
"when using attribute_future, modified_eval should have "
1080+
f"Future type rather than {type(modified_eval)}"
1081+
)
1082+
1083+
# Need to collect both initial eval and modified_eval
1084+
eval_futs: Future[
1085+
List[
1086+
Future[
1087+
Union[
1088+
Tuple[
1089+
List[Tensor],
1090+
List[Tensor],
1091+
Tensor,
1092+
Tensor,
1093+
int,
1094+
dtype,
1095+
],
1096+
Tensor,
1097+
]
1098+
]
1099+
]
1100+
] = collect_all(
1101+
[
1102+
processed_initial_eval_fut,
1103+
modified_eval,
1104+
]
1105+
)
1106+
1107+
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = eval_futs.then(
1108+
lambda eval_futs, current_inputs=current_inputs, current_mask=current_masks, i=i: self._eval_fut_to_ablated_out_fut_cross_tensor( # type: ignore # noqa: E501 line too long
1109+
eval_futs=eval_futs,
1110+
current_inputs=current_inputs,
1111+
current_mask=current_mask,
1112+
perturbations_per_eval=perturbations_per_eval,
1113+
num_examples=num_examples,
1114+
)
1115+
)
1116+
1117+
all_modified_eval_futures.append(ablated_out_fut)
1118+
1119+
if attr_progress is not None:
1120+
attr_progress.close()
1121+
1122+
return self._generate_async_result_cross_tensor(
1123+
all_modified_eval_futures,
1124+
is_inputs_tuple,
1125+
)
1126+
1127+
def _fut_tuple_to_accumulate_fut_list_cross_tensor(
1128+
self,
1129+
total_attrib: List[Tensor],
1130+
weights: List[Tensor],
1131+
fut_tuple: Future[Tuple[List[Tensor], List[Tensor]]],
1132+
) -> None:
1133+
try:
1134+
# process_ablated_out_* already accumlates the total attribution.
1135+
# Just get the latest value
1136+
attribs, this_weights = fut_tuple.value()
1137+
total_attrib[:] = attribs
1138+
weights[:] = this_weights
1139+
except FeatureAblationFutureError as e:
1140+
raise FeatureAblationFutureError(
1141+
"_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
1142+
) from e
1143+
9241144
# pyre-fixme[3] return type must be annotated
9251145
def _attribute_progress_setup(
9261146
self,
@@ -950,7 +1170,6 @@ def _attribute_progress_setup(
9501170

9511171
def _eval_fut_to_ablated_out_fut(
9521172
self,
953-
# pyre-ignore Invalid type parameters [24]
9541173
eval_futs: Future[List[Future[List[object]]]],
9551174
current_inputs: Tuple[Tensor, ...],
9561175
current_mask: Tensor,
@@ -1012,6 +1231,94 @@ def _eval_fut_to_ablated_out_fut(
10121231
) from e
10131232
return result
10141233

1234+
def _generate_async_result_cross_tensor(
1235+
self,
1236+
futs: List[Future[Tuple[List[Tensor], List[Tensor]]]],
1237+
is_inputs_tuple: bool,
1238+
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
1239+
accumulate_fut_list: List[Future[None]] = []
1240+
total_attrib: List[Tensor] = []
1241+
weights: List[Tensor] = []
1242+
1243+
for fut_tuple in futs:
1244+
accumulate_fut_list.append(
1245+
fut_tuple.then(
1246+
lambda fut_tuple: self._fut_tuple_to_accumulate_fut_list_cross_tensor( # noqa: E501 line too long
1247+
total_attrib, weights, fut_tuple
1248+
)
1249+
)
1250+
)
1251+
1252+
result_fut = collect_all(accumulate_fut_list).then(
1253+
lambda x: self._generate_result(
1254+
total_attrib,
1255+
weights,
1256+
is_inputs_tuple,
1257+
)
1258+
)
1259+
1260+
return result_fut
1261+
1262+
def _eval_fut_to_ablated_out_fut_cross_tensor(
1263+
self,
1264+
eval_futs: Future[List[Future[List[object]]]],
1265+
current_inputs: Tuple[Tensor, ...],
1266+
current_mask: Tuple[Optional[Tensor], ...],
1267+
perturbations_per_eval: int,
1268+
num_examples: int,
1269+
) -> Tuple[List[Tensor], List[Tensor]]:
1270+
try:
1271+
modified_eval = cast(Tensor, eval_futs.value()[1].value())
1272+
initial_eval_tuple = cast(
1273+
Tuple[
1274+
List[Tensor],
1275+
List[Tensor],
1276+
Tensor,
1277+
Tensor,
1278+
int,
1279+
dtype,
1280+
],
1281+
eval_futs.value()[0].value(),
1282+
)
1283+
if len(initial_eval_tuple) != 6:
1284+
raise AssertionError(
1285+
"eval_fut_to_ablated_out_fut_cross_tensor: "
1286+
"initial_eval_tuple should have 6 elements: "
1287+
"total_attrib, weights, initial_eval, "
1288+
"flattened_initial_eval, n_outputs, attrib_type "
1289+
)
1290+
if not isinstance(modified_eval, Tensor):
1291+
raise AssertionError(
1292+
"_eval_fut_to_ablated_out_fut_cross_tensor: "
1293+
"modified eval should be a Tensor"
1294+
)
1295+
(
1296+
total_attrib,
1297+
weights,
1298+
initial_eval,
1299+
flattened_initial_eval,
1300+
n_outputs,
1301+
attrib_type,
1302+
) = initial_eval_tuple
1303+
total_attrib, weights = self._process_ablated_out_full(
1304+
modified_eval=modified_eval,
1305+
inputs=current_inputs,
1306+
current_mask=current_mask,
1307+
perturbations_per_eval=perturbations_per_eval,
1308+
num_examples=num_examples,
1309+
initial_eval=initial_eval,
1310+
flattened_initial_eval=flattened_initial_eval,
1311+
n_outputs=n_outputs,
1312+
total_attrib=total_attrib,
1313+
weights=weights,
1314+
attrib_type=attrib_type,
1315+
)
1316+
except FeatureAblationFutureError as e:
1317+
raise FeatureAblationFutureError(
1318+
"_eval_fut_to_ablated_out_fut_cross_tensor func failed"
1319+
) from e
1320+
return total_attrib, weights
1321+
10151322
def _ith_input_ablation_generator(
10161323
self,
10171324
i: int,

0 commit comments

Comments
 (0)