liushaojie
Add application file
360d784
raw
history blame
No virus
3.09 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/8 22:12
@Author : alexanderwu
@File : schema.py
@Desc : mashenquan, 2023/8/22. Add tags to enable custom message classification.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Type, TypedDict, Set, Optional, List
from pydantic import BaseModel
from metagpt.logs import logger
class MessageTag(Enum):
Prerequisite = "prerequisite"
class RawMessage(TypedDict):
content: str
role: str
@dataclass
class Message:
"""list[<role>: <content>]"""
content: str
instruct_content: BaseModel = field(default=None)
role: str = field(default='user') # system / user / assistant
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
tags: Optional[Set] = field(default=None)
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
return f"{self.role}: {self.content}"
def __repr__(self):
return self.__str__()
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content
}
def add_tag(self, tag):
if self.tags is None:
self.tags = set()
self.tags.add(tag)
def remove_tag(self, tag):
if self.tags is None or tag not in self.tags:
return
self.tags.remove(tag)
def is_contain_tags(self, tags: list) -> bool:
"""Determine whether the message contains tags."""
if not tags or not self.tags:
return False
intersection = set(tags) & self.tags
return len(intersection) > 0
def is_contain(self, tag):
return self.is_contain_tags([tag])
def dict(self):
"""pydantic-like `dict` function"""
full = {
"instruct_content": self.instruct_content,
"sent_from": self.sent_from,
"send_to": self.send_to,
"tags": self.tags
}
m = {"content": self.content}
for k, v in full.items():
if v:
m[k] = v
return m
@dataclass
class UserMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'user')
@dataclass
class SystemMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'system')
@dataclass
class AIMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'assistant')
if __name__ == '__main__':
test_content = 'test_message'
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
]
logger.info(msgs)