dejanseo commited on
Commit
a3ab355
1 Parent(s): 77c9a6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -20
app.py CHANGED
@@ -22,11 +22,10 @@ except FileNotFoundError:
22
  st.stop()
23
 
24
  def tokenize(text):
25
- # Ensure the text is a string before splitting
26
  if isinstance(text, str):
27
  return text.split()
28
  else:
29
- return [] # Return an empty list if the text is not a string
30
 
31
  def embed_text(text_series, fasttext_model):
32
  embeddings = []
@@ -40,26 +39,21 @@ def embed_text(text_series, fasttext_model):
40
  return np.array(embeddings)
41
 
42
  def preprocess_input(query, title, description, url, fasttext_model):
43
- # Convert None or NaN to an empty string to avoid errors during tokenization
44
  query = str(query) if pd.notna(query) else ''
45
  title = str(title) if pd.notna(title) else ''
46
  description = str(description) if pd.notna(description) else ''
47
  url = str(url) if pd.notna(url) else ''
48
 
49
- # Embed each text field using FastText
50
  query_ft = embed_text(pd.Series([query]), fasttext_model)
51
  title_ft = embed_text(pd.Series([title]), fasttext_model)
52
  description_ft = embed_text(pd.Series([description]), fasttext_model)
53
  url_ft = embed_text(pd.Series([url]), fasttext_model)
54
 
55
- # Combine embeddings into a single array
56
  combined_features = np.hstack([query_ft, title_ft, description_ft, url_ft])
57
 
58
- # Convert combined_features to a DMatrix for XGBoost
59
  dmatrix = xgb.DMatrix(combined_features)
60
  return dmatrix
61
 
62
- # Function to extract title and description from a URL
63
  def extract_title_description(url):
