Skip to content

Commit c6c0720

Browse files
Use lower precision in DPA (#20615)
1 parent 5b6b9b0 commit c6c0720

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

keras/src/backend/numpy/nn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,12 +1096,14 @@ def _apply_masks(logits, mask, is_causal):
10961096

10971097

10981098
def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
1099+
original_dtype = key.dtype
10991100
logits_dtype = np.promote_types(query.dtype, np.float32)
1100-
logits = np.einsum(
1101-
"BTNH,BSNH->BNTS",
1102-
query.astype(logits_dtype),
1103-
key.astype(logits_dtype),
1104-
)
1101+
if backend.standardize_dtype(key.dtype) == "bfloat16":
1102+
# `np.einsum` doesn't support bfloat16
1103+
key = key.astype("float32")
1104+
value = value.astype("float32")
1105+
logits = np.einsum("BTNH,BSNH->BNTS", query, key)
1106+
logits = logits.astype(logits_dtype)
11051107
logits *= np.array(scale, dtype=logits.dtype)
11061108

11071109
if bias is not None:
@@ -1111,7 +1113,7 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
11111113

11121114
# Softmax and it is always carried out in fp32.
11131115
padded_logits = padded_logits.astype(np.float32)
1114-
probs = softmax(padded_logits, axis=-1).astype(key.dtype)
1116+
probs = softmax(padded_logits, axis=-1).astype(original_dtype)
11151117
encoded_dtype = probs.dtype
11161118
if backend.standardize_dtype(probs.dtype) == "bfloat16":
11171119
# `np.einsum` doesn't support bfloat16

keras/src/backend/tensorflow/nn.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,12 +1015,8 @@ def _apply_masks(logits, mask, is_causal):
10151015

10161016
def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
10171017
logits_dtype = backend.result_type(query.dtype, "float32")
1018-
logits = tf.einsum(
1019-
"BTNH,BSNH->BNTS",
1020-
tf.cast(query, dtype=logits_dtype),
1021-
tf.cast(key, dtype=logits_dtype),
1022-
optimize="optimal",
1023-
)
1018+
logits = tf.einsum("BTNH,BSNH->BNTS", query, key, optimize="optimal")
1019+
logits = tf.cast(logits, logits_dtype)
10241020
logits = tf.multiply(logits, tf.cast(scale, logits.dtype))
10251021

10261022
if bias is not None:

0 commit comments

Comments
 (0)