TowardsDataScience 2023 博客中文翻译(一百)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

解码美国参议院对 AI 的监督听证会:Python 中的 NLP 分析

原文:towardsdatascience.com/decoding-the-us-senate-hearing-on-oversight-of-ai-nlp-analysis-in-python-2a1e50a1fd0c?source=collection_archive---------7-----------------------#2023-06-02

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

摄影: Harold MendozaUnsplash

使用 NLTK 工具包进行词频分析、可视化和情感评分

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Raul Vizcarra Chirinos

·

关注 发表在 Towards Data Science ·21 分钟阅读·2023 年 6 月 2 日

上周日早晨,当我在换频道找早餐时可以看的节目时,我偶然发现了参议院对 AI 监管听证会的重播。距离开始已经过去了 40 分钟,所以我决定观看剩下的部分(谈到度过一个有趣的周日早晨!)。

当像参议院司法委员会对 AI 监管的听证会这样的事件发生时,如果你想了解关键要点,你有四个选择:观看直播,寻找未来的录音(这两个选项都需要你花费三小时);阅读书面版本(转录本),它们大约有 79 页,超过 29,000 个词;或者在网站或社交媒体上阅读评论以获取不同的观点并形成自己的看法(如果不是来自其他人)。

如今,随着一切变化如此迅速,我们的时间似乎总是过于短暂,人们很容易选择捷径,依赖评论而不是查阅原始来源(我也有过这样的经历)。如果你选择这个听证会的捷径,很可能你在网上或社交媒体上找到的大多数评论都会集中在 OpenAI CEO Sam Altman 呼吁监管 AI 上。然而,看过听证会后,我觉得还有更多内容值得探索,超越头条新闻。

所以,在我完成了周日的休闲早晨活动后,我决定下载参议院听证会的 transcript,并使用 NLTK 包(一个用于自然语言处理的 Python 包——NLP)来分析它,比较最常用的词汇,并对不同的兴趣群体(OpenAI、IBM、学术界、国会)应用一些情感评分,看看是否能发现其中的含义。剧透警告!在分析的 29,000 个词中,只有 70 个(0.24%)与“regulation”,“regulate”,“regulatory”或“legislation”等词相关。

需要注意的是,这篇文章并不是关于我对这次 AI 听证会或 ChatGPT 的 Sam Altman 的看法。相反,它关注的是在国会山这个屋檐下,各个社会部分(私人、学术界、政府)所代表的各方言辞背后的含义,以及我们从这些混杂的言辞中能学到什么。

鉴于未来几个月在人工智能监管方面的有趣时刻,因为 EU AI Act 的最终草案等待在欧洲议会辩论(预计在 6 月进行),探索大西洋这边围绕 AI 的讨论背后的内容是值得的。

步骤-01:获取数据

我使用了 Justin Hendrix 在 Tech Policy Press 发布的 transcript(可在这里访问)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

访问参议院听证会 transcript 这里

虽然亨德里克斯提到这是一个快速转录,并建议通过观看参议院听证会视频来确认引述,但我仍然发现它对这次分析非常准确和有趣。如果你想观看参议院听证会或阅读萨姆·奥特曼(OpenAI)、克里斯蒂娜·蒙哥马利(IBM)和加里·马库斯(纽约大学教授)的证词,你可以在这里找到它们。

最初,我计划将转录文本复制到 Word 文档中,并在 Excel 中手动创建一个包含参与者姓名、他们代表的组织及其评论的表格。然而,这种方法既耗时又低效。所以,我转向了 Python,并将 Microsoft Word 文件中的完整转录文本上传到数据框中。以下是我使用的代码:

# STEP 01-Read the Word document
# remember to install  pip install python-docx

import docx
import pandas as pd

doc = docx.Document('D:\....your word file on microsoft word')

items = []
names = []
comments = []

# Iterate over paragraphs 
for paragraph in doc.paragraphs:
    text = paragraph.text.strip()

    if text.endswith(':'):
        name = text[:-1]  
    else:
        items.append(len(items))
        names.append(name)
        comments.append(text)

dfsenate = pd.DataFrame({'item': items, 'name': names, 'comment': comments})

# Remove rows with empty comments
dfsenate = dfsenate[dfsenate['comment'].str.strip().astype(bool)]

# Reset the index
dfsenate.reset_index(drop=True, inplace=True)
dfsenate['item'] = dfsenate.index + 1
print(dfsenate)

输出应如下所示:

 item name comment
0 1 Sen. Richard Blumenthal (D-CT) Now for some introductory remarks.
1 2 Sen. Richard Blumenthal (D-CT) “Too often we have seen what happens when technology outpaces regulation, the unbridled exploitation of personal data, the proliferation of disinformation, and the deepening of societal inequalities. We have seen how algorithmic biases can perpetuate discrimination and prejudice, and how the lack of transparency can undermine public trust. This is not the future we want.2 3 Sen. Richard Blumenthal (D-CT) If you were listening from home, you might have thought that voice was mine and the words from me, but in fact, that voice was not mine. The words were not mine. And the audio was an AI voice cloning software trained on my floor speeches. The remarks were written by ChatGPT when it was asked how I would open this hearing. And you heard just now the result I asked ChatGPT, why did you pick those themes and that content? And it answered. And I’m quoting, Blumenthal has a strong record in advocating for consumer protection and civil rights. He has been vocal about issues such as data privacy and the potential for discrimination in algorithmic decision making. Therefore, the statement emphasizes these aspects.
3 4 Sen. Richard Blumenthal (D-CT) Mr. Altman, I appreciate ChatGPT’s endorsement. In all seriousness, this apparent reasoning is pretty impressive. I am sure that we’ll look back in a decade and view ChatGPT and GPT-4 like we do the first cell phone, those big clunky things that we used to carry around. But we recognize that we are on the verge, really, of a new era. The audio and my playing, it may strike you as curious or humorous, but what reverberated in my mind was what if I had asked it? And what if it had provided an endorsement of Ukraine, surrendering or Vladimir Putin’s leadership? That would’ve been really frightening. And the prospect is more than a little scary to use the word, Mr. Altman, you have used yourself, and I think you have been very constructive in calling attention to the pitfalls as well as the promise.
4 5 Sen. Richard Blumenthal (D-CT) And that’s the reason why we wanted you to be here today. And we thank you and our other witnesses for joining us for several months. Now, the public has been fascinated with GPT, dally and other AI tools. These examples like the homework done by ChatGPT or the articles and op-eds, that it can write feel like novelties. But the underlying advancement of this era are more than just research experiments. They are no longer fantasies of science fiction. They are real and present the promises of curing cancer or developing new understandings of physics and biology or modeling climate and weather. All very encouraging and hopeful. But we also know the potential harms and we’ve seen them already weaponized disinformation, housing discrimination, harassment of women and impersonation, fraud, voice cloning deep fakes. These are the potential risks despite the other rewards. And for me, perhaps the biggest nightmare is the looming new industrial revolution. The displacement of millions of workers, the loss of huge numbers of jobs, the need to prepare for this new industrial revolution in skill training and relocation that may be required. And already industry leaders are calling attention to those challenges.
5 6 Sen. Richard Blumenthal (D-CT) To quote ChatGPT, this is not necessarily the future that we want. We need to maximize the good over the bad. Congress has a choice. Now. We had the same choice when we face social media. We failed to seize that moment. The result is predators on the internet, toxic content exploiting children, creating dangers for them. And Senator Blackburn and I and others like Senator Durbin on the Judiciary Committee are trying to deal with it in the Kids Online Safety Act. But Congress failed to meet the moment on social media. Now we have the obligation to do it on AI before the threats and the risks become real. Sensible safeguards are not in opposition to innovation. Accountability is not a burden far from it. They are the foundation of how we can move ahead while protecting public trust. They are how we can lead the world in technology and science, but also in promoting our democratic values.
6 7 Sen. Richard Blumenthal (D-CT) Otherwise, in the absence of that trust, I think we may well lose both. These are sophisticated technologies, but there are basic expectations common in our law. We can start with transparency. AI companies ought to be required to test their systems, disclose known risks, and allow independent researcher access. We can establish scorecards and nutrition labels to encourage competition based on safety and trustworthiness, limitations on use. There are places where the risk of AI is so extreme that we ought to restrict or even ban their use, especially when it comes to commercial invasions of privacy for profit and decisions that affect people’s livelihoods. And of course, accountability, reliability. When AI companies and their clients cause harm, they should be held liable. We should not repeat our past mistakes, for example, Section 230, forcing companies to think ahead and be responsible for the ramifications of their business decisions can be the most powerful tool of all. Garbage in, garbage out. The principle still applies. We ought to beware of the garbage, whether it’s going into these platforms or coming out of them.

接下来,我考虑为未来的分析添加一些标签,通过所代表的社会群体来识别个人。

 def assign_sector(name):
    if name in ['Sam Altman', 'Christina Montgomery']:
        return 'Private'
    elif name == 'Gary Marcus':
        return 'Academia'
    else:
        return 'Congress'

# Apply function 
dfsenate['sector'] = dfsenate['name'].apply(assign_sector)

# Assign organizations based on names
def assign_organization(name):
    if name == 'Sam Altman':
        return 'OpenAI'
    elif name == 'Christina Montgomery':
        return 'IBM'
    elif name == 'Gary Marcus':
        return 'Academia'
    else:
        return 'Congress'

# Apply function
dfsenate['Organization'] = dfsenate['name'].apply(assign_organization)

print(dfsenate)

最后,我决定添加一个列来统计每个声明的字数,这也有助于我们进一步分析。

dfsenate['WordCount'] = dfsenate['comment'].apply(lambda x: len(x.split()))

此时,你的数据框应该如下所示:

 item                            name  ... Organization WordCount
0       1  Sen. Richard Blumenthal (D-CT)  ...     Congress         5
1       2  Sen. Richard Blumenthal (D-CT)  ...     Congress        55
2       3  Sen. Richard Blumenthal (D-CT)  ...     Congress       125
3       4  Sen. Richard Blumenthal (D-CT)  ...     Congress       145
4       5  Sen. Richard Blumenthal (D-CT)  ...     Congress       197
..    ...                             ...  ...          ...       ...
399   400         Sen. Cory Booker (D-NJ)  ...     Congress       156
400   401                      Sam Altman  ...       OpenAI       180
401   402         Sen. Cory Booker (D-NJ)  ...     Congress        72
402   403  Sen. Richard Blumenthal (D-CT)  ...     Congress       154
403   404  Sen. Richard Blumenthal (D-CT)  ...     Congress        98

STEP-02: 视觉化数据

让我们看看到目前为止的数据:404 个问题或证词,几乎 29,000 字。这些数字为我们提供了启动所需的材料。重要的是要知道一些声明被分成了较小的部分。当有长声明并且包含不同的段落时,代码将它们分成了独立的声明,即使它们实际上是一个贡献的一部分。为了更好地理解每个参与者的参与程度,我还考虑了他们使用的字数。这提供了另一个角度来衡量他们的参与。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

监督人工智能的听证会:图 01

正如图 01 所示,国会议员的干预占所有听证会的一半以上,其次是萨姆·奥特曼的证词。然而,通过统计每一方的发言字数,另一种观点显示了国会(11 名成员)与由奥特曼(OpenAI)、蒙哥马利(IBM)和马库斯(学界)组成的专家小组之间的更平衡的代表性。

值得注意的是参与参议院听证会的国会议员之间的不同参与程度(见下表)。正如预期的那样,作为分委员会主席的布卢门萨尔参议员参与度很高。但其他成员呢?表格显示所有十一位参与者的参与程度有显著差异。请记住,贡献的数量不一定表示其质量。我会让你在查看数字时自行判断。

最后,尽管萨姆·奥特曼受到了大量关注,但值得注意的是,尽管加里·马库斯可能看似参与机会较少,但他的发言量与奥特曼相当,说明他有很多要说。或者这可能是因为学术界往往提供详细解释,而商业世界则更倾向于实用和直接?

好的,马克斯教授,如果你能具体一点就好了。这是你的机会,伙计。用简单的英语告诉我,我们是否应该实施任何规则。请不要只是使用概念。我需要具体的内容。

参议员约翰·肯尼迪(R-LA)。美国参议院关于 AI 监督的听证会(2023)

#*****************************PIE CHARTS************************************
import pandas as pd
import matplotlib.pyplot as plt

# Pie chart - Grouping by 'Organization' Questions&Testimonies
org_colors = {'Congress': '#6BB6FF', 'OpenAI': 'green', 'IBM': 'lightblue', 'Academia': 'lightyellow'}
org_counts = dfsenate['Organization'].value_counts()

plt.figure(figsize=(8, 6))
patches, text, autotext = plt.pie(org_counts.values, labels=org_counts.index, 
                                  autopct=lambda p: f'{p:.1f}%\n({int(p * sum(org_counts.values) / 100)})', 
                                  startangle=90, colors=[org_colors.get(org, 'gray') for org in org_counts.index])
plt.title('Hearing on Oversight of AI: Questions or Testimonies')
plt.axis('equal')
plt.setp(text, fontsize=12)
plt.setp(autotext, fontsize=12)
plt.show()

# Pie chart - Grouping by 'Organization' (WordCount)
org_colors = {'Congress': '#6BB6FF', 'OpenAI': 'green', 'IBM': 'lightblue', 'Academia': 'lightyellow'}
org_wordcount = dfsenate.groupby('Organization')['WordCount'].sum()

plt.figure(figsize=(8, 6))
patches, text, autotext = plt.pie(org_wordcount.values, labels=org_wordcount.index, 
                                  autopct=lambda p: f'{p:.1f}%\n({int(p * sum(org_wordcount.values) / 100)})', 
                                  startangle=90, colors=[org_colors.get(org, 'gray') for org in org_wordcount.index])

plt.title('Hearing on Oversight of AI: WordCount ')
plt.axis('equal')
plt.setp(text, fontsize=12)
plt.setp(autotext, fontsize=12)
plt.show()

#************Engagement among the members of Congress**********************

# Group by name and count the rows
Summary_Name = dfsenate.groupby('name').agg(comment_count=('comment', 'size')).reset_index()

# WordCount column for each name
Summary_Name ['Total_Words'] = dfsenate.groupby('name')['WordCount'].sum().values

# Percentage distribution for comment_count
Summary_Name ['comment_count_%'] = Summary_Name['comment_count'] / Summary_Name['comment_count'].sum() * 100

# Percentage distribution for total_word_count
Summary_Name ['Word_count_%'] = Summary_Name['Total_Words'] / Summary_Name['Total_Words'].sum() * 100

Summary_Name  = Summary_Name.sort_values('Total_Words', ascending=False)

