tttoaster commited on
Commit
6d3af86
1 Parent(s): c1c47cb

Upload conversation.py

Browse files
Files changed (1) hide show
  1. conversation.py +182 -0
conversation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+ import io
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ import copy
10
+
11
+ IMG_FLAG = '<image>'
12
+
13
+
14
+ class SeparatorStyle(Enum):
15
+ """Different separator style."""
16
+ SINGLE = auto()
17
+ TWO = auto()
18
+ MPT = auto()
19
+ PLAIN = auto()
20
+ LLAMA_2 = auto()
21
+
22
+
23
+ def decode_image(encoded_image: str) -> Image:
24
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
25
+ buffer = io.BytesIO(decoded_bytes)
26
+ image = Image.open(buffer)
27
+ return image
28
+
29
+
30
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
31
+ with io.BytesIO() as buffer:
32
+ image.save(buffer, format=format)
33
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return encoded_image
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that keeps all conversation history."""
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ skip_next: bool = False
50
+
51
+ def get_prompt(self):
52
+ messages = copy.deepcopy(self.messages)
53
+ if self.sep_style == SeparatorStyle.SINGLE:
54
+ if self.system is None or self.system == '':
55
+ text = ''
56
+ else:
57
+ text = self.system + self.sep
58
+ images = []
59
+ for message in messages:
60
+ text += message['role'] + ": " + message['message']['text'] + self.sep
61
+ for image_path in message['message']['images']:
62
+ image = Image.open(image_path).resize((256, 256))
63
+ image_base64 = encode_image(image)
64
+ images.append(image_base64)
65
+
66
+ text += self.roles[1] + ":"
67
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
68
+ b_token = "[INST] "
69
+ e_token = " [/INST]"
70
+ if self.system is None or self.system == '':
71
+ text = ''
72
+ else:
73
+ text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
74
+ images = []
75
+ for idx, message in enumerate(messages):
76
+ # text += message['role'] + ": " + message['message']['text'] + self.sep
77
+ if idx % 2 == 0:
78
+ text += b_token + message['message']['text'] + e_token + self.sep
79
+ else:
80
+ text += message['message']['text'] + self.sep
81
+
82
+ for image_path in message['message']['images']:
83
+ image = Image.open(image_path)
84
+ image_base64 = encode_image(image)
85
+ images.append(image_base64)
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ return {'text': text, 'images': images}
90
+
91
+ # def update_image_ids(self, images_ids):
92
+ # image_count = 0
93
+ # for message in self.messages:
94
+ # for idx in range(len(message['message']['images_ids'])):
95
+ # if message['message']["images_ids"][idx] is None:
96
+ # message['message']["images_ids"][idx] = images_ids[image_count]
97
+ # image_count += 1
98
+
99
+ # assert len(images_ids) == image_count, print(len(images_ids), image_count)
100
+
101
+ def append_message(self, role, message):
102
+ self.messages.append([role, message])
103
+
104
+ def to_gradio_chatbot(self):
105
+ dialog = []
106
+ for i, single_turn in enumerate(self.messages[self.offset:]):
107
+ single_turn = single_turn['message']
108
+ text_list = single_turn['text'].split(IMG_FLAG)
109
+ assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
110
+ message = ''
111
+ for image_idx in range(len(single_turn['images'])):
112
+ # image = single_turn['images'][image_idx]
113
+ # image_base64 = encode_image(image)
114
+ # image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
115
+ image_path = single_turn['images'][image_idx]
116
+ if image_path == '':
117
+ message += text_list[image_idx] + '<corrupt_image>'
118
+ else:
119
+ message += text_list[image_idx] + f'![](file={image_path})'
120
+ message += text_list[-1]
121
+
122
+ if i % 2 == 0:
123
+ dialog.append([message, None])
124
+ else:
125
+ dialog[-1][-1] = message
126
+
127
+ return dialog
128
+
129
+ def copy(self):
130
+ return Conversation(system=self.system,
131
+ roles=self.roles,
132
+ messages=copy.deepcopy(self.messages),
133
+ offset=self.offset,
134
+ sep_style=self.sep_style,
135
+ sep=self.sep,
136
+ sep2=self.sep2,
137
+ version=self.version)
138
+
139
+ def dict(self):
140
+ messages = copy.deepcopy(self.messages)
141
+ for message in messages:
142
+ for i in range(len(message['message']['images'])):
143
+ message['message']['images'][i] = os.path.basename(message['message']['images'][i])
144
+ return {
145
+ "system": self.system,
146
+ "roles": self.roles,
147
+ "messages": messages,
148
+ "offset": self.offset,
149
+ "sep": self.sep,
150
+ "sep2": self.sep2,
151
+ }
152
+
153
+
154
+ conv_seed_vicuna = Conversation(
155
+ system="",
156
+ roles=("USER", "ASSISTANT"),
157
+ version="v2",
158
+ messages=[],
159
+ offset=0,
160
+ sep_style=SeparatorStyle.SINGLE,
161
+ sep='\n',
162
+ )
163
+
164
+ conv_seed_vicuna_system = Conversation(
165
+ system="A chat between a curious user and an artificial intelligence assistant. ",
166
+ roles=("USER", "ASSISTANT"),
167
+ version="v2",
168
+ messages=[],
169
+ offset=0,
170
+ sep_style=SeparatorStyle.SINGLE,
171
+ sep='\n',
172
+ )
173
+
174
+ conv_seed_llama2 = Conversation(
175
+ system="",
176
+ roles=("[INST]", "[/INST]"),
177
+ version="v2",
178
+ messages=[],
179
+ offset=0,
180
+ sep_style=SeparatorStyle.LLAMA_2,
181
+ sep='\n',
182
+ )