Spaces:
Runtime error
Runtime error
Update pipelines/flux_pipeline/transformer.py
Browse files
pipelines/flux_pipeline/transformer.py
CHANGED
|
@@ -125,9 +125,15 @@ class FluxAttnProcessor2_0:
|
|
| 125 |
|
| 126 |
if neg_mode and FLEX_ATTENTION_AVAILABLE:
|
| 127 |
# Apply flex_attention with the block mask
|
| 128 |
-
global
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 132 |
seq_len = query.shape[2]
|
| 133 |
|
|
@@ -155,7 +161,7 @@ class FluxAttnProcessor2_0:
|
|
| 155 |
block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
|
| 156 |
Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
|
| 157 |
|
| 158 |
-
hidden_states =
|
| 159 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 160 |
elif neg_mode:
|
| 161 |
# Fallback to original implementation if flex_attention is not available
|
|
|
|
| 125 |
|
| 126 |
if neg_mode and FLEX_ATTENTION_AVAILABLE:
|
| 127 |
# Apply flex_attention with the block mask
|
| 128 |
+
global block_mask
|
| 129 |
+
need_new_mask = (
|
| 130 |
+
block_mask is None
|
| 131 |
+
or block_mask.shape[-2] != query.shape[2]
|
| 132 |
+
or block_mask.shape[-1] != query.shape[2]
|
| 133 |
+
or block_mask.device != query.device
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if need_new_mask:
|
| 137 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 138 |
seq_len = query.shape[2]
|
| 139 |
|
|
|
|
| 161 |
block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
|
| 162 |
Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
|
| 163 |
|
| 164 |
+
hidden_states = flex_attention(query, key, value, block_mask=block_mask)
|
| 165 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 166 |
elif neg_mode:
|
| 167 |
# Fallback to original implementation if flex_attention is not available
|