|
|
|
|
|
""" |
|
@Time : 2023/5/8 22:12 |
|
@Author : alexanderwu |
|
@File : schema.py |
|
@Desc : mashenquan, 2023/8/22. Add tags to enable custom message classification. |
|
""" |
|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass, field |
|
from enum import Enum |
|
from typing import Type, TypedDict, Set, Optional, List |
|
|
|
from pydantic import BaseModel |
|
|
|
from metagpt.logs import logger |
|
|
|
|
|
class MessageTag(Enum): |
|
Prerequisite = "prerequisite" |
|
|
|
|
|
class RawMessage(TypedDict): |
|
content: str |
|
role: str |
|
|
|
|
|
@dataclass |
|
class Message: |
|
"""list[<role>: <content>]""" |
|
content: str |
|
instruct_content: BaseModel = field(default=None) |
|
role: str = field(default='user') |
|
cause_by: Type["Action"] = field(default="") |
|
sent_from: str = field(default="") |
|
send_to: str = field(default="") |
|
tags: Optional[Set] = field(default=None) |
|
|
|
def __str__(self): |
|
|
|
return f"{self.role}: {self.content}" |
|
|
|
def __repr__(self): |
|
return self.__str__() |
|
|
|
def to_dict(self) -> dict: |
|
return { |
|
"role": self.role, |
|
"content": self.content |
|
} |
|
|
|
def add_tag(self, tag): |
|
if self.tags is None: |
|
self.tags = set() |
|
self.tags.add(tag) |
|
|
|
def remove_tag(self, tag): |
|
if self.tags is None or tag not in self.tags: |
|
return |
|
self.tags.remove(tag) |
|
|
|
def is_contain_tags(self, tags: list) -> bool: |
|
"""Determine whether the message contains tags.""" |
|
if not tags or not self.tags: |
|
return False |
|
intersection = set(tags) & self.tags |
|
return len(intersection) > 0 |
|
|
|
def is_contain(self, tag): |
|
return self.is_contain_tags([tag]) |
|
|
|
def dict(self): |
|
"""pydantic-like `dict` function""" |
|
full = { |
|
"instruct_content": self.instruct_content, |
|
"sent_from": self.sent_from, |
|
"send_to": self.send_to, |
|
"tags": self.tags |
|
} |
|
|
|
m = {"content": self.content} |
|
for k, v in full.items(): |
|
if v: |
|
m[k] = v |
|
return m |
|
|
|
|
|
@dataclass |
|
class UserMessage(Message): |
|
"""便于支持OpenAI的消息 |
|
Facilitate support for OpenAI messages |
|
""" |
|
|
|
def __init__(self, content: str): |
|
super().__init__(content, 'user') |
|
|
|
|
|
@dataclass |
|
class SystemMessage(Message): |
|
"""便于支持OpenAI的消息 |
|
Facilitate support for OpenAI messages |
|
""" |
|
|
|
def __init__(self, content: str): |
|
super().__init__(content, 'system') |
|
|
|
|
|
@dataclass |
|
class AIMessage(Message): |
|
"""便于支持OpenAI的消息 |
|
Facilitate support for OpenAI messages |
|
""" |
|
|
|
def __init__(self, content: str): |
|
super().__init__(content, 'assistant') |
|
|
|
|
|
if __name__ == '__main__': |
|
test_content = 'test_message' |
|
msgs = [ |
|
UserMessage(test_content), |
|
SystemMessage(test_content), |
|
AIMessage(test_content), |
|
Message(test_content, role='QA') |
|
] |
|
logger.info(msgs) |
|
|