print (Summary_Name)
+-------+--------------------------------+---------------+-------------+-----------------+--------------+
| index |              name              | Interventions | Total_Words | Interv_%        | Word_count_% |
+-------+--------------------------------+---------------+-------------+-----------------+--------------+
|     2 | Sam Altman                     |            92 |        6355 |     22.77227723 |  22.32252626 |
|     1 | Gary Marcus                    |            47 |        5105 |     11.63366337 |  17.93178545 |
|    15 | Sen. Richard Blumenthal (D-CT) |            58 |        3283 |     14.35643564 |  11.53184165 |
|    10 | Sen. Josh Hawley (R-MO)        |            25 |        2283 |     6.188118812 |  8.019249008 |
|     0 | Christina Montgomery           |            36 |        2162 |     8.910891089 |  7.594225298 |
|     6 | Sen. Cory Booker (D-NJ)        |            20 |        1688 |      4.95049505 |  5.929256384 |
|     7 | Sen. Dick Durbin (D-IL)        |             8 |        1143 |      1.98019802 |  4.014893393 |
|    11 | Sen. Lindsey Graham (R-SC)     |            32 |         880 |     7.920792079 |  3.091081527 |
|     5 | Sen. Christopher Coons (D-CT)  |             6 |         869 |     1.485148515 |  3.052443008 |
|    12 | Sen. Marsha Blackburn (R-TN)   |            14 |         869 |     3.465346535 |  3.052443008 |
|     4 | Sen. Amy Klobuchar (D-MN)      |            11 |         769 |     2.722772277 |  2.701183744 |
|    13 | Sen. Mazie Hirono (D-HI)       |             7 |         755 |     1.732673267 |  2.652007447 |
|    14 | Sen. Peter Welch (D-VT)        |            11 |         704 |     2.722772277 |  2.472865222 |
|     3 | Sen. Alex Padilla (D-CA)       |             7 |         656 |     1.732673267 |  2.304260775 |
+-------+--------------------------------+---------------+-------------+-----------------+--------------+

STEP-03: 标记化

现在是自然语言处理(NLP)有趣的开始阶段。为了分析文本,我们将使用 NLTK Package 这个 Python 库。它提供了用于词频分析和可视化的有用工具。以下库和模块将提供进行词频分析和可视化所需的工具。

 #pip install nltk
#pip install spacy
#pip install wordcloud
#pip install subprocess
#python -m spacy download en

首先,我们将开始标记化,这意味着将文本拆分成单独的词语,也称为“标记”。为此,我们将使用 spaCy,一个开源的 NLP 库,能够处理缩写、标点和特殊字符。接下来,我们将使用来自 NLTK 库的停用词资源去除那些没有太多意义的常见词,如“a”,“an”,“the”,“is”和“and”。最后,我们将应用词形还原,将词语还原为其基本形式,称为词根。例如,“running”变成“run”,“happier”变成“happy”。这种技术帮助我们更有效地处理文本并理解其含义。

总结一下:

o 标记化文本。

o 去除常见词。

o 应用词形还原。

#***************************WORD-FRECUENCY*******************************

import subprocess
import nltk
import spacy
from nltk.probability import FreqDist
from nltk.corpus import stopwords

# Download resources
subprocess.run('python -m spacy download en', shell=True)
nltk.download('punkt')

# Load spaCy model and set stopwords
nlp = spacy.load('en_core_web_sm')
stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    words = nltk.word_tokenize(text)
    words = [word.lower() for word in words if word.isalpha()]
    words = [word for word in words if word not in stop_words]
    lemmas = [token.lemma_ for token in nlp(" ".join(words))]
    return lemmas

# Aggregate words and create Frecuency Distribution
all_comments = ' '.join(dfsenate['comment'])
processed_comments = preprocess_text(all_comments)
fdist = FreqDist(processed_comments)

#**********************HEARING TOP 30 COMMON WORDS*********************
import matplotlib.pyplot as plt
import numpy as np

# Most common words and their frequencies
top_words = fdist.most_common(30)
words = [word for word, freq in top_words]
frequencies = [freq for word, freq in top_words]

# Bar plot-Hearing on Oversight of AI:Top 30 Most Common Words
fig, ax = plt.subplots(figsize=(8, 10))
ax.barh(range(len(words)), frequencies, align='center', color='skyblue')

ax.invert_yaxis()
ax.set_xlabel('Frequency', fontsize=12)
ax.set_ylabel('Words', fontsize=12)
ax.set_title('Hearing on Oversight of AI:Top 30 Most Common Words', fontsize=14)
ax.set_yticks(range(len(words)))
ax.set_yticklabels(words, fontsize=10)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_linewidth(0.5)
ax.spines['bottom'].set_linewidth(0.5)
ax.tick_params(axis='x', labelsize=10)
plt.subplots_adjust(left=0.3)

for i, freq in enumerate(frequencies):
    ax.text(freq + 5, i, str(freq), va='center', fontsize=8)

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AI 监督听证会:图 02

正如在条形图(图 02)中所见,“思考”的频率很高。也许前五个词语给我们提供了关于我们今天和未来在 AI 方面应该做些什么的有趣线索:

“我们需要****思考了解****AI应该去哪里。”

正如我在文章开头提到的,乍一看,“监管”在参议院 AI 听证会上并不是一个频繁使用的词。然而,得出它不是主要关注话题的结论可能是不准确的。关于 AI 是否应该受到监管的兴趣以不同的词汇表达,如**“监管”“调控”“机构”“监管”**。因此,让我们对代码进行一些调整,汇总这些词,并重新运行条形图,以查看它如何影响分析。

nlp = spacy.load('en_core_web_sm')
stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    words = nltk.word_tokenize(text)
    words = [word.lower() for word in words if word.isalpha()]
    words = [word for word in words if word not in stop_words]
    lemmas = [token.lemma_ for token in nlp(" ".join(words))]
    return lemmas

# Aggregate words and create Frecuency Distribution
all_comments = ' '.join(dfsenate['comment'])
processed_comments = preprocess_text(all_comments)
fdist = FreqDist(processed_comments)
original_fdist = fdist.copy() # Save the original object

aggregate_words = ['regulation', 'regulate','agency', 'regulatory','legislation']
aggregate_freq = sum(fdist[word] for word in aggregate_words)
df_aggregatereg = pd.DataFrame({'Word': aggregate_words, 'Frequency': [fdist[word] for word in aggregate_words]})

# Remove individual words and add aggregation
for word in aggregate_words:
    del fdist[word]
fdist['regulation+agency'] = aggregate_freq

# Pie chart for Regulation+agency distribution
import matplotlib.pyplot as plt

labels = df_aggregatereg['Word']
values = df_aggregatereg['Frequency']

plt.figure(figsize=(8, 6))
plt.subplots_adjust(top=0.8, bottom=0.25)  

patches, text, autotext = plt.pie(values, labels=labels, 
                                  autopct=lambda p: f'{p:.1f}%\n({int(p * sum(values) / 100)})', 
                                  startangle=90, colors=['#6BB6FF', 'green', 'lightblue', 'lightyellow', 'gray'])

plt.title('Regulation+agency: Distribution', fontsize=14)
plt.axis('equal')
plt.setp(text, fontsize=8)  
plt.setp(autotext, fontsize=8)  
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AI 监督听证会:图 03

如图 03 所示,监管的话题在参议院 AI 听证会上确实被提及了很多次。

STEP-04: 词语背后的含义

单独的词汇可能给我们一些线索,但词汇的相互关系才真正提供了视角。因此,让我们采用词云的方法,探索是否可以发现简单的条形图和饼图无法显示的见解。

# Word cloud-Senate Hearing on Oversight of AI
from wordcloud import WordCloud
wordcloud = WordCloud(width=800, height=400, background_color='white').generate_from_frequencies(fdist)
plt.figure(figsize=(10, 5))
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis('off')
plt.title('Word Cloud - Senate Hearing on Oversight of AI')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AI 监督听证会:图 04

让我们进一步探索,并比较 AI 听证会中不同利益集团(私营部门、国会、学术界)的词云,看看这些词汇是否揭示了对 AI 未来的不同看法。

# Word clouds for each group of Interest
organizations = dfsenate['Organization'].unique()
for organization in organizations:
    comments = dfsenate[dfsenate['Organization'] == organization]['comment']
    all_comments = ' '.join(comments)
    processed_comments = preprocess_text(all_comments)
    fdist_organization = FreqDist(processed_comments)

    # Word clouds
    wordcloud = WordCloud(width=800, height=400, background_color='white').generate_from_frequencies(fdist_organization)
    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    if organization == 'IBM':
        plt.title(f'Word Cloud: {organization} - Christina Montgomery')
    elif organization == 'OpenAI':
        plt.title(f'Word Cloud: {organization} - Sam Altman')
    elif organization == 'Academia':
        plt.title(f'Word Cloud: {organization} - Gary Marcus')
    else:
        plt.title(f'Word Cloud: {organization}')
    plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AI 监督听证会:图 05

有趣的是,不同利益集团在参议院 AI 听证会中讨论人工智能时,一些词汇会出现(或消失)。

关于大标题*“萨姆·奥特曼对监管 AI 的呼吁”*;嗯,不管他是否支持监管,我真的无法判断,但在我看来,他的话语中似乎没有太多监管的内容。相反,萨姆·奥特曼谈论 AI 时似乎更关注人本身,重复使用**“思考”、“人们”、“了解”、“重要”“使用”等词汇,更倾向于使用“技术”、“系统”“模型”这样的词汇,而不是使用“AI”**这个词。

说到**“风险”“问题”克里斯蒂娜·蒙哥马利(IBM)在谈论“技术”“公司”“AI”时不断重复这些词。在她的证词中,一个有趣的事实是她提到的词汇最常见的有“信任”“治理”“思考”,以及在 AI 方面“正确”**的看法。

“我们需要立即让公司对他们部署的 AI 负责,并承担责任……”

克里斯蒂娜·蒙哥马利。美国参议院 AI 监督听证会(2023 年)

加里·马库斯在他的初步声明中提到,“我以科学家的身份出现,曾创办 AI 公司,并且真正热爱 AI……” 因此,为了这次 NLP 分析,我们将他视为学术界声音的代表。“需要”、“思考”、“了解”、“进行”,**“人们”等词汇在其中尤为突出。一个有趣的事实是,在他的证词中,“系统”这个词似乎比“AI”**出现得更多。也许 AI 并不是一种单一的技术可以改变未来,未来的影响将来自于多种技术或系统的相互作用(物联网、机器人技术、生物技术等),而不是仅仅依赖于其中某一个。

最后,参议员约翰·肯尼迪提到的第一个假设似乎并非完全错误(不仅仅是对于国会,也对整个社会)。我们仍然处于试图理解 AI 发展方向的阶段。

请允许我向你们分享三个假设,我希望你们暂时接受这些假设为真。第一个假设,许多国会议员不了解人工智能。第二个假设,这种理解的缺乏可能不会阻止国会热情投入,并试图以可能对这一技术造成伤害的方式来监管它。第三个假设,我希望你们假设,人工智能社区中可能有一个失控的派别,无论是有意还是无意,都可能利用人工智能来杀死我们所有人,并在我们死去的整个过程中伤害我们……

参议员约翰·肯尼迪(R-LA)。美国参议院关于人工智能监管的听证会(2023 年)

STEP-05: 你话语背后的情感

我们将使用 NLTK 库中的SentimentIntensityAnalyzer类进行情感分析。这个预训练模型使用基于词典的方法,其中词典中的每个单词(VADER)都有一个预定义的情感极性值。将文本中单词的情感分数汇总以计算总体情感分数。数值范围从-1(负面情感)到+1(正面情感),0 表示中立情感。正面情感反映了有利的情感、态度或热情,而负面情感传达了不利的情感或态度。

#************SENTIMENT ANALYSIS************
from nltk.sentiment import SentimentIntensityAnalyzer
nltk.download('vader_lexicon')

sid = SentimentIntensityAnalyzer()
dfsenate['Sentiment'] = dfsenate['comment'].apply(lambda x: sid.polarity_scores(x)['compound'])

#************BOXPLOT-GROUP OF INTEREST************
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_style('white')
plt.figure(figsize=(12, 7))
sns.boxplot(x='Sentiment', y='Organization', data=dfsenate, color='yellow', 
            width=0.6, showmeans=True, showfliers=True)

# Customize the axis 
def add_cosmetics(title='Sentiment Analysis Distribution by Group of Interest',
                  xlabel='Sentiment'):
    plt.title(title, fontsize=28)
    plt.xlabel(xlabel, fontsize=20)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    sns.despine()

def customize_labels(label):
    if "OpenAI" in label:
        return label + "-Sam Altman"
    elif "IBM" in label:
        return label + "-Christina Montgomery"
    elif "Academia" in label:
        return label + "-Gary Marcus"
    else:
        return label

# Apply customized labels to y-axis
yticks = plt.yticks()[1]
plt.yticks(ticks=plt.yticks()[0], labels=[customize_labels(label.get_text()) 
                                          for label in yticks])

add_cosmetics()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

关于人工智能监管的听证会:图 06

箱型图总是很有趣,因为它显示了最小值和最大值、中位数、第一四分位数(Q1)和第三四分位数(Q3)。此外,添加了一行代码以显示平均值。(感谢 Elena Kosourova 设计了箱型图代码模板;我仅为我的数据集做了调整)。

总体而言,参议院听证会期间,大家似乎心情愉快,尤其是萨姆·奥特曼,他以最高的情感分数脱颖而出,其次是克里斯蒂娜·蒙哥马利。另一方面,加里·马库斯的体验似乎较为中立(中位数约为 0.25),他可能有时感到不太舒服,值接近 0 或甚至为负。此外,国会整体上在情感分数上表现出左偏分布,显示出对中立或积极情感的倾向。有趣的是,如果我们进一步观察,某些干预措施的情感分数非常高或非常低。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

关于人工智能监管的听证会:图 07

也许我们不应将结果解读为参议院人工智能听证会上的人们感到快乐或不安。也许这表明参与听证会的人们对人工智能的未来并不持过于乐观的观点,但与此同时,他们也不悲观。评分可能表明存在一些担忧,并对人工智能的发展方向持谨慎态度。

那么时间线如何呢?听证会期间的情绪是否一直保持不变?每个利益集团的情绪如何变化? 为了分析时间线,我将陈述按捕获顺序进行整理,并进行了情感分析。由于有超过 400 个问题或证词,我定义了每个利益集团(国会、学术界、私人)情感评分的移动平均值,窗口大小为 10。这意味着移动平均值是通过对每 10 个连续陈述的情感评分取平均值来计算的:

#**************************TIMELINE US SENATE AI HEARING**************************************

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline

# Moving average for each organization
window_size = 10  
organizations = dfsenate['Organization'].unique()

# Create the line plot
color_palette = sns.color_palette('Set2', len(organizations))

plt.figure(figsize=(12, 6))
for i, org in enumerate(organizations):
    df_org = dfsenate[dfsenate['Organization'] == org]

    # moving average
    df_org['Sentiment'].fillna(0, inplace=True) # missing values filled with 0
    df_org['Moving_Average'] = df_org['Sentiment'].rolling(window=window_size, min_periods=1).mean()

    x = np.linspace(df_org.index.min(), df_org.index.max(), 500)
    spl = make_interp_spline(df_org.index, df_org['Moving_Average'], k=3)
    y = spl(x)
    plt.plot(x, y, linewidth=2, label=f'{org} {window_size}-Point Moving Average', color=color_palette[i])

plt.xlabel('Statement Number', fontsize=12)
plt.ylabel('Sentiment Score', fontsize=12)
plt.title('Sentiment Score Evolution during the Hearing on Oversight of AI', fontsize=16)
plt.legend(fontsize=12)
plt.grid(color='lightgray', linestyle='--', linewidth=0.5)
plt.axhline(0, color='black', linewidth=0.5, alpha=0.5)

