定义了两个类:Conversation
和 StoppingCriteriaSub
。
Conversation
类用于存储和操作对话历史记录。以下是它的属性和方法的解释:
属性:
system
:表示对话中系统的角色的字符串。roles
:表示对话中角色的字符串列表。messages
:表示对话消息的列表,每个内部列表包含一个角色和相应的消息。offset
:表示在将对话转换为聊天机器人格式时要从哪个位置开始的整数偏移量。sep_style
:一个SeparatorStyle
枚举实例,用于确定对话中使用的分隔符样式。sep
:表示对话中主要分隔符的字符串。sep2
:表示对话中备用分隔符的字符串。skip_next
:一个布尔标志,指示是否应跳过下一条消息。conv_id
:对话的标识符。
方法:
get_prompt()
:生成并返回格式化的对话提示作为字符串。append_message(role, message)
:向对话中添加一条新消息。to_gradio_chatbot()
:将对话转换为 Gradio 聊天机器人所需的格式。copy()
:创建并返回当前对话的副本。dict()
:将对话的属性和值以字典的形式返回。
StoppingCriteriaSub
类继承自 StoppingCriteria
类,并重写了其中的方法。
StoppingCriteriaSub
类具有一个__init__()
方法,用于初始化对象的属性。__call__(self, input_ids, scores)
方法用于在给定输入和分数时确定是否满足停止条件。
最后的 CONV_VISION
是一个 Conversation
类的实例,表示一个名为 "CONV_VISION" 的对话,其中包含了一些初始属性值。
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
# system_img: List[Image.Image] = []
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
# system_img=self.system_img,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id)
def dict(self):
return {
"system": self.system,
# "system_img": self.system_img,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
CONV_VISION = Conversation(
system="",
roles=("Human", "Assistant"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)