jeffeux commited on
Commit
30f0864
1 Parent(s): e0fbe2a

balloons debug

Browse files
Files changed (1) hide show
  1. main.py +11 -0
main.py CHANGED
@@ -2,6 +2,7 @@
2
  import os, logging, torch, streamlit as st
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM)
 
5
 
6
  # --------------------- HELPER --------------------- #
7
  def C(text, color="yellow"):
@@ -17,12 +18,14 @@ def C(text, color="yellow"):
17
  return (
18
  f"{color_dict.get(color, None)}"
19
  f"{text}{color_dict[None]}")
 
20
 
21
  # ------------------ ENVIORNMENT ------------------- #
22
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
23
  device = ("cuda"
24
  if torch.cuda.is_available() else "cpu")
25
  logging.info(C("[INFO] "f"device = {device}"))
 
26
 
27
  # ------------------ INITITALIZE ------------------- #
28
  @st.cache
@@ -41,20 +44,28 @@ def model_init():
41
  return tokenizer, model
42
 
43
  tokenizer, model = model_init()
 
44
 
45
  try:
46
  # ===================== INPUT ====================== #
47
  # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
48
  prompt = st.text_input("Prompt: ")
 
 
49
 
50
  # =================== INFERENCE ==================== #
51
  if prompt:
 
52
  with torch.no_grad():
53
  [texts_out] = model.generate(
54
  **tokenizer(
55
  prompt, return_tensors="pt"
56
  ).to(device))
 
57
  output_text = tokenizer.decode(texts_out)
 
58
  st.markdown(output_text)
 
59
  except Exception as err:
60
  st.write(str(err))
 
 
2
  import os, logging, torch, streamlit as st
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM)
5
+ st.balloons()
6
 
7
  # --------------------- HELPER --------------------- #
8
  def C(text, color="yellow"):
 
18
  return (
19
  f"{color_dict.get(color, None)}"
20
  f"{text}{color_dict[None]}")
21
+ st.balloons()
22
 
23
  # ------------------ ENVIORNMENT ------------------- #
24
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
25
  device = ("cuda"
26
  if torch.cuda.is_available() else "cpu")
27
  logging.info(C("[INFO] "f"device = {device}"))
28
+ st.balloons()
29
 
30
  # ------------------ INITITALIZE ------------------- #
31
  @st.cache
 
44
  return tokenizer, model
45
 
46
  tokenizer, model = model_init()
47
+ st.balloons()
48
 
49
  try:
50
  # ===================== INPUT ====================== #
51
  # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
52
  prompt = st.text_input("Prompt: ")
53
+ st.balloons()
54
+
55
 
56
  # =================== INFERENCE ==================== #
57
  if prompt:
58
+ st.balloons()
59
  with torch.no_grad():
60
  [texts_out] = model.generate(
61
  **tokenizer(
62
  prompt, return_tensors="pt"
63
  ).to(device))
64
+ st.balloons()
65
  output_text = tokenizer.decode(texts_out)
66
+ st.balloons()
67
  st.markdown(output_text)
68
+ st.balloons()
69
  except Exception as err:
70
  st.write(str(err))
71
+ st.snow()