for org in organizations:
    df_org = dfsenate[dfsenate['Organization'] == org]
    plt.text(df_org.index[-1], df_org['Moving_Average'].iloc[-1], f'{df_org["Moving_Average"].iloc[-1]:.2f}', ha='right', va='top', fontsize=12, color='black')

plt.tight_layout()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

人工智能监管听证会:图 08

起初,会议似乎友好而乐观,大家讨论人工智能的未来。但随着会议的进行,情绪开始发生变化。国会议员变得不那么乐观,问题也变得更加具有挑战性。这影响了讨论小组的评分,甚至有些评分较低(你可以在会议结束时看到这一点)。有趣的是,即使在与国会议员的紧张时刻,模型仍将 Altman 视为中立或略微积极。

重要的是要记住,模型有其局限性,可能带有一定主观性。虽然情感分析并不完美,但它为我们提供了对那一天国会山上情感强度的有趣窥视。

最后的想法

在我看来,这次美国参议院关于人工智能的听证会背后的教训在于五个最常出现的词汇:“我们 需要 思考 知道 人工智能 应该 哪里*”*。值得注意的是,像**“人们”“重要性”这样的词在 Sam Altman 的词云中意外出现,超出了“呼吁监管”的标题范围。虽然我希望在 Altman 的 NLP 分析中看到更多“透明度”“问责制”“信任”“治理”“公平”**等词,但发现这些词在 Christina Montgomery 的证词中经常出现,还是让人感到宽慰。这正是我们在讨论人工智能时期待听到的。

加里·马库斯强调了**“system”“AI”一样多,或许是在邀请我们从更广的视角来看待人工智能。目前有多种技术正在出现,它们对社会、工作和未来就业的综合影响将来自这些技术之间的冲突,而不仅仅是某一种技术。学术界在引导这一过程方面发挥着至关重要的作用,如果需要某种形式的监管的话。我说的是“字面上的” 而不是 “精神上的”(来自六个月停顿信的内部笑话)。

最后,“Agency” 这个词在不同形式中被重复使用的频率与**“Regulation”相当。这表明“Agency for AI”**的概念及其作用可能会在不久的将来成为讨论的话题。理查德·布卢门撒尔参议员在参议院人工智能听证会上提到对此挑战的有趣反思:

“…我职业生涯的大部分时间都在执法。我告诉你们,你们可以创建 10 个新机构,但如果不给他们资源,我说的不是仅仅是资金,还包括科学专业知识,你们将把他们绕圈子。而且不仅仅是模型或生成性人工智能会把他们绕圈子,而是你们公司里的科学家。对于政府监管中的每一个成功故事,你可以想到五个失败案例……我希望我们这里的经验会有所不同……”

理查德·布卢门撒尔(D-CT)参议员。美国参议院关于人工智能监督的听证会(2023 年)

尽管对我来说,调和创新、意识和监管是具有挑战性的,我完全支持提升对人工智能在我们现在和未来角色的意识,但也要理解**“research”“development”是不同的。前者应该得到鼓励和推广,而不是限制,后者则是需要额外努力在“thinking”“knowing”**上的地方。

我希望你觉得这篇自然语言处理分析有趣,并且我想感谢 贾斯廷·亨德里克斯Tech Policy Press 允许我在本文中使用他们的稿件。你可以在这个 GitHub 库中访问完整代码。(同时感谢 ChatGPT 帮助我优化了一些代码,使其展示更佳)

我有遗漏什么吗? 欢迎提出建议,让对话持续进行。

面向 ChatGPT 的 LLM 聊天机器人解耦前端——后端微服务架构

原文:towardsdatascience.com/decoupled-frontend-backend-microservices-architecture-for-chatgpt-based-llm-chatbot-61637dc5c7ea

使用 Streamlit、FastAPI 和 OpenAI API 构建无头 ChatGPT 应用程序的实用指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Marie Stephen Leo

·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 5 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者使用 Midjourney V5.1 生成,提示词:“解耦前端后端软件应用”

我之前的文章中,我讨论了基于 LLM 的聊天机器人应用程序的单体与微服务架构模式之间的差异。选择微服务架构模式的一个显著优势是,它允许前端代码与数据科学逻辑分离,使得数据科学家可以专注于数据科学逻辑,而不必担心前端代码。在这篇文章中,我将向你展示如何使用 Streamlit、FastAPI 和 OpenAI API 构建微服务聊天机器人应用程序。我们将前端和后端代码解耦,以便可以轻松地将前端替换为其他前端框架,如 React、Swift、Dash、Gradio 等。

首先,创建一个新的 conda 环境并安装所需的库。

# Create and activate a conda environment
conda create -n openai_chatbot python=3.10
conda activate openai_chatbot

# Install the necessary libraries
pip install ipykernel streamlit "fastapi[all]" openai

后端:数据科学逻辑

像我之前的博客文章一样,我们将使用 FastAPI 构建后端。任何 API 中最关键的部分是 API 契约,它定义了 API 接受的输入格式和 API 将发送回客户端的输出格式。定义并遵循一个健全的 API 契约可以使前端开发人员独立于 API 开发人员进行工作,只要双方都尊重契约。这就是将前端与后端解耦的好处。FastAPI 允许我们使用 Pydantic 模型轻松地指定和验证 API 契约。我们的后端 API 契约如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

API 合同细节。图片由作者提供

后端将负责以下任务:

  1. 首先,我们初始化一个新的 FastAPI 应用,加载 OpenAI API 密钥,并定义一个系统提示,以告知 ChatGPT 我们希望它扮演的角色。在这种情况下,我们希望 ChatGPT 扮演漫画书助手的角色,因此我们这样提示它。可以随意“设计”不同的提示,并查看 ChatGPT 的回应!

  2. 接下来,我们创建两个 Pydantic 模型,ConversationConversationHistory,用于验证 API 负载。Conversation 模型将验证对话历史记录中的每条消息,而 ConversationHistory 模型只是一个对话列表,用于验证整个对话历史记录。OpenAI ChatGPT API 只接受 assistantuser 作为 role 参数,因此我们在 Conversation 模型中指定了这个限制。如果尝试在 role 参数中发送其他值,API 将返回错误。使用 Pydantic 模型与 FastAPI 配合使用的好处之一就是验证。

  3. 接下来,我们为健康检查保留根路由。

  4. 最后,我们定义一个 /chat 路由,该路由接受一个 POST 请求。该路由将接收一个 ConversationHistory 负载,这是一系列对话。然后,该路由将负载转换为 Python 字典,使用系统提示和负载中的消息列表初始化对话历史记录,使用 OpenAI ChatGPT API 生成响应,并将生成的响应和令牌使用情况返回给 API 调用者。

# %%writefile backend.py
import os
from typing import Literal

import openai
from fastapi import FastAPI
from pydantic import BaseModel, Field

app = FastAPI()

# Load your API key from an environment variable or secret management service
openai.api_key = os.getenv("OPENAI_API_KEY")

system_prompt = "You are a comic book assistant. You reply to the user's question strictly from the perspective of a comic book assistant. If the question is not related to comic books, you politely decline to answer."

class Conversation(BaseModel):
    role: Literal["assistant", "user"]
    content: str

class ConversationHistory(BaseModel):
    history: list[Conversation] = Field(
        example=[
            {"role": "user", "content": "tell me a quote from DC comics about life"},
        ]
    )

@app.get("/")
async def health_check():
    return {"status": "OK!"}

@app.post("/chat")
async def llm_response(history: ConversationHistory) -> dict:
    # Step 0: Receive the API payload as a dictionary
    history = history.dict()

    # Step 1: Initialize messages with a system prompt and conversation history
    messages = [{"role": "system", "content": system_prompt}, *history["history"]]

    # Step 2: Generate a response
    llm_response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo", messages=messages
    )

    # Step 3: Return the generated response and the token usage
    return {
        "message": llm_response.choices[0]["message"],
        "token_usage": llm_response["usage"],
    }

就这样!我们现在可以使用 uvicorn backend:app — reload 在本地机器上运行后端,并通过 127.0.0.1:8000/docs. 使用 Swagger UI 进行测试。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FastAPI 后端文档。图片由作者提供

前端:用户界面

我们将构建前端,使其完全独立于后端。我们只需要遵守后端使用的 API 合同。在构建前端用户界面之前,让我们定义一些辅助函数。

  1. display_conversation() 将帮助我们使用 原生 streamlit 聊天元素 显示对话历史记录。我们可以使用表情符号或文件路径为用户和助手消息选择独特的头像。

  2. clear_conversation() 将帮助我们清除对话历史记录。它还将初始化 conversation_history 会话状态变量以存储对话历史记录,以及 total_cost 会话状态变量以保存总对话成本。

  3. download_conversation() 将允许我们将对话历史记录下载为 CSV 文件。

  4. calc_cost(): 将帮助我们根据使用的令牌数量计算对话成本。OpenAI API 对每 1000 个输出令牌收费 $0.002,对每 1000 个输入令牌收费 $0.0015,所以我们将使用这些费用来计算对话成本。

# %%writefile utils.py
from datetime import datetime

import pandas as pd
import streamlit as st

user_avatar = "😃"
assistant_avatar = "🦸"

def display_conversation(conversation_history):
    """Display the conversation history"""

    # Loop over all messages in the conversation
    for message in conversation_history:
        # Change avatar based on the role
        avatar = user_avatar if message["role"] == "user" else assistant_avatar

        # Display the message content
        with st.chat_message(message["role"], avatar=avatar):
            st.markdown(message["content"])

            if "api_call_cost" in message:
                st.caption(f"Cost: US${message['api_call_cost']:.5f}")

def clear_conversation():
    """Clear the conversation history."""
    if (
        st.button("🧹 Clear conversation", use_container_width=True)
        or "conversation_history" not in st.session_state
    ):
        st.session_state.conversation_history = []
        st.session_state.total_cost = 0

def download_conversation():
    """Download the conversation history as a CSV file."""
    conversation_df = pd.DataFrame(
        st.session_state.conversation_history, columns=["role", "content"]
    )
    csv = conversation_df.to_csv(index=False)

    st.download_button(
        label="💾 Download conversation",
        data=csv,
        file_name=f"conversation_{datetime.now().strftime('%Y%m%d%H%M%S')}.csv",
        mime="text/csv",
        use_container_width=True,
    )

def calc_cost(token_usage):
    # https://openai.com/pricing

    return (token_usage["prompt_tokens"] * 0.0015 / 1000) + (
        token_usage["completion_tokens"] * 0.002 / 1000
    )

现在我们拥有了使用 Streamlit 构建用户界面所需的一切。让我们创建一个 frontend.py 文件并导入我们之前定义的助手函数。

  1. 首先,我们将定义我们 FastAPI 后端的 URL。

  2. openai_llm_response() 将使用 user 角色将最新的用户输入附加到 conversation_history 会话状态变量中。然后,我们将创建一个符合我们后端 FastAPI 应用程序期望的格式的有效负载,包含 history 字段。最后,我们将有效负载发送到后端,并将生成的响应及单次 API 调用的成本附加到 conversation_history 会话状态变量中。我们还将用生成响应的成本增加总成本。

  3. main(): 是 UI 设计的主要部分。在标题下方,我们使用 utils.py 中的助手函数添加了清除和下载对话的按钮。接着我们有一个聊天输入框,用户可以在其中输入问题。按下回车将把输入框中输入的文本发送到后端。最后,我们展示对话的成本和对话历史。

# %%writefile frontend.py
import requests
import streamlit as st
import utils

# Replace with the URL of your backend
app_url = "http://127.0.0.1:8000/chat"

@st.cache_data(show_spinner="🤔 Thinking...")
def openai_llm_response(user_input):
    """Send the user input to the LLM API and return the response."""

    # Append user question to the conversation history
    st.session_state.conversation_history.append(
        {"role": "user", "content": user_input}
    )

    # Send the entire conversation history to the backend
    payload = {"history": st.session_state.conversation_history}
    response = requests.post(app_url, json=payload).json()

    # Generate the unit api call cost and add it to the response
    api_call_cost = utils.calc_cost(response["token_usage"])
    api_call_response = response["message"]
    api_call_response["api_call_cost"] = api_call_cost

    # Add everything to the session state
    st.session_state.conversation_history.append(api_call_response)
    st.session_state.total_cost += api_call_cost

def main():
    st.title("🦸 ChatGPT Comic Book Assistant")

    col1, col2 = st.columns(2)
    with col1:
        utils.clear_conversation()

    # Get user input
    if user_input := st.chat_input("Ask me any comic book question!", max_chars=50):
        openai_llm_response(user_input)

    # Display the total cost
    st.caption(f"Total cost of this session: US${st.session_state.total_cost:.5f}")

    # Display the entire conversation on the frontend
    utils.display_conversation(st.session_state.conversation_history)

    # Download conversation code runs last to ensure the latest messages are captured
    with col2:
        utils.download_conversation()

if __name__ == "__main__":
    main()

就这样!我们已完成前端应用程序。现在我们可以使用 streamlit run frontend.py 进行测试。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Streamlit App 界面。图像来源:作者

结论

使用 OpenAI API 构建一个聊天机器人,采用微服务架构通过将前端与后端解耦是很简单的。以下是一些考虑何时采用解耦架构的想法:

  1. 你的应用相对复杂或需要支持中到大规模的流量。解耦架构允许前端和后端独立扩展,以处理大规模流量。

  2. 你有专门的前端开发资源来构建 UI,或者需要为外部客户提供高度精致的 UI。在本教程中,我们使用了 Streamlit 构建了一个简单的用户界面,但构建更复杂的 UI 可能会变得困难甚至不可能。最好使用像 React、Swift 等专业 UI 框架来构建面向客户的应用程序。

  3. 你想独立于前端改进数据科学逻辑。例如,你可以更新提示词或添加多个微服务,所有这些都由 API 服务器入口点进行协调,只要你遵守与前端工程师达成的相同 API 合同,就无需担心前端代码。

然而,可能会有一些情况下,解耦不是你应用的最佳架构选择。以下是一些关于何时不使用解耦架构的想法:

  1. 你的应用很简单或流量较低。你可以使用单体应用程序,因为扩展不是问题。

  2. 你没有专门的前端开发资源来构建用户界面,或者你的应用程序仅服务于内部客户,这些客户可能对粗糙的用户界面设计更为宽容。尤其是在构建最小可行产品或原型时,这一点尤为明显。

  3. 你是一个想要同时提升数据科学逻辑和前端界面的独角兽!

深度确定性策略梯度(DDPG)解释

原文:towardsdatascience.com/deep-deterministic-policy-gradients-explained-4643c1f71b2e

一种基于梯度的强化学习算法,用于学习连续动作空间中的确定性策略

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Wouter van Heeswijk, PhD

·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 4 月 5 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Jonathan Ford 提供,来源于 Unsplash

本文介绍了深度确定性策略梯度(DDPG)——一种适用于连续动作空间中的确定性策略的强化学习算法。通过将演员-评论家范式与深度神经网络结合,可以在不依赖随机策略的情况下处理连续动作空间。

