|
from flask import Flask, request, jsonify, render_template
|
|
import torch
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import os
|
|
import torch.nn as nn
|
|
import timm
|
|
from torchvision.models import swin_t, Swin_T_Weights, vit_b_16, ViT_B_16_Weights
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
UPLOAD_FOLDER = os.path.join('static', 'uploads')
|
|
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
model = GPT2LMHeadModel.from_pretrained('models\\LLM').to(device)
|
|
tokenizer = GPT2Tokenizer.from_pretrained('models\\LLM')
|
|
separator_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
gastrointestinal_classes = ['Diverticulosis', 'Neoplasm', 'Peritonitis', 'Ureters']
|
|
gastrointestinal_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
|
|
gastrointestinal_model.head = nn.Linear(gastrointestinal_model.head.in_features, len(gastrointestinal_classes))
|
|
gastrointestinal_model = gastrointestinal_model.to(device)
|
|
gastrointestinal_model.load_state_dict(torch.load('models\\gastrointestinal_model_swin.pth', map_location=device, weights_only=True), strict=False)
|
|
gastrointestinal_model.eval()
|
|
|
|
|
|
chest_ct_classes = ['Adenocarcinoma', 'Large Cell Carcinoma', 'Normal', 'Squamous Cell Carcinoma']
|
|
chest_ct_model = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
|
|
chest_ct_model.head = nn.Linear(chest_ct_model.head.in_features, len(chest_ct_classes))
|
|
chest_ct_model = chest_ct_model.to(device)
|
|
chest_ct_model.load_state_dict(torch.load('models\\best_model.pth', map_location=device, weights_only=True), strict=False)
|
|
chest_ct_model.eval()
|
|
|
|
|
|
chest_xray_classes = ['Normal', 'Pneumonia']
|
|
chest_xray_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
|
|
chest_xray_model.heads.head = nn.Linear(chest_xray_model.heads.head.in_features, len(chest_xray_classes))
|
|
chest_xray_model = chest_xray_model.to(device)
|
|
chest_xray_model.load_state_dict(torch.load('models\\best_model_vit_chest_xray.pth', map_location=device, weights_only=True), strict=False)
|
|
chest_xray_model.eval()
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
|
|
def process_image(image_path):
|
|
image = Image.open(image_path).convert('RGB')
|
|
return transform(image).unsqueeze(0).to(device)
|
|
|
|
|
|
def generate_answer(question, max_length=1024):
|
|
model.eval()
|
|
input_text = question + separator_token
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
|
|
|
|
output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
|
|
answer = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
return answer
|
|
|
|
|
|
@app.route('/predict_gastrointestinal', methods=['POST'])
|
|
def predict_gastrointestinal():
|
|
if 'file' not in request.files:
|
|
return jsonify({"error": "No file uploaded"}), 400
|
|
|
|
file = request.files['file']
|
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
|
file.save(file_path)
|
|
|
|
|
|
image_tensor = process_image(file_path)
|
|
|
|
|
|
with torch.no_grad():
|
|
output = gastrointestinal_model(image_tensor)
|
|
|
|
|
|
|
|
|
|
if len(output.shape) > 2:
|
|
output = output.view(output.size(0), -1)
|
|
|
|
|
|
if output.size(0) != 1:
|
|
return jsonify({"error": "Unexpected output size"}), 500
|
|
|
|
|
|
_, predicted = torch.max(output, 1)
|
|
predicted_class = gastrointestinal_classes[predicted.item()]
|
|
|
|
return jsonify({'prediction': predicted_class})
|
|
|
|
|
|
@app.route('/predict_chest_ct', methods=['POST'])
|
|
def predict_chest_ct():
|
|
if 'file' not in request.files:
|
|
return jsonify({"error": "No file uploaded"}), 400
|
|
|
|
file = request.files['file']
|
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
|
file.save(file_path)
|
|
|
|
|
|
image_tensor = process_image(file_path)
|
|
|
|
|
|
with torch.no_grad():
|
|
output = chest_ct_model(image_tensor)
|
|
_, predicted = torch.max(output, 1)
|
|
predicted_class = chest_ct_classes[predicted.item()]
|
|
|
|
return jsonify({'prediction': predicted_class})
|
|
|
|
|
|
@app.route('/predict_chest_xray', methods=['POST'])
|
|
def predict_chest_xray():
|
|
if 'file' not in request.files:
|
|
return jsonify({"error": "No file uploaded"}), 400
|
|
|
|
file = request.files['file']
|
|
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
|
|
file.save(file_path)
|
|
|
|
|
|
image_tensor = process_image(file_path)
|
|
|
|
|
|
with torch.no_grad():
|
|
output = chest_xray_model(image_tensor)
|
|
_, predicted = torch.max(output, 1)
|
|
predicted_class = chest_xray_classes[predicted.item()]
|
|
|
|
return jsonify({'prediction': predicted_class})
|
|
|
|
|
|
|
|
@app.route('/ask_llm', methods=['POST'])
|
|
def ask_llm():
|
|
user_question = request.json.get('question', None)
|
|
|
|
if not user_question:
|
|
return jsonify({"error": "No question provided"}), 400
|
|
|
|
try:
|
|
|
|
answer = generate_answer(user_question)
|
|
except Exception as e:
|
|
return jsonify({"error": f"An error occurred: {str(e)}"}), 500
|
|
|
|
return jsonify({'answer': answer})
|
|
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True)
|
|
|