Spaces:
Build error
Build error
Add files
Browse files- .gitattributes +1 -0
- README.md +1 -1
- app.py +3 -4
- cc3m_embeddings.pkl +3 -0
- fromage/models.py +1 -1
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
cc3m_embeddings.pkl filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: FROMAGe
|
3 |
emoji: 🧀
|
4 |
-
sdk:
|
5 |
colorFrom: blue
|
6 |
colorTo: red
|
7 |
pinned: true
|
|
|
1 |
---
|
2 |
title: FROMAGe
|
3 |
emoji: 🧀
|
4 |
+
sdk: gradio
|
5 |
colorFrom: blue
|
6 |
colorTo: red
|
7 |
pinned: true
|
app.py
CHANGED
@@ -13,10 +13,9 @@ import tempfile
|
|
13 |
class FromageChatBot:
|
14 |
def __init__(self):
|
15 |
# Download model from HF Hub.
|
16 |
-
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
|
17 |
-
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
|
18 |
-
|
19 |
-
self.model = models.load_fromage('./')
|
20 |
self.chat_history = ''
|
21 |
self.input_image = None
|
22 |
|
|
|
13 |
class FromageChatBot:
|
14 |
def __init__(self):
|
15 |
# Download model from HF Hub.
|
16 |
+
ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
|
17 |
+
args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
|
18 |
+
self.model = models.load_fromage('./', args_path, ckpt_path)
|
|
|
19 |
self.chat_history = ''
|
20 |
self.input_image = None
|
21 |
|
cc3m_embeddings.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a20fa8168bd72e848ff088820b767383dded455a57ac5dd2d97d43e600402195
|
3 |
+
size 2979901225
|
fromage/models.py
CHANGED
@@ -594,7 +594,7 @@ class Fromage(nn.Module):
|
|
594 |
return return_outputs
|
595 |
|
596 |
|
597 |
-
def load_fromage(
|
598 |
model_args_path = os.path.join(model_dir, 'model_args.json')
|
599 |
model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
|
600 |
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
|
|
|
594 |
return return_outputs
|
595 |
|
596 |
|
597 |
+
def load_fromage(embeddings_dir: str, args_path: str, ckpt_path: str) -> Fromage:
|
598 |
model_args_path = os.path.join(model_dir, 'model_args.json')
|
599 |
model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
|
600 |
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
|