|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_community.llms import OpenAI |
|
|
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
class LLLResponseGenerator(): |
|
|
|
def __init__(self): |
|
print("initialized") |
|
|
|
|
|
def llm_inference( |
|
self, |
|
model_type: str, |
|
question: str, |
|
prompt_template: str, |
|
context: str, |
|
ai_tone: str, |
|
questionnaire: str, |
|
user_text: str, |
|
openai_model_name: str = "", |
|
|
|
hf_repo_id: str = "mistralai/Mistral-7B-Instruct-v0.2", |
|
temperature: float = 0.5, |
|
max_length: int = 128 * 4, |
|
) -> str: |
|
"""Call HuggingFace/OpenAI model for inference |
|
|
|
Given a question, prompt_template, and other parameters, this function calls the relevant |
|
API to fetch LLM inference results. |
|
|
|
Args: |
|
model_str: Denotes the LLM vendor's name. Can be either 'huggingface' or 'openai' |
|
question: The question to be asked to the LLM. |
|
prompt_template: The prompt template itself. |
|
context: Instructions for the LLM. |
|
ai_tone: Can be either empathy, encouragement or suggest medical help. |
|
questionnaire: Can be either depression, anxiety or adhd. |
|
user_text: Response given by the user. |
|
hf_repo_id: The Huggingface model's repo_id |
|
temperature: (Default: 1.0). Range: Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability. |
|
max_length: Integer to define the maximum length in tokens of the output summary. |
|
|
|
Returns: |
|
A Python string which contains the inference result. |
|
|
|
HuggingFace repo_id examples: |
|
- google/flan-t5-xxl |
|
- tiiuae/falcon-7b-instruct |
|
|
|
""" |
|
prompt = PromptTemplate( |
|
template=prompt_template, |
|
input_variables=[ |
|
"context", |
|
"ai_tone", |
|
"questionnaire", |
|
"question", |
|
"user_text", |
|
], |
|
) |
|
|
|
if model_type == "openai": |
|
|
|
llm = OpenAI( |
|
model_name=openai_model_name, temperature=temperature, max_tokens=max_length |
|
) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
return llm_chain.run( |
|
context=context, |
|
ai_tone=ai_tone, |
|
questionnaire=questionnaire, |
|
question=question, |
|
user_text=user_text, |
|
) |
|
|
|
elif model_type == "huggingface": |
|
|
|
llm = HuggingFaceHub( |
|
repo_id=hf_repo_id, |
|
model_kwargs={"temperature": temperature, "max_length": max_length}, |
|
) |
|
|
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
response = llm_chain.run( |
|
context=context, |
|
ai_tone=ai_tone, |
|
questionnaire=questionnaire, |
|
question=question, |
|
user_text=user_text, |
|
) |
|
print(response) |
|
|
|
response_start_index = response.find("Response;") |
|
return response[response_start_index + len("Response;"):].strip() |
|
|
|
else: |
|
print( |
|
"Please use the correct value of model_type parameter: It can have a value of either openai or huggingface" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') |
|
|
|
context = "You are a mental health supporting non-medical assistant. DO NOT PROVIDE any medical advice with conviction." |
|
|
|
ai_tone = "EMPATHY" |
|
questionnaire = "ADHD" |
|
question = ( |
|
"How often do you find yourself having trouble focusing on tasks or activities?" |
|
) |
|
user_text = "I feel distracted all the time, and I am never able to finish" |
|
|
|
|
|
template = """INSTRUCTIONS: {context} |
|
|
|
Respond to the user with a tone of {ai_tone}. |
|
|
|
Question asked to the user: {question} |
|
|
|
Response by the user: {user_text} |
|
|
|
Provide some advice and ask a relevant question back to the user. |
|
|
|
Response; |
|
""" |
|
|
|
temperature = 0.5 |
|
max_length = 128 *4 |
|
|
|
model = LLLResponseGenerator() |
|
|
|
|
|
llm_response = model.llm_inference( |
|
model_type="huggingface", |
|
question=question, |
|
prompt_template=template, |
|
context=context, |
|
ai_tone=ai_tone, |
|
questionnaire=questionnaire, |
|
user_text=user_text, |
|
temperature=temperature, |
|
max_length=max_length, |
|
) |
|
|
|
print(llm_response) |
|
|