from flask import Flask, request, jsonify, render_template import torch from torchvision import transforms from PIL import Image import os import torch.nn as nn import timm from torchvision.models import swin_t, Swin_T_Weights, vit_b_16, ViT_B_16_Weights from transformers import GPT2LMHeadModel, GPT2Tokenizer app = Flask(__name__) # Set up directories for uploads and models UPLOAD_FOLDER = os.path.join('static', 'uploads') app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the LLM model and tokenizer model = GPT2LMHeadModel.from_pretrained('models\\LLM').to(device) tokenizer = GPT2Tokenizer.from_pretrained('models\\LLM') separator_token = tokenizer.eos_token # Separator token for the model # Define and load the pre-trained Swin models # Gastrointestinal Model (4 classes: Diverticulosis, Neoplasm, Peritonitis, Ureters) gastrointestinal_classes = ['Diverticulosis', 'Neoplasm', 'Peritonitis', 'Ureters'] gastrointestinal_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True) gastrointestinal_model.head = nn.Linear(gastrointestinal_model.head.in_features, len(gastrointestinal_classes)) gastrointestinal_model = gastrointestinal_model.to(device) gastrointestinal_model.load_state_dict(torch.load('models\\gastrointestinal_model_swin.pth', map_location=device, weights_only=True), strict=False) gastrointestinal_model.eval() # Chest CT Model (4 classes: Adenocarcinoma, Large cell carcinoma, Normal, Squamous cell carcinoma) chest_ct_classes = ['Adenocarcinoma', 'Large Cell Carcinoma', 'Normal', 'Squamous Cell Carcinoma'] chest_ct_model = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1) chest_ct_model.head = nn.Linear(chest_ct_model.head.in_features, len(chest_ct_classes)) chest_ct_model = chest_ct_model.to(device) chest_ct_model.load_state_dict(torch.load('models\\best_model.pth', map_location=device, weights_only=True), strict=False) chest_ct_model.eval() # Chest X-ray Model (2 classes: Normal, Pneumonia) chest_xray_classes = ['Normal', 'Pneumonia'] chest_xray_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) chest_xray_model.heads.head = nn.Linear(chest_xray_model.heads.head.in_features, len(chest_xray_classes)) chest_xray_model = chest_xray_model.to(device) chest_xray_model.load_state_dict(torch.load('models\\best_model_vit_chest_xray.pth', map_location=device, weights_only=True), strict=False) chest_xray_model.eval() # Image transformation (same for all models) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Helper function to load and transform images def process_image(image_path): image = Image.open(image_path).convert('RGB') return transform(image).unsqueeze(0).to(device) # LLM helper function to generate answers def generate_answer(question, max_length=1024): model.eval() # Set the model to evaluation mode input_text = question + separator_token input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id) answer = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) return answer # Prediction routes for each model @app.route('/predict_gastrointestinal', methods=['POST']) def predict_gastrointestinal(): if 'file' not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files['file'] file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) file.save(file_path) # Preprocess the image image_tensor = process_image(file_path) # Make prediction using the gastrointestinal model with torch.no_grad(): output = gastrointestinal_model(image_tensor) # Ensure the output tensor has the right shape and handle it # If the output has extra dimensions, flatten it if len(output.shape) > 2: output = output.view(output.size(0), -1) # Check if output is for a batch or single sample if output.size(0) != 1: return jsonify({"error": "Unexpected output size"}), 500 # Get the predicted class (ensure it's scalar) _, predicted = torch.max(output, 1) predicted_class = gastrointestinal_classes[predicted.item()] return jsonify({'prediction': predicted_class}) @app.route('/predict_chest_ct', methods=['POST']) def predict_chest_ct(): if 'file' not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files['file'] file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) file.save(file_path) # Preprocess the image image_tensor = process_image(file_path) # Make prediction using the chest CT model with torch.no_grad(): output = chest_ct_model(image_tensor) _, predicted = torch.max(output, 1) predicted_class = chest_ct_classes[predicted.item()] return jsonify({'prediction': predicted_class}) @app.route('/predict_chest_xray', methods=['POST']) def predict_chest_xray(): if 'file' not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files['file'] file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename) file.save(file_path) # Preprocess the image image_tensor = process_image(file_path) # Make prediction using the chest X-ray model with torch.no_grad(): output = chest_xray_model(image_tensor) _, predicted = torch.max(output, 1) predicted_class = chest_xray_classes[predicted.item()] return jsonify({'prediction': predicted_class}) # New LLM route for asking questions @app.route('/ask_llm', methods=['POST']) def ask_llm(): user_question = request.json.get('question', None) if not user_question: return jsonify({"error": "No question provided"}), 400 try: # Generate answer using the fine-tuned GPT-2 model answer = generate_answer(user_question) except Exception as e: return jsonify({"error": f"An error occurred: {str(e)}"}), 500 return jsonify({'answer': answer}) # Main route for the homepage @app.route('/') def index(): return render_template('index.html') if __name__ == "__main__": app.run(debug=True)