File size: 1,491 Bytes
09cc233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer

class ImageChatbot:
    def __init__(self, model_name='openbmb/MiniCPM-Llama3-V-2_5-int4'):
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model.eval()

    def load_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        return image

    def display_image(self, image):
        import matplotlib.pyplot as plt
        plt.imshow(image)
        plt.axis('off')
        plt.show()

    def chat_with_image(self, image_path, question, sampling=True, temperature=0.7):
        image = self.load_image(image_path)
        self.display_image(image)
        msgs = [{'role': 'user', 'content': question}]
        res = self.model.chat(
            image=image,
            msgs=msgs,
            tokenizer=self.tokenizer,
            sampling=sampling,
            temperature=temperature,
        )
        generated_text = ""
        for new_text in res:
            generated_text += new_text
            print(new_text, flush=True, end='')
        return generated_text

# Example usage
if __name__ == "__main__":
    image_chatbot = ImageChatbot()
    image_path = '/content/sample_data/Cat-profile-picture-35.jpg'
    question = 'این شکل چی هست ؟'
    generated_text = image_chatbot.chat_with_image(image_path, question)