使用OpenAI API 做 Book Summary(三):编写summarize.py

使用OpenAI API 做 Book Summary(三):编写summarize.py


summarize.py

源代码

import os
import json
import hashlib
from file_process import FileProcessor
from prompt import chunk_prompt_messages, synthesis_prompt_messages
from typing import Dict, List
import openai
import random
import time
from openai.error import APIConnectionError, APIError, RateLimitError


openai.api_key = os.getenv("OPENAI_API_KEY")


class Summarizer:

    def __init__(self, processed_file: FileProcessor, model: str):
        """
        Initialize the class with the processed file and set default values.

        Args:
            processed_file (FileProcessor): The processed file object.
            model (str): The model to use for summarization.
        """
        self.file = processed_file
        self.actual_tokens = 0  # Number of tokens processed
        self.MAX_ATTEMPTS = 3  # Maximum number of attempts allowed
        self.chunk_model = "gpt-3.5-turbo-1106"  # Chunk model version
        self.syns_model = model  # Synonyms model version

    def memoize_to_file(func):
        """
        Memoization decorator that caches the output of a method in a JSON file.

        Args:
            func: The function to be memoized.

        Returns:
            The wrapped function.
        """

        def wrapped(self, *args, **kwargs):
            """
            Wrapper function that caches the output of the decorated function in a JSON file.

            Args:
                self: The instance of the class.
                *args: The positional arguments passed to the decorated function.
                **kwargs: The keyword arguments passed to the decorated function.

            Returns:
                The result of the decorated function.
            """
            # check if "cache" has been existed
            if not os.path.exists("cache"):
                os.makedirs("cache")
                
            # Load the cache from the JSON file
            if os.path.exists(self.file.cache_file):
                with open(self.file.cache_file, "r") as f:
                    cache = json.load(f)
            else:
                cache = {}

            # Compute the hash of the argument
            arg_hash = hashlib.sha256(repr(tuple(args)).encode("utf-8")).hexdigest()
            print("ASSESSING HASH OF: ", arg_hash)

            # Check if the result is already cached
            if arg_hash in cache:
                print(f"Cached result found for {arg_hash}. Returning it.")
                return cache[arg_hash]
            else:
                print("CACHE NOT FOUND")

            # Compute the result and cache it
            result = func(self, *args, **kwargs)
            cache[arg_hash] = result

            # write the cache to the JSON file
            with open(self.file.cache_file, "w") as f:
                json.dump(cache, f)

            return result

        return wrapped

    @memoize_to_file
    def gpt_summarize(self, messages: List[Dict], model: str) -> str:
        """
        Summarizes the text using OpenAI's GPT model.

        Args:
            messages (List[Dict]): List of messages for the chatbot.
            model (str): Name of the GPT model to use.

        Returns:
            str: The summarized text.

        Raises:
            APIConnectionError: If there is an error connecting to the OpenAI API.
            APIError: If there is an error with the OpenAI API.
            RateLimitError: If the API rate limit is exceeded.
        """

        # Initialize variables
        tries = 0

        # Retry until successful or maximum attempts reached
        while True:
            try:
                tries += 1

                # Generate chat completion using OpenAI API
                result = openai.ChatCompletion.create(model=model, messages=messages)

                # Update the total token count
                self.actual_tokens += result.usage.total_tokens

                # Return the summarized text
                return result.choices[0].message.to_dict()["content"] + "\n"

            except (APIConnectionError, APIError, RateLimitError) as e:
                # Check if maximum attempts reached
                if tries >= self.MAX_ATTEMPTS:
                    print(f"OpenAI exception after {self.MAX_ATTEMPTS} tries. Aborting. {e}")
                    raise e

                # Check if should_retry flag is set to False
                if hasattr(e, "should_retry") and not e.should_retry:
                    print(f"OpenAI exception with should_retry false. Aborting. {e}")
                    raise e

                # Retry with exponential backoff
                else:
                    print(f"Summarize failed (Try {tries} of {self.MAX_ATTEMPTS}). {e}")
                    random_wait = random.random() * 4.0 + 1.0  # Wait between 1 and 5 seconds
                    random_wait = random_wait * tries  # Scale that up by the number of tries (more tries, longer wait)
                    time.sleep(random_wait * tries)

    def run(self):
        """
        Generate a summary for the file using chunk-based summarization.
        """
        # Initialize an empty list to store the summaries for each chunk
        chunk_summaries = []
        print('Summarizing each chunk of the file...\n\n')

        # Iterate over each chunk in the file
        for chunk in self.file.chunks:
            # Generate prompt messages for summarization
            messages = chunk_prompt_messages(chunk, self.file.chunk_summary_size)

            # Generate a summary for the chunk using GPT model
            chunk_summary = self.gpt_summarize(messages, self.chunk_model)

            # Append the chunk summary to the list of summaries
            chunk_summaries.append(chunk_summary)

        # Generate prompt messages for synthesizing the chunk summaries
        syn_messages = synthesis_prompt_messages(''.join(chunk_summaries), self.file.summary_size)
        print("\n\nSynthesizing the summaries...\n\n")

        # Generate a final summary for the file using GPT model
        summary = self.gpt_summarize(syn_messages, self.syns_model)

        # Return the generated summary
        return summary

功能讲解

导入api_key
  • 首先将大家的api_key设置为环境变量,我使用的是windows 11,设置位置如下,设置完之后需要重启电脑;

    在这里插入图片描述

  • 然后,导入api_key

openai.api_key = os.getenv("OPENAI_API_KEY")
调用openai api

