@@ -1096,12 +1096,14 @@ def _apply_masks(logits, mask, is_causal):
1096
1096
1097
1097
1098
1098
def _dot_product_attention_xla (query , key , value , bias , mask , is_causal , scale ):
1099
+ original_dtype = key .dtype
1099
1100
logits_dtype = np .promote_types (query .dtype , np .float32 )
1100
- logits = np .einsum (
1101
- "BTNH,BSNH->BNTS" ,
1102
- query .astype (logits_dtype ),
1103
- key .astype (logits_dtype ),
1104
- )
1101
+ if backend .standardize_dtype (key .dtype ) == "bfloat16" :
1102
+ # `np.einsum` doesn't support bfloat16
1103
+ key = key .astype ("float32" )
1104
+ value = value .astype ("float32" )
1105
+ logits = np .einsum ("BTNH,BSNH->BNTS" , query , key )
1106
+ logits = logits .astype (logits_dtype )
1105
1107
logits *= np .array (scale , dtype = logits .dtype )
1106
1108
1107
1109
if bias is not None :
@@ -1111,7 +1113,7 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
1111
1113
1112
1114
# Softmax and it is always carried out in fp32.
1113
1115
padded_logits = padded_logits .astype (np .float32 )
1114
- probs = softmax (padded_logits , axis = - 1 ).astype (key . dtype )
1116
+ probs = softmax (padded_logits , axis = - 1 ).astype (original_dtype )
1115
1117
encoded_dtype = probs .dtype
1116
1118
if backend .standardize_dtype (probs .dtype ) == "bfloat16" :
1117
1119
# `np.einsum` doesn't support bfloat16
0 commit comments