Commit
·
09790b2
1
Parent(s):
8eeff9c
update
Browse files
nested_attention_pipeline.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import List
|
|
| 4 |
import torch
|
| 5 |
from PIL import Image
|
| 6 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
| 7 |
|
| 8 |
from nested_attention_processor import AttnProcessor, NestedAttnProcessor
|
| 9 |
from utils import get_generator
|
|
@@ -110,7 +111,7 @@ class NestedAdapterInference:
|
|
| 110 |
|
| 111 |
def load_nested_adapter(self):
|
| 112 |
state_dict = {"adapter_modules": {}, "qformer": {}}
|
| 113 |
-
f =
|
| 114 |
for key in f.keys():
|
| 115 |
if key.startswith("adapter_modules."):
|
| 116 |
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
|
|
|
|
| 4 |
import torch
|
| 5 |
from PIL import Image
|
| 6 |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 7 |
+
from safetensors import load_file
|
| 8 |
|
| 9 |
from nested_attention_processor import AttnProcessor, NestedAttnProcessor
|
| 10 |
from utils import get_generator
|
|
|
|
| 111 |
|
| 112 |
def load_nested_adapter(self):
|
| 113 |
state_dict = {"adapter_modules": {}, "qformer": {}}
|
| 114 |
+
f = load_file(self.adapter_ckpt)
|
| 115 |
for key in f.keys():
|
| 116 |
if key.startswith("adapter_modules."):
|
| 117 |
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
|