from torchvision import models, transforms from PIL import Image import torch import torch.nn as nn import io import streamlit as st import time st.title("パーソナルカラー診断AI") SIZE = 224 MEAN = (0.485, 0.456, 0.406) STD = (0.229, 0.224, 0.225) transform = transforms.Compose([ transforms.Resize((SIZE, SIZE)), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) model = models.resnet152(pretrained=True) n_classes = 4 num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, n_classes) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load('Resnet_2024_0214_version1', map_location=device)) model.to(device) model.eval() view_flag = True skip = False def predict_image(img): img = img.convert('RGB') img_transformed = transform(img) inputs = img_transformed.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(inputs) _, preds = torch.max(outputs, 1) return preds.item() uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'png']) if uploaded_file: img = Image.open(uploaded_file) st.image(img, caption="Uploaded Image", use_column_width=True) pred = predict_image(img) if pred == 0: season_type = "秋" elif pred == 1: season_type = "春" elif pred == 2: season_type = "夏" else: season_type = "冬" if 'show_video' not in st.session_state: st.session_state.show_video = False if 'skip' not in st.session_state: st.session_state.skip = False if 'result' not in st.session_state: st.session_state.result = False st.write(f"パーソナルカラー診断結果:{season_type} ") st.write("あなたにおすすめの色はこちらです") st.session_state.result = True st.image(f"{season_type}.png") st.write( """ あなたにおすすめの商品はこちらです """) st.image("服.png")