64
  headers = {
65
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.81 Safari/537.36'
@@ -74,11 +68,10 @@ def extract_title_description(url):
74
  except Exception as e:
75
  return 'Error extracting title', 'Error extracting description'
76
 
77
- # Function to make predictions
78
  def predict(query, title, description, url, fasttext_model):
79
  dmatrix = preprocess_input(query, title, description, url, fasttext_model)
80
- probability = model.predict(dmatrix, validate_features=False)[0] # Get the probability prediction
81
- binary_prediction = int(probability >= 0.5) # Convert to binary: 1 if >= 0.5, else 0
82
  return binary_prediction, probability
83
 
84
  # Streamlit interface
@@ -101,8 +94,6 @@ with tab1:
101
  binary_result, confidence = predict(query, title, description, url, fasttext_model)
102
  st.write(f'Predicted +/-: {binary_result}')
103
  st.write(f'Conf.: {confidence:.2%}')
104
-
105
- # Convert confidence to a percentage and cast to int
106
  confidence_percentage = int(confidence * 100)
107
  st.progress(confidence_percentage)
108
  else:
@@ -115,8 +106,6 @@ with tab2:
115
 
116
  if uploaded_file is not None:
117
  df = pd.read_csv(uploaded_file)
118
-
119
- # Select only the columns necessary for inference
120
  required_columns = ['Query', 'Title', 'Description', 'URL']
121
 
122
  if set(required_columns).issubset(df.columns):
@@ -127,15 +116,12 @@ with tab2:
127
  predictions.append(binary_result)
128
  confidences.append(confidence)
129
 
130
- # Add binary predictions and confidence to the DataFrame
131
  df['+/-'] = predictions
132
  df['Conf.'] = [f"{conf:.2%}" for conf in confidences]
133
 
134
- # Reorder the columns to put '+/-' and 'Conf.' at the front
135
  cols = ['+/-', 'Conf.'] + [col for col in df.columns if col not in ['+/-', 'Conf.']]
136
  df = df[cols]
137
 
138
- # Display and allow download of results
139
  st.write(df)
140
  st.download_button("Download Predictions", df.to_csv(index=False), "predictions.csv")
141
  else:
@@ -149,11 +135,13 @@ with tab3:
149
 
150
  if st.button('Scrape A/B'):
151
  title_A, description_A = extract_title_description(url)
 
 
152
  st.write(f'Extracted Title A: {title_A}')
153
  st.write(f'Extracted Description A: {description_A}')
154
 
155
- title_B = st.text_input('Title B', value=title_A)
156
- description_B = st.text_area('Description B', value=description_A)
157
 
158
  if st.button('Predict A/B'):
159
  if query and url:
@@ -163,7 +151,6 @@ with tab3:
163
  st.write(f'Results for A: Predicted +/-: {binary_result_A}, Conf.: {confidence_A:.2%}')
164
  st.write(f'Results for B: Predicted +/-: {binary_result_B}, Conf.: {confidence_B:.2%}')
165
 
166
- # Determine improvement
167
  if binary_result_A == 1 and binary_result_B == 0:
168
  st.write("B is worse than A")
169
  elif binary_result_A == 0 and binary_result_B == 1:
 
22
  st.stop()
23
 
24
  def tokenize(text):
 
25
  if isinstance(text, str):
26
  return text.split()
27
  else:
28
+ return []
29
 
30
  def embed_text(text_series, fasttext_model):
31
  embeddings = []
 
39
  return np.array(embeddings)
40
 
41
  def preprocess_input(query, title, description, url, fasttext_model):
 
42
  query = str(query) if pd.notna(query) else ''
43
  title = str(title) if pd.notna(title) else ''
44
  description = str(description) if pd.notna(description) else ''
45
  url = str(url) if pd.notna(url) else ''
46
 
 
47
  query_ft = embed_text(pd.Series([query]), fasttext_model)
48
  title_ft = embed_text(pd.Series([title]), fasttext_model)
49
  description_ft = embed_text(pd.Series([description]), fasttext_model)
50
  url_ft = embed_text(pd.Series([url]), fasttext_model)
51
 
 
52
  combined_features = np.hstack([query_ft, title_ft, description_ft, url_ft])
53
 
 
54
  dmatrix = xgb.DMatrix(combined_features)
55
  return dmatrix
56
 
 
57
  def extract_title_description(url):
58
  headers = {
59
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.81 Safari/537.36'
 
68
  except Exception as e:
69
  return 'Error extracting title', 'Error extracting description'
70
 
 
71
  def predict(query, title, description, url, fasttext_model):
72
  dmatrix = preprocess_input(query, title, description, url, fasttext_model)
73
+ probability = model.predict(dmatrix, validate_features=False)[0]
74
+ binary_prediction = int(probability >= 0.5)
75
  return binary_prediction, probability
76
 
77
  # Streamlit interface
 
94
  binary_result, confidence = predict(query, title, description, url, fasttext_model)
95
  st.write(f'Predicted +/-: {binary_result}')
96
  st.write(f'Conf.: {confidence:.2%}')
 
 
97
  confidence_percentage = int(confidence * 100)
98
  st.progress(confidence_percentage)
99
  else:
 
106
 
107
  if uploaded_file is not None:
108
  df = pd.read_csv(uploaded_file)
 
 
109
  required_columns = ['Query', 'Title', 'Description', 'URL']
110
 
111
  if set(required_columns).issubset(df.columns):
 
116
  predictions.append(binary_result)
117
  confidences.append(confidence)
118
 
 
119
  df['+/-'] = predictions
120
  df['Conf.'] = [f"{conf:.2%}" for conf in confidences]
121
 
 
122
  cols = ['+/-', 'Conf.'] + [col for col in df.columns if col not in ['+/-', 'Conf.']]
123
  df = df[cols]
124
 
 
125
  st.write(df)
126
  st.download_button("Download Predictions", df.to_csv(index=False), "predictions.csv")
127
  else:
 
135
 
136
  if st.button('Scrape A/B'):
137
  title_A, description_A = extract_title_description(url)
138
+ st.session_state['title_A'] = title_A
139
+ st.session_state['description_A'] = description_A
140
  st.write(f'Extracted Title A: {title_A}')
141
  st.write(f'Extracted Description A: {description_A}')
142
 
143
+ title_B = st.text_input('Title B', value=st.session_state.get('title_A', ''))
144
+ description_B = st.text_area('Description B', value=st.session_state.get('description_A', ''))
145
 
146
  if st.button('Predict A/B'):
147
  if query and url:
 
151
  st.write(f'Results for A: Predicted +/-: {binary_result_A}, Conf.: {confidence_A:.2%}')
152
  st.write(f'Results for B: Predicted +/-: {binary_result_B}, Conf.: {confidence_B:.2%}')
153
 
 
154
  if binary_result_A == 1 and binary_result_B == 0:
155
  st.write("B is worse than A")
156
  elif binary_result_A == 0 and binary_result_B == 1: