# from fastapi.staticfiles import StaticFiles # from fastapi.responses import FileResponse from fastapi import FastAPI from pydantic import BaseModel from transformers import GPT2Tokenizer, GPT2LMHeadModel from langchain.prompts import PromptTemplate app = FastAPI() tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2LMHeadModel.from_pretrained('gpt2') class TextRequest(BaseModel): question: str def preprocess_text(question: str): return question.lower() def classify_text(question: str): prompt_template = PromptTemplate(template="Answer the following question and classify it: {question}", input_variables=["question"]) format_prompt = prompt_template.format(question=question) encoded_input = tokenizer(format_prompt, return_tensors='pt') # Run the model output = model.generate(**encoded_input) # Use generate method for text generation # Decode the model output to text decoded_output = tokenizer.decode(output[0]) response_text = decoded_output.split('\n\n') answer=response_text[1] return {"answer": answer} @app.post("/classify") async def classify_text_endpoint(request: TextRequest): preprocessed_text = preprocess_text(request.question) response = classify_text(preprocessed_text) return response # # # @app.get("/infer_t5") # def t5(input): # preprocessed_text = preprocess_text(request.text) # response = classify_text(preprocessed_text) # output = pipe_flan(input) # return {"output": output[0]["generated_text"]} # # app.mount("/", StaticFiles(directory="static", html=True), name="static") # # @app.get("/") # def index() -> FileResponse: # return FileResponse(path="/app/static/index.html", media_type="text/html")