在 Rust 中构建跨平台 TFIDF 文本摘要器
跨平台的 Rust NLP
使用 Rayon 进行优化,并支持 C/C++、Android、Python
·
关注 发布于 Towards Data Science · 12 分钟阅读 · 2023 年 12 月 14 日
–
照片由 Patrick Tomasso 提供,来自 Unsplash
NLP 工具和实用程序在 Python 生态系统中得到了广泛发展,使得各级开发者能够大规模地构建高质量的语言应用。Rust 是 NLP 的新兴领域,像 HuggingFace 这样的组织正在采用它来构建机器学习包。
[## Hugging Face 已经用 Rust 编写了一个新的机器学习框架,现在已开源!
最近,Hugging Face 开源了一个重量级的机器学习框架 Candle,这与通常的 Python 有所不同…
medium.com](https://medium.com/@Aaron0928/hugging-face-has-written-a-new-ml-framework-in-rust-now-open-sourced-1afea2113410?source=post_page-----7b05938f4507--------------------------------)
在这篇博客中,我们将探讨如何利用 TFIDF 的概念构建一个文本总结器。我们首先将了解 TFIDF 总结的工作原理,以及为什么 Rust 可能是实现 NLP 管道的好语言,以及如何在其他平台如 C/C++、Android 和 Python 上使用我们的 Rust 代码。此外,我们将讨论如何利用 Rayon 进行并行计算来优化总结任务。
这是 GitHub 项目:
[## GitHub - shubham0204/tfidf-summarizer.rs: 简单、高效且跨平台的基于 TFIDF 的文本…
简单、高效且跨平台的基于 TFIDF 的 Rust 文本总结器 - GitHub - shubham0204/tfidf-summarizer.rs…
github.com](https://github.com/shubham0204/tfidf-summarizer.rs?source=post_page-----7b05938f4507--------------------------------)
开始吧 ➡️
目录
-
动机
-
提取式和抽象式文本总结
-
使用 TFIDF 理解文本总结
-
Rust 实现
-
与 C 的使用
-
未来展望
-
结论
动机
我在 2019 年使用相同的技术构建了一个文本总结器,使用 Kotlin 并称之为 Text2Summary。它主要为 Android 应用程序设计,作为一个副项目,使用 Kotlin 进行所有计算。快进到 2023 年,我现在正在处理 C、C++ 和 Rust 代码库,并且在 Android 和 Python 中使用了这些 本地 语言构建的模块。
我选择用 Rust 重新实现 Text2Summary
,因为这将是一个很好的学习经验,同时也是一个小巧高效的文本总结工具,可以轻松处理大文本。Rust 是一种编译语言,具有智能的借用和引用检查器,帮助开发者编写无 bug 的代码。用 Rust 编写的代码可以通过 jni
与 Java 代码库集成,并转换为 C 头文件/库,用于 C/C++ 和 Python。
提取式和抽象式文本总结
文本总结一直是自然语言处理(NLP)中长期研究的问题。从文本中提取重要信息并生成文本摘要是文本总结器需要解决的核心问题。解决方案分为两类,即提取式总结和抽象式总结。
我们如何自动总结文档?
[towardsdatascience.com
在抽取式文本总结中,短语或句子直接从句子中提取。我们可以使用评分函数对句子进行排名,并根据它们的分数从文本中选择最合适的句子。与抽象总结中生成新文本不同,摘要是从文本中选择的句子的集合,从而避免了生成模型所展示的问题。
-
在抽取式总结中,文本的精确度得以保持,但由于选择的文本粒度仅限于句子,信息丢失的可能性较高。如果一条信息分散在多个句子中,评分函数必须考虑包含这些句子的关系。
-
抽象式文本总结需要更大的深度学习模型来捕捉语言的语义,并建立适当的文档到摘要的映射。训练此类模型需要大量数据集和较长的训练时间,这会重负计算资源。预训练模型可能解决了更长训练时间和数据需求的问题,但仍然固有地偏向于其训练的文本领域。
-
抽取式方法可能有不带参数的评分函数,无需任何学习。它们属于无监督学习的 ML 领域,有用的是它们需要的计算较少且不偏向于文本领域。总结在新闻文章和小说摘录中可能同样有效。
使用我们的基于 TFIDF 的技术,我们不需要任何训练数据集或深度学习模型。我们的评分函数基于不同句子中词汇的相对频率。
使用 TFIDF 理解文本总结
为了对每个句子进行排序,我们需要计算一个分数来量化句子中信息的量。TF-IDF 包含两个术语——TF,表示词频,以及 IDF,表示逆文档频率。
## 从头开始使用 python 创建 TF(词频)-IDF(逆文档频率)。
从头开始创建 TF-IDF 模型
[towardsdatascience.com
我们认为每个句子由词汇(单词)组成,
表达式 1:句子 S 表示为单词元组
每个单词在句子 S 中的词频定义为,
表达式 2:k 代表句子中的总词数。
每个单词在句子 S 中的逆文档频率定义为,
表达式 3:逆文档频率量化了该词在其他句子中的出现情况。
每个句子的分数是该句子中所有单词的 TFIDF 分数之和,
表达式 4:每个句子的分数 S 决定了它是否包含在最终总结中。
重要性与直觉
正如你可能已经观察到的,词频对于句子中较少出现的单词来说会更低。如果同一个词在其他句子中的出现也较少,那么 IDF 分数也会更高。因此,一个包含重复单词(较高 TF)且这些单词仅在该句子中较为独特(较高 IDF)的句子将具有更高的 TFIDF 分数。
Rust 实现
我们通过创建将给定文本转换为 Vec
句子的函数来开始实现我们的技术。这个问题被称为句子分词,它在文本中识别句子边界。使用像 nltk
这样的 Python 包,punkt
句子分词器可用于此任务,并且也存在 Punkt 的 Rust 移植版。[rust-punkt](https://github.com/ferristseng/rust-punkt)
不再维护,但我们在这里仍然使用它。还编写了另一个将句子拆分为单词的函数,
use punkt::{SentenceTokenizer, TrainingData};
use punkt::params::Standard;
static STOPWORDS: [ &str ; 127 ] = [ "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you",
"your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself",
"it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this",
"that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having",
"do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of",
"at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above",
"below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once",
"here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other",
"some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can",
"will", "just", "don", "should", "now" ] ;
/// Transform a `text` into a list of sentences
/// It uses the popular Punkt sentence tokenizer from a Rust port:
/// <`/`>https://github.com/ferristseng/rust-punkt<`/`>
pub fn text_to_sentences( text: &str ) -> Vec<String> {
let english = TrainingData::english();
let mut sentences: Vec<String> = Vec::new() ;
for s in SentenceTokenizer::<Standard>::new(text, &english) {
sentences.push( s.to_owned() ) ;
}
sentences
}
/// Transforms the sentence into a list of words (tokens)
/// eliminating stopwords while doing so
pub fn sentence_to_tokens( sentence: &str ) -> Vec<&str> {
let tokens: Vec<&str> = sentence.split_ascii_whitespace().collect() ;
let filtered_tokens: Vec<&str> = tokens
.into_iter()
.filter( |token| !STOPWORDS.contains( &token.to_lowercase().as_str() ) )
.collect() ;
filtered_tokens
}
在上述代码片段中,我们删除停用词,即在语言中常见且对文本信息内容没有重要贡献的词汇。
关于在 Python 中删除英文停用词的实用指南!
towardsdatascience.com
接下来,我们创建一个计算语料库中每个单词频率的函数。此方法将用于计算句子中每个单词的词频。(word, freq)
对存储在 [Hashmap](https://doc.rust-lang.org/std/collections/struct.HashMap.html)
中,以便在后续阶段快速检索。
use std::collections::HashMap;
/// Given a list of words, build a frequency map
/// where keys are words and values are the frequencies of those words
/// This method will be used to compute the term frequencies of each word
/// present in a sentence
pub fn get_freq_map<'a>( words: &'a Vec<&'a str> ) -> HashMap<&'a str,usize> {
let mut freq_map: HashMap<&str,usize> = HashMap::new() ;
for word in words {
if freq_map.contains_key( word ) {
freq_map
.entry( word )
.and_modify( | e | {
*e += 1 ;
} ) ;
}
else {
freq_map.insert( *word , 1 ) ;
}
}
freq_map
}
接下来,我们编写了一个计算句子中单词词频的函数,
// Compute the term frequency of tokens present in the given sentence (tokenized)
// Term frequency TF of token 'w' is expressed as,
// TF(w) = (frequency of w in the sentence) / (total number of tokens in the sentence)
fn compute_term_frequency<'a>(
tokenized_sentence: &'a Vec<&str>
) -> HashMap<&'a str,f32> {
let words_frequencies = Tokenizer::get_freq_map( tokenized_sentence ) ;
let mut term_frequency: HashMap<&str,f32> = HashMap::new() ;
let num_tokens = tokenized_sentence.len() ;
for (word , count) in words_frequencies {
term_frequency.insert( word , ( count as f32 ) / ( num_tokens as f32 ) ) ;
}
term_frequency
}
另一个函数计算了分词句子中单词的 IDF(逆文档频率),
// Compute the inverse document frequency of tokens present in the given sentence (tokenized)
// Inverse document frequency IDF of token 'w' is expressed as,
// IDF(w) = log( N / (Number of documents in which w appears) )
fn compute_inverse_doc_frequency<'a>(
tokenized_sentence: &'a Vec<&str> ,
tokens: &'a Vec<Vec<&'a str>>
) -> HashMap<&'a str,f32> {
let num_docs = tokens.len() as f32 ;
let mut idf: HashMap<&str,f32> = HashMap::new() ;
for word in tokenized_sentence {
let mut word_count_in_docs: usize = 0 ;
for doc in tokens {
word_count_in_docs += doc.iter().filter( |&token| token == word ).count() ;
}
idf.insert( word , ( (num_docs) / (word_count_in_docs as f32) ).log10() ) ;
}
idf
}
我们现在已经添加了计算句子中每个单词的 TF 和 IDF 分数的函数。为了计算每个句子的最终分数,也就是决定其排名的分数,我们必须计算句子中所有单词的 TFIDF 分数总和。
pub fn compute(
text: &str ,
reduction_factor: f32
) -> String {
let sentences_owned: Vec<String> = Tokenizer::text_to_sentences( text ) ;
let mut sentences: Vec<&str> = sentences_owned
.iter()
.map( String::as_str )
.collect() ;
let mut tokens: Vec<Vec<&str>> = Vec::new() ;
for sentence in &sentences {
tokens.push( Tokenizer::sentence_to_tokens(sentence) ) ;
}
let mut sentence_scores: HashMap<&str,f32> = HashMap::new() ;
for ( i , tokenized_sentence ) in tokens.iter().enumerate() {
let tf: HashMap<&str,f32> = Summarizer::compute_term_frequency(tokenized_sentence) ;
let idf: HashMap<&str,f32> = Summarizer::compute_inverse_doc_frequency(tokenized_sentence, &tokens) ;
let mut tfidf_sum: f32 = 0.0 ;
// Compute TFIDF score for each word
// and add it to tfidf_sum
for word in tokenized_sentence {
tfidf_sum += tf.get( word ).unwrap() * idf.get( word ).unwrap() ;
}
sentence_scores.insert( sentences[i] , tfidf_sum ) ;
}
// Sort sentences by their scores
sentences.sort_by( | a , b |
sentence_scores.get(b).unwrap().total_cmp(sentence_scores.get(a).unwrap()) ) ;
// Compute number of sentences to be included in the summary
// and return the extracted summary
let num_summary_sents = (reduction_factor * (sentences.len() as f32) ) as usize;
sentences[ 0..num_summary_sents ].join( " " )
}
使用 Rayon
对于较大的文本,我们可以在多个 CPU 线程上并行执行一些操作,即使用流行的 Rust crate [rayon-rs](https://github.com/rayon-rs/rayon)
。在上面的 compute
函数中,我们可以并行执行以下任务,
-
将每个句子转换为 tokens 并移除停用词
-
计算每个句子的 TFIDF 分数总和
这些任务可以独立地在每个句子上执行,并且互相之间没有依赖,因此它们可以被并行化。为了确保不同线程访问共享容器时的互斥性,我们使用 [Arc](https://doc.rust-lang.org/rust-by-example/std/arc.html)
(原子引用计数指针) 和 [Mutex](https://fongyoong.github.io/easy_rust/Chapter_43.html)
,这是确保原子访问的基本同步原语。
Arc
确保被引用的 Mutex
对所有线程都是可访问的,而 Mutex
本身只允许单线程访问它所包装的对象。这里有另一个函数 par_compute
,它使用 Rayon 并行执行上述任务,
pub fn par_compute(
text: &str ,
reduction_factor: f32
) -> String {
let sentences_owned: Vec<String> = Tokenizer::text_to_sentences( text ) ;
let mut sentences: Vec<&str> = sentences_owned
.iter()
.map( String::as_str )
.collect() ;
// Tokenize sentences in parallel with Rayon
// Declare a thread-safe Vec<Vec<&str>> to hold the tokenized sentences
let tokens_ptr: Arc<Mutex<Vec<Vec<&str>>>> = Arc::new( Mutex::new( Vec::new() ) ) ;
sentences.par_iter()
.for_each( |sentence| {
let sent_tokens: Vec<&str> = Tokenizer::sentence_to_tokens(sentence) ;
tokens_ptr.lock().unwrap().push( sent_tokens ) ;
} ) ;
let tokens = tokens_ptr.lock().unwrap() ;
// Compute scores for sentences in parallel
// Declare a thread-safe Hashmap<&str,f32> to hold the sentence scores
let sentence_scores_ptr: Arc<Mutex<HashMap<&str,f32>>> = Arc::new( Mutex::new( HashMap::new() ) ) ;
tokens.par_iter()
.zip( sentences.par_iter() )
.for_each( |(tokenized_sentence , sentence)| {
let tf: HashMap<&str,f32> = Summarizer::compute_term_frequency(tokenized_sentence) ;
let idf: HashMap<&str,f32> = Summarizer::compute_inverse_doc_frequency(tokenized_sentence, &tokens ) ;
let mut tfidf_sum: f32 = 0.0 ;
for word in tokenized_sentence {
tfidf_sum += tf.get( word ).unwrap() * idf.get( word ).unwrap() ;
}
tfidf_sum /= tokenized_sentence.len() as f32 ;
sentence_scores_ptr.lock().unwrap().insert( sentence , tfidf_sum ) ;
} ) ;
let sentence_scores = sentence_scores_ptr.lock().unwrap() ;
// Sort sentences by their scores
sentences.sort_by( | a , b |
sentence_scores.get(b).unwrap().total_cmp(sentence_scores.get(a).unwrap()) ) ;
// Compute number of sentences to be included in the summary
// and return the extracted summary
let num_summary_sents = (reduction_factor * (sentences.len() as f32) ) as usize;
sentences[ 0..num_summary_sents ].join( ". " )
}
跨平台使用
C 和 C++
为了在 C 中使用 Rust 结构体和函数,我们可以使用 cbindgen
生成包含结构体/函数原型的 C 风格头文件。生成头文件后,我们可以将 Rust 代码编译成基于 C 的 动态或静态库,这些库包含头文件中声明的函数的实现。为了生成基于 C 的静态库,我们需要在 Cargo.toml
中将 [crate_type](https://doc.rust-lang.org/cargo/reference/cargo-targets.html)
参数设置为 staticlib
,
[lib]
name = "summarizer"
crate_type = [ "staticlib" ]
接下来,我们添加 FFIs 以在 src/lib.rs
的 ABI (应用程序二进制接口) 中暴露总结器的功能,
/// functions exposing Rust methods as C interfaces
/// These methods are accessible with the ABI (compiled object code)
mod c_binding {
use std::ffi::CString;
use crate::summarizer::Summarizer;
#[no_mangle]
pub extern "C" fn summarize( text: *const u8 , length: usize , reduction_factor: f32 ) -> *const u8 {
...
}
#[no_mangle]
pub extern "C" fn par_summarize( text: *const u8 , length: usize , reduction_factor: f32 ) -> *const u8 {
...
}
}
我们可以使用 cargo build
构建静态库,libsummarizer.a
将在 target
目录中生成。
安卓
使用 Android 的本地开发工具包 (NDK),我们可以为 armeabi-v7a
和 arm64-v8a
目标编译 Rust 程序。我们需要使用 Java Native Interface (JNI) 编写特殊的接口函数,这些函数可以在 src/lib.rs
的 android
模块中找到。
[## Kotlin JNI for Native Code
如何从 Kotlin 调用本地代码。
Python
使用 Python 的 ctypes
模块,我们可以加载共享库(.so
或 .dll
)并使用与 C 兼容的数据类型来执行库中定义的函数。代码尚未在 GitHub 项目上提供,但很快会提供。
[## Python 绑定:从 Python 调用 C 或 C++ - Real Python
什么是 Python 绑定?你应该使用 ctypes、CFFI 还是其他工具?在这篇逐步教程中,你将获得…
未来展望
该项目可以以多种方式扩展和改进,我们将在下面讨论:
-
当前的实现要求使用
[nightly](https://doc.rust-lang.org/book/appendix-07-nightly-rust.html)
Rust 构建,仅仅是因为一个依赖项punkt
。punkt
是一个句子分词器,用于确定文本中的句子边界,之后进行其他计算。如果punkt
可以使用稳定版 Rust 构建,那么当前实现将不再需要nightly
Rust。 -
添加新的度量标准来排名句子,特别是那些捕捉句子间依赖关系的度量。TFIDF 不是最准确的评分函数,并且有其自身的局限性。构建句子图并利用它们对句子进行评分,大大提高了提取摘要的整体质量。
-
摘要生成器尚未在已知数据集上进行基准测试。Rouge 分数
[R1](https://en.wikipedia.org/wiki/ROUGE_(metric))
,[R2](https://en.wikipedia.org/wiki/ROUGE_(metric))
和[RL](https://en.wikipedia.org/wiki/ROUGE_(metric))
常用于评估生成的摘要与标准数据集如 纽约时报数据集 或 CNN 日报数据集的质量。与标准基准进行性能测量将为开发者提供更清晰、更可靠的实现参考。
结论
使用 Rust 构建 NLP 工具具有显著优势,考虑到由于性能和未来前景,语言在开发者中的受欢迎程度不断上升。我希望这篇文章对你有所帮助。请查看 GitHub 项目:
## GitHub - shubham0204/tfidf-summarizer.rs: 简单、高效且跨平台的基于 TFIDF 的文本摘要工具
简单、高效且跨平台的基于 TFIDF 的文本摘要工具,使用 Rust 编写 - GitHub - shubham0204/tfidf-summarizer.rs
如果你觉得有改进的空间,可以考虑提出问题或提交拉取请求!继续学习,祝你有美好的一天。
建立一个自定义 GPT:教训与技巧
从兴奋到失望,最终走向问题的解决与赞赏
·
关注 发表在 Towards Data Science ·13 分钟阅读 · 2023 年 11 月 15 日
–
The Causal Mindset(Quentin Gallea 的个性化 GPT),由 Dall-E 生成。
在上周二(2023 年 11 月 6 日),Sam Altman(OpenAI CEO)发布了 GPTs 的版本,允许任何人使用自然语言创建个性化的 ChatGPT。
像许多人一样,我也跟上了热潮,过去几周里日夜兼程,到了感觉大脑快要烧坏的地步。这真是一段过山车式的经历,我既惊讶,又失望。但最终,我找到了解决问题的方法,接受了局限性,并且现在对此充满热情。
你将在本文中发现什么: 在这篇文章中,我将首先介绍我的应用程序以提供背景,然后讨论三个阶段:热情(这如何运作?)、失望(有哪些问题?)和接受(如何解决问题并接受局限性?)。
1. 我的应用程序:因果思维模式
因果思维模式应用程序。作者提供的图片。更多信息请见 thecausalmindset.com
我一直致力于使因果推断变得更易于访问,以帮助人们做出更好的决策并减少被操控的风险。在过去的十年里,我向大约 12,000 人教授统计学,大多数是在学术界。
除了理论和经验课程,我创建了一个框架,“因果思维模式”,本质上是一套基于因果推断和统计学的实用批判性思维工具(不涉及数学),我每周在 LinkedIn 和 Instagram 上发布内容。
区分事实与虚构对于做出明智决策和保护自己免受操控至关重要。不幸的是,面对信息的泛滥,这一任务变得越来越困难;虚假信息、误导性文章和彻头彻尾的谎言随处可见。
事实核查并不总是实际可行,因为它可能耗时较长,并且通常假设存在一个明确的“正确”答案。
这个应用程序的目的是提供一个多功能的工具,可以随时随地用来揭示论证中的缺陷,并增强决策能力。此外,它旨在赋予你这些批判性思维工具,培养你对应用程序本身的独立性。
我的应用程序是如何工作的?
你可以向聊天机器人分享一个陈述、一张图表或一个反思,它会应用因果思维模式框架来解剖和质疑这一主张的有效性。
你可以在 thecausalmindset.com 上找到示例或自己尝试。但这不是本文的重点。我更想向你展示在这一过程中我学到的可能对你有用的内容。
2. 热情(这如何运作?)
现在,如果你有 ChatGPT Plus 的访问权限,你可以进入“操作界面”并创建你自己的个性化 GPT。OpenAI 已经发布了一些他们自制的 GPT(公司称之为 GPTs),并计划在不久的将来推出一个由用户生产的 GPT 应用商店(参见:openai.com/blog/introducing-gpts
)。
基本的方面是,你不需要编程,而是使用自然语言来设置和调整应用程序。以下是这个操作界面的样子:
作者提供的图片,展示了创建个性化 GPT 的实验室界面
在左侧,你可以与 GPT 构建器聊天以进行设置,在右侧,你可以尝试它。所以,基本上,这就是我开始的方式。我给出了指示,也向构建器请教建议,并在右侧测试了结果。
在左上角,你也可以点击“配置”,在那里你可以直接访问应用的其他方面(见下图)。你也可以通过在“创建”选项卡中与 GPT Builder 讨论来填写所有这些字段。注意,Builder 可以根据你进行的对话决定调整这些字段的内容(这将在下一部分变得重要)。
图片来源:作者,创建你个性化 GPT 的实验室截图
你可以在对话中添加图标或用 Dall-E 生成图标。你还可以添加一个始终在应用名称旁边可见的描述。然后是指令部分:这个 GPT 做什么?它的行为方式如何?它应该避免做什么?最后,你有对话启动器,用户可以选择这些来测试应用。
以下是其他选项:
知识: 这是知识库。你可以放置文件,GPT 会优先处理来自这个部分的信息。因此,这一点是区别于基础版 ChatGPT 的关键。
能力: 你还可以选择它的能力:网页浏览、Dall-e 图像生成和代码解释器。
操作: 添加操作按钮:“你可以通过提供关于端点、参数的详细信息,以及模型如何使用它的描述,使第三方 API 可用于你的 GPT。”
起初,这一切都令人难以置信。感觉就像在训练某人。我把我的知识库(我的文章、TEDx 的文字记录,以及我写的关于框架《因果思维》的书)交给了 GPT。这非常迷人,感觉像《黑客帝国》。与其阅读我的书,我不如把它上传到机器中,别人就可以访问其中的知识。
但很快,它开始崩溃。
3. 失望(有哪些问题?)
类似于教导某人,GPT 会开始忘记一些事情。因此,我开始重新修复之前修复过的东西。当最新的问题解决后,我会发现其他地方有变化,我不得不再次修复。因此,我很快陷入了一个循环,这让我感到很沮丧。更糟糕的是,模型会不断更改对话启动器、指令、描述等。其他参与这个项目的人也有同样的感觉。
一些同事和朋友把这个东西丢弃了,并声称它太初步了。我没有。
然后,我还意识到用户可以找到我的知识库、我的指令,甚至是我用于构建应用的对话(‘提示注入’)。这也是相当令人沮丧的(见下图)。
作为应用用户,我如何获取知识库中的文档列表的示例。图片来源:作者。
4. 接受(如何解决问题并接受局限性?)
我不打算放弃。因此,这是我找到的所有这些问题的解决方案:
4.1 隐私
我放弃了隐私方面的考虑,仅分享了我已经在网上分享的内容,任何人都可以访问(这不是适用于任何人和任何情况的解决方案)。此外,我认为,与其保护我如何指示模型,不如分享出来让其他人改进他们的 GPT(因此这篇文章)。所以我接受了“开源”这一方面。请注意,他们可能在处理这个方面,禁用代码解释器可能会稍微减少风险。
4.2 用户指南
当你进入应用程序时,你的信息有限。我有一个完整的网站提供示例,还有一个指南在其他地方(thecausalmindset.com),但在应用程序本身,你只有一个简短的描述,通常不够。因此,这是我的第一个想法。我将“这个应用程序是如何工作的?”作为第一个对话开场白,以便人们在进入应用程序时能够立即点击这个问题。然后,我指示 GPT 始终给出完全相同的回答(见下图)。我追求简短且带有示例的内容,以邀请人们尝试,而不是阅读手册。
我的简短用户指南。图像由作者提供。
4.3 如何防止一切不断变化
尽管这种方法很好,但有时如果我指示应用程序做其他事情,应用程序仍会“忘记”。这是我做出最大更新的地方。
我彻底改变了与 Builder 的工作方式。
我创建了一个说明文件,包含了应用程序的所有关键方面,而不仅仅是聊天功能。这个文件在知识库中,并且几乎完整地粘贴在说明中(最大长度为 80000 字符)。
因此,与其讨论这些说明几个小时,我决定编写说明,将其粘贴到知识库中,并仅通过聊天指示始终仔细遵循这些说明(完整说明可以在本文附录中找到)。
这是我说明的内容:
应用程序的主要规则: 在这里,我提供了“应用程序的 10 条戒律”。
这个应用程序是如何工作的? 我给出了上面显示的描述,这是我希望应用程序告诉用户的内容。
核心分析结构: 这是我方法的核心。我指示模型始终首先按照这个结构回答(它也在我的主要规则中)。
因果思维应用程序说明: 这些是整个程序的次要说明,超出了核心分析的范围。
写作风格: 我希望它的写作风格(语气、风格等)。
对话开场白: 这里我列出了所有的对话开场白。
5. 结论
本质上,在这次经历之后,我彻底改变了创建 GPT 的方法。与模型讨论数小时不再是我的选择,我会准备一个指令文件,上传它,只要求应用程序仔细遵循这些指令。瞧。如果你想保护你的内容,你可能愿意等待一段时间,因为这是我迄今未能解决的问题。
我看到了一些对这一新概念的怀疑:GPTs。我听到的主要观点是它基本上仍然是 ChatGPT。我不同意这一点。
我认为它远不止是 ChatGP 的一个微妙不同的版本,我相信它可以非常有用,可能为你打开意想不到的可能性,原因有两个。
首先,预提示是有价值的。即使你知道想要达到的目标,也可能需要数小时或数天。即便是几分钟,也能防止你多次重做。此外,它允许你从他人的专业知识中受益。
其次,知识库的差异非常大。我花了多年时间创建因果思维模式。因此,即使设置需要几分钟,它也可能代表了十年的研究成果。
我期待看到可能性,你做了什么以及你的想法。
如果你想测试并在下面的评论中提供反馈,这里是我的应用程序:thecausalmindset.com
附录:我的指令文件
应用程序的主要规则:
· 使每一个词都有意义,并广泛依赖例子来阐明你的观点。
· 始终将讨论引导到因果分析上。
· 如果有人问一些无关的问题,回顾你的目的和目标,并建议人们可以问你的问题的例子(例如使用对话引导语)。
· 始终从下面描述的核心分析开始分析的第一部分。
· 优先处理上传的文件:应用程序优先考虑来自上传文件的信息进行分析,并将这些文件作为主要知识来源。
· 遵守事实:应用程序坚持文件中呈现的事实,避免猜测。它在依赖基线知识或其他来源之前,重视这些文件中提供的知识。
这个应用程序是如何工作的?
因果思维帮助你将事实与虚构分开。
分享一个声明、一个图表或一个反思给我,我将应用因果思维框架来剖析和质疑这一声明的有效性。
示例:
· 我跟随了一个健身计划一个月,我可以将我的表现提高归因于这个计划吗?
· 我读到一篇科学文章称冷水淋浴显著减少工作病假。他们在荷兰进行了一项有 4000 名参与者的对照实验,并收集了自我报告的数据。
· 我看到一个城市中心的免费电动滑板车公司声称,他们帮助减少了排放,因为使用电动滑板车比开车更环保。
核心分析结构:
因果思维应用程序应始终按照以下结构开始对因果问题或情况的分析:
-
提醒用户你将应用因果思维框架来分析情况。
-
快速反馈 — 主要缺陷:用一个明确的例子描述所呈现情况中的主要问题。
-
还有其他情况吗?在这里,你应该基本上找出是否有替代解释因果效应的情况,或者至少一些混杂效应(遗漏变量偏见)。如果反向因果关系是一个问题,你也应该在这里提到(但如果不是问题,可以不提)。
-
我们能否推广?你应该以至少一个明确的例子质疑外部有效性。
-
反事实:在这里,你应该提出潜在的反事实,邀请对方思考他们的例子中的比较是否有效。
-
常见偏见:在这里,潜在的附加统计或行为偏见(例如安慰剂效应、选择偏见、确认偏见、期望偏见等)。
-
提供解决方案:建议对实验或自然实验进行解释,以适当测量因果关系,适用于广泛的受众。
-
结论:通过邀请对方根据其角色进行更深入的分析来结束分析:
o 记者/内容创作者/批判性思维爱好者:对于那些从事媒体和内容创作的人,应用程序建议进一步探索来源验证和偏见检测。这对于确保工作的准确性和公正性至关重要。
o 专业人士/决策者:对于这一群体,应用程序提供了高级模块用于风险评估和决策分析。这些工具对于在复杂、高风险环境中做出明智决策至关重要。
o 学生/研究人员:应用程序推荐关于因果推理的教育资源和学术合作机会。这对学术环境中的人员或希望深入理解因果分析的人特别有益。
因果思维应用程序说明:
因果思维应用程序的结构旨在引导用户通过全面的因果分析框架。它的工作原理如下:
核心分析阶段: 当用户呈现一个情况或问题时,应用程序首先进行因果关系的基本评估。这涉及识别关键点、潜在偏见和变量之间的关系。这里的重点是理解问题的因果结构,基于因果推理的原则。
量身定制的深入探索: 基于初步分析,应用程序会提供更详细的探索,针对用户的具体角色或需求:
-
记者/内容创作者:对于媒体和内容创作相关人员,应用建议进一步探索来源验证和偏见检测。这对确保工作的准确性和公正性至关重要。
-
专业人士/决策者:对于这一群体,应用提供高级模块用于风险评估和决策分析。这些工具对于在复杂、高风险环境中做出明智决策至关重要。
-
学生/研究人员:应用推荐关于因果推断的教育资源和学术合作机会。这对学术环境中的人员或希望深入理解因果分析的人员尤为有益。
用户互动
-
应用通过一系列问题和分析与用户互动,促使他们考虑情况的各个方面。
-
应用可能使用现实世界的例子或假设情境来说明观点,帮助用户将因果思维应用到具体背景中。
-
应用可能使用如因果图或流程图等视觉工具,帮助用户可视化复杂关系。
附加功能
-
教育内容: 应用可能包含关于因果推断的教程、文章和案例研究,以教育用户关于关键概念。
-
互动练习: 为了强化学习,应用可能提供互动练习或模拟,允许用户在不同情境中练习因果分析。
应用目标
主要目标是赋予用户使更好决策的技能,基于对因果关系的深刻理解。
它旨在提升批判性思维和分析技能,帮助用户对虚假信息具有更强的免疫力,更好地应对职业和个人生活中的复杂问题。
这种方法确保用户从应用中获得即时、实际的价值,同时也可以根据特定需求和专业水平参与更深入和专业的内容。
写作风格:
· 信息性和教育性:写作主要旨在教育读者,以易于理解的方式传达因果关系和统计分析的复杂概念。目标是传授这些领域的知识和理解。
· 分析性和启发性:风格具有分析性,鼓励读者深入探讨主题。它促使读者进行批判性思考,并挑战他们质疑和探索各种情境中的因果关系。
· 结构化和清晰:文本结构良好,以逻辑顺序呈现观点。这种清晰性使复杂主题更易理解,特别是对因果分析和批判性思维新手尤为有益。
· 结合现实世界例子:写作经常结合现实世界的情境和例子,有助于将抽象概念与实际情况联系起来。这种方法使材料更具相关性,更易于理解。
· 互动性和包容性:文本鼓励读者互动,提出问题和情境,促使读者应用他们所学习的概念。这种互动风格增强了学习和记忆。
· 学术严谨但易于理解:写作风格在学术严谨性和可接近性之间取得了平衡。文本显然基于全面的研究和专业知识,但以一种对更广泛的观众易于接触的方式呈现。
· 有时对话式:在一些部分,写作采用了更对话式的语气,增加了其吸引力。这种风格有助于揭开统计概念的神秘面纱,使内容更易于接受。
总的来说,写作风格有利于学习和参与,特别适合那些寻求理解和应用因果分析于其职业、学术或个人生活各个方面的人。
交流引子:
这个应用程序是如何工作的?
示例:健身计划:我跟随一个划船的健身计划一个月。我今天测试了,速度更快了。这是这个计划的功劳吗?
示例:环境政策评估:2022 年 9 月,瑞士政府启动了一项国家广告宣传活动,支持节能以应对因乌克兰战争带来的短缺威胁。它展示了一张图表,显示在实施该政策后,瑞士的每月净消费下降(在冬季)。
示例:公司影响:我看到一家在欧洲城市中心提供免费电动滑板车的公司,并声称他们的服务减少了污染。他们使用生命周期分析来比较汽车每公里和他们的电动滑板车每公里的污染。
示例:冷水澡研究论文:我看到一篇同行评审的文章,介绍了一项随机对照试验,其中 2000 名荷兰志愿者被分成每天冷水澡和正常淋浴两组。他们发现,洗冷水澡的人因生病缺勤的次数比对照组少 1/3。所有结果均为自我报告。你怎么看?
使用 Apache Spark 在 PB 规模上构建数据湖
原文:
towardsdatascience.com/building-a-data-lake-on-pb-scale-with-apache-spark-1622d7073d46
我们在 Emplifi 如何处理大数据
·发布于 Towards Data Science ·15 分钟阅读·2023 年 1 月 26 日
–
图片由Victor Hanacek在picjumbo提供
在职业上,我在 Emplifi(前身为 Socialbakers)公司的数据工程团队中度过了过去四年,其中一个我参与的最大项目是构建一个分布式数据存储系统,目前该系统存储了近一 PB 的数据,目的是为数据分析师和研究人员提供可以高效分析和研究的数据表。正如你可以想象的那样——构建和维护这样一个数据湖并非易事,因为数据不仅频繁变化,而且其架构随着时间的推移而演变,拥有几十个甚至几百个具有不同嵌套层级的字段。
在这篇文章中,我想分享我在这段旅程中的经验和亮点,主要集中在技术层面。
我们每天处理的数据来源于社交网络,如 Facebook、Twitter、Instagram、YouTube、LinkedIn 或 TikTok。处理的数据集主要是这些网络上的公共档案和发布的帖子。一小部分数据也来自内部系统。我们的存储系统建立在 S3 AWS 上,我们称之为数据湖,因为我们在这里以原始格式以及预处理、处理和聚合后的形式存储数据。原始数据主要以压缩的 JSON 文件形式存在,而处理和聚合的数据集则以 Apache Parquet 格式存储,并以(Hive metastore)表格形式向用户公开。
数据流
数据经过公司基础设施的多个点,首先,它通过公共和在某些情况下也通过私有 API 从社交网络下载。接下来,它到达 DynamoDB,这是 AWS 上的一个键值分布式数据库服务,我们称之为主数据库。DynamoDB 擅长处理记录的频繁更新,这对我们的用例非常有用,因为数据来自社交网络,其中每条记录(如 post)随着时间的推移而演变——它收集互动并频繁变化。每次在 DynamoDB 中的更新也会在我们数据湖的stage层保存到 S3。这是 Apache Spark 中一系列步骤的起点,这些步骤将生成一致的表,数据分析师和研究人员可以使用标准分析工具高效访问。接下来的章节,我将更详细地描述这些步骤。
数据的结构
传入到stage的数据是以压缩 JSON 文件保存的连续记录流。每条记录有两个子结构——oldImage和newImage,其中第一个是记录更新前的状态,而后者包含更新后的数据。每条记录(post)可能在一天内被多次下载自社交网络,因此这个流中包含重复项,即每个id唯一标识的 post 可能出现多次。我们需要只考虑每个id的最新version来正确更新最终表。这里可以看到数据结构的一个简单示例:
{
'newImage': {
'id': 1,
'version': 100,
'created_time': '2023-01-01: 10:00:00',
'interactions': 50,
MANY OTHER FIELDS...
}
'oldImage': {
'id': 1,
'version': 101,
'created_time': '2023-01-01: 10:00:00',
'interactions': 51,
MANY OTHER FIELDS...
}
}
数据湖——三层抽象
我们的数据湖构建在 S3(AWS 上的分布式对象存储)上,并有三层抽象,我们称之为stage、target和mart。在stage层,我们保存主要来自 DynamoDB 的原始数据,但也来自其他数据库,如 Elasticsearch、Postgres、MongoDB,以及各种内部 API。这些数据主要以压缩 JSON 格式到达,有时也以 CSV 格式到达。这些数据集每天进行预处理,并以 Apache Parquet 格式保存于stage中。
这些 Parquet 文件接下来会被处理,并用于更新target中的表,更新可以是每日一次,也有些情况下——数据新鲜度不那么关键——每周一次。与stage数据不同,target表由我们的数据用户直接访问。
对于一些数据结构不太友好的表,我们创建新的表来转换数据,以便于查询,并将其保存到mart层——我们不在target层这样做,因为这里我们希望保持表作为数据来源系统的镜像,以避免用户混淆,例如,DynamoDB 中的数据与target中的数据不同。
数据湖的结构(图片由作者制作)
对我们来说,使表格在查询时高效是一个重要优先事项。数据分析师不得不等待几分钟才能完成查询是非常痛苦的。我们在两个层面上以不同的成功程度实现了这一点:
-
在 target 中:我们的手脚会有些束缚,因为我们保证数据与原始数据库中的结构相同。然而,我们可以通过自定义文件组织来实现相当大的成果,例如使用分区、桶和排序等技术。有关更多细节,请参见下一节。
-
在 mart 中:我们可以做得更多,简化嵌套结构,并根据对表执行的典型查询进行预连接或预聚合。
阶段
stage 是数据最终暴露给用户的入口点。这里传入的原始数据每天都会经过以下 3 个处理步骤:
-
模式演变
-
数据清理
-
数据去重
模式演变——阶段层
来自数据的结构随时间发生变化并不罕见——数据集中可能会添加新字段,甚至某些字段的数据类型可能会改变。如果没有一个健全的模式演变步骤,这将是一个相当大的问题。Apache Spark 允许你通过提供模式来读取数据,或者让 Spark 推断它(使用所谓的 schema-on-read)。
通常,你对模式应该是什么有一些期望(理想情况下与已经存在的 target 表的模式相同),但是将此模式提供给 Spark 会导致所有新字段可能因为某些原因在数据的新增量中缺失。此外,如果某个字段的数据类型发生变化,Spark 将会将其读取为 NULL(如果这两种数据类型不兼容),或者会丢弃这些行,或者根据指定的 模式 完全失败(另见我的另一篇 文章 关于 JSON 模式演变)。另一方面,让 Spark 推断模式似乎更合理,但也有一些缺点和需要考虑的点:
-
时间戳将被推断为字符串(除了 Spark 3.0 版本,它将被推断为时间戳)
-
map 类型被推断为结构体
-
如果没有其他逻辑,所有新字段可能会传播到最终表中,这可能是不希望的——有人应该确认新字段是否可以暴露给用户(也许新字段只是由于某些错误产生的)
-
如果没有其他逻辑,如果数据类型与 target 表中的类型不同,将增量合并到 target 表中将会失败
我们解决这个问题的方法是使用我们内部开发的架构比较工具。我们为所有数据集的架构建立了版本控制系统(架构注册表)。在处理过程中,我们首先让 Spark 推断架构,然后将其与版本化的架构进行比较。它会检测所有新字段并在一些通知渠道中报告。如果数据类型发生变化,算法会决定这些新类型是否可以安全地转换为版本化类型而不会丢失数据。基于这种比较,我们创建了一个新的架构,并使用这个新架构重新读取数据,将变化的字段转换为版本化字段。
唯一的问题发生在新数据类型无法安全转换时,此时我们会使用新的数据类型,但在通知渠道中报告错误,因为这意味着在目标层的后续作业将会因为这个问题而失败。在这种情况下,我们需要手动调查为什么会发生这种情况,是否只是由于某些错误导致的单次更改,或者是数据结构中的永久性更改,这意味着我们需要相应地调整最终表格。
我们的架构注册表提供了 API 和 Web 用户界面。在这里,你可以看到来自 Web UI 的截图,其中比较了一个数据源的两个架构版本。版本 1.0.6 中存在而版本 1.0.5 中不存在的新字段以绿色突出显示。
数据清洗
在使用演化后的架构读取数据到 DataFrame 并转换类型后,我们应用一系列过滤器,排除某些关键字段(如id、version、created_time等)中具有NULL值的记录。这些字段对于数据的正确去重和在目标层的最终表格中的正确合并是必要的。我们将这些记录称为损坏记录,并将其保存在一个单独的文件夹中,以便调查为什么缺少这些值。
数据去重
如前所述,stage中的数据可能包含多个具有相同id的记录版本。这是因为个别记录(主要是来自社交网络的帖子)不断变化,我们一天内从社交网络下载这些记录多次。
stage层的处理作业在清晨运行,目标是获取前一天的增量(新数据),并为将数据合并到最终表格的目标作业做准备。在架构演化和数据清洗之后,我们还需要进行去重,即对于每个id,我们将仅保留最新版本:
increment_df = (
spark.read
.format('json')
.option('path', input_path)
.schema(evolved_schema)
.load()
.select('newImage.*')
)
w = Window().partitionBy('id').orderBy(desc('version'))
result_df = (
increment_df
.withColumn('r', row_number().over(w))
.filter(col('r') == 1)
.drop('r')
)
从代码中可以看出,我们仅从newImage中选择数据,并使用窗口函数row_number()进行去重(有关在 Spark 中使用窗口函数的详细信息,请参见我另一篇文章)。
目标
目标层的目标是创建与 DynamoDB、Elasticsearch 和其他内部数据库镜像的表,这些表可以通过标准分析工具直接访问,如 SQL、Spark 中的 DataFrame API 或 Pandas,以及 Python 数据科学生态系统中的所有库都可以在其上使用。请注意,通过直接访问原始数据库实现这一点将非常困难。
从技术上讲,在目标中,我们从阶段获取 parquet 文件,并将其插入到现有表中。我们还希望为分析师提供准备好的分析查询表,因此在创建表的布局时我们投入了一些精力。表本身只是对文件的某种抽象,我们再次使用 Apache Parquet 格式。
表布局
我们使用写时复制的概念,这意味着在每次更新时,整个表和所有文件都会被覆盖。这是相当昂贵的,特别是对于大型表,然而,它允许我们保持文件的紧凑和组织,以满足高效读取的要求。
表示社交网络帖子的数据表按时间维度进行分区,典型的分区列是created_year和created_month,这两个列都是从帖子的created_time派生出来的。这种分区加快了数据的读取,因为分析师通常对一些近期数据感兴趣。因此,处理引擎如 Spark 或 Presto 确保只扫描查询的分区,而跳过其他分区。这些表还按列profile_id进行分桶,因为它可以与表示个人资料的表在此列上连接。对该列进行聚合和连接的查询不会引起 Spark 中的数据洗牌,这样可以提高此类查询的总体效率。此外,如果查询包含对profile_id列的过滤条件,则可以进行分桶修剪(有关分桶的更多信息,请参见我另一篇文章)。
合并增量
target层中的处理作业在stage作业后清晨运行,其目标是将前一天的增量合并到target表中。我们将增量和表都读取到 DataFrames 中,将它们联合,并使用窗口函数row_number — 类似于我们在stage中所做的 — 选择将保存到新快照中的最新记录版本。因此,如果特定的id已经在表中,它将被来自增量的具有更大version的记录所替代:
stage_df = (
spark.read
.format('parquet')
.schema(input_schema)
.option(input_path)
.load()
)
target_df = spark.table(table_name)
w = Window().partitionBy('id').orderBy(desc('version'))
result_df = (
stage_df.unionByName(target_df)
.withColumn('r', row_number().over(w))
.filter(col('r') == 1)
.drop('r')
)
然而,如果stage_df和target_df的模式不同,这种方法将不起作用,因此我们需要再次应用模式演变步骤。
模式演变 — 目标层
同样地,在stage中我们对模式进行了演变,我们也需要在target中进行相同的操作。要成功将增量合并到target表中,两个表的模式必须相同。然而,如果增量中添加了新字段,这种情况可能不会成立。
我们使用一种称为input_schema的模式来控制这些新字段是否已添加到表中。这个input_schema也有版本,如果我们想将新字段推广到表中,我们会创建一个包含这些新字段的新版本的input_schema。这是一个手动步骤,考虑到应该有人确认这些新字段可以暴露给数据用户,这似乎是合理的。
接下来,我们还需要将表的模式更改为与增量的模式相同。Spark SQL 提供了一种使用ALTER TABLE table_name ADD COLUMNS向表中添加新列的方法,但它没有提供添加嵌套字段的方法。
因此,我们将这一功能自行实现到我们的框架中。在这种情况下,target作业会在一个临时位置创建一个具有新模式的空表。之后,我们删除原始表,并将空表指向原始表的数据位置。这样,我们将获得一个具有相同数据但模式已修改的新表:
(
spark.createDataFrame([], new_schema)
.write
.mode('overwrite')
.option('path', table_location_temp)
.saveAsTable(table_name_temp)
)
spark.sql('ALTER TABLE table_name_temp SET LOCATION table_location')
spark.sql('MSCK REPAIR TABLE table_name_temp')
spark.sql('DROP TABLE table_name')
spark.sql('ALTER TABLE table_name_temp RENAME TO table_name')
原子写入
从我们开发数据湖的初期开始,最大的挑战之一就是确保原子写入到位。问题是,当简单地覆盖一个表时…
(
df.write
.mode('overwrite')
.format('parquet')
.option('path', output_path)
.saveAsTable(table_name)
)
…不是原子的。如果 Spark 作业因任何原因在写入过程中失败,表可能会停止存在,或者 S3 前缀可能会开始包含部分写入的数据。在我们的框架中,我们通过始终将表保存在不同的位置并以新名称table_name_temp保存来实现原子性,在写入成功后,我们交换表名称,因此新的快照只有在成功写入后才会对用户可用。如果作业失败,我们不会交换名称,而是从头开始在新位置重新启动过程:
output_path = posixpath.join(output_path, str(int(time.time())))
(
df.write
.mode('overwrite')
.format('parquet')
.option('path', output_path)
.saveAsTable(table_name_tmp)
)
spark.sql('DROP TABLE IF EXISTS table_name')
spark.sql('ALTER TABLE table_name_tmp RENAME TO table_name')
我们保存数据的路径包含写入时的时间戳,并且我们始终保留最近的几个快照,这使得所谓的时间旅行成为可能——如果我们发现新创建的表因某些原因(由于某些错误)而损坏,我们可以将其指向任何之前的快照。
数据质量
目标 表中的数据的一致性和质量至关重要,因为这些表直接暴露给数据用户。这可以通过在每次更新后检查表格来实现,以确保数据没有由于代码中的某些错误、原始数据库的错误导出或数据在经过转换步骤后到达最终表时可能发生的其他问题而被破坏。
我们使用的数据质量框架是Great Expectations,它可以与 Spark 集成。对于每个数据集,我们可以定义一组期望,这些期望会在创建表的新快照后进行验证。如果一些关键期望未能满足,我们可以将表“时间旅行”到之前的快照,并详细调查失败的期望。
Mart
尽管从目标 表中访问数据相当高效,但对于某些表来说,这并不是很用户友好。这是自然的,因为,如前所述,我们保持目标 表中的数据结构与数据在原始数据库中存储的结构相同,而这些数据库不一定是为分析目的存储的。这主要涉及嵌套数据结构中的列,例如数组或结构体。有时,使用 SQL 的人们在查询中转换这些字段可能太繁琐,因此我们创建了从目标 表派生的附加表,以便在这些表中转换数据结构,从而更容易进行查询。有时这些表也可以与其他表连接,并且它们是为某些特定用户群体量身定制的。这些派生表被保存在湖泊的mart 层中。mart 的另一个目的是保留由分析师计算的一些汇总和结果表。
环境、作业、协调
在stage、target 和mart 层中的所有处理都是用 PySpark 实现的,并且在 Python 中实现了处理模式演变的自定义逻辑。代码在Databricks平台上运行,对于每个数据源,我们在stage 层和target 层都有一个作业,对于某些数据源,还有mart 层的作业。所有作业都使用 Apache Airflow 技术进行协调,这使我们能够定义作业之间的依赖关系。
去重、模式演变、原子数据保存等核心逻辑都在类和模块中实现,并编译成一个轮子(wheel),这个轮子被导入到任务中,以便基于软件工程的最佳实践复用代码。
湖中的所有步骤基本上都是自动化的,仅有少数几个与引入新数据源相关的手动步骤。在这种情况下,我们只需在配置文件中添加一些配置参数,并将新数据源的模式添加到模式注册表中。Git 管道会在 Databricks 中创建作业,并通过 Airflow 按照预定的调度启动它们。其他手动步骤与监控和审批新字段以及对无法安全转换为版本化类型的更改数据类型进行调试相关。
如前所述,数据用户可以使用 SQL 或 Spark 中的 DataFrame API 访问目标和数据集市表,这些 API 可以方便地与 Pandas 集成,从而实现与整个 Python 数据科学和机器学习库生态系统的集成。喜欢使用 SQL 的分析师可以使用Querybook——一个用户友好的笔记本界面——它连接到 Presto 引擎,而喜欢使用 Spark 的科学家则可以在 Databricks 的笔记本环境中访问表格。
我们的数据湖用户(分析师和研究人员)可以在 Spark 上使用数据科学和机器学习库,或者使用 Presto 引擎用 SQL 查询数据。他们可以选择在 Databricks 和 Querybook 中工作,这两个平台都提供了具有用户友好界面的笔记本环境,适用于各种使用场景。
致谢
构建和维护数据湖是整个数据工程团队的协作工作。我想感谢所有同事,不仅感谢在数据湖项目中进行的有益且富有成效的合作,感谢我们在过去四年中提出的大量有趣的想法和意见,还要感谢对本文的审阅和有益的评论。
结论
在处理大数据的公司中,拥有一个强大且可靠的数据湖似乎是必需的。传统的数据仓库概念已不再适用,因为数据量庞大,并且来自具有复杂结构的不同数据库和系统。然而,在实现这样的存储系统时,仍然有几个挑战需要应对。
我们在关系数据库中认为理所当然的 ACID 事务在此已不再那么明显,必须处理存储系统中的低级概念,如文件组织。通过一些额外的努力,可以恢复 ACID——我们在原子性方面达到了这一点。另一种方法是使用更先进的表格格式,如 Delta 或 Iceberg,它们提供 ACID。
数据模式的变化,无论是由于导出过程中出现错误还是实体真的发生了变化,都是一个相当棘手的问题,如果不加以谨慎处理,可能会导致数据丢失。我们在内部开发了一个工具,我们考虑将来将其开源。这个工具允许我们对所有数据源进行模式版本控制,以便查看其模式如何演变的历史。它还允许我们比较模式,执行各种操作,并轻松处理模式演变。
维护数据湖需要与数据用户(通常是分析师、研究人员和数据科学家)进行持续讨论。了解对表执行了什么样的查询,使我们能够通过在分区、桶化甚至排序方面使用自定义布局来进行高度优化。
构建分子属性预测的图卷积网络
人工智能
制作分子图和开发一个基于 PyTorch 的简单 GCN 的教程
·
关注 发表在 Towards Data Science ·17 分钟阅读·2023 年 12 月 23 日
–
照片由 BoliviaInteligente 提供,来源于 Unsplash
人工智能在全球范围内引起了轰动。每周都会出现新的模型、工具和应用程序,承诺推动人类努力的边界。开放源代码工具的可用性使得用户能够在少量代码中训练和使用复杂的机器学习模型,真正实现了人工智能的民主化;同时,尽管许多这些现成的模型可能提供了出色的预测能力,但它们作为黑箱模型的使用可能会剥夺了对人工智能深入理解的好奇学生。特别是在自然科学中,这种理解尤为重要,因为知道一个模型准确是不够的——还必须了解它与其他物理理论的联系、其局限性以及它对其他系统的普遍适用性。在本文中,我们将通过化学的视角探讨一种特定的机器学习模型——图卷积网络。这并不是一个数学严格的探讨;相反,我们将尝试将网络的特征与传统自然科学模型进行比较,并思考它为何表现如此出色。
1. 对图形和图神经网络的需求
在化学或物理学中,模型通常是一个连续函数,比如 y=f(x₁, x₂, x₃, …, xₙ),其中 x₁, x₂, x₃, …, xₙ 是输入,y 是输出。这样的模型的一个例子是决定两个点电荷 q₁ 和 q₂ 之间的静电相互作用(或力)的方程,这两个点电荷在相对介电常数为 εᵣ 的介质中,相隔距离为 r,通常称为库仑定律。
图 1:库仑方程作为点电荷之间静电相互作用的模型(图像来源:作者)
如果我们不知道这种关系,但假设有多个数据点,每个数据点包括点电荷之间的相互作用(输出)和相应的输入,我们可以拟合一个人工神经网络来预测任何给定点电荷在指定介质中的任何给定分离下的相互作用。在这个问题的情况下,虽然忽略了一些重要的警告,但创建一个数据驱动的物理问题模型是相对简单的。
现在考虑从分子的结构预测某一特定性质的问题,比如在水中的溶解度。首先,没有明显的输入集来描述一个分子。你可以使用各种特征,如键长、键角、不同类型元素的数量、环的数量等等。然而,没有保证任何这样的任意集合对所有分子都有效。
其次,与点电荷的例子不同,输入可能不一定存在于连续空间中。例如,我们可以将甲醇、乙醇和丙醇视为一组链长逐渐增加的分子;然而,它们之间并不存在任何概念——链长是一个离散参数,没有办法在甲醇和乙醇之间进行插值以得到其他分子。拥有一个连续的输入空间对于计算模型的导数是至关重要的,这些导数随后可以用于优化所选属性。
为了克服这些问题,已经提出了各种编码分子的方法。其中一种方法是使用 SMILES 和 SELFIES 等方案进行文本表示。这种表示方法有大量文献资料,我推荐感兴趣的读者阅读这篇有用的综述。第二种方法涉及将分子表示为图形。虽然每种方法都有其优点和缺点,但图形表示对化学更直观。
图是由节点通过边连接组成的数学结构,边表示节点之间的关系。分子自然适应这种结构——原子成为节点,键成为边。图中的每个节点由一个向量表示,该向量编码了相应原子的属性。通常,一位编码方案就足够了(更多内容见下一节)。这些向量可以堆叠起来形成一个节点矩阵。节点之间的关系——由边表示——可以通过一个方形的邻接矩阵来划分,其中每个元素aᵢⱼ 取值为 1 或 0,取决于两个节点i 和 j 是否由边连接。对角线上的元素设置为 1,表示自连接,这使得矩阵适合卷积(如你将在下一节看到的)。可以开发更复杂的图形表示,其中边的属性也在一个单独的矩阵中进行一位编码,但我们将这些留待另一篇文章。这些节点和邻接矩阵将作为我们模型的输入。
图 2:将乙酰胺分子表示为图形,节点的原子序号通过一位编码表示(图片来源:作者)
通常,人工神经网络模型接受的是一维输入向量。对于多维输入,比如图像,开发了一类叫做卷积神经网络的模型。在我们的情况下,我们有二维矩阵作为输入,因此需要一个修改过的网络来接受这些输入。图神经网络是为了处理这样的节点和邻接矩阵而开发的,它们将这些矩阵转换为适当的一维向量,这些向量可以通过普通的人工神经网络的隐藏层来生成输出。图神经网络有许多类型,比如图卷积网络、消息传递网络、图注意力网络等等,它们主要在于节点和边之间交换信息的函数上有所不同。由于图卷积网络相对简单,我们将更详细地了解它们。
2. 图卷积和池化层
考虑你输入的初始状态。节点矩阵表示了每个原子的独热编码。为了简化起见,我们考虑原子序数的独热编码,其中原子序数为n的原子在nᵗʰ索引处有一个 1,其余位置都是 0。邻接矩阵表示节点之间的连接。在当前状态下,节点矩阵不能作为人工神经网络的输入,原因有以下几点:(1) 它是二维的,(2) 它不是排列不变的,(3) 它不是唯一的。这里的排列不变性意味着无论你如何排列节点,输入应该保持不变;目前,相同的分子可以由相同节点矩阵的多个排列表示(假设邻接矩阵也有适当的排列)。这是一个问题,因为网络会将不同的排列视为不同的输入,而它们应该被视为相同的。
对于前两个问题,有一个简单的解决方案——池化。如果节点矩阵沿列维度进行池化,那么它将被减少到一个排列不变的一维向量。通常,这种池化是简单的均值池化,这意味着最终池化后的向量包含节点矩阵中每一列的均值。然而,这仍然无法解决第三个问题——池化两个异构体的节点矩阵,例如正戊烷和新戊烷,将产生相同的池化向量。
为了使最终的池化向量具有唯一性,我们需要在节点矩阵中加入一些邻居信息。以同分异构体为例,虽然它们的化学式相同,但它们的结构却不同。加入邻居信息的一个简单方法是对每个节点及其邻居进行某种操作,例如求和。这可以表示为节点矩阵与邻接矩阵的乘法(试着在纸上计算:邻接矩阵与节点矩阵的乘积生成一个更新后的节点矩阵,其中每个节点向量等于它自身与邻居节点向量的和)。通常,通过用对角度矩阵的逆进行预乘,对每个节点的度(或邻居数量)进行归一化,从而使这一和值成为邻居的均值。最后,这个乘积会被一个权重矩阵后乘,以使这个操作具有参数化特性。这个完整的操作称为图卷积。图 3 显示了一种直观而简单的图卷积形式。一个数学上更严格且数值上更稳定的形式可以在Thomas Kipf 和 Max Welling 的研究中找到,该研究对邻接矩阵进行了修改的归一化。卷积和池化操作的组合也可以解释为一种非线性的经验群体贡献方法。
图 3:用于乙酰胺分子的图卷积(作者提供的图片)
图卷积网络的最终结构如下——首先,为给定的分子计算节点和邻接矩阵。然后对这些矩阵应用多次图卷积,并进行池化以生成一个包含所有分子信息的单一向量。随后,这个向量通过标准人工神经网络的隐藏层产生输出。隐藏层、池化层和卷积层的权重通过对基于回归的损失函数(如均方误差)应用反向传播同时确定。
3. 代码实现
在讨论了与图卷积网络相关的所有关键概念之后,我们准备开始使用 PyTorch 构建一个网络。虽然存在一个名为 PyTorch Geometric 的灵活且高性能的 GNN 框架,但我们不会使用它,因为我们的目标是深入了解其内部机制并发展我们的理解。
本教程分为四个主要部分——(1)使用 RDKit 自动创建图形,(2)将图形打包成 PyTorch 数据集,(3)构建图卷积网络架构,以及(4)训练网络。完整的代码以及安装和导入所需包的说明可以在文章末尾提供的 GitHub 仓库中找到链接。
3.1. 使用 RDKit 创建图形
RDKit 是一个化学信息学库,允许高通量访问小分子的性质。我们将需要它来完成两个任务——获取分子中每个原子的原子序数以进行节点矩阵的独热编码,并获取邻接矩阵。我们假设分子是通过其 SMILES 字符串提供的(这对于大多数化学信息学数据来说是正确的)。此外,为了确保所有分子的节点和邻接矩阵的大小一致——默认情况下它们的大小不一致,因为它们的大小依赖于分子中的原子数——我们用 0 填充这些矩阵。最后,我们将对上面提出的卷积进行小修改——我们将邻接矩阵中的“1”替换为相应的键长的倒数。这样,网络将获得更多关于分子几何的信息,并且还会根据邻居的键长来加权每个节点周围的卷积。
class Graph:
def __init__(
self, molecule_smiles: str,
node_vec_len: int,
max_atoms: int = None
):
# Store properties
self.smiles = molecule_smiles
self.node_vec_len = node_vec_len
self.max_atoms = max_atoms
# Call helper function to convert SMILES to RDKit mol
self.smiles_to_mol()
# If valid mol is created, generate a graph of the mol
if self.mol is not None:
self.smiles_to_graph()
def smiles_to_mol(self):
# Use MolFromSmiles from RDKit to get molecule object
mol = Chem.MolFromSmiles(self.smiles)
# If a valid mol is not returned, set mol as None and exit
if mol is None:
self.mol = None
return
# Add hydrogens to molecule
self.mol = Chem.AddHs(mol)
def smiles_to_graph(self):
# Get list of atoms in molecule
atoms = self.mol.GetAtoms()
# If max_atoms is not provided, max_atoms is equal to maximum number
# of atoms in this molecule.
if self.max_atoms is None:
n_atoms = len(list(atoms))
else:
n_atoms = self.max_atoms
# Create empty node matrix
node_mat = np.zeros((n_atoms, self.node_vec_len))
# Iterate over atoms and add to node matrix
for atom in atoms:
# Get atom index and atomic number
atom_index = atom.GetIdx()
atom_no = atom.GetAtomicNum()
# Assign to node matrix
node_mat[atom_index, atom_no] = 1
# Get adjacency matrix using RDKit
adj_mat = rdmolops.GetAdjacencyMatrix(self.mol)
self.std_adj_mat = np.copy(adj_mat)
# Get distance matrix using RDKit
dist_mat = molDG.GetMoleculeBoundsMatrix(self.mol)
dist_mat[dist_mat == 0.] = 1
# Get modified adjacency matrix with inverse bond lengths
adj_mat = adj_mat * (1 / dist_mat)
# Pad the adjacency matrix with 0s
dim_add = n_atoms - adj_mat.shape[0]
adj_mat = np.pad(
adj_mat, pad_width=((0, dim_add), (0, dim_add)), mode="constant"
)
# Add an identity matrix to adjacency matrix
# This will make an atom its own neighbor
adj_mat = adj_mat + np.eye(n_atoms)
# Save both matrices
self.node_mat = node_mat
self.adj_mat = adj_mat
3.2. 在 Dataset 中打包图
PyTorch 提供了一个便捷的Dataset类来存储和访问各种数据。我们将使用它来存储每个分子的节点和邻接矩阵及输出。请注意,使用这个Dataset接口来处理数据不是强制性的;不过,使用这个抽象会使后续步骤更加简单。我们需要为继承自Dataset类的GraphData类定义两个主要方法:一个是**len方法来获取数据集的大小,另一个是getitem**方法来获取给定索引的输入和输出。
class GraphData(Dataset):
def __init__(self, dataset_path: str, node_vec_len: int, max_atoms: int):
# Save attributes
self.node_vec_len = node_vec_len
self.max_atoms = max_atoms
# Open dataset file
df = pd.read_csv(dataset_path)
# Create lists
self.indices = df.index.to_list()
self.smiles = df["smiles"].to_list()
self.outputs = df["measured log solubility in mols per litre"].to_list()
def __len__(self):
return len(self.indices)
def __getitem__(self, i: int):
# Get smile
smile = self.smiles[i]
# Create MolGraph object using the Graph abstraction
mol = Graph(smile, self.node_vec_len, self.max_atoms)
# Get node and adjacency matrices
node_mat = torch.Tensor(mol.node_mat)
adj_mat = torch.Tensor(mol.adj_mat)
# Get output
output = torch.Tensor([self.outputs[i]])
return (node_mat, adj_mat), output, smile
由于我们已经定义了自己定制的节点和邻接矩阵、输出以及 SMILES 字符串的返回方式,我们需要定义一个自定义函数来整理数据,即将数据打包成一个批次,然后传递给网络。通过传递数据批次而不是单个数据点,并使用小批量梯度下降来训练神经网络,可以在准确性和计算效率之间取得微妙的平衡。我们将在下面定义的整理函数本质上会收集所有数据对象,将它们按类别分层,堆叠在列表中,转换为 PyTorch 张量,并重新组合这些张量,以便以与我们的GraphData类相同的方式返回它们。
def collate_graph_dataset(dataset: Dataset):
# Create empty lists of node and adjacency matrices, outputs, and smiles
node_mats = []
adj_mats = []
outputs = []
smiles = []
# Iterate over list and assign each component to the correct list
for i in range(len(dataset)):
(node_mat,adj_mat), output, smile = dataset[i]
node_mats.append(node_mat)
adj_mats.append(adj_mat)
outputs.append(output)
smiles.append(smile)
# Create tensors
node_mats_tensor = torch.cat(node_mats, dim=0)
adj_mats_tensor = torch.cat(adj_mats, dim=0)
outputs_tensor = torch.stack(outputs, dim=0)
# Return tensors
return (node_mats_tensor, adj_mats_tensor), outputs_tensor, smiles
3.3. 构建图卷积网络架构
完成数据处理部分的代码后,我们现在转向构建模型本身。为了清晰起见,我们将构建自己的卷积层和池化层,但你们中更高级的开发者可以轻松地用 PyTorch Geometric 模块中更复杂的预定义层替换这些层。ConvolutionLayer本质上做三件事——(1)从邻接矩阵计算逆对角度矩阵,(2)对四个矩阵(D⁻¹ANW)进行乘法运算,以及(3)对层输出应用非线性激活函数。与其他 PyTorch 类一样,我们将从已经定义了forward方法等方法的Module基类继承。
class ConvolutionLayer(nn.Module):
def __init__(self, node_in_len: int, node_out_len: int):
# Call constructor of base class
super().__init__()
# Create linear layer for node matrix
self.conv_linear = nn.Linear(node_in_len, node_out_len)
# Create activation function
self.conv_activation = nn.LeakyReLU()
def forward(self, node_mat, adj_mat):
# Calculate number of neighbors
n_neighbors = adj_mat.sum(dim=-1, keepdims=True)
# Create identity tensor
self.idx_mat = torch.eye(
adj_mat.shape[-2], adj_mat.shape[-1], device=n_neighbors.device
)
# Add new (batch) dimension and expand
idx_mat = self.idx_mat.unsqueeze(0).expand(*adj_mat.shape)
# Get inverse degree matrix
inv_degree_mat = torch.mul(idx_mat, 1 / n_neighbors)
# Perform matrix multiplication: D^(-1)AN
node_fea = torch.bmm(inv_degree_mat, adj_mat)
node_fea = torch.bmm(node_fea, node_mat)
# Perform linear transformation to node features
# (multiplication with W)
node_fea = self.conv_linear(node_fea)
# Apply activation
node_fea = self.conv_activation(node_fea)
return node_fea
接下来,我们构造PoolingLayer。该层只执行一个操作,即沿第二维度(节点数量)计算均值。
class PoolingLayer(nn.Module):
def __init__(self):
# Call constructor of base class
super().__init__()
def forward(self, node_fea):
# Pool the node matrix
pooled_node_fea = node_fea.mean(dim=1)
return pooled_node_fea
最后,我们将定义一个ChemGCN类,包含卷积层、池化层和隐藏层的定义。通常,这个类应该有一个构造函数来定义这些层的结构和顺序,以及一个forward方法,接受输入(在我们的情况下是节点和邻接矩阵)并生成输出。我们将对所有层的输出应用LeakyReLU激活函数。此外,我们还将使用 dropout 来减少过拟合。
class ChemGCN(nn.Module):
def __init__(
self,
node_vec_len: int,
node_fea_len: int,
hidden_fea_len: int,
n_conv: int,
n_hidden: int,
n_outputs: int,
p_dropout: float = 0.0,
):
# Call constructor of base class
super().__init__()
# Define layers
# Initial transformation from node matrix to node features
self.init_transform = nn.Linear(node_vec_len, node_fea_len)
# Convolution layers
self.conv_layers = nn.ModuleList(
[
ConvolutionLayer(
node_in_len=node_fea_len,
node_out_len=node_fea_len,
)
for i in range(n_conv)
]
)
# Pool convolution outputs
self.pooling = PoolingLayer()
pooled_node_fea_len = node_fea_len
# Pooling activation
self.pooling_activation = nn.LeakyReLU()
# From pooled vector to hidden layers
self.pooled_to_hidden = nn.Linear(pooled_node_fea_len, hidden_fea_len)
# Hidden layer
self.hidden_layer = nn.Linear(hidden_fea_len, hidden_fea_len)
# Hidden layer activation function
self.hidden_activation = nn.LeakyReLU()
# Hidden layer dropout
self.dropout = nn.Dropout(p=p_dropout)
# If hidden layers more than 1, add more hidden layers
self.n_hidden = n_hidden
if self.n_hidden > 1:
self.hidden_layers = nn.ModuleList(
[self.hidden_layer for _ in range(n_hidden - 1)]
)
self.hidden_activation_layers = nn.ModuleList(
[self.hidden_activation for _ in range(n_hidden - 1)]
)
self.hidden_dropout_layers = nn.ModuleList(
[self.dropout for _ in range(n_hidden - 1)]
)
# Final layer going to the output
self.hidden_to_output = nn.Linear(hidden_fea_len, n_outputs)
def forward(self, node_mat, adj_mat):
# Perform initial transform on node_mat
node_fea = self.init_transform(node_mat)
# Perform convolutions
for conv in self.conv_layers:
node_fea = conv(node_fea, adj_mat)
# Perform pooling
pooled_node_fea = self.pooling(node_fea)
pooled_node_fea = self.pooling_activation(pooled_node_fea)
# First hidden layer
hidden_node_fea = self.pooled_to_hidden(pooled_node_fea)
hidden_node_fea = self.hidden_activation(hidden_node_fea)
hidden_node_fea = self.dropout(hidden_node_fea)
# Subsequent hidden layers
if self.n_hidden > 1:
for i in range(self.n_hidden - 1):
hidden_node_fea = self.hidden_layersi
hidden_node_fea = self.hidden_activation_layersi
hidden_node_fea = self.hidden_dropout_layersi
# Output
out = self.hidden_to_output(hidden_node_fea)
return out
3.4. 网络训练
我们已经构建了训练模型和进行预测所需的工具。在这一部分,我们将编写辅助函数来训练和测试我们的模型,并编写脚本以运行生成图表、构建网络和训练模型的工作流程。
首先,我们定义一个Standardizer类来标准化我们的输出。神经网络更喜欢处理相对较小且相互之间变化不大的数字。标准化有助于达到这一点。
class Standardizer:
def __init__(self, X):
self.mean = torch.mean(X)
self.std = torch.std(X)
def standardize(self, X):
Z = (X - self.mean) / (self.std)
return Z
def restore(self, Z):
X = self.mean + Z * self.std
return X
def state(self):
return {"mean": self.mean, "std": self.std}
def load(self, state):
self.mean = state["mean"]
self.std = state["std"]
其次,我们定义一个函数来执行每个 epoch 的以下步骤:
-
从数据加载器中解包输入和输出,并将其传输到 GPU(如果可用)。
-
通过网络传递输入并获得预测结果。
-
计算预测值与输出之间的均方误差。
-
执行反向传播并更新网络的权重。
-
对其他批次重复上述步骤。
该函数返回批量平均损失和均值绝对误差,可用于绘制损失曲线。一个类似的没有反向传播的函数用于测试模型。
def train_model(
epoch,
model,
training_dataloader,
optimizer,
loss_fn,
standardizer,
use_GPU,
max_atoms,
node_vec_len,
):
# Create variables to store losses and error
avg_loss = 0
avg_mae = 0
count = 0
# Switch model to train mode
model.train()
# Go over each batch in the dataloader
for i, dataset in enumerate(training_dataloader):
# Unpack data
node_mat = dataset[0][0]
adj_mat = dataset[0][1]
output = dataset[1]
# Reshape inputs
first_dim = int((torch.numel(node_mat)) / (max_atoms * node_vec_len))
node_mat = node_mat.reshape(first_dim, max_atoms, node_vec_len)
adj_mat = adj_mat.reshape(first_dim, max_atoms, max_atoms)
# Standardize output
output_std = standardizer.standardize(output)
# Package inputs and outputs; check if GPU is enabled
if use_GPU:
nn_input = (node_mat.cuda(), adj_mat.cuda())
nn_output = output_std.cuda()
else:
nn_input = (node_mat, adj_mat)
nn_output = output_std
# Compute output from network
nn_prediction = model(*nn_input)
# Calculate loss
loss = loss_fn(nn_output, nn_prediction)
avg_loss += loss
# Calculate MAE
prediction = standardizer.restore(nn_prediction.detach().cpu())
mae = mean_absolute_error(output, prediction)
avg_mae += mae
# Set zero gradients for all tensors
optimizer.zero_grad()
# Do backward prop
loss.backward()
# Update optimizer parameters
optimizer.step()
# Increase count
count += 1
# Calculate avg loss and MAE
avg_loss = avg_loss / count
avg_mae = avg_mae / count
# Print stats
print(
"Epoch: [{0}]\tTraining Loss: [{1:.2f}]\tTraining MAE: [{2:.2f}]"\
.format(
epoch, avg_loss, avg_mae
)
)
# Return loss and MAE
return avg_loss, avg_mae
最后,我们编写整体工作流程。这个脚本将调用我们之前定义的所有内容。
#### Fix seeds
np.random.seed(0)
torch.manual_seed(0)
use_GPU = torch.cuda.is_available()
#### Inputs
max_atoms = 200
node_vec_len = 60
train_size = 0.7
batch_size = 32
hidden_nodes = 60
n_conv_layers = 4
n_hidden_layers = 2
learning_rate = 0.01
n_epochs = 50
#### Start by creating dataset
main_path = Path(__file__).resolve().parent
data_path = main_path / "data" / "solubility_data.csv"
dataset = GraphData(dataset_path=data_path, max_atoms=max_atoms,
node_vec_len=node_vec_len)
#### Split data into training and test sets
# Get train and test sizes
dataset_indices = np.arange(0, len(dataset), 1)
train_size = int(np.round(train_size * len(dataset)))
test_size = len(dataset) - train_size
# Randomly sample train and test indices
train_indices = np.random.choice(dataset_indices, size=train_size,
replace=False)
test_indices = np.array(list(set(dataset_indices) - set(train_indices)))
# Create dataoaders
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler,
collate_fn=collate_graph_dataset)
test_loader = DataLoader(dataset, batch_size=batch_size,
sampler=test_sampler,
collate_fn=collate_graph_dataset)
#### Initialize model, standardizer, optimizer, and loss function
# Model
model = ChemGCN(node_vec_len=node_vec_len, node_fea_len=hidden_nodes,
hidden_fea_len=hidden_nodes, n_conv=n_conv_layers,
n_hidden=n_hidden_layers, n_outputs=1, p_dropout=0.1)
# Transfer to GPU if needed
if use_GPU:
model.cuda()
# Standardizer
outputs = [dataset[i][1] for i in range(len(dataset))]
standardizer = Standardizer(torch.Tensor(outputs))
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Loss function
loss_fn = torch.nn.MSELoss()
#### Train the model
loss = []
mae = []
epoch = []
for i in range(n_epochs):
epoch_loss, epoch_mae = train_model(
i,
model,
train_loader,
optimizer,
loss_fn,
standardizer,
use_GPU,
max_atoms,
node_vec_len,
)
loss.append(epoch_loss)
mae.append(epoch_mae)
epoch.append(i)
#### Test the model
# Call test model function
test_loss, test_mae = test_model(model, test_loader, loss_fn, standardizer,
use_GPU, max_atoms, node_vec_len)
#### Print final results
print(f"Training Loss: {loss[-1]:.2f}")
print(f"Training MAE: {mae[-1]:.2f}")
print(f"Test Loss: {test_loss:.2f}")
print(f"Test MAE: {test_mae:.2f}")
就这样!运行这个脚本应该会输出训练和测试的损失和错误。
4. 结果
具有给定架构和超参数的网络在开源的DeepChem 库上训练,该库包含约 1000 种小分子的水溶性。下图显示了一个特定训练-测试划分的训练损失曲线和测试集的对比图。训练集和测试集上的平均绝对误差分别为 0.59 和 0.58(以 log mol/l 为单位),低于线性模型的 0.69 log mol/l(基于数据集中的预测)。神经网络表现优于线性回归模型并不令人意外;尽管如此,这种粗略的比较使我们确信模型的预测是合理的。此外,我们仅通过在图中包含基本的结构描述符——原子序数和键长——来实现这一点,让卷积和池化函数建立这些描述符之间更复杂的关系,从而得出最准确的分子性质预测。
图 4:测试集的训练损失曲线(左)和对比图(右)(图像由作者提供)
5. 最后的说明
这绝不是解决所选问题的最终模型。改进模型的方式有很多,包括:
-
优化超参数
-
使用早停策略找到具有最低验证损失的模型
-
使用更复杂的卷积和池化函数
-
收集更多数据
尽管如此,本教程的目标是通过一个简单的例子阐述化学领域图卷积网络的基础知识。在掌握了基础知识后,你在 GCN 模型构建之旅中的可能性是无限的。
仓库和有用的参考资料
-
完整的代码(包括创建图形的脚本)提供在GitHub 仓库中。安装所需模块的说明也提供在那里。用于训练模型的数据集来自开源的DeepChem 库,该库在 MIT 许可下(允许商业使用)。仓库中的原始数据集文件名为 delaney_processed.csv。
-
关于图卷积网络的研究文章。本文介绍的卷积函数是本文中给出的函数的简化和更直观的形式。
-
关于消息传递神经网络的研究文章。这些是更为通用和富有表现力的图神经网络。可以证明,图卷积网络是具有特定类型消息函数的消息传递神经网络。
-
关于分子深度学习的在线书籍。这是一个极好的资源,可以帮助你学习化学深度学习的基础知识,并通过动手编码练习应用所学。
如果你有任何问题、评论或建议,请随时通过电子邮件联系我或通过X 联系我。
使用 Streamlit 构建 LAS 文件数据探索应用
原文:
towardsdatascience.com/building-a-las-file-data-explorer-app-with-streamlit-347289e0d000
使用 Python 和 Streamlit 探索 Log ASCII Standard 文件
·发布于 Towards Data Science ·14 min 阅读·2023 年 2 月 3 日
–
照片由 Carlos Muza 提供,Unsplash 上的
LAS 文件是石油和天然气行业中传输和存储井日志和/或岩石物理数据的标准且简单的方式。该格式在 80 年代末和 90 年代初由 加拿大井日志学会 开发,旨在标准化和组织数字日志信息。LAS 文件本质上是结构化的 ASCII 文件,包含多个部分,其中有关于井及其数据的信息;因此,它们可以在典型的文本编辑器中轻松查看,如记事本或 TextEdit。
Streamlit 是我最喜欢的 Python 库之一,用于创建快速且易于使用的仪表板或交互式工具。如果你想创建一个应用程序,使你或最终用户无需担心代码,它也非常棒。因此,在本文中,我们将深入了解如何使用 Streamlit 构建 LAS 文件的数据探索应用。
如果你想查看完整的应用演示,请查看下面的短视频。
或者在 GitHub 上探索源代码:
[## GitHub - andymcdgeo/las_explorer: LAS Explorer 是一个 Streamlit 网络应用,允许你理解…
LAS Explorer 是一个 Streamlit 网络应用,允许你理解 LAS 文件的内容。还包括…
github.com](https://github.com/andymcdgeo/las_explorer?source=post_page-----347289e0d000--------------------------------)
如果你想了解如何在 Python 中处理 LAS 文件,以下文章可能会引起你的兴趣:
安装和设置 Streamlit
我们应用的第一部分将涉及导入所需的库和模块。
这些是:
导入这些库后,我们可以在最后添加一行代码,将页面宽度设置为全页,并更改浏览器窗口中的应用标题。
import streamlit as st
import lasio
import pandas as pd
from io import StringIO
# Plotly imports
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
st.set_page_config(layout="wide", page_title='LAS Explorer v.0.1')
为了检查 Streamlit 是否正常工作,我们可以在终端中运行以下命令:
streamlit run app.py
这将打开一个浏览器窗口,显示一个空白的 Streamlit 应用。
空白的 Streamlit 应用。图片来源:作者。
使用 st.file_uploader 加载 LAS 文件
我们要添加到应用中的第一段代码是调用 st.sidebar
。这将创建一个位于应用左侧的列,我们将用它来存储我们的导航菜单和文件上传小部件。
st.sidebar.write('# LAS Data Explorer')
st.sidebar.write('To begin using the app, load your LAS file using the file upload option below.')
我们可以使用 st.sidebar.write
添加一些消息和说明给最终用户。在这个示例中,我们将保持相对简单,提供应用名称和如何开始的消息。
一旦侧边栏到位,我们可以开始实现文件上传器的代码部分。
las_file=None
uploadedfile = st.sidebar.file_uploader(' ', type=['.las'])
las_file, well_data = load_data(uploadedfile)
if las_file:
st.sidebar.success('File Uploaded Successfully')
st.sidebar.write(f'<b>Well Name</b>: {las_file.well.WELL.value}',
unsafe_allow_html=True)
为此,我们需要调用 st.file_uploader
。我们还将限制文件类型为 .las 文件。为了更实用,我们可能还希望包含大写版本的扩展名。
接下来,我们将调用 load data 函数,稍后我们将详细介绍。该函数将设置为返回 las_file
作为一个 lasio las 文件对象,以及 well_data
作为包含井日志测量数据的数据框。
随后,我们将检查是否有 las 文件。如果设置为 None
,则不会发生任何事情;然而,如果文件通过 load_data
函数成功加载,则它不会是 None
,因此会执行下面的代码。
if 函数中的代码本质上显示了一个彩色标注,后跟 las 文件的井名称。
在运行 Streamlit 应用之前,我们需要创建 load_data
函数。这将允许我们读取数据并生成 lasio las 文件对象和 pandas 数据框。
@st.cache
def load_data(uploaded_file):
if uploaded_file is not None:
try:
bytes_data = uploaded_file.read()
str_io = StringIO(bytes_data.decode('Windows-1252'))
las_file = lasio.read(str_io)
well_data = las_file.df()
well_data['DEPTH'] = well_data.index
except UnicodeDecodeError as e:
st.error(f"error loading log.las: {e}")
else:
las_file = None
well_data = None
return las_file, well_data
当我们运行 Streamlit LAS 数据浏览器应用时,我们将看到左侧的侧边栏以及文件上传小部件。
添加侧边栏到 LAS 文件数据浏览器 Streamlit 应用后。图片由作者提供。
然后我们可以点击浏览文件并搜索一个 las 文件。
一旦该文件被加载,我们将看到绿色的提示,表示文件加载成功,接着是文件中包含的井名。
成功读取 LAS 文件与 LAS 数据浏览器 Streamlit 应用。图片由作者提供。
向 Streamlit 应用中添加主页
当有人第一次启动 LAS 数据浏览器应用时,展示应用的名称和简要描述会很好。
st.title('LAS Data Explorer - Version 0.2.0')
st.write('''LAS Data Explorer is a tool designed using Python and
Streamlit to help you view and gain an understanding of the contents
of a LAS file.''')
st.write('\n')
当我们重新运行应用时,现在将看到我们的主页。这可以扩展以包括额外的说明、有关应用的详细信息以及如果出现问题如何联系。
创建主页后的 LAS 数据浏览器 Streamlit 应用。图片由作者提供。
在构建 Streamlit 应用时,最好将代码拆分成函数,并在适当的时间调用它们。这使得代码更具模块化,更易于导航。
对于我们的主页,我们将上述代码放入一个名为 home()
的函数中。
def home():
st.title('LAS Data Explorer - Version 0.2.0')
st.write('''LAS Data Explorer is a tool designed using Python and
Streamlit to help you view and gain an understanding of the contents
of a LAS file.''')
st.write('\n')
添加导航单选按钮
在构建 Streamlit 应用时,很容易陷入一个不断添加部分的陷阱,结果是生成一个很长的可滚动网页。
使 Streamlit 应用更具可导航性的一种方法是添加导航菜单。这允许你将内容拆分到多个页面上。
实现这一点的一种方法是使用一系列单选按钮,这些按钮在切换时将更改主界面上显示的内容。
首先,我们需要为导航部分指定一个标题,然后我们必须调用 st.sidebar.radio
并传入一个我们希望用户能够导航到的页面列表。
# Sidebar Navigation
st.sidebar.title('Navigation')
options = st.sidebar.radio('Select a page:',
['Home', 'Header Information', 'Data Information',
'Data Visualisation', 'Missing Data Visualisation'])
当我们运行应用时,我们将看到现在有一个由单选按钮表示的导航菜单。
添加了单选按钮导航菜单后的 LAS 数据浏览器。图片由作者提供。
目前,如果点击按钮,什么也不会发生。
我们需要告诉 Streamlit 在进行选择时该做什么。
通过创建如下的 if/elif 语句来实现。当选择了一个选项时,将调用一个特定的函数。
例如,如果用户选择了主页,则会显示先前创建的主页函数。
if options == 'Home':
home()
elif options == 'Header Information':
header.header(las_file)
elif options == 'Data Information':
raw_data(las_file, well_data)
elif options == 'Data Visualisation':
plot(las_file, well_data)
elif options == 'Missing Data Visualisation':
missing(las_file, well_data)
让我们开始实现其他部分,以便开始显示一些内容。
从 LAS 文件中检索井头信息
在每个 las 文件中,顶部有一个包含有关井的信息的部分。这包括井名、国家、操作员等。
Volve 田野的 LAS 文件头示例。图片由作者提供。
为了读取这些信息,我们将创建一个名为header
的新函数,然后遍历头部中的每一行。
为了防止用户点击头信息单选按钮时出现错误,我们需要检查在加载过程中是否已经创建了 las 文件对象。否则,我们将向用户展示错误。
然后,对于每个头项,我们将显示描述名称(item.descr
)、助记符(item.mnemonic
)和相关值(item.value
)。
def header(las_file):
st.title('LAS File Header Info')
if not las_file:
st.warning('No file has been uploaded')
else:
for item in las_file.well:
st.write(f"<b>{item.descr.capitalize()} ({item.mnemonic}):</b> {item.value}",
unsafe_allow_html=True)
当应用程序重新运行,并从导航菜单中选择头信息页面时,我们现在会看到相关的井信息。
来自 LAS 文件的井日志头信息。图片由作者提供。
检索井日志测量信息
在成功读取头信息后,我们接下来要查看 las 文件中包含了哪些井日志测量。
为此,我们将创建一个简单的函数,名为raw_data
,它将:
-
遍历 las 文件中的每个测量,写出它的助记符、单位和描述
-
提供测量总数的统计
-
使用 pandas 的
describe
方法为每个测量创建一个统计摘要表 -
创建一个包含所有原始值的数据表
对于一个单一的函数来说,这个工作量很大,可能需要整理一下,但对于这个简单的应用程序,我们将把它们保持在一起。
def raw_data(las_file, well_data):
st.title('LAS File Data Info')
if not las_file:
st.warning('No file has been uploaded')
else:
st.write('**Curve Information**')
for count, curve in enumerate(las_file.curves):
st.write(f" {curve.mnemonic} ({curve.unit}): {curve.descr}",
unsafe_allow_html=True)
st.write(f"<b>There are a total of: {count+1} curves present within this file</b>",
unsafe_allow_html=True)
st.write('<b>Curve Statistics</b>', unsafe_allow_html=True)
st.write(well_data.describe())
st.write('<b>Raw Data Values</b>', unsafe_allow_html=True)
st.dataframe(data=well_data)
当 Streamlit 应用重新运行时,我们将看到所有与井日志测量相关的信息。
首先,我们有井测量信息和相关统计数据。
LAS 井日志测量信息。图片由作者提供。
然后是原始数据值。
LAS 井日志测量信息。图片由作者提供。
使用 Plotly 在 Streamlit 中可视化井日志数据
与任何数据集一样,仅通过分析原始数字很难掌握数据的外观。为了进一步深入,我们可以使用交互式图表。
这些将使最终用户更容易更好地理解数据。
以下代码在 Streamlit 页面上生成多个图表。所有内容都包含在一个函数中,以便在这个应用程序中使用。请记住,每个函数代表 LAS 数据探索器应用中的一个页面。
为了避免使用多个页面,下面的代码将为三种不同的图生成三个展开器:折线图、直方图和散点图(在岩石物理学中也称为交叉图)。
def plot(las_file, well_data):
st.title('LAS File Visualisation')
if not las_file:
st.warning('No file has been uploaded')
else:
columns = list(well_data.columns)
st.write('Expand one of the following to visualise your well data.')
st.write("""Each plot can be interacted with. To change the scales of a plot/track, click on the left hand or right hand side of the scale and change the value as required.""")
with st.expander('Log Plot'):
curves = st.multiselect('Select Curves To Plot', columns)
if len(curves) <= 1:
st.warning('Please select at least 2 curves.')
else:
curve_index = 1
fig = make_subplots(rows=1, cols= len(curves), subplot_titles=curves, shared_yaxes=True)
for curve in curves:
fig.add_trace(go.Scatter(x=well_data[curve], y=well_data['DEPTH']), row=1, col=curve_index)
curve_index+=1
fig.update_layout(height=1000, showlegend=False, yaxis={'title':'DEPTH','autorange':'reversed'})
fig.layout.template='seaborn'
st.plotly_chart(fig, use_container_width=True)
with st.expander('Histograms'):
col1_h, col2_h = st.columns(2)
col1_h.header('Options')
hist_curve = col1_h.selectbox('Select a Curve', columns)
log_option = col1_h.radio('Select Linear or Logarithmic Scale', ('Linear', 'Logarithmic'))
hist_col = col1_h.color_picker('Select Histogram Colour')
st.write('Color is'+hist_col)
if log_option == 'Linear':
log_bool = False
elif log_option == 'Logarithmic':
log_bool = True
histogram = px.histogram(well_data, x=hist_curve, log_x=log_bool)
histogram.update_traces(marker_color=hist_col)
histogram.layout.template='seaborn'
col2_h.plotly_chart(histogram, use_container_width=True)
with st.expander('Crossplot'):
col1, col2 = st.columns(2)
col1.write('Options')
xplot_x = col1.selectbox('X-Axis', columns)
xplot_y = col1.selectbox('Y-Axis', columns)
xplot_col = col1.selectbox('Colour By', columns)
xplot_x_log = col1.radio('X Axis - Linear or Logarithmic', ('Linear', 'Logarithmic'))
xplot_y_log = col1.radio('Y Axis - Linear or Logarithmic', ('Linear', 'Logarithmic'))
if xplot_x_log == 'Linear':
xplot_x_bool = False
elif xplot_x_log == 'Logarithmic':
xplot_x_bool = True
if xplot_y_log == 'Linear':
xplot_y_bool = False
elif xplot_y_log == 'Logarithmic':
xplot_y_bool = True
col2.write('Crossplot')
xplot = px.scatter(well_data, x=xplot_x, y=xplot_y, color=xplot_col, log_x=xplot_x_bool, log_y=xplot_y_bool)
xplot.layout.template='seaborn'
col2.plotly_chart(xplot, use_container_width=True)
一旦上述代码实现后,我们可以看到 LAS 文件可视化页面,包含三个可展开的框。
在地球科学和岩石物理学中,我们经常在折线图上绘制数据——通常称为日志图。y 轴通常表示井眼深度,而 x 轴表示我们希望可视化的数据。这使我们可以轻松地可视化这些测量数据随深度的趋势和模式。
在日志图部分,我们可以从数据框中选择特定列,并在交互式 Plotly 图表中显示它们。
使用 Plotly 创建的井日志图,并显示在 LAS 数据探索者 Streamlit 应用中。图片由作者提供。
直方图显示数据分布,并允许我们在一个小而简洁的图表中包含大量数据。
在直方图部分,我们有一些基本选项。我们可以从数据框中选择一列进行显示,并决定是否以线性或对数方式显示。
最后,我们可以使用 Streamlit 的颜色选择器。这允许你为直方图选择颜色,可以增强你在演示和报告中的可视化效果。
使用 Plotly 在 LAS 数据探索者 Streamlit 应用中创建的直方图。图片由作者提供。
散点图(交叉图)通常在岩石物理学和数据科学中用于比较两个变量。 从这种图表中,我们可以了解两个变量之间是否存在关系以及这种关系的强度。
在数据可视化页面的交叉图部分,我们可以选择 x 轴和 y 轴变量,以及一个第三变量,用于数据的颜色编码。
最后,我们可以将 x 轴和 y 轴设置为线性刻度或对数刻度。
使用 Plotly 在 LAS 数据探索者 Streamlit 应用中创建的散点图/交叉图。图片由作者提供。
识别井日志测量中的缺失数据
缺失数据是我们在处理数据集时面临的最常见的数据质量问题之一。它可能因多种原因而缺失,从传感器故障到不当和可能粗心的数据管理。
在处理数据集时,识别缺失数据并理解数据缺失的根本原因是至关重要的。对数据缺失原因的正确理解是开发务实解决方案的关键,尤其是许多机器学习算法无法处理缺失值。
在 Python 中,我们可以使用 pandas 的 describe
函数提供的文本数据摘要。虽然这很有用,但在图表中可视化缺失数据值通常更有帮助。这使我们能够轻松识别可能在基于文本的摘要中不明显的模式和关系。
为了创建数据完整性的交互式图表,我们可以利用 Plotly 库。下面的代码设置了 LAS 数据浏览器应用中的缺失数据可视化页面。
首先,我们检查是否有有效的 LAS 文件;如果有,我们开始创建页面并添加一些说明文本。
接下来,我们为用户提供一个选项,以选择数据框中的所有数据或选择特定列。在这旁边,我们允许用户更改图表中条形的颜色。
然后,我们继续根据用户选择绘制数据。
def missing(las_file, well_data):
st.title('LAS File Missing Data')
if not las_file:
st.warning('No file has been uploaded')
else:
st.write("""The following plot can be used to identify the depth range of each of the logging curves.
To zoom in, click and drag on one of the tracks with the left mouse button.
To zoom back out double click on the plot.""")
data_nan = well_data.notnull().astype('int')
# Need to setup an empty list for len check to work
curves = []
columns = list(well_data.columns)
columns.pop(-1) #pop off depth
col1_md, col2_md= st.columns(2)
selection = col1_md.radio('Select all data or custom selection', ('All Data', 'Custom Selection'))
fill_color_md = col2_md.color_picker('Select Fill Colour', '#9D0000')
if selection == 'All Data':
curves = columns
else:
curves = st.multiselect('Select Curves To Plot', columns)
if len(curves) <= 1:
st.warning('Please select at least 2 curves.')
else:
curve_index = 1
fig = make_subplots(rows=1, cols= len(curves), subplot_titles=curves, shared_yaxes=True, horizontal_spacing=0.02)
for curve in curves:
fig.add_trace(go.Scatter(x=data_nan[curve], y=well_data['DEPTH'],
fill='tozerox',line=dict(width=0), fillcolor=fill_color_md), row=1, col=curve_index)
fig.update_xaxes(range=[0, 1], visible=False)
fig.update_xaxes(range=[0, 1], visible=False)
curve_index+=1
fig.update_layout(height=700, showlegend=False, yaxis={'title':'DEPTH','autorange':'reversed'})
# rotate all the subtitles of 90 degrees
for annotation in fig['layout']['annotations']:
annotation['textangle']=-90
fig.layout.template='seaborn'
st.plotly_chart(fig, use_container_width=True)
当我们访问 LAS 数据浏览器的这一页时,我们会看到一个互动的 Plotly 图表,如下所示。如果用户选择了“所有数据”,则所有列都会显示出来。
使用 Streamlit 在 Plotly 图表中显示 pandas 数据框的所有列。图片由作者提供。
如果用户选择了“自定义选择”,则他们可以直接从数据框中选择列。
使用 Streamlit 多选框从数据框中选择列,并在 Plotly 图表中显示它们。图片由作者提供。
如果你想查看使用 Python 识别缺失值的其他方法,请查看下面的文章:
- 使用 missingno Python 库识别和可视化机器学习前的缺失数据
摘要
在本文中,我们展示了如何使用 Streamlit 和 Python 构建一个用于探索 LAS 文件的应用程序。虽然这是一个基础应用,但它可以作为查看原始 LAS 文件的一种有用替代方案。还可以添加更多功能来编辑文件或将其转换为其他标准格式。可能性无穷无尽!
本教程中使用的数据
本教程中使用的数据是 Equinor 于 2018 年发布的 Volve 数据集的一个子集。数据集的完整详细信息,包括许可证,可以在下面的链接中找到。
Equinor 已正式提供了一整套来自北海油田的数据,用于研究、学习等……
Volve 数据许可证基于 CC BY 4.0 许可证。许可证协议的完整详细信息可以在这里找到:
感谢阅读。在离开之前,你应该肯定地订阅我的内容,并将我的文章发送到你的收件箱。 你可以在这里做到这一点!另外,你还可以 注册我的通讯 以便免费获取额外的内容。
其次,你可以通过注册会员,获得完整的 Medium 体验,支持我和其他成千上万的作家。每月只需花费 5 美元,你即可全面访问所有精彩的 Medium 文章,还能有机会通过写作赚钱。
如果你使用 我的链接, 你将直接通过你的费用的一部分支持我,而且不会额外花费你更多。如果你这样做了,非常感谢你的支持。
用 Hamilton 在 13 分钟内构建一个可维护且模块化的 LLM 应用堆栈
LLM 应用是数据流,使用专门设计的工具来表达它们
·
关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 7 月 13 日
–
LLM 堆栈。使用合适的工具,如 Hamilton,可以确保你的堆栈不会变得难以维护和管理。图片来源于 pixabay。
此文章与 Thierry Jean 合作撰写,最初发布于 此处。
在这篇文章中,我们将分享如何使用Hamilton这一开源框架来编写模块化和可维护的代码,以支持你的大型语言模型(LLM)应用程序堆栈。Hamilton 非常适合描述任何类型的数据流,这正是你在构建 LLM 驱动应用程序时所做的。通过 Hamilton,你可以获得强大的软件维护人机工程学,同时还能轻松地交换和评估应用程序组件的不同提供者/实现。免责声明:我是 Hamilton 包的作者之一。
我们将演示的示例将镜像你用于填充向量数据库的典型 LLM 应用程序工作流程。具体来说,我们将涵盖从网络中提取数据、创建文本嵌入(向量)并将其推送到向量存储中。
堆栈概述。作者提供的图像。
LLM 应用程序数据流
首先,让我们描述一下典型的 LLM 数据流的组成。应用程序将接收一个小的数据输入(例如文本、命令),并在更大的上下文中进行操作(例如聊天记录、文档、状态)。这些数据将通过不同的服务(LLM、向量数据库、文档存储等)进行操作、生成新的数据工件,并返回最终结果。大多数用例会在迭代不同输入的过程中重复这一流程多次。
一些常见的操作包括:
-
将文本转换为嵌入
-
存储 / 搜索 / 检索嵌入
-
查找嵌入的最近邻
-
检索用于嵌入的文本
-
确定传递到提示中的上下文
-
使用相关文本中的上下文提示模型
-
将结果发送到其他服务(API、数据库等)
-
…
-
并将它们串联起来!
现在,让我们在生产环境中深入探讨上述内容,假设用户对你的应用程序的输出不满意,并且你想找到问题的根源。你的应用程序记录了提示和结果。你的代码允许你找出操作的顺序。然而,你不知道问题出在哪里,系统产生了不理想的输出……为了解决这个问题,我们认为跟踪数据工件及生成它们的代码是关键,这样你才能快速调试类似的情况。
由于许多操作是非确定性的,这增加了你的 LLM 应用程序数据流的复杂性,这意味着你不能重新运行或逆向工程操作以重现中间结果。例如,即使你拥有相同的输入和配置,生成文本或图像响应的 API 调用可能也是不可重复的(你可以通过如temperature这样的选项缓解部分问题)。这也扩展到某些向量数据库操作,如“查找最近”——其结果取决于数据库中当前存储的对象。在生产环境中,快照数据库状态以使调用可重复几乎是不现实的。
基于这些原因,采用灵活的工具以创建稳健的数据流很重要,这样可以让你:
-
轻松地插入各种组件。
-
了解组件之间如何连接。
-
添加和定制常见的生产需求,如缓存、验证和可观察性。
-
根据你的需求调整流结构,而不需要强大的工程技能
-
插件集成到传统的数据处理和机器学习生态系统中。
在这篇文章中,我们将概述 Hamilton 如何满足第 1、2 和 4 点。有关第 3 和 5 点的信息,请参阅我们的文档。
当前的 LLM 应用程序开发工具
LLM 领域仍处于起步阶段,使用模式和工具正在快速演变。虽然 LLM 框架可以让你入门,但当前的选项并未经过生产环境测试;据我们了解,目前没有成熟的科技公司在生产中使用当前流行的 LLM 框架。
别误解我们的意思,有些工具确实非常适合快速建立概念验证!然而,我们认为它们在两个特定领域存在不足:
1. 如何建模 LLM 应用程序的数据流。 我们强烈认为“动作”的数据流建模更适合用函数来表示,而不是通过面向对象的类和生命周期。函数更容易推理、测试和更改。面向对象的类可能变得相当晦涩,并带来更多的思维负担。
当出现错误时,面向对象的框架需要你深入到对象的源代码中以理解它。而使用 Hamilton 函数时,清晰的依赖关系谱能告诉你在哪里查找,并帮助你推理发生了什么(更多信息请见下文)!
2. 定制/扩展。 不幸的是,一旦你超出框架提供的“简单”功能,你需要强大的软件工程技能来修改当前框架。如果这不是一个选项,这意味着你可能会在特定的自定义业务逻辑上脱离框架,这可能会导致你维护更多的代码面积,而不是如果你一开始就不使用框架的话。
关于这两点的更多信息,我们推荐你查看这些讨论线程(hacker news, reddit),其中有用户详细讨论。
虽然 Hamilton 并不是当前 LLM 框架的完整替代品(例如,没有“代理”组件),但它确实拥有满足 LLM 应用程序需求的所有构建模块,并且两者可以协同工作。如果你想要一种干净、清晰且可定制的方式来编写生产代码、集成多个 LLM 技术栈组件,并对你的应用程序进行观察,那么让我们继续进入接下来的几个部分吧!
使用 Hamilton 构建
Hamilton 是一个声明式微框架,用于在 Python 中描述数据流。它不是一个新框架(已有 3.5 年以上历史),并且在生产建模数据和机器学习数据流中使用多年。它的优势在于以一种直观易创建和维护的方式表达数据和计算流(类似于 DBT 对 SQL 的作用),这非常适合支持建模 LLM 应用程序的数据和计算需求。
Hamilton 范式的示意图。与其使用过程性赋值,不如将其建模为一个函数。函数名称是你可以获得的“输出”,而函数输入参数声明了计算所需的依赖关系。图片由作者提供。
Hamilton 的基础知识很简单,而且可以通过多种方式扩展;你不必了解 Hamilton 就能从这篇文章中获得价值,但如果你感兴趣,可以查看:
-
tryhamilton.dev – 在浏览器中的互动教程!
-
在 5 分钟内使用 Hamilton 进行 Pandas 数据转换
进入我们的示例
为了帮助建立一些心理背景,想象一下。你是一个小型数据团队,负责创建一个 LLM 应用程序,与组织的文档进行“聊天”。你认为评估候选架构在功能、性能配置、许可证、基础设施要求和成本方面是很重要的。最终,你知道你组织的主要关注点是提供最相关的结果和良好的用户体验。评估这些的最佳方法是构建一个原型,测试不同的技术栈,并比较它们的特性和输出。然后,当你过渡到生产环境时,你会希望确保系统能够轻松维护和检查,以始终提供优质的用户体验。
有鉴于此,在这个示例中,我们将实现 LLM 应用程序的一部分,特别是数据摄取步骤,用于索引知识库,其中我们将文本转换为嵌入并存储在向量数据库中。我们使用几种不同的服务/技术以模块化的方式实现这一点。广泛的步骤包括:
-
从 HuggingFace Hub 加载 SQuAD 数据集。你可以将其替换为你的预处理文档的语料库。
-
使用 Cohere API、OpenAI API 或 SentenceTransformer 库 嵌入文本条目。
如果你需要了解更多关于嵌入和搜索的信息,我们推荐以下链接:
在我们讲解这个示例时,考虑以下几点会对你有帮助:
-
将我们展示的内容与当前做的事情进行比较。 看到 Hamilton 如何使你能够策划和结构化一个项目,而无需明确的 LLM 重点框架。
-
项目和应用结构。 了解 Hamilton 如何强制执行一种结构,使你能够构建和维护模块化堆栈。
-
迭代中的信心和项目的持久性。 结合上述两点,Hamilton 使你能够更轻松地维护生产中的 LLM 应用程序,无论它的作者是谁。
让我们从一个可视化开始,以便你能对我们谈论的内容有一个概览:
Hamilton DAG 可视化 Pinecone + 句子变换器堆栈。图片由作者提供。
当使用 pinecone 和句子变换器时,LLM 应用程序的数据流将如下所示。借助 Hamilton,了解事物的连接就像在 Hamilton 驱动程序对象上调用display_all_functions()
一样简单。
模块化代码
让我们解释一下使用 Hamilton 实现模块化代码的两种主要方式,以我们的示例为背景。
@config.when
Hamilton 关注可读性。虽然没有解释@config.when
的作用,你可能已经可以判断这是一个条件语句,且仅在满足条件时包括。下面你将找到使用 OpenAI 和 Cohere API 将文本转换为嵌入的实现。
Hamilton 将识别两个函数作为替代实现,因为@config.when
装饰器和相同的函数名称embeddings
位于双下划线(__cohere
、__openai
)之前。它们的函数签名不必完全相同,这意味着采纳不同实现是简单且清晰的。
embedding_module.py
对于这个项目,将所有嵌入服务实现放在同一个文件中并使用@config.when
装饰器是合理的,因为每个服务只有 3 个函数。然而,随着项目复杂性的增长,函数也可以移动到单独的模块中,并采用下一节的模块化模式。另一个要点是这些函数都是独立可单元测试的。如果你有特定的需求,将其封装到函数中并进行测试是很简单的。
更换 Python 模块
下面你将看到 Pinecone 和 Weaviate 的向量数据库操作实现。请注意,这些代码片段来自pinecone_module.py
和weaviate_module.py
,并观察函数签名的相似之处和不同之处。
pinecone_module.py 和 weaviate_module.py
使用 Hamilton 时,数据流通过函数名称和函数输入参数连接在一起。因此,通过共享类似操作的函数名称,这两个模块可以轻松互换。由于 LanceDB、Pinecone 和 Weaviate 实现分别存在于不同的模块中,这减少了每个文件的依赖数量,使文件更短,从而提高了可读性和可维护性。每个实现的逻辑都清晰地封装在这些命名函数中,因此针对每个模块进行单元测试是直接可行的。分离的模块强化了它们不应同时加载的概念。当发现多个相同名称的函数时,Hamilton 驱动程序实际上会抛出错误,这有助于加强这一概念。
驱动程序的影响
运行 Hamilton 代码的关键部分是Driver
对象,它在run.py
中找到。排除 CLI 和一些参数解析的代码,我们得到:
run.py 的代码片段
Hamilton 驱动程序负责协调执行,并且是你通过它操控数据流的工具,通过上述代码片段中看到的三种机制实现了模块化:
-
Driver 配置。 这是一个字典,驱动程序在实例化时接收该字典,包含应该保持不变的信息,例如使用哪个 API 或嵌入服务 API 密钥。这与可以传递 JSON 或字符串的命令平面(例如 Docker 容器、Airflow、Metaflow 等)很好地集成。具体来说,这里是我们指定要更换哪个嵌入 API 的地方。
-
驱动程序模块。 驱动程序可以接收任意数量的独立 Python 模块来构建数据流。在这里,
vector_db_module
可以替换为我们连接的所需向量数据库实现。还可以通过 importlib 动态导入模块,这在开发与生产环境中可能很有用,同时也能实现通过配置驱动的方式来改变数据流实现。 -
驱动程序执行。
final_vars
参数决定了应该返回什么输出。你无需重构代码来改变想要获得的输出。举个例子,如果你想调试数据流中的某些内容,可以通过将函数的名称添加到final_vars
来请求任何函数的输出。例如,如果你有一些中间输出需要调试,可以很容易地请求它,或者完全在那个点停止执行。请注意,驱动程序在调用execute()
时可以接收inputs
和overrides
值;在上面的代码中,class_name
是一个执行时的input
,指示我们要创建的嵌入对象以及将其存储在向量数据库中的位置。
模块化总结
在 Hamilton 中,使组件可互换的关键是:
-
定义具有相同名称的函数,然后,
-
使用
@config.when
对它们进行注解,并通过传递给驱动程序的配置选择使用哪一个,或者, -
将它们放在不同的 Python 模块中,并将所需的模块传递给驱动程序。
所以我们刚刚展示了如何使用 Hamilton 插件、交换和调用各种 LLM 组件。我们无需解释什么是面向对象的层次结构,也不要求你具备广泛的软件工程经验(我们希望如此!)。为了实现这一点,我们只需匹配函数名称及其输出类型。因此,我们认为这种编写和模块化代码的方式比当前 LLM 框架所允许的更加可访问。
Hamilton 代码的实际应用
为了支持我们的主张,这里有一些我们观察到的将 Hamilton 代码应用于 LLM 工作流的实际影响:
CI/CD
模块/@config.when
的可互换性也意味着在 CI 系统中的集成测试非常容易思考,因为你可以根据需要灵活地交换或隔离数据流的部分。
协作
-
Hamilton 实现的模块化可以轻松跨团队边界镜像。函数名称及其输出类型成为合同,确保可以进行有针对性的更改并对更改充满信心,还可以通过 Hamilton 的 可视化和血统功能 了解下游依赖关系(就像我们看到的初始可视化一样)。例如,如何与向量数据库交互并进行消费就非常清晰。
-
代码更改更易于审查,因为流程由声明式函数定义。更改是自包含的;由于没有面向对象的层次结构需要学习,只需修改一个函数。任何“自定义”的内容都被 Hamilton 默认支持。
调试
当 Hamilton 出现错误时,很清楚它映射到的代码是什么,并且由于函数的定义,你知道它在数据流中的位置。
以使用 cohere 的 embeddings 函数为简单示例。如果发生超时或解析响应时出错,将清楚地映射到这段代码,并且通过函数定义你会知道它在流程中的位置。
@config.when(embedding_service="cohere")
def embeddings__cohere(
embedding_provider: cohere.Client,
text_contents: list[str],
model_name: str = "embed-english-light-v2.0",
) -> list[np.ndarray]:
"""Convert text to vector representations (embeddings) using Cohere Embed API
reference: https://docs.cohere.com/reference/embed
"""
response = embedding_provider.embed(
texts=text_contents,
model=model_name,
truncate="END",
)
return [np.asarray(embedding) for embedding in response.embeddings]
可视化显示embeddings
在数据流中的位置。图像由作者提供。
创建模块化 LLM 堆栈的技巧
在结束之前,这里有一些想法来指导你构建应用程序。某些决策可能没有明显的最佳选择,但正确的模块化方法将使你能够随着需求的变化高效迭代。
-
在编写任何代码之前,绘制你的工作流的 DAG。这为定义通用步骤和接口奠定了基础,这些步骤和接口不是特定于服务的。
-
确定可以交换的步骤。通过有目的地设置配置点,你将减少投机泛化的风险。具体来说,这将导致具有较少参数、默认值且按主题模块分组的函数。
-
将数据流的部分切分成依赖较少的模块(如有相关)。这将导致更短的 Python 文件,减少包依赖,提高可读性和可维护性。Hamilton 对此不在意,可以从多个模块构建其 DAG。
结论与未来方向
感谢你阅读到这里。我们相信 Hamilton 在帮助每个人表达他们的数据流方面有一定作用,而 LLM 应用程序只是其中一个用例!总结我们在这篇文章中的信息,可以归纳为:
-
将 LLM 应用程序视为数据流是有用的,因此非常适合使用 Hamilton。
-
面向对象的 LLM 框架可能不透明且难以扩展和维护以满足生产需求。相反,应该使用 Hamilton 简单的声明式风格编写自己的集成。这样可以提高代码的透明度和可维护性,具有清晰的可测试函数、明确的运行时错误映射到函数的方式,以及内置的数据流可视化。
-
使用 Hamilton 所规定的模块化将使协作更高效,并为你提供必要的灵活性,以便按照该领域的进展速度修改和更改 LLM 工作流。
现在邀请你在这里玩转、尝试和修改完整的示例。这里有一个README
文件会解释如何运行命令和开始使用。否则,我们正在思考以下内容来提升 Hamilton + LLM 应用体验:
-
代理。 我们能否为代理提供与常规 Hamilton 数据流相同的可视性?
-
并行化。 我们如何简化在文档列表上运行数据流的表达方式。请参见这个进行中的 PR了解我们的意思。
-
缓存和可观察性的插件。 目前已经可以在 Hamilton 上实现自定义的缓存和可观察性解决方案。我们正在致力于为常见组件提供更多的标准选项,例如 redis。
-
用户贡献的数据流部分。 我们看到可以在特定 LLM 应用用例上标准化常见名称的可能性。在这种情况下,我们可以开始聚合 Hamilton 数据流,并允许人们根据自己的需求下载。
我们想听听你的意见!
如果你对这些内容感到兴奋,或有强烈的看法,欢迎访问我们的 Slack 频道或在这里留下评论!一些可以帮助你的资源:
📣 加入我们的Slack社区 — 我们很乐意帮助解答你可能遇到的问题或帮助你入门。
⭐️ 在GitHub上给我们点赞
📝 如果你发现了问题,请给我们留下一个issue
你可能感兴趣的其他 Hamilton 文章:
-
tryhamilton.dev – 一个在浏览器中进行交互式教程的平台!
帮助初创公司创始人找到最佳孵化器:一个端到端的项目。
一个自由职业项目的演示,使用 Python、Pinecone、FastAPI、Pydantic 和 Docker 提出最佳孵化器的建议
·发布在 Towards Data Science ·15 min 阅读·2023 年 11 月 26 日
–
Harness,一个致力于帮助创始人创业的初创公司,找到我开发了一个帮助其社区找到最合适孵化器的工具:匹配工具。
在本文中,我们将介绍这个项目的不同阶段,从解决方案设计到交付。
Rames Quinerie 在 Unsplash 上的照片
背景
该公司及其联合创始人希望创建一个工具,使他们的初创公司创始人社区能够找到全球最佳的孵化器和加速器。
为了实现这一目标,他们手动从孵化器网站收集数据,包括位置、各种要求、资金机会等详细信息。此外,他们还利用了一个活跃的创始人社区。
利用孵化器和其社区的数据,他们需要找到一种方法来检索基于初创公司信息的前 k 名孵化器。
挑战接受。
解决方案设计
概述
乍一看,这个项目看起来像是一个推荐系统,比如 Netflix 或 Amazon 用于向用户推荐最佳的系列或产品。通过用户行为,如点击、评论或点赞,公司可以预测并推荐最合适的产品。
然而,在这种特定情况下,我们缺乏关于创始人偏好的任何先前数据。因此,在这种情况下构建推荐系统是不可行的。
另一种方法可以涉及将孵化器和初创企业数据嵌入到向量空间中进行相似性搜索。简而言之,这种方法涉及测量向量之间的距离,以确定最接近给定初创企业的孵化器。
但这种方法在这种情况下有很多缺陷。
孵化器具有我所称的硬标准,这些因素可能导致任何不符合要求的初创企业被立即拒绝。这可能包括如果孵化器要求混合或面对面的出席,位置不在同一城市,或缺乏资金。
那些硬标准会使嵌入(数据的向量表示)在这种情况下不是一个好的方法。例如,一个孵化器可能完全匹配一个初创企业,但如果申请未开放,则不应向创始人推荐这个孵化器。
这些硬标准的存在使得在整个数据集上使用嵌入不适合这种情况。例如,即使一个孵化器与初创企业完美对接,如果当前没有开放申请,也不适合向创始人推荐。
最后,即使大多数特征可以转化为数值(融资金额,接受的前期融资金额,初创企业收入预期)或分类(国家,出席要求,MVP 准备好),某些特征由于其多样性却无法分类:
-
融资工具: 赠款,140k$,股权(SAFE),…
-
行业重点: 医疗科技,人工智能,金融科技,…
此外,这些特征必须在匹配工具中考虑,但可能不会被视为硬标准。例如,创始人可能会选择一个专注于健康科技的孵化器,并且仍然愿意接受一个生物技术初创企业。
混合方法
为了解决这些问题,我们来考虑最佳的两全其美的方案。
如果某些孵化器的硬 标准会导致不匹配,可以考虑根据初创企业的信息筛选这些孵化器。经过缩小潜在匹配的列表后,我们可以使用剩余的软标准进行相似性搜索,将其转化为统一的文本并嵌入到向量中。
好消息是:Pinecone 向其向量数据库提供了这一功能!
[## 向量搜索中的缺失 WHERE 子句 | Pinecone]
向量相似性搜索使得庞大的数据集可以在几分之一秒内进行检索。然而,尽管其卓越的表现和…
www.pinecone.io](https://www.pinecone.io/learn/vector-search-filtering/?source=post_page-----bd65c41175bd--------------------------------)
项目路径现在已经明确:
-
孵化器的数据需要预处理以便过滤硬标准和相似性搜索软标准。然后将数据存储在 Pinecone 向量数据库中。
-
过滤对象必须根据 Pinecone Python 库构建。此外,它还需要保持灵活,以便客户可以轻松修改标准而无需修改算法。
-
软标准需要统一,并转换为嵌入格式,使用适当的嵌入模型。
-
数据是关键,我们需要为启动信息实现数据验证步骤,也需要为upserting新的孵化器数据到向量数据库中进行验证。我们将使用Pydantic。
-
该算法将作为API在docker 容器中提供。我们将使用 FastAPI 并创建一个 Dockerfile,以确保代码在任何环境下都能正常工作。
-
额外说明:单元测试和集成测试将被设置,以便任何人可以以 CI/CD 方式修改代码。
所有这些点都与利益相关者讨论过并被接受了。
我们准备出发了!
数据预处理
我收到了孵化器的解析信息在一个电子表格中。乍一看,数据相当混乱:手动提取没有明确的过程,字符串而不是布尔值,同一特征内的一致性缺乏,……
需要做大量的工作来使数据可用。
相同特征的不同日期“格式”
关于数据集中空值,每个特征都是独立处理的。
例如,出勤要求可能是面对面、混合或远程。在这种情况下,缺少此特征的孵化器被认为是要求面对面出勤。
另一个例子是启动公司的注册:注册或未注册。与其选择这两个类别中的一个,不如添加第三个类别作为默认值:无论如何。这将在过滤阶段有用,不仅选择主要类别之一,还选择所有未明确说明的孵化器。我们将在过滤部分讨论这个问题。
最终,我们将软标准转化为一个单一的提示以嵌入。为此,我们简单地使用了一个提示模板。如果在项目后期需要添加新特性,只需更新该提示即可。
# config.py
class Templates:
embedding_template = """Industries accepted:
{industry_focus}
Funding vehicle:
{funding_vehicle}"""
Templates.embedding_template.format(
industry_focus=industry_focus,
funding_vehicle=funding_vehicle
)
一旦孵化器数据经过预处理,就会导出到Pinecone 向量数据库中。
使用孵化器数据构建向量数据库
Pinecone 提供了一个易于使用的 Python SDK,用于插入、修改和查询向量数据库中的数据。
在我们的案例中,我们需要upsert(插入或更新)一个表示软标准的向量,此外还有硬标准。
根据 Pinecone,数据应遵循以下格式:
# List[(id, vector, metadata)]
[
("A", [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], {"genre": "comedy", "year": 2020}),
("B", [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], {"genre": "documentary", "year": 2019}),
("C", [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], {"genre": "comedy", "year": 2019}),
("D", [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], {"genre": "drama"}),
("E", [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], {"genre": "drama"})
]
嵌入
有许多模型,无论是开源的还是非开源的,可以将文本嵌入到向量表示中。在这种情况下,我们将使用sentence-bert,一个旨在利用开源嵌入模型的 Python 库。你可以查看我之前的文章,其中描述了它的工作原理:
随着大型语言模型推动的 AI 最新趋势和 ChatGPT(OpenAI)的成功,企业已经…
medium.com](https://medium.com/@jeremyarancio/semantic-search-using-sequence-bert-2116dabecfa3?source=post_page-----bd65c41175bd--------------------------------)
这个库的简洁性使其成为构建第一个匹配工具版本的良好选择。
# pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer
class SentenceTransformersEmbedding:
"""Embedding using the SentenceTransformers library (https://www.sbert.net)"""
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2"
) -> None:
self.model = SentenceTransformer(model_name)
def get_embeddding(self, texts: Union[str, List[str]]) -> List:
# We need to return a list instead of an array for Pinecone
return self.model.encode(texts).tolist()
准备并导出孵化器数据。
要将新的孵化器数据插入到向量数据库中,我们按照 Pinecone 文档中介绍的方式准备数据。
def prepare_from_payload(self, incubators: List[Incubator]) -> List[Tuple[str, List[float], Mapping[str, Any]]]:
"""Prepare payload containing incubators data to export to Pinecone vector database.
Args:
incubators (List[Incubator]): List of Incubator containing the incubator information that will be sent to Pinecone.
Returns:
List[Tuple[str, List[float], Mapping[str, Any]]]: Prepared data for Pinecone. Check official documentation (https://docs.pinecone.io/docs/metadata-filtering#inserting-metadata-into-an-index).
"""
data = []
for incubator in incubators:
metadata = {key: value for key, value in incubator.model_dump(exclude={"incubator_id"}).items()}
additional_information_text = Templates.embedding_template.format(incubator.industry_focus, incubator.funding_vehicle)
embedding = self.embedding_generator.get_embeddding(additional_information_text)
incubator_data = (incubator.incubator_id, embedding, metadata)
data.append(incubator_data)
return data
正如你在代码中看到的,我们使用 Pydantic 的BaseModel
创建了一个Incubators
对象。
from pydantic import BaseModel
from datetime import date
class Incubator(BaseModel):
incubator_id: str
name: str
application_open: int = 1
next_deadline: date = date.max
funding_amount: int = 0 # Maximal amount the incubator can fund
attendance_requirement: Literal["in-person", "remote", "hybrid"] = "in-person"
incorporation: Literal["incorporated", "unincorporated"] = "regardless"
minimum_cofounders: int = 0
minimum_employees: int = 0
previous_funding_accepted: int = 1
...
class Incubators(BaseModel):
incubators: List[Incubator]
这个BaseModel
类有两个主要好处。它不仅确保数据符合我们算法和查询的正确格式,而且还定义了孵化器数据的默认模式。
print(Incubator(
incubator_id="id",
name="incubator_on_fire",
industry_focus="Health tech",
funding_vehicle="Grant"
))
# Output
{
'id': 'id'
'name': 'incubator_on_fire',
'application_open': 1,
'next_deadline': datetime.date(9999, 12, 31),
'funding_amount': 0,
'attendance_requirement': 'in-person',
'incorporation': 'regardless',
'minimum_cofounders': 0,
'minimum_employees': 0,
'woman_founders': 0,
'student_founders': 0,
'industry_focus': 'Health tech',
'funding_vehicle': 'Grant'
...
}
孵化器数据随后使用 Pinecone Python 库导出到向量数据库。为了让其他开发人员能够在应用程序的整体架构中实现这段代码,我们使用了 FastAPI:
import os
from fastapi import FastAPI, HTTPException
from app.models import Incubators
from features import FeatureEngine
from embedding import SentenceTransformersEmbedding
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
ENVIRONMENT = os.getenv("ENVIRONMENT")
app = FastAPI()
@app.post("/upsert")
def upsert(incubators: Incubators):
try:
embedding_generator = SentenceTransformersEmbedding()
feature_engine = FeatureEngine(embedding_generator=embedding_generator)
data = feature_engine.prepare_from_payload(incubators=incubators.incubators)
vectors = [pinecone.Vector(id=id, values=values, metadata=metadata) for id, values, metadata in data]
pinecone.init(api_key=PINECONE_API_KEY, environment=ENVIRONMENT)
index = pinecone.Index(index_name=VectorDatabaseConfig.index_name)
index.upsert(vectors=vectors)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
数据导出后,我们能够开始使用初创公司信息查询向量数据库。
构建匹配算法
该算法在两个步骤中执行 top-k 孵化器的检索:
-
过滤掉不相关的孵化器,
-
使用嵌入向量执行相似性搜索。
我们还需要确保算法足够灵活,以便在项目后期添加或更改任何数据而不触及算法的核心。
那么如何做到这一点呢?
这是我想到的解决方案:
Pinecone 使用与 MongoDB 相同的语言来过滤数据库[source]。它看起来是这样的:
import pinecone
pinecone.init(api_key=PINECONE_API_KEY, environment=ENVIRONMENT)
index = pinecone.Index("example-index")
index.query(
vector=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
filter={
"genre": {"$eq": "documentary"},
"year": 2019
},
top_k=5,
include_metadata=True
)
过滤映射也可以更为复杂:
# $in statement
{
"genre": { "$in": ["comedy", "documentary", "drama"] }
}
# Multi criteria
{
"genre": { "$eq": "drama" },
"year": { "$gte": 2020 }
}
# $or statement
{
"$or": [{ "genre": { "$eq": "drama" } }, { "year": { "$gte": 2020 } }]
}
通过在查询中实现初创公司信息,我们能够检索出符合要求的孵化器:$gte
— 大于, $eq
— 等于, 等等*…*
但有些情况更为复杂。
例如,位置和出席要求是配对使用的。如果一个孵化器只接受混合或面对面,那么初创公司逻辑上应该位于与孵化器相同的城市/国家。但匹配工具也应该展示所有接受远程的孵化器,无论初创公司位于何处。
另一个示例:假设初创公司由 女性创始人 领导,或者初创公司已经构建了 MVP。因此,具有此陈述为真的初创公司应被提议孵化器,该孵化器仅接受女性创始人,或要求 MVP,此外还包括所有其他孵化器。
正如这些示例所示,标准可以分为不同的“模板”称为 Criterion
。这些标准模板将用于构建 filter_object
,这是 Pinecone/MongoDB 使用的过滤映射。
使用 Python 类,它看起来是这样的:
class Criterion(ABC):
"""Incubators criterion template used to build the filter object.
Each subclass of this class is a specific rule case used incubators and start-ups data.
Args:
name (str): incubators metadata name as it is in the vectordatabase.
"""
def __init__(
self,
name: str,
) -> None:
self.name = name
class NormalCriterion(Criterion):
"""Basic rule for creating to filter data based on this criterion.
It takes this form:
```python
criterion.name = {criterion.condition_type: payload[criterion.startup_correspondance]}
```py
With `payload` the start-up information.
Example:
```python
max_funding_amount = {$gte: 10000}
```py
This will filter all incubators with a maximal funding capacity greater than 10000.
Args:
condition_type (str): comparison element like "$eq" (equal), "$lte" (lower than or equal), "$gt" (greater than)
The complete list is available on the pinecone documentaton (https://docs.pinecone.io/docs/metadata-filtering#metadata-query-language).
startup_correspondance (str): start-up correspondance from the payload
"""
def __init__(
self,
name: str,
condition_type: str,
startup_correspondance: str
) -> None:
self.condition_type = condition_type
self.startup_correspondance = startup_correspondance
super().__init__(name=name)
父类对象 Criterion
用于构建多个子类,表示每种情况。如果我们以上面介绍的 女性创始人/MVP 情况为例:
class InclusiveCriterion(Criterion):
"""If condition validated, considers all.
Example:
Being women founders should match women-founders-only incubators, but also the other incubators.
Same for MVP, Ready_to_pay, Students founders, etc...
```
if woman_founders_startup (False) != condition (True):
{"woman_founders_incubator": {"$eq": woman_founders_startup_value (false)}}
参数:
condition_type (str): 比较元素,如 "$eq"(等于),"$lte"(小于或等于),"$gt"(大于)
完整列表可在 pinecone 文档中找到 (https://docs.pinecone.io/docs/metadata-filtering#metadata-query-language)。
startup_correspondance (str): 从 payload 中的初创公司对应(见 matching_tool/app/models.py)
condition (bool): 如果条件得到验证,考虑标准
"""
def __init__(
self,
name: str,
condition_type: str,
startup_correspondance: str,
condition: bool
) -> None:
self.condition_type = condition_type
self.startup_correspondance = startup_correspondance
self.condition = condition
super().__init__(name)
```py
Those `Criterion` classes are used along their respective method to build the `filter_object` :
def normal_case(
payload: Mapping,
criterion: NormalCriterion,
filter_object: Dict
) -> Dict:
"""最简单的情况:取启动值(资金额,之前的资助等)并在 vectordatabase 中按此过滤
condition_type($eq, $lte, $gte, $gt, ...)
参数:
payload (Mapping): 启动信息
criterion (NormalCriterion): 普通标准
filter_object (Dict): 在 vectordatabase 查询期间的元数据过滤器
返回:
Dict:
```pypython
{metadata_name: {condition_type: startup_value}}
```
"""
filter_object[criterion.name] = {
criterion.condition_type: payload[criterion.startup_correspondance]
}
return filter_object
def inclusive_case(
payload: Mapping,
criterion: InclusiveCriterion,
filter_object: Dict
) -> Dict:
"""包容性案例:为包容性案例准备过滤器:女性创始人,学生创始人,MVP,其他费用...
如果条件满足(初创公司中的女性创始人 == 1),因此不要考虑过滤标准 => 获取所有(仅接受女性的孵化器和其他所有孵化器)
否则:只考虑没有女性创始人的孵化器 => {women_founders: {"$eq: 0}}
参数:
payload (Mapping): 启动信息
criterion (NormalCriterion): 普通标准
filter_object (Dict): 在 vectordatabase 查询期间的元数据过滤器
"""
if payload[criterion.startup_correspondance] != criterion.condition:
filter_object[criterion.name] = {criterion.condition_type: payload[criterion.startup_correspondance]}
return filter_object
All these `Criterion` classes are stored inside another class object we call `Criteria` . This class acts as a repository of all the criteria to consider for filtering the database and can be easily modified to add or remove any criterion.
class Criteria:
"""使用 Criterion 模板进行过滤。
使用适当的 Criterion 模板添加或删除任何条件。
"""
country = DependendantCriterion(
name="country",
condition_type="$eq",
startup_correspondance="country"
)
city = DependendantCriterion(
name="city",
condition_type="$eq",
startup_correspondance="city"
)
attendance_requirement = ConditionalCriterion(
name="attendance_requirement",
condition=["remote"],
true_criteria=[],
else_criteria=[country, city]
)
minimum_cofounders = NormalCriterion(
name="minimum_cofounders",
condition_type="$lte",
startup_correspondance="n_cofounders"
)
working_product_requirement = InclusiveCriterion(
name="working_product_requirement",
condition_type="$eq",
startup_correspondance="working_product",
condition=True
)
woman_founders = InclusiveCriterion(
name="woman_founders",
condition_type="$eq",
startup_correspondance="woman_founders",
condition=True
)
…
Once all the criteria are added to the `Criteria` object, we iterate over it and build the `filter_object` based on the start-up information. For each `Criterion` case, we add a filter element to the `filter_object` .
class Matcher:
"从向量数据库中检索与初创公司信息匹配的孵化器。"
def __init__(
self,
index: Index,
criteria: Criteria = Criteria(),
embedder: Embedding = SentenceTransformersEmbedding(),
) -> None:
"""
参数:
index (Index): 向量数据库索引 / 表
criteria (Criteria, optional): 孵化器元数据以进行搜索。默认为 Criteria()。
embedder (Embedding, optional): 嵌入方法,用于将文本转换为向量表示
语义搜索。默认为 SentenceTransformersEmbedding()。
"""
self.index = index
self.criteria = criteria
self.embedder = embedder
def _get_filter(
self,
payload: Dict[str, Any],
) -> Mapping[str, Any]:
"""构建用于在 Pinecone 上过滤元数据的字典。
过滤对象应遵循以下格式。有关更多信息,请查看官方 Pinecone 文档:
https://docs.pinecone.io/docs/metadata-filtering
参数:
payload (Dict[str, Any]): 初创公司信息
返回:
Mapping[str, Any]: 过滤对象
```pybash
filter={
'application_open': 1,
'$or': [{'attendance_requirement': {'$in': ['remote']}}, {'country': {'$eq': 'estonia'}, 'city': {'$eq': 'tallinn'}}],
'funding_amount': {'$gte': 12000},
'other_costs': {'$eq': 0},
'previous_funding_accepted': {'$eq': 1},
'working_product_requirement': {'$eq': 0}
}
```
"""
# 初始过滤器
filter_object = {"application_open": 1}
criteria = self.criteria.get_criteria()
for criterion in criteria:
if isinstance(criterion, NormalCriterion):
if check_correspondance_in_payload(payload, criterion):
filter_object = normal_case(
payload=payload,
criterion=criterion,
filter_object=filter_object,
)
if isinstance(criterion, InclusiveCriterion):
if check_correspondance_in_payload(payload, criterion):
filter_object = inclusive_case(
payload=payload,
criterion=criterion,
filter_object=filter_object,
)
if isinstance(criterion, ConditionalCriterion):
if check_dependencies(payload, conditional_criterion=criterion):
filter_object = conditional_case(
payload=payload,
criterion=criterion,
filter_object=filter_object,
)
if isinstance(criterion, DefaultCriterion):
if check_correspondance_in_payload(payload, criterion):
filter_object = default_case(
payload=payload,
criterion=criterion,
filter_object=filter_object,
)
return filter_object
As you can see in the code, we built four different `Criterion` templates to consider many cases: `NormalCriterion` , `InclusiveCriterion` , `ConditionalCriterion` , and `DefaultCriterion` .
In the future of the project, more categories can be added without changing the algorithm core, making it **customizable**.
Once the `filter_object` is created with the `_get_filter()` method, the vector database can be queried with the Pinecone `index.query()` method:
matches = self.index.query(
vector=embedding,
filter=filter_object,
include_metadata=True,
top_k=top_k
)
The matching tool algorithm is created. We then served it through an API endpoint using FastAPI and Pydantic.
@app.post(“/match”)
def search(payload: StartUp, top_k: int = 5) -> Mapping:
LOGGER.info("开始匹配。")
try:
payload = preprocess_payload(dict(payload))
pinecone.init(api_key=PINECONE_API_KEY, environment=ENVIRONMENT)
index = pinecone.Index(index_name=VectorDatabaseConfig.index_name)
matching_tool = Matcher(index=index)
matches = matching_tool.match(payload=payload, top_k=top_k)
return matches
except Exception as e:
LOGGER.error(f"{str(e)}")
raise HTTPException(status_code=500, detail=str(e))
As `Incubator` built with Pydantic, we created the object `Startup` object to ensure the start-up data comes in the right format:
class StartUp(BaseModel):
country: Optional[str] = None
city: Optional[str] = None
funding_amount: Optional[int] = None
n_cofounders: Optional[int] = None
n_employees: Optional[int] = None
woman_founders: Optional[bool] = None
industry_focus: str = ""
funding_vehicle: str = ""
...
An advantage of using Pydantic with FastAPI is that the API payload (here the start-up information) doesn’t have to be complete. For example, if there is missing information, Pydantic will automatically replace it with its default value, or not consider it at all in the algorithm (defined by the `None` statement).
The core of the API is now set up. We can now make the code ready for shipment using Docker and CI/CD with Pytest.
# Delivering the API
## Integration test with Pytest
During the development of the code, unitests and integration tests were created to ensure no modifications would break the algorithm.
Furthermore, creating the test algorithms not only provides a CI/CD process but also gives my client indications about how the code is supposed to work.
To build an integration test with FastAPI, we used the `TestClient` provided within the library. It uses the `httpx` library instead of `requests` making a call to the API.
The data used as validation of the code is stored in an external JSON file `data/integration_test_data.json`
integration_test.py
pip install httpx
from fastapi.testclient import TestClient
URL = “/match”
client = TestClient(app)
DATA_PATH = Path(os.path.realpath(file)).parent / “data/integration_test_data.json”
with open(DATA_PATH, ‘r’) as data:
DATA = json.load(data)
def test_match():
for test in DATA["match_tests"]:
response = client.post(URL, json=test["payload"])
assert response.status_code == 200
payload: Dict = json.loads(response.content)
match_ids = [match["incubator_id"] for match in payload.values()]
for expected_id in test["expected"]:
assert expected_id in match_ids

Run Pytest on all “test” scripts
Once all tests passed, we created the **Dockerfile** to containerize the code.
## Docker
To create a Docker container, we simply create a Dockerfile within the repository:
FROM python:3.9
WORKDIR /src
ENV PYTHONPATH=/src
COPY requirements.txt requirements.txt
COPY matching_tool/ .
RUN pip install -r requirements.txt
EXPOSE 8001
CMD [“uvicorn”, “app.api:app”, “–host”, “0.0.0.0”, “–port”, “8001”]

Structure of the repository
Here’s what each line does:
* `FROM` import the docker image from the hub with all the basic elements required to run Python 3.9 in this case.
* `WORKDIR` specifies the location of the code within the container
* `ENV PYTHONPATH = /src` specifies which directory Python has to look into to import internal modules.
* `COPY` copies the files in the attributed directory.
* `RUN` is triggered during the Docker image creation, and before the Docker container build. This way, `pip install -r requirements.txt` only runs once.
* `EXPOSE` exposes a container port of our choice, here’s the port 8001\. The API port should match the container port.
* `CMD ['uvicorn”, “app.api.app”, “ — host”, “0.0.0.0”, “ — port 8001]`runs the FastAPI API. It is important here to indicate the host as `0.0.0.0` to enable calls from outside the container.
We then created the Docker image by running in the CLI:
docker build -t matching-tool:latest -f Dockerfile .
Finally, to run the container, one has just to write:
docker run -p 8001:8001 --name matching-tool matching-tool
一旦容器运行,任何人都可以通过端口 8001 调用 API。也可以将 Docker 容器部署到任何云提供商,**使匹配工具立即生效**。
项目已准备好交付。
# 结论
在这篇文章中,我分享了我为一家美国初创公司进行的实际项目。
根据我所提供的数据,以及与利益相关者的多次迭代,我开发了一个工具,帮助初创企业创始人找到最适合他们需求的孵化器。我逐步解释了我所遵循的过程和解决此问题的不同策略。
下一步将是将此算法嵌入到整体应用中,并开始收集用户数据。这将启动任何机器学习功能所需的**飞轮**。确实,从这些代表用户偏好的数据中,将能够构建一个会随时间学习的推荐系统,并为当前和未来的创始人提供最佳输出。
与[Harness](https://www.joinharness.com/)在这个项目中合作非常愉快。我祝愿他们一切顺利。他们知道未来有合作的机会可以随时联系我。
如果你喜欢这篇文章,[**欢迎订阅我的新闻通讯**](https://medium.com/@jeremyarancio/subscribe)**。我分享有关 NLP、MLOps 和创业的内容。**
你可以通过[Linkedin](https://www.linkedin.com/in/jeremy-arancio/)联系我,或者查看我的[Github](https://github.com/JeremyArancio)。
如果你是企业并希望将机器学习应用到你的产品中,你也可以[**预约通话**](https://topmate.io/jeremyarancio/555697)。
再见,祝编码愉快!