File size: 4,354 Bytes
45e92bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModel, AutoTokenizer

# Model path
model_path = "openbmb/MiniCPM-Llama3-V-2_5"

# User and assistant names
U_NAME = "User"
A_NAME = "Assistant"

# Set page configuration
st.set_page_config(
    page_title="MiniCPM-Llama3-V-2_5 Streamlit",
    page_icon=":robot:",
    layout="wide"
)


# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
    print(f"load_model_and_tokenizer from {model_path}")
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device="cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return model, tokenizer


# Initialize session state
if 'model' not in st.session_state:
    st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
    st.session_state.model.eval()
    print("model and tokenizer had loaded completed!")

# Initialize session state
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

# Sidebar settings
sidebar_name = st.sidebar.title("MiniCPM-Llama3-V-2_5 Streamlit")
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)

# Clear chat history button
buttonClean = st.sidebar.button("Clear chat history", key="clean")
if buttonClean:
    st.session_state.chat_history = []
    st.session_state.response = ""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    st.rerun()

# Display chat history
for i, message in enumerate(st.session_state.chat_history):
    if message["role"] == "user":
        with st.chat_message(name="user", avatar="user"):
            if message["image"] is not None:
                st.image(message["image"], caption='User uploaded image', width=448, use_column_width=False)
                continue
            elif message["content"] is not None:
                st.markdown(message["content"])
    else:
        with st.chat_message(name="model", avatar="assistant"):
            st.markdown(message["content"])

# Select mode
selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
if selected_mode == "Image":
    # Image mode
    uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"],
                                              accept_multiple_files=False)
    if uploaded_image is not None:
        st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
        # Add uploaded image to chat history
        st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})

# User input box
user_text = st.chat_input("Enter your question")
if user_text:
    with st.chat_message(U_NAME, avatar="user"):
        st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
        st.markdown(f"{U_NAME}: {user_text}")

    # Generate reply using the model
    model = st.session_state.model
    tokenizer = st.session_state.tokenizer

    with st.chat_message(A_NAME, avatar="assistant"):
        # If the previous message contains an image, pass the image to the model
        if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
            uploaded_image = st.session_state.chat_history[-2]["image"]
            imagefile = Image.open(uploaded_image).convert('RGB')

        msgs = [{"role": "user", "content": user_text}]
        res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
                         sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
                         temperature=temperature, stream=True)

        # Collect the generated_text str
        generated_text = st.write_stream(res)

        st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None})

    st.divider()