edugp commited on
Commit
52a4ec3
1 Parent(s): caaaeef

Cache model and tokenizer and lock dependencies

Browse files
Files changed (2) hide show
  1. app.py +19 -3
  2. requirements.txt +7 -7
app.py CHANGED
@@ -2,10 +2,10 @@ import os
2
  import sys
3
 
4
  import streamlit as st
 
5
  from huggingface_hub import snapshot_download
6
  from transformers import AutoTokenizer
7
 
8
-
9
  LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
10
  sys.path.append(LOCAL_PATH)
11
 
@@ -15,16 +15,24 @@ from test_on_image import run_inference
15
 
16
  def save_file_to_disk(uplaoded_file):
17
  temp_file = os.path.join("/tmp", uplaoded_file.name)
18
- with open(temp_file,"wb") as f:
19
  f.write(uploaded_file.getbuffer())
20
  return temp_file
21
 
 
 
 
 
 
 
 
22
  def load_tokenizer_and_model():
23
  # load the saved model
24
  tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
25
  model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
26
  return tokenizer, model
27
 
 
28
  tokenizer, model = load_tokenizer_and_model()
29
 
30
  st.title("Image-Caption Matching")
@@ -36,7 +44,15 @@ if uploaded_file is not None and text_input:
36
  try:
37
  local_image_path = save_file_to_disk(uploaded_file)
38
  score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
39
- st.image(uploaded_file, caption=text_input, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
 
 
 
 
 
 
 
 
40
  st.write(f"## Score: {score:.2f}")
41
  finally:
42
  if local_image_path:
 
2
  import sys
3
 
4
  import streamlit as st
5
+ import transformers
6
  from huggingface_hub import snapshot_download
7
  from transformers import AutoTokenizer
8
 
 
9
  LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
10
  sys.path.append(LOCAL_PATH)
11
 
 
15
 
16
  def save_file_to_disk(uplaoded_file):
17
  temp_file = os.path.join("/tmp", uplaoded_file.name)
18
+ with open(temp_file, "wb") as f:
19
  f.write(uploaded_file.getbuffer())
20
  return temp_file
21
 
22
+
23
+ @st.cache(
24
+ hash_funcs={
25
+ transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: id,
26
+ FlaxHybridCLIP: id,
27
+ }
28
+ )
29
  def load_tokenizer_and_model():
30
  # load the saved model
31
  tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
32
  model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
33
  return tokenizer, model
34
 
35
+
36
  tokenizer, model = load_tokenizer_and_model()
37
 
38
  st.title("Image-Caption Matching")
 
44
  try:
45
  local_image_path = save_file_to_disk(uploaded_file)
46
  score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
47
+ st.image(
48
+ uploaded_file,
49
+ caption=text_input,
50
+ width=None,
51
+ use_column_width=None,
52
+ clamp=False,
53
+ channels="RGB",
54
+ output_format="auto",
55
+ )
56
  st.write(f"## Score: {score:.2f}")
57
  finally:
58
  if local_image_path:
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- flax
2
- huggingface_hub
3
- jax
4
  streamlit==0.84.1
5
- torch
6
- torchvision
7
- transformers
8
- watchdog
 
1
+ flax==0.3.4
2
+ huggingface-hub==0.0.12
3
+ jax==0.2.17
4
  streamlit==0.84.1
5
+ torch==1.9.0
6
+ torchvision==0.10.0
7
+ transformers==4.8.2
8
+ watchdog==2.1.3