YchKhan commited on
Commit
b7f5159
1 Parent(s): c778b9c

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +5 -0
classification.py CHANGED
@@ -177,8 +177,13 @@ def match_categories(df, category_df, treshold=0.45):
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
 
 
 
 
180
  for j in high_score_indices:
181
  df.loc[index, category_df.loc[j, 'topic']] = float(cos_scores[j])
 
182
  return df
183
 
184
  def save_data(df, filename):
 
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
+ categories_list.append([category_df.loc[index, 'description'] for index in high_score_indices])
181
+ experts_list.append([category_df.loc[index, 'experts'] for index in high_score_indices])
182
+ topic_list.append([category_df.loc[index, 'topic'] for index in high_score_indices])
183
+ scores_list.append([float(cos_scores[index]) for index in high_score_indices])
184
  for j in high_score_indices:
185
  df.loc[index, category_df.loc[j, 'topic']] = float(cos_scores[j])
186
+
187
  return df
188
 
189
  def save_data(df, filename):