SEPredictor / app.py
YchKhan's picture
Update app.py
3421540
from flask import Flask, jsonify, request, render_template, stream_with_context, Response, flash
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
import torch
import re
app = Flask(__name__)
def extract_embeddings(embeddings_str):
pattern = r'(-?\d+(?:\.\d+)?(?:[eE]-?\d+)?)'
matches = re.findall(pattern, embeddings_str)
return list(map(float, matches))
df = pd.read_excel("ebd4appdom.xlsx")
embedder = SentenceTransformer('all-MiniLM-L6-v2')
df['Embeddings'] = df['Embeddings'].apply(extract_embeddings)
descriptions_embeddings = list(df.Embeddings)
patnums = list(df["Number"])
standards = list(df["Standards"])
urls = list(df["URL"])
descriptions = list(df.Description)
#split function
def split_string(s, max_len, overlap, min_words_count=0):
words = s.split()
substrings = []
start = 0
while start + max_len < len(words):
end = start + max_len
substring = " ".join(words[start:end])
substrings.append(substring)
start = end - overlap
substrings.append(" ".join(words[start:]))
long_substrings = []
for string in substrings:
if len(string.split()) > min_words_count:
long_substrings.append(string)
return long_substrings
from collections import Counter
def top_five_strings(strings, size):
# Count the frequency of each string
string_counts = Counter(strings)
# Calculate the weight of each string
string_weights = [{ 'word': string, 'freq': count/size } for string, count in string_counts.items()]
# Sort the strings by weight in descending order
sorted_strings = sorted(string_weights, key=lambda x: x['freq'], reverse=True)
# Return the top five strings
return sorted_strings[:5]
def stream_template(template_name, **context):
app.update_template_context(context)
t = app.jinja_env.get_template(template_name)
rv = t.stream(context)
rv.disable_buffering()
return rv
def infer_gen(query):
user_samples = split_string(query, 80, 3)
top_k = min(5, len(descriptions))
results = []
cpt=0
specs=[]
std = []
for i, user_sample in enumerate(user_samples):
print("processing {}/{}".format(i, len(user_samples)))
sp=[[user_sample, 'sample' + str(cpt)]]
print("embedding ; ", user_sample[:10])
sample_embedding = embedder.encode(user_sample, convert_to_tensor=True)
print("scoring : ", user_sample[:10])
cos_scores = util.cos_sim(sample_embedding, descriptions_embeddings)[0]
print("scored!", sample_embedding[:1])
top_results = torch.topk(cos_scores, top_k)
print("topped!")
for score, idx in zip(top_results[0], top_results[1]):
print("creating dict for : ", user_sample[:10])
my_dict = dict(score= round(float(score.item()), 4), standards=standards[idx], desc=descriptions[idx], url=urls[idx])
sp.append(my_dict)
temp = standards[idx].split(", ")
for ts in temp:
if ts not in specs:
specs.append(ts)
std.extend(specs)
results.append(sp)
yield results
#rankings = [{ 'word': 'string', 'freq': 'count/size' }, { 'word': 'string', 'freq': 'count/size' }, { 'word': 'string', 'freq': 'count/size' }, { 'word': 'string', 'freq': 'count/size' }, { 'word': 'string', 'freq': 'count/size' },]
#yield render_template('index.html', results=results)#, query=query, rankings=rankings)
#yield results
cpt += 1
print("ranking")
rankings = top_five_strings(std, len(user_samples))
print("ranked")
yield render_template('index.html', results=results)#, query=query, rankings=rankings)
def infer(query):
gen = infer_gen
while True:
try:
yield str(next(gen))
except StopIteration:
return
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
query = request.form['query']
return infer(query)
else:
return render_template('index.html', results=None, query="", rankings=None)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)