尤其适用于动作中的随机性不受欢迎的连续控制任务——例如,机器人技术或导航——DDPG 可能正是你需要的算法。

DDPG 在强化学习领域中适合什么位置?

DDPG 结合了基于策略的方法和基于价值的方法,形成了一种混合策略类。

## 强化学习的四种策略类别

[towardsdatascience.com

策略梯度方法REINFORCETRPO和 PPO 使用随机策略 π:a~P(a|s) 来探索和比较动作。这些方法从可微分分布 P_θ(a|s) 中提取动作,从而能够计算相对于 θ 的梯度。这些决策中的固有随机性可能在实际应用中不受欢迎。DDPG 消除了这种随机性,产生了更简单和更可预测的策略。

基于价值的方法如 SARSA、蒙特卡罗学习和深度 Q 学习基于确定性策略,该策略始终根据输入状态返回一个单一的动作。然而,这些方法假设动作的数量是有限的,这使得在具有无限多个动作的连续动作空间中评估它们的价值函数和选择最有回报的动作变得困难。

正如你猜测的那样,深度确定性策略梯度(DDPG)填补了这一空白,结合了深度 Q 学习和策略梯度方法的元素。DDPG 有效地处理连续动作空间,并已成功应用于机器人控制和游戏任务中。

如果你不熟悉策略梯度算法(特别是 REINFORCE)或基于价值的方法(特别是 DQN),建议在探讨 DDPG 之前先了解它们。

DDPG: 评论者

DDPG 与深度 Q 学习非常接近,共享了符号和概念。让我们快速了解一下。

DQN 对连续动作空间的适用性?

在普通(即,表格)Q 学习中,我们使用 Q 值来逼近贝尔曼价值函数 V。Q 值为每个状态-动作对定义,因此用 Q(s,a) 表示。表格 Q 学习需要一个查找表来包含每对的 Q 值,因此需要离散的状态空间和离散的动作空间。

是时候将 ‘深度’ 融入深度强化学习中了。与查找表相比,引入神经网络有两个优点:(i)它为整个状态空间提供了一个通用表达式,(ii)因此,它还可以处理连续的 状态 空间。

当然,我们需要处理连续的 动作 空间;因此我们不能为每个动作输出 Q 值。相反,我们提供一个动作作为 输入 并计算状态-动作对的 Q 值(这个过程也称为简单 DQN)。用数学术语来说,我们可以将网络表示为 Q*:(s,a)→Q(s,a)*,即为给定的状态-动作对输出一个单一的 Q 值。

相应的评论员网络如下所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 中的评论员网络 Q:(s,a)示例。网络同时接受状态向量和动作向量作为输入,并输出一个 Q 值[图片由作者提供]

尽管提供了泛化能力,神经网络也引入了一些稳定性问题。作为所有状态的单一表示,每次更新也会影响所有 Q 值。由于观察元组*(s,a,r,s’)*是顺序收集的,它们之间往往存在高时间相关性,使得过拟合变得非常可能。在这里不深入细节,正确训练价值网络需要以下三种技术:

  • 经验回放:从经验缓冲区中采样观察数据*(s,a,r,s’)*,打破随后收集的元组之间的相关性。

  • 批量学习:用观察数据批次训练价值网络,使更新更可靠且有影响力。

  • 目标网络:使用不同的网络来计算Q(s’,a’)Q(s,a),减少期望与观察之间的相关性。

如何建模经验回放、批量学习和目标网络 [## 如何建模经验回放、批量学习和目标网络

关于稳定且成功的深度 Q 学习的三个基本技巧的快速教程,使用 TensorFlow 2.0

如何建模经验回放、批量学习和目标网络

评论员网络更新

现在基础知识已经更新,让我们将上述概念与 DDPG 结合起来。我们定义一个评论员网络 Q_ϕ,如前所述,由ϕ(代表网络权重)参数化。

我们设定了一个损失函数,目标是最小化,这对于有 Q 学习经验的人应该是熟悉的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与 DQN 相比,关键区别在于对于* s’对应的动作——而不是在动作空间中最大化——**我们通过目标演员网络μ_{θ_targ}确定动作a’***(稍后会详细讲解)。在这个旁道之后,我们像往常一样更新评论员网络。

除了更新主评论员网络,我们还必须更新目标评论员网络。在深度 Q 学习中,这通常是主价值网络的周期性副本(例如,每 100 集复制一次)。在 DDPG 中,通常使用滞后目标网络进行 Polyak 平均,使目标网络落后于主价值网络:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于ρ通常接近 1,目标网络适应非常缓慢,这逐渐提高了训练稳定性。

DDPG:演员

在 DDPG 中,actor 和 critic 以离策略的方式紧密相连。我们首先探索算法的离策略特性,然后再讨论动作生成和 actor 网络更新。

离策略训练

在纯策略梯度方法中,我们直接更新策略 μ_θ(由θ 参数化)以最大化期望奖励,而不依赖于显式的价值函数来捕捉这些奖励。DDPG 是一种混合方法,它还使用 Q 值,但从 actor 的角度来看,最大化目标 J(θ) 表面上看是相似的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

然而,仔细观察期望值会发现 DDPG 是一种离策略方法,而典型的 actor-critic 方法是同策略的。大多数 actor-critic 模型最大化期望值 E_{τ~π_θ},其中τ 是由策略 π_θ 生成的状态-动作轨迹。相对而言,DDPG 对从经验缓冲区中抽取的样本状态进行期望值优化(E_{s~D})。由于 DDPG 使用不同策略生成的经验来优化策略,因此它是一种离策略算法

在这种离策略背景下,重放缓冲区的作用需要一些关注。为什么可以重用旧经验,为什么应该这样做?

首先,让我们探讨为什么缓冲区可以包含与当前策略不同的经验。随着策略的不断更新,重放缓冲区保存了源自过时策略的经验。由于最优 Q 值适用于任何过渡,因此生成这些经验的策略并不重要。

其次,重放缓冲区应该包含多样化的经验的原因在于我们部署的是确定性策略。如果算法是同策略的,我们可能会有有限的探索。通过借鉴过去的经验,我们还会在当前策略下不太可能遇到的观察值上进行训练。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 中的 actor 网络 μ_θ:s 示例。该网络将状态向量作为输入,并输出确定性动作 μ(s)。在训练过程中,通常会添加一个单独的随机噪声 ϵ [图像由作者提供]

动作探索

那么,政策梯度方法中著名的探索机制如何呢?毕竟,我们现在部署的是确定性策略而非随机策略,对吧?DDPG 通过在训练过程中添加一些噪声 ϵ 来解决这个问题,在部署策略时去除这些噪声。

早期的 DDPG 实现使用了相当复杂的噪声结构(例如,时间相关的奥恩斯坦-乌伦贝克噪声),但后来的实验证明,普通高斯噪声 ϵ~N(0,σ²) 的效果同样良好。噪声可能会随着时间的推移逐渐减少,但不像随机策略中的σ_θ那样是一个可训练的组件。最后一点是,我们可能会裁剪动作范围。显然,探索过程中涉及一些调优工作。

总之,演员生成动作如下。它以状态作为输入,输出一个确定的值,并添加一些随机噪声:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

演员网络更新

关于策略更新的最终说明,这并不一定是简单的。我们根据评论家网络(由 ϕ 参数化)返回的 Q 值来更新演员网络参数 θ。因此,我们保持 Q 值不变——即,我们在这一步不更新 ϕ——并通过改变动作来最大化预期奖励。这意味着我们 假设评论家网络对动作是可微的,以便我们可以在一个最大化 Q 值的方向上更新动作:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

尽管第二个梯度 ∇_θ 常常为了可读性而省略,但它提供了一些说明。我们训练演员网络以 输出更好的动作,从而改进获得的 Q 值。如果愿意,你可以通过应用链式法则来详细说明这个过程。

演员目标网络使用 Polyak 平均进行更新,与评论家目标网络类似。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

综合整理

我们有一个演员,也有一个评论家,因此现在没有什么可以阻止我们完成算法!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 大纲 [作者提供的图像,初步大纲由 ChatGPT 生成]

让我们一步一步地进行详细说明。

初始化 [第 1–4 行]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 初始化 [作者提供的图像]

我们从以下四个网络开始:

**演员网络 μ_**θ

  • 由 θ 参数化

  • 基于输入 s 输出确定性动作

演员目标网络 μ_{θ_targ}

  • θ_targ 参数化

  • 在训练评论家网络时提供 s’ 的动作

评论家网络 Q_ϕ(s,a)

  • ϕ 参数化

  • 基于输入 (s,a) 输出 Q 值 Q(s,a)(期望)

评论家目标网络 μ

  • ϕ_tar 参数化

  • 在训练评论家网络时输出 Q 值 Q(s’,a’)(目标)

我们从一个空的重放缓冲区 D 开始。与策略方法不同,我们 在更新策略后不会清空缓冲区, 因为我们会重复使用旧的过渡。

最后,我们将学习率 ρ 设置为 更新目标网络。 为了简单起见,我们假设两个目标网络的学习率相同。请记住,ρ 应设置接近 1(例如,0.995),以便网络更新缓慢,目标保持相对稳定。

数据收集 [第 9–11 行]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 数据收集 [作者提供的图像]

动作通过演员网络生成,演员网络输出确定性动作。为了增加探索,向这些动作中添加噪声。生成的观测元组存储在重放缓冲区中。

更新演员和评论家网络 [第 12–17 行]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 主网络更新 [作者提供的图片]

从重放缓冲区中采样一个随机小批量 B⊆D(包括源自较旧策略的观察值)。

要更新评论者,我们最小化平方误差,即观察值(通过目标网络获得)和期望值(通过主网络获得)之间的误差。

为了更新演员,我们计算样本策略梯度,同时保持 Q 值固定。在神经网络设置中,我们计算伪损失,即生成动作的累积 Q 值。

训练过程可以通过下面的 Keras 代码片段进行澄清:

更新目标网络 [第 18–19 行]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 目标网络更新 [作者提供的图片]

演员目标网络和评论者目标网络使用Polyak 平均进行更新,它们的权重略微靠近更新后的主网络。

返回训练后的网络 [第 23 行]

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DDPG 训练后的演员网络 [作者提供的图片]

尽管我们经历了一些麻烦,但结果策略看起来非常干净。与 DQN 不同,我们不对动作空间进行显式最大化,因此不再需要 Q 值 [注意我们从未用它们来做决策,只是用来改进它们]。

我们不再需要目标网络,这些网络只是为了稳定训练和防止振荡。理想情况下,主网络和目标网络将会收敛,从而使得我们有μ_θ=μ_{θ_targ}(以及 Q_ϕ=Q_{ϕ_targ})。这样,我们知道我们的策略确实已经收敛。

最后,我们去掉了探索噪声 ϵ,它从未是策略的一个核心部分。我们得到一个接受状态作为输入并输出确定性动作的演员网络,这在许多应用中正是我们所希望的简单性。

什么使 DDPG 与其他算法不同?

我们确定 DDPG 是一种混合类方法,结合了策略梯度方法和基于值的方法。这适用于所有演员-评论者方法,那么究竟是什么使 DDPG 独特呢?

  • DDPG 处理连续动作空间: 该算法专门设计用于处理连续动作空间,而不依赖于随机策略。确定性策略可能更容易学习,并且在实际应用中没有固有随机性的策略通常是更可取的。

  • DDPG 是离线策略的。 与常见的演员-评论者算法不同,经验来自于包括较旧策略的观察值的重放缓冲区。离线策略的性质对于充分探索是必要的(因为动作是确定性生成的)。它还提供了更高的样本效率和更好的稳定性。

  • DDPG 在概念上非常接近 DQN: 从本质上讲,DDPG 是 DQN 的一个变体,适用于连续动作空间。为了避免明确地在所有动作上进行最大化——DQN 通过枚举整个动作空间来识别最高的 Q(s,a) 值——动作由一个单独优化的演员网络提供。

  • DDPG 输出一个演员网络: 尽管在训练上接近 DQN,但在部署过程中我们只需要训练好的演员网络。这个网络将状态作为输入,并确定性地输出一个动作。

尽管乍一看可能不那么明显,DDPG 的确定性特性往往简化了训练,比在线方法更稳定、更高效。输出是一个全面的演员网络,该网络确定性地生成动作。由于这些特性,它已成为连续控制任务中的重要工具。

想了解更多关于 DDPG 构建模块的背景?查看以下文章。

深度 Q 学习(DQN):

## TensorFlow 2.0 中深度 Q 学习的最小工作示例

一个多臂赌博机的例子,用于训练 Q 网络。更新过程只需要几行代码,使用 TensorFlow。

## 深度 Q 学习在悬崖行走问题中的应用 ## 深度 Q 学习在悬崖行走问题中的应用 [## 深度 Q 学习在悬崖行走问题中的应用

一个完整的 Python 实现,使用 TensorFlow 2.0 进行悬崖导航。

## 深度 Q 学习在悬崖行走问题中的应用

策略梯度方法:

## 强化学习中的策略梯度解释 ## 强化学习中的策略梯度解释 [## 强化学习中的策略梯度解释

了解基于似然比的策略梯度算法(REINFORCE):直觉、推导…

## 深度策略梯度用于悬崖行走 ## 深度策略梯度用于悬崖行走 [## 深度策略梯度用于悬崖行走

一个用 Python 实现的 TensorFlow 2.0 解决方案。在这个方案中,演员由一个神经网络表示,该网络…

## 深度策略梯度用于悬崖行走

参考文献

OpenAI (2018). 深度确定性策略梯度。 spinningup.openai.com/en/latest/algorithms/ddpg.html

Keras (2020). 深度确定性策略梯度(DDPG)由 amifunny 制作。 keras.io/examples/rl/ddpg_pendulum/

Lillicrap, T. P., Hunt, J. J., Pritzel, A., Heess, N., Erez, T., Tassa, Y., … & Wierstra, D. (2015). 基于深度强化学习的连续控制。 arXiv 预印本 arXiv:1509.02971

Yang, A. & Philion, J. (2020). 深度强化学习中的连续控制www.pair.toronto.edu/csc2621-w20/assets/slides/lec3_ddpg.pdf

深入了解 ESA 的哨兵 API

原文:towardsdatascience.com/deep-dive-into-esas-sentinel-api-e6ff4f9d0730

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于 10 米分辨率哨兵数据的布达佩斯 RGB 卫星地图片段。

如何使用 Python 获取、分析和可视化卫星图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Milan Janosov

·发表于Towards Data Science ·13 分钟阅读·2023 年 10 月 26 日

本文中的所有图像均由作者创建。

欧洲航天局一直在运行其哨兵任务,支持哥白尼,即欧盟空间计划的地球观测组件,提供各种类型的遥感数据,如雷达和多光谱成像仪器,用于陆地、海洋和大气监测。

目前有六个进行中的哨兵任务,其中三个可以通过它们的Python API轻松访问。这些任务,引用官方来源

哨兵-1是一个极轨、全天候、昼夜雷达成像任务,用于陆地和海洋服务。哨兵-1A 于 2014 年 4 月 3 日发射,哨兵-1B 于 2016 年 4 月 25 日发射。两者均通过苏联运载火箭从欧洲法属圭亚那的航天发射场送入轨道。哨兵-1B 任务于 2022 年结束,并计划尽快发射哨兵-1C。

哨兵-2是一个极轨、多光谱高分辨率成像任务,专用于陆地监测,例如提供植被、土壤和水体覆盖、内陆水道和沿海地区的图像。哨兵-2 还可以提供紧急服务的信息。哨兵-2A 于 2015 年 6 月 23 日发射,哨兵-2B 则于 2017 年 3 月 7 日发射。

Sentinel-3 是一个多仪器任务,用于高端准确性和可靠性地测量海表地形、海洋和陆地表面温度、海洋颜色和陆地颜色。该任务支持海洋预测系统以及环境和气候监测。Sentinel-3A 于 2016 年 2 月 16 日发射,Sentinel-3B 于 2018 年 4 月 25 日与其双胞胎一起进入轨道。

经过一些额外的挖掘,我们可以了解到Sentinel-1的数据在空间分辨率方面可以达到几米。而Sentinel-2的视觉数据最高分辨率为 10 米,Sentinel-3则根据传感器类型在 100 公里的规模上运行。

好的,所以我们知道如何获取卫星数据,看起来还有很多的源(传感器)和空间分辨率可以选择。有人可能会指出,这仅仅是冰山一角,正如这卫星数据源列表所概述的那样。那么,我们将这些不同类型的卫星数据用于什么呢?首先,这里有 50 多个用例的选择

一般来说,我认为用例、问题的具体细节以及目标区域的地理空间特征和地形都是确定适合的数据源的重要因素。然而,在实际操作中,根据我的经验,这些是主要的驱动因素:

  • 价格(最好是免费探索,适用于 Sentinel)

  • 具有几米的空间分辨率,甚至较小的城市结构也可以被捕捉到。

  • 至少具有几个波段,例如可见光和近红外。

  • 时间频率

这些方面使得 Sentinel-2 可能是地理空间数据社区中使用最广泛的卫星数据源。基于这些组件,在这篇文章中,我将向你展示如何获取 Sentinel 数据以及下载时应该期待什么。我还将深入探讨不同的可能性以及图像记录和存储的信息的时间演变。

在这篇文章中,使用了 2023 年的 Copernicus Sentinel 数据,因为欧盟法律允许免费访问 Copernicus Sentinel 数据和服务信息。

1. 数据获取

首先,我将按照官方文档和示例代码设置 API 连接。此外,我还需要一个目标区域来下载图像。为了方便调试,我选择了我的家乡布达佩斯。我将使用 OSMNx 下载其行政边界。

import osmnx as ox # version: 1.0.1
import matplotlib.pyplot as plt # version: 3.7.1

city = 'Budapest'
admin = ox.geocode_to_gdf(city)
admin.plot()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

布达佩斯的行政边界。

现在点击 Sentinel API:

from sentinelsat import SentinelAPI, read_geojson, geojson_to_wkt # version 0.14

# to get an account, sign up here: https://apihub.copernicus.eu/apihub
user = <add your user name 
password = < add your password >
api = SentinelAPI(user, password, 'https://apihub.copernicus.eu/apihub') 

为了进行查询,最好有一个平滑的多边形来指定位置。为此,我创建了布达佩斯行政区的凸包:

# to simplify the query, I extract the convex hull of the input polygon
admin_polygon = admin.convex_hull.geometry.to_list()[0]
admin_polygon

输出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

布达佩斯的凸包。

在我们选择的平台和给定的时间框架内搜索卫星图像。后者应为 Sentinel-A。此外,我们还可以根据云覆盖进行筛选——这意味着如果图像过于多云,我们将立即丢弃它。

# here we can specifcy the location (based on a polygon)
# the time frame
# the space probe
# and the level of cloud-coverage accepted 

products = api.query(admin_polygon,
                     date=('20150623', '20231006'),
                     platformname='Sentinel-2',
                     cloudcoverpercentage=(0, 100))

len(products)

正如这些单元的输出所示,遵循 Sentinel 文档,结果显示在 2015 年 6 月 23 日(任务开始)和 2023 年 10 月 6 日(我撰写本文时),总共记录了 3876 张与布达佩斯行政边界重叠的卫星图像。我将云覆盖百分比设置为 100,这意味着没有基于云覆盖的筛选。因此,我们应该拥有过去八年的所有图像标识符。

我还注意到,结果产品列表包含了所有卫星图像的标识符和元数据,但不包含图像本身。此外,如果我用 Sentinel-3 重复相同的操作,结果将得到近 2 万条图像记录——尽管分辨率要低得多。

2. 探索元数据

让我们将产品列表转换为 Pandas DataFrame 并开始分析吧!

import pandas as pd # version: 1.4.2

products_gdf = api.to_geodataframe(products)
products_gdf = products_gdf.sort_values(['beginposition'], ascending=[True])
print(products_gdf.keys())
print(len(products_gdf.keys()))
products_gdf.head(3)

此块的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

查询结果预览。

通过计算表中由卫星图像标识符索引的键的数量,人们可以感受到这些数据有多丰富,其中包含 41 个特征列。

虽然这些领域中有很多技术信息,但我希望仔细查看几个特征。一方面,空间和时间维度编码在生成日期和开始位置(作为日期时间信息)以及几何形状(作为多边形、GIS、数据类型)中。另一方面,有几个有趣的指标基于图像描述土地覆盖类型:cloudcoverpercentage(我们在查询中已经看到过),vegetationpercentagewaterpercentagesnowicepercentage。这些环境指数是从不同材料的光谱特性中得出的。注意:这些值都是汇总得分,捕捉了整个瓦片的总体平均值。更多信息请见 这里

3. 空间维度

由于我们有几何维度,让我们看看这在地图上的样子!我将通过可视化一组随机图块来做到这一点,这些图块在几次运行后完全具有代表性。为了可视化,我使用了带有 CartoDB Dark_Matter 基础地图的 Folium。

import folium
import geopandas as gpd

x, y = admin_polygon.centroid.xy
m = folium.Map(location=[y[0], x[0]], zoom_start=8, tiles='CartoDB Dark_Matter')

# visualize a set of random tiles
polygon_style = { 'fillColor': '#39FF14', 'color': 'black',  'weight': 3, 'opacity': 0}
geojson_data = products_gdf[['geometry']].sample(10).to_json()
folium.GeoJson(
    geojson_data,
    style_function=lambda feature: polygon_style
).add_to(m)

# add the admin boundaries on top
admin_style = {'fillColor': '#00FFFF',  'color': 'black','weight': 3, 'opacity': 100.0  }
admin_geojson_data = admin[['geometry']].to_json()
folium.GeoJson(
    admin_geojson_data,
    style_function=lambda feature: admin_style
).add_to(m)

# show the map
m

该代码块的输出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与布达佩斯行政区重叠或交叉的卫星图块的随机样本。

从这幅可视图可以看出,几部分图块不断重复。同时也明显有些图块将城市的行政边界分成了两半。这可能导致无法避免的情况,即你想分析完全覆盖你目标区域的数据,却发现它被分成了两半。一种可能的解决方法是过滤掉那些没有完全覆盖所需行政区域的图块:

def compute_overlapping_area(tile, admin):
    return tile.intersection(admin_polygon).area / admin_polygon.area

products_gdf['overlapping_area_fraction'] = products_gdf.geometry.apply(lambda x: compute_overlapping_area(x, admin_polygon))
products_gdf_f = products_gdf[products_gdf.overlapping_area_fraction==1]
print(len(products_gdf))
print(len(products_gdf_f))
products_gdf_f.head(3)

该单元格的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

过滤后的卫星图像产品数据集预览

通过应用此过滤器,我去掉了大约一半的图块。现在让我们看看地图,看看它们与城市边界的重叠情况,以及没有图块将城市分成两半的情况:

import folium
import geopandas as gpd

x, y = admin_polygon.centroid.xy
m = folium.Map(location=[y[0], x[0]], zoom_start=8, tiles='CartoDB Dark_Matter')

# visualize a set of random tiles
polygon_style = { 'fillColor': '#39FF14', 'color': 'black',  'weight': 3, 'opacity': 0}
geojson_data = products_gdf_f[['geometry']].sample(10).to_json()
folium.GeoJson(
    geojson_data,
    style_function=lambda feature: polygon_style
).add_to(m)

# add the admin boundaries on top
admin_style = {'fillColor': '#00FFFF',  'color': 'black','weight': 3, 'opacity': 100.0  }
admin_geojson_data = admin[['geometry']].to_json()
folium.GeoJson(
    admin_geojson_data,
    style_function=lambda feature: admin_style
).add_to(m)

# show the map
m

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

完全覆盖布达佩斯行政区的卫星图块的随机样本。

4. 数据集的时间维度

首先,让我们查看每天、每周和每月覆盖布达佩斯的图像数量。为了测量时间,我将依赖字段beginposition

# Assuming 'beginposition' is a Timestamp column in your GeoDataFrame
# You can convert it to a DateTime index
products_gdf_f_cntr = products_gdf_f.copy()
products_gdf_f_cntr['beginposition'] = pd.to_datetime(products_gdf_f_cntr['beginposition'])
products_gdf_f_cntr.set_index('beginposition', inplace=True)

# Resample the data to count rows per day, week, and month
daily_counts = products_gdf_f_cntr.resample('D').count()
weekly_counts = products_gdf_f_cntr.resample('W').count()
monthly_counts = products_gdf_f_cntr.resample('M').count()

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
for idx, (count_name, count_val) in enumerate([('Daily Counts', daily_counts), ('Weekly Counts', weekly_counts), ('Monthly Counts', monthly_counts), ]): 

    ax[idx].plot(count_val.index[0:250], count_val['geometry'].to_list()[0:250])
    ax[idx].set_xlabel('Date')
    ax[idx].set_ylabel('Count')
    ax[idx].set_title(count_name)

plt.tight_layout()
plt.suptitle('Number of satellite images taken in various time-frames', fontsize = 20, y = 1.15)
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每天、每周和每月在布达佩斯目标区域捕获的卫星图像数量。

这些图形展示了 Sentinel-2 探测器的前 250 天、前 250 周和前 250 个月(整个时间跨度)。第一幅图显示了每隔一天拍摄一次快照。第二幅图显示了对前一幅图的每周平均值的计算,显示在前两年中,卫星每周拍摄布达佩斯一次或两次,然后从 2017 年到 2018 年,拍摄次数增加到每周 5-6 张。最后一幅图展示了整个时间跨度,显示了相同的趋势以及在工作了 3 年后,这 25 张每月图像成为了标准水平。

5. 土地覆盖变量的时间维度

现在,来看一下植被百分比水体百分比雪冰百分比云量百分比的时间演变。如前图所示,早期年份可能会显示不同的,可能是噪声结果,所以我们要保持谨慎。在这里,我不会丢弃那些年份的数据,因为我们总共有八年,去掉其中 3 年可能会丢失太多信息。首先,只需查看随时间变化的原始值,并进行每周聚合:

import pandas as pd
import matplotlib.pyplot as plt

# Assuming 'beginposition' is a Timestamp column in your GeoDataFrame
# You can convert it to a DateTime index
products_gdf_f_cntr = products_gdf_f.copy()
products_gdf_f_cntr['beginposition'] = pd.to_datetime(products_gdf_f_cntr['beginposition'])
products_gdf_f_cntr.set_index('beginposition', inplace=True)

# Resample the data to calculate weekly averages
weekly_averages = products_gdf_f_cntr[['vegetationpercentage', 'waterpercentage', 'snowicepercentage', 'cloudcoverpercentage']].resample('W').mean()

# Create a multi-plot figure with four subplots
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(10, 15))

# Plot 'vegetationpercentage' with a green line
ax1.plot(weekly_averages.index, weekly_averages['vegetationpercentage'], color='green', label='Weekly Average Vegetation Percentage')
ax1.set_xlabel('Date')
ax1.set_ylabel('Percentage')
ax1.set_title('Weekly Average Vegetation Percentage')
ax1.legend()

# Plot 'waterpercentage' with a blue line
ax2.plot(weekly_averages.index, weekly_averages['waterpercentage'], color='blue', label='Weekly Average Water Percentage')
ax2.set_xlabel('Date')
ax2.set_ylabel('Percentage')
ax2.set_title('Weekly Average Water Percentage')
ax2.legend()

# Plot 'snowicepercentage' with a cyan line
ax3.plot(weekly_averages.index, weekly_averages['snowicepercentage'], color='cyan', label='Weekly Average Snow/Ice Percentage')
ax3.set_xlabel('Date')
ax3.set_ylabel('Percentage')
ax3.set_title('Weekly Average Snow/Ice Percentage')
ax3.legend()

# Plot 'cloudcoverpercentage' with a gray line
ax4.plot(weekly_averages.index, weekly_averages['cloudcoverpercentage'], color='gray', label='Weekly Average Cloud Cover Percentage')
ax4.set_xlabel('Date')
ax4.set_ylabel('Percentage')
ax4.set_title('Weekly Average Cloud Cover Percentage')
ax4.legend()

plt.tight_layout()
plt.show() 

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

植被、水、雪和云量百分比的时间演变,以周为单位进行聚合。

以及月度聚合的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

植被、水、雪和云量百分比的时间演变,以月度聚合为单位。

这些时间序列告诉我们一些有趣的事情:

  • 植被百分比清楚地显示了季节性的变化,每年春天一切变绿,然后在秋天这种绿意逐渐消退,从 50-60%降到接近零。

  • 相比之下,水的百分比在全年和整个观察期间波动在 0.8%左右。这是因为研究区域的地表水量非常小。尽管如此,冬季的降水似乎更频繁,这意味着一些淡水体结冰。

  • 关于雪,最突出的峰值——大约 4-8%出现在冬季。尽管如此,基于个人经验,我可以说我们没有很多雪。因此,测量值仅为 1-2%,尤其是在非冬季,可能会导致一些噪声甚至云的错误分类。

  • 关于云层,我们看到它们大多与植被同步,遵循季节性模式。

一些观察结果在这些指标的相关性中也很明显:

products_gdf_f_cntr[['vegetationpercentage', 'waterpercentage', 'snowicepercentage', 'cloudcoverpercentage']].corr()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

环境变量随时间变化的相关性。

6. 下载卫星图像

首先,对 Sentinel-2 和 Sentinel-3 进行查询,选择今年八月的同一周,并尽可能限制云覆盖。查看可用的快照数量:

# query tile product ids
products_sent = api.query(admin_polygon, date=('20230806', '20230813'), platformname='Sentinel-2', cloudcoverpercentage=(0, 1))
products_sent = api.to_geodataframe(products_sent)

f, ax = plt.subplots(1,1,figsize=(6,4))
admin.plot(ax=ax, color = 'none', edgecolor = 'k')
ax.set_title('Sentinel-2, number of tiles = ' + str(len(products_sent)))
products_sent.plot(ax=ax, alpha = 0.3)

# filter out the tiles not fully overlapping with Budapest
products_sent['overlapping_area_fraction'] = products_sent.geometry.apply(lambda x: compute_overlapping_area(x, admin_polygon))
products_sent = products_sent[products_sent.overlapping_area_fraction==1]

f, ax = plt.subplots(1,1,figsize=(6,4))
admin.plot(ax=ax, color = 'none', edgecolor = 'k')
ax.set_title('Sentinel-2, number of tiles = ' + str(len(products_sent)))
products_sent.plot(ax=ax, alpha = 0.3)

len(products_sent)

这个代码块的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

查询的瓷砖。

现在根据

# download the first tiles as sat images
product_ids = products_sent.index.to_list()

for prod in product_ids:
    api.download(prod)

注意 — 在这里你可能会收到这个通知,这种情况下只需等待几个小时,然后再次运行下载器。

Product a3c61497-d77d-48da-9a4d-394986d2fe1d is not online. Triggering retrieval from long term archive.

7. 打开并可视化下载的图像

这里你可以找到关于数据格式的详细描述,以及关于文件夹结构的漂亮可视化图像。打开图像目录后,可以找到不同的波段。每个波段的含义以及其空间分辨率在这篇文章中有很好的总结,13 个波段的空间分辨率范围从 10 到 60 米。几个亮点:

  • 蓝色(B2)、绿色(B3)、红色(B4)和近红外(B8)频道具有 10 米分辨率。

  • 然后,其植被红边(B5)、近红外(B6、B7 和 B8A)以及短波红外(B11 和 B12)具有 10 米分辨率。

  • 最终,其海岸气溶胶(B1)和短波红外气溶胶(B10)的像素大小为 60 米。

这就是它

# after unzipping the downloaded folder:
import os
image_path = 'S2B_MSIL1C_20230810T094549_N0509_R079_T34TCT_20230810T124346.SAFE/GRANULE/L1C_T34TCT_A033567_20230810T095651/IMG_DATA'
sorted(os.listdir(image_path))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

存储在 .jp2 格式中的卫星图像波段列表。

这是一整张瓦片在使用rasterio可视化 B4 红色波段时的样子:

import rasterio
from rasterio.plot import show

image_file = 'T34TCT_20230810T094549_B04.jp2'

with rasterio.open(image_path + '/' + image_file) as src:

    image = src.read(1)  # Change the band index as needed
    plt.figure(figsize=(10, 10))
    plt.imshow(image, cmap='Reds')  # You can change the colormap
    plt.title(image_file)
    plt.colorbar()
    plt.show()

输出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

该瓦片的红色波段包含

现在集中在布达佩斯,并按城市的行政边界分别对 R、G 和 B 波段进行掩膜处理:

from rasterio import mask

f, ax = plt.subplots(1,3,figsize=(15,5))

for idx, (band_name, band_num, color_map) in enumerate([('Blue', 'B02', 'Blues'), ('Green', 'B03', 'Greens'), ('Red', 'B04', 'Reds')]):

    raster_path = image_path + '/T34TCT_20230810T094549_' + band_num + '.jp2'

    with rasterio.open(raster_path) as src:
        polygons = admin.copy().to_crs(src.crs)
        geom = polygons.geometry.iloc[0]
        masked_image, _ = mask.mask(src, [geom], crop=True)

    ax[idx].imshow(masked_image[0], cmap=color_map)
    ax[idx].set_title('Budapest Sentinel 2 - ' + band_name + ' band')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

布达佩斯的三种可视化卫星波段。

最后,我们将这些图像拼接成一张布达佩斯的 RGB 图像。首先,我将完整的瓦片拼接成 RGB 图像,然后读取它,按照官方说明进行直方图均衡化,最后得到最终的图像。

# Get the band locations
band_blue = '/T34TCT_20230810T094549_B02.jp2'
band_green = '/T34TCT_20230810T094549_B03.jp2'
band_red = '/T34TCT_20230810T094549_B04.jp2'

# Read in the bands and create the full RGB tile
b2   = rasterio.open(image_path + '/' + band_blue)
b3   = rasterio.open(image_path + '/' + band_green)
b4   = rasterio.open(image_path + '/' + band_red)

# export the full tile as a tif file
meta = b4.meta
meta.update({"count": 3})
prefire_rgb_path = 'budapest_rgb.tif'
with rasterio.open(prefire_rgb_path, 'w', **meta) as dest:
    dest.write(b2.read(1),1)
    dest.write(b3.read(1),2)
    dest.write(b4.read(1),3)

# crop and save it to the admin boundaries of budapest
with rasterio.open('budapest_rgb.tif') as src:
    polygons = admin.copy().to_crs(src.crs)
    geom = polygons.geometry.iloc[0]
    out_image, out_transform  = mask.mask(src, [geom], crop=True)
    out_meta = src.meta.copy()
    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width" : out_image.shape[2],
                     "transform": out_transform})