gpt_summarize()函数用来调用OpenAI API 发送请求并获取结果:

  • 调用OpenAI API的openai.ChatCompletion.create()函数,将messages和model发送过去,返回结果是chat completion 对象,通过result.choices[0].message.to_dict()[“content”],提取结果的文本内容。关于openai.ChatCompletion.create()具体用法请参考官方文档:API Reference - OpenAI API
  • 通过try … except … 结构,捕捉异常并设置retry的次数和时间;
  • 用@memoize_to_file进行装饰,实现cache检查。
    @memoize_to_file
    def gpt_summarize(self, messages: List[Dict], model: str) -> str:
        """
        Summarizes the text using OpenAI's GPT model.

        Args:
            messages (List[Dict]): List of messages for the chatbot.
            model (str): Name of the GPT model to use.

        Returns:
            str: The summarized text.

        Raises:
            APIConnectionError: If there is an error connecting to the OpenAI API.
            APIError: If there is an error with the OpenAI API.
            RateLimitError: If the API rate limit is exceeded.
        """

        # Initialize variables
        tries = 0

        # Retry until successful or maximum attempts reached
        while True:
            try:
                tries += 1

                # Generate chat completion using OpenAI API
                result = openai.ChatCompletion.create(model=model, messages=messages)

                # Update the total token count
                self.actual_tokens += result.usage.total_tokens

                # Return the summarized text
                return result.choices[0].message.to_dict()["content"] + "\n"

            except (APIConnectionError, APIError, RateLimitError) as e:
                # Check if maximum attempts reached
                if tries >= self.MAX_ATTEMPTS:
                    print(f"OpenAI exception after {self.MAX_ATTEMPTS} tries. Aborting. {e}")
                    raise e

                # Check if should_retry flag is set to False
                if hasattr(e, "should_retry") and not e.should_retry:
                    print(f"OpenAI exception with should_retry false. Aborting. {e}")
                    raise e

                # Retry with exponential backoff
                else:
                    print(f"Summarize failed (Try {tries} of {self.MAX_ATTEMPTS}). {e}")
                    random_wait = random.random() * 4.0 + 1.0  # Wait between 1 and 5 seconds
                    random_wait = random_wait * tries  # Scale that up by the number of tries (more tries, longer wait)
                    time.sleep(random_wait * tries)
cache 检查

创建一个装饰器 memoize_to_file(func),用来进行cache检查,如果cache中已含,直接返回结果,避免不必要的资金和时间的浪费。

  • 在类中写装饰器有个要点,我也是尝试了多次终于写对了。外层的函数参数没有self,而内层函数参数需要加self,pycharm会报错,不用理他。

  • 该函数会为每个文件创建一个json格式的cache文件,调用hashlib.sha256()函数,计算每次请求的hash值,如果已在cache文件中就直接返回结果,最后会把新的结果写入cache文件。

    def memoize_to_file(func):
        """
        Memoization decorator that caches the output of a method in a JSON file.

        Args:
            func: The function to be memoized.

        Returns:
            The wrapped function.
        """

        def wrapped(self, *args, **kwargs):
            """
            Wrapper function that caches the output of the decorated function in a JSON file.

            Args:
                self: The instance of the class.
                *args: The positional arguments passed to the decorated function.
                **kwargs: The keyword arguments passed to the decorated function.

            Returns:
                The result of the decorated function.
            """
            # check if "cache" has been existed
            if not os.path.exists("cache"):
                os.makedirs("cache")
                
            # Load the cache from the JSON file
            if os.path.exists(self.file.cache_file):
                with open(self.file.cache_file, "r") as f:
                    cache = json.load(f)
            else:
                cache = {}

            # Compute the hash of the argument
            arg_hash = hashlib.sha256(repr(tuple(args)).encode("utf-8")).hexdigest()
            print("ASSESSING HASH OF: ", arg_hash)

            # Check if the result is already cached
            if arg_hash in cache:
                print(f"Cached result found for {arg_hash}. Returning it.")
                return cache[arg_hash]
            else:
                print("CACHE NOT FOUND")

            # Compute the result and cache it
            result = func(self, *args, **kwargs)
            cache[arg_hash] = result

            # write the cache to the JSON file
            with open(self.file.cache_file, "w") as f:
                json.dump(cache, f)

            return result

        return wrapped
summarize主函数

run()函数是Summarizer类的主函数,功能非常简单,首先对每个chunk调用gpt_summarize()函数生成chunk summary,然后将所有的chunk summary,再调用gpt_summarize(),整合成synthesis summary。

    def run(self):
        """
        Generate a summary for the file using chunk-based summarization.
        """
        # Initialize an empty list to store the summaries for each chunk
        chunk_summaries = []
        print('Summarizing each chunk of the file...\n\n')

        # Iterate over each chunk in the file
        for chunk in self.file.chunks:
            # Generate prompt messages for summarization
            messages = chunk_prompt_messages(chunk, self.file.chunk_summary_size)

            # Generate a summary for the chunk using GPT model
            chunk_summary = self.gpt_summarize(messages, self.chunk_model)

            # Append the chunk summary to the list of summaries
            chunk_summaries.append(chunk_summary)

        # Generate prompt messages for synthesizing the chunk summaries
        syn_messages = synthesis_prompt_messages(''.join(chunk_summaries), self.file.summary_size)
        print("\n\nSynthesizing the summaries...\n\n")

        # Generate a final summary for the file using GPT model
        summary = self.gpt_summarize(syn_messages, self.syns_model)

        # Return the generated summary
        return summary
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值