vsrinivas commited on
Commit
14cb0d3
1 Parent(s): 7999f91

Create funcs.py

Browse files
Files changed (1) hide show
  1. funcs.py +171 -0
funcs.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import io
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import pipeline
8
+ from datetime import datetime
9
+ from PIL import Image
10
+ import os
11
+ from datetime import datetime
12
+ from openai import OpenAI
13
+ from ai71 import AI71
14
+
15
+ if torch.cuda.is_available():
16
+ model = model.to('cuda')
17
+
18
+ # dials_embeddings = pd.read_pickle('dials_embeddings.pkl')
19
+ # dials_embeddings = pd.read_pickle('https://huggingface.co/datasets/vsrinivas/CBT_dialogue_embed_ds/resolve/main/dials_embeddings.pkl')
20
+ dials_embeddings = pd.read_pickle('https://huggingface.co/datasets/vsrinivas/CBT_dialogue_embed_ds/resolve/main/kaggle_therapy_embeddings.pkl')
21
+ with open ('emotion_group_labels.txt') as file:
22
+ emotion_group_labels = file.read().splitlines()
23
+
24
+ embed_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
25
+ classifier = pipeline("zero-shot-classification", model ='facebook/bart-large-mnli')
26
+
27
+ AI71_BASE_URL = "https://api.ai71.ai/v1/"
28
+ AI71_API_KEY = os.getenv('AI71_API_KEY')
29
+
30
+ # Detect emotions from patient dialogues
31
+ def detect_emotions(text):
32
+ emotion = classifier(text, candidate_labels=emotion_group_labels, batch_size=16)
33
+ top_5_scores = [i/sum(emotion['scores'][:5]) for i in emotion['scores'][:5]]
34
+ top_5_emotions = emotion['labels'][:5]
35
+ emotion_set = {l: "{:.2%}".format(s) for l, s in zip(top_5_emotions, top_5_scores)}
36
+ return emotion_set
37
+
38
+ # Measure cosine similarity between a pair of vectors
39
+ def cosine_distance(vec1,vec2):
40
+ cosine = (np.dot(vec1, vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
41
+ return cosine
42
+
43
+ # Generate an image of trigger emotions
44
+ def generate_triggers_img(items):
45
+ labels = list(items.keys())
46
+ values = [float(v.strip('%')) for v in items.values()] # Convert to float for plotting
47
+
48
+ new_items = {k:v for k, v in zip(labels, values)}
49
+ new_items = dict(sorted(new_items.items(), key=lambda item: item[1]))
50
+ labels = list(new_items.keys())
51
+ values = list(new_items.values())
52
+
53
+ fig, ax = plt.subplots(figsize=(10, 6))
54
+ colors = plt.cm.viridis(np.linspace(0, 1, len(labels)))
55
+
56
+ bars = ax.barh(labels, values, color=colors)
57
+
58
+ for spine in ax.spines.values():
59
+ spine.set_visible(False)
60
+
61
+ ax.tick_params(axis='y', labelsize=18)
62
+ ax.xaxis.set_visible(False)
63
+ ax.yaxis.set_ticks_position('none')
64
+
65
+ for bar in bars:
66
+ width = bar.get_width()
67
+ ax.text(width, bar.get_y() + bar.get_height()/2, f'{width:.2f}%',
68
+ ha='left', va='center', fontweight='bold', fontsize=18)
69
+
70
+ plt.tight_layout()
71
+ plt.savefig('triggeres.png')
72
+ triggers_img = Image.open('triggeres.png')
73
+ return triggers_img
74
+
75
+ class session_processor:
76
+ def __init__(self):
77
+ self.session_conversation = []
78
+
79
+ # Generate therapist responses and patient triggers
80
+ def get_doc_response_emotions(user_message, therapy_session_conversation):
81
+
82
+ user_messages = []
83
+ user_messages.append(user_message)
84
+ emotion_set = detect_emotions(user_message)
85
+ print(emotion_set)
86
+
87
+ emotions_msg = generate_triggers_img(emotion_set)
88
+ user_embedding = embed_model.encode(user_message, device='cuda' if torch.cuda.is_available() else 'cpu')
89
+
90
+ similarities =[]
91
+ for v in dials_embeddings['embeddings']:
92
+ similarities.append(cosine_distance(user_embedding,v))
93
+
94
+ top_match_index = similarities.index(max(similarities))
95
+ # doc_response = dials_embeddings.iloc[top_match_index+1]['Doctor']
96
+ doc_response = dials_embeddings.iloc[top_match_index]['Doctor']
97
+
98
+ therapy_session_conversation.append(["User: "+user_message, "Therapist: "+doc_response])
99
+
100
+ self.session_conversation.extend(["User: "+user_message, "Therapist: "+doc_response])
101
+
102
+ print(f"User's message: {user_message}")
103
+ print(f"RAG Matching message: {dials_embeddings.iloc[top_match_index]['Patient']}")
104
+ # print(f"Therapist's response: {dials_embeddings.iloc[top_match_index+1]['Doctor']}\n\n")
105
+ print(f"Therapist's response: {dials_embeddings.iloc[top_match_index]['Doctor']}\n\n")
106
+
107
+ return '', therapy_session_conversation, emotions_msg
108
+
109
+ # Generate summarization and recommendations for teh session
110
+ def summarize_and_recommend():
111
+ session_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
112
+ session_conversation_processed = self.session_conversation.copy()
113
+ session_conversation_processed.insert(0, "Session_time: "+session_time)
114
+ session_conversation_processed ='\n'.join(session_conversation_processed)
115
+ print("Session conversation:", session_conversation_processed)
116
+
117
+ AI71_BASE_URL = "https://api.ai71.ai/v1/"
118
+
119
+ client = OpenAI(
120
+ api_key=AI71_API_KEY,
121
+ base_url=AI71_BASE_URL,
122
+ )
123
+
124
+ full_summary = ""
125
+ for chunk in AI71(AI71_API_KEY).chat.completions.create(
126
+ model="tiiuae/falcon-180b-chat",
127
+ messages=[
128
+ {"role": "system", "content": """You are an Expert Cognitive Behavioural Therapist and Precis writer.
129
+ Summarize the below user content <<<session_conversation_processed>>> into useful, ethical, relevant and realistic phrases with a format
130
+ Session Time:
131
+ Summary of the patient messages: #in two to four sentences
132
+ Summary of therapist messages: #in two to three sentences:
133
+ Summary of the whole session: # in two to three sentences. Ensure the entire session summary strictly does not exceed 100 tokens."""},
134
+ {"role": "user", "content": session_conversation_processed},
135
+ ],
136
+ stream=True,
137
+ ):
138
+ if chunk.choices[0].delta.content:
139
+ summary = chunk.choices[0].delta.content
140
+ # print("Chunk summary:", summary, sep="", end="", flush=True)
141
+ full_summary += summary
142
+ full_summary = full_summary.replace('User:', '').strip()
143
+ print("\n")
144
+ print("Full summary:", full_summary)
145
+
146
+ full_recommendations = ""
147
+ for chunk in AI71(AI71_API_KEY).chat.completions.create(
148
+ model="tiiuae/falcon-180b-chat",
149
+ messages=[
150
+ {"role": "system", "content": """You are an expert Cognitive Behavioural Therapist.
151
+ Based on the full summary <<<full_summary>>> provide clinically valid, useful, appropriate action plan for the Patient as a bullted list.
152
+ The list shall contain both medical and non medical prescriptions, dos and donts. The format of response shall be in passive voice with proper tense.
153
+ - The patient is referred to........ #in one sentence
154
+ - The patient is advised to ........ #in one sentence
155
+ - The patient is refrained from........ #in one sentence
156
+ - It is suggested that tha patient ........ #in one sentence
157
+ - Scheduled a follow-up session with the patient........#in one sentence
158
+ *Ensure the list contains NOT MORE THAN 7 points"""},
159
+ {"role": "user", "content": full_summary},
160
+ ],
161
+ stream=True,
162
+ ):
163
+ if chunk.choices[0].delta.content:
164
+ rec = chunk.choices[0].delta.content
165
+ # print("Chunk recommendation:", rec, sep="", end="", flush=True)
166
+ full_recommendations += rec
167
+ full_recommendations = full_recommendations.replace('User:', '').strip()
168
+ print("\n")
169
+ print("Full recommendations:", full_recommendations)
170
+ self.session_conversation=[]
171
+ return full_summary, full_recommendations