with rasterio.open('budapest_rgb_cropped.tif', "w", **out_meta) as dest:
    dest.write(out_image)

# read and show the cropped version
import numpy as np
from skimage import exposure

img = rasterio.open('budapest_rgb_cropped.tif')
image = np.array([img.read(3), img.read(2), img.read(1)])
image = image.transpose(1,2,0)

# do the histogram equalization
p2, p98 = np.percentile(image, (2,98))
image = exposure.rescale_intensity(image, in_range=(p2, p98)) / 100000

f, ax = plt.subplots(1,1,figsize=(15,15))
rasterio.plot.show(image.transpose(2,0,1), transform=img.transform, ax = ax)
ax.axis('off')
plt.savefig('budapest_rgb_cropped_2.png', dpi = 100, bbox_inches = 'tight')

输出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于 10 米分辨率 Sentinel 数据的布达佩斯 RGB 卫星图像。

结论

快速总结一下,看看这篇文章中发生了什么:

  • Sentinel 卫星平台的快速概述

  • 查询多个图像标识符及其元数据的示例

  • 如何仅基于瓦片的汇总信息在元数据中进行时间分析

  • 如何下载、存储和可视化单张图像

所有这些步骤的目的是将卫星图像处理和分析添加到你每天使用的地理空间数据科学工具中,这可以涵盖从城市规划到环境监测和农业等众多应用场景。

