2024.6.10 山东大学项目实训纪实

定义了两个类:Conversation 和 StoppingCriteriaSub

Conversation 类用于存储和操作对话历史记录。以下是它的属性和方法的解释:

属性:

  • system:表示对话中系统的角色的字符串。
  • roles:表示对话中角色的字符串列表。
  • messages:表示对话消息的列表,每个内部列表包含一个角色和相应的消息。
  • offset:表示在将对话转换为聊天机器人格式时要从哪个位置开始的整数偏移量。
  • sep_style:一个 SeparatorStyle 枚举实例,用于确定对话中使用的分隔符样式。
  • sep:表示对话中主要分隔符的字符串。
  • sep2:表示对话中备用分隔符的字符串。
  • skip_next:一个布尔标志,指示是否应跳过下一条消息。
  • conv_id:对话的标识符。

方法:

  • get_prompt():生成并返回格式化的对话提示作为字符串。
  • append_message(role, message):向对话中添加一条新消息。
  • to_gradio_chatbot():将对话转换为 Gradio 聊天机器人所需的格式。
  • copy():创建并返回当前对话的副本。
  • dict():将对话的属性和值以字典的形式返回。

StoppingCriteriaSub 类继承自 StoppingCriteria 类,并重写了其中的方法。

  • StoppingCriteriaSub 类具有一个 __init__() 方法,用于初始化对象的属性。
  • __call__(self, input_ids, scores) 方法用于在给定输入和分数时确定是否满足停止条件。

最后的 CONV_VISION 是一个 Conversation 类的实例,表示一个名为 "CONV_VISION" 的对话,其中包含了一些初始属性值。

class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    # system_img: List[Image.Image] = []
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            # system_img=self.system_img,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id)

    def dict(self):
        return {
            "system": self.system,
            # "system_img": self.system_img,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False


CONV_VISION = Conversation(
    system="",
    roles=("Human", "Assistant"),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值