dhmeltzer commited on
Commit
d21a4cc
1 Parent(s): fbd29f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -5,13 +5,19 @@ import streamlit as st
5
  #def main():
6
  st.title("Scientific Question Generation")
7
 
8
- API_URL = "https://api-inference.huggingface.co/models/dhmeltzer/bart-large_askscience-qg"
9
- headers = {"Authorization": "Bearer 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"}
 
10
 
11
- def query(payload):
 
 
 
 
12
  response = requests.post(API_URL,
13
  headers=headers,
14
  json=payload)
 
15
  return response.json()
16
 
17
  # User search
@@ -22,18 +28,17 @@ user_input = st.text_area("Question Generator",
22
  st.sidebar.markdown("**Filters**")
23
 
24
  temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.0,.1)
25
- num_results = st.sidebar.slider("Number of search results", 1, 50, 1)
26
-
27
  vector = query([user_input])
28
 
29
  if user_input:
30
-
31
- output = query({
32
- "inputs": user_input,
33
- "temperature":temperature,
34
- "wait_for_model":True})
35
-
36
- st.write(output)
 
37
 
38
 
39
  #if __name__ == "__main__":
 
5
  #def main():
6
  st.title("Scientific Question Generation")
7
 
8
+ checkpoints = ['dhmeltzer/bart-large_askscience-qg',
9
+ 'dhmeltzer/flan-t5-base_askscience-qg',
10
+ 'google/flan-t5-xxl']
11
 
12
+ headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}
13
+
14
+ def query(checkpoint, payload):
15
+ API_URL = f"https://api-inference.huggingface.co/models/{checkpoint}}"
16
+
17
  response = requests.post(API_URL,
18
  headers=headers,
19
  json=payload)
20
+
21
  return response.json()
22
 
23
  # User search
 
28
  st.sidebar.markdown("**Filters**")
29
 
30
  temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.0,.1)
 
 
31
  vector = query([user_input])
32
 
33
  if user_input:
34
+ for checkpoint in checkpoints:
35
+ output = query(checkpoint,{
36
+ "inputs": user_input,
37
+ "temperature":temperature,
38
+ "wait_for_model":True})[0][0]['generated_text']
39
+
40
+ model_name = checkpoints.split('/')[1]
41
+ st.write(f'Model {model_name}: output')
42
 
43
 
44
  #if __name__ == "__main__":