深入探讨 Apache Spark 数据倾斜的处理方法

原文:towardsdatascience.com/deep-dive-into-handling-apache-spark-data-skew-57ce0d94ee38

分布式计算中处理数据倾斜的终极指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Chengzhi Zhao

·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 1 月 3 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Lizzi Sassman 提供,来源于 Unsplash

为什么我的 Spark 作业运行很慢? 是使用 Apache Spark 时一个不可避免的问题。关于 Apache Spark 性能调优的常见场景之一是 数据倾斜。在本文中,我们将讨论如何识别 Spark 作业的慢速是否由数据倾斜引起,并深入探讨如何通过代码处理 Apache Spark 数据倾斜,解释包括“加盐”技术在内的三种处理数据倾斜的方法。

如何识别 Spark 中的数据倾斜

在 Spark 性能调优方面,有许多因素需要考虑。鉴于分布式计算的复杂性,如果你能将问题缩小到瓶颈位置,那么你已经成功了一半。

数据倾斜通常发生在分区需要处理的数据不均匀时。假设我们在 Spark 中有三个分区来处理 150 万条记录。理想情况下,每个分区均匀地处理 50 万条记录(图片 1 左侧)。然而,也可能出现某个分区处理的数据远多于其他分区的情况(图片 1 右侧)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片 1 | 作者提供的图片

为什么一个分区会处理比其他分区更多的数据? 这与分布系统的工作方式有关。在许多数据处理框架中,数据倾斜是由于数据洗牌引起的,即将数据从一个分区移动到另一个分区。数据洗牌需要在性能调优时加以关注,因为它涉及在集群中的节点之间转移数据。这可能会导致数据管道中出现不必要的延迟,并且难以发现。

数据洗牌很昂贵,但有时执行 宽操作(如 groupBy 和 joins)是不可避免的。这些操作通常是基于键的,即键被哈希后映射到分区。相同的哈希值会被保证洗牌到相同的分区。在上述示例中,许多键被哈希到体量巨大的分区 A 中,分区 A 成为处理近 99% 数据的“热点”。这就是为什么整个作业运行缓慢的原因——数据分布不均,分区 B 和 C 大部分时间闲置,而分区 A 成为处理重负荷的“试验品”。

如何识别 Spark 中的数据倾斜? 我们不能将所有的慢速归咎于数据倾斜。Spark Web UI 是识别 Spark 作业中数据倾斜的最佳本地解决方案。当你在 Spark UI 的 Stages 标签页时,倾斜的分区会在一个阶段内停滞,几乎没有进展。如果我们查看摘要指标,最大列的值通常远大于中位数,并且记录数更多。那么我们就知道我们遇到了数据倾斜问题。

如何知道代码中的哪个部分导致了数据倾斜? Spark UI 中的阶段详情页面只给我们 DAG 的可视化表示。

你怎么知道 Spark 中的哪个部分代码运行缓慢?这在 Spark 官方文档中提到:“整个阶段代码生成操作也会标注 代码生成 ID*。对于 Spark DataFrame 或 SQL 执行的阶段,这允许将阶段执行详情与 Web-UI SQL 标签页中报告的 SQL 计划图和执行计划相关联。*

在以下情况下,我们可以使用 WholeStageCodegen IDs:2、4 或 5。我们可以前往 Spark Data Frame 标签页找到代码,并在 SQL 计划图上悬停以了解代码中正在运行的详细信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

代码生成 ID 示例 | 图片由作者提供

数据倾斜示例设置

我们首先通过设置具有数据倾斜的 Spark 环境来演示问题。我们将只为spark.executor.memory设置 1G,并设置一个具有三个核心的执行器,spark.sql.shuffle.partitions也设置为三,因此最终我们将得到三个分区。我们可以使用spark_partition_id来确定记录属于哪个分区,以验证数据分布。为了确保 Spark 不会自作聪明地进行更多优化,比如增加分区数量或将物理计划转换为广播连接,我们将通过将spark.sql.adaptive.enabled设置为 false 来关闭自适应查询执行(AQE)。

我们不需要导入额外的数据源来设置示例。我们可以创建随机数据并在本文中作为示例进行操作。

案例 1:均匀分布情况

我们将在 Spark 中创建一个包含 1,000,000 行的数据框。在这种情况下,从 0 到 999,999 的值是被哈希和洗牌的键。请注意,这里的键是唯一的,这意味着没有任何重复。这些确保了键是非确定性的。不能保证两个不同的键总是位于同一个分区。

df_evenly = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_evenly = df_evenly.withColumn("partitionId", spark_partition_id())

你可以通过使用getNumPartitions来验证分区数量,在这种情况下,它应该是三,因为我们只有一个执行器和三个核心。

df_evenly.rdd.getNumPartitions()
//output 3

如果一切均匀分布,我们将得到一个良好分布的计数,如果按 partitionId 分组。这是我们上面提到的完美情况图片 1 左

df_evenly.groupby([df_evenly.partitionId]).count().sort(df_evenly.partitionId).show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按 PartitionId 平均分区的数据 | 图像由作者提供

然后我们可以执行自连接来查看计划是什么样的,我们期望会看到SortMergJoin,这是当两个数据集同等重要时通常能做到的最优计划。

df_evenly.alias(“left”).join(df_evenly.alias(“right”),”value”, “inner”).count()

在以下结果中,我们可以看到数据总大小在三个分区之间分布良好,如果查看每个分区所需的时间,它们似乎没有显著的差距。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按照 Spark 物理计划平均分区的数据 | 图像由作者提供

案例 2:倾斜情况

现在,让我们走极端来展示图片 1 右所示的情况,其中我们有一个极度倾斜的数据集。

我们仍将在 Spark 中创建一个包含 1000000 行的数据框。然而,我们不会让所有键都有不同的值,而是将大多数键设为相同。这确保了我们创建一个**“热”键,无论我们尝试多少个哈希函数,都可能成为问题。它保证会在同一个分区中。**

df0 = spark.createDataFrame([0] * 999998, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1], IntegerType()).repartition(1)
df2 = spark.createDataFrame([2], IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2)
df_skew = df_skew.withColumn("partitionId", spark_partition_id())
## If we apply the same function call again, we get what we want to see for the one partition with much more data than the other two.
df_skew.groupby([df_skew.partitionId]).count().sort(df_skew.partitionId).show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据倾斜按 PartitionId| 作者提供的图像

在这种情况下,99.99%的数据在一个分区中。让我们用均匀分布的数据集进行连接,以检查计划是什么样的。在运行连接之前,让我们将倾斜数据集以轮询方式重新分区为三个分区,以模拟实际使用情况中的数据读取方式。

//simulate reading to first round robin distribute the key
df_skew = df_skew.repartition(3)

df_skew.join(df_evenly.select(“value”),”value”, “inner”).count() 

检查 Spark 物理计划,我们可以看到在一个分区中分布不均的大量数据(最大时间),而连接时间是指数级的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据分区倾斜| 作者提供的图像

如何解决数据倾斜问题

数据倾斜会导致 Spark 性能缓慢,作业会被卡在几个分区中而永远挂起。有多种策略可以解决倾斜问题。如今,借助 Spark 的自适应查询执行(AQE),Spark 更容易找到优化的方式。在极端情况下,AQE 并不能 100%提供最佳优化。此时,我们仍需介入,并熟悉需要使用的方法。

1. 利用分区数量

spark.sql.shuffle.partitions可能是 Spark 中最关键的配置之一。它配置了在洗牌数据用于连接或聚合时使用的分区数量。 配置这个 值并不总是意味着可以解决倾斜问题,但它可能是对 Spark 作业的一般优化。默认值为 200,这对于许多大数据项目在过去是合适的,现在仍然适用于小型/中型数据项目。

将其视为在数据在洗牌阶段被处理时的箱子数量。是否有过多的数据需要单个箱子处理,或者它们是否几乎已满?

2. 广播连接

广播连接可能是避免倾斜的最快连接类型。通过提供BROADCAST提示,我们明确向 Spark 提供了需要将哪个数据框发送到每个执行器的信息。

广播连接通常适用于较小的数据框,例如维度表或具有元数据的数据。它不适合具有百万行的事务表。

df_skew.join(**broadcast**(df_evenly.select(“value”)),”value”, “inner”).count()

3. Salting

来自密码学的 SALT 理念引入了随机性到密钥中,而无需了解数据集的上下文。这个理念是,对于给定的热点键,如果它与不同的随机数结合,我们将不会在单个分区中处理所有的该键的数据。SALT 的一个重要好处是它与任何键无关,你不必担心某些具有相似上下文的键再次出现相同的值。

我已经在Skewed Data in Spark? Add SALT to Compensate上发布了另一篇文章。你可以阅读更多内容了解详细信息。

## Spark 中的数据倾斜?添加 SALT 进行补偿

逐步指南:使用 SALT 技术处理数据倾斜

