Spaces:
Runtime error
Runtime error
Update pipelines/flux_pipeline/transformer.py
Browse files
pipelines/flux_pipeline/transformer.py
CHANGED
|
@@ -126,12 +126,7 @@ class FluxAttnProcessor2_0:
|
|
| 126 |
if neg_mode and FLEX_ATTENTION_AVAILABLE:
|
| 127 |
# Apply flex_attention with the block mask
|
| 128 |
global block_mask
|
| 129 |
-
need_new_mask =
|
| 130 |
-
block_mask is None
|
| 131 |
-
or block_mask.shape[-2] != query.shape[2]
|
| 132 |
-
or block_mask.shape[-1] != query.shape[2]
|
| 133 |
-
or block_mask.device != query.device
|
| 134 |
-
)
|
| 135 |
|
| 136 |
if need_new_mask:
|
| 137 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
|
|
|
| 126 |
if neg_mode and FLEX_ATTENTION_AVAILABLE:
|
| 127 |
# Apply flex_attention with the block mask
|
| 128 |
global block_mask
|
| 129 |
+
need_new_mask = block_mask is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
if need_new_mask:
|
| 132 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|