hassaanik's picture
Upload 25 files
8f65667 verified
from flask import Flask, request, jsonify, render_template
import torch
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import timm
from torchvision.models import swin_t, Swin_T_Weights, vit_b_16, ViT_B_16_Weights
from transformers import GPT2LMHeadModel, GPT2Tokenizer
app = Flask(__name__)
# Set up directories for uploads and models
UPLOAD_FOLDER = os.path.join('static', 'uploads')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the LLM model and tokenizer
model = GPT2LMHeadModel.from_pretrained('models\\LLM').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('models\\LLM')
separator_token = tokenizer.eos_token # Separator token for the model
# Define and load the pre-trained Swin models
# Gastrointestinal Model (4 classes: Diverticulosis, Neoplasm, Peritonitis, Ureters)
gastrointestinal_classes = ['Diverticulosis', 'Neoplasm', 'Peritonitis', 'Ureters']
gastrointestinal_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
gastrointestinal_model.head = nn.Linear(gastrointestinal_model.head.in_features, len(gastrointestinal_classes))
gastrointestinal_model = gastrointestinal_model.to(device)
gastrointestinal_model.load_state_dict(torch.load('models\\gastrointestinal_model_swin.pth', map_location=device, weights_only=True), strict=False)
gastrointestinal_model.eval()
# Chest CT Model (4 classes: Adenocarcinoma, Large cell carcinoma, Normal, Squamous cell carcinoma)
chest_ct_classes = ['Adenocarcinoma', 'Large Cell Carcinoma', 'Normal', 'Squamous Cell Carcinoma']
chest_ct_model = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
chest_ct_model.head = nn.Linear(chest_ct_model.head.in_features, len(chest_ct_classes))
chest_ct_model = chest_ct_model.to(device)
chest_ct_model.load_state_dict(torch.load('models\\best_model.pth', map_location=device, weights_only=True), strict=False)
chest_ct_model.eval()
# Chest X-ray Model (2 classes: Normal, Pneumonia)
chest_xray_classes = ['Normal', 'Pneumonia']
chest_xray_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
chest_xray_model.heads.head = nn.Linear(chest_xray_model.heads.head.in_features, len(chest_xray_classes))
chest_xray_model = chest_xray_model.to(device)
chest_xray_model.load_state_dict(torch.load('models\\best_model_vit_chest_xray.pth', map_location=device, weights_only=True), strict=False)
chest_xray_model.eval()
# Image transformation (same for all models)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Helper function to load and transform images
def process_image(image_path):
image = Image.open(image_path).convert('RGB')
return transform(image).unsqueeze(0).to(device)
# LLM helper function to generate answers
def generate_answer(question, max_length=1024):
model.eval() # Set the model to evaluation mode
input_text = question + separator_token
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
answer = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return answer
# Prediction routes for each model
@app.route('/predict_gastrointestinal', methods=['POST'])
def predict_gastrointestinal():
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(file_path)
# Preprocess the image
image_tensor = process_image(file_path)
# Make prediction using the gastrointestinal model
with torch.no_grad():
output = gastrointestinal_model(image_tensor)
# Ensure the output tensor has the right shape and handle it
# If the output has extra dimensions, flatten it
if len(output.shape) > 2:
output = output.view(output.size(0), -1)
# Check if output is for a batch or single sample
if output.size(0) != 1:
return jsonify({"error": "Unexpected output size"}), 500
# Get the predicted class (ensure it's scalar)
_, predicted = torch.max(output, 1)
predicted_class = gastrointestinal_classes[predicted.item()]
return jsonify({'prediction': predicted_class})
@app.route('/predict_chest_ct', methods=['POST'])
def predict_chest_ct():
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(file_path)
# Preprocess the image
image_tensor = process_image(file_path)
# Make prediction using the chest CT model
with torch.no_grad():
output = chest_ct_model(image_tensor)
_, predicted = torch.max(output, 1)
predicted_class = chest_ct_classes[predicted.item()]
return jsonify({'prediction': predicted_class})
@app.route('/predict_chest_xray', methods=['POST'])
def predict_chest_xray():
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(file_path)
# Preprocess the image
image_tensor = process_image(file_path)
# Make prediction using the chest X-ray model
with torch.no_grad():
output = chest_xray_model(image_tensor)
_, predicted = torch.max(output, 1)
predicted_class = chest_xray_classes[predicted.item()]
return jsonify({'prediction': predicted_class})
# New LLM route for asking questions
@app.route('/ask_llm', methods=['POST'])
def ask_llm():
user_question = request.json.get('question', None)
if not user_question:
return jsonify({"error": "No question provided"}), 400
try:
# Generate answer using the fine-tuned GPT-2 model
answer = generate_answer(user_question)
except Exception as e:
return jsonify({"error": f"An error occurred: {str(e)}"}), 500
return jsonify({'answer': answer})
# Main route for the homepage
@app.route('/')
def index():
return render_template('index.html')
if __name__ == "__main__":
app.run(debug=True)