[towardsdatascience.com

然而,在上述文章中,我仅提供了聚合操作中的加盐代码。仍然没有提及如何在连接操作中执行加盐,这留下了一些问题:“我明白我们可以对键进行加盐以均匀分布数据,但这改变了我的连接键。加盐后如何将数据连接回原始键?” 我将在这篇文章中提供一些代码示例。

利用键加盐的核心思想是考虑空间与时间的权衡。

  • 将盐键作为新列的一部分添加到键中。我们还称原始键和盐键为复合键。新增的键迫使 Spark 对新键进行哈希处理,从而生成不同的哈希值,使数据被打乱到不同的分区。请注意,我们也可以通过从 spark.sql.shuffle.partitions 中获取值来动态获取盐键的随机性数量。
df_left = df_skew.withColumn(“salt”, (rand() * spark.conf.get(“spark.sql.shuffle.partitions”)).cast(int))

如下所示,尽管值和 partitionId 相同,我们还是创建了一个额外的“salt”列,以提供更多指导给 Spark 进行连接。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将盐键作为新列的一部分添加到键中 | 图片来源:作者

  • 将所有潜在盐键的数组作为新列添加。你可以选择一个行数较少的数据框(如果行数相同,随机选择一个),并且
df_right = df_evenly.withColumn(“salt_temp”, array([lit(i) for i in range(int(spark.conf.get(“spark.sql.shuffle.partitions”)))]))

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将所有潜在盐键的数组作为新列添加 | 图片来源:作者

  • 使用该数组探索数据框。这将现有行复制 n 次(n=你选择的盐的数量)。当两个数据框连接时,由于我们已经在一侧(通常是右侧)有了复制的数据框,连接依然会被验证。这会产生与使用原始键相同的结果。

我们还可以在连接后验证最终分布。相同键“0”的连接数据框在三个分区中均匀分布。这种均匀分布展示了键加盐的技术。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

连接后的最终 PartitionId | 图片来源:作者

从物理计划来看,数据分布均匀,并且在百分位指标上的处理时间类似。如果选择一个更大的数据集,我们可以明显看出差异。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

键加盐以提高 Spark 性能物理计划 | 图片来源:作者

最终思考

在 Apache Spark 中,数据倾斜可以通过多种方式处理。它可以通过 Spark 配置、Spark 计划优化,或者通过“盐”键引导 Spark 平均分配数据来解决。识别 Spark 作业变慢的原因是任何 Spark 调优的基础。在这些原因中,数据倾斜是常见的罪魁祸首之一。

我写这篇文章是为了帮助大家更好地理解 Spark 中的数据倾斜及其潜在解决方案。然而,当涉及到 Spark 性能优化时,并没有万灵药。你需要投入更多精力查看查询计划,并弄清楚代码中发生了什么。更多知识是通过反复试验获得的。

我希望这篇文章对你有所帮助。这篇文章是我工程与数据科学故事系列的一部分,目前包括以下内容:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

程志赵

数据工程与数据科学故事

查看列表53 篇故事!外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

你也可以订阅我的新文章或成为推荐的 Medium 会员,享受 Medium 上所有故事的无限访问权限。

如有疑问/评论,请随时在本文评论区留言或**直接通过Linkedin**或Twitter联系我。

深入探讨 pandas Copy-on-Write 模式:第一部分

原文:towardsdatascience.com/deep-dive-into-pandas-copy-on-write-mode-part-i-26982e7408c6?source=collection_archive---------5-----------------------#2023-08-09

解释 Copy-on-Write 内部是如何工作的

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Patrick Hoefler

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 8 月 9 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Clint AdairUnsplash 提供

介绍

pandas 2.0于四月初发布,并带来了许多对新 Copy-on-Write (CoW) 模式的改进。该功能预计将成为 pandas 3.0 的默认模式,计划于 2024 年四月发布。目前没有遗留或非 CoW 模式的计划。

这一系列文章将解释 Copy-on-Write 如何在内部工作,以帮助用户了解发生了什么,展示如何有效使用它,并说明如何调整你的代码。这将包括如何利用机制以获得最有效的性能的示例,同时展示一些会导致不必要瓶颈的反模式。我几个月前写了一篇简短的介绍 介绍 Copy-on-Write。

我写了一篇简短的文章,解释了 pandas 的数据结构,这将帮助你理解 CoW 所需的一些术语。

我是 pandas 核心团队的一员,至今参与了 CoW 的实施和改进。我是Coiled的开源工程师,负责 Dask,包括改进 pandas 集成并确保 Dask 符合 CoW 标准。

Copy-on-Write 如何改变 pandas 行为

你们中的许多人可能已经熟悉 pandas 中的以下注意事项:

import pandas as pd

df = pd.DataFrame({"student_id": [1, 2, 3], "grade": ["A", "C", "D"]})

让我们选择 grade 列并用 "E" 覆盖第一行。

grades = df["grade"]
grades.iloc[0] = "E"
df

   student_id grade
0           1     E
1           2     C
2           3     D

不幸的是,这也更新了df而不仅仅是grades,这可能会引入难以发现的错误。CoW 将不允许这种行为,并确保仅更新grades。我们还看到一个无用的SettingWithCopyWarning,对我们没有帮助。

让我们看一个ChainedIndexing的示例,这个示例没有做任何事情:

df[df["student_id"] > 2]["grades"] = "F"
df

   student_id grade
0           1     A
1           2     C
2           3     D

我们再次得到SettingWithCopyWarning,但在这个示例中df 没有发生任何变化。所有这些问题都归结于 NumPy 中的复制和视图规则,这是 pandas 在底层使用的。pandas 用户必须了解这些规则以及它们如何应用于 pandas DataFrame,以理解类似的代码模式为何会产生不同的结果。

CoW 清理了所有这些不一致性。启用 CoW 时,用户只能一次更新一个对象,例如,在第一个示例中,df 将保持不变,因为那时只更新了grades,而第二个示例会引发ChainedAssignmentError,而不是什么都不做。一般来说,不可能一次更新两个对象,例如,每个对象的行为都像是前一个对象的副本。

还有更多这样的情况,但在这里讨论所有这些超出了范围。

如何工作

让我们更详细地探讨 Copy-on-Write,并突出一些值得了解的事实。这是本文的主要部分,相当技术性。

Copy-on-Write 承诺任何从其他 DataFrame 或 Series 派生的 对象始终表现为副本。这意味着不可能通过单个操作修改多个对象,例如我们上面的第一个示例仅会修改grades

为了保证这一点,采取非常防御性的方式是每次操作时复制 DataFrame 及其数据,这样可以完全避免 pandas 中的视图。这将保证 CoW 语义,但也会带来巨大的性能损失,因此这不是一个可行的选项。

我们现在将深入了解确保不会有两个对象通过单次操作更新 我们的数据不会不必要地被复制的机制。第二部分是使实现变得有趣的部分。

我们必须准确知道何时触发复制,以避免不必要的复制。只有在尝试改变一个 pandas 对象的值而不复制其数据时,潜在的复制才是必要的。如果这个对象的数据与另一个 pandas 对象共享,我们必须触发一个复制。这意味着我们必须跟踪是否一个 NumPy 数组被两个 DataFrame 引用(通常,我们需要知道一个 NumPy 数组是否被两个 pandas 对象引用,但为了简单起见,我将使用 DataFrame 这个术语)。

df = pd.DataFrame({"student_id": [1, 2, 3], "grade": [1, 2, 3]})
df2 = df[:]

这个语句创建了一个 DataFrame df 和这个 DataFrame df2 的视图。视图意味着这两个 DataFrame 是由相同的底层 NumPy 数组支撑的。当我们用 CoW 看待这个问题时,df 必须知道 df2 也引用了它的 NumPy 数组。但这还不够。df2 也必须知道 df 引用了它的 NumPy 数组。如果两个对象都知道另一个 DataFrame 引用相同的 NumPy 数组,我们可以在其中一个被修改时触发一个复制,例如:

df.iloc[0, 0] = 100

df 在这里被就地修改。df 知道有另一个对象引用相同的数据,例如,它触发了一个复制。它不知道哪个对象引用相同的数据,只是知道外面有另一个对象。

让我们看看如何实现这一点。我们创建了一个内部类 BlockValuesRefs,用于存储这些信息,它指向所有引用给定 NumPy 数组的 DataFrames。

创建 DataFrame 可以通过三种不同类型的操作:

  • DataFrame 是从外部数据创建的,例如通过 pd.DataFrame(...) 或通过任何 I/O 方法。

  • 通过一个 pandas 操作创建一个新的 DataFrame,这个操作会触发对原始数据的复制,例如 dropna 在几乎所有情况下都会创建一个复制。

  • 通过一个 pandas 操作创建一个新的 DataFrame,这个操作 不会 触发对原始数据的复制,例如 df2 = df.reset_index()

前两个案例很简单。当创建 DataFrame 时,支撑它的 NumPy 数组会连接到一个新的 BlockValuesRefs 对象。这些数组仅被新对象引用,因此我们不必跟踪任何其他对象。该对象创建一个 weakref,指向包裹 NumPy 数组的 Block 并在内部存储这个引用。Blocks 的概念在这里进行了解释。

weakref创建对任何 Python 对象的引用。它不会在对象通常超出作用域时保持该对象存活。

import weakref

class Dummy:
    def __init__(self, a):
        self.a = a

In[1]: obj = Dummy(1)
In[2]: ref = weakref.ref(obj)
In[3]: ref()
Out[3]: <__main__.Dummy object at 0x108187d60>
In[4]: obj = Dummy(2)

这个示例创建了一个 Dummy 对象及其弱引用。随后,我们将另一个对象赋给相同的变量,例如初始对象超出作用域并被垃圾回收。弱引用不会干扰这一过程。如果你解析弱引用,它将指向None而不是原始对象。

In[5]: ref()
Out[5]: None

这确保了我们不会保留任何本应被垃圾回收的数组。

让我们来看看这些对象是如何组织的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们的示例有两列"a""b",它们的 dtype 都是"int64"。它们由一个 Block 支持,该 Block 保存这两列的数据。Block 持有对引用跟踪对象的硬引用,确保只要 Block 没有被垃圾回收,它就会保持活跃。引用跟踪对象持有对 Block 的弱引用。这使得该对象能够跟踪此 Block 的生命周期,但不会阻止垃圾回收。引用跟踪对象尚未持有对任何其他 Block 的弱引用。

这些是简单的场景。我们知道没有其他 pandas 对象共享相同的 NumPy 数组,因此我们可以简单地实例化一个新的引用跟踪对象。

第三种情况更复杂。新对象查看的数据与原始对象相同。这意味着两个对象指向相同的内存。我们的操作将创建一个新的 Block,该 Block 引用相同的 NumPy 数组,这称为浅拷贝。我们现在必须在我们的引用跟踪机制中注册这个新的Block。我们将使用与旧对象连接的引用跟踪对象来注册我们的新Block

df2 = df.reset_index(drop=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们的BlockValuesRefs现在指向支持初始df的 Block 和支持df2的新增 Block。这确保了我们始终了解所有指向相同内存的 DataFrame。

我们现在可以询问引用跟踪对象有多少个指向相同 NumPy 数组的 Block 仍然存在。引用跟踪对象评估弱引用,并告诉我们有多个对象引用相同的数据。这使我们能够在其中一个对象在原地修改时内部触发复制。

df2.iloc[0, 0] = 100

df2中的 Block 通过深拷贝进行复制,创建了一个新的 Block,该 Block 拥有自己的数据和引用跟踪对象。原始的 Block 现在可以被垃圾回收,这确保了dfdf2所支持的数组不会共享任何内存。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

让我们看一个不同的场景。

df = None
df2.iloc[0, 0] = 100

在我们修改df2之前,df已被失效。因此,我们引用跟踪对象的弱引用,指向支持df的 Block,评估结果为None。这使我们能够在不触发复制的情况下修改df2

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

我们的引用跟踪对象仅指向一个 DataFrame,这使我们能够在不触发复制的情况下进行就地操作。

上述reset_index创建了一个视图。如果我们有一个内部触发复制的操作,机制会简单一些。

df2 = df.copy()

这立即为我们的 DataFrame df2 实例化了一个新的引用跟踪对象。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

结论

我们已经研究了 Copy-on-Write 跟踪机制是如何工作的以及何时触发复制。该机制尽可能推迟 pandas 中的复制,这与非 CoW 行为有很大不同。引用跟踪机制跟踪所有共享内存的 DataFrame,从而在 pandas 中实现更一致的行为。

本系列的下一部分将解释用于提高此机制效率的技术。

感谢阅读。如有意见和反馈,请随时联系以分享您对 Copy-on-Write 的看法。

深入探讨 pandas Copy-on-Write 模式—第 II 部分

原文:towardsdatascience.com/deep-dive-into-pandas-copy-on-write-mode-part-ii-b023432a5334?source=collection_archive---------6-----------------------#2023-08-17

解释 Copy-on-Write 如何优化性能

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Patrick Hoefler

·

跟进 发表在 Towards Data Science ·6 分钟阅读·2023 年 8 月 17 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Joshua Brown 拍摄于 Unsplash

介绍

第一篇文章 first post 解释了 Copy-on-Write 机制的工作原理。它重点介绍了在工作流程中引入副本的一些领域。本文将专注于确保这些优化不会减慢平均工作流程的优化。

我们利用了 pandas 内部使用的一种技术,以避免在不必要时复制整个 DataFrame,从而提高性能。

我是 pandas 核心团队的一员,并且在实现和改进 CoW 方面参与了很多。我是 Coiled 的开源工程师,主要负责 Dask 的相关工作,包括改进 pandas 集成,并确保 Dask 符合 CoW 标准。

防御性复制的移除

从最具影响力的改进开始。许多 pandas 方法进行防御性复制以避免副作用,从而保护后续的就地修改。

df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df2 = df.reset_index()
df2.iloc[0, 0] = 100

reset_index 中不需要复制数据,但返回视图在修改结果时会引入副作用,例如 df 也会被更新。因此,在 reset_index 中执行了防御性复制。

启用 Copy-on-Write 后,这些防御性复制将不再存在。这影响了许多方法。完整列表可以在这里找到。

此外,选择 DataFrame 的列子集现在总是返回视图,而不是像之前那样返回复制。

让我们看看当我们结合这些方法时,性能方面会有什么变化:

import pandas as pd
import numpy as np

N = 2_000_000
int_df = pd.DataFrame(
    np.random.randint(1, 100, (N, 10)), 
    columns=[f"col_{i}" for i in range(10)],
)
float_df = pd.DataFrame(
    np.random.random((N, 10)), 
    columns=[f"col_{i}" for i in range(10, 20)],
)
str_df = pd.DataFrame(
    "a", 
    index=range(N), 
    columns=[f"col_{i}" for i in range(20, 30)],
)

df = pd.concat([int_df, float_df, str_df], axis=1)

这会创建一个具有 30 列、3 种不同数据类型和 200 万行的 DataFrame。让我们在这个 DataFrame 上执行以下方法链:

%%timeit
(
    df.rename(columns={"col_1": "new_index"})
    .assign(sum_val=df["col_1"] + df["col_2"])
    .drop(columns=["col_10", "col_20"])
    .astype({"col_5": "int32"})
    .reset_index()
    .set_index("new_index")
)

启用 CoW 前,所有这些方法都会执行防御性复制。

没有 CoW 的性能:

2.45 s ± 293 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

启用 CoW 的性能:

13.7 ms ± 286 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

大约提升了 200 倍。我特意选择了这个例子来说明 CoW 的潜在好处,并不是每个方法的速度都会提升这么多。

优化因就地修改触发的复制

上一部分展示了许多方法在不再需要防御性复制的情况下。CoW 保证你不能同时修改两个对象。这意味着当两个 DataFrame 参考相同数据时,我们必须引入复制。让我们来看看如何使这些复制尽可能高效。

上一篇文章展示了以下情况可能会触发复制:

df.iloc[0, 0] = 100

如果 df 的数据被另一个 DataFrame 参考,则会触发复制。我们假设我们的 DataFrame 有 n 个整数列,例如由一个 Block 支持。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们的参考跟踪对象也引用了另一个 Block,因此我们不能在不修改其他对象的情况下就地修改 DataFrame。一个简单的方法是复制整个块然后完成。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这将设置一个新的引用跟踪对象,并创建一个由新的 NumPy 数组支持的新块。这个块没有任何其他引用,因此另一个操作将能够再次原地修改它。这种方法复制了n-1列,而我们不一定需要复制这些列。我们利用一种称为块拆分的技术来避免这种情况。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

内部只复制了第一列。所有其他列都作为对先前数组的视图。新块与其他列没有共享引用。旧块仍与其他对象共享引用,因为它只是对先前值的视图。

这种技术有一个缺点。初始数组有n列。我们创建了从列2n的视图,但这会保持整个数组的存在。我们还添加了一个只有一列的新数组用于第一列。这将比必要时多占用一点内存。

这个系统直接转换为具有不同数据类型的 DataFrames。所有未被修改的块会原样返回,只有被原地修改的块才会被拆分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

我们现在在列n+1的浮点块中设置一个新值,以创建对列n+2m的视图。新块将只支持列n+1

df.iloc[0, n+1] = 100.5

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

可以原地操作的方法

我们查看的索引操作通常不会创建新对象;它们会原地修改现有对象,包括该对象的数据。另一组 pandas 方法则完全不涉及 DataFrame 的数据。一个显著的例子是rename。Rename 只会更改标签。这些方法可以利用上述提到的惰性复制机制。

还有第三组方法实际上可以原地操作,如replacefillna。这些方法将始终触发复制。

df2 = df.replace(...)

修改数据时如果不触发复制,则会修改dfdf2,这违反了 CoW 规则。这是我们考虑保留这些方法的inplace关键字的原因之一。

df.replace(..., inplace=True)

这将解决这个问题。这仍然是一个开放提案,可能会朝不同的方向发展。也就是说,这仅涉及实际被更改的列;所有其他列仍然以视图形式返回。这意味着,如果你的值只出现在一列中,则只会复制一列。

结论

我们研究了 CoW 如何改变 pandas 的内部行为,以及这将如何转化为代码的改进。许多方法在使用 CoW 时会变得更快,而我们会看到一些与索引相关的操作变慢。以前,这些操作总是原地进行的,这可能产生副作用。这些副作用在 CoW 中消失了,对一个 DataFrame 对象的修改将永远不会影响另一个对象。

本系列的下一篇文章将解释如何更新你的代码以符合 CoW 标准。此外,我们还将说明未来应该避免哪些模式。

感谢阅读。如有任何关于写时复制(Copy-on-Write)的想法和反馈,请随时联系我们。

深入探讨 Pandas 的 Copy-on-Write 模式 — 第三部分

原文:towardsdatascience.com/deep-dive-into-pandas-copy-on-write-mode-part-iii-c024eaa16ed4?source=collection_archive---------10-----------------------#2023-09-29

解释 Copy-on-Write 的迁移路径

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Patrick Hoefler

·

关注 发布于 Towards Data Science ·4 分钟阅读·2023 年 9 月 29 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Zoe Nicolaou 提供,来源于 Unsplash

介绍

引入写时复制(CoW)是一个重大变更,将对现有的 pandas 代码产生一定影响。我们将研究如何调整我们的代码以避免在 CoW 默认启用时出现错误。这目前计划在 2024 年 4 月发布的 pandas 3.0 版本中实现。本系列的第一篇帖子解释了写时复制的行为,而第二篇帖子则深入探讨了与写时复制相关的性能优化。

我们计划添加一个警告模式,该模式将对所有使用 CoW(写时复制)可能改变行为的操作发出警告。由于警告可能会对用户产生很大的干扰,因此必须谨慎处理。本文解释了常见情况以及如何调整代码以避免行为变化。

链式赋值

链式赋值是一种通过两个连续操作更新一个对象的技术。

import pandas as pd

df = pd.DataFrame({"x": [1, 2, 3]})

df["x"][df["x"] > 1] = 100

第一个操作选择了列"x",而第二个操作限制了行数。这些操作有许多不同的组合(例如,与lociloc结合使用)。在 CoW 下,这些组合都不会起作用。相反,它们会引发ChainedAssignmentError警告,以便移除这些模式,而不是默默无闻地什么也不做。

通常,你可以使用loc来代替:

df.loc[df["x"] > 1, "x"] = 100

loc的第一个维度总是对应于row-indexer。这意味着你可以选择一个行的子集。第二个维度对应于column-indexer,这使你能够选择一个列的子集。

当你想要在行的子集上设置值时,使用loc通常会更快,因此这将清理你的代码并提供性能提升。

这是 CoW 将产生影响的明显情况。它也会影响链式的就地操作:

df["x"].replace(1, 100)

模式与上述相同。列选择是第一个操作。replace方法尝试在临时对象上操作,这将无法更新初始对象。你也可以通过指定要操作的列来轻松移除这些模式。

df = df.replace({"x": 1}, {"x": 100})

避免的模式

我之前的帖子解释了 CoW 机制如何工作以及 DataFrames 如何共享底层数据。如果两个对象共享相同的数据,而你在就地修改一个对象时,将会执行防御性复制。

df2 = df.reset_index()
df2.iloc[0, 0] = 100

reset_index 操作将创建底层数据的视图。结果分配给一个新变量 df2,这意味着两个对象共享相同的数据。这在 df 被垃圾回收之前一直有效。因此,setitem 操作会触发复制。如果你不再需要初始对象 df,这完全没有必要。只需重新分配到相同的变量将使对象所持有的引用失效。

df = df.reset_index()
df.iloc[0, 0] = 100

总结来说,在同一方法中创建多个引用会保持不必要的引用活跃。

链接不同方法时创建的临时引用是可以的。

df = df.reset_index().drop(...)

这只会保持一个引用活跃。

访问底层 NumPy 数组

pandas 目前通过 to_numpy.values 让我们访问底层的 NumPy 数组。如果你的 DataFrame 由不同的数据类型组成,例如:

df = pd.DataFrame({"a": [1, 2], "b": [1.5, 2.5]})
df.to_numpy()

[[1\.  1.5]
 [2\.  2.5]]

DataFrame 由两个数组支持,这两个数组必须合并为一个。这会触发复制。

另一种情况是 DataFrame 仅由一个 NumPy 数组支持,例如:

df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
df.to_numpy()

[[1 3]
 [2 4]]

我们可以直接访问数组并获取视图,而不是复制。这比复制所有数据要快得多。我们现在可以对 NumPy 数组进行操作,并可能就地修改它,这也会更新 DataFrame,并可能更新所有共享数据的其他 DataFrame。由于我们移除了许多防御性复制,这在写时复制的情况下变得更加复杂。现在,更多的 DataFrame 将相互共享内存。

to_numpy.values 将因而返回只读数组。这意味着结果数组是不可写的。

df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
arr = df.to_numpy()

arr[0, 0] = 1

这将触发一个 ValueError

ValueError: assignment destination is read-only

你可以通过两种不同的方式避免这种情况:

  • 如果你想避免更新与数组共享内存的 DataFrame,请手动触发复制。

  • 使数组可写。这是一种更高效的解决方案,但会绕过写时复制(Copy-on-Write)规则,因此应谨慎使用。

arr.flags.writeable = True

在某些情况下,这不可能实现。一个常见的情况是,当你访问由 PyArrow 支持的单列时:

ser = pd.Series([1, 2], dtype="int64[pyarrow]")
arr = ser.to_numpy()
arr.flags.writeable = True

这将返回一个 ValueError

ValueError: cannot set WRITEABLE flag to True of this array

Arrow 数组是不可变的,因此无法使 NumPy 数组可写。在这种情况下,Arrow 到 NumPy 的转换是零复制的。

结论

我们已经看到了最具侵入性的写时复制相关更改。这些更改将成为 pandas 3.0 的默认行为。我们还研究了如何调整代码,以避免在启用写时复制时破坏代码。如果你能避免这些模式,升级过程应该会非常顺利。

深入探讨模型可解释性的 PFI

原文:towardsdatascience.com/deep-dive-into-pfi-for-model-interpretability-f12f0c64226c?source=collection_archive---------11-----------------------#2023-07-20

另一个可供选择的可解释性工具

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Tiago Toledo Jr.

·

关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 7 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 fabio 提供,来源于 Unsplash

了解如何评估你的模型对于数据科学家的工作至关重要。如果你不能完全理解并向利益相关者传达你的解决方案,没有人会批准它。这就是为什么了解可解释性方法如此重要的原因。

缺乏可解释性可能会毁掉一个非常好的模型。我还没有开发过一个我的利益相关者不关心理解预测如何产生的模型。因此,知道如何解释模型并将其传达给业务是数据科学家的核心能力。

在这篇文章中,我们将深入探讨置换特征重要性(PFI),这是一种与模型无关的方法,可以帮助我们识别模型中最重要的特征,从而更好地沟通模型在进行预测时的考虑因素。

置换特征重要性是什么

PFI 方法尝试估计一个特征对模型结果的重要性,基于我们改变与目标变量相关的特征时模型的表现。

为了做到这一点,对于每个特征,我们要分析其重要性,我们将其随机打乱,同时保持其他特征和目标不变。

这使得特征在预测目标时变得无用,因为我们通过改变它们的联合分布打破了它们之间的关系。

然后,我们可以使用模型来预测我们打乱的数据集。模型性能的减少量将指示该特征的重要性。

算法大致如下:

  • 我们在训练数据集上训练一个模型,然后在训练集和测试集上评估其表现。

  • 对于每个特征,我们创建一个新的数据集,其中该特征被打乱。

  • 然后我们使用训练好的模型来预测新数据集的输出。

  • 新的性能指标与旧指标的比值给出了特征的重要性。

请注意,如果一个特征不重要,模型的表现不应有太大变化。如果它重要,那么表现应该会有很大变化。

解释 PFI

现在我们知道如何计算 PFI,我们如何解释它呢?

这取决于我们将 PFI 应用到哪个折叠。我们通常有两个选项:将其应用于训练数据集或测试数据集。

训练解释

在训练过程中,我们的模型学习数据的模式并尝试表示它。当然,在训练过程中,我们无法知道我们的模型对未见数据的泛化效果如何。

因此,通过将 PFI 应用于训练数据集,我们将看到哪些特征对模型学习数据表示最为相关。

从业务角度来看,这表明哪些特征对模型构建最为重要。

测试解释

现在,如果我们将方法应用于测试数据集,我们将看到特征对模型泛化的影响。

让我们考虑一下。如果我们在打乱某个特征后看到模型在测试集上的表现下降,这意味着该特征对该数据集的表现很重要。由于测试集是我们用来测试泛化的(如果你做得对的话),那么我们可以说它对泛化很重要。

PFI 的问题

PFI 分析了特征对模型性能的影响,因此,它并没有说明原始数据的任何信息。如果模型性能较差,那么你通过 PFI 找到的任何关系都是没有意义的。

这对两种情况都适用,如果你的模型出现欠拟合(训练集上的预测能力低)或过拟合(测试集上的预测能力低),那么你不能从这个方法中获得有用的见解。

此外,当两个特征高度相关时,PFI 可能会误导你的解释。如果你打乱一个特征,但所需的信息编码在另一个特征中,那么性能可能不会受到影响,这可能会让你认为这个特征是无用的,但实际上可能并非如此。

在 Python 中实现 PFI

要在 Python 中实现 PFI,我们首先必须导入所需的库。为此,我们主要使用 numpy、pandas、tqdm 和 sklearn 这些库:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes, load_iris
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import accuracy_score, r2_score

现在,我们必须加载数据集,将使用 Iris 数据集。然后,我们将对数据进行随机森林拟合。

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
  X, y, test_size=0.3, random_state=12, shuffle=True
)

rf = RandomForestClassifier(
  n_estimators=3, random_state=32
).fit(X_train, y_train)

在模型拟合完成后,让我们分析其性能,以确定是否可以安全地应用 PFI 来查看特征如何影响我们的模型:

print(accuracy_score(rf.predict(X_train), y_train))
print(accuracy_score(rf.predict(X_test), y_test))

我们可以看到,在训练集上我们达到了 99% 的准确率,在测试集上达到了 95.5% 的准确率。目前看来不错。让我们获取原始错误评分以便后续比较:

original_error_train = 1 - accuracy_score(rf.predict(X_train), y_train)
original_error_test = 1 - accuracy_score(rf.predict(X_test), y_test)

现在让我们计算置换得分。为此,通常需要对每个特征进行多次打乱,以获得特征得分的统计数据,从而避免任何巧合。在我们的案例中,让我们对每个特征进行 10 次重复:

n_steps = 10

feature_values = {}
for feature in range(X.shape[1]):
  # We will save each new performance point for each feature
    errors_permuted_train = []
    errors_permuted_test = []

    for step in range(n_steps):
        # We grab the data again because the np.random.shuffle function shuffles in place
        X, y = load_iris(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=12, shuffle=True)
        np.random.shuffle(X_train[:, feature])
        np.random.shuffle(X_test[:, feature])

    # Apply our previously fitted model on the new data to get the performance
        errors_permuted_train.append(1 - accuracy_score(rf.predict(X_train), y_train))
        errors_permuted_test.append(1 - accuracy_score(rf.predict(X_test), y_test))

    feature_values[f'{feature}_train'] = errors_permuted_train
    feature_values[f'{feature}_test'] = errors_permuted_test

现在我们有了一个包含每次打乱性能的字典。接下来,让我们生成一个表格,该表格对每个折叠中的每个特征显示其性能的平均值和标准差,并与模型的原始性能进行比较:

PFI = pd.DataFrame()
for feature in feature_values:
    if 'train' in feature:
        aux = feature_values[feature] / original_error_train
        fold = 'train'
    elif 'test' in feature:
        aux = feature_values[feature] / original_error_test
        fold = 'test'

    PFI = PFI.append({
        'feature': feature.replace(f'_{fold}', ''),
        'pfold': fold,
        'mean':np.mean(aux),
        'std':np.std(aux),
    }, ignore_index=True)

PFI = PFI.pivot(index='feature', columns='fold', values=['mean', 'std']).reset_index().sort_values(('mean', 'test'), ascending=False)

我们将得到如下结果:

我们可以看到,特征 2 似乎是数据集中最重要的特征,其次是特征 3。由于我们没有固定 numpy 打乱函数的随机种子,我们可以预期这个数字会有所变化。

然后我们可以绘制一个图表,以更好地可视化重要性:

结论

PFI 是一种简单的方法论,可以帮助你快速识别最重要的特征。继续尝试将它应用到你正在开发的模型中,看看它的表现如何。

但也要注意方法的局限性。如果不了解方法的不足之处,将会导致错误的解释。

另外,注意 PFI 显示的是特征的重要性,但并没有说明它对模型输出的影响方向。

那么,告诉我,你打算如何在下一个模型中使用这个方法?

敬请关注更多关于可解释性方法的帖子,这些方法可以提高你对模型的整体理解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值