Commit
·
83f7703
1
Parent(s):
96eabf2
Upload folder using huggingface_hub
Browse files
animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc
CHANGED
|
Binary files a/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc and b/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc differ
|
|
|
animatediff/utils/convert_from_ckpt.py
CHANGED
|
@@ -198,20 +198,21 @@ def assign_to_checkpoint(
|
|
| 198 |
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 199 |
|
| 200 |
# proj_attn.weight has to be converted from conv 1D to linear
|
| 201 |
-
if "
|
| 202 |
-
|
|
|
|
| 203 |
else:
|
| 204 |
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 205 |
|
| 206 |
|
| 207 |
def conv_attn_to_linear(checkpoint):
|
| 208 |
keys = list(checkpoint.keys())
|
| 209 |
-
attn_keys = ["
|
| 210 |
for key in keys:
|
| 211 |
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 212 |
if checkpoint[key].ndim > 2:
|
| 213 |
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 214 |
-
elif "
|
| 215 |
if checkpoint[key].ndim > 2:
|
| 216 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 217 |
|
|
@@ -632,7 +633,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
| 632 |
oldKey = {"old": "key", "new": "to_k"}
|
| 633 |
oldQuery = {"old": "query", "new": "to_q"}
|
| 634 |
oldValue = {"old": "value", "new": "to_v"}
|
| 635 |
-
oldOut = {"old": "proj_attn", "new": "to_out"}
|
| 636 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
| 637 |
conv_attn_to_linear(new_checkpoint)
|
| 638 |
|
|
@@ -669,7 +670,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
| 669 |
oldKey = {"old": "key", "new": "to_k"}
|
| 670 |
oldQuery = {"old": "query", "new": "to_q"}
|
| 671 |
oldValue = {"old": "value", "new": "to_v"}
|
| 672 |
-
oldOut = {"old": "proj_attn", "new": "to_out"}
|
| 673 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
| 674 |
conv_attn_to_linear(new_checkpoint)
|
| 675 |
return new_checkpoint
|
|
|
|
| 198 |
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 199 |
|
| 200 |
# proj_attn.weight has to be converted from conv 1D to linear
|
| 201 |
+
if "to_out.0.weight" in new_path and "decoder" in new_path:
|
| 202 |
+
# turn [512, 512, 1] into [512, 512]
|
| 203 |
+
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze(-1)
|
| 204 |
else:
|
| 205 |
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 206 |
|
| 207 |
|
| 208 |
def conv_attn_to_linear(checkpoint):
|
| 209 |
keys = list(checkpoint.keys())
|
| 210 |
+
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
| 211 |
for key in keys:
|
| 212 |
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 213 |
if checkpoint[key].ndim > 2:
|
| 214 |
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 215 |
+
elif "to_out.0.weight" in key:
|
| 216 |
if checkpoint[key].ndim > 2:
|
| 217 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 218 |
|
|
|
|
| 633 |
oldKey = {"old": "key", "new": "to_k"}
|
| 634 |
oldQuery = {"old": "query", "new": "to_q"}
|
| 635 |
oldValue = {"old": "value", "new": "to_v"}
|
| 636 |
+
oldOut = {"old": "proj_attn", "new": "to_out.0"}
|
| 637 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
| 638 |
conv_attn_to_linear(new_checkpoint)
|
| 639 |
|
|
|
|
| 670 |
oldKey = {"old": "key", "new": "to_k"}
|
| 671 |
oldQuery = {"old": "query", "new": "to_q"}
|
| 672 |
oldValue = {"old": "value", "new": "to_v"}
|
| 673 |
+
oldOut = {"old": "proj_attn", "new": "to_out.0"}
|
| 674 |
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
|
| 675 |
conv_attn_to_linear(new_checkpoint)
|
| 676 |
return new_checkpoint
|