nupurkmr9 commited on
Commit
71c8eae
Β·
verified Β·
1 Parent(s): da28f5f

Update pipelines/flux_pipeline/transformer.py

Browse files
pipelines/flux_pipeline/transformer.py CHANGED
@@ -125,9 +125,15 @@ class FluxAttnProcessor2_0:
125
 
126
  if neg_mode and FLEX_ATTENTION_AVAILABLE:
127
  # Apply flex_attention with the block mask
128
- global flex_attention_func, block_mask
129
- if flex_attention_func is None:
130
- flex_attention_func = torch.compile(flex_attention, dynamic=False)
 
 
 
 
 
 
131
  res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
132
  seq_len = query.shape[2]
133
 
@@ -155,7 +161,7 @@ class FluxAttnProcessor2_0:
155
  block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
156
  Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
157
 
158
- hidden_states = flex_attention_func(query, key, value, block_mask=block_mask)
159
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
160
  elif neg_mode:
161
  # Fallback to original implementation if flex_attention is not available
 
125
 
126
  if neg_mode and FLEX_ATTENTION_AVAILABLE:
127
  # Apply flex_attention with the block mask
128
+ global block_mask
129
+ need_new_mask = (
130
+ block_mask is None
131
+ or block_mask.shape[-2] != query.shape[2]
132
+ or block_mask.shape[-1] != query.shape[2]
133
+ or block_mask.device != query.device
134
+ )
135
+
136
+ if need_new_mask:
137
  res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
138
  seq_len = query.shape[2]
139
 
 
161
  block_mask = create_block_mask(block_diagonal_mask, B=1, H=None,
162
  Q_LEN=seq_len, KV_LEN=seq_len, device=query.device)
163
 
164
+ hidden_states = flex_attention(query, key, value, block_mask=block_mask)
165
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
166
  elif neg_mode:
167
  # Fallback to original implementation if flex_attention is not available