UvicornGpt2 / main.py
vakodiya's picture
Update main.py
4065212
# from fastapi.staticfiles import StaticFiles
# from fastapi.responses import FileResponse
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from langchain.prompts import PromptTemplate
app = FastAPI()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
class TextRequest(BaseModel):
question: str
def preprocess_text(question: str):
return question.lower()
def classify_text(question: str):
prompt_template = PromptTemplate(template="Answer the following question and classify it: {question}",
input_variables=["question"])
format_prompt = prompt_template.format(question=question)
encoded_input = tokenizer(format_prompt, return_tensors='pt')
# Run the model
output = model.generate(**encoded_input) # Use generate method for text generation
# Decode the model output to text
decoded_output = tokenizer.decode(output[0])
response_text = decoded_output.split('\n\n')
answer=response_text[1]
return {"answer": answer}
@app.post("/classify")
async def classify_text_endpoint(request: TextRequest):
preprocessed_text = preprocess_text(request.question)
response = classify_text(preprocessed_text)
return response
#
#
# @app.get("/infer_t5")
# def t5(input):
# preprocessed_text = preprocess_text(request.text)
# response = classify_text(preprocessed_text)
# output = pipe_flan(input)
# return {"output": output[0]["generated_text"]}
#
# app.mount("/", StaticFiles(directory="static", html=True), name="static")
#
# @app.get("/")
# def index() -> FileResponse:
# return FileResponse(path="/app/static/index.html", media_type="text/html")