TowardsDataScience-2023-博客中文翻译-四十一-
TowardsDataScience 2023 博客中文翻译(四十一)
原文:TowardsDataScience
协议:CC BY-NC-SA 4.0
停止在数据科学项目中硬编码——改用配置文件
原文:
towardsdatascience.com/stop-hard-coding-in-a-data-science-project-use-config-files-instead-479ac8ffc76f
如何在 Python 中高效地与配置文件交互
Khuyen Tran
·发表于 Towards Data Science ·6 分钟阅读·2023 年 5 月 26 日
--
作者提供的图片
最初发布于 https://mathdatasimplified.com 2023 年 5 月 26 日。
问题
在你的数据科学项目中,某些值经常会变化,例如文件名、选定的特征、训练-测试分割比例和模型的超参数。
作者提供的图片
在编写临时代码用于假设测试或演示目的时,硬编码这些值是可以接受的。然而,随着代码库和团队的扩展,避免硬编码变得至关重要,因为它可能导致各种问题:
- 可维护性:如果值在代码库中分散,保持一致地更新它们会变得更加困难。这可能导致在值需要更新时出现错误或不一致。
作者提供的图片
- 可重用性:硬编码值限制了代码在不同场景下的重用性。
作者提供的图片
- 安全问题:将敏感信息如密码或 API 密钥直接硬编码到代码中可能会带来安全风险。如果代码被共享或暴露,可能会导致未经授权的访问或数据泄露。
def connect_to_database():
# Hardcoded sensitive information
host = 'localhost'
username = 'admin'
password = 'secretpassword'
database = 'mydatabase'
# Code to connect to the database
- 测试和调试:硬编码的值可能使测试和调试变得更加困难。如果值被硬编码到代码中,就很难有效地模拟不同的场景或测试边界情况。
def divide(num1: int):
num2 = 3
return num1 / num2
def test_divide():
assert divide(3) == 1.0 # passed
# We can't test num2 = 0 because it is hard-coded
解决方案——配置文件
配置文件通过提供以下好处来解决这些问题:
- 配置与代码分离:配置文件允许你将参数与代码分开存储,这提高了代码的可维护性和可读性。
作者提供的图片
- 灵活性和可修改性:通过配置文件,你可以轻松修改项目配置,而不必修改代码本身。这种灵活性允许快速实验、参数调整,并将项目适应不同的场景或环境。
作者提供的图片
- 版本控制:将配置文件存储在版本控制中可以跟踪配置的更改。这有助于维护项目配置的历史记录,并促进团队成员之间的协作。
作者提供的图片
- 部署:当将数据科学项目部署到生产环境时,配置文件可以轻松自定义生产环境的特定设置,而无需修改代码。这种配置与代码的分离简化了部署过程。
Hydra 简介
有许多 Python 库可以用来创建配置文件,如 pyyaml、configparser、ConfigObj。然而,Hydra这款开源 Python 库脱颖而出,成为我首选的配置管理工具,因为它拥有一套令人印象深刻的功能,包括:
-
便捷的参数访问
-
命令行配置覆盖
-
从多个来源组合配置
-
执行具有不同配置的多个作业
让我们深入探讨这些功能。
随意玩耍并在这里分叉本文的源代码:
[## GitHub - khuyentran1401/hydra-demo
你现在不能执行该操作。你在另一个标签页或窗口中登录了。你在另一个标签页或…
github.com](https://github.com/khuyentran1401/hydra-demo?source=post_page-----479ac8ffc76f--------------------------------)
便捷的参数访问
假设所有的配置文件都存储在conf
文件夹下,所有的 Python 脚本都存储在src
文件夹下。
.
├── conf/
│ └── main.yaml
└── src/
├── __init__.py
├── process.py
└── train_model.py
main.yaml
文件如下所示:
# main.yaml
data:
raw: data/raw/winequality-red.csv
intermediate: data/intermediate
model: models
processs:
cols_to_drop:
- free sulfur dioxide
feature: quality
test_size: 0.2
在 Python 脚本中访问配置文件就像在你的 Python 函数上应用一个装饰器一样简单。
import hydra
from omegaconf import DictConfig
@hydra.main(config_path="../conf", config_name="main", version_base=None)
def process_data(config: DictConfig):
...
要从配置文件中访问特定参数,我们可以使用点符号(例如,config.process.cols_to_drop
),这比使用括号(例如,config['process']['cols_to_drop']
)更简洁直观。
作者提供的图片
这种直接的方法使你可以轻松地检索所需的参数。
命令行配置覆盖
假设你正在尝试不同的test_size
。一遍遍打开配置文件并修改test_size
值是很耗时的。
作者图片
幸运的是,Hydra 使得直接从命令行覆盖配置变得简单。这种灵活性允许快速调整和微调,而无需修改底层配置文件。
python src/process_data.py processs.test_size=0.3
从多个来源组合配置
想象一下你希望尝试各种数据处理方法和模型超参数的组合。虽然你可以每次运行新实验时手动编辑配置文件,但这种方法可能非常耗时。
作者图片
Hydra 通过配置组支持从多个来源组合配置。要创建一个数据处理的配置组,创建一个名为process
的目录来保存每种处理方法的文件:
.
└── conf/
├── process/
│ ├── process1.yaml
│ └── process2.yaml
└── main.yaml
作者图片
如果你想默认使用process1.yaml
文件,请将其添加到 Hydra 的默认列表中。
作者图片
按照相同的程序创建训练超参数的配置组:
.
└── conf/
├── process/
│ ├── process1.yaml
│ └── process2.yaml
├── train/
│ ├── train1.yaml
│ └── train2.yaml
└── main.yaml
作者图片
将train1
设置为默认配置文件:
作者图片
现在运行应用程序将默认使用process1.yaml
文件和model1.yaml
文件中的参数:
$ python src/process.py --help
process:
cols_to_drop:
- free sulfur dioxide
feature: quality
test_size: 0.2
train:
hyperparameters:
svm__kernel:
- rbf
svm__C:
- 0.1
- 1
这种功能特别有用,当需要无缝组合不同的配置文件时。
多次运行
假设你想对多种处理方法进行实验,一个一个地应用每个配置可能是一个耗时的任务。
$ python src/process.py process=process1 # wait for this to finish
$ python src/process.py process=process2 # then run the application with another config
幸运的是,Hydra 允许你同时使用不同的配置运行相同的应用程序。
$ python src/process.py --multirun process=process1,process2
这种方法简化了使用各种参数运行应用程序的过程,最终节省了宝贵的时间和精力。
结论
恭喜!你刚刚了解了使用配置文件的重要性以及如何使用 Hydra 创建配置文件。我希望这篇文章能为你提供创建自己配置文件所需的知识。
我喜欢撰写关于数据科学概念的文章,并玩弄各种数据科学工具。你可以通过以下方式获取我最新的帖子:
-
订阅我在数据科学简化上的新闻通讯。
-
在LinkedIn和Twitter上与我联系。
停止使用 PowerPoint 来做你的机器学习演示,试试这个替代工具
原文:
towardsdatascience.com/stop-using-powerpoint-for-your-ml-presentations-and-try-this-instead-f943c2e9e284
Gradio 是打动技术和非技术利益相关者的可靠方式——为什么更多的数据科学家和机器学习工程师不使用它呢?
Matt Chapman
·发表于 Towards Data Science ·6 分钟阅读·2023 年 7 月 3 日
--
图片来自 Will Porada 在 Unsplash
PowerPoint 演示文稿很糟糕。
至少,差的确实如此。
不好的 PowerPoint 会让观众分心(他们会关闭摄像头并进行多任务处理),而且容易让演讲者养成使用过多技术术语和长时间冗长讲解等不良习惯。
那么为什么数据科学家会如此频繁地使用 PowerPoint 呢?
在最近一个关于这个话题的 Reddit 讨论串 中,从事数据科学的受访者报告称,他们花费了 10% 到 60% 的时间制作幻灯片或进行演讲。我意识到这不是一个非常可靠的统计数据,但不论真实分布如何,这种情绪对于许多从事数据科学工作的人来说是准确的:我们使用 PowerPoint——非常多——来展示从模型卡片到 ROC 曲线和 Shapley 值的截图。
无论你喜欢与否,PowerPoint 是现代机器学习技术栈的重要组成部分,它不会消失。
或者它并非如此?
在这篇文章中,我将向你介绍 Gradio,一个免费的工具,允许你:
-
通过你的浏览器或 Jupyter Notebook 可视化机器学习模型
-
通过互动且易于理解的可视化给你的非技术利益相关者留下深刻印象
-
测试你的模型并识别弱点和特征重要性
我与 Gradio 没有任何关联,也不试图向你推销任何东西——我只是想展示一个在我作为数据科学家的工作中效果良好的工具,特别是对于使用表格数据的模型,如 XGBoost。
介绍 Gradio:一个免费的、互动的方式来展示和测试您的机器学习模型
在开发者自己的话语中,
Gradio 是演示机器学习模型的最快方法,它提供了一个友好的网络界面,使任何人都可以在任何地方使用它!
这怎么运作的?其实非常简单。
你好,世界!
首先,通过 pip 安装 Gradio。
pip install gradio
接下来,导入 Gradio 并定义一个可以接受输入的函数。然后,将您的模型包装在‘gradio.Interface()’类中,——瞧——您的模型就拥有了一个友好的互动界面,可以嵌入到笔记本或网页中。这里是一个使用非常简单的“Hello {user}!”函数的示例:
import gradio as gr
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()
图片由作者提供
如果您在 Jupyter Notebook 中运行它,上面的演示将自动出现在新单元格中。如果您在脚本中运行,它将出现在您的浏览器中,网址为 http://localhost:7860。如果您愿意,您还可以“自动生成一个公共链接,您可以与同事分享,让他们通过自己的设备远程与您计算机上的模型进行互动”(docs)。
如何将 Gradio 与您的机器学习模型一起使用
举一个稍微复杂一点的例子,假设我们有一个可以识别手绘图像的机器学习模型。使用 Gradio,我们可以创建一个接受用户输入的草图板……
import gradio as gr
def sketch_recognition(img):
pass# Implement your sketch recognition model here...
gr.Interface(fn=sketch_recognition, inputs="sketchpad", outputs="label").launch()
…给我们提供了一种巧妙的方法来绘制草图,将其传递给模型,并实时演示模型:
图片由作者提供
注意:为了保持本文的长度在可管理范围内,我没有包括有关模型本身的详细信息;如果您正在寻找模型,您可能想查看 HuggingFace,这是一个很棒的预训练模型库,可以很容易地加载到 Jupyter 笔记本或 Python 脚本中,并与 Gradio 一起使用。
只需几行代码,Gradio 就能轻松地以任何人都能理解的互动方式展示您的模型。想想您团队中的可能性——这将使您向团队或利益相关者展示模型变得多么容易?
如果您喜欢这个故事,点击我的‘关注’按钮对我意义重大——只有 1%的读者这样做!感谢阅读。
那 XGBoost 呢?
是的,我知道——用一些炫酷的图像识别模型做演示很好,但实际上,XGBoost 通常在实际的数据科学团队中占据主导地位。
好吧,如果这是您在想的,我有一些好消息要告诉您:Gradio 可以与各种模型一起使用,包括像 XGBoost 这样的表格数据模型。
这是一个示例,展示了一个用 Gradio 展示的 XGBoost 模型,该模型用于预测人们的收入:
图片由作者提供。代码可在 此处 获取。
如你所见,Gradio 使得通过调整开关和滑块与模型互动成为可能,并观察这些调整如何影响预测。你甚至可以使用 Shapley 值来查看不同特征在确定模型预测中的重要性!
我如何使用 Gradio
Gradio 永远不会取代传统的模型测试和评估方法(例如分类报告和使用适当的样本外测试数据集),但我发现它在我日常工作中有两个方面很有用:
-
向非技术利益相关者解释模型——Gradio 帮助我回答利益相关者的临时问题,比如“如果我们改变这个变量,模型的预测会发生什么?”而且,你不需要拥有 AI 博士学位来调整开关,因此每个人都可以尝试操作模型,建立自己对模型工作的理解,无论他们是否有技术背景或懂得编码。
-
测试我的模型——Gradio 让你实时查询模型,测试你的假设,并迅速识别模型中的隐藏弱点(例如发现意外行为)。这是比批量测试更快速的替代方案,而且——至关重要的是——它使测试过程民主化。使用 Gradio,你不需要依赖一个孤立的数据科学家来运行笔记本中的批量测试——你可以公开托管模型(例如在 HuggingFace Spaces 上),分享模型的链接,每个团队成员都可以参与并尝试探测模型。
Gradio 是预言中的 PowerPoint 杀手吗?
不——但这并不是重点。
PowerPoint 的多功能性意味着像 Gradio 这样的狭窄工具不能完全替代它。此外,如果 PowerPoint 使用得好,它实际上可以是展示 ML 生命周期部分(如模型卡和业务案例)的极其有效的方法。
但是对于 ML 生命周期的其他部分(如展示特征重要性和模型结果),PowerPoint 并不总是效果很好,你可能会发现 Gradio 有助于解决它的一些弱点,防止观众在演示过程中打瞌睡。
所以,如果你是一个数据科学家或机器学习工程师,觉得 PowerPoint 无聊或效果不好,为什么不试试 Gradio 呢?对我来说,它确实帮助很大。
哦,还有一件事——
我开设了一个名为 AI in Five 的免费通讯,每周分享 5 个要点,涵盖最新的 AI 新闻、编码技巧和数据科学家/分析师的职业故事。没有炒作,没有“数据是新石油”的废话,也没有来自埃隆的推文——只有实用的技巧和见解,帮助你在职业发展中进步。如果这对你有吸引力,点击这里订阅!
[## AI in Five | Matt Chapman | Substack
最新的数据科学和人工智能领域的新闻、职业故事和编码技巧,总结为 5 个要点…
aiinfive.substack.com](https://aiinfive.substack.com/?source=post_page-----f943c2e9e284--------------------------------)
在 TensorFlow 记录文件中存储图像
原文:
towardsdatascience.com/storing-images-in-tensorflow-record-files-166d030269fb
如何使用 TFRecord 文件,这是一种针对 TensorFlow 的高效数据存储和读取的数据格式,在处理图像时
Pascal Janetzky
·发表于 Towards Data Science ·6 分钟阅读·2023 年 3 月 2 日
--
你知道 TensorFlow 有一种自定义格式来存储数据吗?它叫做 TensorFlowRecords——简称 TFRecords——并建立在一个简单的原则上:
将数据按顺序存储(在一个文件中),以便快速访问连续的数据块。
这种方法基于协议缓冲区,这是一种跨平台的结构化数据存储方法。我们不需要深入探讨背景;我们需要知道的是数据以类似字典的映射形式存储:
{"string": value}
一个文件可以包含许多这样的“字典”,在 TensorFlow 中称为Examples,如下图所示:
TensorFlow 记录文件背后的概念概述。图片由作者提供。
在每个Example——或字典——内,单独的数据条目被存储。这种格式非常灵活:你可以存储图像、文本、音频以及任何可以转换为字节表示的数据。此外,数据类型可以混合,这使我们可以保留,例如,图像和边界框以及文本描述。然而,在过早深入之前,我们将专注于一种模态:图像。其余的模态,音频和文本数据,将在未来的帖子中涵盖。
根据我的经验,最好用简单的示例来讲解这种高级主题,以最佳展示底层的工作流程。在这种情况下,我们使用随机的(图像形状)矩阵。
存储图片
创建随机数据
考虑一个包含 1000 张图片的数据集,每张图片的尺寸为 224 x 224,包含三个颜色通道。这个虚拟数据集的每个样本都标记为 0 到 9 中的一个类别。仅使用 numpy 库,我们可以轻松创建这样的数据集:
这段代码的结果是一个充满图像数据的数据集(这里是 numpy 数组)。
辅助函数
在我们拥有一个可用的数据集后,我们必须将其转换为字节数据。
为此,我们创建了四个辅助函数(也见这里)。前三个辅助函数将某些数据类型(如浮点数)转换为 TFRecord 兼容的表示。最后一个辅助函数将数组转换为二进制数据字符串:
创建 TFRecord 数据集
这些函数在我们开始创建 TFRecords 文件时发挥作用。在这里,我们需要一个函数来创建单个Example的布局,即我们要存储的图像的内部表示布局。使用之前的简化视觉表示,这样的Example具有多个包含数据的槽,称为Features:
关于数据如何在Example中存储的概念性概述。图片由作者提供。
对于第一次使用者来说,创建这样的浓缩表示可能会感到不知所措,所以让我们逐一介绍。首先,我们需要存储信息以恢复输入的数据维度。对于我们的图像用例,这些是高度、宽度(224)和通道数(3)。每个数字都是整数,这意味着我们可以将它们存储为整数数据。
其次,我们需要存储图像的字节表示。
第三,我们需要存储标签,标签像数据维度一样以整数数据形式存储。在代码中,这三个要求建模如下:
接下来,我们需要一个函数,它处理包含随机图像及其同样随机标签的数据集,并为存储做好准备。首先,我们打开一个处理将数据写入磁盘的 writer 对象。之后,我们使用一个遍历 numpy 数组的 for 循环,创建图像-标签对,并使用前面描述的方法将它们存储在 TFRecord 文件中。最后,在我们完成遍历数据集后,我们关闭 writer:
就这样!调用这个函数后,我们将拥有一个存储整个数据集的文件!
检索图像
提取字节数据
当我们在之后的时间点想要处理 TFRecords 时,我们需要检索存储的数据。从概念上讲,我们现在是反向存储过程。在这里,我们准备结构,但尚未填充数据。要小心:占位符必须具有相同的名称和适当的数据类型,否则提取将失败。然后,对于 TFRecord 文件中的每个Example,我们提取内容并重新塑造图像:
创建数据集
在编写提取数据的例程后,我们需要一种方法来将其应用于 TFRecord 文件中的每个样本。这一过程,即将数据解析为正确格式,是通过将提取函数映射到每个Example来完成的。在这里,我们依赖于 TensorFlow 的tf.data API,它具有这样的功能:
之后,我们将此函数指向之前创建的 TFRecord 文件(这里是“random_images.tfrecords”),并检索数据。然后,作为一个 sanity check,我们可以比较图像的形状,看看它是否被正确恢复:
注意事项
我们在这篇文章中涵盖的是如何将图像数据放入 TFRecord 文件。这里有两个注意事项,分别是前提假设:
首先,我们从已经加载到内存中的图像(我们的 numpy 数组)开始。第二,在我们的设置中,所有示例的形状都是相同的——这在实际应用中不太可能。
第一点很容易解决:使用许多优秀的库之一来完成。这里的例子包括imageio库或Pillow。对于这些库,存在大量教程,展示了加载数据所需的步骤。
第二点稍微复杂一些。挑战在于 TFRecord 文件的创建,而不是数据加载与批处理的结合。记得我们通过之前的函数存储了原始图像数据及其形状吗?在解析 TFRecord 文件时,这些信息使我们能够恢复图像的适当形状。然而,现在,当将多个示例组合成一个批次时,我们面临着数据维度各异的可能性:图像 1 可能是 224x224 像素,但下一个可能是 124x356 像素。
对于这种情况,我们有一个解决方案:TensorFlow 的padded_batch()方法。为了帮助你入门,这里是之前的数据集创建代码(最初没有使用任何批处理;样本是一个一个返回的),但这次使用了填充批处理:
有趣的部分从第 10 行开始,这一行将数据集中的每个批次填充到由padded_shapes
参数指定的固定形状。元组的第一个元素填充为[256, None, 3],这意味着张量的第一个维度固定为 256,第二个维度填充为适合该批次所有示例的最小支持长度,第三个维度固定为 3。批次元组的第二个元素,即标签,不需要填充,这就是我们写[]
的原因,表示不应应用任何填充。
总结
在这篇文章中,我们介绍了将一种数据模态——图像——存储到 TFRecord 文件中,这是一个用于高效数据存储和读取的 TensorFlow 专用数据格式。在介绍相关工作流程时,我们生成了一组随机的“图像”和同样随机的标签。然后,我们使用这些数据集展示了如何使用三个辅助函数准备数据以供存储。最后,在使用 TensorFlow 原生方法将数据写入磁盘后,我们还编写了相反的过程:从文件中提取数据。从概念上讲,这涉及通过填充占位符字典来逆转存储过程。最后,我们还简要讨论了两个注意事项及其解决方法。
可视化故事讲述——哪个区域的社会经济评分最高,为什么
原文:
towardsdatascience.com/story-telling-with-visualization-which-area-has-the-highest-socio-economic-score-and-why-c1205b2450c7
使用实际地理数据演示
Jin Cui
·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 12 月 26 日
--
图片由Joash Viriah拍摄,来源于Unsplash
背景
有时,为了更有效地分配资源,政府可能会收集个人或家庭的关于人口统计特征的数据,例如年龄、性别和出生国,以及他们的社会经济特征,例如收入、职业和支出。这些数据的一部分会被按地理区域汇总,并向公众提供。
在我居住的澳大利亚,政府通过澳大利亚统计局(ABS)校准一个称为经济资源指数(IER)的指标,该指标利用来自五年一次的人口普查数据的多种变量评分地理区域的相对社会经济状态。
IER 可以通过将澳大利亚划分为不同大小的地理区域的各种数字边界进行汇总。例如,州界(图 1 中的虚线)将澳大利亚划分为 8 个州和领地,而统计区域 1(SA1)边界(图 2 中)则将澳大利亚划分为更细的区域,有时是几条街道的集群。
在检查 ABS 提供的互动地图中的 IER 时,如下图所示,我发现 IER 在不同地区甚至街道级别上差异很大,我思考这可能是由什么因素驱动的。
图像 1 — 按 SA1 划分的 IER,全国视图。澳大利亚统计局,人口与住房普查:地区社会经济指数(SEIFA),澳大利亚,2021 年,ABS 网站,访问日期 2023 年 12 月 24 日。
图像 2 — 按 SA1 划分的 IER,街道视图。澳大利亚统计局,人口与住房普查:地区社会经济指数(SEIFA),澳大利亚,2021 年,ABS 网站,访问日期 2023 年 12 月 24 日
我们能否通过可视化揭示 ABS 如何根据地理区域区分 IER 得分?继续阅读!
数据
感谢 ABS,我们可以很方便地在一个地方获取 SA1 的 IER 得分及其支持变量,这在此网页的数据下载部分,“标准化变量比例数据立方体”1。
为了本文的目的,我将提供创建一系列可视化图表的 Python 代码,这可能有助于读者理解每个支持变量对 SA1 的 IER 得分的贡献。
我将从加载所需的包、读取和检查数据开始。
## Load the required packages
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import polars as pl
import hiplot as hip
## Read and inspect data
data_proportion = pd.read_excel('/content/gdrive/MyDrive/SEIFA/Statistical Area Level 1 (SA1), Standardised Variable Proportions, SEIFA 2021.xlsx',
sheet_name = 'Table 4', header = 5, usecols = 'A:Q')
data_proportion.head()
初看下面的图像,支持变量的列命名规则似乎不够直观。
图像 3 — 示例数据。作者提供的图像。
为了解决这个问题,ABS 提供了用于校准 IER 得分的 14 个变量的数据字典,点击这里查看。总结如下:
-
INC_LOW: 年收入在$1 和$25,999 AUD 之间的家庭中居住的人的比例
-
INC_HIGH: 年收入大于$91,000 AUD 的人的比例
-
UNEMPLOYED_IER: 15 岁及以上失业者的比例
-
HIGHBED: 占用私人物业中有四间或更多卧室的比例
-
HIGHMORTGAGE: 占用私人物业中每月按揭金额大于$2,800 AUD 的比例
-
LOWRENT: 占用私人物业中每周租金低于$250 AUD 的比例
-
OWNING: 没有抵押贷款的私有住房的百分比
-
MORTGAGE: 有抵押贷款的私有住房的百分比
-
GROUP: 被群体占用的私有住房的百分比(例如公寓或单元房)
-
LONE: 孤身占用的私有住房的百分比
-
OVERCROWD: 根据加拿大国家居住标准,要求额外卧室的私有住房占比
-
NOCAR: 没有汽车的私有住房占比
-
ONEPARENT: 单亲家庭的百分比
-
UNINCORP: 至少有一个是企业主的物业的百分比
另外,ABS 出版物中‘Score’列的 IER 评分经过标准化,均值为 1,000,标准差为 100,以显示特定区域的相对社会经济状况。
现在让我们来看看一些视觉效果!
可视化
下面的 Python 代码对数据进行了一些小的转换,以便将数据转换为适合可视化的格式。
## Rename the first column (optional)
data_proportion.rename(columns = {'2021 Statistical Area Level 1 (SA1) 11-Digit Code': 'SA1'}, inplace = True)
## Create a SCORE_HIGH class column with 1 indicating a high score based on standardised mean
data_proportion['SCORE_HIGH'] = np.where(data_proportion['Score'] > 1000, 1, 0)
## Select only the columns needed for visualization
column_select = [
'SCORE_HIGH', 'INC_LOW', 'INC_HIGH',
'UNEMPLOYED_IER', 'HIGHBED', 'HIGHMORTGAGE', 'LOWRENT', 'OWNING',
'MORTGAGE', 'GROUP', 'LONE', 'OVERCROWD', 'NOCAR', 'ONEPARENT',
'UNINCORP'
]
data = data_proportion[column_select]
## Remove rows with missing value (113 out of 60k rows)
data_dropna = data.dropna().reset_index(drop = True)
## Create a Polars dataframe
df = pl.from_pandas(data_dropna)
现在我们有效地拥有一个包含 14 个特征和一个类别变量的数据框,我们正在尝试对其进行可视化。与其在数据科学工作流程中通常对单一特征和目标变量进行低维度单向分析,不如使用下面的 Python 代码,可以交互式地可视化所有 14 个特征与目标变量之间的关系。
df_ml = pl.concat([
df.filter(pl.col("SCORE_HIGH") == 0).sample(5_000),
df.filter(pl.col("SCORE_HIGH") == 1).sample(5_000),
]
) # Note that I'm selecting a sample of 5,000 data points for each class
# as I don't want the visual to look too busy for the purpose of this demonstration
hip.Experiment.from_iterable(df_ml.to_dicts()).display()
如下图所示,通过选择‘SCORE_HIGH’列进行着色:
-
可视化中的每一条线代表一行数据。特别是,橙色线代表 IER 评分相对较高的 SA1 区域,蓝色线代表 IER 评分相对较低的 SA1 区域。
-
通过检查每个单独特征的颜色分布(回顾一下,这些特征是根据数据字典的某些特征的百分比),我们可以轻松判断特定特征与 IER 评分之间的关系。
-
例如,有证据表明,IER 评分与 SA1 区域内高收入人群的百分比(由 INC_HIGH 变量表示)呈正相关,而 IER 评分与 SA1 区域内单亲家庭的百分比(由 ONEPARENT 变量表示)呈负相关。这些观察结果都是直观上合理的。
-
从高层次来看,IER 评分可以通过与收入、物业价值和家庭构成相关的变量进行校准。
图像 4 — 可视化输出。图片由作者提供。
令人惊讶的是,除了少数几个 SA1 区域,大多数变量与 IER 评分的关系相对明确。
增强
由于上图所示每个变量的数值(即百分比)被标准化为均值 0 和标准差 1,如 ABS 方法论所述,因此可能很难确切解释某一特定变量的百分比对高或低 IER 分数的贡献。例如,高收入人群或具有 4 间或更多卧室的房产,在特定地区需要多少百分比才能获得高于平均水平的 IER 分数?
幸运的是,应要求,ABS 能够提供这些变量的非标准化(即原始)值,这提供了更全面的视角,如下所示。
图 5 — 可视化输出,原始百分比。图像由作者提供。
有用的可视化示例
你知道仅仅通过可视化也可以进行预测建模,这可能会超越某些更先进的模型(如神经网络)下的预测吗?
使用与上述类似的高维度可视化技术,目的是“记录”可能对特定类别有预测性的值范围。
例如,在上面的可视化中,为了预测高 IER 分数,我们可能需要的值范围是 UNINCORP >15%,HIGHBED > 50%,INC_LOW < 10% 等,这些可以用来制定一个简单的嵌套 IF ELSE 语句(在任何编程语言中)。
以下视频演示了如何使用上图中的可视化进行交互式操作。
视频 6 — 如何识别高 IER 分数的预测变量
为了进一步实验,我将其应用于 Google 开发的欺诈检测分类问题,该问题用于推广Keras 包(用于构建神经网络模型),并产生了以下比较模型评估结果。
表 7 — 评估指标比较。表格由作者提供。
不仅可视化超越了神经网络,而且通过记录影响“IF-ELSE”规则的特征值,我们能够比一些其他不易解释的模型(如神经网络)更清晰地向预测者提供建议。
结论思考
在本演示中,除了创建有助于揭示各种特征与目标变量之间关系的可视化外,我还希望强调能够解释模型和解释输出的重要性,这可能与仅基于评估指标拟合模型一样重要。
我是可视化的忠实粉丝,我坚信每一个统计数据都有一个故事。如在 这篇文章 中提到的另一个使用案例,我的经验是良好的可视化大大有助于讲故事,从而最终提升了我个人品牌的可信度并与利益相关者建立了信任。因此,不论单独还是作为数据科学管道中的重要环节,可视化都是至关重要的。
参考资料
1 澳大利亚统计局 (2021),地区社会经济指数 (SEIFA),ABS 网站,访问日期 2023 年 12 月 24 日(创意共享许可)
我之前在以下文章中博文过其他可视化技术。如果你喜欢这些,确保关注 Medium 上的作者 Medium!
在 PowerBI 中使用形状图可视化进行互动地理空间可视化
用 Python 创建互动地理空间可视化
当我顺应 AI/ML 浪潮时,我喜欢用全面的语言编写和分享逐步指南和操作教程,并附有现成的代码。如果你想访问我所有的文章(以及其他 Medium 上的从业者/作者的文章),可以使用 这个链接 注册!
使用图表讲故事
原文:
towardsdatascience.com/storytelling-with-charts-23dd41096721
第一部分:显示单一定量变量
Darío Weitz
· 发表在数据科学前沿 ·8 分钟阅读·2023 年 2 月 10 日
--
图片由Derek Story拍摄,Unsplash
每个数据集都包含大量细节。此外,许多数据集只是充满了没有任何分类的数字列表。
例如:
· 阿根廷中部地区过去 20 年 1 月份的平均降雨量。
· 信息系统工程学生的智商测试结果。
· 根据 2022 年普查数据,阿根廷 24 个省份的人口。
· 根据一天中的星期几和小时计算的阿根廷车祸致死人数。
以上是典型的包含相对较少单一定量变量的数据集示例。请记住,定量数据表示数量。同时,记住根据值是测量还是计数,定量变量可以是连续的或离散的。
为什么数据分析师应该对这样一组数字感兴趣?首先,尽管许多科学、商业或管理问题涉及数字变量之间的比较、关系、组成或趋势,但对数据集中每个变量的可视化作为基本的探索性数据分析,对于理解这些变量的变化模式非常重要。其次,如上例所示,探索性图表可以帮助理解生成数据集中存储数字的过程。
如我之前所述:“在探索性图表中可以发现数据集的三个重要特征:1)异常值,数据集中与其他数据非常不同且不符合相同模式的数据。这些异常值可能代表了有价值的信息用于分析。首先,必须验证这些异常值的存在是否由于数据测量错误;2)间隙,一个不包含数据的区间。数据间隙的可视化合理化了对其存在原因的深入分析;3)簇,孤立的数据点组,也可能需要对其在图表中存在的原因进行特定分析。当然,间隙和簇可能代表数据收集方法中的错误。”
本文将展示三种简单的(基于 Python)探索性图表,以便可视化单一定量变量的分布。
点图
点图,也称为点 chart,是最简单的可视化图之一,由在二维 x 轴和 y 轴方案中绘制的数据值点(小圆圈)组成。一个轴显示数据值分组的类别或范围,而另一个轴显示每个不同组的数据点数量。每个小圆圈代表一个值。根据分析师的喜好,点可以垂直或水平堆叠。
图 1:作者使用 Matplotlib 制作的点图。
它们适用于小到中等规模的数据集(10–45 个值),并且非常有助于突出异常值、间隙、簇和偏斜。
点图有两种类型:1)威尔金森点图;2)克利夫兰点图。前者表示连续数据值的分布,而后者是条形图的替代方案。本文专门讨论威尔金森点图。
Matplotlib 没有专门绘制点图的方法。Plotly 使用px.scatter来绘制克利夫兰点图。我发现了一段非常有趣的 Python 代码,来源于帕特里克·菲茨杰拉德(1)在stackoverflow(2)上:
# Create random data
rng = np.random.default_rng(1) # random number generator
data = rng.integers(10, 30, size=45)
values, counts = np.unique(data, return_counts=True)
# Set formatting parameters based on data
data_range = max(values)-min(values)
width = data_range/2 if data_range<30 else 15
height = max(counts)/3 if data_range<50 else max(counts)/4
marker_size = 10 if data_range<50 else np.ceil(30/(data_range//10))
# Create dot plot with appropriate format
fig, ax = plt.subplots(figsize=(width, height))
for value, count in zip(values, counts):
ax.plot([value]*count, list(range(count)), marker='o', color='green',
ms=marker_size, linestyle='')
for spine in ['top', 'right', 'left']:
ax.spines[spine].set_visible(False)
ax.yaxis.set_visible(False)
ax.set_ylim(-1, max(counts))
ax.set_xticks(range(min(values), max(values)+1))
ax.tick_params(axis='x', length=0, pad=10)
ax.set_title('Dot Plot of Random Integer Values')
plt.show()
图 2 是通过代码创建的。很容易注意到在 21 和 24 之间存在一个间隙,该间隙被样本中最高浓度的值所包围。聪明的数据分析师应该探究这种特殊值分布的原因。
图 2:作者使用 Matplotlib 制作的点图。
茎叶图
这种表格图在 1900 年左右非常流行,并在约翰·W·图基在普林斯顿大学的讲座中被重新发现(3)。这种图表之所以得名,是因为每个值被分为叶子和茎。
我们如何手动绘制茎叶图?我们首先确定数据的范围。然后,根据数据集的大小,我们将范围划分为固定长度的区间。接下来,我们绘制一条垂直线,将数字的前几位(千位、百位或十位)除了最后一位 按升序排列在垂直线的左侧。这就是茎。我们再次遍历数据集,将下一个显著的(最后)数字写在垂直线的右侧。这就是叶子。在最左侧,从底部到顶部,我们累计我们要绘制的值的数量。
例如,给定以下列表:[16, 25, 47, 56, 23, 45, 19, 55, 44, 27],对应的茎叶图是(区间长度等于 10):
图 3:作者用 Stemgraphic 制作的茎叶图。
在图的底部,我们标出了值 16 和 19,然后是 23,再上一行放置 25 和 27。接着是 44,再上一行是 45 和 47,最后在两个单独的叶子中是 55 和 56。
Matplotlib 和 Plotly 都没有茎叶图。专门为此目的设计了一个名为Stemgraphic的 Python 模块。它支持任何大小的数据,并可以生成准备打印的图表(4)。
首先,你必须安装它:
pip3 install -U stemgraphic
Stemgraphic 需要 docopt、Matplotlib 和 pandas。可选地,安装 Scipy 会提供次级图表(5)。
假设我们有来自两个信息科学工程课程的 IQ 数据。让我们使用stemgraphic创建一个茎叶图来分析这些数据的分布。
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import Stemgraphic
data_IQ1 = [111, 85, 83, 98, 107, 101, 100, 94, 101, 86,105, 122, 104, 106, 90, 123, 102, 107, 93, 109]
data_IQ2 = [146, 81, 91, 88, 98, 121, 83, 105, 97, 116, 93, 105, 108, 87, 104, 120, 109, 108, 107, 75, 93, 108, 104]
data_IQ = data_IQ1 + data_IQ2
df_IQ = pd.DataFrame(data_IQ, columns = ["IQ"])
stemgraphic.stem_graphic(df_IQ['IQ'])
print('')
print(df_IQ.describe())
print("")
print(df_IQ.agg({
"IQ": ["count", "min", "max", "mean", "median", "skew"],
})
)
图 4:作者用 Stemgraphic 制作的茎叶图。
图 5:作者制作。
我们可以清晰地观察到两个离群值,一个在下端(75),另一个在排名顶部(146)。图形还表明中位数在 104 和 105 之间。数据呈适度偏斜。最左边的列包含累计计数。茎叶图的主要优点是我们可以通过观察图表轻松重建源数据。
直方图
也叫柱状图,它们是数据集分布的图形表示。它们是二维图,具有两个坐标轴:水平轴被划分为箱子(数值范围的区间);垂直轴是频率轴,其值来源于每个箱子的计数。每个箱子的频率通过垂直矩形条的面积显示。
直方图提供了连续定量变量分布的视觉总结。你可以推断数据集的位置、分布、对称性和偏斜程度。你还可以注意到是否存在集群、间隙和离群值。
你可以在我之前的 Medium 文章中找到更多理论概念和几个示例,关于直方图的内容:1) 直方图,为什么以及如何。讲故事、技巧和扩展;2) 使用 Plotly Express 的直方图,主题和模板。
我在与蒙特卡洛模拟相关的几篇文章中使用了 Matplotlib 来绘制直方图。
例如,在 蒙特卡洛模拟。第二部分 中,我使用了以下代码来显示图 6:
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(list_of_costs, histtype ='bar', bins=20, color = 'c',
edgecolor='k', alpha=0.65, density = True) #density=True show probability
ax.axvline(media, color='g', linestyle='dashed', linewidth=3)
ax.axvline(Sum_most_like, color='r', linestyle='dashed', linewidth=3)
ax.text(48,0.185, 'Mean - - -', color = 'green' )
ax.text(48,0.155, 'Most Like - - -', color = 'red' )
ax.set_title('Frequency Chart')
ax.set_ylabel('Probability')
ax.set_xlabel('Total Cost (MM*1000 U$S)')
ax.grid(axis = 'y')
图 6:作者用 Matplotlib 制作的直方图。
在 蒙特卡洛模拟。第三部分 中,我使用了以下代码来显示图 7 中所示的直方图:
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(list_cost_TTR, histtype ='bar', bins=20, color = 'c',
edgecolor='k', alpha=0.65, density = False) # density=False show counts
ax.axvline(media, color='g', linestyle='dashed', linewidth=3)
ax.axvline(median, color='r', linestyle='dashed', linewidth=3)
ax.set_title('Frequency Chart')
ax.set_ylabel('Counts')
ax.set_xlabel('U$S')
ax.grid(axis = 'y')
图 7:作者用 Matplotlib 制作的直方图。
结论
约翰·图基曾表示,探索性数据分析就像是数字侦探工作。他试图解释说,数据中的统计结构是建立模型或理论之前的基础,这些模型或理论可以解释生成存储在数据集中的数字的过程。
每个数据分析师都必须具备某些可视化工具,用于分析和展示数据。特别是,对于单一的定量变量。
在本文中,我描述了三种简单的基于 Python 的探索性图表,用于可视化单一定量变量的分布:点图;茎叶图;直方图。
在三个图形之间选择的主要标准是数据集的大小。建议点图不应显示超过 50 个值。还建议茎叶图不应显示超过 300 个数值。相反,直方图在所表示的数值增加时会得到改善。在前面提到的文章中(与蒙特卡洛模拟相关),这些直方图总结了来自 5000 次复制(5000 个数值)的模拟运行获得的信息。
参考文献
(1) stackoverflow.com/users/14148248/patrick-fitzgerald
(2) stackoverflow.com/questions/49703938/how-to-create-a-dot-plot-in-matplotlib-not-a-scatter-plot
(3) Tukey, John. “Exploratory Data Analysis”, Addison-Wesley, 1977
(4) stemgraphic.org/doc/modules.html
(5) pypi.org/project/stemgraphic/
用图表讲故事
原文:
towardsdatascience.com/storytelling-with-charts-29e233182be6
第四部分(II):你想展示组成吗?
Darío Weitz
·发布于Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 1 日
--
照片由Jonatan Pie在Unsplash上提供
这是第二部分(第四篇文章中的一部分),其目的是指明当传达给受众的消息目的是展示数据的组成时,最佳的数据可视化技术。
为了更好地理解本文内容,强烈建议阅读(或重温)上一篇文章,其中描述了组成的概念及其一些分析元素。
在上一篇文章中,我们指出,以下是最常用于展示组成的六种图表:饼图;堆叠条形图;树图;堆叠面积图;瀑布图;Marimekko 图表。
列表中的前三种已在那篇文章中详细描述。现在,我们将集中讨论列表中的后三种(堆叠面积图;Marimekko 图表;瀑布图)。
堆叠面积图
首先,让我们定义一下什么是面积图:它是一种线图,其中连接数据点的线与水平轴之间的区域填充了特定颜色。
有四种不同类型的面积图:1) 标准面积图;2) 堆叠面积图;3) 百分比堆叠面积图;4) 重叠面积图。只有堆叠面积图(StACs)和百分比堆叠面积图(%StACs)用于显示组成。
在两个堆叠区域图中,多个区域堆叠在一起。 它们展示了随时间变化的数值变量(动态组成),并使用通常是分类的第三个变量来显示组成。
与 StAC 相关,这是一种部分与整体图表,其中每个区域表示相对于类别总量的每个部分的绝对值。与 %StAC 相关,这也是一种部分与整体图表,其中每个区域表示相对于类别总量的每个部分的百分比。不同区域之间没有重叠。 在 StAC 中,纵轴的最终高度与所表示的所有数值之和相关。在 %StAC 中,纵轴的最终高度始终为 100%。
图 1 显示了一个 StAC,表示 2013 年至 2018 年间四个不同地区的 PS4 销售情况。图表右上角显示的图例指示不同颜色区域属于哪个地区。可以看到每个地区(每个区域,每个部分)如何对总销售(整体,总销售额)做出贡献。每个区域的高度代表了每个特定地区的销售绝对值,而最终高度是这些值的总和,表示每年的总销售额。可以看出,StAC 应主要用于传达整体趋势和每个部分对整体的相对贡献,而不必关注显示每个部分的精确数值。
图 1:堆叠区域图。图表由作者使用 Plotly Express 制作。
图 2 是一个 %StAC,表示相同的 PS4 销售数据。每个区域表示每个地区相对于全球 PS4 销售总量的百分比。如上所述,最终高度为 100%。毫无疑问,这种图表比图 1 中的图表更好地分析了全球销售的组成。
图 2:百分比堆叠区域图。图表由作者使用 Plotly Express 制作。
最终警告:StAC 和 %StAC 相对难以读取和理解,因为它们依赖于观众通过比较堆叠区域解码数据信息的能力。我们建议仅使用它们来传达整体趋势和每个部分对整体的相对贡献。
Marimekko 图表
它们是一种特殊的可变宽度条形图。 Marimekko 图表(MCs)类似于 100% 堆叠条形图,但不同之处在于它们的矩形条可以有不同的宽度。
MC 用于显示数据集中每个类别的两个数值变量。它们有两个轴:垂直轴有 100% 数值刻度,而水平轴可以是分类的或数值的。矩形条形图以垂直方向排列,中间没有空隙。水平轴的整个宽度都被占据。
图 3 显示了一个 Marimekko 图表。该图表显示了每个品牌和地区的年度收入。百分比垂直轴表示每个地区的百分比,而水平轴表示每个品牌的年度收入。我们在一个图表中指示了每个类别和子类别的两个数值。
图 3:一个 Marimekko 图表。由 Vizzlo 制作并获得许可 (#1)。
如我之前所述的我之前提到的: “Marimekko 图表的特点包括:一个矩形区域被划分为不同宽度的小矩形;垂直堆叠的矩形;占据整个图表宽度的水平轴;带有百分比刻度的垂直轴;顶部基线上的每个品牌的总收入;不同的条形宽度可以计算每个品牌对总收入的相对贡献”。
Marimekko 图表可以用作 100% 堆叠条形图的替代品,但仅用于静态分析(展示某一时刻的组成)。它们绝不应用于展示随时间变化的组成。
与堆叠面积图相同的警告:MCs 难以解释,因为人类不擅长计算面积,特别是当矩形数量增加时。
瀑布图
瀑布图(WCs)是一种特殊类型的条形图,表示数据在添加和减去之间的累积效应。其信息是讲述两个数据点之间的组成变化。
WC 由一个初始垂直条形图、一组中间垂直条形图和一个最终垂直条形图组成。通常(且建议的)布局是初始和最终的垂直条形图(列)具有相同的颜色,而中间条形图(浮动条形图)显示添加的绿色值和减少的红色值。第一列和最后一列通常从零基线开始。
图 4 显示了一个基于类别的瀑布图,具有上述特征。这种类型的 WC 通常用于人力资源(显示某个部门的招聘和离职情况)、特定业务(显示收入和支出)、仓库(库存增加、库存减少)以及许多其他数据在正负值之间波动的情况。基于时间的 WC 用于金融行业(显示在单个时间段内的收益和损失)。
图 4:作者使用 Plotly 制作的基于类别的瀑布图。
瀑布图提供比标准条形图更多的上下文信息。后者仅显示初始值和最终值,而前者则指示加法和减法元素对总量的贡献,以及这些初始值和最终值之间变化的组成。
讲述初始值与最终值之间变化的这一显著能力,其复杂性在于正确解读这些变化的幅度。这是因为浮动列中缺乏共同的基准,使得比较连续的加法和减法的具体大小变得困难。最佳做法是在列中添加数字注释,并通过连接的水平线将它们连接起来(图 4 和图 5)。
图 5 显示了一个基于时间的瀑布图,展示了一个虚拟网页每月访客数量变化的故事。任何其他视觉表现形式对普通观众理解这种特定情况会更复杂。
图 5:作者使用 Plotly 制作的基于时间的瀑布图。
结论
在任何数据可视化项目中,一个关键问题是:“我选择了正确的图表来讲述我的故事吗?”
选择最合适的图表取决于要传达给观众的消息的性质。
当传达的消息是组成时,使用六种不同类型的图表:饼图;堆积条形图;树图;堆积面积图;瀑布图;马里梅科图。
我们的建议是,对于静态组成使用饼图,对于动态组成使用堆积条形图。当整体由十个或数千个部分组成时,树图是一个有效的替代方案。马里梅科图适用于表示两个数值变量,包括一个主要类别及其子类别。最后,瀑布图仅显示初始值和最终值之间变化的组成。
如果您觉得这篇文章有趣,请阅读我之前的 56 篇文章中的任何一篇:medium.com/@dar.wtz
。关于数据可视化、模拟、蒙特卡罗技术、仪表盘等的浏览量超过 30 万次。
1: vizzlo.com/
用图表讲故事
原文:
towardsdatascience.com/storytelling-with-charts-c59c52c49871
第三部分:你想要比较项目吗?
Darío Weitz
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 5 月 8 日
--
图片由 charlesdeluvio 提供,来源于 Unsplash
这是系列中的第三篇文章,旨在帮助从事数据可视化活动的人选择最合适的图表,以便向他们的观众展示他们试图传达的信息类型。
在系列的第一篇文章中,指出了三种基于 Python 的图表,这些图表允许展示单个定量变量的分布。
在信息包括显示一组特定数字的大小时,最合适的图表在系列的第二篇文章中已有指示。
现在信息是比较项目。以下是常用的图表,用于呈现这种图形表示:
-
· 标准条形图
-
· 聚类条形图
-
· 重叠条形图
-
· 棒棒糖图
-
· 哑铃图
-
· 分歧条形图
标准条形图
标准条形图(SBCs)每个项目或类别只比较一个数字变量。它们试图回答这样一个问题:“每个类别有多少?” 请记住,类别或项目指的是诸如国家、城市、姓氏、公司、品牌、年份、月份、日期、数据科学方法等定性元素。
图 1:垂直和水平 SBC 的示意图。由作者使用 Matplotlib 制作。
有垂直条形图(柱状图)和水平条形图。柱状图通过垂直矩形条形的高度比较项目。水平条形图通过水平矩形条形的长度比较数量。矩形的最终值(每个条形的长度或高度)与打算比较的数值成正比。每个条形代表一个单一项目,条形之间通常留有一定的空间。
观众对这些类型的图表很熟悉,因此可以专注于信息而无需浪费时间理解图示。更多细节可以在我的上一篇文章中找到。
聚类条形图
聚类条形图(CBCs)展示了主要项目或类别与其子组之间存在的相对比例的数据信息,这些子组属于第二个分类变量。一个变量被称为分类变量,如果其观察值可以分配到不重叠的类别中。通常,它们只能取有限(通常是固定)的值。
与 SBCs 类似,它们可以水平或垂直排列。每个主要项目被分为一个条形图簇,表示第二个分类变量的子组。每个子类别的数量由一些相邻的矩形条形的高度或长度表示,这些条形排列在一起形成簇,簇之间的间隙略宽于单个标准条形。
CBCs 用于讲述比较和比例,但重点在于组成(整体分析的一部分)。因此,当整体被划分为多个部分时,CBCs 特别有效。它们可以进行跨子组的比较(堆积条形图进行子组内的比较)。
图 2:垂直和水平 CBCs 的示意布局。由作者使用 Matplotlib 制作。
下图显示了一个虚拟公司在 2016–2019 年期间的销售、支出和利润信息。这是一个垂直排列的 CBC,以年份作为主要类别。销售、支出和利润每年以条形簇的形式表示。图表显示,尽管 2018 年支出显著增加,但利润仅减少了少量。
图 3:虚拟公司在 2016–2019 年期间的经济表现。该图由作者使用 Matplotlib 开发。
观众也对这些类型的图表很熟悉,因此可以专注于信息而无需浪费时间理解图示。更多细节可以在我的上一篇文章中找到。
重叠条形图
当我们希望在单一图表中比较每个项目或类别的两个数值变量时,可以使用重叠条形图(OVCs)。当然,这两个数值变量必须具有足够的相关性以证明比较的合理性。
概念上,通过重叠的方式对比两个变量的数值。这种重叠允许我们以更强的表达力讲述故事。
与 SBCs 和 CBCs 类似,经典布局包括两个轴和可以水平或垂直定向的矩形条。一个轴显示类别,另一个轴显示与待比较变量相关的数值。当然,这两个数值变量必须共享相同的数值尺度。每个数值变量的条宽不同,较小的条在前以便于阅读,尽管这种情况并不总是适用于所有图表。
图 4:作者使用 Matplotlib 开发的重叠条形图。
可以用部分重叠的条形图来比较两个以上的数值变量。在这种类型的图表中,表示不同数值变量的条形图会被前面位置的其他矩形部分遮挡。从概念上讲,部分重叠的条形图类似于 CBC,只不过表示不同子组的矩形开始重叠而不是并排放置。在这方面,必须非常小心,以免混淆观众。因此,建议仅用于比较最多三个不同的数值变量。
更多细节可以在我的上一篇文章中找到。
棒棒糖图
从概念上讲,与 SBCs 类似,棒棒糖图(LCs)用于对不同项目或类别之间进行比较。它们只比较每个项目的一个数值变量。不同之处在于 LCs 用一条末端带点的线代替了矩形条。该点的位置表示相应的数值。因此,这些线末端的点的位置等同于垂直或水平标准条的高度或长度。
图 5:标准条形图与棒棒糖图的比较。由作者使用 Matplotlib 制作。
经典布局包括两个轴和可以垂直或水平定向的非常细的线条。一个轴表示类别,另一个轴具有与待比较项目相关的数值尺度(最好带单位)。
当需要显示大量相似值时,建议使用 LCs 作为 SBCs 的替代方案。这样,我们可以避免显示杂乱的图表,也可以防止观众经历一种令人烦恼的光学效应,称为摩尔纹。
图 6:作者使用 Plotly 制作的棒棒糖图。
哑铃图
虽然哑铃图与棒棒糖图相似,但它们的主要目标是表示两个数据点之间的变化。在这方面,它们通常用于对比两个类别。
图 7:作者使用 Plotly 制作的哑铃图。
我们建议将它们用于比较范围、分布、变化和两个数值变量之间的差异或点之间的距离。
发散条形图
发散条形图(DBC)基本上由两个水平矩形(条形)组成,它们对齐,使得一个矩形从右到左延伸,另一个从左到右延伸,并且两个矩形都从一个共同的垂直基线开始,通常位于图表的中心。如前所述,每个矩形(条形)的长度与其显示的数值成比例。每个条形表示一个项目或分类变量,并且它们之间必须留有一些空间。对于 DBC 的最佳编码是当需要比较两个数值选项时。
蝴蝶图是一种特殊的 DBC。通常,它们在条形之间留有一些空间,以放置被比较的变量名称。
图 8:蝴蝶图。来源:#1。
另一组图表也可以用来比较项目:雷达图、Mekko 图和 Marimekko 图。问题是它们很难解释,因为它们依赖于受众通过比较角度或面积解码数字信息的能力。因此,总是更倾向于选择之前提到的图表。更多细节请参见我的上一篇文章。
数据可视化是讲述数据背后故事的最强有力工具。但如果我们没有正确选择最适合我们想要展示的信息的可视化技术,我们的受众可能会感到困惑。
这系列文章的目的就是指出哪些图表和图示最适合特定类型的信息。
在本文中,我们指出了六种不同的图表[标准条形图、簇状条形图、重叠条形图、棒棒糖图、哑铃图、发散条形图],旨在对比项目。我们指出了它们之间的相似性和差异。
数据科学家和数据分析师接受与受众所需信息类型相关的培训是至关重要的。与上述内容相关的是选择最适合讲述我们数字背后故事的图表。
敬请关注即将发布的文章。
#1:https://www.slideteam.net/butterfly-chart-tornado-chart-for-price-comparison-powerpoint-slide.html
图表讲故事
原文:
towardsdatascience.com/storytelling-with-charts-dae59034f60
第二部分:你想展示数量吗?
达里奥·维茨
·发表于 Towards Data Science ·8 分钟阅读·2023 年 4 月 3 日
--
图片由 Claudio Schwarz 提供,来源于 Unsplash
数据可视化的第一个规则是:“数据可视化是一种沟通工具”。它是讲述商业、商业、科学、学术和创业环境中的故事的最佳工具。
但请记住,你总是有一个观众。也许这个观众包括你的老板、同事、客户、员工、公务员,或者你自己。
在任何情况下,关键问题是:我是否清晰地传达了信息?
下一个问题是:我是否选择了正确的图表来讲述我的故事?
在 本系列的第一篇文章 中,我们介绍了三种基于 Python 的探索性图表,允许可视化单一定量变量的分布。在本文(以及接下来的几篇文章中),我们将建议根据要传达给观众的信息的性质选择最合适的图表。
信息 1:展示数量
这个信息包括展示某一组数字的大小:销售额、销售的项目数量、缴纳的税款、制造的物品、生产的项目、完成的程度、人口数据、性能数据、商业智能仪表板中的关键绩效指标(KPIs)等。
当要展示的数据只有一项(一个数值度量)时,通常会附带其他指标来添加额外的信息并增强信息的传达效果。
常用的数字图形表示图表如下:
-
数值指标
-
角度指标
-
子弹图
-
图示符号
-
柱状图
-
横向条形图
数值指标
它们是定制的小部件,用于显示一个或两个数值,通常伴有标题和副标题、一些注释以及趋势指示器。它们经常用于在商业智能仪表盘中显示关键绩效指标(KPI)。
图 1:作者使用 Plotly 制作。
Plotly包括一个名为Indicator的图形,具有几种模式。对于上述图中的数字指示器,Plotly trace是go.Indicator(),对应的参数有:mode = “number+delta” 用于可视化数值量以及该数值与参考值之间的差异;value 用于指示要显示的数值;domain 用于指示图形的位置;title 用于数值上方的文本。delta 有几个属性:reference 是要比较的值;relative = False 表示数量与参考值之间的差异必须以绝对方式计算;position 确定相对于数值的位置;valueformat 编码差异的数值格式。mode = “delta” 仅显示数量与参考值之间的差异。
import plotly.graph_objects as go
# defining the path where the charts will be stored.
path = your_path
fig = go.Figure()
fig.add_trace(go.Indicator(
mode = "number+delta",
value = 300,
delta = {'reference': 380, 'relative': False,
"valueformat": "0.0f"},
title = {'text': "Total Revenue (MU$S)", 'font': {'size': 32}},
domain = {'x': [0.25, 0.75], 'y': [0.5, 1.0]}
))
fig.add_trace(go.Indicator(
mode = "delta",
value = 450,
delta = {'reference': 350, 'relative': False,
'position' : "bottom", "valueformat": "0.0f"},
title = {'text': "Total Quantities (units)", 'font': {'size': 32}},
domain = {'x': [0.25, 0.75], 'y': [0.0, 0.4]}
))
fig.update_layout(paper_bgcolor = "lightblue")
fig.write_image(path + 'figIndic1.png')
fig.show()
角度指示器
最著名的角度指示器是标准仪表图。这是一种非常简单的图表,它提供了有关单一数值测量的准确信息,允许将显示的值与目标值进行比较,并与由不同颜色带表示的数值范围进行对比。目标值可以是基准、需要超越的前一个值,或是需要达到的目标。
标准仪表图类似于汽车的速度计。它有一个径向数值刻度,分为几个部分,每个部分由特定的颜色标识。它们还显示了下限和上限值,特别是目标值。当前值也用指针或指示针标示,或者用弯曲条在径向数值刻度上移动。最常见的配置有三个部分及其对应的颜色 [红色、黄色、绿色]。通常红色表示差、警告或低性能;黄色表示一般、警示或正常性能;绿色表示好、满意或高性能。红色还意味着数值测量低于下限,黄色表示数值测量在阈值之间,绿色则表示高于上限。
以下图示表明当前销售值(45)处于平均范围内,但比预定目标低 25%。
图 2:作者使用 Plotly 制作的标准仪表图。目标为 60。阈值为 25 和 50。
图 2 是用以下 Python 代码制作的。更多细节可以在我的上一篇文章中找到。
import plotly.graph_objects as go
path = your_path
val = 45
title = 'Sales (MU$S)'
fig1 = go.Figure(go.Indicator(
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': title, 'font_size' : 40},
value = val,
number= {'font_size' : 50},
mode = "gauge + number",
gauge = { 'shape' : 'angular',
'steps' : [
{'range': [0, 25], 'color': "red"},
{'range': [25, 50], 'color': "yellow"},
{'range': [50, 100],'color': "lightgreen"}
],
'bar' : {'color' : "black", 'thickness': 0.5},
'threshold' : {'line': {'color': "orange", 'width':8},
'thickness': 0.8, 'value': 60},
'axis': {'range': [None, 100], 'tickformat' : '$'},
}
))
fig1.add_annotation(x=0.5, y=0.4, text="-25%",
font=dict(size = 30, color="darkred"),
showarrow=False)
# Add a circle
fig1.add_shape(type="circle", x0=0.4, y0=0.3, x1=0.6, y1=0.5,
line_color="purple")
fig1.update_layout(paper_bgcolor = "lightblue")
fig1.write_image(path + 'figIndic1.png')
fig1.show()
子弹图
从概念上讲,它们类似于标准仪表图。它们显示关于单一定量测量的信息,将其当前值与目标值进行比较,并与由不同颜色带表示的一组数值范围进行对比。在视觉上,它们类似于标准条形图(水平或垂直),因为它们也通过矩形条的长度或高度来编码信息。主要的区别在于,子弹图包括一个中央狭窄的条,表示报告的数值测量的当前值。
图 3:作者使用 Plotly 制作的子弹图。目标是 450。阈值为 250 和 350。
目标值由 450 处的红色垂直线表示。类似于仪表图,有多个扇区(从三个到五个),用颜色或单一色调的不同强度来表示。扇区指示定性值,如[差,满意,良好],[警报,警告,满意],[差绩效,平均绩效,高绩效]。扇区也可以表示与显示的数值相关的不同数值范围。
子弹图相较于仪表图的主要优点是它们占用的屏幕空间较少,没有分散注意力的装饰物,并且最重要的是,其较小的尺寸允许你在单一图表中比较多个不同的类别。
更多细节和 Python 代码可以在我的上一篇文章中找到。
象形图
又称:象形图,图示单位图,图片图表。
象形图使用图标显示离散的数据集。通常,图标代表与所示数据相关的主题类别;例如,人口数据使用人物图标(见下图)。每个图标可以代表一个单位或其他数量(如 10,100,1000)。数据集通过图标的列或行的接近程度进行比较。
由Edward Howell拍摄,来源于Unsplash
使用图标有时有助于克服语言、文化和教育水平的差异。观众通常很容易理解它们。 图标还可以提供数据的更具代表性和全面的视图。
在象形图中,始终避免显示大量图标,这可能会使显示的值难以计数。也要避免部分图标,这可能会使观众感到困惑并影响信息的传达。
柱状图
又称:柱形图,垂直条形图
这是特别类型的垂直标准条形图。它们有两个坐标轴:水平轴显示类别,垂直轴显示数量。每个类别的数量通过垂直矩形条的高度显示。每条的高度与要显示的数值成比例。每条代表一个类别,并且它们之间留有一些空隙。
图 5:由作者使用 Matplotlib 制作的柱形图。
虽然柱形图的经典用途通常是进行比较、显示排名或指示时间趋势,但它们对于简单展示数值也是非常有用的。这些数值通过每个矩形的最终高度进行编码。柱形图也适用于显示负数值。
通过在条形图内或条形图上方使用注释,可以大大提高信息的清晰度。显示网格线也很方便,但仅限于水平网格线。
图 6:由作者使用 Plotly 制作的柱形图。
好的实践建议将垂直轴从 0 开始,因为如果基线被修改,我们不可避免地会扭曲视觉效果。还必须避免所有 3D 效果,因为这些效果违反了适当讲述的所有规则。最后,必须避免使用圆角而不是锐角矩形,因为这会使最终数值的读取变得困难。
始终记住,最多 10% 的男性观众可能有色觉障碍。在这方面,尽量不要在同一图表中使用绿色和红色条形。 色盲症是一种缺乏绿色感知锥体的情况,而红盲症则是缺乏红色感知锥体的情况。
图 7:由作者使用 Plotly 制作的柱形图。
水平条形图
这是特别类型的水平标准条形图。它们有两个坐标轴:水平轴显示数量,垂直轴显示类别。每个类别的数量通过垂直矩形条的长度显示。每条的长度与要显示的数值成比例。每条代表一个类别,并且它们之间留有一些空隙。
图 8:由作者使用 Matplotlib 制作的水平条形图。
通过在条形图内或条形图右侧使用注释,可以大大提高信息的清晰度。显示网格线也很方便,但仅限于垂直网格线。
在绘制多个类别时,尤其是带有非常长标签的类别时,水平条形图更为合适。
图 9:由作者使用 Plotly 制作的水平条形图。
正如之前在柱状图描述中提到的,你必须将数值轴从 0 开始,避免所有的 3D 效果,并且避免使用圆角边缘。同时,还要考虑红绿色盲的问题。
数据可视化是将某种信息转化为视觉上下文的过程。它是向多样化的受众讲述故事的强大工具。为此,提供了大量的图表、图形和示意图。完成任务的关键步骤是选择合适的图表来讲述你的故事,这取决于你想传达的信息。
在这篇文章中,我们介绍了六种不同的图表 [数值指示器、角度指示器、子弹图、图标图、柱状图和水平条形图],这些图表旨在展示一个或几个数值,通常还附带注释和趋势指示器。我们介绍了优点,一些良好的实践,警告以及几个 Python 代码。
明智地选择,但始终问自己:我是否在清晰地沟通?
图表讲故事
原文:
towardsdatascience.com/storytelling-with-charts-fbd23ebb70ee
第四部分(I):你想展示组成部分吗?
达里奥·维茨
· 发表在 Towards Data Science · 阅读时间 7 分钟 · 2023 年 6 月 1 日
--
图片由 Hiral Parikh 提供,来源于 Unsplash
这是系列中的第四篇文章,旨在帮助人们根据他们试图向特定观众展示的信息来决定使用哪种类型的图表。
前三篇文章聚焦于以下内容:文章 1,展示了单一数值变量的分布;文章 2,显示了一系列数字的大小;文章 3,比较了各项。
本文的目的是指示在展示组成部分时最常用的图表。请记住,组成部分涉及一个可以分为各个部分的整体,以及每个部分与该整体的关系(绝对或相对)。分析可以是静态(显示某一时刻的组成部分)或动态(显示组成部分随时间的变化)。
常用于展示组成部分的图表如下:
· 饼图
· 堆叠条形图
· 堆叠面积图
· 瀑布图
· Mekko 图
· 树状图
本文将集中描述以下图表类型:饼图;堆叠条形图;以及树状图。在接下来的文章中,我们将描述剩余的三种图表。
饼图
饼图(PCs)(图 1)是圆形图表,分成楔形扇区,用于显示整体的部分,这些部分是互斥的且不重叠的类别。完整的圆表示整体,而楔形(切片、扇区、段)表示部分。因此,完整的圆必须表示所有数据的总和,并且必须始终加总到 100%。包含在一个切片中的数值不能包含在另一个切片中,因为如前所述,扇区必须是互斥的,重叠是禁止的。从概念上讲,它们表示整体的简单份额。
图 1:作者使用 Plotly Express 制作的饼图。
饼图通过两个视觉标记来编码数值:1) 每个扇区的面积;2) 每个扇区沿圆周的长度。与大多数其他图表不同,饼图的轴和刻度不是线性的。
人类在视觉上计算曲线周围的面积或距离并不容易。这是对这类图表的主要反对意见,也是无休止争议的源头:它们非常简单易制,观众也习惯了它们的使用,但如果没有注释和百分比来澄清上下文,它们的解释则非常困难。
有时,可以通过以下替代方案增强饼图传达的信息:A1)甜甜圈图;A2)扇区分离。
A1: 甜甜圈图(图 2),概念上等同于饼图,但与饼图不同的是,它们在图表的中心有一个空白区域(类似于一个洞),其中显示某种附加信息,以增强叙事效果。
图 2:作者使用 plotly.graph_objects 制作的带注释的甜甜圈图。
中心的空白区域不允许进行面积比较,因此甜甜圈图只有一个视觉标记:每个扇区的数值仅通过圆周上的弧长进行编码。
A2: 扇区分离,通过从标准饼图或甜甜圈图中拉出或分离一个(或几个)扇区,可以增强信息传达效果。
图 3:作者使用 plotly.graph_objects 制作的带有拉出扇区的甜甜圈图。
当然,必须有充分的理由来证明这种分离的合理性,因为观众的注意力不可避免地会集中在该扇区。此外,还有一种视觉失真,使得与其他扇区进行直接比较变得困难。
最后,饼图只显示某一时刻的组成(静态组成)。有关饼图的更多细节可以在我的上一篇文章中找到。
堆叠条形图
堆叠条形图(SBCs)(图 4)是可以垂直(水平)排列的矩形条形图。它们有两个轴:一个轴显示类别,另一个轴显示数值及其对应的刻度。每个条形代表一个主要类别,并被分割成代表第二个分类变量子类别的矩形扇区。这些矩形段的高度(长度)显示了每个子类别的数值,这些矩形段垂直(水平)堆叠在一起。每个主要条形的最终高度(长度)表示每个类别的总量(百分比堆叠条形图除外)。
图 4:作者使用 Matplotlib 制作的简单堆叠条形图。
有两种特定类型的 SBCs:1)简单堆叠条形图(图 4);2)百分比堆叠条形图(图 5)。
简单 SBs 将每个子类别的绝对值堆叠在前一个子类别上,而百分比 SBs 将每个子类别的百分比堆叠在前一个子类别上。简单 SBs 中的主要条形通常具有不同的高度(长度),而在百分比 SBs 中,所有主要条形具有相同的高度。当仅相对差异重要时,必须使用百分比 SBs;当相对和绝对差异都重要时,使用简单 SBs。
图 5:作者使用 Matplotlib 制作的百分比堆叠条形图。
SBCs 在 展示随时间变化的组成(动态组成)方面表现出色。** 对于这种类型的动态分析,必须使用垂直方向堆叠条形图,并且与时间(天、月、年、时间范围)相关的变量始终 放在横轴上(图 6)。
图 6:作者使用 Matplotlib 制作的堆叠条形图。
在堆叠扇区数量或长时间绘制图表时应当谨慎。建议每个主要条形上堆叠的扇区不超过四个或五个。当主要条形过多或很长时间内扇区超过三个时,观众可能会感到困惑。在这种情况下,我们建议使用堆叠面积图,当你需要展示大量的时间数据和/或每个主要条形上有四个或更多的扇区时。
更多细节可以在我的上一篇文章中找到。
树形图
这种特定类型的图表由马里兰大学计算机科学教授 Ben Shneiderman 发明,他在寻找“目录树结构的紧凑可视化”时发明了它(#2)。
用我自己的话说:“树形图是一种基于矩形的可视化工具,允许你表示一个层次结构(树状结构)数据集。概念是比较数量并在物理限制的空间中展示某些层次结构的模式。为此,使用不同大小和颜色的矩形从不同角度展示数据集。目标不是指示确切的数值,而是将数据集‘拆解’成其组成部分,并快速识别出其较大和较小的组件” (#3)。
图 7:作者使用 Plotly Express 制作的树形图。
后来发现它们可以作为饼图的替代方案,显示部分与整体的关系。由于每个矩形的面积与其所代表的数值成正比,它们开始用于指示部分之间的相对比例和差异。整个矩形的面积必须表示所有数据的总和。树形图仅显示某一时刻的组成情况(静态组成)。
相比饼图,树形图有两个主要优点:1)它们可以在相对较小的空间内包含十个或上千个部分的嵌套矩形;2)它们用面积编码数值,这是比圆周上的弧长更好的视觉属性。
必须始终用适当的注释标明数值,因为缺乏共同的基线严重影响了部分矩形之间的比较。
图 8:作者使用 Plotly Express 制作的带注释的树形图。
更多详细信息请参见我的上一篇文章。
待续
我们很多时候需要向观众展示组成情况。这种部分与整体的分析对我们的特定观众并不总是容易解读的。因此,事先我们必须分析我们有哪些方法及其优缺点,如何与我们的数据和信息相关联。
如前所述,展示组成情况可以使用六种不同类型的图表:饼图;堆叠条形图;树形图;堆叠面积图;Mekko 图;瀑布图。这里,我们描述了其中的三种,特别是它们的特性、优点以及需要注意的一些事项。
请关注接下来的文章,描述其余图表。
参考文献
#1: https://serialmentor.com/dataviz/visualizing-proportions.html
2 Ben Shneiderman (1992). “使用树形图的树形可视化:2D 空间填充方法”。ACM Transactions on Graphics. 11: 92–99. doi:10.1145/102377.115768。
3 medium.com/towards-data-science/treemaps-why-and-how-cfb1e1c863e8
如果你觉得这篇文章有趣,请阅读我之前的 55 篇文章中的任何一篇:medium.com/@dar.wtz
用表格讲故事
原文:
towardsdatascience.com/storytelling-with-tables-514412adc4b7
第二部分:良好表格的建议指南
Darío Weitz
·发表于Towards Data Science ·8 分钟阅读·2023 年 1 月 10 日
--
图片由Joao Viegas提供,来源于Unsplash
我们已经努力工作了几个月,我们相信有一些非常重要的报告。
我们如何让观众理解我们贡献的深度?
简单:通过恰当地讲述我们的故事。
为了更好的讲述故事,我们通常使用视觉元素,如图表、图解、表格、插图和图片。
你可以查看我在 Medium 上的为什么与如何列表,了解使用经典图表(条形图、散点图、饼图、直方图)以及不常见图表(平行坐标图、Mekko 图、面积图、仪表盘图表)时的基本概念、技巧和需要避免的陷阱。
你还可以查看我在 Medium 上使用 Plotly 进行数据可视化的列表,其中展示了如何使用Plotly Express和Plotly.graph_objects来实现类似的目的。
但如果你需要使用表格,强烈建议你首先阅读以下指南。
表格
记住,表格是一种由行和列组成的结构,其主要目的是在标记的列中显示数字和/或文本的列表。
在这一系列的第一篇文章中,我指出了一个设计良好的表格应该具备的内容,并建议了使用表格的合适时机。
现在,我会提供一些建议和需要避免的陷阱。
表格的 10 个技巧
每个表格必须能够让观众理解,而无需参考周围的文本;
所有表格必须按照它们在文本中引用的准确顺序使用阿拉伯数字进行顺序编号;
将数值右对齐,以便比较它们的大小;
将包含文本的列的标题和内容左对齐;
使用适当且有意义的小数位数。统一小数位数并右对齐;
根据列内容对齐列名称。每列标题的首字母应大写;
设计表格时,确保要比较的数据是连续的或接近彼此的;
每个表格必须有一个标题或自解释的标题;
始终包括脚注以提供额外的说明、解释缩写或不常见的定义;
始终包括一个来源行,以指示表格中数据的来源;
避免的 10 个陷阱
首先问自己 你的观众是否需要表格来帮助他们理解你的故事;
如果你的表格只有一到两列且行数很少,也许可以在文本中展示这些数据;
如果数据已经在文本中,不要在表格中重复它;
表格不应占据整个屏幕(整页)。另一方面,表格也不应小到难以阅读;
“Table” 这个词不应该像“Fig.”用于图表那样被缩写;
始终在每一列中使用相同的小数位数。同时,不要在每列中更改计量单位;
使用不同背景颜色交替的行时要谨慎(斑马条纹)。它们通常在大型数据集中的使用是合适的,以帮助数据的可读性。尽量使用柔和的色彩调色板;
尽量避免使用垂直线来分隔列;
列标题不应明显宽于列中最宽的数据;
始终记住,表格 比图表需要更多的处理时间。给观众足够的时间来理解。
使用 plotly.figure.factory 的表格
如果你确实急需用最少的代码绘制表格,我建议你使用 plotly.figure_factory 模块。该模块包含一些包装函数,用于扩展 Plotly 的绘图功能。让我们来看看它是如何工作的:
# Tables with Plotly Figure Factory
import plotly.figure_factory as ff
data_to_table1 = [['Year', 'Battery Electric', 'Plug-in Hybrid', 'Full Hybrid', 'Petrol', 'Diesel'],
[2013, 5.79, 0.00, 6.68, 34.99, 52.54],
[2014, 12.56, 1.15, 6.95, 30.67,48.67],
[2015, 17.11, 5.30, 7.13, 29.64, 40.82],
[2016, 15.23, 13.82,11.31,28.84, 30.80],
[2017, 18.93, 19.70, 12.69, 25.54, 23.14],
[2018, 31.19, 17.81, 10.66, 22.64, 17.69],
[2019, 42.38, 13.55, 12.35, 15.68, 16.03]]
fig_ff1 = ff.create_table(data_to_table1)
fig_ff1.write_image(your_path + 'FF_Table1.png', scale = 2)
fig_ff1.show()
表格 1:由作者使用 plotly.figure.factory 制作
表格中显示的数据对应于 2013 年至 2019 年挪威按类型分类的新乘用车注册 1。
这是一个漂亮的表格,具有清晰的标题、易读的字体类型和大小、充足的空白空间,以及适当的斑马条纹颜色调色板。此外,你可以改变字体颜色和大小、背景颜色以及行高。
但该模块有一些限制,如果你需要更详细的表格,你必须求助于 Plotly Express。
使用 Plotly 图形对象的表格
参考 文章 1: “你可以使用 Table 方法在 Plotly 中创建表格。在其最基本的变体中,程序是使用 fig.add_trace()、go.Table 和两个参数:header 和 cells。前者如其名所示,代表表格的标题(第一行),而 cells 代表我们想展示给观众的数值或非数值” 2。
首先,我们需要转置在 data_to_table1 列表中的数字数据。
# Plotly Table 1
import numpy as np
import plotly.graph_objects as go
headers_tb1 = ['Year', 'Battery Electric', 'Plug-in Hybrid', 'Full Hybrid', 'Petrol', 'Diesel']
values_tb1 = [[2013, 5.79, 0.00, 6.68, 34.99, 52.54],
[2014, 12.56, 1.15, 6.95, 30.67, 48.67],
[2015, 17.11, 5.30, 7.13, 29.64, 40.82],
[2016, 15.23, 13.82, 11.31, 28.84, 30.80],
[2017, 18.93, 19.77, 12.69, 25.40, 23.14],
[2018, 31.19, 17.81, 10.66, 22.64, 17.69],
[2019, 42.38, 13.55, 12.35, 15.68, 16.03]]
transposed_tb1 = np.array(values_tb1).T.tolist()
fig_tb1 = go.Figure()
fig_tb1.add_trace(
go.Table(
header = dict(values = headers_tb1 ),
cells = dict(values = transposed_tb1)
))
fig_tb1.write_image(your_path + 'Table_tb1.png', scale = 2)
fig_tb1.show()
表 2:由作者使用 plotly.graph.objects 制作
我们还有很多工作要做,直到我们的表格准备好展示给观众。
然后,我们必须解决对齐问题:年份列将居中对齐,因为它是有序的,而其余的数字值将右对齐。
# Plotly Table 2
aligns = ['center','right', 'right','right','right', 'right']
fig_tb2 = go.Figure()
fig_tb2.add_trace(
go.Table(
header = dict(values = headers_tb1, align = aligns ),
cells = dict(values = transposed_tb1, align = aligns )
))
fig_tb2.write_image(your_path + 'Table_tb2.png', scale = 2)
fig_tb2.show()
表 3:由作者使用 plotly.graph.objects 制作
虽然有所改善,但我们仍然面临某些单元格小数位数不同的问题。
# Plotly Table 3
fig_tb3 = go.Figure()
fig_tb3.add_trace(
go.Table(
header = dict(values = headers_tb1, align = aligns ),
cells = dict(values = transposed_tb1, align = aligns,
format = [None, ",.2f"])
))
fig_tb3.write_image(your_path + 'Table_tb3.png', scale = 2)
fig_tb3.show()
表 4:由作者使用 plotly.graph.objects 制作
现在是时候进行装饰和最终细节的调整了。
首先,我们使用了 Plotly Express 模块 px.colors.qualitative 更改了颜色调色板,该模块包含内置的颜色序列,如 Pastel2\。我们使用了‘darkslategray’作为边框颜色,背景颜色则使用了三种不同的柔和色调:
# Plotly Table 4
fill_color_h = px.colors.qualitative.Pastel2[1]
line_color_h = 'darkslategray'
fill_color_c = [px.colors.qualitative.Pastel2[0], px.colors.qualitative.Pastel2[2],
px.colors.qualitative.Pastel2[2], px.colors.qualitative.Pastel2[2],
px.colors.qualitative.Pastel2[2], px.colors.qualitative.Pastel2[2]]
line_color_c = 'darkslategray'
fig_tb4 = go.Figure()
fig_tb4.add_trace(
go.Table(
header = dict(values = headers_tb1, align = aligns,
fill_color = fill_color_h,
line_color = line_color_h),
cells = dict(values = transposed_tb1, align = aligns,
fill_color = fill_color_c,
line_color = line_color_c,
format = [None, ",.2f"])
))
fig_tb4.write_image(your_path + 'Table_tb4.png', scale = 2)
fig_tb4.show()
表 5:由作者使用 plotly.graph.objects 制作
接下来,我们按照上面的说明(技巧 8 和 9)添加了标题和来源行。我们还通过增加单元格高度来增加了空白空间。
# Plotly Table 5
fig_tb5 = go.Figure()
fig_tb5.add_trace(
go.Table(
header = dict(values = headers_tb1, align = aligns,
fill_color = fill_color_h,
line_color = line_color_h),
cells = dict(values = transposed_tb1, align = aligns,
fill_color = fill_color_c,
line_color = line_color_c,
height = 35, format = [None, ",.2f"])
))
fig_tb5.update_layout(title = "New Passenger Vehicle Registrations in Norway (%)",
title_font_size = 20, title_x = 0.5)
fig_tb5.add_annotation(x=1, yref= 'paper', y = 0.01,
text="Source: Our World in Data (2021)",
showarrow=False,
font_size = 15, font_color = 'blue')
fig_tb5.update_layout(autosize = False,
margin=dict(l=20, r=20, t=70, b=100))
fig_tb5.write_image(your_path + 'Table_tb5.png', scale = 2)
fig_tb5.show()
表 6:由作者使用 plotly.graph.objects 制作
第 2 和第 3 列(电池电动和插电式混合动力)的标题字符数多于其他列。增加这两列的宽度(columnwidth=)会使视觉效果更加美观。
# Plotly Table 6
fig_tb6 = go.Figure()
fig_tb6.add_trace(
go.Table(
columnwidth = [80,90,90,80, 80,80],
header = dict(values = headers_tb1, align = aligns,
fill_color = fill_color_h,
line_color = line_color_h),
cells = dict(values = transposed_tb1, align = aligns,
fill_color = fill_color_c,
line_color = line_color_c,
height = 35, format = [None, ",.2f"])
))
fig_tb6.update_layout(title = "New Passenger Vehicle Registrations in Norway (%)",
title_font_size = 20, title_x = 0.5)
fig_tb6.add_annotation(x=1, yref= 'paper', y = 0.01,
text="Source: Our World in Data (2021)",
showarrow=False,
font_size = 15, font_color = 'blue')
fig_tb6.update_layout(autosize = False,
margin=dict(l=20, r=20, t=70, b=100))
fig_tb6.write_image(your_path + 'Table_tb6.png', scale = 2)
fig_tb6.show()
表 7:由作者使用 plotly.graph.objects 制作
结论
表格是数据分析师向技术、商业或管理观众展示其发现时的基本工具。因此,其准备工作必须非常细致。
在这篇文章中,我们提供了 10 个改进表格设计的提示和 10 个准备过程中应避免的陷阱。我们还介绍了如何使用 plotly.figure_factory 准备一个基本表格。最后,我们描述了一系列在设计表格时应遵循的步骤,使用 plotly.graph_objects。
不要让设计糟糕的表格毁掉你的叙述。
参考文献
战略数据分析(第二部分):描述性问题
原文:
towardsdatascience.com/strategic-data-analysis-for-descriptive-questions-b6c9e469b32f?source=collection_archive---------3-----------------------#2023-10-14
深入探讨回答描述性问题的方法
Viyaleta Apgar
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 10 月 14 日
--
这是关于战略数据分析的系列文章的一部分。
战略数据分析(第一部分) → 战略数据分析(第二部分):描述性问题战略数据分析(第三部分):诊断性问题(第三部分) 战略数据分析(第四部分):预测性问题 ← 即将推出!
战略数据分析(第五部分):处方性问题 ← 即将推出!*
在第一部分中,我讨论了数据分析师尝试回答的四种问题类型以及识别每种问题类型的方法。如果你记得,当我们问描述性问题时,我们试图获取对某事的理解。这些问题通常以“what/is/does”开头,并涉及现在时或过去时。现在,让我们深入探讨如何回答这些问题的策略。
描述性问题的回答策略
描述性问题通常是数据分析师遇到的最多的问题,而这些问题的答案往往为后续问题提供基础。通常,经验丰富的分析师已经有一个策略(或至少一些指南)来回答描述性问题。更具体的策略根据问题、行业、个人偏好和知识等不同而有所不同。然而,任何策略的骨架应包括以下内容:
-
评估问题的意图
-
确定问题中的变量
-
定义问题的分析目标
这些步骤应指导你选择最佳方法论并提供最合适的答案。让我们深入了解一下。
作者制作的图表
步骤 1:评估问题的意图
在应用任何技术来回答决策者提出的问题之前,我们必须首先理解为什么会提出这个问题。这可以显著影响我们的策略和最终选择的方法。意图中的一些考虑因素包括:
-
答案将如何被解释,
-
我们的答案将通知哪些决策,以及
-
我们受众的技术或统计素养
我最喜欢的意图意识的例子(也是我最常与同行分享的)是Tyler Buffington, PhD写的文章:均值还是中位数?根据决策选择,而不是分布。在这篇出色的关于选择正确方法论的评审中,Tyler 认为分布的偏斜度不应构成均值或中位数作为“平均值”指标的选择。相反,分析师应关注这个指标将如何被决策者用于推断。
问题的意图也可以帮助我们选择正确的数据点。让我们看一个例子:“我们今年第二季度的销售额是多少?”我们的回答可以是总销售额(销售单位数量乘以每单位价格)或净销售额(总销售额减去折扣和促销)。在某些情况下,我们的决策者可能不知道这个区别,因此教育他们或澄清这个值将如何被使用应该能告诉我们使用哪个值。
另一个考虑因素是受众,这是意图的一部分。如果我们试图回答一个要求我们比较不同组之间分布的问题,向不懂如何阅读箱型图的决策者展示复杂的可视化图表可能并不明智。简单的统计数据可能是更好的选择,特别是对那些每天做出数百个决策且没有时间查看复杂图表的商业伙伴(例如,C 级高管)。另一方面,如果我们想向具有统计知识的数据科学家展示信息,那么箱型图可能是更好的选择。
第二步:识别问题中的变量
下一步是识别和澄清问题中的变量,这些变量我们希望以某种方式进行描述,并确保这些变量具有代表性数据。
例如,在“今年第二季度我们的销售额是多少?”这个问题中,单一变量显而易见——即今年第二季度的销售额,我们可以从销售账本中轻松获取数据。
然而,如果问题缺乏明显的变量,那么问题应该被重新表述,以便涉及明确且可以用数据表示的变量。
例如,在“我们的临床患者护理中是否存在性别偏见?”这个问题中,变量是“性别偏见”,但“性别偏见”不一定是数据点本身。然而,“性别之间的结果差异”或“性别之间的患者满意度”是“性别偏见”的潜在衡量标准。因此,我们可以将问题重新表述为“在我们的临床患者护理中,不同性别之间的患者结果是否存在差异?”
透彻审视问题的复杂性也是很重要的。有些问题可能涉及多个名词,但要求我们找出特定的变量,我们应该从问题中孤立出这个变量。
例如,“来自哪个城市的游客倾向于在我们酒店逗留更长时间?”这个问题涉及游客、城市和酒店,但我们寻找的变量是游客的来源城市。对于问题:“在我们增加了更多的呼叫中心代表后,等待时间是否有变化?”两个变量是:1. 时间序列(帮助我们推断变化前后的信息)和 2. 客户等待的时间。
第三步:定义问题的分析目标
在识别出问题中的变量后,我们现在可以对问题的目标进行分类。这可以通过将其重新表述为一个指令,并对该指令进行分类来实现。识别目标可以帮助我们缩小合适的定量技术范围,从而回答原始问题。
记住:分析目标和问题的意图是不同的。问题的意图确定了决策者计划如何利用答案或他们计划如何解读分析结果。问题的分析目标决定了我们在识别变量后希望如何使用这些变量。
描述性问题可能有三种类型的目标,这些目标取决于我们之前识别的变量:
-
描述变量 如果问题的目标是描述单个变量,那么答案将要求我们找到描述主题的一些参数或一组参数。如果我们可以使用“找到”这个关键词,并跟上问题的主题来重新表述问题,那么问题的目标就是描述变量。
例如:“今年第二季度我们的销售额是多少?”其目标是得到一个代表所有销售额的值;因此,它要求我们找到销售总额。作为指令,我们可以将问题重述为“找到今年第二季度的销售总额”。
大多数可以用来回答这些问题的技术包括计算描述性统计(如总和、均值、众数、范围等)或可视化工具(如直方图或核密度估计图)。不过,根据问题的性质,也存在更高级的技术。
-
比较组或变量 如果问题的目标是比较变量中的组或比较不同的变量,那么我们可以使用“比较”关键词重新表述问题。这些问题也可以包括时间上的比较,这可能要求我们从时间序列中创建一个变量来作为时间类别(如“之前/之后”、小时、月份等)。
例如,“我们的临床患者护理是否存在性别偏见?”这个问题旨在比较性别组之间的患者护理,也可以重新表述为指令:“比较所有性别的临床患者护理。”
有许多技术可以帮助比较组或变量。可视化工具如条形图或饼图可以帮助比较组,直方图和密度图可以帮助比较两个变量之间的值分布,折线图可以帮助比较随时间变化的值,而散点图可以帮助比较个体点。描述性统计和统计比较测试(如 t 检验或 ANOVA)可以用来比较两个或多个分布[1]。
-
识别趋势或关系 如果问题的目标是识别一个系列中的模式(如时间)或两个或多个变量之间的模式,那么我们可以使用关键词“识别连接/相关性”将描述性问题重新表述为指令。需要注意的是,关系不意味着因果关系,而只是试图建立变量之间的联系;因果关系在诊断性问题中解决。
例如:“我们今年的收入变化如何?”旨在识别收入随时间的变化趋势。我们可以将其重新表述为一个指令:“识别收入与时间之间的关系。”
问题“空气温度和海水温度是否相关”旨在找出两种温度之间的关系。我们可以将其重新表述为“识别空气温度和海水温度之间的相关性”。
为了识别变量之间的关系,散点图、气泡图和热力图可以在视觉上提供帮助,而像 Pearson 或 Spearman 相关性这样的统计方法可以帮助识别变量是否存在联系。识别时间/序列中的趋势最好通过使用折线图和像 ARIMA 这样的统计方法来实现。
一个案例研究
让我们来看一个来自第一部分的问题:“火车是否晚点?”为了找到正确有效的技术来回答这个问题,让我们遵循上述策略步骤。
评估意图: 假设这个问题来自于火车运营公司的副总裁。通过与她的对话,我们了解到副总裁想知道如果火车确实晚点,是否需要采取措施来调整当前的火车时间表。如果火车实际上并没有晚点,她还希望将晚点情况建立为一个 KPI 指标并继续监测。此外,副总裁告诉我们,她认为“火车晚点”是指大多数火车晚点超过一分钟。
识别变量: 在问题“火车是否晚点”中,感兴趣的身份是火车的晚点情况,但哪个变量或变量组合可以正确地代表这一身份?通过对问题及其意图的分析,我们可以确定几个变量选择的选项:
-
两个变量:火车预期到达时间和火车实际到达时间
-
一个变量:火车实际到达时间与预期到达时间的差异
-
一个变量:如果火车实际到达时间与预期到达时间的差异大于 1 分钟,则设置为 1 的二进制标志
我们的变量选择应根据问题的意图,并且一定会影响我们如何确定问题的目标。从意图中,我们知道副总裁认为如果大多数列车迟到,那么列车就算是迟到。因此,我们实际上只需要一个二元标志来识别每列车是否真的迟到。这是我们可以提供的最简单的信息,帮助我们理解整体列车迟到情况,并帮助决策者确定她的下一步。
定义分析目标: 既然我们已经确定了意图和相关变量,我们现在可以定义分析目标并选择技术。由于我们正在处理一个单一变量,即二元“迟到列车”标志,我们知道问题的目标是描述该变量。问题的意图是确定大多数列车是否迟到。因此,我们可以选择的技术之一是计算所有迟到列车的百分比,以确定是否有超过 50%的列车迟到。我们可以将最终信息传达给我们的副总裁,以便她决定下一步做什么。
如果问题的意图或受众不同,这一策略将有显著差异。如果我们的决策者想要了解列车迟到的分布,我们应该选择实际到达时间与预期到达时间的差异,并选择如直方图这样的可视化技术来传达列车迟到的分布。
图表由作者制作
一些最终说明
你可以根据自己的需要使用上述策略,但这里有一些建议来帮助你:
-
保持简单,根据需要逐步增加复杂性。
-
战略过程应该是直观的,但写下意图、变量和目标总是个好主意,这样你对任务有清晰的认识,或者在你的方法中培养纪律性。
-
保持灵活——你的策略可能会随着时间的推移而改变或演变。这份文件是一个良好的开端,但不要让它限制你的创造力和思考。
-
不要忘记进行分析!有些问题不像其他问题那样直观,需要我们思考和分析以理解并找到最佳答案。
感谢阅读!在我的下一篇文章中,我将深入探讨诊断问题,请继续关注,并在评论中告诉我你的想法!
来源:
1 www.scribbr.com/statistics/statistical-tests/
照片由 Scott Graham 在 Unsplash 上拍摄
战略数据分析(第一部分)
原文:
towardsdatascience.com/strategic-data-analysis-part-1-fb2df3a43831?source=collection_archive---------3-----------------------#2023-10-07
数据分析及其试图回答的四种问题
Viyaleta Apgar
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 10 月 7 日
--
这是《战略数据分析》系列的一部分。
→ 战略数据分析(第一部分)战略数据分析(第二部分):描述性问题战略数据分析(第三部分):诊断性问题(第三部分)
战略数据分析(第四部分):预测性问题 ← 敬请期待!
战略数据分析(第五部分):预测性问题 ← 敬请期待!
在我十年的数据工作经验中,我注意到学习定量技术以进行数据分析的重点非常突出。我花费了数千小时来完善我对从统计学到机器学习再到经济学等各方面的知识。然而,我发现关于通过数据分析回答商业问题的战略方法指导非常有限。我还遇到过许多初级分析师,他们常常将数据分析误认为是定量技术,而忽视了分析是一种强大的思维方式和极好的问题解决工具——即数据分析不仅仅是其方法的产物。
在这个多部分系列中,我希望提供一个数据分析入门指南,这将为使用分析回答商业问题提供结构化的方法。在第一部分中,我将介绍数据分析及其能够帮助回答的四种问题类型。这可以用作正确识别分析问题的指导。在随后的帖子中,我将提出回答每种问题类型的策略和选择正确技术的方法。我希望你觉得这个指南有用——在评论中告诉我吧!
什么是数据分析?
那么数据分析是什么,它试图实现什么?一般来说,分析是通过将复杂的信息拆解成更小、更简单的部分并首先理解这些部分的过程。这个过程用于帮助解决问题或回答问题。与一般情况一样,数据分析是通过尝试理解关于复杂数据的更可管理的信息来理解一些东西。
分析师可以使用各种技术来进行数据分析。例如,如果我们与医疗机构的管理者合作,他们要求我们描述其典型患者,我们可以使用统计方法如取均值或计算范围来描述患者群体。因此,我们可以通过一些描述患者总体的简单统计数据来了解诊所的所有患者。这个问题要求我们理解复杂规模的数据,而我们可以通过理解其更简单的部分来实现。
数据分析是“分析数据以回答问题、提取见解和识别趋势的过程和实践”1。然而,尽管数据分析需要借鉴统计学、机器学习、数学和其他学科的技术,数据分析师并不是统计学家、数据科学家或数学家。虽然数据科学家应该对他们所研究的主题有深入的了解,但他们不一定是该主题的专业专家。数据分析师的目标是对各种技术足够熟悉,并在应用这些技术时成为专家,以生成见解和建议,并帮助业务伙伴做出更好的数据驱动决策。但你不必成为数据分析师才能进行数据分析,任何熟悉定量技术和数据分析策略的人都可以利用这些技术帮助做出数据驱动的决策。
几乎所有需要数据分析的问题都属于四大类:描述性、诊断性、预测性和规范性。有些问题涉及已知的值和变量(如描述性和诊断性问题);有些问题则更具假设性而非具体性(如诊断性和规范性问题)。回答这些问题需要批判性思维、创造性解决问题和逻辑推理。然而,如果我们能够将需要数据分析的问题进行分类,我们可以根据其类别制定回答策略。因此,有必要熟悉问题类型及其应对策略。
本文的其余部分介绍了四种问题类型,描述了它们,并提供了示例,以帮助我们识别每种类型。
描述性问题
描述性问题旨在获得对某些具体事物的理解。这可以包括对一个群体的描述、不同变量之间的关系或各种趋势。这些问题通常最容易识别——它们通常指的是当前状态或过去的情况,并且通常以“what”或“is/does/did”关键词开头。由于并非所有描述性问题都以这些关键词开头,另一种识别描述性问题的方法是检查问题的关键词是否可以重述为以“what”开头。这些问题的一些示例如下:
-
我们在今年第二季度的销售额是多少?2
-
我们的收入自上季度以来是否有所增加?
-
我们今年的收入变化如何?
-
客户取消订阅的频率有多高?
-
火车是否经常晚点?
-
我们的临床患者护理中是否存在性别偏见?
-
来自哪个城市的游客倾向于在我们酒店逗留更长时间?
-
上个月的温度变化如何?
-
空气温度和海水温度之间是否有关联?
-
我们在雇佣更多呼叫中心代表后,等待时间有变化吗?
上述问题都涉及一些已知的变量,这些变量可以用于分析——诊所中的性别记录、温度记录或年度收入。如前所述,所有这些问题都可以重新表述为以“什么”或“是否”开头:“空气温度和海水温度有关吗?”与“空气和海水温度之间是否存在关系?”是相同的问题,“客户取消订阅的频率是多少?”与“客户订阅取消的频率是多少?”也是相同的问题。
诊断性问题
诊断性问题旨在理解某件事情发生的原因或如何发生,并试图评估变量之间的依赖关系。这些问题以“为什么”及其同义词(如“怎么会”、“是什么导致的”等)开头,涉及已经发生或正在发生的事件。
诊断性问题的关键在于要求分析师提出潜在原因并验证这些原因是否正确。这是非常直观的,也是大多数人尝试诊断根本原因的方式。通常,相关的因变量发生了变化,我们想知道原因。我们还可以将诊断性问题视为“因果关系”问题,其中“因”是未知的。一些诊断性问题的例子包括:
-
为什么某一客户群体比其他客户群体更愿意与我们互动?
-
为什么我们的销售额在这个季度下降了?
-
是什么导致了热浪?
-
为什么我们的客户取消了他们的订阅?
-
为什么火车晚点了?
-
怎么会有些病人最终进入重症监护室?
在诊断性问题中,未知的是效果的原因。如果我们能够识别已知效果和未知原因,我们可能正在处理一个诊断性问题。
预测性问题
预测性问题旨在识别已知或未知变量中的未知值。我们想预测的值可能涉及部分已知和完全未知的变量。例如,在预测未来销售时,“销售”变量是部分已知的(我们有当前或过去销售的数据);在客户细分中,“客户群体”是完全未知的变量,我们必须依赖其他特征或信息来推测新变量的值。
决策者通常会提出预测性问题,以便进行战略性决策或评估他们对未来状态的准备。预测性问题通常用于寻找未知信息,但与描述性问题不同,答案总是不确定的。一些预测性问题的例子包括:
-
下个季度我们的销售额会是多少?
-
我们的酒店在接下来的 90 天里预计会有多少名客人?
-
我们的 Instagram 帖子会得到多少个赞?
-
我们的客户给我们在 Yelp 上打五星的可能性有多大?
-
今年冬天会有很多雪吗?
-
我们如何根据植物的物理特征分组家居植物?
-
驼背鲸的人口将来会如何变化?
-
火车会继续晚点吗?
如前所述,预测性问题不仅仅是尝试预测未来。它们处理的是部分或完全未知的事物。问题“我们如何根据植物的物理特征分组家居植物?”与未来时态无关,而是希望解决家居植物的未知参数。问题“我们的 Instagram 帖子会获得多少个赞?”很可能涉及一个部分未知的变量:我们可能知道其他 Instagram 帖子获得的赞的数量,但这个特定帖子会收到多少赞是未知的。
规定性问题
规定性问题旨在预测在特定决策下会发生什么 [3]。从这个意义上说,提问的决策者希望根据一组预测结果获得推荐意见。通常,这些问题以两种方式中的一种来表达:“如果…会发生什么?”或“应该怎么做,以便…。”
规定性问题通过评估当前情况的变化如何导致特定结果或识别能带来最佳结果的最优变化,进一步推动了预测性问题的进展。像回答预测性问题一样,我们的结果永远不会是确定的,存在一定的不确定性。然而,答案可以帮助数据驱动的决策,或引导研究来验证预测结果。
一些规定性问题的例子包括:
-
如果我们降低价格,销售会增加吗? 2
-
我怎样才能最大化员工生产力? 2
-
我们如何减少碳排放?
-
我们的商店每天应该开多长时间?
-
如果我们强制实施高等教育入学考试,毕业率会增加吗?
-
我们如何减少急诊部门的患者等待时间?
-
我们的产品价格应该是多少?
规定性问题可能会或可能不会建议决策者计划采取的潜在行动。例如,“如果我们降低价格,销售会增加吗?”包含了一个我们将分析的潜在行动:价格的降低。但不同的问题,如“我们如何减少碳排放?”不包含任何行动,而是要求列出可能减少碳排放的候选行动。这意味着我们需要在策略中采取额外的步骤,以制定候选行动列表。
希望你喜欢这篇数据分析入门文章。敬请期待下一部分,我将分享选择正确技术来回答描述性问题的策略。
来源:
1 online.hbs.edu/blog/post/diagnostic-analytics
2 www.pragmaticinstitute.com/resources/articles/data/32-business-questions-for-data-analysis/
[3] www.pragmaticinstitute.com/resources/articles/data/32-business-questions-for-data-analysis/
照片来自Gia Oris于Unsplash
《战略数据分析(第三部分):诊断性问题》
原文:
towardsdatascience.com/strategic-data-analysis-part-3-diagnostic-questions-c0fcb840294b?source=collection_archive---------6-----------------------#2023-10-26
深入探讨回答“为什么”问题的方法
Viyaleta Apgar
·
关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 10 月 26 日
--
这是《战略数据分析》系列的一部分。
战略数据分析(第一部分)战略数据分析(第二部分):描述性问题→ 《战略数据分析(第三部分):诊断性问题》 战略数据分析(第四部分):预测性问题 ← 敬请期待!
《战略数据分析(第五部分):处方性问题》← 敬请期待!*
回答“为什么”问题对于任何数据分析师来说都可能很困难。缺乏主题知识、技术储备不足和缺乏战略性方法都可能对帮助决策者找到正确答案产生不利影响。然而,凭借扎实的基础和方向,这些诊断问题可以被任何人轻松解决。
诊断问题通常紧随描述性问题的答案。在提出诊断问题时,决策者的目的是了解某些信息是如何产生的或是什么导致了某件事的发生。因此,当我们考虑诊断问题时,我们通常会想到因果推断。因此,熟悉因果推断的一般原则是很好的。
在本文中:
-
因果推断简介
-
诊断问题回答策略
-
案例研究
-
一些最后的备注
因果推断简介
因果推断旨在揭示干预(或对现状的变化)如何影响结果。在因果推断中,我们假设因果关系发生在某个干预,即“处理”,施加到某个单位上,并导致该单位结果的变化。如果我们比较有无处理的单位的结果,我们将能够观察到处理的效果(即因果关系)。
例如,如果我们想知道在房屋上市出售之前是否粉刷外墙会使其更快出售,最理想的情况是我们需要同时比较粉刷和未粉刷房屋的销售时间。在这里,房屋是我们的单位,粉刷外墙是我们的处理,销售时间是我们的结果。然而,不可能同时对同一房屋进行粉刷和不粉刷。因此,“我们无法同时观察到同一单位的处理与未处理”1。
这就是因果推断发挥作用的地方。我们不是直接测量特定单位上的处理效果,而是测量关联性和偏差。关联性是接受处理的所有单位与未接受处理的所有单位之间结果的平均差异。偏差通过捕捉所有结果不同的因素,将关联性与因果关系区分开来。
在我们的房屋出售示例中,我们可以比较所有粉刷过的房屋和所有未粉刷的房屋,并记录它们的销售时间。两组房屋之间的销售时间差异称为“关联”。如果没有偏差,我们可以确定在出售前粉刷房屋会使其更快出售。
然而,大多数决定在出售前粉刷房屋的原房主也有能力这么做,因为他们住在更好的社区;而更好的社区中的房屋通常销售得更快。因此,偏差可能是房屋销售更快的原因不仅仅是因为新刷的油漆,还因为它们位于更好的社区。如果我们能消除这种偏差(以及其他偏差),我们可以确定在出售前粉刷房屋是否会导致它销售更快。
这是因果推断的核心。要深入了解,我强烈推荐 Matheus Facure Alves 的书籍:Causal Inference for the Brave and True,该书详细讲解了这一主题。因果推断的基础构建了回答诊断问题的策略,因此让我们更详细地探讨一下。
解决诊断问题的策略
诊断问题难以回答的原因是,它们可能需要对主题有深刻的了解。揭示某事发生或正在发生原因的一般策略需要理解所有可能的原因和偏差,然后通过严格的技术方法评估它们的效果。理解所有可能的原因需要付出努力和时间。因此,回答诊断问题的大部分时间都花在研究上。不幸的是,研究有时会将分析师引向各种死胡同。采用战略方法和严谨性有助于使过程更顺利。
一般来说,回答诊断问题的方法包括:
-
识别结果
-
识别可能的原因和潜在的偏差
-
评估因果关系
图片由 Mediamodifier 提供,来源于 Unsplash
在开始之前,需要注意的是,在几乎所有情况下,我们可能无法确定某件事的确切根本原因。相反,我们可以识别出最可能的影响因素,并评估它们的影响概率。
不仅要理解这一点,还要制定沟通策略,以便决策者在我们甚至还未开始回答他们的诊断问题之前,就能了解这一警告。在寻找诊断问题的答案时,决策者承受着风险。答案越不确定——风险越大。因此,决策者必须知道在基于所提供的答案做出决策时,必须权衡这种风险。
说明完毕,让我们详细看看这个策略。
步骤 1:识别结果
问题中的结果是经历了某种潜在原因影响的因变量。通常,诊断问题应该只有一个因变量。识别结果非常重要,以便明确它并验证它是否可以被测量。如果问题有多个因变量,则应该将问题拆分成不同的问题。
例如,在第一部分的“是什么导致了热浪”这个问题中,结果是热浪,这可以定义为温度的突然和剧烈升高。在“为什么我们的客户取消了他们的订阅”这个问题中,我们想要调查的结果是订阅取消。如果我们遇到一个类似于“为什么房价在上涨而租金在下降”的问题,我们应该回答两个独立的诊断问题:“为什么房价在上涨”和“为什么租金在下降”。
步骤 2:识别可能的原因和潜在的偏见
一旦我们确定了问题中的结果,我们必须列出所有可能解释它的因素,并帮助我们回答“为什么”。通常,这个过程可以分解为识别三件事:原因、偏见和因果机制。应构建图形因果模型来辅助识别过程。
图片由Mark Rabe提供,来源于Unsplash
潜在原因可以通过研究、专业知识、访谈和关联来确定。如果没有适当的主题知识或专家的帮助,这很难实现。因此,有必要尽可能多地收集有关该主题的知识(有关为什么建立知识很重要的更多细节,请查看我的文章首先我们必须发现,然后我们可以探索)。
在列出潜在原因的清单时,一个很好的工具是头脑风暴。头脑风暴的一种新颖方法是一个重复的过程,首先:列出尽可能多的原因,而不对其有效性进行判断,其次:遍历清单,确保列出的原因是合理和逻辑的。
例如,为了回答第一部分中的一个问题:“为什么我们的客户取消了他们的订阅”,我们可以首先进行调研,了解我们的流失客户是否报告了他们取消订阅的原因。我们可以采访客户成功团队,以了解他们经常收到的客户投诉。然后,我们可以通过与决策者的头脑风暴会议,提出任何额外的原因。
潜在偏差可能比潜在原因更难发现,但对答案有重大影响。就像原因一样,偏差可以通过建立主题专长来确定。然而,与主要依赖知识的潜在原因不同,偏差识别通常需要创造性和建设性的思维。
一个好的起点是熟悉数据分析中常见的偏差类型,并推测这些偏差是否出现在你的使用案例中。一些常见的偏差类型包括确认偏差、选择偏差、历史偏差、生存者偏差、可得性偏差和异常值偏差2(更多信息请查看这篇文章)。
一个非常突出的生存者偏差例子涉及了亚伯拉罕·瓦尔德在二战期间的工作。作为哥伦比亚大学统计研究小组的一部分,瓦尔德及其团队被委托优化战机应携带的装甲量:如果战机装甲过多——由于重量过大,它们将无法飞行;如果装甲过少——它们将没有保护。在分析了安全返回但有弹孔的战机后,亚伯拉罕·瓦尔德建议将装甲添加到战机上没有弹孔的地方(而不是弹孔所在的位置)。为什么?因为分析只包括了幸存的战机,因此那些没有幸存的战机可能在一些关键区域有弹孔。如果这些关键区域被击中,它们未能返回,所以在关键区域增加装甲是有意义的[3]。了解整个故事请查看这篇文章作者亚历山德罗·巴齐。
来源: 维基共享资源
因果机制构成了潜在原因如何影响结果。没有因果机制,很难区分因果关系与巧合。这在选择推断因果关系的模型时起着重要作用。
一个很好的巧合例子是缅因州离婚率与人造黄油消费之间的相关性(见原始文章)。这两个趋势可能是平行的,但没有合理的机制来解释为什么一个会导致另一个。因此,我们不能认为缅因州离婚率的增加会导致人造黄油消费的增加,反之亦然。
图形因果模型应该被开发出来,以帮助识别原因和偏差,以及构成因果关系的机制。实质上,这些模型是包括所有原因和结果的有向图。开发一个图形模型以理解因果关系也可以帮助我们加深对主题的理解,并可以用来帮助我们与决策者沟通。
例如,图形因果模型可以帮助我们揭示混杂偏差。我们来自原因和偏差的变量不一定只是影响结果——它们实际上可以互相影响。如果某些变量影响我们的潜在原因和结果,那么我们就涉及到混杂偏差。为了消除这种偏差,我们应该控制所有共同的潜在原因。
假设我们正在研究在房屋上市出售之前进行粉刷是否会影响销售时间。我们可以假设收入较高可能会影响房主是否决定在出售之前粉刷房屋。然而,我们也可以认识到更高的收入意味着房主还可以利用减少销售时间的资源。这是混杂偏差的一个例子,我们应该在最终模型中控制收入因素。
由我在 draw.io 制作的图示
第 3 步:评估因果关系
现在我们有了一个结果、原因和偏差,以及构成我们依赖关系的机制,我们可以评估因果关系。这一步要求我们验证我们假设的观点是否合理。根据情况和可用资源,我们可以通过两种方式实现这一目标:1. 进行随机实验并比较结果,或 2. 利用历史数据开发统计模型来测量因果关系。
照片来自 Bradyn Trollip 在 Unsplash
进行随机实验,包括处理组和对照组,可以帮助我们通过确保实验中的两个(或更多)组具有相似的人群代表性来减少偏差。如果这些组在组成上相似且我们的样本量足够,我们应该能够比较组间的结果,并识别出结果差异是否显著。
在我们的房屋销售示例中,我们可以抽取两组房屋卖家(确保两组在房主人群中具有相同的代表性)。我们可以要求其中一组在上市前粉刷他们的房子,而要求另一组保持外墙涂料不变。然后,我们将比较两组之间的销售时间分布。通过统计测试,我们可以查看销售时间指标是否存在显著差异。
在实践中,由于许多原因,这可能很难实现,其中一些原因包括让志愿的房主参与我们的实验、确保实验有足够的资金以及确保我们的样本是随机的并且代表了房屋销售人群。然而,即使我们无法进行这样的实验,我们仍然有选择。
建立统计模型使用历史数据可以帮助我们控制混杂原因和偏差,并估计直接原因对结果的影响。使用回归等技术,我们可以为每个原因和通用偏差度量分配权重。我们可以通过使用历史数据训练模型来估计模型的参数(模型中的权重)。最终结果应该帮助我们理解变量对最终结果的因果效应。“即使我们不能使用随机对照试验来保持处理组和对照组之间的其他因素相等,回归也可以通过将这些相同的因素包括在模型中来做到这一点,即使数据不是随机的!” 1
然而,无论我们选择什么技术来测量因果关系,都需要注意我们的模型不能确定因果关系。我们可以在回归模型中包括数百个特征,但仅仅因为它们被包括在内以及因为它们有一定的权重并不能保证它们是结果的原因。因此,捕捉图形因果模型中的可能因果机制是很重要的,这样我们可以避免包括不相关的特征,并确保我们获得充分的结果。
案例研究
让我们继续讨论第二部分的案例研究,在那里我提出了一个关于火车迟到的描述性问题的回答策略。假设我们的决策者现在想知道“为什么火车会迟到?”按照本文中概述的步骤,我们可以制定以下策略来回答这个问题:
识别结果。 问题“火车为什么迟到”的结果是火车迟到(我们定义为“如果火车实际到达时间与预期到达时间之间的差异大于 1 分钟,则设置为 1 的二进制标志”)
识别潜在的原因和偏差。
-
为了识别潜在的原因,我们可以与决策者进行一些访谈和头脑风暴会议,我们可以在平台上观察火车并进行火车旅行,还可以与列车员和乘客交谈。潜在原因的示例包括平台装卸时间延迟、轨道施工、缺乏专用轨道导致的火车交会和超车延迟、危险(如树叶、冰雪)、火车年龄和技术问题。对于每个原因,我们还应该识别一个机制,通过该机制原因对结果产生影响。
-
为了识别潜在的偏差,我们可以熟悉各种偏差类型,并评估这些偏差是否适用于我们的用例。例如,选择偏差可能不会对我们造成问题,因为我们可以将所有火车纳入研究,而不是仅选择一部分火车。另一方面,我们可能存在幸存者偏差的情况,因为某些火车的机械问题可能导致火车从未到达,因此被排除在晚点火车数据集之外。
-
为了识别潜在的因果机制,我们应该识别每个潜在原因如何影响或冲击结果。例如,某种危险(如落叶或雪)可能导致火车晚点,因为它会让火车因危险而减速。我们可以假设火车的年龄影响火车的晚点情况,因为老旧的火车较慢。但这是真的吗?收集相关数据并进行探索性数据分析可以帮助我们验证这一因果机制是否合理。
我们可以整理一个图形因果模型,以评估我们提出的原因和偏差与结果的关系,并为每个原因概述一个潜在机制。此时,我们还可以进行更多的探索性数据分析,以发现我们原因之间的隐性关联,并选择最终的潜在原因纳入模型。例如,如果我们发现有技术问题的火车大多数是老旧火车,我们就不需要将火车年龄作为模型参数,因为它已经通过技术问题参数得到了暗示。
来源:由我在 draw.io 中制作
评估因果关系。 最后,我们准备评估因果关系。对我们而言,进行一系列实验以测试每个潜在原因将会是困难且成本高昂的。然而,由于我们有详细的火车时刻表、火车问题以及天气和轨道条件记录,我们应该努力构建一个回归模型,以验证可能的原因。在我们的案例中,我们可以使用可能的原因构建一个逻辑回归模型,以预测火车是否确实晚点。模型训练后,与模型参数相关的权重应指示每个原因对结果的影响。
在选择了具有非零权重的原因后,我们可以向决策者展示我们的发现,并回答他们最初的问题:“为什么火车会晚点?”
几个最终说明
这篇文章较长,但我希望它能够阐明一个复杂的主题,并使其更容易处理。几点说明:
-
我们可能无法确定过去或当前事件的确切原因。在大多数情况下,我们可以识别出最可能或最有可能的原因。因此,决策者需要承担一定的风险,并应对此有所了解。
-
图形因果模型可以成为与决策者沟通的一个极好工具,帮助揭示潜在情况之间的关联,并有助于识别偏差。
-
如果没有因果机制,一个与结果有非零关联的潜在原因可能只是巧合。
-
作为分析师,运用批判性思维技能非常重要,特别是在处理诊断问题时。这些问题可能有很多曲折,可能会将你引向错误的路径。
感谢阅读!在我的下一篇文章中,我将深入探讨预测问题,请保持关注,并在评论中告诉我你的想法!
来源
1 matheusfacure.github.io/python-causality-handbook/01-Introduction-To-Causality.html
2 www.metabase.com/blog/6-most-common-type-of-data-bias-in-data-analysis
[3] www.cantorsparadise.com/survivorship-bias-and-the-mathematician-who-helped-win-wwii-356b174defa6
从云存储中流式传输大数据文件
原文:
towardsdatascience.com/streaming-big-data-files-from-cloud-storage-634e54818e75?source=collection_archive---------9-----------------------#2023-01-25
高效处理大型文件的方法
Chaim Rand
·
关注 发表于 Towards Data Science ·13 分钟阅读·2023 年 1 月 25 日
--
图片由 Aron Visuals 提供,来源于 Unsplash
处理非常大的文件可能会给应用程序开发人员带来与高效资源管理和运行时性能相关的挑战。例如,文本文件编辑器可以分为能够处理大文件的编辑器和那些让你的 CPU 卡顿、让你的 PC 冻结、让你想尖叫的编辑器。当大型文件存储在远程存储位置时,这些挑战会更加严重。在这种情况下,必须考虑文件如何被拉取到应用程序中,同时考虑:带宽容量、网络延迟和应用程序的文件访问模式。在这篇文章中,我们考虑了我们的数据应用程序需要访问存储在云对象存储中的一个或多个大文件的情况。这是关于高效从云中获取数据的一系列文章中的一篇(例如,这里,这里 和 这里)。
在开始之前,让我们明确一下……在使用云存储时,通常不建议处理特别大的文件。如果你正在处理多个数据文件,也建议不要选择特别小的文件大小(因为对存储服务发出的多个请求会产生额外开销)。最佳文件大小因平台而异,但通常在几 MB 到几百 MB 之间。(如果你严重依赖云存储,可能需要设计一个简单的实验来测试这一点。)不幸的是,我们并不总是能控制数据设计过程,有时只能接受现有的条件。
另一个好的实践,特别是在处理大文件时,是选择支持部分文件读取的格式——也就是说,选择一种不需要加载整个文件即可处理其任何部分的格式。这类文件的几个例子包括:
-
一个简单的文本文件,
-
一个顺序数据集——一个包含按顺序分组到单个文件中的单独记录的数据集,
-
以列式格式存储的数据集,例如 Apache Parquet,这种格式专门设计用于仅加载选定的列,
-
允许从任意时间偏移播放的视频文件(大多数格式都支持这种功能)。
在这篇文章中,我们假设我们的大文件允许部分读取。我们将考虑几种在 Python 中读取文件内容的选项,并测量它们在不同应用场景下的表现。尽管我们的演示将基于 AWS 的对象存储服务 Amazon S3,但我们所写的内容同样适用于任何其他对象存储服务。请注意,我们选择的具体服务、API、库等,不应被视为对这些选项的偏好。云端数据流的最佳解决方案很大程度上依赖于项目和环境的细节,强烈建议在得出任何结论之前进行深入分析。
从 Amazon S3 直接下载
在这篇文章中,我们将假设我们直接从 Amazon S3 下载文件。然而,需要注意的是,有许多服务和/或解决方案在对象存储和应用程序之间包括一个中间步骤。例如,AWS 提供了 Amazon FSx 和 Amazon EFS 等服务,将数据镜像到云中的高性能文件系统中。AI Store 提供了一种基于 kubernetes 的解决方案,用于与数据消费应用程序相邻的轻量级存储堆栈。这些解决方案可能会缓解使用大文件时的一些挑战,例如,它们可能会减少延迟并支持更高的带宽。另一方面,它们通常会引入一系列与部署、维护、扩展性等相关的新挑战。此外,它们还会增加额外的费用。
比较性能测量指标
在接下来的几节中,我们将描述在 Python 中从 Amazon S3 拉取大文件的不同方法。为了比较这些方法的行为,我们将使用以下指标:
-
首次采样时间 — 读取文件中第一个样本需要多长时间。
-
平均顺序读取时间 — 在顺序遍历所有样本时,每个样本的平均读取时间是多少。
-
总处理时间 — 整个数据文件的总处理时间是多少。
-
平均随机读取时间 — 在读取任意偏移量的样本时的平均读取时间是多少。
不同的应用程序将对这些指标的优先级有不同的偏好。例如,视频流应用可能会优先考虑较低的首次样本时间,以提高观众体验。它还需要在任意偏移量处进行高效读取,以支持快进等功能。另一方面,只要平均每样本时间超过某个阈值(例如,每秒 30 帧),优化此指标就不那么重要。相比之下,深度学习训练应用可能会优先考虑减少平均顺序读取时间和总处理时间,以最小化训练流程中的潜在性能瓶颈。
玩具示例
为了方便讨论,我们创建了一个 2 GB 的二进制文件,并假设该文件包含 2,048 个数据样本,每个样本大小为 1 MB。下面的代码块包括以下片段:创建一个包含随机数据的文件并将其上传到 Amazon S3(使用 boto3),按顺序遍历所有样本,以及在非顺序文件偏移量处采样数据。对于此及所有后续代码片段,我们假设您的 AWS 账户和本地环境已被适当地 配置 以访问 Amazon S3。
import os, boto3
KB = 1024
MB = KB * KB
def write_and_upload():
# write 2 GB file
with open('2GB.bin', 'wb') as f:
for i in range(2*KB):
f.write(os.urandom(MB))
# upload to S3
bucket = '<s3 bucket name>'
key = '<s3 key>'
s3 = boto3.client('s3')
s3.upload_file('2GB.bin', bucket, key)
def read_sequential(f, t0):
t1 = time.time()
x = f.read(MB)
print(f'time of first sample: {time.time() - t1}')
print(f'total to first sample: {time.time() - t0}')
t1 = time.time()
count = 0
while True:
x = f.read(MB)
if len(x) == 0:
break
count += 1
print(f'time of avg read: {(time.time() - t1)/count}')
def fast_forward(f):
t1 = time.time()
total = 10
for i in range(total):
f.seek(i * 100 * MB)
t1 = time.time()
x = f.read(MB)
print(f'time of avg random read: {(time.time() - t1)/total}')
从 S3 下载到本地磁盘
我们考虑的第一个选项是将大型文件下载到本地磁盘,然后以读取任何其他本地文件的方式从那里读取它。有多种方法可以将文件下载到本地磁盘。我们在这里评估的三种方法是:Python boto3 API、AWS CLI 和 S5cmd。
Boto3 文件下载
在 Python 中从 Amazon S3 拉取文件的最直接方法是使用专用的 Boto3 Python 库。在下面的代码块中,我们展示了如何定义一个 S3 客户端 并使用 下载文件 API 将文件从 S3 拉取到本地路径。该 API 接受一个 TransferConfig 对象,其中包含调节下载行为的控制项。在我们的示例中,我们将设置保留为默认值。
import boto3, time
bucket = '<s3 bucket name>'
key = '<s3 key>'
local_path = '<local path>'
s3 = boto3.client('s3')
config = boto3.s3.transfer.TransferConfig(
multipart_threshold=8 * MB,
max_concurrency=10,
multipart_chunksize=8 * MB,
num_download_attempts=5,
max_io_queue=100,
io_chunksize=256 * KB,
use_threads=True,
max_bandwidth=None)
t0 = time.time()
s3.download_file(bucket, key, local_path, Config=config)
with open(local_path, 'rb') as f:
read_sequential(f,t0)
print(f'total time: {time.time()-t0}')
我们在本地环境中运行了这个脚本(以及所有后续脚本),并将结果在 10 次试验中取了平均值。不出所料,平均首次样本时间相对较高,约为 21.3 秒。这是因为我们需要等待整个文件下载完成后才能打开。一旦下载完成,顺序和任意样本的平均读取时间都非常微小,就像我们从其他本地文件中预期的一样。
Boto3 包含一个类似的 API,download_fileobj,用于将文件直接下载到内存中(例如,使用io.BytesIO对象)。然而,这在处理大文件时通常不推荐。
AWS CLI
AWS CLI工具提供了类似的命令行功能。AWS CLI 是用 Python 编写的,使用与 Boto3 相同的底层 API。一些开发者对这种使用方式感到更加舒适。下载配置设置通过AWS 配置文件进行控制。
import shlex, time
from subprocess import Popen
bucket = '<s3 bucket name>'
key = '<s3 key>'
local_path = '<local path>'
cmd = f'aws s3 cp s3://{bucket}/{key} {local_path}'
p = Popen(shlex.split(cmd)).wait()
with open(local_path, 'rb') as f:
read_sequential(f,t0)
print(f'total time: {time.time()-t0}')
不出所料,运行此脚本(使用默认配置设置)的结果与之前的 Boto3 结果几乎完全相同。
S5cmd
我们在之前的文章中详细介绍了S5cmd命令行工具,展示了它在并行从云存储中下载数百个小文件的价值。与之前的方法不同,S5cmd 是用Go 编程语言编写的,因此能够更好地利用底层资源(例如,CPU 核心和 TCP 连接)。有关 S5cmd 如何工作及其显著性能优势的更多细节,请查看这篇信息丰富的博客。S5cmd 的concurrency标志允许控制下载速度。下面的代码块演示了将 S5cmd 的concurrency设置为 10 的使用方法。
import shlex, time
from subprocess import Popen
bucket = '<s3 bucket name>'
key = '<s3 key>'
local_path = '<local path>'
s5cmd = f's5cmd cp --concurrency 10 s3://{bucket}/{key} {local_path}'
p = Popen(shlex.split(cmd)).wait()
with open(local_path, 'rb') as f:
read_sequential(f,t0)
print(f'total time: {time.time()-t0}')
很遗憾,我们未能再现 S5cmd 在50 GB 文件上的卓越性能提升。平均首次样本时间约为 23.1 秒,稍高于我们之前的结果。
多线程与单线程下载
之前的每种方法都在后台使用了多线程下载。在多线程下载中,多个线程并行运行,每个线程负责下载文件的一个不重叠的块。多线程下载对于及时从云端拉取大文件至关重要。为了演示其重要性,我们重新进行了 Boto3 实验,将use_threads标志设置为False,实际上禁用了多线程下载。这导致结果的平均首次样本时间飙升至 156 秒。
数据文件预取的艺术
预取是一种常用技术,用于遍历多个大型文件时。当进行预取时,应用程序会开始并行下载一个或多个后续文件,同时处理当前文件。通过这种方式,应用程序可以避免除第一个文件之外的所有文件的下载延迟。有效的预取需要适当的调整以达到最优结果。许多框架通过从云中预取数据来加速数据摄取速度。例如,PyTorch和TensorFlow都支持预取训练数据文件以优化深度学习训练。
从 S3 流式传输数据
一些应用程序可能愿意在平均每个样本的读取时间上做出妥协,以换取较低的首样本时间。在这一节中,我们演示了一种 Boto3 选项,用于从 S3 流式传输文件,以便在完成文件下载之前就开始处理它。我们描述的方法涉及创建一个Linux FIFO 管道,并将其传递给 Boto3 的download_fileobj API:
import os, boto3, time, multiprocessing as mp
bucket = '<s3 bucket name>'
key = '<s3 key>'
t0 = time.time()
os.mkfifo(local_path)
def stream_file():
s3 = boto3.client('s3')
with open(local_path, 'wb') as f:
s3.download_fileobj(bucket, key, f)
proc = mp.Process(target=stream_file)
proc.start()
with open(local_path, 'rb') as f:
read_sequential(f, t0)
print(f'total time: {time.time()-t0}')
确实,平均首样本时间降至约 2.31 秒(从超过 20 秒降下)。另一方面,平均每样本时间和总文件处理时间分别增加到约 0.01 秒和 24.7 秒。
从任意偏移量读取数据
一些应用程序要求能够在任意偏移量处仅读取文件的特定部分。对于这些用例,下载整个文件可能是极其浪费的。在这里,我们展示了如何使用 Boto3 的get_object数据流 API 下载文件的特定字节范围。下面的代码块演示了 API 在流式传输整个文件和读取任意数据块时的使用。
import boto3, time
bucket = '<s3 bucket name>'
key = '<s3 key>'
s3 = boto3.client('s3')
# stream entire file
t0 = time.time()
response = s3.get_object(
Bucket=bucket,
Key=key
)
f = response['Body']
read_sequential(f,t0)
print(f'total time: {time.time()-t0}')
# fast forward
total = 10
t0 = time.time()
for i in range(total):
response = s3.get_object(
Bucket=bucket,
Key=key,
Range=f'bytes={i*100*MB}-{i*100*MB+MB-1}'
)
f = response['Body']
x = f.read()
print(f'time of avg random read: {(time.time() - t0)/total}')
尽管此解决方案的首样本时间结果约为 1.37 秒,但总文件处理时间(约 119 秒)使其不适合顺序读取。其价值体现在读取任意样本的平均时间上——约 0.191 秒。
请注意,我们的示例没有利用任意偏移量是预先确定的这一事实。现实世界中的应用程序将使用这些信息来预取文件段并提升性能。
如上所述,从云中高效拉取大型文件依赖于并行多部分下载。实现这一点的一种方法是使用 Boto3 的get_object API 以刚刚展示的方式读取不连续的文件块。
使用 Amazon S3 Select 过滤数据
有时,我们所寻求的部分数据是从存储在 CSV、JSON 或 Apache Parquet 文件格式的大文件中提取的少量行和/或列。在这种情况下,我们可以简单地使用专用服务,如Amazon S3 Select来应用 SQL 过滤器。要对多个文件运行 SQL 过滤器,你可以考虑使用Amazon Athena。这两种服务都允许你将数据检索限制为所需的特定信息,避免了拉取一个或多个大型文件的高开销。请务必查看文档以了解更多信息。
使用 Goofys 挂载 S3 数据
到目前为止,我们讨论的所有解决方案都涉及直接从云存储中提取数据。其他解决方案则将云存储访问暴露给应用程序,作为(类似 POSIX 的)文件系统。Goofys是一个流行的基于FUSE的库,用于从 Amazon S3 中读取数据。下面的命令演示了如何将 S3 桶挂载到本地文件路径。
goofys -o ro -f <s3 bucket name> <local path>
一旦配置了 goofys 挂载,应用程序可以通过指向本地路径的方式访问大型文件,如下面的代码块所示:
bucket = '<s3 bucket name>'
key = '<s3 key>'
mount_dir = '<local goofys mount>'
sequential = True # toggle flag to run fast_forward
t0 = time.time()
with open(f'{mount_dir}/{key}', 'rb') as f:
if sequential:
read_sequential(f, t0)
print(f'total time: {time.time()-t0}')
else:
fast_forward(f)
基于 goofys 的解决方案导致了约 1.36 秒的首次样本时间、约 27.6 秒的总文件处理时间和约 0.149 秒的读取任意样本的平均时间。
应该注意的是,在底层,goofys 尝试优化响应时间,即使是以增加对 Amazon S3 的额外调用为代价(例如,预取数据块,即使在它们被请求之前)。根据你的设置和应用程序的数据访问模式的细节,这可能导致相对于其他解决方案,Amazon S3 成本略有增加。Goofys 包括多个设置来控制其行为,例如,“ — cheap”标志用于在性能和潜在较低成本之间进行权衡。请注意,这些控制是在定义挂载时应用的。与基于 Boto3 的解决方案相反,goofys 不允许你在运行时调整控制(例如,块大小、预取行为等),以支持不同的数据摄取模式。
另一个需要注意的点是,goofys 读取的数据部分会被缓存。因此,如果在缓存中仍然存在相同的数据部分时再次读取,该响应将是即时的。在运行测试时请记住这一点,并确保在每次实验之前重置 goofys 挂载。
确保查看文档以获取更多关于 goofys 如何工作、如何控制其行为、其限制等方面的详细信息。
比较结果
下表总结了我们进行的实验结果:
从 S3 拉取 2GB 文件的比较结果(作者提供)
结果表明,处理大数据文件的最佳方法可能取决于应用程序的具体需求。根据这些结果,旨在最小化首次样本时间的应用程序将从基于 goofys 的解决方案中获益最大,而旨在最小化总处理时间的应用程序则会选择使用 Boto3 将文件下载到本地磁盘。
我们强烈建议不要依赖此表格来为你的应用程序做出任何决策。测试是在一个非常特定的环境下进行的,使用的是本文撰写时可用的工具版本。比较结果可能会因以下因素而有很大差异:设置、网络负载、与云存储设施的距离、所用工具的版本、应用程序的资源需求等。
我们的结果没有考虑不同解决方案之间 CPU、内存和其他系统资源的使用差异。在实际应用中,系统资源的负载应予以考虑,因为它们可能会影响应用程序的行为。
在做出任何设计决策之前,请务必进行自己的深入、基于用例的实验。
总结
在这篇文章中,我们讨论了从云存储中提取大型数据文件的话题。我们看到,最佳的方法可能会根据数据应用的具体需求而有所不同。我们的讨论突显了一个在所有关于云服务的文章中都存在的主题——尽管云为技术发展和进步开辟了广泛的新机会,但它也带来了许多独特而令人兴奋的挑战,并促使我们重新思考常见的应用设计原则。
一如既往,欢迎提出评论、问题和更正。
数据工程中的流数据
原文:
towardsdatascience.com/streaming-in-data-engineering-2bb2b9b3b603
流数据管道和实时分析
💡Mike Shakhomirov
·发表于Towards Data Science ·阅读时间 9 分钟·2023 年 12 月 12 日
--
图片由DESIGNECOLOGIST提供,在Unsplash上
流数据是最受欢迎的数据管道设计模式之一。将事件作为单个数据点创建了从一个点到另一个点的持续数据流,从而为实时数据摄取和分析提供了机会。如果你想了解数据流并学习如何构建实时数据管道,这篇文章适合你。了解如何测试解决方案,并模拟事件流的测试数据。这篇文章是一个绝佳的机会,让你掌握一些受欢迎的数据工程技能,使用流行的流处理工具和框架,即 Kinesis、Kafka 和 Spark。我想谈谈数据流的好处、示例和用例。
数据流究竟是什么?
流数据,也称为事件流处理,是一种数据管道设计模式,当数据点不断从源头流向目的地时使用。这可以实时处理,使实时分析功能能够快速对数据流和分析事件作出反应。由于流处理,应用程序可以对新数据事件触发即时响应,通常这将是处理企业级数据最受欢迎的解决方案之一。
只要在点 A 和点 B 之间进行数据处理,就会有一个数据管道 1。
流数据管道示例。图片由作者提供
在这个例子中,我们可以创建一个ELT 流处理数据管道到AWS Redshift。AWS Firehose delivery stream可以提供这种无缝集成,将数据直接创建到数据仓库表中。然后,数据将被转化,以使用AWS Quicksight作为 BI 工具生成报告。
假设我们需要创建一个报告仪表盘来展示公司中的收入来源。在许多场景中,业务需求是实时生成洞察。这正是我们需要使用流处理的情况。
数据流可以由各种数据源生成,例如物联网、服务器数据流、营销应用内事件、用户活动、支付交易等。这些数据可以以不同格式流动,并且经常变化。流处理模式的理念是实时应用 ETL 并无缝处理事件流。
每当我们需要处理毫秒级的数据延迟时,流处理是正确的选择。
考虑下面的例子以更好地理解它。所有应用程序都使用 OLTP 数据库[4],例如 MySQL。你的应用程序也是其中之一,但你需要将这些数据存储到数据仓库解决方案(DWH),即 Snowflake 或 BigQuery。
## 数据建模为数据工程师
初学者的终极指南
towardsdatascience.com
使用批量数据管道解决方案,我们可能希望从 MySQL 加载到 DWH 一次/每天/每小时/每五分钟等。流处理则相反,它将使用专用系统,例如 Kafka Connect。它会在数据进入数据库时立即处理数据。
流行的数据流工具
让我们深入了解过去几年中被证明最有用的流数据平台和框架。
-
Apache Spark — 用于大规模分析和复杂数据转换的分布式数据计算框架
-
Apache Kafka — 一个实时数据管道工具,具有分布式消息系统用于应用程序
-
AWS Kinesis — 一个用于分析和应用程序的实时流平台
-
Google Cloud Dataflow — 谷歌的实时事件处理和分析管道的流平台
-
Apache Flink — 一个分布式流数据平台,旨在进行低延迟数据处理。
几乎所有这些平台都有它们的托管云服务(如 AWS Kinesis、Google Cloud Dataflow),并且可以与其他服务(如存储(S3)、队列(SQS、pub/sub)、数据仓库(Redshift)或 AI 平台)无缝集成。
所有这些工具都可以部署在 Kubernetes、Docker 或 Hadoop 上,旨在解决一个问题——处理高容量和高速度的数据流。
数据流的好处
流数据管道设计模式帮助组织主动减轻与数据处理延迟相关的不利业务事件的影响,例如各种损失和中断、客户流失和财务衰退。由于今天业务需求的复杂性,传统的批处理数据处理是‘不可行’的解决方案,因为它只能处理累积时间的事务数据组。
所以使用数据流的业务优势如下:
-
提高客户满意度,进而提高客户保留率
-
减少操作损失,因为它可以提供关于系统中断和漏洞的实时洞察。
-
投资回报率提高,因为公司现在可以更快地对业务数据做出反应,对客户需求和市场趋势的响应能力提高。
主要的技术优势在于数据处理,因为它以严格的逐个处理方式运行事件处理。与批处理处理相比,它具有更好的故障容忍性,如果管道中的一个事件因某些原因无法处理,那么只有这个事件会受到影响。在批处理管道中,由于单个数据点可能具有错误的模式或数据格式,整个数据处理块会因此失败。
流数据管道的主要缺点是成本
每次我们的流处理器达到端点时,它都需要计算能力。通常,流数据处理会导致与特定数据管道相关的更高成本。
构建流数据管道的挑战
- 故障容忍性 — 我们能否设计和构建一个能够处理单一数据事件处理失败的数据平台?数据通常来自不同的数据源,甚至可能以不同的格式出现。在设计带有流组件的数据平台时,数据的可用性和持久性成为重要的考虑因素[3]。
数据平台架构类型
它能多好地满足你的业务需求?选择的困境。
towardsdatascience.com
-
排队和排序 — 数据流中的事件必须正确排序。否则,数据处理可能会失败。例如,如果排序不正确,应用内消息将没有意义。
-
可扩展性 — 应用程序需要扩展。就这么简单。设计一个能够很好应对来自源的事件数量增加的数据管道并非易事。能够为数据管道添加更多资源和数据处理能力是一个强大数据平台的重要组成部分。
-
数据一致性 — 在分布式数据平台中,数据经常是并行处理的。这可能会成为挑战,因为在一个数据处理器中数据可能已经被修改,而在另一个处理器中变得陈旧。
现实世界中的一个例子
让我们看一下这个使用 AWS Kinesis 和 Redshift 构建的流数据管道示例。
示例管道。图片由作者提供
Amazon Kinesis Data Firehose 是一个 ETL 服务,可以高可靠性地收集、转换并分发流数据到数据湖、数据存储和分析服务。
我们可以用它将数据流传输到 Amazon S3,并将数据转换为分析所需的格式,无需开发处理管道。它对于机器学习(ML)管道也非常适合,其中模型用于检查数据并预测推断端点,因为数据流向其目标。
Kinesis Data Streams 与 Kinesis Data Firehose
Kinesis Data Streams 主要关注于消费和存储数据流。Kinesis Data Firehose 旨在将数据流传递到特定的目标。两者都可以消费数据流,但使用哪个取决于我们希望数据流去往何处。
AWS Kinesis Data Firehose 允许我们将数据流重定向到 AWS 数据存储。Kinesis Data Firehose 是收集、处理和加载数据流到 AWS 数据存储的最直接方法。
Amazon Kinesis Data Firehose 支持批处理操作、加密和流数据压缩,以及自动化的每秒 TB 级别的可扩展性。Firehose 可以无缝集成 S3 数据湖、RedShift 数据仓库解决方案或 ElasticSearch 服务。
AWS Kinesis Data Streams 是一个 Amazon Kinesis 实时数据流解决方案,具有卓越的可扩展性和耐用性,数据流全天候 24/7 可用给任何消费者。这使得它比 Kinesis Data Firehose 更昂贵。
如何使用 AWS CloudFormation 创建 Firehose 资源
请查看下面的 CloudFormation 模板。它部署了包括我们需要的 Firehose 在内的所有必要资源。
AWSTemplateFormatVersion: 2010-09-09
Description: >
Firehose resources relating to data generation.
Parameters:
Environment:
AllowedValues:
- staging
- production
Description: Target environment
Type: String
Default: 'staging'
DataLocation:
Description: S3 data lake bucket name.
Type: String
Default: data.staging.aws
Resources:
MyDataStream:
Type: AWS::KinesisFirehose::DeliveryStream
Properties:
DeliveryStreamName: !Sub 'my-event-${Environment}'
DeliveryStreamType: DirectPut
ExtendedS3DestinationConfiguration:
BucketARN:
- !Sub 'arn:aws:s3:::${DataLocation}' # For example: 'arn:aws:s3:::data.staging.aws'
BufferingHints:
IntervalInSeconds: 300
SizeInMBs: 30
CompressionFormat: UNCOMPRESSED
Prefix: !Sub 'events/my-event-${Environment}/'
RoleARN: !GetAtt AccessRole.Arn
AccessRole:
Type: AWS::IAM::Role
Properties:
AssumeRolePolicyDocument:
Statement:
- Effect: Allow
Principal:
Service:
- firehose.amazonaws.com
Action:
- sts:AssumeRole
Path: /
Policies:
- PolicyName: !Sub '${AWS::StackName}-AccessPolicy'
PolicyDocument:
Statement:
- Effect: Allow
Action:
- s3:AbortMultipartUpload
- s3:GetBucketLocation
- s3:GetObject
- s3:ListBucket
- s3:ListBucketMultipartUploads
- s3:PutObject
Resource:
- !Sub 'arn:aws:s3:::${DataLocation}'
- !Sub 'arn:aws:s3:::${DataLocation}/*'
# - 'arn:aws:s3:::data.staging.aws' # replace with your S3 datalake bucket
# - 'arn:aws:s3:::data.staging.aws/*'
- Effect: Allow
Action:
- kinesis:DescribeStream
- kinesis:GetShardIterator
- kinesis:GetRecords
Resource:
- !Sub 'arn:aws:kinesis:${AWS::Region}:${AWS::AccountId}:stream/my-event-${Environment}'
可以使用 AWS CLI 工具在 AWS 中部署它。我们需要在命令行中运行这个(在你的账户中替换为唯一的存储桶名称):
./deploy-firehose-staging.sh s3-lambda-bucket s3-data-lake-bucket
我们的 shell 脚本如下所示:
#!/usr/bin/env bash
# chmod +x ./deploy-firehose-staging.sh
# Run ./deploy-firehose-staging.sh s3-lambda-bucket s3-data-lake-bucket
STACK_NAME=FirehoseStackStaging
LAMBDA_BUCKET=$1 #datalake-lambdas.aws # Replace with unique bucket name in your account
S3_DATA_LOCATION=$2 #data.staging.aws # S3 bucket to save data, i.e. datalake
# Deploy using AWS CLI:
aws \
cloudformation deploy \
--template-file firehose_stack.yaml \
--stack-name $STACK_NAME \
--capabilities CAPABILITY_IAM \
--parameter-overrides \
"Environment"="staging" \
"DataLocation"=$S3_DATA_LOCATION #"data.staging.aws"
已创建 Firehose 资源。图片由作者提供
现在我们需要创建一个事件生产者。我们可以使用 Python 完成这个操作,app.py
的代码如下:
import boto3
kinesis_client = boto3.client('firehose', region_name='eu-west-1')
...
response = client.put_record_batch(
DeliveryStreamName='string',
Records=[
{
'Data': b'bytes'
},
]
)
put_record_batch
方法可以在一次调用中将多个数据记录写入交付流,这比单条记录写入方式能提供更好的每生产者吞吐量。PutRecord
用于将单条数据记录写入交付流。在本教程中选择哪个方法由你决定。
我们可以在 app.py
中使用下面的辅助函数生成一些合成数据。
def get_data():
'''This function will generate random data for Firehose stream.'''
return {
'event_time': datetime.now().isoformat(),
'event_name': random.choice(['JOIN', 'LEAVE', 'OPEN_CHAT', 'SUBSCRIBE', 'SEND_MESSAGE']),
'user': round(random.random() * 100)}
现在这些数据可以通过以下方式发送到我们的事件生产者:
try:
print('Sending events to Firehose...')
for i in range(0, 5):
data = get_data()
print(i, " : ", data)
kinesis_client.put_record(
DeliveryStreamName=STREAM_NAME,
Record={
"Data":json.dumps(data)
}
)
processed += 1
print('Wait for 5 minutes and Run to download: aws s3 cp s3://{}/events/ ./ --recursive'.format(S3_DATA))
# For example, print('Wait for 5 minutes and Run to download: aws s3 cp s3://data.staging.aws/events/ ./ --recursive')
except Exception as e:
print(e)
完成!我们已经创建了一个简单的流数据管道,将汇总结果输出到云存储(AWS S3)。
在命令行中运行 python app.py
:
事件连接器示例。作者提供的图片
查看下面的教程,了解更高级的数据管道示例 2
## 使用 Redshift Serverless 和 Kinesis 构建流数据管道
面向初学者的端到端教程
towardsdatascience.com
结论
项目理想的流数据平台并不存在。流设计有其好处,但在使用时也会遇到一些明显的挑战。选择哪个流工具不是一个容易的决定。这取决于你的业务目标和功能数据需求。你可能需要尝试并比较多个流平台,考虑功能、性能、成本、易用性和兼容性等特征。它会是一个机器学习管道吗?我们需要处理分区、窗口和连接吗?我们需要高吞吐量、容错性还是低延迟?
不同的流框架具有不同的能力,例如,Kafka 有一个方便的会话 库,可以很容易地集成到你的分析管道中。
我们的管道需要什么频率的数据传输和消费?它将交付到数据仓库解决方案还是数据湖中?一些平台比其他平台提供更好的集成功能。
另一个重要的考虑因素是必须对流数据进行的数据处理和分析的类型和复杂性。
我建议根据你自己数据管道场景和公司主要利益相关者收集的需求来创建一个原型。最终的流数据管道应该是能够为业务增值并满足你的数据工程目标的。
推荐阅读:
1 towardsdatascience.com/data-pipeline-design-patterns-100afa4b93e3
2 towardsdatascience.com/building-a-streaming-data-pipeline-with-redshift-serverless-and-kinesis-04e09d7e85b2
[3] medium.com/towards-data-science/data-platform-architecture-types-f255ac6e0b7
[4] medium.com/towards-data-science/data-modelling-for-data-engineers-93d058efa302
使用笔记本风格工作区简化 dbt 模型开发
原文:
towardsdatascience.com/streamline-dbt-model-development-with-notebook-style-workspace-eb156fe6e81
互动式构建和编排数据模型
Khuyen Tran
·发布于 Towards Data Science ·7 分钟阅读·2023 年 6 月 5 日
--
作者提供的图片
最初发布于 https://mathdatasimplified.com 2023 年 6 月 5 日。
动机
dbt(数据构建工具)是一个强大的数据仓库数据转换工具。
然而,它确实存在一些限制,包括以下几点:
-
缺乏输出预览: 使用 dbt core 时,无法在构建模型之前预览模型的输出,这可能会阻碍对数据转换过程的验证和迭代。
-
特征工程的局限性: 由于 SQL 是 dbt 的主要语言,因此在执行复杂的特征工程任务时存在一定的局限性。为了执行超出 SQL 能力的复杂特征工程,可能需要额外的工具或语言,如 Python。
-
部分 ETL 解决方案: 虽然 dbt 在数据转换方面表现出色,但它并未提供全面的端到端解决方案来处理数据加载、数据提取和编排等任务。
为了缓解这些挑战,dbt Cloud 提供了诸如开发数据模型、预览输出和通过用户友好的网页界面调度 dbt 作业等功能。然而,随着项目数量的增加,使用 dbt Cloud 的成本可能会变得相当高。
作者提供的图片
Mage + dbt 集成
dbt cloud 的一个免费替代品是 Mage,这是一个开源的数据管道工具,用于数据转换和集成任务。
Mage 无缝地与 dbt 互补,带来了诸多好处,包括:
-
集成的基于网页的 IDE: Mage 提供了一个方便的基于网页的 IDE,你可以在一个界面内轻松开发和探索数据模型。
-
语言灵活性: 使用 Mage,你可以将不同工具和语言的优势与 dbt 结合,以增强数据处理能力。
-
可视化 dbt 模型输出: Mage 提供了内置的可视化功能,允许用户通过几次点击轻松地可视化 dbt 模型生成的输出。
-
数据提取和加载: 除了数据转换,Mage 还提供了数据提取和加载的功能,实现更全面的端到端数据管道解决方案。
-
管道调度和重试机制: Mage 允许你调度数据管道并自动重试失败的组件,确保数据集成过程的顺利和可靠执行。
让我们深入探讨这些功能。
随意通过克隆这个 GitHub 仓库来探索和实验源代码:
[## GitHub - khuyentran1401/dbt-mage
当前无法执行该操作。你在另一个标签页或窗口中登录了。你在另一个标签页或窗口中登出了…
github.com
设置
安装 Mage
你可以通过 Docker、pip 或 conda 安装 Mage。本文将使用 Docker 安装 Mage 并初始化项目。
docker run -it -p 6789:6789 -v $(pwd):/home/src mageai/mageai /app/run_app.sh mage start [project_name]
例如,我们将项目命名为“dbt_mage”,因此命令变为:
docker run -it -p 6789:6789 -v $(pwd):/home/src mageai/mageai /app/run_app.sh mage start dbt_mage
其他安装 Mage 的方式请参考 这里。
创建管道
在浏览器中打开 localhost:6789/
查看 Mage UI。
点击“新建”,选择“标准(批量)”来创建一个新的批量管道。将其重命名为“dbt_pipeline”。
作者提供的图片
安装依赖
由于我们将使用 BigQuery 作为 dbt 的数据仓库,我们需要通过将 dbt-bigquery
添加到“requirements.txt”文件中并点击“安装包”来安装它。
作者提供的图片
创建 dbt 项目
要创建 dbt 项目,请导航到右侧面板并点击终端按钮。
作者提供的图片
移动到项目下的“dbt”文件夹并执行命令 dbt init
:
cd dbt_mage/dbt
dbt init demo -s
该命令将“demo”文件夹添加到 dbt 目录。
作者提供的图片
右键点击“demo”文件夹,创建一个名为“profiles.yml”的新文件。在此文件中指定你的 BigQuery 凭证。
作者提供的图片
请参考 此文档 以获取指定其他数据平台提供商凭证的说明。
现在设置完成,我们准备好探索 Mage 的令人兴奋的功能了!
集成的基于 Web 的 IDE
Mage 提供了一个用户友好的基于 Web 的 UI,简化了 dbt 模型的创建。它的独特功能允许进行交互式代码开发,类似于在 Jupyter Notebook 中工作。每个块都作为一个独立的可执行代码文件。
图片来源:作者
要创建一个新的 dbt 模型,点击“DBT model”并提供名称和位置。
图片来源:作者
编写查询并点击“Compile & preview”以预览查询结果。
预览结果后,通过点击三点图标并选择“Run model”来执行模型。
图片来源:作者
你可以在同一个编辑器中创建额外的模型,只需再次点击“DBT model”。
图片来源:作者
使用{{ ref('model_name') }}
来引用另一个模型,就像在 dbt 中一样。
图片来源:作者
语言灵活性和 ETL 功能
Mage 使你能够将 dbt 模型与其他语言(包括 Python、R 或 SQL)结合,用于数据加载、转换和导出目的。
图片来源:作者
例如,我们创建两个 Python 数据加载器——一个用于美国数据,一个用于印度数据。
图片来源:作者
然后,我们将结合一个 dbt 块以连接这些数据加载器的输出。通过点击“Edit parent block”并选择数据加载器块来设置join_tables
块的父块。
图片来源:作者
当一个 dbt 模型依赖于非 dbt 上游块时,Mage 会自动将该块的源添加到“dbt/demo/models/mage_sources.yml”文件中。
图片来源:作者
现在你可以在 dbt 块中利用 Python 块的输出。
图片来源:作者
然后,我们可以设置一个名为convert_object_to_int
的变换器块,将其作为join_tables
的下游块以处理其输出。
图片来源:作者
可视化 dbt 模型输出
虽然传统工具如 Tableau 可以用于可视化,但 Mage 提供了一个集中化的解决方案,可以在一个地方处理和分析 dbt 块的输出。
为了演示这一点,我们创建另一个块,名为convert_week_to_datetime
,它将Week
列转换为 datetime 类型。
图片来源:作者
点击“Add chart”图标,选择“Time series line chart”,并创建一个时间序列可视化。
图片来源:作者
你将看到如下的折线图:
作者提供的图片
管道调度和重试机制
Mage 使你能够调度管道运行,并包含一个重试机制,以处理失败的块而无需重新运行整个管道。
要立即执行管道,请点击左侧面板上的“触发器”图标,然后选择“立即运行管道”。
你可以通过选择新创建的管道触发器并点击“查看块运行”来监控所有块运行的实时状态。
作者提供的图片
要调度管道,请点击“创建新触发器”,选择“调度”,并定义所需的频率。
作者提供的图片
点击“重试未完成的块”以重试失败的块而无需重新启动整个管道。
作者提供的图片
缺点
虽然 Mage 是 dbt 的绝佳补充,但在使用 Mage 时需要考虑一些缺点:
-
项目复杂性增加: 将 Mage 和 dbt 集成可能会增加项目结构的复杂性。
-
更长的错误信息: 由于 Mage 在块代码周围添加了额外的代码,错误信息通常比标准错误信息要长。
-
学习曲线: 虽然 Mage 提供了直观的用户体验,但熟悉这个新工具需要一些时间和努力。
结论
如果你寻求提高效率并且愿意接受项目复杂性的轻微增加,Mage 是补充你的 dbt 项目的理想工具。
我喜欢写关于数据科学概念的文章,并玩弄各种数据科学工具。你可以通过以下方式保持最新:
-
订阅我的新闻通讯,内容在数据科学简化。
-
在LinkedIn和Twitter上与我联系。
使用 GPT-3 精简你的文档
原文:
towardsdatascience.com/streamline-your-documentation-with-gpt-3-5d9f2bbf217c
使用人工智能生成 Python 文档字符串
Mikhail Klassen
·发表于 Towards Data Science ·6 分钟阅读·2023 年 1 月 16 日
--
图像由 DALL-E 生成,提示语为:“一个机器人在计算机终端检查代码。赛博朋克风格,逼真。”
GPT-3,OpenAI 开发的最新语言模型,具有生成类人文本的能力,使其成为各种自然语言处理任务的强大工具。
该模型还进行了编程语言的训练。一个应用场景是生成 Python 函数的文档字符串(docstrings)。
什么是文档字符串?
文档字符串是出现在 Python 函数中的第一个语句。它提供了函数及其输入和输出的简要描述,使其他开发人员更容易理解和使用代码。编写清晰且信息丰富的文档字符串可能耗时且乏味,尤其是在拥有许多函数的大型项目中。
# A docstring example
def square(n):
"""Takes in a number n and returns the square of n"""
return n**2
通过在一个有用的提示中提供 Python 函数的源代码,可以使用 GPT-3 来生成文档字符串。
示例:计算一组数字的平均值
假设你编写了一个计算数字列表平均值的函数。你希望 GPT-3 创建文档字符串。
以下是提示的示例:
# Python 3.7
def mean_of_arr(arr):
return sum(arr)/len(arr)
# An elaborate, high quality docstring for the above function:
"""
制定正确的提示非常重要。
注意我们如何通过以下方式帮助 GPT-3:
-
告诉它编程语言(Python 3.7),
-
提供关于如何显式要求生成详细且高质量的函数文档字符串的评论,
-
然后添加一行带有三个双引号的内容,这也是文档字符串的标准起始模式。
GPT-3 生成了以下响应
This function takes an array of numbers and returns the mean of the array.
The mean is the sum of the numbers divided by the length of the array.
GPT-3 甚至知道如何缩进文本块。
需要注意的是,生成的文档字符串可能不完美,可能需要一些编辑。然而,GPT-3 可以在编写文档字符串的过程中节省大量时间和精力。
使用文档字符串生成器自动创建文档字符串
将你的代码复制并粘贴到 OpenAI 的 playground 中,仅仅让它为你生成文档字符串有点乏味。
我们已经看到 AI 工具直接嵌入我们的开发环境中,如 GitHub Copilot 或 Amazon CodeWhisperer.
未来的程序员将是人类和 AI 的混合体。让我们自己开始玩转这一能力。
假设你在使用 Jupyter notebook,如在 Colab,并且你想快速为一个新函数生成文档字符串。
假设你已经有一个 OpenAI 账户,生成你的 API 密钥 并将其粘贴到下面的示例中,复制到你的 notebook。
import openai # Install via "pip install openai"
import inspect
openai.api_key = "YOUR_OPENAI_API_KEY_GOES_HERE"
def generate_docstring(my_func, python_version=3.7):
# Grab the source code from the function as a string
source_code = inspect.getsource(my_func)
# Remove any existing docstring
if my_func.__doc__ is not None:
source_code = source_code.replace(my_func.__doc__,"")\
.replace("\"\"\"","")
# Wrap the source code with hints for GPT-3
prompt = f"# Python {python_version}\n\n" \
+ source_code \
+ "\n# An elaborate, high quality docstring for the above function:\n\"\"\""
# Send prompt to GPT-3
response = openai.Completion.create(
model="code-davinci-002",
prompt=prompt,
temperature=0,
max_tokens=150,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["#", "\"\"\""]
)
# Return the generated docstring
docstring = response["choices"][0]["text"]
return docstring
inspect
库的 getsource
函数从你传递给它的任何函数中获取源代码,并将其作为字符串返回。
如果函数已经有了文档字符串,我不想让 GPT-3 产生偏见,因此我包括了一些代码来剥离任何现有的文档字符串。
我假设代码是 Python,但我将 Python 版本留作可选变量。
实验
现在进入有趣的部分。让我们看看它在几个示例中的表现。
选择排序
这是一种时间复杂度为 O(n²) 的排序算法,意味着它在大列表上效率低下,但它的优点是简单。
def sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i+1, len(arr)):
if arr[min_idx] > arr[j]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr
我明确选择不将函数命名为 selection_sort
,因为我不希望 GPT-3 仅仅从它在训练语料库中见过的选择排序示例中获取文档字符串。
有趣的是,GPT-3 推断出这里使用的是选择排序算法。很可能在它的语料库中,选择排序算法就是以这种方式编写的。以下是 GPT-3 返回的结果:
This function sorts an array of integers in ascending order.
It uses the selection sort algorithm.
Parameters
----------
arr : list
The array to be sorted.
Returns
-------
list
The sorted array.
那是一个相当适当的文档字符串。
将数据拆分为训练集和测试集
在对数据进行机器学习模型训练之前,常见的一步是将数据拆分为训练集和测试集。
假设你有一个函数来完成这个任务,并将结果写入两个文件:
import pandas as pd
from sklearn.model_selection import train_test_split
def randomly_split_dataset(folder, filename, split_ratio=[0.8, 0.2]):
df = pd.read_json(folder + filename, lines=True)
train_name, test_name = "train.jsonl", "test.jsonl"
df_train, df_test = train_test_split(df, test_size=split_ratio[1],
random_state=42)
df_train.to_json(folder + train_name, orient='records', lines=True)
df_test.to_json(folder + test_name, orient='records', lines=True)
return
仅将 randomly_split_dataset
函数传递给我们的文档字符串生成器:
my_docstring = generate_docstring(randomly_split_dataset)
print(my_docstring)
GPT-3 返回:
This function takes in a folder and filename, and splits the dataset into
two files, train.jsonl and test.jsonl, with the split ratio specified by
the user.
Parameters
----------
folder : str
The folder where the dataset is located.
filename : str
The name of the dataset file.
split_ratio : list
The ratio of the split, with the first element being the training set
and the second element being the test set.
Returns
-------
None
Examples
--------
>>> randomly_split_dataset("data/", "train
不错。它在示例中突然停止,但我发现可以通过增加 max_tokens
的数量来改进这一点。
现在进行最后的测试:记录它自己。
为文档字符串生成器编写文档字符串
也许最令人满意的结果是让文档字符串生成器为自己编写文档字符串。
my_docstring = generate_docstring(generate_docstring)
print(my_docstring)
结果是一个完全可以接受的文档字符串,我们现在可以将其复制并粘贴到函数中。
This function takes a function as input and returns a docstring for that function.
It uses the OpenAI API to generate the docstring.
Parameters
----------
my_func : function
The function for which you want a docstring.
python_version : float
The version of Python you are using.
Returns
-------
docstring : str
The generated docstring.
结论
你的体验可能会有所不同。在我自己的实验中,结果有时会显得怪异或与我测试的功能无关。其他时候,我被明显的洞察力震惊。调整 temperature
变量可能值得尝试,以探索更多“创造性”的输出。
需要注意的是,GPT-3 实际上并不理解代码。它不能对你的代码进行推理。GPT-3 是在数十亿行代码上进行训练的,其中包括许多高质量的 docstrings。这个拥有 1750 亿参数的神经网络非常擅长预测在给定提示之后最有可能出现的 token(单词等)序列。
除了生成 docstrings,GPT-3 还可以用于许多其他自然语言处理任务,如文本总结、问答和语言翻译。GPT-3 的潜在应用广泛,未来它的使用情况将非常有趣。
总结来说,GPT-3 可以成为生成 Python 函数 docstrings 的有用工具,帮助开发者以最小的努力编写清晰且有信息量的文档。由于该模型仍在开发中,我们可以期待它在软件开发的其他方面如何改进。
如果你喜欢阅读这些故事并希望支持我作为作者,请考虑注册成为 Medium 会员。每月 $5,可访问我所有的写作以及成千上万其他作家的作品。如果你通过 我的链接 注册,我将获得一小部分佣金,对你没有额外费用。
## 使用我的推荐链接加入 Medium — Mikhail Klassen
阅读 Mikhail Klassen(以及 Medium 上成千上万其他作者)的每个故事。你的会员费用直接支持…
mikhailklassen.medium.com
优化 Azure 虚拟机性能并降低成本:提升效率的可靠策略
原文:
towardsdatascience.com/streamlining-azure-vm-performance-while-slashing-costs-proven-strategies-for-optimal-efficiency-23a9bfc7fe62
在不妨碍效率的前提下降低成本的技术
Subha Ganapathi
·发表于 Towards Data Science ·8 分钟阅读·2023 年 7 月 13 日
--
照片来源 Growtika 于 Unsplash
概述
在设置 Azure 虚拟机时,重要的是要在环境配置之前了解定价模型和服务内容。如果不这样做,我们可能会面临高额账单,而这些费用本可以通过遵循成本优化策略来避免。本文将讨论实用的策略和见解,帮助你避免这种情况并更好地控制成本。我们还将讨论多少成本过高以及 Azure 定价中高级功能的作用。
请注意,本文中的图像包含 Azure 的尺寸和配置。这些图像来源于写作时某个地区的 Azure 门户。它们仅用于演示目的,不应视为你所在地区当前可用或配置的指示。建议参考官方 Azure 文档(或你的 Azure 门户)以获取最新和准确的成本及定价信息。
让我们开始吧。
实施调度机制来启动和停止你的 Azure 虚拟机
Azure 虚拟机的计费基于资源使用情况,包括 CPU、内存和存储消耗。需要注意的是,即使虚拟机处于空闲状态或未被主动使用,你仍然会为这些资源付费。换句话说,即使虚拟机内部没有正在运行的进程/作业,你也会因为虚拟机‘开着’而被收费。因此,需要优化工作负载以防止不必要的费用。
为了有效处理这一点,请执行以下操作:
-
确定你打算在虚拟机中设置的进程和作业。这可以是与你的应用服务器相关的进程,甚至是虚拟机上托管的数据库中运行的作业。
-
确定作业之间的依赖关系,即,识别需要并行运行的作业和需要同时运行的作业。还要记录那些与其他作业无关的作业。
-
尝试在一致的时间范围内安排作业运行。例如,如果你有多个并发作业从数据库中处理数据,将作业运行对齐在相同时间范围内是有意义的。
-
一旦你缩小了时间范围,你可以安排你的虚拟机在所需的时间框架内启动和停止。例如,如果作业需要在某个时区的每晚 9 PM 到 12 AM 之间运行,你可以安排你的虚拟机在 8:45 PM 启动。
使用 Powershell、CLI 或控制台可以安排虚拟机的启动和停止。Azure 控制台为用户提供了使用 Azure 门户管理虚拟机的功能和便利。
你可以通过访问 Azure 门户中的虚拟机来找到启动、停止和释放虚拟机的选项。以下是 Azure 提供的自动化模板。
Azure 控制台中的任务(图片来源:作者)
以下是访问自动化模板的步骤 -
选择你的虚拟机并访问任务(预览)。
Azure 虚拟机 — 自动化 - 任务(图片来源:作者)
点击‘添加任务’
Azure 虚拟机 — 任务 — 添加任务(图片来源:作者)
选择模板‘启动虚拟机’
Azure 虚拟机 - 自动化模板(图片来源:作者)
在进行身份验证后,你可以配置启动虚拟机的时间。
Azure 虚拟机 — 配置启动虚拟机自动化任务(图片来源:作者)
相同的步骤也可以用于关闭和释放虚拟机。
减少成本的最重要一步是解除分配虚拟机。 停止虚拟机和解除分配虚拟机之间存在很大差异。当你停止虚拟机时,Azure 资源管理会暂时暂停虚拟机,但保留分配的资源并保存虚拟机的状态。然而,当你解除分配虚拟机时,Azure 资源管理会将分配的资源释放回 Azure 资源池。这些释放的资源可以供其他资源或服务使用,也可以根据需求分配给其他应用程序或虚拟机。这样,你可以节省计算成本。请记住,你仍然需要支付存储费用。下面是一个显示 Azure 环境及其与虚拟机交互的组件的示意图。
Azure 资源管理(作者提供的图片)
一个值得考虑的好策略是在非活动期间将虚拟机解除分配。解除分配非生产环境也是一个不错的选择,可以帮助实现显著的成本节省。确保解除分配与你的操作需求一致,并且不会影响关键流程或服务。
将磁盘从高级改为标准
希望拥有高级磁盘存储、管理磁盘和自动扩展等高级功能是很自然的。但是,通常需要问自己这些功能是否真的适合你的用例。
Azure 支持多种高级和标准磁盘。正如你可能猜到的,高级磁盘比标准磁盘更贵。如果你有像数据库和大数据处理这样的 I/O 密集型操作,你将需要继续使用高级磁盘。但是,如果你的工作负载没有时间限制且不太密集,那么标准磁盘是一个不错的选择。我们用一个例子来看看这个问题。
假设你有一个脚本每天(每天一次)从外部文件系统中提取数据,并填充到基于云的报告工具中。这是一项不那么密集的操作,可以轻松地使用标准磁盘处理。
假设你在虚拟机上托管了一个 SQL 数据库,该数据库直接连接到一个用户可以进行临时查询的自助报告。如果是这种情况,最好使用高级磁盘。不过,你可以考虑将高级磁盘降级到较低层级,以获得一些成本节省。
以下是通过控制台更改磁盘的过程。点击 Azure 虚拟机并访问侧边栏中的‘磁盘’选项。
Azure 虚拟机 — 磁盘(作者提供的图片)
点击磁盘名称中的超链接。
Azure 虚拟机- 当前操作系统(作者提供的图片)
点击侧边栏中的‘大小 + 性能’。
Azure 虚拟机- 大小 + 性能(作者提供的图片)
当前使用的配置会以灰色高亮显示。这是用户可以将磁盘从高级磁盘更改为标准磁盘类型的地方。您可以通过点击选择不同的配置。请注意,任何不适用于您的环境的配置将无法点击。
更改虚拟机配置
Azure 提供不同尺寸、内存、存储和计算能力的虚拟机。每一系列和代的 Azure 虚拟机都旨在提供特定的性能特征,让用户根据其工作负载选择最合适的类型。下图展示了 Azure 虚拟机提供的不同代数。
Azure 虚拟机世代(作者提供的图片)
确定您的工作负载,并降级到最符合工作负载要求且能够节省成本的虚拟机。以下是进行相应操作的步骤 -
-
点击虚拟机
-
点击侧边栏中的“尺寸”
Azure 虚拟机 - 尺寸(作者提供的图片)
结果面板显示了不同的虚拟机世代及其所有相关详细信息,如每小时费用、内存和支持的工作负载。
请注意,这只是一个示例图片,并不旨在展示所有可用的配置选项。
Azure 虚拟机 - 磁盘世代与尺寸(作者提供的图片)
向右滚动以访问每小时费用。
请注意,这些成本是撰写文章时的情况。当前正在使用的虚拟机的尺寸和代数会在顶部显示。要更改虚拟机配置,请点击“调整大小”。
Azure 虚拟机 - 调整虚拟机配置(作者提供的图片)
考虑更改为定价较低的地区
Azure 的定价在不同地区可能有所不同。需求较低的地区可能会有稍低的定价,而需求较高和资源有限的地区可能价格稍高。规划部署时,可以使用 Azure 定价计算器比较不同地区的定价。不过,尽量选择离您较近的地区,以避免不必要的网络延迟。
例如,你可能会发现,位于东欧的某些区域比美国东海岸和西海岸的区域便宜。假设你有一些位于欧洲的用户需要访问你设置的虚拟机。为了提高性能,设置在离他们位置较近的区域的虚拟机可能更有意义,而不是在美国区域设置虚拟机。然而,需要注意的是,在确定虚拟机设置的最佳区域时,可能需要考虑数据传输规定和限制。此外,还需要考虑的是,随着区域距离当前位置的远离,网络延迟可能增加,从而导致响应时间变慢。
总结
在这篇文章中,我们讨论了 4 种策略,以减少 Azure 账单上的成本,同时不影响现有/新工作负载的效率。所有这些策略都可以通过 Azure 门户无缝实现,无需使用编程方法。理解你的工作负载需求和处理需求是管理成本效益环境的关键。同时,平衡成本节省与性能和数据传输规定等因素同样至关重要。通过应用这些技巧,你可以优化 Azure VM 性能,同时有效管理成本,确保资源的高效利用,实现你期望的结果。
在探索性数据分析中简化重复任务
原文:
towardsdatascience.com/streamlining-repetitive-tasks-during-exploratory-data-analysis-46a40fe1d588
数据科学中的自动化
邀请你识别你的重复 EDA 任务,并创建一个自动化工作流,通过一个示例工具进行说明。
Christabelle Pabalan
·发布于 Towards Data Science ·7 min read·Oct 24, 2023
--
图片来源:作者 (DALL-E 生成)
编程原则:自动化单调的任务
人们常说懒惰的程序员是最好的程序员。然而,更准确的说法是,那些没有耐心处理重复工作流的程序员会投入前期时间来自动化他们能自动化的所有任务,以避免这些任务。简而言之,最好的程序员不会耐心重复单调的任务——他们会自动化它们。熟练的程序员之所以“懒惰”,是因为他们在前期投入时间创建工具,以便在未来节省精力。这可能意味着学习键盘快捷键、创建自定义模块或寻找聪明的软件来自动化工作流。
在一篇标题为“为什么优秀的程序员懒惰和愚蠢”的文章中,Philipp Lenssen 说道:
“只有懒惰的程序员才会避免编写单调、重复的代码——从而避免冗余,这对软件维护和灵活重构来说是敌人 […] 为了让一个懒惰的程序员成为一个优秀的程序员,他(或她)也必须在学习如何保持懒惰时非常不懒惰——也就是说,哪些软件工具使他的工作更轻松,哪些方法可以避免冗余,以及他如何使自己的工作能够被轻松维护和重构。”
没有人喜欢枯燥单调的任务,如果有人发现自己在项目中重复相同的功能,这种总体的沮丧感就会开始缠绕他们,并低声耳语,“将它们打包成模块。”
图片来源:作者
EDA 的重复性质
这些低语在我探索性数据分析阶段确实出现了。
探索性数据分析(EDA) 涉及使用统计技术和可视化来研究数据、理解其结构、识别模式,并检测任何不规则性或异常值。通常,对于新的数据集需要相同的分析和可视化,因此 EDA 可以大大受益于自动化。
完全自动化的限制
然而,在之前的尝试中,我每次都被阻碍,因为完全自动化受到了每个数据集独特挑战的限制,例如确定编码策略和确保数据类型正确。数据清洗过程与数据分析之间的相互作用是重复的,因此,很难完全标准化。
模块化方法
为了解决这个限制,我创建了一个工具,假设数据已经经过了最小处理并具有正确的数据类型。它还需要定义数值列、分类列和目标列(假设我们在处理分类任务)。
它包含了什么?
-
数值和分类数据的高级统计
-
统计显著性测试
-
相关性热图
-
类别平均值
-
数据分布可视化
该函数还提供了可选参数的灵活性,以启用或禁用上述任何功能。
本文旨在展示创建定制化 EDA 工具的价值。虽然示例侧重于自动化摘要和可视化,但关键是识别您在重复 EDA 工作中的痛点,并将您自己的重复工作流程编码化。我将重点展示工具的关键功能和示例输出,而不是包含完整代码。
数据集
数据集已上传到 Kaggle,目的是研究哪些因素可能预测患者是否会被诊断为中风。
数据集的五个随机样本观察。作者提供的图片。
轻度预处理和特征工程
我开始了这个过程:
-
从“胆固醇水平”中提取 HDL 和 LDL 胆固醇值
-
为每个症状生成二进制指示符列
-
通过标签编码将分类列和目标列转换为数值代码
# Define a function to extract values from a column and convert to integer
def extract_and_convert(column, prefix):
return column.str.extract(f'{prefix}(\d+)')[0].astype(int)
# Extract HDL and LDL values and add them as new columns
df['HDL'] = extract_and_convert(df['Cholesterol Levels'], 'HDL:')
df['LDL'] = extract_and_convert(df['Cholesterol Levels'], 'LDL:')
# List of unique symptoms
unique_symptoms = ['Difficulty Speaking', 'Headache', 'Loss of Balance', 'Dizziness',
'Confusion', 'Seizures', 'Blurred Vision', 'Severe Fatigue',
'Numbness', 'Weakness']
# Create binary columns for each unique symptom indicating its presence in 'Symptoms'
df[unique_symptoms] = df['Symptoms'].str.contains('|'.join(unique_symptoms))
# Convert categorical columns to numerical codes using label encoding
df[categorical_columns] = df[categorical_columns].apply(lambda x: pd.factorize(x)[0])
# Convert the target variable to numerical codes using label encoding
df[target] = pd.factorize(df[target])[0]
轻度预处理数据的示例
特征工程后数据集的 5 个随机样本观察。作者提供的图片。
从这里,我需要采取两个步骤:
-
定义数值列、分类列和目标列
-
运行
summary()
并输入我希望看到的函数
Summary()
定义数值列、分类列和目标列
# Define numerical columns
numerical_columns = ['age', 'bmi', 'glucose', 'stress', 'bp', 'hdl', 'ldl', ]
# Define categorical columns
categorical_columns = ['gender', 'hypertension', 'heart_dis', 'married', 'work', 'residence',
'smoker', 'alcohol', 'fitness', 'stroke_history', 'family_stroke_history',
'diet', 'speech', 'headache', 'balance', 'dizziness', 'confusion',
'seizures', 'vision', 'fatigue', 'numbness', 'weakness']
# Define target column
target = 'diagnosis'
在这篇文章中,我包含了更大的工具summary()
,并排除了辅助函数:calculate_entropy()
、statistical_tests()
、plot_distribution_plots()
、plot_correlation_heatmap()
、calculate_categorical_summary
、calculate_numerical_summary()
。
Summary() 实现
def summary(df: pd.DataFrame,
numerical_columns: list,
categorical_columns: list,
target: str,
categorical_summary: Optional[bool] = True,
numerical_summary: Optional[bool] = True,
perform_tests: Optional[bool] = True,
plot_corr_heatmap: Optional[bool] = True,
calculate_cat_averages: Optional[bool] = True,
plot_distribution: Optional[bool] = True) -> None:
"""
Generate a summary of data exploration tasks.
"""
df_numerical = df[numerical_columns]
df_categorical = df[categorical_columns]
# Join numerical and categorical columns together
df_joined = df_numerical.join(df_categorical)
df_joined[target] = df[target]
if categorical_summary:
print('\nCATEGORICAL SUMMARY')
categorical_summary = calculate_categorical_summary(df_categorical)
display(categorical_summary.round(2))
if numerical_summary:
print('\nNUMERICAL SUMMARY')
numerical_summary = calculate_numerical_summary(df_numerical)
display(numerical_summary.round(2))
if perform_tests:
print('\nSTATISTICAL TESTS')
df_summary = statistical_tests(df, categorical_columns, numerical_columns, target)
display(df_summary.round(2))
if plot_corr_heatmap:
plot_correlation_heatmap(df_joined)
if calculate_cat_averages:
for col in categorical_columns:
display(df_joined.groupby(col).mean())
if plot_distribution:
plot_distribution_plots(df, categorical_columns + [target], numerical_columns)
类别和数值汇总
该工具生成两个统计汇总——一个用于类别变量,一个用于数值变量。
类别汇总提供了对每个类别的高层次洞察,包括:
-
唯一值的数量
-
最频繁的值及其频率
-
缺失值的百分比
-
熵——分布的随机性测量
数值汇总计算常见的描述性统计数据,如:
-
唯一值的数量
-
缺失值的百分比
-
异常值的数量
-
集中趋势测量(均值,中位数)
-
离散度测量(标准差,最小/最大)
这种分解作为对类别和数值数据分布及完整性的快速评估。这些汇总有效地指出了需要更深入探索的领域,如显著的缺失数据或重要的异常值。总体而言,它们提供了数据集基本特征的全面快照。
例如,下文中可以明显看出血压数据中有四个异常值,其中一半的人群有中风病史,75%的患者表现出高血压。
图片由作者提供。
统计测试
统计测试汇总包括统计测试结果,用于评估每个特征与目标变量之间的关系。该工具对类别变量运行卡方检验,对数值变量运行双尾 t 检验,以评估每个特征与目标之间的关系。
然而,这些测试有其局限性。它们检测线性相关性,但可能忽略非线性关联或变量之间的复杂交互。结果提供了识别潜在预测特征的起点,但需要进一步分析来揭示细致的关系。因此,自动化测试加速了初步特征筛选,但应结合更深入的技术,如多变量建模和集成方法,以获得进一步的见解。
图片由作者提供。
相关热图
这种可视化突出了数值变量、序数变量和目标变量之间的斯皮尔曼相关性。选择斯皮尔曼相关性是因为它在捕捉各种类型关系方面更具鲁棒性。与皮尔逊相关性不同,斯皮尔曼的相关性是非参数的,适用于序数、类别或非线性关系。
斯皮尔曼相关热图。图片由作者提供。
图示
对于分布可视化,summary()
将返回类别变量的条形图以及数值变量的直方图和箱型图。分布可视化可以揭示数据应如何分离和处理不同,并可能突出质量保证(QA)问题或异常。
类别数据的条形图。图片作者提供。
直方图、箱型图以及不包含异常值的箱型图用于数值数据。图片作者提供。
总结
本文展示了一个以自动生成统计摘要、可视化和基本特征分析为重点的示例 EDA 工具。虽然不全面,但它允许快速探索新数据集,并揭示洞察以指导更有针对性的分析。通过一些定制,这些工具可以适应不同领域或业务背景的典型探索性工作流程。
关键在于识别流程中的冗余,并提前花时间将工作流程编码化。这会随着时间的推移积累,使你能够将认知资源集中在更高价值的领域,如领域知识、特征工程和建模。简而言之——创建你的工具,自动化重复工作,让自动化处理繁重的工作,以便你可以专注于艺术。
精简无服务器 ML 推理:释放 Candle 框架在 Rust 中的力量
原文:
towardsdatascience.com/streamlining-serverless-ml-inference-unleashing-candle-frameworks-power-in-rust-c6775d558545?source=collection_archive---------4-----------------------#2023-12-21
使用 Hugging Face 的新 Candle 框架构建一个精简且强大的模型服务层,用于向量嵌入和搜索
Alon Agmon
·
关注 发表在 Towards Data Science · 14 分钟阅读 · 2023 年 12 月 21 日
--
图片由 Clay Banks 提供,发布于 Unsplash
1. 介绍
在过去十年中,AI 研究和工具的惊人进展催生了更准确、更可靠的机器学习模型,以及使将 AI 功能集成到现有应用程序中变得越来越简单的库和框架。
然而,在要求高的生产环境中,推理规模仍然是一个相当大的挑战。例如,假设我们有一个简单的搜索服务,它接收几个关键词,然后使用语言模型将其嵌入到向量中,并在某个向量数据库中搜索类似的文本或文档。这是一个相当流行的用例,也是 RAG 架构的核心部分,RAG 架构通常用于将生成型 AI 应用于特定领域的知识和数据。
从本身来看,这似乎是一个相对简单的实现用例。我们可以使用许多开源语言模型和模型中心,几行代码就可以用作嵌入模型。如果我们进一步假设需要存储和查询的向量数量相对适中(例如少于 1M),那么有很多简单的向量存储和搜索选项:从纯内存存储到数据库如 Postgres、Redis 或 Elastic。
但,如果我们的服务需要每秒处理数千或数十万的请求怎么办?如果我们需要在每个请求上保持相对较低的 — 毫秒级 — 延迟呢?如果我们需要快速扩展以应对请求的高峰怎么办?
尽管我们的使用案例确实相当简单,但规模和负载要求无疑使其成为一项挑战。可扩展的高吞吐量系统通常基于多个小而高效的二进制文件实例,这些实例可以快速启动、扩展并处理请求。在机器学习系统,尤其是深度学习的背景下,这带来了挑战,因为常见的库通常比较笨重,部分原因是大多数库是用 Python 实现的,而 Python 在要求高的环境中扩展性较差。
因此,面对这些挑战时,我们要么选择使用一些付费服务平台来处理规模问题,要么不得不使用多种技术创建自己的专用服务层。
针对这些挑战,Hugging Face 引入了Candle框架,它被描述为“一个注重性能和易用性的 Rust 轻量级 ML 框架”。Candle 使我们能够在 Rust 中构建稳健且轻量的模型推理服务,使用类似 torch 的 API。基于 Candle 的推理服务将能够轻松扩展、快速启动,并以极快的速度处理请求,使其更适合于旨在应对规模和弹性挑战的云原生无服务器环境。
本帖的目的是展示如何使用 Candle 框架实现之前描述的流行用例,从头到尾。我们将深入探讨基于 Candle 和 Axum(作为我们的 Web 框架)的相对简单但强大的向量嵌入和搜索 REST 服务实现。我们将使用一个特定的新闻头条数据集,但代码可以很容易扩展到任何文本数据集。
这将是一个非常实践和实用的帖子。第二部分展示了我们服务的主要设计或流程,以及我们将开发和使用的相关组件。第三部分聚焦于 Candle 框架,并展示如何使用 Bert 模型实现向量嵌入和搜索功能。第四部分展示了如何使用 Axum 将模型推断功能封装在 REST Web 服务中。第五部分解释了如何创建我们服务所需的实际嵌入和工件。第六部分总结。
2. 高级服务设计
我将从我们要实现的推断服务的主要构建块开始。主要要求是创建一个 HTTP REST 端点,该端点将接收一个由几个关键字组成的文本查询,并响应与搜索查询最相似的前 5 条新闻头条。
对于这个任务,我们将使用Bert作为语言模型,因为它通常在文档嵌入任务中表现良好。在一个离线批处理过程中(将在第五部分中解释),我们将使用 Bert 来嵌入大约 20K 条新闻头条或文档,并为每个头条创建一个向量嵌入,服务将使用这些嵌入来搜索匹配项。嵌入和文本头条将被序列化为二进制格式,每个服务实例将加载这些格式以服务查询请求。
作者提供的图片
如上所示,在接收到带有搜索查询的请求后,服务将首先使用语言模型将搜索查询嵌入到一个向量中。接下来,它将搜索预加载的embedding以找到 N 个最相似的向量(每个向量代表一个新闻头条)。最后,它将使用最相似向量的索引通过映射文件提取它所代表的实际文本头条。
我们将通过创建一个名为BertInferenceModel的模块或 rust 库来实现这一点,该模块将提供主要功能:模型加载、句子嵌入和向量搜索。该模块将被一个使用 Axum Web 框架实现的 REST 服务使用。下一节将专注于实现该模块,而后续的部分将专注于 Web 服务本身。
请注意,接下来的部分包括许多代码示例,但为了清晰起见,它们仅展示了实现的主要功能。有关解决方案的完整实现,请参阅下面链接的配套 git 仓库。
3. 使用 Candle 进行模型服务和嵌入
本节将重点介绍将作为 Candle 库 API 上的抽象层的模块的实现。我们将实现一个名为BertInferenceModel的结构体,该结构体包含三个主要功能:模型加载、推理(或句子嵌入)以及使用余弦相似度进行简单的向量搜索。
pub struct BertInferenceModel {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
embeddings: Tensor,
}
BertInferenceModel 将封装我们从 Hugging Face 仓库下载的 Bert 模型和分词器,并基本上包装它们的功能。
从 Hugging Face Hub 加载模型
BertInferenceModel 将通过 load() 函数实例化,该函数将返回一个新实例的结构体,加载了相关模型和分词器,并准备进行推理任务。加载函数的参数包括我们希望加载的模型的名称和修订版(我们将使用 Bert 句子转换器)以及嵌入文件路径。
pub fn load(
model_name: &str,
revision: &str,
embeddings_filename: &str,
embeddings_key: &str,
) -> anyhow::Result<Self> {}
如下加载函数代码所示,加载模型涉及创建一个包含 Hugging Face 仓库相关属性的Repo结构体(例如名称和修订版),然后创建一个API结构体以实际连接到仓库并下载相关文件(用于创建模型的模型权重使用 HuggingFace 的 safetensors 格式表示)。
api.get函数返回相关文件的本地名称(无论是下载的还是仅从缓存中读取的)。文件将只下载一次,而随后的api.get调用将仅使用缓存版本。我们使用分词器配置文件实例化一个分词器结构体,并使用权重文件(以 safetensors 格式)和配置文件来构建我们的模型。
在加载了模型和分词器之后,我们最终可以加载实际的嵌入文件,用于搜索匹配项。我将稍后展示如何使用相同的模型生成嵌入文件,然后将其序列化为文件。使用 HuggingFace 的 safetensors 模块将嵌入加载为 Tensor 相对简单,我们只需要文件名和保存 Tensor 时给予的密钥。
现在我们已经加载了模型和分词器,并且内存中有了嵌入向量,我们完成了对返回给调用函数的 BertInferenceModel 的初始化,可以继续实现推理方法。
句子推理和嵌入
推理函数也相当简单。我们首先使用加载的分词器对句子进行编码(第 5 行)。encode()函数返回一个 Encoding 结构体,该结构体具有一个 get_ids() 函数,返回一个数组或句子中单词的数值表示。接下来,我们将令牌 ID 数组包装在一个 Tensor 中,以便将其输入到我们的嵌入模型中,并使用模型的前向函数获取表示句子的嵌入向量(第 10 行)。
我们从嵌入模型中在第 12 行得到的向量维度是[128, 384]。这是因为 Bert 用大小为 384 的向量表示每个标记或词,且句子向量的最大输入长度为 128(因为我们的输入只有几个单词,所以大部分是填充)。换句话说,除了填充和其他指令标记外,我们基本上获得了每个标记或词的大小为 384 的向量。
接下来,我们需要将句子向量从大小为[128, 384]的张量压缩成一个大小为[1, 384]的单一向量,该向量将代表或捕捉句子的“精髓”,以便我们可以将其与嵌入中的其他句子进行匹配,并找到与之相似的句子。为此,并且部分因为我们处理的输入是短关键字而不是长句子,我们将使用最大池化,它本质上通过取给定张量每个维度的最大值来创建一个新向量,以捕捉每个维度中最显著的特征。正如您在下方看到的,这使用 Candle 的 API 实现起来相当简单。最后,我们使用 L2 归一化以避免偏差,并通过确保所有向量具有相同的幅度来改善余弦相似性度量。您可以在下方看到池化和归一化函数的实际实现。
测量向量相似性
尽管这与 Candle 库没有直接关系,但我们的模块也将提供一个向量搜索实用方法,该方法将接收一个向量并利用其内部嵌入以返回最相似向量的索引。
这实现得相当简单:我们首先创建一个元组集合(第 7 行),其中元组的第一个成员将表示相关文本的索引,第二个成员将表示余弦相似性评分。然后,我们遍历所有索引,测量每个与我们需要匹配的给定向量之间的余弦相似性。最后,我们将(索引,相似性评分)的元组添加到集合中,对其进行排序,并返回前 N 个请求的结果。
4. 嵌入和搜索 Web 服务
现在我们有了一个封装主要模型功能的结构体,我们需要将其封装在一个 REST 服务中。我们将创建一个 REST 端点,包含一个 POST 路由,该路由将接收 JSON 有效负载中的几个关键字,并返回在预加载嵌入中的最相似向量的索引。根据请求,服务将把关键字嵌入到一个向量中,搜索其内存中嵌入的相似性,并返回最相似向量的索引。该服务将使用索引在文本映射文件中找到相应的标题。
我们将使用优秀的 Axum web 框架来实现这个服务。相关代码大部分是典型的 Axum 模板代码,所以我不会详细讲解如何使用 Axum 创建 REST 端点。与许多 web 框架一样,构建 REST 端点通常涉及创建一个 Router 并在某个路由上注册一个处理函数来处理请求。然而,ML 模型服务层具有额外的复杂性,即管理模型本身的状态和持久性。模型加载可能在性能上很昂贵,因为它涉及加载模型文件的 IO 操作(无论是从 Hugging Face 的仓库还是本地)。同样,我们需要找到一种方法来缓存和重用模型以应对多个请求。
为了满足这些要求,Axum 提供了应用状态功能,我们可以用来初始化和持久化任何我们想要注入到每个请求上下文中的资产。首先让我们逐行查看服务的整个初始化代码,看看它是如何工作的。
每个服务实例都会创建并加载一个模型包装器,然后缓存它以供每个接收到的请求重用。在第 3 行,我们通过调用load()函数来创建模型包装器,以引导并加载模型。除了从 HF 加载的 Bert 模型的名称和版本,我们还需要指定嵌入文件的位置,该文件被加载到内存中以便搜索相似的向量,以及在创建嵌入时使用的密钥。
除了实际的模型,我们还需要缓存映射文件以供每个请求重用。服务在使用模型嵌入关键词后,会在其嵌入文件中搜索最相似的向量,然后返回它们的索引。服务接着使用映射文件提取与最相似向量的索引对应的实际文本。在更稳健的生产系统中,服务会从某个快速访问的数据库中获取实际文本,但在我们的案例中,从存储在文件中的预加载字符串列表中读取就足够了。在第 10 行,我们加载了之前保存为二进制文件的列表。
现在我们有两个需要缓存和重用的资产——模型(包装器)和映射文件。Axum 使我们能够使用Arc,即线程安全的引用计数指针,每个请求都会共享。正如第 15 行所示,我们在包含模型包装器和映射文件的元组周围创建了一个新的 Arc。在第 17–19 行,我们创建了一个新的 HTTP 路由来处理每个请求的函数。
let shared_state =
Arc::new((bert_model, text_map));
let app = Router::new()
.route("/similar", post(find_similar))
.with_state(shared_state);
为了缓存元组并使其对每个请求可用,我们使用with_state(state)函数将其注入到相关的请求上下文中。我们来看看具体是如何操作的。
处理请求
我们的服务将处理 HTTP POST 请求,这些请求包含以下有效负载,这些有效负载包括关键词和我们想要接收的相似向量或标题的数量。
{
"text": "europe climate change storm",
"num_results":5
}
我们将实现处理函数的相应请求和响应结构体,Axum 将在需要时处理序列化和反序列化。
#[derive(Deserialize)]
struct ReqPayload {
keywords: String,
num_results: u32,
}
#[derive(Serialize)]
struct ResPayload {
text: Vec<String>,
}
接下来,我们可以进入处理函数本身。处理函数将接受 2 个参数:我们之前初始化的应用程序状态(Axum 将负责将其注入到每个函数调用中),以及我们之前定义的请求结构体。
处理每个请求将包括 4 个主要阶段,这些阶段现在应该已经很清楚了。在第 5 行,我们首先提取一个指向状态元组的引用,该元组持有对模型和映射文件的引用。在第 6 行,我们使用模型将关键词嵌入到一个向量中。接下来,在第 9 行,我们搜索 N 个最相似的向量。score_vector_similarity()函数返回一个由元组组成的向量,每个元组包含一个索引和余弦相似度分数。最后,我们遍历这些索引元组,从映射文件中提取对应索引的字符串,并将其封装到响应有效负载结构中。
然后……我们就可以开始了!虽然这可能并没有具体说明太多,但我在我的 Mac 上进行了测试,使用了大约 20K 向量的嵌入,并获得了 100ms 的良好平均响应时间。对于基于 Bert 的向量嵌入和向量搜索来说,这算是不坏。
curl -s -w "\\nTotal time: %{time_total}s\\n" \
-X POST http://localhost:3000/similar \
-H "Content-Type: application/json" \
-d '{"text": "self driving cars navigation", "num_results": 3}' | jq
{
"text": [
"Item:Stereo Acoustic Perception ... (index: 8441 score:0.8516491)",
"Item:Vision-based Navigation of ... (index: 7253 score:0.85097575)",
"Item:Learning On-Road Visual ..... (index: 30670 score:0.8500275)"
]
}
Total time: 0.091665s
(这个示例是使用在 Arxiv 论文摘要数据集上生成的嵌入创建的。实际数据集可以在这里以公共领域许可证获取。)
5. 生成嵌入
在我们结束之前,还有一个最后的组件需要覆盖。到目前为止,我们假设了一个嵌入文件的存在,在其中我们搜索相似的向量。然而,我还没有解释如何创建嵌入文件本身。
请记住,在上一节中创建的结构体— BertInferenceModel,已经包含了一个将一组关键词嵌入到向量中的函数。当我们创建一个需要嵌入多个关键词集的函数时,我们只需将它们作为批量处理即可。
我们使用BertInferenceModel的主要区别在于使用 tokenizer 的encode_batch函数而不是encode,前者接受一个字符串向量而不是单个字符串。然后,我们将所有向量堆叠成一个单一的张量,并将其输入到模型的forward()函数中,就像我们处理单个向量嵌入时一样(你可以在下面链接的辅助仓库中查看函数的完整源代码)。
一旦我们拥有能够嵌入多个字符串的函数,嵌入生成器本身就相当简单。它使用 rayon crate 来并行处理文本文件的嵌入,然后将结果堆叠在一起,创建一个单一的张量。最后,它使用 safetensors 格式将嵌入写入磁盘。嵌入是这个管道中的重要资产,因为它需要被复制到每个服务实例。
现在我们可以得出结论 😃
6. 结论
在机器学习工程中,最大的挑战之一是大规模推断。人工智能绝非轻量级,因此,扩展推断工作负载往往是一项非常昂贵或过度设计的痛苦挑战。这正是 Hugging Face 的 Candle 库试图解决的难题。它使用类似 Torch 的 Rust API,使我们能够创建一个精简且快速的模型服务层,能够轻松扩展并在无服务器环境中运行。
在这篇文章中,我展示了如何使用 Candle 创建一个端到端的模型推断层,能够处理向量嵌入和搜索请求。我解释了如何将 Bert / sentence transformers 模型包装成一个内存占用小的库,并在基于 Axum 的 REST 服务中使用它。
Hugging Face 的 Candle 库的真正价值在于其能够弥合强大机器学习能力与高效资源利用之间的差距。通过利用 Rust 的性能和安全特性,Candle 为更可持续和成本效益高的机器学习解决方案铺平了道路。这对那些希望在不增加开销的情况下大规模部署 AI 的组织特别有利。我希望借助 Candle,我们将看到一波不仅性能高效,而且更轻量且适应各种环境的机器学习应用。
一些关于 Candle 的资源
-
github.com/huggingface/candle
-
medium.com/@Aaron0928/hugging-face-has-written-a-new-ml-framework-in-rust-now-open-sourced-1afea2113410
-
pub.towardsai.net/candle-and-falcon-a-guide-to-large-language-models-in-rust-3f0a4369df03
本文的所有源代码可以在我的 GitHub 仓库中找到 这里
Streamlit 和 MongoDB:在云端存储你的数据
原文:
towardsdatascience.com/streamlit-and-mongodb-storing-your-data-in-the-cloud-c28663313ade
Streamlit Cloud 没有本地存储,因此在应用程序终止时创建的数据会丢失——除非你使用类似 MongoDB 的第三方存储
Alan Jones
·发表于 Towards Data Science ·12 分钟阅读·2023 年 8 月 25 日
--
经典 NoSQL 数据库——照片由 Jan Antonin Kolar 提供,Unsplash
Streamlit 允许你将公共应用程序部署到他们的云端,免费提供,但你在本地创建的任何文件或数据库将在应用结束时消失。这可能不是你想要的行为,因此我们将探讨使用 MongoDB 的解决方案。
对于许多应用程序来说,丢失这些数据并不成问题。例如,如果你设计了一个从外部源读取数据的仪表板,那么你生成的任何数据都可能是临时的,仅在应用运行期间需要。
但正如我在为文章开发应用程序时所提到的,Simple Surveys with Streamlit,如果应用程序本身生成需要存储的数据,这就不那么简单了。在那个应用程序中,我将数据存储在本地文件中,但在基于云的部署中,这些数据将在应用停止运行时消失——正确的解决方案是使用外部数据存储。
我们将看看如何通过 MongoDB 实现这一点,但也有其他选择。
有哪些选择?
在 Streamlit 文档中,有关于连接各种数据库和云存储供应商的指南。它们基本上分为三个领域:数据桶,例如 AWS S3 和 Google Cloud Storage,你可以在其中存储任何东西;SQL 数据库,如微软的 SQL Server、MySQL、PostgreSQL;以及 NoSQL 数据库,Firestore 和 MongoDB 就是其中的例子。对于每种类型,你显然需要访问托管该特定数据库的服务器。
坦白说,我不是 SQL 的最大粉丝。对我来说,SQL 代码和 Python 之间的脱节让人感到不适。(话虽如此,我确实欣赏 SQL 的强大功能和便利,并且在这里、这里和这里写过相关内容。)
但 NoSQL 数据库如 MongoDB 感觉更符合 Python 的工作方式。
我相信关于速度、效率、易用性、安全性等各种争论都有。但我不打算讨论这些。我将使用 MongoDB,因为这是我的个人喜好和偏见,你可以自己决定它是否是一个好的选择。
调查应用程序
我将使用我开发的一个应用程序版本来演示在 Streamlit 中构建简单调查。我有意编写它以便于移植到不同的数据存储,因此所有的数据存储代码都在一个库中。其目的是,为了从使用本地数据存储转到基于云的数据库,你只需重新编写库即可。
但首先,我们需要访问 MongoDB 数据库。
MongoDB Atlas
你可以下载 MongoDB 并在自己的服务器上运行一个实例。或者,你可以使用像 MongoDB Atlas 这样的托管服务。我们将使用后者。
MongoDB Atlas 是搜索中出现的第一个托管服务——虽然还有其他服务,但我不知道有哪个提供免费层,因此我们将使用这个。作为一种免费增值服务,你可以从零开始,这几乎满足了你对简单(或者实际上不那么简单)应用程序的所有需求,但限制你存储数据为 0.5 GB。
要使用 MongoDB,你需要注册一个帐户(使用上面的链接),并且还应阅读他们网页上资源标签中的全面入门指南。
一旦设置好,你将拥有创建连接字符串的信息,如下所示:
"mongodb+srv://<user>:<password>@<cluster-url>?retryWrites=true&writeConcern=majority"
替换 <user>
、<password>
和 <cluster-url>
为你自己的详细信息,并将其保留为 Streamlit 密钥。
一个 MongoDB 数据库
MongoDB 由 集群 和 数据库 组成。数据库内有 集合(有点像表),集合内有 文档(以类似 JSON 的格式存储的实际数据)。
为了本教程的方便,我创建了一个数据库和两个集合。这些集合将存储构成调查的问题以及用户生成的结果。我将它们命名为 survey1
和 results1
——我知道这并不特别有创意。
在结果收集部分,每个文档将代表一个问题。在结果表中,每个文档将代表对调查中所有问题的单独回答。
调查 数据库 截图
点击其中一个集合名称将显示文档(如果有的话)。
调查 1 集合中的问题
如你所见,这些条目非常类似于 Python 字典。注意 _id
字段:这是一个唯一标识符,对于 MongoDB 文档来说是必需的;你可以在创建文档时指定它,否则 MongoDB 会自动为你创建一个。
应用程序
你可以在原始文章中阅读实现的详细信息。这里,我们将专注于数据库代码。(在本文发布后不久,将会有一个指向我网页上所有代码的链接。)但让我们快速回顾一下应用程序的功能。
原始应用程序由三个页面组成:
-
第一个页面允许用户通过指定单独的问题并存储它们来创建一个简单的问卷。
-
第二个页面向用户展示问卷,收集响应并将其添加到之前的响应中。
-
第三个页面展示结果;它读取响应并对其进行一些简单的分析 —— 绘制概览柱状图,并让用户选择一个问题以更详细地查看该问题的响应。数据也可以作为 CSV 文件下载。
每个页面使用一个库,DButils,来访问数据。在原始应用程序中,数据存储为 JSON 文件。
为了保持简单,我将专注于问卷的展示和收集的数据的展示。这意味着我们需要对数据库进行读写操作。(我的想法是,创建问卷可能最好在本地完成,并将结果存储在文件中,之后再将其完整上传到 MongoDB —— 因此,本文将不涉及应用程序的这一方面。)
数据库库
原始库使用了一些辅助函数,但该版本应用程序使用的函数有:
-
get_survey()
— 这个函数用于检索整个问卷 -
append_results(value)
— 这个函数将新的响应添加到现有响应中 -
get_results()
— 这个函数用于检索所有的响应
对于 MongoDB 版本,我们需要一点前置准备。
import streamlit as st
from pymongo.mongo_client import MongoClient
DB = "survey"
SURVEY_KEY = "survey1"
RESULTS_KEY = "results1"
uri = st.secrets['mongoURI']
client = MongoClient(uri)
我们显然需要导入 Streamlit;我们还需要从 pymongo 库中导入 MongoClient。然后,我们设置一些便利常量,将数据库和集合名称硬编码到代码中。这当然使得库特定于这个应用程序,如果我们希望在不同的应用程序中再次使用该库,可能会避免这种情况。我们更倾向于在应用程序中设置这些值,并将其传递给库。
变量 uri
是我们之前创建的连接字符串,通过它我们创建一个 MongoDB 客户端,从中可以访问我们的数据。
读取和写入数据非常简单。首先,你需要指定数据库以及该数据库中的集合。然后,你可以使用 find()
函数来读取数据,使用 insert_one()
或 insert_many()
函数将数据添加到集合中。
我们将在这里说明读取和写入函数的简单用法,但要想充分利用 MongoDB,你需要查阅他们网站上的文档,例如 如何在 MongoDB 中使用 Python。
读取数据
要从 MongoDB 集合中读取所有数据,我们使用 find()
函数。下面的代码将从我们其中一个集合中读取所有数据——key
参数应为 SURVEY_KEY
或 RESULTS_KEY
。
def get(key):
coll = db[key]
item_details = coll.find()
return list(data)
从 find()
返回的值是一个游标,我们可以遍历它以检索数据。然而,由于我们在这里处理的数据量很小,将其转换为 Python 列表以便所有数据都保存在内存中是完全可以接受的。
但这并不是我们想要的。正如你会记得的,每个集合中的文档都有一个由 MongoDB 自动提供的唯一标识符。对于这个应用程序,我们不需要,也确实不想要那个属性。因此,让我们更仔细地查看这个函数。
MongoDB 查询和过滤器
正如你所料,我们可以对 MongoDB 数据库进行查询。为了方便这一点,find()
函数可以接受修改其行为的参数。每个参数都是可选的,正如我们看到的,通过不指定这些参数,我们检索到所有的数据。
下面我们可以看到一个小调查的片段。这是集合中所有文档的数组,每个文档由 id 字段加上三个调查问题和每个问题的回答组成。
[
{
"_id": "ObjectId('64de4b2425de44afbff94ba5')",
"How many years of programming experience do you have?": "less than 1",
"What programming language do you use most?": "Python",
"Towards Data Science is one of the most useful publications on Medium": "Strongly agree"
},
{
"_id": "ObjectId('64de4b3125de44afbff94ba6')",
"How many years of programming experience do you have?": "1 to 5",
"What programming language do you use most?": "Python",
"Towards Data Science is one of the most useful publications on Medium": "Agree"
},
...
]
如果我们只需要对某个问题的回答,可以添加一个查询参数。在下面的代码中,我们指定我们想要包含对问题“你使用最多的编程语言是什么?”的回答为 ‘R’ 的文档。该参数以 Python 字典的形式提供,你可以在代码下方立即看到响应。
def get(key):
coll = db[key]
item_details = coll.find({'What programming language do you use most?':'R'})
return list(item_details)
[
{
"_id": "ObjectId('64df43224c943462625ec464')",
"How many years of programming experience do you have?": "5 to 10",
"Towards Data Science is one of the most useful publications on Medium": "Disagree",
"What programming language do you use most?": "R"
},
{
"_id": "ObjectId('64df43324c943462625ec465')",
"How many years of programming experience do you have?": "1 to 5",
"Towards Data Science is one of the most useful publications on Medium": "Neither agree not disagree",
"What programming language do you use most?": "R"
}
]
正如你所预期的,返回值是与查询匹配的文档集合——它相当于 SQL 中的 SELECT * WHERE …
。
这为我们提供了 MongoDB 查询如何工作的有用见解,但并未解决省略 id 字段的问题。为此,我们需要一个第二个参数。
第二个参数也是 Python 字典的形式,并充当过滤器,例如 {'_id':False}
。这告诉 MongoDB 我们不需要 id 字段。我们可以在此参数中指定字段列表,每个字段标记为 True
或 False
,取决于我们是否希望包含它们。
如果我们将这个第二个参数添加到之前的查询中,结果将是这样:
[
{
"How many years of programming experience do you have?": "5 to 10",
"Towards Data Science is one of the most useful publications on Medium": "Disagree",
"What programming language do you use most?": "R"
},
{
"How many years of programming experience do you have?": "1 to 5",
"Towards Data Science is one of the most useful publications on Medium": "Neither agree not disagree",
"What programming language do you use most?": "R"
}
]
与之前相同的条目,但省略了 id 字段。
好的,让我们回到我们真正想要的,即所有的文档,但省略 id 字段。我们仍然需要指定两个参数,但第一个参数需要告诉 MongoDB 我们想要所有文档,我们通过传递一个空字典来实现。
def get(key):
coll = db[key]
item_details = coll.find({},{'_id':False})
return list(item_details)
这给了我们想要的结果。
[
{
"How many years of programming experience do you have?": "less than 1",
"What programming language do you use most?": "Python",
"Towards Data Science is one of the most useful publications on Medium": "Strongly agree"
},
{
"How many years of programming experience do you have?": "1 to 5",
"What programming language do you use most?": "Python",
"Towards Data Science is one of the most useful publications on Medium": "Agree"
},
{
"How many years of programming experience do you have?": "more than 10",
"What programming language do you use most?": "Python",
"Towards Data Science is one of the most useful publications on Medium": "Strongly agree"
},
{
"How many years of programming experience do you have?": "5 to 10",
"What programming language do you use most?": "Julia",
"Towards Data Science is one of the most useful publications on Medium": "Neither agree not disagree"
},
{
"How many years of programming experience do you have?": "more than 10",
"What programming language do you use most?": "other",
"Towards Data Science is one of the most useful publications on Medium": "Agree"
},
{
"How many years of programming experience do you have?": "5 to 10",
"Towards Data Science is one of the most useful publications on Medium": "Disagree",
"What programming language do you use most?": "R"
},
{
"How many years of programming experience do you have?": "1 to 5",
"Towards Data Science is one of the most useful publications on Medium": "Neither agree not disagree",
"What programming language do you use most?": "R"
},
{
"How many years of programming experience do you have?": "more than 10",
"Towards Data Science is one of the most useful publications on Medium": "Strongly agree",
"What programming language do you use most?": "Python"
},
{
"How many years of programming experience do you have?": "1 to 5",
"Towards Data Science is one of the most useful publications on Medium": "Agree",
"What programming language do you use most?": "Julia"
}
]
当我们读取调查问题时,我们想做的正是同样的事情——返回所有文档,但不包含 id 字段——因此这个函数用于读取两个数据库。我们只需要提供正确的键,无论是 SURVEY_KEY
还是 RESULTS_KEY
。
追加数据
在这个版本的应用中,我们仅向结果集合中添加数据。具体方法如下,使用 insert_one()
函数完成。
def append_results(value):
coll = db[RESULTS_KEY]
coll.insert_one(value)
再次,我们选择具有适当键的集合,并添加以 Python 字典列表形式表示的调查问题及其每个回答的值。正如我们之前提到的,MongoDB 会自动添加唯一的 id。我们最终得到的集合与我们上面看到的类似。
如果你需要向集合中追加一组文档,insert_many()
函数可以为你完成这项工作。
代码
库的完整代码如下——实际上非常简洁!
import streamlit as st
from pymongo.mongo_client import MongoClient
DB = "survey"
SURVEY_KEY = "survey1"
RESULTS_KEY = "results1"
uri = st.secrets['mongoURI']
client = MongoClient(uri)
db = client[DB]
# Get functions
def get(key):
coll = db[key]
item_details = coll.find({},{'_id':False})
return list(item_details)
def get_survey(key=SURVEY_KEY):
return get(key)
def get_results(key=RESULTS_KEY):
return get(key)
# Append results function
def append_results(value):
coll = db[RESULTS_KEY]
coll.insert_one(value)
结论
我希望你能看到,使用 MongoDB 是相当简单的,并且为 Streamlit 应用中的持久存储问题提供了一个简洁的解决方案。我们本可以使用托管的 SQL 解决方案,但正如我之前提到的,NoSQL 解决方案在我看来更符合 Python 的处理方式——将 Python 列表映射到 MongoDB 集合是完全自然的。
不过,我应该补充一点:本文中的代码是在 Streamlit 引入新的 st.experimental_connection()
特性之前编写的,这种方法或其他基于类的解决方案可能会为生产代码提供更好的前景,特别是当 MongoDB 将被用于多个不同的项目时。然而,对于像这样的临时项目,我在这里使用的方法似乎简单且足够。
感谢阅读,希望这对你有帮助。如果你对代码或我的方法有任何疑问,或有任何评论,或者只是想说“你好”,请在下面添加评论——我总是会回复。
当你阅读这篇文章时,我的网站上应该有指向完整应用代码的链接 website。在那里你可以找到其他关于 Streamlit、数据可视化、数据科学和编程的一些文章和书籍——请查看一下。
Streamlit 教程:为数据科学项目创建 Word 报告
原文:
towardsdatascience.com/streamlit-tutorial-creating-word-reports-for-data-science-projects-96a749483cb3
将 python-docx 和 Streamlit 结合用于数据科学报告自动化
Andy McDonald
·发表于Towards Data Science ·12 分钟阅读·2023 年 4 月 17 日
--
报告图像由作者使用 Midjourney 基础计划生成。
在数据相关项目的结束阶段,无论是石油物理学还是数据科学,创建报告是非常常见的。生成的报告为客户和最终用户提供了在研究过程中获得的关键结果和结论的信息,并详细说明了使用的方法。
然而,创建结构化报告可能是一个繁琐且耗时的过程,特别是在确保报告格式正确和数据以最佳方式呈现时。
本文将展示我们如何使用流行的Streamlit库,结合python-docx,来创建自动化报告过程的第一步。
python-docx库将允许我们创建一个 Microsoft Word 报告。将报告以这种格式呈现将使我们能够进行编辑,并在转换为 PDF 之前完成最后的润色。
尽管本文中的示例需要主要的手动输入,但它可以适应大语言模型的强大功能,以总结数据并创建所需的文本。
让我们开始构建一个 Streamlit Word 文档报告生成器吧。
导入库和数据
首先,我们将导入我们要使用的主要库。这些库包括Streamlit、pandas、matplotlib和python-docx。
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import docx
接下来,我们将设置 Streamlit 页面布局为宽幅。这使得应用程序能够占据浏览器窗口的整个宽度,而不是位于中间的窄列中。
st.set_page_config(layout='wide')
设置 Streamlit 用户界面(UI)
现在我们可以开始构建用户界面(UI)。
我们将从为我们的应用程序添加标题开始。
st.title('Streamlit Data Report Generator')
报告生成器的起始点。图片来源:作者。
为了简化操作,我们将使用 pd.read_csv()
预加载我们的数据并传入一个文件名。
df = pd.read_csv('Xeek_Well_15-9-15.csv')
本教程使用的数据集是 Xeek 和 FORCE 2020 举办的机器学习竞赛(Bormann 等,2020)中的一个子集。该数据集在 Creative Commons Attribution 4.0 International 许可证下授权。
为了使应用程序更具灵活性,我们可以添加一个文件上传器,允许用户加载他们自己的数据。
你可以在我的文章 上传和读取 Streamlit 文件 中了解更多关于如何做到这一点的信息。
使用 st.form 创建报告表单
当小部件包含在 Streamlit 应用程序中时,每次编辑或选择它们时,Streamlit 应用程序都会重新运行。为了防止这种情况,我们可以创建一个表单。
这将允许我们输入值,应用程序只会在按下按钮时运行。
我们可以使用 with st.form('report')
创建一个表单,然后添加我们想要的输入。
带有报告详细部分的 Streamlit 报告生成器。图片来源:作者。
我们将使用应用程序的上部来创建报告元数据。这包括报告的标题、作者、客户以及报告的日期。
这些元素中的每一个都与来自 Streamlit 的用户输入小部件相关联。
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
为了显示表单,我们需要添加一个提交按钮。使用 st.form_submit_button()
可以完成这一点。在这里,我们可以传递一个将出现在按钮上的标签。
if st.form_submit_button('Generate'):
generate_report(report_title)
在此之下,我放置了一个 generate_report
函数调用,我们将很快创建这个函数。目前,这将作为一个占位符。
这是目前表单的代码。
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
if st.form_submit_button('Generate'):
generate_report(report_title)
在 Streamlit 表单中创建报告部分
报告通常由多个部分或章节组成。
为了说明如何在我们的应用程序中创建一个非常简单的部分,我们将添加一些输入,让用户输入部分标题和摘要。
带有部分输入框的 Streamlit 报告生成器。图片来源:作者。
在上图中,我添加了两个新的输入小部件。
一个部分标题,这是一个简单的文本输入(st.text_input
),和该部分的摘要,这是一个文本区域(st.text_area
)。
此外,我创建了两个新的列,以将它们与上面的列分开。如果我们想在这些表单部分之间添加任何全宽的文本/信息,这一点非常重要。
这是我们目前的表单代码:
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
sect_col1, sect_col2 = st.columns(2, gap='large')
sect_col1.write("### Section Details")
section_title = sect_col1.text_input("Enter section title")
section_text_summary = sect_col1.text_area("Section Summary")
我们可以扩展此功能,以便用户可以添加多个部分。每个部分都可以被编码为在新页面上开始,使用分页符。
此外,为了使其更全面,我们可以在应用程序中生成报告的预览。
可能性非常多!
在 Word 文档中包含数据框使用 docx
表格在报告中至关重要,因为它们有助于以清晰、简单和有组织的方式展示信息。这使读者能够快速理解数据,并将其与同一或不同表中的其他数据值/类别进行比较。
为了说明在报告中包含一个表格,我们可以使用 pandas describe()
函数生成的统计摘要作为示例。
在 UI 中,我们可以添加一个多选选项,允许用户从数据框中选择列。如果我们有许多列且只对其中一些感兴趣,这将特别方便。
Streamlit 报告生成器允许用户从数据框中选择列。图片由作者提供。
在创建多选条目框之前,我们首先需要从数据框中获取列名,这可以通过创建一个新变量并将其分配给 df.columns
来完成。
然后我们使用 st.multiselect()
创建多选框。由于我们在处理列,因此需要调用所需的列。在这种情况下,是 sect_col2
。
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
sect_col1, sect_col2 = st.columns(2, gap='large')
sect_col1.write("### Section Details")
section_title = sect_col1.text_input("Enter section title")
section_text_summary = sect_col1.text_area("Section Summary")
data_features = df.columns
sect_col2.write("### Data Summary")
data_to_summarise = sect_col2.multiselect("Select features to include in statistical summary",
options=data_features)
if st.form_submit_button('Generate'):
generate_report(report_title)
接下来,我们需要创建两个函数。
第一个函数将获取我们感兴趣的特征和数据框,并生成数据的统计摘要。
def create_df_stats_summary(dataframe, features_to_include):
sub_df = dataframe[features_to_include].copy()
return sub_df.describe()
第二个函数要复杂一些。
由于 python-docx 不原生支持数据框,我们需要使用 docx 创建一个表格,如下所示:
def add_df_to_docx(doc, dataframe):
# Reset the index and get the new shape
dataframe = dataframe.reset_index()
num_rows, num_cols = dataframe.shape
# Add a table to the document with the necessary number
# of rows and columns
table = doc.add_table(rows=num_rows + 1, cols=num_cols)
# Add the header row
for i, col in enumerate(dataframe.columns):
table.cell(0, i).text = str(col)
# Add the data rows
for i, row in dataframe.iterrows():
for j, value in enumerate(row):
table.cell(i + 1, j).text = str(value)
return table
当按钮被按下时,我们将调用这些函数。
将图表添加到 Word 文档
图表是报告的另一个重要部分。它们使我们能够简明扼要地传达大量数据。
为了说明在最终 Word 文档中创建和包含图表,我们将允许用户从数据集中选择三列。然后,这些将用于创建一个散点图,并添加到报告中。
在包括散点图选项后,Streamlit 报告生成器。图片由作者提供。
如上图所示,我们将在前面两个部分下方添加三个选择框。这些将添加到三个新列中,并使用 Streamlit 的 selectbox()
创建。
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
sect_col1, sect_col2 = st.columns(2, gap='large')
sect_col1.write("### Section Details")
section_title = sect_col1.text_input("Enter section title")
section_text_summary = sect_col1.text_area("Section Summary")
data_features = df.columns
sect_col2.write("### Data Summary")
data_to_summarise = sect_col2.multiselect("Select features to include in statistical summary",
options=data_features)
st.write("### Scatterplot Setup")
sub_col1, sub_col2, sub_col3 = st.columns(3)
chart_x = sub_col1.selectbox('X axis', options=data_features)
chart_y = sub_col2.selectbox('Y axis', options=data_features)
chart_z = sub_col3.selectbox('Z axis', options=data_features)
if st.form_submit_button('Generate'):
generate_report(report_title)
然后我们将创建一个新的函数,称为 create_scatterplot
,用于生成我们的图形。
我们将设置我们的函数以接受多个参数:
-
dataframe
:包含数据的数据框对象 -
xaxis
:要在 x 轴上绘制的特征 -
yaxis
:要在 y 轴上绘制的特征 -
colour
:用于为数据点上色的特征 -
plot_name
:我们图表的名称。这将用作文件名 -
xaxis_scale
:一个包含两个元素的列表,用于定义 x 轴的最小值和最大值范围 -
yaxis_scale
:一个包含两个元素的列表,用于定义 y 轴的最小值和最大值范围
默认情况下,xaxis_scale
和 yaxis_scale
都会设置为 None
。如果用户没有提供这些,matplotlib 将使用数据绘制的最小值和最大值作为轴的范围。
Python-docx 本身不支持 matplotlib 图形。作为一种解决方法,我们需要将我们的图保存为文件,然后在开始写入 Word 文档时使用。
def create_scatterplot(dataframe, xaxis, yaxis, colour, plot_name,
xaxis_scale= None, yaxis_scale=None):
fig, ax = plt.subplots()
ax.scatter(dataframe[xaxis], dataframe[yaxis],
c=dataframe[colour], cmap='viridis')
ax.set_xlabel(xaxis)
ax.set_ylabel(yaxis)
if xaxis_scale is not None:
ax.set_xlim(xmin=xaxis_scale[0], xmax=xaxis_scale[1])
if yaxis_scale is not None:
ax.set_ylim(ymin=yaxis_scale[0], ymax=yaxis_scale[1])
filename = f'{plot_name}.png'
plt.savefig(filename)
向 Streamlit UI 添加分隔水平线
为了帮助分隔 UI 并使每个部分突出显示,我们可以使用 st.write('---')
添加水平线。
这将从 Markdown 语言转换为实际的行。
如果你想了解更多关于 st.write
函数的信息,可以查看:如何使用 Streamlit 的 st.write 函数来改善你的 Streamlit 仪表板。
我们的最终代码如下:
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
st.write("---")
sect_col1, sect_col2 = st.columns(2, gap='large')
sect_col1.write("### Section Details")
section_title = sect_col1.text_input("Enter section title")
section_text_summary = sect_col1.text_area("Section Summary")
data_features = df.columns
sect_col2.write("### Data Summary")
data_to_summarise = sect_col2.multiselect("Select features to include in statistical summary",
options=data_features)
st.write("---")
st.write("### Scatterplot Setup")
sub_col1, sub_col2, sub_col3 = st.columns(3)
chart_x = sub_col1.selectbox('X axis', options=data_features)
chart_y = sub_col2.selectbox('Y axis', options=data_features)
chart_z = sub_col3.selectbox('Z axis', options=data_features)
if st.form_submit_button('Generate'):
generate_report(report_title)
创建 Word 报告生成函数
我们的最后一步是创建 generate_report
函数。
这个函数将接收我们从用户那里收集的所有内容,然后将其写入我们的 Word 文档中。
如下代码所示,我们首先需要创建我们的 docx 对象,通过调用 docx.Document()
来完成。
然后,我们开始使用标题和段落的组合来创建报告的每个部分。其中一些利用 f-strings,以便我们可以将文本与输入变量结合起来。
接着,我们将添加之前创建的散点图,这可以通过 doc.add_picture()
完成。
最后一部分包含我们的 dataframe 统计摘要,它调用 add_df_to_docx
函数。
最后,我们将报告保存到 docx
文件中。
def generate_report(report_title, report_author, report_date, report_client,
section_title=None,
section_text_summary=None,
data_stats_summary=None,
graph_figure=None):
doc = docx.Document()
# Add Title Page followed by section summary
doc.add_heading(report_title, 0)
doc.add_paragraph(f'Authored By: {report_author}')
doc.add_paragraph(f'Created On: {str(report_date)}')
doc.add_paragraph(f'Created For: {report_client}')
doc.add_heading(section_title, 1)
doc.add_paragraph(section_text_summary)
# Add Scatter plot
doc.add_heading('Data Visualisation', 2)
doc.add_picture(graph_figure)
# Add dataframe summary
doc.add_heading('Data Summary', 2)
summary_table = add_df_to_docx(doc, data_stats_summary)
summary_table.style = 'LightShading-Accent1'
doc.save('report.docx')
return st.info('Report Generated')
一旦写入函数创建完毕,我们就可以填充用户点击生成按钮时的操作。
首先,我们需要调用 summary_stats
和 scatter_plot_file
函数。这些函数的结果将被传递到 generate_report
函数中。
if st.form_submit_button('Generate'):
summary_stats = create_df_stats_summary(df, data_to_summarise)
scatter_plot_file = create_scatterplot(df, chart_x, chart_y, chart_z,
plot_name='scatter', yaxis_scale=[3,1], )
generate_report(report_title, report_author, report_date, report_client,
section_title, section_text_summary, summary_stats,
graph_figure='scatter.png')
当我们查看我们的应用时,可以填写输入框中的所需信息并点击生成。
Streamlit Word 报告生成器的最终视图。图像由作者提供。
这将创建我们下面看到的 Word 文档。
从 Streamlit 应用生成的报告的第一页。图像由作者提供。
从 Streamlit 应用生成的报告的第二页。图像由作者提供。
摘要
创建报告是任何数据科学或岩石物理工作流程中的关键部分。然而,创建这些报告往往是耗时且繁琐的。
结合使用 Streamlit 创建用户界面和 docx 创建 Word 文档,我们可以帮助减少报告生成的负担,并开始自动化这一过程。
随着大型语言模型(LLMs)的到来,我们可能将这些模型集成到此应用中,以进一步提升其功能,并将自动化提升到一个新的水平。
参考文献
Bormann, Peter, Aursand, Peder, Dilib, Fahad, Manral, Surrender, & Dischington, Peter. (2020). FORCE 2020 Well well log and lithofacies dataset for machine learning competition [数据集]. Zenodo. doi.org/10.5281/zenodo.4351156
Streamlit Word 文档报告生成器的完整代码
以下是生成 Word 报告的 Streamlit 应用的完整代码:
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import docx
st.set_page_config(layout='wide')
def create_df_stats_summary(dataframe, features_to_include):
sub_df = dataframe[features_to_include].copy()
return sub_df.describe()
def create_scatterplot(dataframe, xaxis, yaxis, colour, plot_name,
xaxis_scale= None, yaxis_scale=None):
fig, ax = plt.subplots()
ax.scatter(dataframe[xaxis], dataframe[yaxis],
c=dataframe[colour], cmap='viridis')
ax.set_xlabel(xaxis)
ax.set_ylabel(yaxis)
if xaxis_scale is not None:
ax.set_xlim(xmin=xaxis_scale[0], xmax=xaxis_scale[1])
if yaxis_scale is not None:
ax.set_ylim(ymin=yaxis_scale[0], ymax=yaxis_scale[1])
filename = f'{plot_name}.png'
plt.savefig(filename)
def add_df_to_docx(doc, dataframe):
# Reset the index and get the new shape
dataframe = dataframe.reset_index()
num_rows, num_cols = dataframe.shape
# Add a table to the document with the necessary number
# of rows and columns
table = doc.add_table(rows=num_rows + 1, cols=num_cols)
# Add the header row
for i, col in enumerate(dataframe.columns):
table.cell(0, i).text = str(col)
# Add the data rows
for i, row in dataframe.iterrows():
for j, value in enumerate(row):
table.cell(i + 1, j).text = str(value)
return table
def generate_report(report_title, report_author, report_date, report_client,
section_title=None,
section_text_summary=None,
data_stats_summary=None,
graph_figure=None):
doc = docx.Document()
# Add Title Page followed by section summary
doc.add_heading(report_title, 0)
doc.add_paragraph(f'Authored By: {report_author}')
doc.add_paragraph(f'Created On: {str(report_date)}')
doc.add_paragraph(f'Created For: {report_client}')
doc.add_heading(section_title, 1)
doc.add_paragraph(section_text_summary)
# Add Scatter plot
doc.add_heading('Data Visualisation', 2)
doc.add_picture(graph_figure)
# Add dataframe summary
doc.add_heading('Data Summary', 2)
summary_table = add_df_to_docx(doc, data_stats_summary)
summary_table.style = 'LightShading-Accent1'
doc.save('report.docx')
return st.info('Report Generated')
st.title('Streamlit Data Report Generator')
df = pd.read_csv('Xeek_Well_15-9-15.csv')
with st.form('report'):
st.write("### Report Details")
col1, col2 = st.columns(2, gap='large')
# Setup the title and associated data
report_title = col1.text_input("Enter report title")
report_author = col1.text_input("Enter the report author's name")
report_date = col2.date_input("Select a date for the report")
report_client = col2.text_input("Enter the client's name")
st.write("---")
sect_col1, sect_col2 = st.columns(2, gap='large')
# Setup the first report section and associated data
sect_col1.write("### Section Details")
section_title = sect_col1.text_input("Enter section title")
section_text_summary = sect_col1.text_area("Section Summary")
data_features = df.columns
sect_col2.write("### Data Summary")
data_to_summarise = sect_col2.multiselect("Select features to include in statistical summary",
options=data_features)
st.write("---")
st.write("### Scatterplot Setup")
sub_col1, sub_col2, sub_col3 = st.columns(3)
chart_x = sub_col1.selectbox('X axis', options=data_features)
chart_y = sub_col2.selectbox('Y axis', options=data_features)
chart_z = sub_col3.selectbox('Z axis', options=data_features)
if st.form_submit_button('Generate'):
summary_stats = create_df_stats_summary(df, data_to_summarise)
scatter_plot_file = create_scatterplot(df, chart_x, chart_y, chart_z,
plot_name='scatter', yaxis_scale=[3,1], )
generate_report(report_title, report_author, report_date, report_client,
section_title, section_text_summary, summary_stats,
graph_figure='scatter.png')
感谢阅读。在你离开之前,你应该订阅我的内容,将我的文章直接送到你的邮箱。 你可以在这里操作!或者,你也可以 注册我的通讯 以获取额外内容,直接免费送到你的邮箱。
其次,通过注册会员,你可以获得完整的 Medium 体验,并支持我和其他成千上万的作家。每月仅需 $5,你可以完全访问所有精彩的 Medium 文章,还可以通过写作赚钱。
如果你通过 我的链接注册, 你将直接用你的一部分费用支持我,并且不会额外增加你的费用。如果你这么做了,非常感谢你的支持。
压力测试你的 NLP 模型
原文:
towardsdatascience.com/stress-test-for-your-nlp-models-94dba45b6d83
指标无法显示 NLP 模型的实际性能。学习如何彻底测试你的模型并修复注释伪影。
Alexander Biryukov
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 1 月 30 日
--
数据集伪影是自然语言处理(NLP)中的一个问题,会影响模型在实际环境中的表现。尽管预训练模型在基准数据集上表现良好,但在其他环境中表现却很差。这些失败是由于数据集伪影或注释伪影 —— 是语言模型在训练过程中学到的 1 的虚假相关性。
由 Nathan Dumlao 摄影,图片来源于 Unsplash
事实证明,模型在训练过程中会吸收大量伪影,这些伪影源于数据集的特殊性和注释者带来的偏见。伪影为何有害?基本上,它们为你的模型提供了记忆一些实际是虚假的因果关系的捷径,而不是学习正确的“推理”。例如,如果一个数据集中有很多男性角色是医生的例子,那么模型在推断时会更倾向于男性成为医生,而不是女性。也可能构造一些对抗性示例,使得模型的表现出乎意料地低。
作为数据科学家的你,工作是尽可能消除伪影,同时提高模型的整体性能。
找出所有这些伪影!
首先,你需要了解你的敌人。因此,你需要找到这些伪影。为了找到这些伪影,我们需要定义我们要寻找的东西。总而言之,我们可以列出以下伪影,尽管这份列表并不详尽:
-
模型无法理解比较。
-
模型无法理解强化词。
-
一个模型无法理解分类(同义词、反义词等)
-
一个模型对拼写错误、无关的变化等缺乏鲁棒性
-
一个模型无法适当地处理命名实体
-
一个模型对某些少数群体或性别表现出不公平
-
一个模型无法理解事件的顺序
-
一个模型无法适当地处理否定
-
一个模型无法理解共指
-
一个模型无法理解角色如代理、对象等
-
一个模型对某些触发词不够鲁棒(某些词组“破解”你的模型,使其显示一些不希望看到的结果)
-
一个模型无法处理对抗样本(对抗样本是通过在输入段落中添加干扰句子创建的,但它们既不与正确答案矛盾,也不会混淆人类)
-
以及其他……
现在既然我们知道了它们的面貌,我们想找出伪影。在查看数据集或标注它们时,你可能会发现一些伪影,但没有比训练一个基准模型并进行一些测试更好的方法来找到它们。幸运的是,有一个完美的工具——CheckList 2。它不是灵丹妙药,但它非常有用,因为它帮助分析了上述大多数伪影。
伪影检查
工具和仪器
多亏了作者,有一个很棒的开源工具可以即刻用于一些数据集(如 SQuAD、QQP 等)。让我们仔细看看。
CheckList 是一个测试套件,灵感来源于软件开发中的单元测试。作者创建了一些脚本,这些脚本生成了许多带有特殊“模板”的测试,例如:
ret = editor.template({'question': 'Is this a {adj} movie?',
'context': 'This is a {adj} movie.' },
labels='Yes, this is {adj}.',
adj=['good', 'great', 'awesome', 'excellent'])
print(ret.data[0])
print(ret.labels[0])
print()
print(ret.data[1])
print(ret.labels[1])
print()
这个模板将返回一堆上下文、问题和答案,这些将用于测试你的模型:
{'question': 'Is this a good movie?', 'context': 'This is a good movie.'}
Yes, this is good.
{'question': 'Is this a great movie?', 'context': 'This is a great movie.'}
Yes, this is great.
使用这个工具,你可以生成任何数量的此类示例。它的优点在于你可以自定义提供的测试套件,以添加或编辑任何特定的模板。
此外,一旦你设置好测试模板并准备好测试套件,你可以运行测试并获得一个相当整洁的总结,你甚至可以用小部件可视化这个总结(在 Colab 中不起作用)。这个工具使用起来相当简单,而且文档也很完善。所以花些时间去了解一下吧。
测试结果
那么模型呈现了什么结果呢?作者们在 2019 年对大多数流行的最先进模型进行了大量测试,发现它们都有偏见并存在许多伪影。
作者总结了在以下模型上的 CheckList 测试结果(按图中从左到右的顺序排列):
-
Microsoft 的文本分析
-
Google Cloud 的自然语言
-
Amazon 的 Comprehend
-
BERT-base
-
RoBERTa-base
根据论文,测试模型的失败率在几个测试中都相当高。它们在大多数与否定处理相关的测试中表现糟糕,在情感随时间变化和两个陈述的比较中表现不佳。有趣的是,特别是 BERT 基础的模型也未能正确分类中性情感句子。因此,事实证明,即使在大量数据上训练,大多数 NLP 模型也容易受到伪影的影响。
以示例为例,我和我的同事 Derrick Xu 对 SQuAD 训练的 ELECTRA-small 模型进行了类似的测试。正如预期的那样,我们得到了更差的结果(图 1)。尽管模型本身获得了 86.3 的 F1 分数和 78.5 的准确匹配分数,但它存在许多偏差,这可以在图 1 中看到。
图 1. 针对 SQuAD 数据集训练的 ELECTRA-small 模型的测试结果。
我们还通过使用对抗性 AddOneSent 数据集(见图 2)评估了基线 QA 模型的对抗性表现。
与在 SQuAD 开发集上的性能相比,F1 分数从 86%降至 49.6%,准确匹配分数从 78%降至 42.1%。
图 2:模型在 SQuAD 数据集示例上的预测失败示例。对抗性句子和错误预测的答案以红色标记。原始短语和正确答案以蓝色标记。
因此,除了修复通用语言能力(使用 CheckList 发现),我们还寻求提高模型在对抗性示例上的表现。
现在修复它们!
有几种方法可以对抗标注伪影:
-
对难度较大的数据子集或黄金标签分布模糊的数据进行重新训练。建议使用数据集制图[3]或其他任何方法来寻找这些示例。
-
基于集成的去偏差:使用一个弱模型学习相关性,然后训练你的模型学习该模型的残差[4],或者将其从输出分布中移除。这将使你的主要模型在困难示例上进行额外训练。
-
还有其他方法……
基本上,所有的方法都归结为对模型难以推断的数据进行重新训练。所以最简单的方法就是生成这些数据并重新训练你的模型。这实际上效果很好。
幸运的是,你不必手动编写所有额外的数据。你已经设置了 CheckList 生成工具。因此,你只需要设置模板即可开始使用。
为了修复我们的 ELECTRA-small 模型,我们也使用了 CheckList 工具。以下是生成额外数据以改进比较能力的示例代码。
import checklist
import spacy
import itertools
import json
import checklist.editor
from checklist.test_types import MFT, INV, DIR
from checklist.expect import Expect
from checklist.test_suite import TestSuite
from checklist.perturb import Perturb
import checklist.text_generation
# Template to generate comparison examples
adj = ['large', 'fat', 'fresh', 'kind', 'deep', 'wierd', 'poor', 'clear', 'bold', 'calm', 'clever', 'firm', 'mean', 'quick', 'quiet', 'strong', 'bright', 'light']
adj = [(x.rstrip('e'), x) for x in adj]
temp1 = editor.template(
[(
'{first_name} is {adj[0]}er than {first_name1}.',
'Who is less {adj[1]}?'
),(
'{first_name} is {adj[0]}er than {first_name1}.',
'Who is {adj[0]}er?'
)
],
labels = ['{first_name1}','{first_name}'],
adj=adj,
remove_duplicates=True,
nsamples=1000,
save=True
)
# Generating train extension from comparisons
train_extension_comparison = []
id_n = 0
for string in range(len(temp1['data'])):
for i in range(len(temp1['data'][string])):
index_of_answer = temp1['data'][string][i][0].find(temp1['labels'][string][i])
train_extension_comparison.append({
'id':f'aug{id_n}',
'title':'aug_comparison',
'context':temp1['data'][string][i][0],
'question':temp1['data'][string][i][1],
'answers':{"text": [temp1['labels'][string][i], temp1['labels'][string][i], temp1['labels'][string][i]], "answer_start": [index_of_answer, index_of_answer, index_of_answer]}
})
id_n += 1
# will generate 1996 examples with different compbinations
# of adjectives and first names according to the template.
然后只需将获得的数据倒入并与原始训练数据连接。在我们的案例中,这些数据是 SQuAD 训练数据。不要忘记将新数据与原始数据混合,以避免训练数据的偏斜。
我们为该模型重复生成了几个能力的数据。最终,我们实现了对原始 SQuAD 训练数据的大约 30% 扩展。我们还包含了一些对抗数据,这也是对抗伪影的方法之一。刘等人(2019)[5] 发现,当使用来自挑战数据集的 500 个或更多对抗示例重新训练模型(即 BiDAF,QANet)时,模型性能会提高。因此,我们也将大约 750 个对抗示例包含到扩展数据集中,对整个数据集进行混洗,并重新训练了模型。
我们进行了与之前相同的 CheckList 测试,以下是我们得到的结果。结果如图 3 所示。
图 3. CheckList 测试结果在重新训练模型与基线模型(在 SQuAD 数据集上训练的 ELECTRA-small)之间的比较。
如你所见,结果并不完美。我们显然降低了模型在某些能力上的表现,但主要是那些数据未生成用于重新训练的能力。对于那些我们已增强数据的能力,性能显著提升(失败率降低)。
然而,有些结果并不那么稳定。可能没有足够的数据来提高对否定能力(尽管我们确实针对了这一能力)、核心指代和时间能力的表现。
与此同时,我们成功地提高了对抗数据集上的性能(图 4),并弥补了对抗重新训练在原始开发数据集上带来的性能下降。
就最终整体指标而言,原始开发数据集上的测试结果显示:精确匹配度从 78.5 提升至 78.7,F1 分数从 86.3 下降至 86.2。这实际上是一个好结果,因为根据刘等人(2019)的研究,使用对抗数据进行重新训练会导致原始开发数据集上重新训练的 QA 模型性能显著下降(在我们的案例中,精确匹配度下降至 74.2,F1 分数为 81.9),因此单独进行对抗训练看起来并不太有吸引力。然而,与其他伪影处理技术(如在我们案例中的 CheckList 生成数据增强)结合使用,可以获得更好的性能,并显著提高对抗数据上的表现。
图 4. 重新训练模型在 SQuAD 对抗数据集上的表现。与基线模型相比的改进用绿色突出显示。
总结来说,我们在提高模型能力和解决伪影问题的同时,整体指标大致保持不变。
总结
CheckList + 对抗训练只是分析和修复标注伪影的方法之一。为了显著提升模型性能并消除伪影,你必须同时使用几种方法。
然而,我不能过分强调,你必须考虑这些文档并与之抗争。这是朝着更强健、公平且减少偏见的自然语言处理模型迈出的重要一步!
参考文献
本文由我和Derrick Xu共同提供。
除非另有说明,否则所有图片均由作者提供。
1 Gururangan, S., Swayamdipta, S., Levy, O., Schwartz, R., Bowman, S. R., & Smith, N. A. (2018). 自然语言推理数据中的注释伪影。arXiv 预印本 arXiv:1803.02324。
2 Ribeiro, M. T., Wu, T., Guestrin, C., & Singh, S. (2020). 超越准确性:使用 CheckList 对 NLP 模型进行行为测试。arXiv 预印本 arXiv:2005.04118
[3] Swayamdipta, S., Schwartz, R., Lourie, N., Wang, Y., Hajishirzi, H., Smith, N. A., & Choi, Y. (2020). 数据集制图:通过训练动态映射和诊断数据集。arXiv 预印本 arXiv:2009.10795。
[4] He, H., Zha, S., & Wang, H. (2019). 通过拟合残差来消除自然语言推理中的数据集偏差。arXiv 预印本 arXiv:1908.10763。
[5] Liu, N. F., Schwartz, R., & Smith, N. A. (2019). 通过微调进行免疫:分析挑战数据集的方法。arXiv 预印本 arXiv:1904.02668。
结构化您的云实例启动脚本
原文:
towardsdatascience.com/structuring-your-cloud-instances-startup-scripts-2ce981825b8d
区分首次启动与重启
Jake Teo
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 11 月 9 日
--
大多数机器学习任务通常在初步探索阶段后,会被打包成镜像并部署到本地或云服务器上。这将促进快速迭代,以建立支持 MLOps 管道运行的基础设施,涉及整个开发团队,包括数据科学家、数据、软件和云工程师等。
示例图展示了机器学习任务典型的服务器部署(VM = 虚拟机)。图片由作者提供
启动脚本用于在云服务器实例启动时执行自动化配置或其他任务。这在 AWS EC2 中被称为 用户数据,在 Google Cloud Engine 中称为 启动脚本,在 Azure Virtual Machine 中称为 自定义脚本扩展。启动脚本中的内容可以包括安装、元数据设置、环境变量等。其主要目的是确保每个实例在启动时始终配置为准备好服务于内部或相邻服务的应用程序。
就像我们编写的所有脚本一样,我们应始终使其保持整洁、结构化和集中,以便可以将其作为模板重复使用。这将使您在管理项目中不同实例的多个应用程序时更加轻松。在接下来的部分中,我将展示如何做到这一点。
虽然后面的部分专门针对 AWS EC2 的用户数据,但可以很容易地将其适应于其他提供商,只需使用相同的概念。
1) 首次启动与重启启动脚本
在实例首次启动时使用启动脚本是相当直观的,但重启呢?如果我们使用的是按需实例,并且它们不用于生产环境(例如开发、测试、系统集成测试、用户验收测试),那么在开发人员不在工作时(例如周末或下班后)让它们运行是没有经济意义的。因此,它们会被安排在需要时关闭和重启。在打补丁时也有需要重启的情况。
在这些关闭期间,可能会有应用程序需要的元数据更新。因此,在重启后,这些数据应该被刷新,以反映最新的信息。
实例首次启动和重启所需的示例。图片来源:作者。
从现在开始,用户数据可以在实例首次启动时进行配置,也可以在重启时进行配置。通常,这两种启动方式所需的配置并不相同,但问题在于我们只能将一个用户数据文件附加到每个实例上。那么,我们如何在同一个用户数据文件中区分它们呢?
多部分格式
如果我们只要求在实例首次启动时执行用户数据,脚本只需包含 shell 命令。然而,要使其在每次实例重启时也能执行,则需要一个 cloud-config 命令。这是一个单独的格式,因此 AWS 使用 MIME(多用途互联网邮件扩展)多部分格式来包含这两种信息。
Content-Type: multipart/mixed; boundary="//"
MIME-Version: 1.0
--//
Content-Type: text/cloud-config; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="cloud-config.txt"
#cloud-config
cloud_final_modules:
- [scripts-user, always]
--//
Content-Type: text/x-shellscript; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="userdata.txt"
# your script here
--//--
从上面可以看到 MIME 的定义,随后是云配置,其中 [scripts-user, always]
表示用户数据将在实例首次启动和随后的重启时执行。下一个格式是为 shell 命令量身定制的。
区分首次启动和重启
从技术上讲,AWS 没有用户数据配置来根据首次启动和重启分隔你的脚本。幸运的是,我们可以使用一些简单的脚本优雅地完成这一任务,如下面的伪代码所示。
Content-Type: multipart/mixed; boundary="//"
MIME-Version: 1.0
--//
Content-Type: text/cloud-config; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="cloud-config.txt"
#cloud-config
cloud_final_modules:
- [scripts-user, always]
--//
Content-Type: text/x-shellscript; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="userdata.txt"
#!/bin/bash
# --------------- define functions --------------- #
function install_docker() {
# some installations
}
function create_dotenv() {
# create .env file
}
function setup_docker_compose() {
# setup docker-compose.yml
}
function launch_docker_compose() {
# launch your container
}
# --------------- execute script --------------- #
if [ ! -e "STARTED" ]; then
# on first launch
install_docker
create_dotenv
setup_docker_compose
launch_docker_compose
touch "STARTED";
else
# on restart
create_dotenv
setup_docker_compose
fi
--//--
首先,我们需要将脚本结构化为函数,以便它们可以在首次启动或重启时调用。你可以看到我定义了 install_docker
、create_dotenv
、setup_docker_compose
和 launch_docker_compose
四个函数。应设置适当的参数,使其尽可能可重用。
其次,我们有一个简单的 if-else 语句,当 STARTED
文件不存在时,它将执行所有四个函数,并在末尾创建 STARTED
文件。在该实例重启时,由于 STARTED
文件存在,它将仅运行两个配置函数,而不是其他函数。
这很简单,对吧?下面是一个使用 Ubuntu 虚拟机进一步说明的实际示例。
Content-Type: multipart/mixed; boundary="//"
MIME-Version: 1.0
--//
Content-Type: text/cloud-config; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="cloud-config.txt"
#cloud-config
cloud_final_modules:
- [scripts-user, always]
--//
Content-Type: text/x-shellscript; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="userdata.txt"
#!/bin/bash
# --------------- define functions --------------- #
function install_docker() {
# https://docs.docker.com/engine/install/ubuntu/
sudo apt-get update;
sudo apt-get install -y ca-certificates gnupg lsb-release;
sudo mkdir -p /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --yes --dearmor -o /etc/apt/keyrings/docker.gpg
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
$(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt update
sudo apt-get -y install docker-ce docker-ce-cli containerd.io docker-compose-plugin
}
function create_dotenv() {
ENV=$(curl -s http://169.254.169.254/latest/meta-data/tags/instance/Env)
cd $1
rm -f .env
# Calculate memory limit & reservation for docker container
# 90% limit, 70% reserved
total_memory=$(free -m | awk '/^Mem:/{print $2}')
MEM_LIMIT=$(echo "$total_memory * 0.9" | bc)
MEM_LIMIT=$(printf "%.0f" "$MEM_LIMIT")
MEM_RES=$(echo "$total_memory * 0.7" | bc)
MEM_RES=$(printf "%.0f" "$MEM_RES")
echo "Memory limit: $MEM_LIMIT $MEM_RES MB"
echo MEM_LIMIT=${MEM_LIMIT}M >> .env
echo MEM_RES=${MEM_RES}M >> .env
echo ENV=$ENV >> .env
echo -e "[INFO] dotenv created ==========\n"
}
function setup_docker_compose() {
# pull docker-compose file
CI_TOKEN="get from secrets-manager"
curl --header "PRIVATE-TOKEN: $CI_TOKEN" "https://gitlab.com/api/v4/projects/${1}/repository/files/docker-compose.yml/raw?ref=main" -o ${2}docker-compose.yml
# pull image
AWS_ACCOUNT=$(curl -s http://169.254.169.254/latest/dynamic/instance-identity/document | jq -r .accountId)
AWS_REGION=$(curl -s http://169.254.169.254/latest/meta-data/placement/region)
aws ecr get-login-password --region $AWS_REGION | docker login --username AWS --password-stdin
echo -e "[INFO] docker-compose downloaded & docker logged in ==========\n"
}
function launch_docker_compose() {
docker compose pull
docker compose up -d
echo -e "[INFO] docker image pulled and up ==========\n"
}
# --------------- execute script --------------- #
PROJECTID=12345678
HOMEDIR=/home/ubuntu/
if [ ! -e "STARTED" ]; then
# on first launch
install_docker
create_dotenv $HOMEDIR
setup_docker_compose $PROJECTID $HOMEDIR
launch_docker_compose
touch "STARTED";
else
# on restart
create_dotenv $HOMEDIR
setup_docker_compose $PROJECTID $HOMEDIR
fi
--//--
每个函数的简短描述已经提供。请注意参数的使用,使每个函数都可以重用。
-
install_docker():更新包管理器,安装基础库以及 docker 和 docker-compose。
-
create_dotenv():从实例元数据标签中获取环境的元数据,例如开发、暂存、生产,并将其放入
.env
文件中。 -
set_docker_compose():从源代码库中获取最新的
docker-compose.yml
文件,使用文件中的环境设置镜像标签,然后登录到容器注册表。 -
launch_docker_compose():将镜像部署为容器
2) 集中启动脚本与克服字符限制
用户数据的字符或大小限制分别为 16K 和 16KB。这在大多数使用案例中是一个合理的长度。然而,如果超过此限制,你可以轻松地将脚本存储在像 S3 桶这样的 Blob 存储中,并在用户数据中拉取并执行这些脚本。这也是通过中央存储更新所有用户数据脚本的首选方法。
Content-Type: multipart/mixed; boundary="//"
MIME-Version: 1.0
--//
Content-Type: text/cloud-config; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="cloud-config.txt"
#cloud-config
cloud_final_modules:
- [scripts-user, always]
--//
Content-Type: text/x-shellscript; charset="us-ascii"
MIME-Version: 1.0
Content-Transfer-Encoding: 7bit
Content-Disposition: attachment; filename="userdata.txt"
#!/bin/bash
# --------------- define functions --------------- #
function install_aws_cli() {
sudo apt-get update;
sudo apt-get install -y curl unzip;
sudo curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip";
sudo unzip awscliv2.zip;
sudo ./aws/install;
rm -f awscliv2.zip; rm -rf aws;
}
function download_scripts() {
# download template functions from S3
aws s3 cp s3://<s3.bucket.name>/userdata_template.sh userdata_template.sh
source userdata_template.sh
}
# --------------- execute script --------------- #
PROJECTID=12345678
HOMEDIR=/home/ubuntu/
cd $HOMEDIR
if [ ! -e "STARTED" ]; then
# on first launch
install_aws_cli
download_scripts
install_docker
create_dotenv $HOMEDIR
setup_docker_compose $PROJECTID $HOMEDIR
launch_docker_compose
touch "STARTED";
else
# on restart
download_scripts
create_dotenv $HOMEDIR
setup_docker_compose $PROJECTID $HOMEDIR
fi
我们可以将上述四个函数存储在名为userdata_template.sh
的文件中,并将其放置在你选择的 S3 桶中。
要访问 S3 桶,我们需要确保 1) 实例在实例配置文件中具有读取该桶的相关权限,以及 2) 实例已安装aws-cli
以便使用适当的命令从 S3 中拉取启动脚本。
有了这些,我们可以轻松下载脚本,source
它以访问之前的函数,并根据需要执行它们。
3) 调试
如果用户数据没有按预期执行任务,你可以查看实例中的日志文件,以查看是否捕获到任何错误消息。日志文件位于/var/log/cloud-init-output.log
。
# print the last 100 lines of log file
tail -n 100 /var/log/cloud-init-output.log
如果你需要检查用户数据脚本本身,可以使用以下两种方法。
# print the user data script
curl -s http://169.254.169.254/latest/user-data
# the script itself is stored in this directory
cd /var/lib/cloud/instance/scripts
总结
就这样!希望你从虚拟机的首次启动和重启中学到一些优雅的结构化用户数据的技巧。希望你觉得这些技巧有用且直观。
参考资料
-
docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html
-
docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-add-user-data.html#
将你的机器学习项目与 MLOps 思维相结合进行结构化
原文:
towardsdatascience.com/structuring-your-machine-learning-project-with-mlops-in-mind-41a8d65987c9
MLOps 实践:项目结构化
Chayma Zatout
· 发布于 Towards Data Science ·14 分钟阅读·2023 年 3 月 16 日
--
图片由 Priscilla Du Preez 提供,来源于 Unsplash
如果你希望将机器学习项目提升到一个新的水平,MLOps 是过程中的重要部分。本文将为你提供一个实用的教程,教你如何为 MLOps 结构化你的项目,以经典的手写数字分类问题为例。我们将逐步带你完成创建一个基本项目模板的过程,你可以用来组织自己的项目。通过这个教程,你将对 MLOps 原则有一个扎实的理解,并学会如何将它们应用到自己的项目中。然而,如果你对 MLOps 不熟悉,我们建议你先从我的 适合初学者的教程 开始,以便快速上手。让我们开始吧,将你的机器学习项目提升到一个新水平!
目录:
· 1. 介绍
· 2. MLOps
∘ 2.1. 业务问题
∘ 2.2. 数据工程
∘ 2.3. 机器学习模型工程
∘ 2.4. 代码工程
· 3. 项目结构
∘ 3.1. Cookiecutter 数据科学
· 4. MLOps 项目结构
∘ 4.1. 开始一个新的 MLOps 项目
∘ 4.2. 使用 MLOps 项目模板进行手写数字分类
∘ 4.3. 如何运行你的项目?
· 5. 结论
我的 MLOps 教程:
-
教程 1:MLOps 的关键入门:探索其核心组件
-
教程 2:初学者友好的 MLOps 工作流程介绍
-
教程 3:MLOps 原则介绍
-
教程 4:以 MLOps 为核心构建机器学习项目结构
-
教程 5:实践中的版本控制:数据、ML 模型和代码
-
教程 6:实践中的测试:代码、数据和 ML 模型
-
教程 7:实践中的跟踪:代码、数据和 ML 模型
[我会在发布相关文章时更新此列表]
1. 介绍
在之前的教程中,我们将 MLOps 定义为一种以高效、优化和有序的方式设计、构建和部署机器学习模型的方法。这是通过结合一组技术、实践和工具来实现的,这些技术、实践和工具通常在 MLOps 生命周期的背景下进行讨论。
在 MLOps 生命周期中,了解问题后的第一步是构建您的项目。这通常通过使用模板完成,无论是公司模板、公共模板还是您自己的模板,如我们将在本教程中看到的那样。
在本教程中,我们将以手写数字分类为例。 在之前的教程中,我为 MNIST 分类创建了一个 Github 仓库,项目结构如下:
MNIST_classification
├── dataset_scripts
│ ├── construct_dataset_csv.py
│ ├── construct_dataset_folders.py
│ ├── describe_dataset_csv.py
│ ├── explore_dataset_idx.py
│ └── README.md
├── main_classification_convnet.py
├── main_classification_onehot.py
├── main_classification_single_output.py
├── .gitignore
└── README.md
项目文件夹包括“dataset_scripts”文件夹,该文件夹包含用于操作原始 IDX 格式数据集的脚本(有关更多信息,您可以查看我之前的教程“如何轻松探索您的 IDX 数据集”),用于训练三种不同类型模型的 Python 脚本,一个.gitignore 文件,以及一个 README 文件。由于该项目结构是为教程目的设计的,因此它非常简单。在本教程中,我将介绍我在 MLOps 项目中的项目结构。请注意,如果您想了解有关模型和训练的编程细节,您可以随时参考我的教程“神经网络简要介绍:分类问题”。
2. MLOps
机器学习过程中的不同步骤在 MLOps 工作流程中进行了概述,该工作流程包括业务问题、数据工程、机器学习模型工程和代码工程。在本节中,我们将探讨如何实现每一步。然而,由于我们正在解决的问题(手写数字分类)不需要某些步骤,我们不会深入讨论这些步骤。我们将重点关注用绿色突出显示的步骤(见下图)。剩余步骤将在未来的教程中覆盖。如果你想了解更多关于 MLOps 工作流程的内容,你可以查看我的入门友好教程。
MLOps 工作流程。
2.1. 业务问题
本教程中解决的问题是手写数字分类,这是一个多类分类任务。具体而言,给定一个手写数字的输入图像(范围从 0 到 9),模型需要识别数字并输出其对应的标签。
AI 画布包含以下组件:任务描述、预测(模型输出)、判断、行动、结果、训练、输入、反馈和模型对问题的影响。对于当前的手写数字分类问题,我们的 AI 画布将按如下方式构建和填写:
用于手写数字分类的 AI 画布。
2.2. 数据工程
数据工程
数据工程涵盖了各种任务,如数据摄取、探索和验证、清理、标注和拆分。在这个项目中,我们执行了以下数据工程任务:
-
数据摄取: 我们从其官方网站下载了 MNIST 数据集的原始格式,并将其转换为 CSV 文件。
-
数据探索和验证: 我们可视化了一些数据集中的图像,并展示了一些洞见。
-
数据清理: 数据集已经很干净,无需进一步清理。
-
数据标注: 数据集已经标注完毕,因此不需要额外的标注。
-
数据拆分: 数据集已经被拆分为训练集和测试集。我们将从训练集中提取验证集。
值得注意的是,这个项目涉及的相对简单的数据工程过程,因为数据集已经准备和处理好了。然而,我们将在未来的文章中探讨更复杂的例子。
2.3. 机器学习模型工程
机器学习模型工程
机器学习模型工程是 MLOps 工作流程中的第三步。它包括各种任务,如模型训练、评估、测试和打包。在这个项目中,我们执行了以下机器学习模型工程任务:
-
模型训练:在特征工程中,我们使用了数据缩放(将像素缩放到[0,1]范围内)、数据重塑(将图像表示为 1D 向量或 2D 矩阵)和数据编码(独热编码)。在模型工程中,我们实现了两种不同类型的模型并应用了超参数调优。
-
模型评估:除了准确率,我们还使用了召回率、精确度和 F1 分数等其他评估指标,以确保模型符合 AI 画布中描述的业务目标(结果)。
-
模型测试:在评估模型之后,我们在两种不同类型的数据上进行了测试:第一种是 MNIST 数据集的测试集,第二种是从应用程序生成的一些手写数字图像。
-
模型打包和版本控制将在下一教程中讨论,我们将更详细地介绍机器学习管道。
如果你想了解更多编程细节,可以随时查看我之前的教程。
2.4. 代码工程
代码工程
在代码工程步骤中,选择的模型会被部署到应用中,其性能需要被监控和记录。在部署模型之前,需要仔细选择服务模式和部署策略。部署后,需要管理和维护其行为,以确保其正常运行。尽管这一部分在本教程中没有详细说明,但我计划在不久的将来专门写一篇文章。
3. 项目结构
现在我们已经突出展示了在手写数字分类中应用的不同 MLOps 步骤,接下来让我们着手结构化项目,以满足项目需求,同时考虑这些步骤。为此,我将首先介绍一个著名的项目结构,然后展示我的 MLOps 项目结构模板。该模板将在我们添加更多组件时进行更新。
但为什么正确结构化你的机器学习项目很重要呢?好吧,有几个好处:
-
良好的透明度: 组织有序的项目不仅对你自己,而且对他人也更易于理解。
-
简单的维护: 结构良好的项目更易于维护和更新,从而节省时间和精力。
-
提高效率: 清晰的计划减少了浪费的时间,最小化了偏离方向或丢失重要信息的风险。
-
良好的可复现性和可重用性: 一个良好的项目结构确保项目结果可以轻松复现,并且其组件可以重用。
-
便捷的协作: 当项目组织得清晰且逻辑性强时,其他人更容易理解和参与。
总之,正确构建机器学习项目的结构可以带来更大的透明度、效率、可维护性和协作。
3.1. Cookiecutter 数据科学
正如本文之前提到的,在编写任何代码之前,我们首先需要定义项目结构。这可以通过使用项目结构模板来实现。模板可以是公司为响应公司/项目需求而制定的公司模板,也可以是一个团体或个人创建并发布的公共模板,或者是您自己感觉舒适的自定义模板。
在该领域最著名的项目结构之一是Cookiecutter 数据科学,其结构如下:
一种逻辑合理、标准化但灵活的项目结构,用于进行和共享数据科学工作。
您可以在下面找到此模板的项目结构,以及每个文件的描述:
4. MLOps 项目结构
现在我们已经解释了 MLOps 工作流程的不同步骤是如何执行的,让我们定义一个与 MLOps 工作流程对齐的项目结构模板。Cookiecutter MLOps 模板基于我们之前介绍的 Cookiecutter 数据科学模板。与 Cookiecutter 数据科学类似,我的 Cookiecutter MLOps 模板包括 LICENSE、README、Makefile 和 requirements 文件;以及 docs、models、notebooks、references、reports、visualization 和 source 文件夹。然而,新增了一个文件夹 (configs),并且 source 和 visualization 文件夹得到了增强。
MLOps 项目结构模板具有以下结构:
{{ cookiecutter.repo_name }}/
├── LICENSE
├── README.md
├── Makefile # Makefile with commands like `make data` or `make train`
├── configs # Config files (models and training hyperparameters)
│ └── model1.yaml
│
├── data
│ ├── external # Data from third party sources.
│ ├── interim # Intermediate data that has been transformed.
│ ├── processed # The final, canonical data sets for modeling.
│ └── raw # The original, immutable data dump.
│
├── docs # Project documentation.
│
├── models # Trained and serialized models.
│
├── notebooks # Jupyter notebooks.
│
├── references # Data dictionaries, manuals, and all other explanatory
│ # materials.
│
├── reports # Generated analysis as HTML, PDF, LaTeX, etc.
│ └── figures # Generated graphics and figures to be used in reporting.
│
├── requirements.txt # The requirements file for reproducing the environment.
└── src # Source code for use in this project.
├── __init__.py # Makes src a Python module.
│
├── data # Data engineering scripts.
│ ├── build_features.py
│ ├── cleaning.py
│ ├── ingestion.py
│ ├── labeling.py
│ ├── splitting.py
│ └── validation.py
│
├── models # ML model engineering (a folder for each model).
│ └── model1
│ ├── dataloader.py
│ ├── hyperparameters_tuning.py
│ ├── model.py
│ ├── predict.py
│ ├── preprocessing.py
│ └── train.py
│
└── visualization # Scripts to create exploratory and results
│ # oriented visualizations.
├── evaluation.py
└── exploration.py
configs 文件夹包含所有配置文件,例如模型超参数。
data 文件夹(src 的子文件夹)包括以下文件:
-
ingestion.py: 用于收集数据。如果需要创建备份、保护私人信息或创建元数据目录,最好在这里完成。
-
cleaning.py: 用于通过减少离群值/噪声、处理缺失值等来清理数据。
-
labeling.py: 如果需要,使用该文件对数据进行标注。
-
splitting.py: 用于将数据分为测试集和训练集。
-
validation.py: 用于验证数据(以确保其准备好进行训练)。
-
build_features.py: 该文件已移动到此文件夹,因为构建特征意味着将数据集组织成特定结构。
在 models 文件夹(src 的子文件夹)中,每个模型的脚本都在模型的文件夹中组织,包括:
-
model.py: 用于定义模型架构。
-
dataloader.py: 用于加载数据,以供模型使用。
-
preprocessing.py: 用于在将数据输入模型之前进行预处理。
-
train.py: 用于训练模型。
-
hyperparameters_tuning.py: 用于调整模型和/或训练超参数。
-
predict.py:用于对随机图像进行预测(不是来自数据集)。
可视化 文件夹包括以下内容:
-
exploration.py:此文件包括在数据工程过程中用于可视化数据的函数。
-
evaluation.py:此文件包括用于可视化训练结果的函数。
这是 MLOps 模板,有一些重要的注意事项需要考虑:
-
这是一个基本模板,因此根据你的项目需求,可以删除或添加一些文件和文件夹。
-
一些预处理函数可以在所有模型中使用,因此可以创建一个单独的预处理文件,并将其移动到数据文件夹中,以避免函数重复。然而,建议将预处理文件分开,以提高模型的可重用性,并防止未来潜在的问题。
-
在预测脚本中,假设数据来自应用程序而不是数据集,因此可能需要额外的预处理步骤。
4.1. 启动一个新的 MLOps 项目
如果你想使用此模板启动你的机器学习项目,你可以使用GitHub 模板或使用Cookiecutter 模板,如下所示:
- 要使用GitHub 模板,首先,你需要访问这里的模板页面。然后,点击绿色的按钮‘使用此模板’,你将需要选择是‘创建一个新仓库’还是‘在代码空间中打开’:
GitHub 模板
- 要使用Cookiecutter 模板,你首先需要安装 Cookiecutter,使用:
pip install cookiecutter
或者:
conda config --add channels conda-forge
conda install cookiecutter
然后在命令行中运行此命令:
cookiecutter https://github.com/Chim-SO/cookiecutter-mlops
这里是手写数字分类的示例配置,你可以通过填写所需的参数来进行自定义。按下 Enter 键将保留你不想更改的任何参数的默认值:
project_name [project_name]: MLOps_MLflow_mnist_classification
repo_name [mlops_mlflow_mnist_classification]:
author_name [Your name (or your organization/company/team)]: Chim SO
description [A short description of the project.]: MNIST classification
Select open_source_license:
1 - MIT
2 - BSD-3-Clause
3 - No license file
Choose from 1, 2, 3 [1]: 1
s3_bucket [[OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')]:
aws_profile [default]:
Select python_interpreter:
1 - python3
2 - python
Choose from 1, 2 [1]:
4.2. 使用 MLOps 项目模板进行手写数字分类
在第二部分,我们讨论了手写数字分类任务中 MLOps 工作流的不同步骤。使用 MLOps 模板实现该管道将导致以下项目结构:
MLOps_MLflow_mnist_classification
├── configs
│ ├── cnnbased.yaml
│ └── singleoutput.yaml
├── data
│ ├── external
│ │ └── test
│ │ ├── 0_0.png
│ │ ├── 1_0.png
│ │ ├── 1_1.png
│ │ ├── 3_1.png
│ │ ├── 5_1.png
│ │ ├── 7_0.png
│ │ └── 8_0.png
│ ├── interim
│ ├── processed
│ │ ├── test.csv
│ │ └── train.csv
│ └── raw
│ ├── test_images.gz
│ ├── test_labels.gz
│ ├── train_images.gz
│ └── train_labels.gz
├── LICENSE
├── Makefile
├── MLproject
├── mlruns
├── models
├── README.md
├── requirements.txt
└── src
├── data
│ ├── build_features.py
│ ├── dataloader.py
│ └── ingestion.py
├── models
│ ├── cnnbased
│ │ ├── hyperparameters_tuning.py
│ │ ├── model.py
│ │ ├── predict.py
│ │ ├── preprocessing.py
│ │ └── train.py
│ └── singleoutput
│ ├── hyperparameters_tuning.py
│ ├── model.py
│ ├── predict.py
│ ├── preprocessing.py
│ └── train.py
└── visualization
├── evaluation.py
└── exploration.py
由于我们已经描述了每个文件和文件夹的内容,现在我将重点介绍一些可能有点模糊的最重要步骤。
configs
文件夹包含两个配置文件,每个模型一个。例如,singleoutput.yaml
文件包括模型配置、训练参数、日志参数(将在下一个教程中讨论)以及模型调优参数。
# Data parameters
data:
dataset_path : 'data/processed/'
# Model parameters
model:
name: 'singleoutput'
num_units: 224
num_layers: 5
activation_function : 'sigmoid'
# Training parameters
training:
batch_size: 128
num_epochs: 200
loss_function: 'mae'
metric: 'mse'
# Logging and output parameters
mlflow:
mlruns_path: 'file:models/mlruns'
experiment_name: 'singleOutput'
# Tuning
hyperparameter_tuning:
num_layers: [3, 5]
num_units: [16, 64, 224]
activation_function: ['relu', 'sigmoid']
batch_size: [128, 256]
loss_function: ['mae']
metric: ['mse']
num_epochs: [200]
-
使用
src/data/ingestion.py
,数据首先被下载并存储在data/raw/
中。然后,使用src/data/build_features.py
将其转换为记录结构,并直接存储到data/processed
中。 -
在
data/external
文件夹中,我添加了一个test
子文件夹,其中包括一些随机的手写数字图像。这些图像将由predict.py
脚本用于测试训练模型对新、未见数据的预测。
-
对于这个示例,data/interim 文件夹是空的,因为数据处理管道中没有中间步骤。
-
由于数据集是经典数据集,数据加载器被移动到了
src/data/
,而不是为每个模型重复使用。 -
src/models/<model>/predict.py
脚本概述了预测随机图像类别的管道。与用于训练模型的预处理管道(包括调整大小和缩放)不同,预测管道首先对图像进行裁剪,反转像素,然后调整大小和缩放。
随机图像的数据预处理管道。
MLproject
文件和mlruns
文件夹由 MLflow 库使用,MLflow 是一个用于管理机器学习管道的平台。下一篇文章将详细介绍这个主题,所以如果你不熟悉它,也不用担心。
4.3. 如何运行你的项目?
执行 Python 项目有几种方法:交互式运行(逐行执行),批处理运行(安排定时任务或使用作业调度器),容器化运行(使用 Docker 或 Kubernetes),自动化运行(例如使用 MLflow),或分布式运行(使用像 Apache Spark 这样的分布式计算框架)。由于这不是本文的主要内容,我们将使用最简单的方法:从项目目录执行这些命令。
python src/data/ingestion.py -r data/raw/ # Download data
python src/data/build_features.py -r data/raw/ -p data/processed/ # Create csv files
python -m src.models.cnnbased.train -c configs/cnnbased.yaml # Train CNN model
5. 结论
在这篇文章中,我们提供了一个 MLOps 项目结构模板,并应用于手写数字分类问题。我们展示了如何将 MLOps 工作流应用于解决这个问题,并制定了一个你可以作为Cookiecutter 项目或GitHub 模板使用的项目结构模板。如果你觉得这个模板有帮助,请在 GitHub 上给它一个星标,以便其他人也能发现。如果你是 MLOps 的新手,可以阅读我的初学者友好教程。
在接下来的文章中,我们将继续使用这个示例来覆盖所有 MLOps 工作流和原则。我会写更多关于 MLOps 及其各种技术的教程,并提供示例,请继续关注。
感谢阅读这篇文章。你可以在我的GitHub 个人资料中找到示例项目。如果你有任何问题或建议,请随时留言。
图像来源
文章中所有未在标题中提到来源的图像和图表均为作者提供。
使用分布式随机森林研究美国性别工资差距
原文:
towardsdatascience.com/studying-the-gender-wage-gap-in-the-us-using-distributional-random-forests-ec4c2a69abf0?source=collection_archive---------6-----------------------#2023-02-18
分布式随机森林(DRF)的真实数据分析示例
Jeffrey Näf
·
关注 发表在 Towards Data Science · 13 分钟阅读 · 2023 年 2 月 18 日
--
图片由 Ehimetalor Akhere Unuabona 拍摄,发布于 Unsplash
在之前的两篇文章中,我解释了分布式随机森林(DRFs),这是一种能够估计条件分布的随机森林,以及一种方法的扩展,它允许进行不确定性量化,如置信区间等。这里我展示了一个实际应用的例子,数据来自 2018 年美国社区调查,由美国人口普查局提供。 在第一篇DRF 论文中,我们获得了来自 2018 年美国社区调查的大约 100 万名全职员工的数据,从中提取了薪资信息和所有可能与薪资相关的协变量。这些数据非常适合用来实验 DRF 这种方法(实际上我们将在本分析中只使用一个微小的子集)。
当研究原始时薪数据时,两个性别之间存在一致的差距,即男性往往赚得更多。一个有趣的问题是,男性(G=1)和女性(G=0)之间观察到的时薪差距(W)是否仅仅由于性别,还是可以通过一些其他混杂变量X来解释,这些变量受性别影响并反过来影响工资。也就是说,我们想研究与以下因果图中的粗体箭头对应的效应大小:
假设因果图,G=性别,W=工资,X是混杂变量
例如,假设X仅包括职业,并且女性倾向于选择不涉及高 monetary 奖励的职业,如医生、护士或教师,而男性则倾向于从事专业赌博工作,时薪极高。如果仅凭这一点来解释性别之间的时薪差异,我们仍然会看到直接观察到的时薪差距。然而,如果我们将职业固定为X的医生,并比较这两种工资分布,那么任何统计上显著的差异只能来自性别本身。
我们关注于两阶段分析:
-
我们将X固定为一个特定值,并比较在X=x固定的协变量下两个组的工资分布。这从两个方面来看都很有趣:首先,如果X确实包含所有影响工资且与性别相关的其他因素,那么固定X=x并查看两性工资,就意味着我们真正观察到了性别对工资的影响。其次,它允许对具有给定特征x的个体进行整个工资分布的预测。
-
我们使用上面假设的因果图和因果规则,通过 DRF 估计一个反事实分布:女性工资的分布,假如她们被当作男性来设定工资。如果X包含所有相关协变量,并且不存在性别工资差距,这个分布应该与男性的工资分布相同(忽略统计随机性)。
这篇文章是几个人工作的最终成果:代码和数据集来自原始的DRF 仓库,然后与我们新论文中在arXiv上开发的方法结合,这篇论文由Corinne Emenegger共同撰写。
在继续之前,我想指出这仅仅是一个用来说明 DRF 使用的例子。我并不打算在这里做出任何严肃的(因果)声明,因为分析肯定在某些方面存在缺陷,我们下面假设的因果图肯定是错误的。此外,我们只使用了可用数据中的一小部分。
此外,请注意代码运行速度较慢。这是因为,虽然 DRF 本身是用 C 语言编写的,但用于置信区间的重复拟合目前是用 R 实现的。
话虽如此,让我们深入了解。接下来,除非另有说明,所有图片均由作者提供。
数据
来自 2018 年 1 年期美国社区调查的 PUMS(公共使用微数据区域)数据来自美国人口普查局 API。该调查每年发送给约 350 万人,旨在提供比每十年进行一次的官方普查更为最新的数据。2018 年的数据集包含大约 300 万条匿名数据点,涵盖了 51 个州和哥伦比亚特区。对于上面链接的 DRF 论文,我们仅提取了可能与工资相关的变量子集,如个人的性别、年龄、种族、婚姻状况、教育水平和英语水平。
预处理的数据可以在这里找到。我们首先进行一些进一步的清理:
##Further data cleaning ##
which = rep(TRUE, nrow(wage))
which = which & (wage$age >= 17)
which = which & (wage$weeks_worked > 48)
which = which & (wage$hours_worked > 16)
which = which & (wage$employment_status == 'employed')
which = which & (wage$employer != 'self-employed')
which[is.na(which)] = FALSE
data = wage[which, ]
sum(is.na(data))
colSums(is.na(data))
rownames(data) = 1:nrow(data)
#data = na.omit(data)
data$log_wage = log(data$salary / (data$weeks_worked * data$hours_worked))
## Prepare data and fit drf
## Define X and Y
X = data[,c(
'age',
'race',
'hispanic_origin',
'citizenship',
'nativity',
'marital',
'family_size',
'children',
'education_level',
'english_level',
'economic_region'
)]
X$occupation = unlist(lapply(as.character(data$occupation), function(s){return(substr(s, 1, 2))}))
X$occupation = as.factor(X$occupation)
X$industry = unlist(lapply(as.character(data$industry), function(s){return(substr(s, 1, 2))}))
X$industry[X$industry %in% c('32', '33', '3M')] = '31'
X$industry[X$industry %in% c('42')] = '41'
X$industry[X$industry %in% c('45', '4M')] = '44'
X$industry[X$industry %in% c('49')] = '48'
X$industry[X$industry %in% c('92')] = '91'
X$industry = as.factor(X$industry)
X=dummy_cols(X, remove_selected_columns = TRUE)
X = as.matrix(X)
Y = data[,c('sex', 'log_wage')]
Y$sex = (Y$sex == 'male')
Y = as.matrix(Y)
实际上,这些观察值远远超过我们需要的,我们在此分析中随机抽样了 4'000 个训练数据点。
train_idx = sample(1:nrow(data), 4000, replace = FALSE)
## Focus on training data
Ytrain=Y[train_idx,]
Xtrain=X[train_idx,]
再次说明,这是因为它只是一个示例——实际上,你会希望获取尽可能多的数据点。这 4'000 个数据点的两性工资估计密度绘制在图 1 中,使用了以下代码:
## Plot the test data without adjustment
plotdfunadj = data[train_idx, ]
plotdfunadj$weight=1
plotdfunadj$plotweight[plotdfunadj$sex=='female'] = plotdfunadj$weight[plotdfunadj$sex=='female']/sum(plotdfunadj$weight[plotdfunadj$sex=='female'])
plotdfunadj$plotweight[plotdfunadj$sex=='male'] = plotdfunadj$weight[plotdfunadj$sex=='male']/sum(plotdfunadj$weight[plotdfunadj$sex=='male'])
#pooled data
ggplot(plotdfunadj, aes(log_wage)) +
geom_density(adjust=2.5, alpha = 0.3, show.legend = TRUE, aes(fill=sex, weight=plotweight)) +
theme_light()+
scale_fill_discrete(name = "gender", labels = c('female', "male"))+
theme(legend.position = c(0.83, 0.66),
legend.text=element_text(size=18),
legend.title=element_text(size=20),
legend.background = element_rect(fill=alpha('white', 0.5)),
axis.text.x = element_text(size=14),
axis.text.y = element_text(size=14),
axis.title.x = element_text(size=19),
axis.title.y = element_text(size=19))+
labs(x='log(hourly_wage)')
每小时工资的(无条件)原始对数的估计密度
计算两者工资的百分比中位差,即
(男性中位工资 - 女性中位工资)/(女性中位工资)*100,
我们获得了约 18%的结果。也就是说,在未经调整的数据中,男性的中位薪资比女性高 18%(!)
## Median Difference before adjustment!
quantile_maleunadj = wtd.quantile(x=plotdfunadj$log_wage, weights=plotdfunadj$plotweight*(plotdfunadj$sex=='male'), normwt=TRUE, probs=0.5)
quantile_femaleunadj = wtd.quantile(x=plotdfunadj$log_wage, weights=plotdfunadj$plotweight*(plotdfunadj$sex=='female'), normwt=TRUE, probs=0.5)
(1-exp(quantile_femaleunadj)/exp(quantile_maleunadj))
分析
问题现在变成了这是否真的“不公平”。也就是说,我们假设上述因果图,其中性别(G)影响工资(W),以及协变量X,这些协变量反过来影响W。我们想知道的是性别是否直接影响工资(粗体箭头)。也就是说,如果一个女性和一个具有完全相同特征的男性X=x获得相同的工资,还是因为她的性别她获得了更少的工资。
我们将在两种情况下进行研究。第一种情况是将X=x保持不变,并使用在早期文章中解释的机制。直观地说,如果我们固定性别之外可能影响工资的所有其他协变量,然后比较这两种工资分布,那么任何观察到的差异必须仅由工资造成。
第二种方法尝试对所有可能的X值量化这种差异。通过计算反事实分布来实现这一点。
W(男性,X(女性))。
这个量是一个男性如果具有女性的特征时得到的反事实工资。也就是说,我们询问一个女性在像男性一样对待时的工资。
请注意,这假设了上述因果图是正确的。特别是,它假设X捕捉到除了性别之外的所有相关因素,这些因素会决定工资。可能情况并非如此,因此在本文开头的免责声明。
研究条件分布差异
接下来,我们将x固定到一个任意点:
i<-47
# Important: Test point needs to be a matrix
test_point<-X[i,, drop=F]
以下图片展示了一些包含在该测试点x中的值——我们正在查看具有高中学历、已婚且有 1 个孩子的保育员。使用 DRF,我们可以估计并绘制条件于X=x的两个组的密度:
# Load all relevant functions (the CIdrf.R file can be found at the end of this
# article
source('CIdrf.R')
# predict with the new framework
DRF = predictdrf(drf_fit, x=x)
weights <- DRF$weights
## Conditional Density Plotting
plotdfx = data[train_idx, ]
propensity = sum(weights[plotdfx$sex=='female'])
plotdfx$plotweight = 0
plotdfx$plotweight[plotdfx$sex=='female'] = weights[plotdfx$sex=='female']/propensity
plotdfx$plotweight[plotdfx$sex=='male'] = weights[plotdfx$sex=='male']/(1-propensity)
gg = ggplot(plotdfx, aes(log_wage)) +
geom_density(adjust=5, alpha = 0.3, show.legend=TRUE, aes(fill=sex, weight=plotweight)) +
labs(x='log(hourly wage)')+
theme_light()+
scale_fill_discrete(name = "gender", labels = c(sprintf("F: %g%%", round(100*propensity, 1)), sprintf("M: %g%%", round(100*(1-propensity), 1))))+
theme(legend.position = c(0.9, 0.65),
legend.text=element_text(size=18),
legend.title=element_text(size=20),
legend.background = element_rect(fill=alpha('white', 0)),
axis.text.x = element_text(size=14),
axis.text.y = element_text(size=14),
axis.title.x = element_text(size=19),
axis.title.y = element_text(size=19))+
annotate("text", x=-1, y=Inf, hjust=0, vjust=1, size=5, label = point_description(data[i,]))
plot(gg)
给定X=x的两个性别的对数(每小时工资)密度的估计。此图的代码可以在文章末尾找到。
在这个图中,即使在固定的x情况下,也明显存在工资差异(记住,在这种情况下所有假定的混杂因素都被固定,因此我们实际上只是直接比较工资)。使用 DRF,我们现在估计并测试中位差异。
## Getting the respective weights
weightsmale<-weights*(Ytrain[, "sex"]==1)/sum(weights*(Ytrain[, "sex"]==1))
weightsfemale<-weights*(Ytrain[, "sex"]==0)/sum(weights*(Ytrain[, "sex"]==0))
## Choosing alpha:
alpha<-0.05
# Step 1: Doing Median comparison for fixed x
quantile_male = wtd.quantile(x=data$log_wage[train_idx], weights=matrix(weightsmale), normwt=TRUE, probs=0.5)
quantile_female = wtd.quantile(x=data$log_wage[train_idx], weights=matrix(weightsfemale), normwt=TRUE, probs=0.5)
(medianx<-unname(1-exp(quantile_female)/exp(quantile_male)))
mediandist <- sapply(DRF$weightsb, function(wb) {
wbmale<-wb*(Ytrain[, "sex"]==1)/sum(wb*(Ytrain[, "sex"]==1))
wbfemale<-wb*(Ytrain[, "sex"]==0)/sum(wb*(Ytrain[, "sex"]==0))
quantile_maleb = wtd.quantile(x=data$log_wage[train_idx], weights=matrix(wbmale), normwt=TRUE, probs=0.5)
quantile_femaleb = wtd.quantile(x=data$log_wage[train_idx], weights=matrix(wbfemale), normwt=TRUE, probs=0.5)
return( unname(1-exp(quantile_femaleb)/exp(quantile_maleb)) )
})
varx<-var(mediandist)
## Use Gaussian CI:
(upper<-medianx + qnorm(1-alpha/2)*sqrt(varx))
(lower<-medianx - qnorm(1-alpha/2)*sqrt(varx))
这给出了中位差异的置信区间。
(0.06, 0.40) 或 (6%, 40%)
这个区间非常明显地不包含零,因此中位差异确实是显著的。
使用 Witobj 函数,我们可以更清楚地显示这种差异。
Witobj<-Witdrf(drf_fit, x=test_point, groupingvar="sex", alpha=0.05)
hatmun<-function(y,Witobj){
c<-Witobj$c
k_Y<-Witobj$k_Y
Y<-Witobj$Y
weightsall1<-Witobj$weightsall1
weightsall0<-Witobj$weightsall0
Ky=t(kernelMatrix(k_Y, Y , y = y))
out<-list()
out$val <- tcrossprod(Ky, weightsall1 ) - tcrossprod(Ky, weightsall0 )
out$upper<- out$val+sqrt(c)
out$lower<- out$val-sqrt(c)
return( out )
}
all<-hatmun(sort(Witobj$Y),Witobj)
plot(sort(Witobj$Y),all$val , type="l", col="darkblue", lwd=2, ylim=c(min(all$lower), max(all$upper)),
xlab="log(wage)", ylab="witness function")
lines(sort(Witobj$Y),all$upper , type="l", col="darkgreen", lwd=2 )
lines(sort(Witobj$Y),all$lower , type="l", col="darkgreen", lwd=2 )
abline(h=0)
这导致了图示:
工资的条件见证函数的估计,男性减去女性的工资。
我们参考相关文章以获得对该概念的更详细解释。本质上,它可以被视为
在给定 x 的情况下,男性的工资对数的条件密度 — 在给定 x 的情况下,女性的工资对数的条件密度
即,条件见证函数显示了一个组的密度高于另一个组的位置,而无需实际估计密度。在这个例子中,负值表示女性工资的密度在给定x的情况下高于男性工资的密度,正值表示女性工资的密度较低。由于我们已经估计了上述的条件密度,条件见证函数本身并不会增加太多信息。但它对于说明情况很有用。确实,我们看到它在开始时是负的,对于条件密度的女性工资高于条件密度的男性工资的值。相反,它在更大的值下变为正值,对于这些值,男性工资的条件密度高于女性工资。因此,关于两个密度的相关信息在见证函数图中总结:我们看到女性工资的密度在较低工资值时较高,而在较高工资值时较低,表明密度向左偏移,女性赚得更少!此外,我们还可以提供包含真实函数 95%的 95%置信区间(绿色),在所有 y 值上均匀分布。 (尽管实际上需要大量的数据才能使其有效)由于这个均匀置信区间在 2 到 2.5 左右的零线之间不包含,我们再次看到这两个分布之间的差异在统计上是显著的。
对特定的x进行条件化,使我们能够详细研究个体效应,并具有不确定性的概念。然而,研究整体效应也很有趣。我们将在下一节中通过估计反事实分布来实现这一点。
估计反事实分布
使用我们假设的因果图的因果性计算法则,可以推导出:
即,我们寻找的反事实分布是通过对性别为女性的x,平均条件分布W | G=male,X=x获得的。
由于分布被给定为简单的权重,这可以通过以下方式轻松完成:DRF
## Add code
## Male is 1, Female is 0
# obtain all X from the female test population
Xtestf<-Xtest[Ytest[,"sex"]==0,]
# Obtain the conditional distribution of W | G=male, X=x, for x in the female
# population.
# These weights correspond to P(W, G=male | X=x )
weightsf<-predictdrf(drf_fit, x=Xtestf)$weights*(Ytrain[, "sex"]==1)
weightsf<-weightsf/rowSums(weightsf)
# The counterfactual distribution is the average over those weights/distributions
counterfactualw<-colMeans(weightsf)
这导致了以下的反事实密度估计:
plotdfc<-rbind(plotdfc, plotdfunadj[plotdfunadj$sex=='female',])
plotdfc$sex2<-c(rep(1, length(train_idx)), rep(0,nrow(plotdfunadj[plotdfunadj$sex=='female',])))
plotdfc$sex2<-factor(plotdfc$sex2)
#interventional distribution
ggplot(plotdfc, aes(log_wage)) +
geom_density(adjust=2.5, alpha = 0.3, show.legend=TRUE, aes(fill=sex2, weight=plotweight)) +
theme_light()+
scale_fill_discrete(name = "", labels = c("observed women's wages", "wages if treated as men"))+
theme(legend.position = c(0.2, 0.98),
legend.text=element_text(size=16),
legend.title=element_text(size=20),
legend.background = element_rect(fill=alpha('white', 0)),
axis.text.x = element_text(size=14),
axis.text.y = element_text(size=14),
axis.title.x = element_text(size=19),
axis.title.y = element_text(size=19))+
labs(x='log(hourly wage)')
这两个密度现在分别是红色的女性工资密度,以及如果女性被当作男性来设置工资的话的绿色-蓝绿色密度。显然,现在这些密度比之前更接近——调整了混杂因素使得性别薪酬差异变小。然而,中位数差异仍然
quantile_male = wtd.quantile(x=plotdfc$log_wage[plotdfc$sex2==1], weights=counterfactualw, normwt=TRUE, probs=0.5)
quantile_female = wtd.quantile(x=plotdfunadj$log_wage, weights=plotdfunadj$plotweight*(plotdfunadj$sex=='female'), normwt=TRUE, probs=0.5)
(1-exp(quantile_female)/exp(quantile_male))
0.11 或 11 百分比!
因此,如果我们的分析是正确的,那么 11%的薪资差异仍然可以归因于性别。换句话说,虽然我们将未调整数据中 18%的中位收入差异减少到 11%,但仍然存在实质性的差异,表明性别之间存在“非公平”的工资差距(至少如果X确实捕捉到了相关的混杂因素)。
结论
在这篇文章中,我们研究了如何将 DRF 应用于实际数据分析的一个例子。我们探讨了固定的x的情况,对于这种情况,本文讨论的方法允许构建不确定性度量,以及反事实量的分布。在这两种情况下,我们都看到在调整可用的混杂变量时,仍然存在实质性差异,在固定的x情况下尤其显著。
虽然我没有检查,但看到这个小实验的结果与更严肃的分析相比可能会很有趣。无论如何,我希望这篇文章展示了 DRF 如何在实际数据分析中使用。
额外代码
## Functions in CIdrf.R that is loaded above ##
drfCI <- function(X, Y, B, sampling = "binomial",...) {
### Function that uses DRF with subsampling to obtain confidence regions as
### as described in https://arxiv.org/pdf/2302.05761.pdf
### X: Matrix of predictors
### Y: Matrix of variables of interest
### B: Number of half-samples/mini-forests
n <- dim(X)[1]
# compute point estimator and DRF per halfsample S
# weightsb: B times n matrix of weights
DRFlist <- lapply(seq_len(B), function(b) {
# half-sample index
indexb <- if (sampling == "binomial") {
seq_len(n)[as.logical(rbinom(n, size = 1, prob = 0.5))]
} else {
sample(seq_len(n), floor(n / 2), replace = FALSE)
}
## Using refitting DRF on S
DRFb <-
drf(X = X[indexb, , drop = F], Y = Y[indexb, , drop = F],
ci.group.size = 1, ...)
return(list(DRF = DRFb, indices = indexb))
})
return(list(DRFlist = DRFlist, X = X, Y = Y) )
}
predictdrf<- function(DRF, x, ...) {
### Function to predict from DRF with Confidence Bands
### DRF: DRF object
### x: Testpoint
ntest <- nrow(x)
n <- nrow(DRF$Y)
## extract the weights w^S(x)
weightsb <- lapply(DRF$DRFlist, function(l) {
weightsbfinal <- Matrix(0, nrow = ntest, ncol = n , sparse = TRUE)
weightsbfinal[, l$indices] <- predict(l$DRF, x)$weights
return(weightsbfinal)
})
## obtain the overall weights w
weights<- Reduce("+", weightsb) / length(weightsb)
return(list(weights = weights, weightsb = weightsb ))
}
Witdrf<- function(DRF, x, groupingvar, alpha=0.05, ...){
### Function to calculate the conditional witness function with
### confidence bands from DRF
### DRF: DRF object
### x: Testpoint
if (is.null(dim(x)) ){
stop("x needs to have dim(x) > 0")
}
ntest <- nrow(x)
n <- nrow(DRF$Y)
coln<-colnames(DRF$Y)
## Collect w^S
weightsb <- lapply(DRF$DRFlist, function(l) {
weightsbfinal <- Matrix(0, nrow = ntest, ncol = n , sparse = TRUE)
weightsbfinal[, l$indices] <- predict(l$DRF, x)$weights
return(weightsbfinal)
})
## Obtain w
weightsall <- Reduce("+", weightsb) / length(weightsb)
#weightsall0<-weightsall[, DRF$Y[, groupingvar]==0, drop=F]
#weightsall1<-weightsall[,DRF$Y[, groupingvar]==1, drop=F]
# Get the weights of the respective classes (need to standardize by propensity!)
weightsall0<-weightsall*(DRF$Y[, groupingvar]==0)/sum(weightsall*(DRF$Y[, groupingvar]==0))
weightsall1<-weightsall*(DRF$Y[, groupingvar]==1)/sum(weightsall*(DRF$Y[, groupingvar]==1))
bandwidth_Y <- drf:::medianHeuristic(DRF$Y)
k_Y <- rbfdot(sigma = bandwidth_Y)
K<-kernelMatrix(k_Y, DRF$Y[,coln[coln!=groupingvar]], y = DRF$Y[,coln[coln!=groupingvar]])
nulldist <- sapply(weightsb, function(wb){
# iterate over class 1
wb0<-wb*(DRF$Y[, groupingvar]==0)/sum(wb*(DRF$Y[, groupingvar]==0))
wb1<-wb*(DRF$Y[, groupingvar]==1)/sum(wb*(DRF$Y[, groupingvar]==1))
diag( ( wb0-weightsall0 - (wb1-weightsall1) )%*%K%*%t( wb0-weightsall0 - (wb1-weightsall1) ) )
})
# Choose the right quantile
c<-quantile(nulldist, 1-alpha)
return(list(c=c, k_Y=k_Y, Y=DRF$Y[,coln[coln!=groupingvar]], nulldist=nulldist, weightsall0=weightsall0, weightsall1=weightsall1))
}
### Code to generate plots
## Step 0: Choosing x
point_description = function(test_point){
out = ''
out = paste(out, 'job: ', test_point$occupation_description[1], sep='')
out = paste(out, '\nindustry: ', test_point$industry_description[1], sep='')
out = paste(out, '\neducation: ', test_point$education[1], sep='')
out = paste(out, '\nemployer: ', test_point$employer[1], sep='')
out = paste(out, '\nregion: ', test_point$economic_region[1], sep='')
out = paste(out, '\nmarital: ', test_point$marital[1], sep='')
out = paste(out, '\nfamily_size: ', test_point$family_size[1], sep='')
out = paste(out, '\nchildren: ', test_point$children[1], sep='')
out = paste(out, '\nnativity: ', test_point$nativity[1], sep='')
out = paste(out, '\nhispanic: ', test_point$hispanic_origin[1], sep='')
out = paste(out, '\nrace: ', test_point$race[1], sep='')
out = paste(out, '\nage: ', test_point$age[1], sep='')
return(out)
}
数据科学成功秘诀:你在大学里没有学到的 4 项关键技能
原文:
towardsdatascience.com/succeeding-in-data-science-4-essential-skills-you-didnt-learn-in-university-1920815acef3
如何弥合学术界与数据科学领域就业之间的差距
Tomer Gabay
·发表于 Towards Data Science ·6 分钟阅读·2023 年 5 月 11 日
--
由 krakenimages 提供的照片,来源于 Unsplash
自从从荷兰的一所大学毕业后,我在不同组织和公司担任了多个数据科学家的职位。让我和一些其他毕业生感到惊讶的是,大学所教的技能与成为一名有价值的数据科学员工所需的技能之间的差距。在这篇文章中,我想强调大学所教的技能和我发现对数据科学员工表现良好所需的四大技能差距。
我在每一节中都包含了有用的资源,以帮助你提升在该领域的技能,若你希望提高你的熟练程度。
构建扎实的数据科学项目
在大学里,通常会使用 Jupyter Notebook 进行“笔记本式”的工作。我在大学时设置的最稳固项目结构是从main.py文件调用不同的模块,并使用requirements.txt。但是使用setup.py?或者pyproject.toml?没有。
公司对数据科学项目的要求往往大相径庭。你的数据科学项目通常应该被构建成一个可安装的包,例如通过pip install <package_name>
。有时需要一个requirements.txt,但至少应该在pyproject.toml中指定包的依赖(现在比setup.py更受欢迎,参见PEP-518)。
除了可以安装外,你的数据科学项目通常还应能够在云中作为基于 Docker 的 API 运行。通过使用 REST 请求调用你的 API,可以通过管道处理数据,甚至用机器学习模型进行预测。
如果你发现自己也主要是在以 'Notebook 风格' 进行数据科学项目,或者你不熟悉如何构建一个可安装的 Python 项目,我建议阅读例如:
## 如何将你的 Python 项目转换为通过 pip 安装的包
一个带有现成模板的教程,描述了如何将 Python 项目转换为可以在…中使用的包。
towardsdatascience.com [## 如何启动任何专业的 Python 包项目。
包括测试自动化、在 ReadTheDocs 上创建文档以及发布到 PyPi。
medium.com](https://medium.com/mlearning-ai/how-to-start-any-professional-python-package-project-9f66538ebc2?source=post_page-----1920815acef3--------------------------------)
如果你能够构建一个可安装的包,但不确定如何将数据科学项目构建并运行为基于 Docker 的 API,请阅读:
## 在 AWS 上将机器学习模型部署为 API
逐步指南,构建模型 API、使用 Docker 容器化并通过 AWS Elastic 部署到网络。
towardsdatascience.com
在基于云的环境中工作
现在,几乎每家公司和组织都使用至少一些基于云的软件。目前,最受欢迎的有 Microsoft Azure、GCP 和 AWS。在我的大学里,我几乎没有使用这些云平台上提供的资源。然而,在工作中,数据科学家对云平台的知识几乎与他们的编程技能一样重要。以下是我作为数据科学家迄今为止需要完成的(云)工程任务的一个例子:
-
在 Docker 容器中部署 Python 应用程序
-
构建 CI/CD 管道
-
在云中托管 Python 包作为工件
-
构建数据管道
如果你对上述某个平台有一定的云知识,这在申请数据科学职位时已经是一个巨大的优势。每个平台都有自己的学习路线和证书来证明你在云计算方面的知识。每个平台都有不同角色的不同路线。以下是每个平台针对数据科学家的相关路线:
-
Azure DP-100
-
GCP 机器学习工程师
-
AWS 数据分析
获得这样的证书无疑会让你作为(潜在的)数据科学员工更具价值。
与其他数据科学家合作
大学里的数据科学和现实生活中的数据科学之间最重要的区别之一是,在现实生活中,你必须持续不断地协作,而在大学里,通常只有少数几个团队项目贯穿整个本科或硕士阶段。
到目前为止,我工作过的每家公司和组织都使用了 scrum 或 DevOps。我没有遇到过一次面试中对这两种方法的知识和/或经验没有加分的情况。除了使用 scrum 或 DevOps,代码审查也是数据科学家工作中的关键任务。这意味着你既需要能够编写清晰易读的代码以供他人审查,也需要能够快速且批判性地评估他人的代码。你可以在下面的链接中找到一篇非常有趣的关于如何为代码审查提供建设性反馈的文章:
## 如何像人类一样进行代码审查(第一部分)
最近,我在阅读关于代码审查最佳实践的文章。我注意到这些文章专注于发现...
mtlynch.io
此外,你不能再用不明确的函数或变量名称,或不遵守 Python 的 PEP-8 命名规范 的对象来应付了。如果你想了解更多关于如何编写高质量代码的内容,可以参考下面的文章:
## 区分高级开发者和初级开发者的 6 个 Python 最佳实践
如何编写被视为来自经验丰富开发者的 Python 代码。
towardsdatascience.com
处理现实生活中的数据
在大学里,学生遇到的大多数数据集已经经过大量预处理。这些数据集的列通常具有有意义的名称,错误的条目已被删除,数据类型也已经正确配置。你在公司或组织中遇到的数据通常远不如大学里遇到的数据标准,除非你在一家真正的数据驱动(技术)公司工作。
我在工作中遇到的一些杂乱的真实数据示例:
-
一些犯罪的数据,其中某些犯罪的实施日期在未来。
-
没有任何命名规范的列名 [
name
,Address
,JobDescription
,place_of_birth
] -
重复的列具有不同的值 [
job_title
JobTitle
] -
不同时区的 DateTime 值而不涉及时区问题。
由于真实数据比为学生准备的数据要复杂得多,作为数据科学家,你要准备好花费大部分时间与业务方沟通,询问列和数值的解释、数据清理以及合并来自不同来源的数据,而不是进行数据分析或构建机器学习模型。
欲了解更多数据清理的信息,例如:
## 数据清理的终极指南
当数据产生垃圾时
[towardsdatascience.com
如果你想在一些真实(杂乱)数据集上练习,这里有十个数据集供你练习数据清理技能!
总结
大学教授的数据科学技能可能不足以在数据科学员工职位上出色表现。作为在多个组织工作过的数据科学家,我注意到四个主要技能缺口需要解决:
-
建立稳固的数据科学项目。
-
在基于云的环境中工作。
-
与其他数据科学家合作。
-
处理真实数据。
利用本文中提到的资源,你可以帮助缓解在这些高度重视的领域中可能存在的知识或技能缺乏!
当然,每所大学和每个专业都有自己的数据科学课程,一些大学可能在为你准备成为数据科学员工方面做得更好或更差。此外,我毕业已经有几年了,因此也许大学的数据科学课程在云领域等方面已经有所改进,例如增加了(更多)与云相关的课程。
资源
稳固的 Python 项目 packaging.python.org/en/latest/tutorials/packaging-projects/
medium.com/r/url=https%3A%2F%2Ftowardsdatascience.com%2Fhow-to-convert-your-python-project-into-a-package-installable-through-pip-a2b36e8ace10
medium.com/mlearning-ai/how-to-start-any-professional-python-package-project-9f66538ebc2
towardsdatascience.com/deploy-a-machine-learning-model-as-an-api-on-aws-43e92d08d05b
基于云的环境 learn.microsoft.com/en-us/certifications/azure-data-scientist/
cloud.google.com/learn/certification/machine-learning-engineer
aws.amazon.com/certification/certified-data-analytics-specialty/
与其他数据科学家合作 www.scrum.org/resources/what-scrum-module
en.wikipedia.org/wiki/DevOp
mtlynch.io/human-code-reviews-1/
peps.python.org/pep-0008/
medium.com/towards-data-science/6-python-best-practices-that-distinguish-seniors-from-juniors-84199d4cac3c
与现实数据打交道 towardsdatascience.com/the-ultimate-guide-to-data-cleaning-3969843991d4
analyticsindiamag.com/10-datasets-for-data-cleaning-practice-for-beginners/
通过技术图示实现 ML 项目的成功
原文:
towardsdatascience.com/success-in-ml-projects-through-technical-drawings-69dd8d2744a4
通过技术图示改进机器学习(ML)项目的工作流程和期望管理
本杰明·图勒
·发表于数据科学前沿 ·阅读时间 8 分钟·2023 年 5 月 30 日
--
机器学习(ML)项目在商业中变得越来越受欢迎,因为组织力图通过利用人工智能获得竞争优势或增加市场价值。然而,与传统的软件开发或分析项目不同,机器学习项目的调度可能更加具有挑战性,因为项目的成功往往在完成之前很难预知,且工作流程较为松散。也可以这样说:
“你不知道 ML 项目是否成功,直到模型开发完成并投入生产”
换句话说,你希望在开始一个大型 ML 项目之前,通过有结构的工作流程和合理的期望来确保你的团队能够成功(如果你想了解更多关于 ML 的一般知识,你可以找到最好的教程来自Cassie Kozyrkov)。成功的一个关键因素是有效的沟通,这能够促成合适的工作流程和项目管理。
为什么选择技术图示?
用语言进行沟通是困难的,尤其是当人们讲的语言有所不同(业务方面和技术方面)。此外,描述复杂关系的语言需要大量的文本或长时间的会议。然而,图示的好处在于容易理解,并且可以非常直观(当然前提是制作得当)。有时候,一幅图示(或图片)可以替代 1000 个词。
我和我的团队一开始制作了一个技术图,主要是为了帮助我们自己保持对机器学习项目的总体概览,并确保我们有一个标准化的工作流程。随着时间的推移,我意识到使用这样的图不仅对未来项目有帮助,还可以帮助公司内部的其他人和领导了解一个机器学习项目的内容。到目前为止的整体反馈非常好,所以我想展示我们所做的工作,以及我们如何利用这些图进行沟通、设定截止日期和期望。
技术图:从简单到复杂
复杂的技术图可能令人不知所措,因此最好从简单开始,之后再添加更多层。这是一个机器学习项目的潜在初始图:
“内循环”的简化图,展示了如何使用特征和目标来迭代训练模型。
如图所示,这个技术图非常高层次,展示了如何通过输入数据生成训练模型所需的两个必要成分:特征和目标。用红色突出显示的是所谓的“内循环”,它描述了使用数据来训练和改进模型的简化迭代过程。
在我们添加更多细节之前,先完成这个过程,并添加模型训练后的情况:
“外循环”部分的简化图,展示了如何将训练好的模型投入生产并产生输出。
这个图现在添加了生产管道和部分所谓的“外循环”,其中特征按照计划输入模型,最终预测被输出到数据集或直接给用户。
这是我们拥有的高层次但完整的机器学习项目周期技术图。现在是时候添加一些缺失且非常重要的细节了:
监控对于每一个投入生产的模型都是至关重要的,必须在图中体现。
监控对于每个产品都是必不可少的,尤其是对于机器学习。特征的轻微变化可能会对预测产生巨大影响,反之亦然。如果你想防止客户替你做监控(即等到他们抱怨时才采取行动),建议在模型训练前后密切监控数据。如果你的模型会实时再训练和更新,则需要额外的监控(在图中未显示)。
额外的细节现在非常依赖于图的受众。这个例子的受众一方面是经验丰富的数据科学家组成的机器学习团队,另一方面是需要了解过程的部分技术领导团队。鉴于主要的技术背景,我会首先添加有关特征存储的更多细节:
良好的特征提取是模型性能的关键,而一个可扩展且可维护的特征提取过程将确保更快的周转。
由于我们处理的是表格数据(随时间变化的地理空间数据),我们不会“仅仅”让数据科学家和数据工程师独自面对抽象的“特征存储”概念。我们将详细定义前期认为对改进模型最有价值的内容,并建立一个能够轻松处理所需数据类型的数据管理系统。例如,我们将不具备时间线的静态上下文与具有时间线但不定期更新的历史动态上下文分开。此外,我们将具有实时数据流的持续动态上下文分开。最终,所有这些不同的数据集汇聚成一个统一的数据集,我们称之为特征存储,以便所有人可以在一个地方拉取和测试新特征,并要求添加更多特征。
对于数据科学家来说,这是一种很好的情况。特征已经可用并以标准方式格式化,因此可以立即开始有趣的工作(训练模型)。但在数据科学家可以开始之前,我们需要调整图中的内部循环:
内部循环的详细完善,包括添加预处理、将数据集分为 4 个用于建模的部分,以及评估层级。
这张图展示了对内部循环的详细完善,添加了 3 个重要方面到训练周期中:
-
预处理: 为了使特征存储具有一定的可扩展性,任何对现有特征的处理(例如,为自回归创建滞后版本)将在预处理时通过视图/指针函数即时完成,而不是作为另一种冗余特征添加到存储中。
-
分成 4 个数据集: 模型的训练是一个迭代过程。然而,为了改进这些迭代,我们需要获取关于如何改进模型的见解,这样会自动引入可能导致过拟合的偏差。因此,我们创建一个用于训练的数据集,一个用于训练验证的数据集,一个用于调试和获得迭代改进见解的数据集,以及一个作为项目最终测试的数据集,以判断模型是否符合成功标准。
-
评估层级: 考虑到这 4 个数据集,我们还引入了基于验证、调试或测试集的 3 种不同评估。在开始处理内部循环之前,重要的是对这些评估中的“成功”进行对齐并记录下来。只有当测试评估按计划进行时,模型才算成功并准备投入生产。
这样,我们几乎完成了一个中等详细度的机器学习项目技术图纸,只需添加一个小的部分:
最终的政策层确保过滤掉或替换意外的模型行为。
使用机器学习的一个缺点是模型总会给出结果,无论结果是否合理。因此,无论你的训练数据有多好,由于你无法确保新进入的数据遵循训练数据的规范,你应该预期模型行为有时会出现不愉快的情况。政策层的理念是过滤或替换这些行为。例如,一个预测某地访问人数的模型可能会在某些情况下预测负数。政策层检测到这种情况并将其替换为硬编码的数字或另一个位于其上的模型的输出。
使用颜色或形状进行规划并设置时间表
我们现在已经制定了涵盖整个机器学习项目的技术图纸。然而,即使我们一步一步地进行并逐层建立技术图纸,查看图纸仍然可能令人不知所措。想象一下你需要更多细节。增加的复杂性会使图纸越来越不直观,因此变得无用。
这时直观的颜色编码或形状可以派上用场。例如,我们在图纸中使用颜色编码来跟踪进度。例如,当新的功能被添加到功能存储并且代码在生产中运行时,功能存储任务将标记为绿色。这会通知整个团队(和公司)当前项目的进展情况。
绿色形状突出显示“完成”的任务。橙色标记“进行中”的任务。所有其他任务则处于“即将开始”状态。
此外,与团队讨论图纸将使设置截止日期变得更容易。可以将图纸拆分成子部分,对于每个子部分,工作所需时间相对可预测。例如,设置功能存储或将现有模型投入生产是可预测的工作任务。然而,内部循环则更难预测,但这里也可以设定估算。例如,通过将内部循环的时间限制为 X 周。到时,测试评估标准必须得到满足,否则,项目将被降级优先级(但可能会在后续被重新考虑)。
总结
从利用机器学习的产品创意到生产就绪状态的过程可能面临独特的挑战。由于其复杂性,机器学习项目难以估算,而且在项目管理薄弱的情况下,只有有限数量的项目能够达到生产状态。技术图纸可以通过提供视觉清晰度和有效沟通来克服这些限制。以直观的方式使项目的整体复杂性可见将有助于团队和整个组织了解项目内容并设定合理的估算。通过这种方式,技术图纸可以促进整个组织更有效的沟通和工作流程。
除非另有说明,所有图片均为作者提供。
总结最佳实践以进行提示工程
原文:
towardsdatascience.com/summarising-best-practices-for-prompt-engineering-c5e86c483af4?source=collection_archive---------0-----------------------#2023-05-29
如何使用 OpenAI API 构建自己的基于 LLM 的应用
Dmytro Nikolaiev (Dimid)
·
关注 发表在Towards Data Science · 13 分钟阅读 · 2023 年 5 月 29 日
--
图片由Glenn Carstens-Peters拍摄,来源于Unsplash。
提示工程指的是为大型语言模型(LLMs),如 OpenAI 的 ChatGPT,创建称为提示的指令的过程。利用 LLMs 解决各种任务的巨大潜力,借助提示工程可以节省大量时间,并促进令人印象深刻的应用程序的开发。它是释放这些大型模型全部能力的关键,改变我们与这些模型的互动方式及其带来的好处。
在这篇文章中,我尝试总结提示工程的最佳实践,以帮助你更快地构建基于 LLM 的应用程序。虽然这个领域发展迅速,但以下这些“经过时间考验”的:) 技术往往效果很好,并能让你取得出色的成果。特别是,我们将涵盖:
-
迭代提示开发的概念,使用分隔符和结构化输出;
-
思维链推理;
-
少样本学习。
结合直观的解释,我将分享实际的示例和未来调查的资源。
然后我们将探索如何使用OpenAI API免费构建一个简单的基于 LLM 的本地应用程序。我们将使用 Python 来描述逻辑,使用Streamlit 库来构建网页界面。
让我们开始吧!
提示工程的最佳实践
在这篇文章中,我将通过网页界面和 API 与 ChatGPT 互动。我将使用的gpt-3.5-turbo
模型是 ChatGPT 背后的模型,因此你可以直接在浏览器中实验你的提示。
这里一个重要的点是,ChatGPT 不仅仅是大型语言模型(LLM);如你所知,它还是一个 SFT(监督微调)模型,经过了来自人类反馈的强化学习(RLHF)的进一步微调。虽然许多开发者目前利用 OpenAI 的模型进行实验项目和个人探索,但由于隐私和其他原因,其他模型可能更适合在大型企业的生产环境中部署。
如果你想知道为什么基础模型(如GPT-3,Chinchilla,LLaMA)的功能与微调和 RLHF 训练的助手(如ChatGPT,Koala,Alpaca)不同,可以参考Andrej Karpathy 关于训练和使用类似 GPT 模型的讲座。我强烈推荐查看这个讲座以获得更深入的理解,时间只有 40 分钟。总结可以参考这个 Twitter 线程。
现在让我们深入探讨针对指令调优 LLMs 的最佳实践!
迭代提示开发
就像任何机器学习模型都是通过迭代过程构建的,有效的提示也是通过类似的迭代方法构建的。即使是最有才华的开发者也可能在第一次尝试时没有创建完美的提示,因此要准备好接受现实,可能需要多次尝试才能实现预期目标。
基于数据的应用程序构建始终是一个迭代过程。公共领域
通过示例理解事物总是更好。让我们开始构建一个从职位描述中提取信息的系统。我们将在示例中使用的职位描述是来自 LinkedIn 的机器学习工程师招聘广告。
示例职位描述。来自LinkedIn 职位页面的截图
初始提示可以简单地请求模型提取特定信息。此外,我会使用分隔符(你可以在稍后提到的ChatGPT 提示工程师课程中了解更多)。虽然本地应用程序不太可能受到提示注入攻击的影响,但这仍然是一个好习惯。
提示 v1 的输出。由作者使用ChatGPT创建的图像
嗯,这样并没有太大帮助。我们需要更具体地说明我们希望从模型中获得什么。让我们要求它提取职位名称、公司名称、所需关键技能以及总结的职位描述。
记住,这只是一个示例,你可以设计你的提示来提取你想要的任何信息:学位、所需经验年限、地点等等。
提示 v2 的输出。由作者使用ChatGPT创建的图像
看起来更好!为了使输出更简洁明了,让我们要求模型将技能输出为列表,并提供更简短的职位描述总结。
提示 v3 的输出。由作者使用ChatGPT创建的图像
这是我们初次尝试的显著改进,不是吗?我们仅仅用两次迭代就达到了这一点!所以,当事情最初进展不顺利时,不要失去希望。继续指导模型,进行实验,你一定会取得成功。
请求结构化输出
我想讨论的第二点是要求模型以某种预期的结构化格式输出结果。虽然这对通过网页界面与 LLM 交互(例如我们与 ChatGPT 的方式)可能不是那么关键,但对于基于 LLM 的应用程序来说是极其有用的,因为解析结果的过程要容易得多。
一种常见的做法是使用 JSON 或 XML 等格式,并定义特定的键来组织输出数据。让我们修改提示以展示模型预期的 JSON 结构。
对于提示 v4 的输出,要求 JSON 输出。由 ChatGPT 创作的作者提供的图像
这样的输出更容易在应用程序的后续逻辑中解析和处理。
值得提及的是这个方向的发展。一些工具旨在严格将模型的输出适配到给定的格式,这对某些任务非常有用。只需看看下面的示例!
其中一个可能的应用是生成特定格式的大量内容(例如,使用 guidance 的游戏角色信息)。
生成 JSON 格式的游戏角色信息。来自 guidance GitHub 仓库 的 gif
像 LMQL 这样的语言为提示语言模型引入了类似编程的方法。随着这些工具的不断发展和改进,它们有可能彻底改变我们与 LLM 的互动方式,从而提供更准确和结构化的响应。
LMQL 查询示例。请参阅 LMQL 网页,更多示例请见此处 的截图
Chain-of-Thought 推理
Chain-of-Thought (CoT) 推理被发现对于需要……好吧,推理的任务非常有帮助。因此,如果你有机会通过将任务分解成多个更简单的步骤来解决问题,这对于 LLM 来说可能是一个很好的方法。
看看原始论文中的示例。通过将问题拆分成更小的步骤并提供明确的指示,我们可以帮助模型生成正确的输出。
介绍 CoT 提示。来自 Chain-of-Thought Prompting Elicits Reasoning in LLMs 论文 的图 1
有趣的是,后来发现,在提示的末尾附加一个简单的且神奇的 “让我们一步一步思考” 可以改善结果——这种技巧被称为 零-shot CoT。因此,构建允许模型“思考过程”的提示是很有帮助的,因为模型没有其他表达思想的能力,除了生成标记。
目前最佳的零-shot CoT 提示是“让我们一步一步地解决这个问题,以确保我们得到正确的答案”。
最佳零样本提示。来自LLMs 是人类水平的提示工程师论文的表 7
更复杂的解决方案正在积极开发中。虽然它们在某些场景中显著超越其他方法,但其实际应用仍然有限。我将提到两种这样的技术:自一致性和思维树。
自一致性论文的作者提出了以下方法。他们建议不要仅仅依赖于初始模型输出,而是通过多次采样并通过多数投票来聚合结果。通过依赖直觉和集成方法在经典机器学习中的成功,这种技术增强了模型的鲁棒性。
自一致性。来自自一致性改进语言模型中的链式推理论文的图 1
你也可以在不实施聚合步骤的情况下应用自一致性。对于短输出的任务,要求模型建议几个选项并选择最佳选项。
思维树(ToT)将这一概念进一步扩展。它提出了为模型的“推理思维”应用树搜索算法的想法,当模型遇到不良假设时,实际上是回溯。
思维树。来自思维树:使用 LLMs 进行深思熟虑问题解决论文的图 1
如果你感兴趣,可以查看Yannic Kilcher 的 ToT 论文评论视频。
对于我们的具体场景,利用链式思维推理并非必要,但我们可以将模型的总结任务分为两个阶段。最初,它可以概括整个职位描述,然后再对得出的总结进行集中于职位职责的总结。
带有逐步指令的提示 v5 的输出。由ChatGPT创建的作者图像
在这个特定的例子中,结果没有显示出显著的变化,但这种方法对大多数任务非常有效。
少样本学习
我们将介绍的最后一种技术叫做少样本学习,也称为上下文学习。它的简单之处在于将几个示例纳入提示中,以便为模型提供更清晰的任务描述。
这些示例不仅要与任务相关,还要多样化,以涵盖数据的多样性。对于少量样本学习,使用 CoT 可能会有些挑战,特别是当你的流程有很多步骤或输入较长时。然而,通常来说,结果是值得这些努力的。另外,请记住,标记少量示例的成本远低于在传统机器学习模型开发中标记整个训练/测试集的成本。
如果我们在提示中添加一个示例,它将更好地理解要求。例如,如果我们表明我们希望最终总结以要点形式呈现,模型将会按照我们的模板进行回应。
这个提示可能会让人感到有些压倒,但不要害怕:它只是一个以前的提示(v5)和一个标记的示例,采用 For example: 'input description' -> 'output JSON'
格式。
包含一个示例的提示 v6 输出。图像由作者使用ChatGPT创建
总结最佳实践
总结提示工程的最佳实践,请考虑以下几点:
-
不要害怕实验。尝试不同的方法,逐步迭代,纠正模型并一次进行小的步骤;
-
在输入中使用分隔符(例如 <>),并要求结构化输出(例如 JSON);
-
提供完成任务的行动列表。尽可能地,向模型提供一组行动,并让它输出“内部想法”;
-
如果输出较短,请求多个建议;
-
提供示例。如果可能,向模型展示几个代表数据的多样化示例,并显示期望的输出。
我认为这个框架为自动化各种日常任务提供了足够的基础,例如信息提取、总结、文本生成(如电子邮件)等。然而,在生产环境中,仍然可以通过在特定数据集上微调模型来进一步优化性能。此外,插件和代理的快速发展也值得关注,但那是完全不同的话题。
DeepLearning.AI 和 OpenAI 的提示工程课程
除了之前提到的Andrej Karpathy 的讲座,这篇博客文章还从DeepLearning.AI 和 OpenAI 的 ChatGPT 提示工程课程中汲取了灵感。该课程完全免费,完成仅需几小时,并且我个人非常喜欢的一点是,它允许你在无需注册的情况下实验 OpenAI API!
这是一个很好的实验场地,所以一定要去看看。
使用 OpenAI API 和 Streamlit 构建基于 LLM 的应用程序
哇,我们涵盖了很多信息!现在,让我们继续前进,开始使用我们获得的知识构建应用程序吧。
生成 OpenAI 密钥
要开始使用,你需要注册一个 OpenAI 账户并创建你的 API 密钥。OpenAI 目前为每个人提供3 个月的 $5 免费信用额。请参考OpenAI API 入门介绍页面来注册你的账户和生成你的 API 密钥。
一旦你有了密钥,创建一个OPENAI_API_KEY
环境变量以便通过os.getenv('OPENAI_API_KEY')
在代码中访问它。
使用标记器实验室估算成本
在这个阶段,你可能会好奇仅凭免费试用可以做多少事情,以及初始三个月后有哪些选项。这是个很好的问题,特别是当你考虑到LLMs 需要花费数百万美元!
当然,这些数百万是关于训练的。事实证明,推理请求非常实惠。虽然 GPT-4 可能被认为比较贵(尽管价格可能会下降),但gpt-3.5-turbo
(默认 ChatGPT 背后的模型)仍然足以应对大多数任务。实际上,考虑到这些模型的原始参数量以亿计,OpenAI 做了一个令人难以置信的工程工作,因为这些模型现在既便宜又快速。
gpt-3.5-turbo
模型的费用是每 1,000 个标记 $0.002。
那么,具体多少钱呢?让我们看看。首先,我们需要了解什么是标记。简单来说,标记指的是单词的一部分。在英语中,你可以预期每 10 个单词大约有 14 个标记。
为了更准确地估算你特定任务和提示的标记数量,最好的方法是亲自尝试!幸运的是,OpenAI 提供了一个标记器实验室来帮助你。
附注:不同语言的标记化
由于英语在互联网的广泛使用,这种语言的标记化效果最好。正如在“所有语言的标记化并不相等”博客文章中强调的那样,标记化在不同语言中并不统一,某些语言可能需要更多的标记来表示。如果你想构建一个涉及多语言提示的应用程序,比如翻译,请记住这一点。
为了说明这一点,我们来看看不同语言中 全句 的分词情况。在这个玩具示例中,英语需要 9 个标记,法语 — 12,保加利亚语 — 59,日语 — 72,俄语 — 73。
不同语言的分词。截图来自 OpenAI 分词器游乐场
成本与性能
正如你可能注意到的那样,提示可能会变得相当长,尤其是在包含示例时。通过增加提示的长度,我们可能会提高质量,但同时随着使用更多标记,成本也会增加。
我们最新的提示(v6)大约由 1.5k 个标记组成。
提示的分词 v6。截图来自 OpenAI 分词器游乐场
考虑到输出长度通常与输入长度相同,我们可以估计每次请求的平均标记数约为 3k 个(输入标记 + 输出标记)。通过将这个数字乘以初始费用,我们发现每次请求约为 0.006 美元或 0.6 美分,这相当实惠。
即使我们考虑到每次请求稍高的费用为 1 美分(相当于大约 5k 个标记),你仍然可以以仅需 1 美元进行 100 次请求。此外,OpenAI 提供了设置软限制和硬限制的灵活性。软限制会在你接近定义的限制时发出通知,而硬限制则会限制你超出指定阈值。
对于本地使用的 LLM 应用程序,你可以舒适地配置每月 1 美元的硬限制,确保在预算范围内享受模型的好处。
Streamlit 应用模板
现在,让我们构建一个网页界面,以编程方式与模型互动,消除每次手动复制提示的需要。我们将使用 Streamlit 来完成这一任务。
Streamlit 是一个 Python 库,它允许你创建简单的网页界面,无需使用 HTML、CSS 和 JavaScript。它对初学者友好,并允许使用最少的 Python 知识创建基于浏览器的应用程序。现在我们来创建一个简单的模板,用于基于 LLM 的应用程序。
首先,我们需要处理与 OpenAI API 通信的逻辑。在下面的示例中,我假设 generate_prompt()
函数已定义并返回给定输入文本的提示(例如,类似于你之前看到的)。
就这样!了解更多关于不同参数的信息,请参考 OpenAI 的文档,但默认设置已经很好。
有了这些代码,我们可以设计一个简单的网页应用程序。我们需要一个输入文本的字段,一个处理文本的按钮,以及几个输出小部件。我更倾向于访问完整的模型提示和输出,以便进行调试和探索原因。
整个应用程序的代码大致如下,可以在这个 GitHub 仓库中找到。由于共享 OpenAI 密钥不是一个好主意,我添加了一个名为toy_ask_chatgpt()
的占位符函数。目前,这个应用程序只是将提示复制到输出中。
如果不定义函数和占位符,这只有大约 50 行代码!
幸好Streamlit 最近的更新现在允许嵌入它到这篇文章中!所以你应该能够在下方看到它。
现在你可以看到这有多么简单。如果你愿意,你可以使用 Streamlit Cloud 部署你的应用程序。但要小心,因为如果你在其中放置你的 API 密钥,每个请求都会花费你金钱!
结论
在这篇博客文章中,我列出了几种提示工程的最佳实践。我们讨论了迭代提示开发、使用分隔符、请求结构化输出、思维链推理以及少量示例学习。我还提供了一个模板,用于在不到 100 行代码的情况下使用 Streamlit 构建一个简单的网页应用程序。现在,轮到你来提出一个令人兴奋的项目创意并将其变为现实了!
现代工具允许我们在仅仅几小时内创建复杂的应用程序,这真是令人惊叹。即使没有丰富的编程知识、Python 熟练度或对机器学习的深刻理解,你也可以快速构建一些有用的东西并自动化一些任务。
如果你是初学者并且想创建类似的项目,请随时向我提问。我将非常乐意协助你,并尽快回复。祝你的项目好运!
资源
这里是我关于 LLM 的其他文章,希望对你有所帮助。我已经涵盖了:
-
估算大语言模型的规模:LLM 是什么,如何训练,它们需要多少数据和计算资源;
-
使用 ChatGPT 进行调试:如何使用 LLM 进行调试和代码生成。
你可能还会对以下内容感兴趣:
-
免费的Learn Prompting 课程,以深入了解提示和相关技术;
-
最近发布的DeepLearning.AI 短期课程,用于构建 OpenAI API 应用程序
感谢阅读!
-
希望这些材料对你有所帮助。在 Medium 上关注我,以获取更多类似的文章。
-
如果你有任何问题或评论,我会很高兴收到任何反馈。可以在评论中问我,或通过LinkedIn或Twitter联系我。
-
支持我作为作者并访问其他成千上万篇 Medium 文章,请通过我的推荐链接获取 Medium 会员(对你没有额外费用)。
使用 NLP 和 AI 更好地总结播客文字记录和长文本
原文:
towardsdatascience.com/summarize-podcast-transcripts-and-long-texts-better-with-nlp-and-ai-e04c89d3b2cb?source=collection_archive---------0-----------------------#2023-05-03
为什么现有的总结方法存在缺陷,以及如何改进的详细步骤
Isaac Tham
·
关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 5 月 3 日
--
图片来源于 Unsplash。
像 GPT-4 这样的 LLM 已经席卷了世界,而生成文本模型特别擅长于长文本的摘要,如书籍或播客转录。然而,使用 LLM 来总结长文本的传统方法实际上存在根本性的缺陷。在这篇文章中,我将告诉你现有摘要方法的问题,并提出一种更好的摘要方法,实际上考虑了文本的结构!更棒的是,这种方法还会给我们文本的主要话题——一举两得!
我将引导你如何通过对现有方法进行几个小调整,在 Python 中轻松实现这一点。这是我们在Podsmart中使用的方法,我们新推出的 AI 驱动的播客摘要应用,帮助忙碌的知识分子节省听取的时间。
现有解决方案的问题
总结长文本的经典方法是递归摘要,其中长文本被平分成较短的块,这些块可以适应 LLM 的上下文窗口。每个块都被总结,摘要被连接在一起,然后传递给 GPT-3 进行进一步总结。这个过程会重复进行,直到获得所需长度的最终摘要。
然而,主要的缺点是现有实现,例如 LangChain 的summarize chain using map_reduce,将文本分割成块,而不考虑文本的逻辑和结构流。
例如,如果文章长 1000 字,200 字的块大小意味着我们将获得 5 个块。如果作者有几个主要点,其中第一个占据了前 250 字?最后 50 字会被放入第二个块中,与作者的下一个观点的文本一起处理,传递这个块给 GPT-3 的摘要器可能会导致遗漏第一个观点的潜在重要信息。此外,一些关键点可能比其他点更长,并且事先无法知道这一点。
另一种方法是“精炼”方法,它将每一段文本以及来自前几段的摘要传递给 LLM,这样随着看到更多的文本,摘要会逐步被精炼(请参见这里的提示)。然而,该过程的顺序特性意味着它无法并行处理,且所需时间是线性的,远长于递归方法的对数时间。此外,直觉上初始部分的意义在最终摘要中会被过度代表。对于播客转录,其中前几分钟是与播客其余部分完全无关的广告,这成为一个绊脚石。因此,这种方法并不被广泛使用。
即使出现了更先进的语言模型,具有更长的上下文窗口,它仍然会在许多总结用例(整个书籍)中显得极其不足,因此不可避免地需要一些分块和递归总结。
本质上,如果总结过程未能识别文本的意义层次结构,并且与之不兼容,那么生成的总结很可能不足以准确传达作者的意图。
更好的前进方式
更好的解决方案是将总结和主题建模过程在同一个算法中一起处理。在这里,我们将递归总结的一个步骤的总结输出拆分为块,然后将这些块输入到下一步。我们可以通过将块在语义上进行主题聚类,并将主题传递到下一次总结迭代中来实现这一点。让我们来看看如何在 Python 中实现这一点吧!
要求
Python 包:
-
scipy — 用于余弦距离度量
-
networkx — 用于 Louvain 社区检测算法
-
langchain — 一个实用功能包,允许你调用像 OpenAI 的 GPT-3 这样的 LLMs。
数据和预处理
包含 Jupyter notebook 和数据的 GitHub 仓库可以在这里找到:github.com/thamsuppp/llm_summary_medium
我们今天总结的文本是 2023 年美国总统乔·拜登的国情咨文演讲。文本文件在 GitHub 仓库中,这里 是原始来源。演讲,像所有美国政府出版物一样,属于公有领域。请注意,确保你被允许使用源文本很重要——《Towards Data Science》发布了一些有用的提示,关于如何检查数据集的版权和许可证。
我们将原始文本拆分为句子,限制句子的最小长度为 20 个单词,最大长度为 80 个单词。
创建块
与其创建足够大以适应上下文窗口的块,我建议块的大小应为表达一个离散思想通常需要的句子数量。这是因为我们随后将嵌入这个文本块,本质上将其语义意义提炼成一个向量。我目前使用 5 个句子(但你可以尝试其他数量)。我倾向于在块之间有 1 句重叠,以确保连续性,使每个块都能包含一些关于前一个块的上下文信息。对于给定的文本文件,共有 65 个块,平均块长 148 个单词,范围从 46 到 197 个单词。
获取每个块的标题和摘要
现在,这是我开始偏离 LangChain 的总结链的地方。
用 1 次 LLM 调用获得 2 个:标题 + 摘要
我想要既获得一个信息丰富的标题,又要对每个块进行总结(标题的重要性会在后面变得更加明确)。因此,我创建了一个自定义提示,改编了 Langchain 的 默认总结链提示。正如你在 map_prompt_template
中看到的 - text
是一个将被插入到提示中的参数 - 这将是每个块的原始文本。我创建了一个 LLM,目前是 GPT-3,并创建了一个 LLMChain,将 LLM 与提示模板结合起来。然后,map_llm_chain.apply()
调用 GPT-3,并将插入文本的提示模板传入,返回每个块的标题和摘要,我将这些解析为字典输出。注意,所有块可以并行处理,因为它们彼此独立,从而带来了巨大的速度优势。
你可以使用 ChatGPT 以更便宜的价格和类似的性能,但我尝试过的时候,只有 GPT-3 LLM 能并行运行查询,而使用 ChatGPT 则是逐个运行,这非常缓慢,因为我通常会同时传入 ~100 个块。并行运行 ChatGPT 需要异步实现。
def summarize_stage_1(chunks_text):
# Prompt to get title and summary for each chunk
map_prompt_template = """Firstly, give the following text an informative title. Then, on a new line, write a 75-100 word summary of the following text:
{text}
Return your answer in the following format:
Title | Summary...
e.g.
Why Artificial Intelligence is Good | AI can make humans more productive by automating many repetitive processes.
TITLE AND CONCISE SUMMARY:"""
map_prompt = PromptTemplate(template=map_prompt_template, input_variables=["text"])
# Define the LLMs
map_llm = OpenAI(temperature=0, model_name = 'text-davinci-003')
map_llm_chain = LLMChain(llm = map_llm, prompt = map_prompt)
map_llm_chain_input = [{'text': t} for t in chunks_text]
# Run the input through the LLM chain (works in parallel)
map_llm_chain_results = map_llm_chain.apply(map_llm_chain_input)
stage_1_outputs = parse_title_summary_results([e['text'] for e in map_llm_chain_results])
return {
'stage_1_outputs': stage_1_outputs
}
嵌入块并按主题进行聚类
在获得每个块的摘要后,我将使用 OpenAI 的嵌入将它们嵌入到 1536 维向量中。传统的递归总结方法不需要嵌入,因为它们按均匀长度任意拆分文本。对我们来说,我们的目标是通过将语义上相似的块分组到一起,来改进这一方法。
将文本按主题分组是 NLP 中一个研究较多的问题,许多传统方法如 Latent Dirichlet Allocation 早于深度学习时代。我记得在 2017 年使用 LDA 来为我大学报纸的文章进行聚类——它估计的速度非常慢,并且只使用词频,这不能捕捉语义含义。
现在,我们可以利用 OpenAI 的嵌入即服务 API 来获取在一秒钟内捕捉句子语义含义的嵌入。这里还有许多其他可能的嵌入模型,例如 HuggingFace 的 sentence-transformers
,据报道它比 OpenAI 的嵌入表现更好,但这涉及到下载模型并在你自己的服务器上运行。
在获得块的嵌入向量后,我们将相似的向量聚集在一起。
我创建了一个块相似性矩阵,其中 (i,j)th
条目表示第 i 个和第 j 个块的嵌入向量之间的余弦相似性,即块之间的语义相似性。
国情咨文演讲的片段相似度矩阵。你可以看到某些片段组在语义上彼此相似——这正是话题检测算法随后会揭示的内容。图片由作者提供。
我们可以将其视为节点之间的相似度图,其中节点是片段,边的权重是两个片段之间的相似度。我们使用 Louvain 社区检测算法 从片段中检测话题。这是因为在图分析中,社区被定义为具有密集的内部社区连接和稀疏的社区间连接,这正是我们想要的:话题内的片段彼此非常语义相似,而每个片段与其他检测到的话题中的片段的语义相似度较低。
Louvain 社区检测算法有一个超参数叫做分辨率——较小的分辨率会导致较小的簇。此外,我增加了一个超参数 proximity_bonus
—— 如果原始文本中片段的位置彼此接近,它会提高片段的相似度分数。你可以将此解释为将文本的时间结构视为先验(即,彼此接近的片段更可能在语义上相似)。我加入这个参数是为了避免检测到的话题中包含来自文本各处的片段,这种情况不太可能。该函数还试图最小化簇大小的方差,防止出现一个簇有 1 个片段而另一个簇有 13 个片段的情况。
对于国情咨文演讲,输出是 10 个簇,这些簇之间的连贯性很好。
检测到的国情咨文演讲的话题簇。图片由作者提供。
检测到的 Bloomberg Surveillance 播客转录的话题簇。紫色和橙色的话题识别了广告。图片由作者提供。
第二张图片是另一个播客节目中的话题聚类。正如你所见,开始和结束部分被检测为相同的话题,这在开头和结尾有广告的播客中很常见。一些话题,比如紫色的,也是不连续的——我们的算法允许这种情况很不错,因为文本可能回到之前提到的话题,而这也是传统文本拆分未考虑的另一种可能性。
话题标题和总结
现在,我们得到的是语义一致的话题,可以进入递归总结的下一步。对于这个例子,这将是最后一步,但对于更长的文本如书籍,你可以想象重复这个过程几次,直到剩下大约 10 个话题,其话题总结可以适应上下文窗口。
下一步涉及三个不同的部分。
主题标题:对于每个主题,我们生成了该主题中块的标题列表。我们将所有主题的标题列表传递给 GPT-3,并要求其聚合这些标题以得出每个主题的一个标题。我们对所有主题同时进行此操作,以防主题标题之间过于相似。以前,当我单独生成主题标题时,GPT-3 没有其他主题标题的上下文,因此出现了 7 个标题中有 4 个是‘联邦储备的货币政策’的情况。这就是我们希望生成块标题的原因——试图将所有块摘要放入上下文窗口在非常长的文本中可能不可行。
如下所示,标题看起来很好!描述性强,但彼此独特。
1\. Celebrating American Progress and Resilience
2\. US Economy Strengthening and Inflation Reduction
3\. Inflation Reduction Act: Lowering Health Care Costs
4\. Confronting an Existential Threat: Making Big Corporations Pay
5\. Junk Fee Prevention Act: Stopping Unfair Charges
6\. COVID-19 Resilience and Vigilance
7\. Fighting Fraud and Public Safety
8\. United States' Support for Ukraine and Global Peace
9\. Progress Made in Healthcare and Gun Safety
10\. United States of America: A Bright Future
主题摘要: 没有新意,这涉及将每个主题的块摘要结合在一起,并要求 GPT-3 将其总结为主题摘要。
最终摘要: 为了得到文本的总体摘要,我们再次将主题摘要连接在一起,并提示 GPT-3 进行总结。
国情咨文的最终摘要。图像由作者提供。
总结
我们的方法有哪些好处?
文本被分层拆分为主题、块和句子。随着层级的深入,我们得到越来越详细和具体的摘要,从最终摘要,到每个主题的摘要,再到每个块的摘要。
正如我上面提到的,摘要准确地捕捉了文本的语义结构——其中有一个总体主题,分成几个主要主题,每个主题包含若干关键思想(块),确保在各层摘要中保留了关键信息。
这比仅仅提供总体摘要更具灵活性。不同的人对文本的不同部分更感兴趣,因此他们会选择适合自己需求的详细程度。
当然,这需要将生成的摘要与直观且连贯的界面配对,该界面可视化文本的层级特性。此类可视化的一个示例见 Podsmart—— 点击这里 查看演讲的互动摘要。
提取的主题及其时间线的可视化。图像由作者提供。
请注意,这并不会显著增加 LLM 的费用——我们仍然将与传统方法相同的输入传递给 LLM,但我们获得了更丰富的摘要。
TLDR —— 这里是生成优质文本摘要的秘密法宝
-
语义一致的主题——通过对文本的小块进行语义嵌入并按语义相似性拆分文本
-
从数据块中获取标题和摘要——这需要自定义提示,而不是使用默认的 LangChain 摘要链
-
校准 Louvain 社区检测算法——如分辨率和接近度奖励等超参数确保生成的话题簇是合理的
-
不同的话题标题——同时生成所有话题标题,这需要块标题
再次提醒,你可以在 GitHub 仓库 查看整个源代码
如果你觉得这篇文章对你有帮助:
-
请查看我在 Medium 上的其他文章:作为数据科学家构建 AI 应用的技术提示、使用深度学习生成音乐
-
试试我的 app——Podsmart 可以转录和总结播客和 YouTube 视频,为忙碌的知识分子节省听音时间
-
关注我在 LinkedIn 或 Twitter/X,通过消息或评论与我联系!我很乐意讨论关于数据科学和 AI 的各种想法
感谢阅读!
用 ChatGPT 总结最新的 Spotify 发布内容
原文:
towardsdatascience.com/summarizing-the-latest-spotify-releases-with-chatgpt-553245a6df88
探索音乐发现的力量:使用 ChatGPT 或 GPT-4 和 Spotify API 总结新发布的音乐
Luís Roque
·发表于Towards Data Science ·10 分钟阅读·2023 年 3 月 16 日
--
在当今快节奏的世界中,自然语言处理(NLP)已成为各种应用中至关重要的组成部分。像 OpenAI 的 ChatGPT 和 GPT-4 这样的巨大模型,解锁了在摘要生成、语音转文字、语音识别、语义搜索、问答系统、聊天机器人等任务中令人难以置信的潜力。
我很高兴地宣布“大型语言模型编年史:驾驭 NLP 前沿”这一新的每周文章系列,将探讨如何利用大型模型的力量进行各种 NLP 任务。通过深入研究这些前沿技术,我们旨在赋能开发者、研究人员和爱好者,充分利用 NLP 的潜力,解锁新的可能性。
在本系列的第一篇文章中,我们将重点介绍如何使用 OpenAI 的 ChatGPT 和 Spotify API 创建一个智能摘要系统,以获取最新的音乐发布。随着系列的发展,我们将深入探讨多种 NLP 应用,提供洞见、技术和实际示例,展示大型模型在改变我们与语言互动和理解方式方面的能力。
敬请关注更多文章,随着我们踏上这段激动人心的 NLP 之旅,指导你掌握各种语言任务的最新大型模型。
图 1:LLM 是否开始了人与机器之间的新合作? (source)
代码可在我的Github上找到。
介绍
ChatGPT 和 GPT-4,由 OpenAI 开发,是先进的语言模型,在各种自然语言处理任务中表现出色。它们能够理解上下文,生成类似人类的响应,甚至有效总结大段文本。这使它们成为总结 Spotify 上最新音乐发布的理想工具。
作为领先的音乐流媒体平台,Spotify 提供了一个广泛的 API,使开发者能够访问大量音乐数据,包括新发布、播放列表等。通过将 ChatGPT 强大的语言理解能力与 Spotify API 提供的丰富音乐数据结合起来,我们可以构建一个系统,让您及时了解 Spotify 目录中的最新内容。
我们将引导您完成构建这个智能音乐总结系统的过程。我们的方法将包括以下步骤:
-
访问 Spotify API:我们将通过 Spotify API 获取最新音乐发布的数据。
-
使用 ChatGPT 总结:然后,我们将使用 OpenAI 的 API 生成最新发布内容的简明总结。
-
结果:最后,我们将以一种易于阅读和引人入胜的格式展示总结。
敬请关注,我们将深入探讨每一步的细节,助您创建自己的音乐总结工具!
访问 Spotify API
在本节中,我们将探讨如何从 Spotify API 获取最新的音乐发布及其相关曲目数据。然后我们将把这些数据保存到 JSON 文件中以供进一步处理。将使用以下 Python 函数实现此目标:
-
get_new_releases
:从 Spotify 获取新专辑发布。 -
get_album_tracks
:检索特定专辑的曲目信息。 -
save_data_to_file
:将获取的数据保存到 JSON 文件中。 -
load_data_from_file
:从 JSON 文件中加载保存的数据。 -
download_latest_albums_data
:从 Spotify 下载最新的专辑和曲目数据,并保存到 JSON 文件中。
让我们解析这些功能的关键组件,理解它们如何协同工作以访问 Spotify API。
获取新发布
get_new_releases
函数接受两个可选参数,limit
和 offset
。limit
确定要返回的专辑结果的最大数量,而 offset
指定第一个结果的索引。默认情况下,limit
设置为 50,offset
设置为 0。然后,函数调用 Spotify API 的 sp.new_releases
,它返回一个包含专辑信息的字典。相关的专辑项被提取并返回为字典列表。
def get_new_releases(limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
"""
Fetch new releases from Spotify.
Args:
limit (int, optional): Maximum number of album results to return. Defaults to 50.
offset (int, optional): The index of the first result to return. Defaults to 0.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing album information.
"""
new_releases = sp.new_releases(limit=limit, offset=offset)
albums = new_releases["albums"]["items"]
return albums
检索专辑曲目
get_album_tracks
函数接受一个参数 album_id
,这是我们想要获取曲目信息的专辑的 Spotify ID。该函数调用 Spotify API 的 sp.album_tracks
,返回一个包含曲目数据的字典。然后,曲目项被提取并作为字典列表返回。
def get_album_tracks(album_id: str) -> List[Dict[str, Any]]:
"""
Fetch tracks from a specific album.
Args:
album_id (str): The Spotify ID of the album.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing track information.
"""
tracks = sp.album_tracks(album_id)["items"]
return tracks
保存和加载数据
save_data_to_file
函数接受两个参数:data
,这是一个包含专辑和曲目信息的字典列表;file_path
,这是保存数据的 JSON 文件路径。该函数使用json.dump
方法将数据写入指定的文件。
相反,load_data_from_file
函数从指定的 JSON 文件中读取数据,并使用json.load
方法将其作为字典列表返回。
def save_data_to_file(data: List[Dict[str, Any]], file_path: str) -> None:
"""
Save data to a JSON file.
Args:
data (List[Dict[str, Any]]): List of dictionaries containing album and track information.
file_path (str): Path to the JSON file where the data will be saved.
"""
with open(file_path, "w", encoding="utf-8") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
def load_data_from_file(file_path: str) -> List[Dict[str, Any]]:
"""
Load data from a JSON file.
Args:
file_path (str): Path to the JSON file where the data is stored.
Returns:
List[Dict[str, Any]]: List of dictionaries containing album and track information.
"""
with open(file_path, "r", encoding="utf-8") as file:
return json.load(file)
下载最新专辑数据
download_latest_albums_data
函数作为从 Spotify 下载最新专辑和曲目数据的主要驱动程序。它初始化了诸如limit
、offset
、total_albums
、album_count
等变量,并创建了一个空列表all_albums
以存储获取的数据。
函数进入一个循环,直到获取到指定数量的专辑(total_albums
)为止。在每次迭代中,函数调用get_new_releases
和get_album_tracks
以检索专辑和曲目信息。这些数据随后存储在all_albums
列表中。
在获取数据后,函数将offset
按limit
值递增,以便在随后的迭代中获取下一组专辑。为了避免触及 Spotify API 的速率限制,添加了 1 秒的延迟。函数最后调用save_data_to_file
将获取的数据存储到 JSON 文件中。
def download_latest_albums_data() -> None:
"""
Download the latest albums and tracks data from Spotify and save it to a JSON file.
"""
limit = 50
offset = 0
total_albums = 30
album_count = 0
all_albums = []
while total_albums is None or album_count < total_albums:
new_releases = get_new_releases(limit, offset)
if total_albums is None:
total_albums = sp.new_releases()["albums"]["total"]
for album in new_releases:
album_info = {
"album_name": album["name"],
"artist_name": album["artists"][0]["name"],
"album_type": album["album_type"],
"release_date": album["release_date"],
"available_markets": album["available_markets"],
"tracks": [],
}
tracks = get_album_tracks(album["id"])
for track in tracks:
track_info = {
"track_name": track["name"],
"duration_ms": track["duration_ms"],
"preview_url": track["preview_url"],
}
album_info["tracks"].append(track_info)
all_albums.append(album_info)
album_count += 1
offset += limit
time.sleep(1) # Add a delay to avoid hitting the rate limit
print(f"Downloaded {album_count}/{total_albums}")
save_data_to_file(all_albums, "albums_and_tracks.json")
通过使用这些函数,我们可以有效地访问 Spotify API,以收集最新音乐发行的数据。在下一节中,我们将探讨如何预处理这些数据并使用 ChatGPT 生成这些新发行的摘要。
使用 LangChain 通过 ChatGPT 生成摘要
在本节中,我们将讨论如何预处理从 Spotify API 获得的专辑和曲目数据,并利用 ChatGPT 通过 LangChain 库生成最新音乐发行的摘要。LangChain 是一个强大的工具,使开发者能够构建将 LLMs 与其他计算或知识源结合的应用程序。
我们将使用以下 Python 函数来实现这一目标:
-
preprocess_docs
: 将 JSON 数据转换为 Document 对象列表。 -
get_summary
: 使用提供的 JSON 数据生成摘要,该数据在 Document 对象的列表中。
数据预处理
preprocess_docs
函数接受一个包含专辑和曲目信息的字典列表,这是我们从 Spotify API 检索到的数据。该函数将这些数据转换为 JSON 字符串,然后将其拆分为 3500 字符的段落。这些段落用于创建 Document 对象列表,并将传递给 ChatGPT 以生成摘要。
将数据拆分为较小的段落是为了处理 ChatGPT API 施加的文本长度限制。通过将文本拆分成较小的部分,我们可以更有效地处理数据,而不会超出模型的最大令牌限制。
def preprocess_docs(data: List[Dict[str, Any]]) -> List[Document]:
"""
Convert the JSON data to a list of Document objects.
Args:
data (List[Dict[str, Any]]): List of dictionaries containing album and track information.
Returns:
List[Document]: A list of Document objects containing the JSON data as strings, split into 3000-character segments.
"""
json_string = json.dumps(data, ensure_ascii=False, indent=4)
doc_splits = [json_string[i : i + 3500] for i in range(0, len(json_string), 3500)]
docs = [Document(page_content=split_text) for split_text in doc_splits]
return docs
使用 LangChain 通过 ChatGPT 生成摘要
LangChain 的 CombineDocuments 链旨在处理和组合多个文档的信息,使其非常适合诸如摘要和问答等任务。在我们的案例中,我们将专注于使用 MapReduce 方法生成最新 Spotify 发布内容的摘要。如果你已经可以访问 API,你可以轻松地使用 GPT-4。为此,你只需更新传递给ChatOpenAI
类的model_name
参数。
MapReduce 方法通过对每个数据块运行初始提示来工作,为每个数据块生成一个输出。例如,在摘要任务中,这涉及为每个单独的数据块创建一个摘要。在下一步中,会运行不同的提示,将所有这些初始输出合并成一个连贯的输出。
使用 MapReduce 方法的主要优势在于它可以扩展到更大的文档并处理比 Stuffing 方法更多的文档。此外,对每个文档的 LLM 调用是独立的,允许并行处理和更快的处理速度。
在我们的项目背景下,我们将应用 MapReduce 方法使用 ChatGPT 总结最新的 Spotify 发布内容。我们使用 MapReduce 方法为每个文档生成摘要,然后将这些摘要合并成一个简洁的总结。
def get_summary(docs: List[Document]) -> str:
"""
Generate a summary using the JSON data provided in the list of Document objects.
Args:
docs (List[Document]): A list of Document objects containing the JSON data as strings.
Returns:
str: The generated summary.
"""
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
prompt_template = """Write a short summary about the latest songs in Spotify based on the JSON data below: \n\n{text}."""
prompt_template2 = """Write an article about the latest music released in Spotify (below) and adress the change in music trends using the style of Rick Beato. : \n\n{text}"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
PROMPT2 = PromptTemplate(template=prompt_template2, input_variables=["text"])
chain = load_summarize_chain(
llm,
chain_type="map_reduce",
return_intermediate_steps=True,
map_prompt=PROMPT,
combine_prompt=PROMPT2,
verbose=True,
)
res = chain({"input_documents": docs}, return_only_outputs=True)
return res
结果
为了更好地理解我们使用 Spotify API 和 OpenAI API 的 ChatGPT 实现的摘要能力,我们将展示一个示例,演示系统如何处理数据并生成简洁的摘要。让我们检查输入数据、中间步骤和最终输出。
输入数据
输入数据包括几个专辑及其相应的曲目,如 Dillaz 的“Oitavo Céu”和 T-Rex 的“CASTANHO”。每张专辑包括专辑名称、艺术家名称、专辑类型、发行日期以及曲目名称和持续时间(以毫秒为单位)的列表。
中间步骤
中间步骤包括使用 MapReduce 方法处理输入数据。例如,以下是为部分输入数据生成的摘要:
Spotify 上的最新歌曲包括三张新专辑中的曲目:Dillaz 的“Oitavo Céu”、T-Rex 的“CASTANHO”和 Branko 的“OBG”。“Oitavo Céu”包含 12 首曲目,其中包括标题曲目和持续时间最长的“Maçã”,其持续时间为 219130 ms。“CASTANHO”有 11 首曲目,其中“LADO NENHUM”的持续时间最长,为 278190 ms。“OBG”有 10 首曲目,其中“ETA”的持续时间最长,为 226058 ms。所有三张专辑均于 2022 年 4 月发布。
最终输出
最终输出结合了中间步骤的摘要,提供了最新 Spotify 发布内容的连贯、简洁概述:
近年来,音乐行业的趋势发生了显著变化,流媒体平台如 Spotify 的崛起以及嘻哈和电子舞曲(EDM)等流派的日益流行导致了这些变化。因此,Spotify 上发布的最新音乐反映了这些变化,呈现了多样化的艺术家和风格。
一个显著的趋势是葡萄牙语音乐的日益突出,比如 Dillaz 的“八重天”和 T-Rex 的“CASTANHO”在平台上占据了重要位置。这些专辑展示了葡萄牙音乐的独特声音和节奏,将传统风格与现代影响相结合。
另一个趋势是不同风格和背景的艺术家之间合作的日益受欢迎。例如,Don Toliver 的“Life of a DON”和 Pop Smoke 的“Faith”专辑中包含了与 Travis Scott、Kanye West、Rick Ross 和 Lil Tjay 等多位艺术家的合作。这些合作让艺术家们探索新的声音和风格,并创造出更广泛受众喜爱的音乐。
此外,Spotify 上的最新音乐反映了音乐风格和影响的日益多样化。像 Olivia Rodrigo 的“SOUR”和 Billie Eilish 的“Happier Than Ever”这样的专辑展示了年轻女性艺术家的独特声音和视角,为当代音乐提供了新鲜的视角。
如所示,我们基于 ChatGPT 的系统有效地总结了最新的 Spotify 发布内容,为音乐爱好者提供了一个易于访问和参与的概述,以便随时了解和发现新内容。
结论
在这篇文章中,我们展示了将 Spotify API 和 OpenAI 的 ChatGPT 结合起来创建总结系统的强大功能,该系统使您能够及时了解最新的音乐发行。我们讨论了文档链技术,选择了因其可扩展性而被广泛使用的 MapReduce 方法,并展示了我们系统在生成连贯且信息丰富的总结方面的有效性。
AI 驱动的语言模型与像 Spotify 这样的热门平台 API 之间的协同作用为创新和个性化开辟了新的机会。随着 AI 技术的不断发展,它们在各种 NLP 任务中的应用将不断扩展,为我们日常生活的提升提供令人兴奋的方式。
总之,我们的探索作为尖端 AI 技术在解决现实世界挑战和创造有价值用户体验方面的潜力的激励性示例。我们希望这篇文章能鼓励您在自己的项目中进一步探索 AI 的应用,并激励您创造出能够带来改变的创新解决方案。
保持联系:LinkedIn
4 个简单步骤让你的机器学习系统超充电
原文:
towardsdatascience.com/super-charge-your-ml-systems-in-4-simple-steps-4485f0208440?source=collection_archive---------6-----------------------#2023-10-27
使用 DALL.E-3 生成的图像
Donal Byrne
·
关注 发表在 Towards Data Science · 8 分钟阅读 · 2023 年 10 月 27 日
--
欢迎来到机器学习优化的过山车之旅!这篇文章将带你了解我的优化任何机器学习系统以实现闪电般快速训练和推理的 4 个简单步骤。
想象一下:你终于被分配到一个酷炫的新机器学习项目中,你正在训练你的智能体来统计照片中的热狗数量,其成功可能为你的公司带来数十美元的收入!
你在你最喜欢的框架中实现了最新的炙手可热的物体检测模型,该模型有很多 GitHub 星标,运行一些玩具示例,经过一个小时左右,它就像一个在大学第 3 年重修的穷学生一样准确地识别热狗,生活美好。
下一步显而易见,我们想将其扩展到更困难的问题,这意味着更多的数据、更大的模型,当然,还有更长的训练时间。现在你需要面对几天的训练时间,而不是几个小时。不过没关系,你已经忽视你的团队 3 周了,可能应该花一天时间处理积压的代码审查和被动攻击的电子邮件。
你在为你在同事的 MR 上留下的有见地且绝对必要的细节而感到满意的一天后回来,结果发现你的性能崩溃了,在经历了 15 小时的训练后(因果报应来得很快)。
接下来的几天变成了试验、测试和实验的旋风,每个潜在的想法都需要超过一天的运行时间。这些迅速开始积累数百美元的计算成本,所有这些都导致了一个大问题:我们如何才能让这一切变得更快、更便宜?
欢迎来到机器学习优化的情感过山车!这里有一个简单的 4 步流程,可以使局势对你有利:
-
基准测试
-
简化
-
优化
-
重复
这是一个迭代过程,很多时候你会在进行下一步之前重复某些步骤,所以这不只是一个 4 步系统,更像是一个工具箱,但 4 步听起来更好。
1 — 基准测试
“测量两次,切割一次”—某位智者。
你应该始终做的第一件(可能也是第二件)事是对系统进行性能分析。这可以是简单地计时特定代码块运行所需的时间,也可以是复杂的全程性能跟踪。重要的是你有足够的信息来识别系统中的瓶颈。我根据我们在过程中所处的阶段进行多次基准测试,并通常将其分为两种类型:高层次和低层次基准测试。
高层次
这就是你会在每周“我们到底有多糟糕?”会议上向老板展示的内容,并希望这些指标成为每次运行的一部分。这些将给你一个关于系统性能的高层次感受。
Batches Per Second——我们每秒处理多少批次?这应该尽可能高。
Steps Per Second——(特指强化学习)我们在环境中生成数据的速度是多少,应该尽可能高。这里有一些复杂的步伐时间与训练批次之间的相互作用,我在这里不详细讨论。
GPU Util——在训练过程中你的 GPU 使用了多少?这应该始终接近 100%,如果不是,那么你有可以优化的空闲时间。
CPU Util——在训练过程中你的 CPU 使用了多少?同样,这应该尽可能接近 100%。
FLOPS——每秒浮点运算次数,这能让你了解你是如何有效利用总硬件的。
低层次
使用上述指标后,你可以进一步深入查看瓶颈可能出现在何处。一旦有了这些信息,你需要开始查看更细粒度的指标和分析。
时间分析 — 这是最简单且通常最有用的实验。像cprofiler这样的分析工具可以帮助你从整体上了解每个组件的时间消耗,或者查看特定组件的时间。
内存分析 — 另一个优化工具箱中的常见工具。大型系统需要大量内存,所以我们必须确保没有浪费内存!像memory-profiler这样的工具将帮助你缩小系统消耗 RAM 的范围。
模型分析 — 像Tensorboard这样的工具提供了优秀的分析工具,用于查看你的模型中哪些部分正在消耗性能。
网络分析 — 网络负载是导致系统瓶颈的常见原因。像wireshark这样的工具可以帮助你进行网络分析,但说实话,我从未使用过。相反,我更倾向于对我的组件进行时间分析,测量组件内部所需的总时间,然后隔离网络 I/O 本身所花费的时间。
确保查看这篇关于 Python 性能分析的优秀文章,RealPython,以获取更多信息!
2 — 简化
一旦你在性能分析中确定了需要优化的区域,就要进行简化。去除除该部分之外的所有内容。继续将系统简化为更小的部分,直到找到瓶颈。不要害怕在简化过程中进行性能分析,这将确保你在迭代过程中走在正确的方向上。继续重复这个过程,直到找到你的瓶颈。
提示
-
用存根和模拟函数替换其他组件,这些存根和模拟函数仅提供预期的数据。
-
使用
sleep
函数或虚拟计算来模拟重负载函数。 -
使用虚拟数据以去除数据生成和处理的开销。
-
从本地、单进程版本的系统开始,然后再转到分布式系统。
-
在单台机器上模拟多个节点和演员,以去除网络开销。
-
找出系统每个部分的理论最大性能。如果系统中所有其他瓶颈都消除了,除了这个组件,我们的预期性能是什么?
-
再次进行性能分析!每次简化系统时,重新运行你的性能分析。
问题
一旦我们锁定了瓶颈,就有一些关键问题需要回答。
这个组件的理论最大性能是多少?
如果我们已经充分隔离了瓶颈组件,那么应该能够回答这些问题。
我们距离最大性能还有多远?
这个最优性差距将告诉我们系统的优化程度。现在,可能会出现其他硬性约束,一旦我们将组件重新引入系统中,这也是可以接受的,但至少要意识到这个差距。
是否存在更深层的瓶颈?
总是问自己这个问题,也许问题比你最初想到的更深层次,在这种情况下,我们需要重复基准测试和简化的过程。
3 — 优化
好的,我们已经识别出了最大的瓶颈,现在进入有趣的部分,我们怎么改进?通常我们应该关注 3 个可能的改进领域。
-
计算
-
通信
-
内存
计算
为了减少计算瓶颈,我们需要尽可能高效地使用数据和算法。这显然是项目特定的,有很多可以做的事情,但让我们来看一些好的经验法则。
并行化 — 确保尽可能多地进行并行工作。这是设计系统时第一个显著的胜利,可以大幅度提升性能。考虑使用向量化、批处理、多线程和多进程等方法。
缓存 — 尽可能地预计算和重用计算结果。许多算法可以利用预计算的值,从而节省每一步训练中的关键计算。
卸载 — 我们都知道 Python 速度不快。幸运的是,我们可以将关键计算卸载到低级语言如 C/C++。
硬件扩展 — 这有点偷懒,但当一切都失败时,我们总可以增加更多的计算机来解决问题!
通信
任何经验丰富的工程师都会告诉你,沟通是成功交付项目的关键,我们当然是指系统内部的沟通(天哪,我们希望不要跟同事交流)。一些好的经验法则包括:
无闲置时间 — 你所有可用的硬件必须始终被利用,否则你将错失性能提升。这通常是由于系统间通信的复杂性和开销所致。
保持本地化 — 在迁移到分布式系统之前,尽可能长时间地将所有内容保留在单台机器上。这使你的系统保持简单,同时避免了分布式系统的通信开销。
异步 > 同步 — 识别任何可以异步完成的任务,这将有助于通过在数据移动的同时保持工作进行,从而减轻通信的成本。
避免数据移动 — 将数据从 CPU 移动到 GPU 或从一个进程移动到另一个进程是昂贵的!尽量减少这种操作,或者通过异步方式减少其影响。
内存
最后但同样重要的是内存。上述许多领域可以帮助缓解瓶颈,但如果没有足够的内存,这可能是不可能的!让我们来看一些需要考虑的事项。
数据类型 — 保持这些尽可能小,有助于减少通信和内存成本,并且与现代加速器一起,它还会减少计算。
缓存 — 类似于减少计算,聪明的缓存可以帮助节省内存。然而,确保你的缓存数据使用频率足够高,以证明缓存的合理性。
预分配 — 在 Python 中我们不太习惯这样做,但严格进行内存预分配可以让你准确知道所需内存量,减少碎片化的风险,并且如果你能够写入共享内存,你将减少进程之间的通信!
垃圾回收 — 幸运的是,Python 处理了大部分这方面的工作,但重要的是确保你没有在作用域中保留不必要的大值,或者更糟的是,存在可能导致内存泄漏的循环依赖。
懒惰 — 仅在必要时评估表达式。在 Python 中,你可以使用生成器表达式代替列表推导式,以便进行惰性计算的操作。
4 — 重复
那么,我们什么时候才算完成呢?这真的取决于你的项目、需求是什么,以及在你渐渐崩溃之前需要多久!
随着你消除瓶颈,你在优化系统时投入的时间和精力将会得到递减的回报。在这个过程中,你需要决定何时“足够好”。记住,速度是实现目标的一种手段,不要陷入为优化而优化的陷阱。如果对用户没有影响,那么可能是时候继续前进了。
结论
构建大规模 ML 系统是困难的。这就像玩一个扭曲的“沃尔多在哪里”游戏,混合了《黑暗之魂》的元素。如果你真的找到问题,你必须进行多次尝试才能解决,而且你会花费大部分时间被虐待,问自己“我为什么要在周五晚上做这些?”。有一个简单且有原则的方法可以帮助你通过最终 boss 战,并品尝到那些甜美的理论最大 FLOPs。
## ML in Action | Donal Byrne | Substack
提供未经请求的建议、实用见解和在快速发展的领域中学到的经验的机器学习通讯…
donalbyrne.substack.com
使用超级收敛加速你的深度学习模型训练
原文:
towardsdatascience.com/supercharge-training-of-your-deep-learning-models-7168ff81a042?source=collection_archive---------7-----------------------#2023-11-22
使用单周期学习率实现超级收敛
Raghav Bali
·
关注 发表在《走向数据科学》 ·7 分钟阅读·2023 年 11 月 22 日
--
照片由Philip Swinburn拍摄,来源于Unsplash
你是否遇到这样的情况:一开始提高准确率很容易,但一旦达到 90%,就必须非常努力才能进一步提高性能?你的模型训练时间太长吗?
在本文中,我们将探讨一种有趣的技术,以加速你的训练设置,并获得你一直寻求的额外性能,提高训练速度。本质上,我们将致力于通过一种称为一周期学习率的策略,在训练过程中动态调整学习率。
原本在 Leslie Smith 的论文中提到的一周期学习率计划[1], [2],专注于一种独特的策略,在训练过程中动态更新学习率。听起来术语很多,别担心,让我们先从一个典型的训练设置开始,然后逐渐理解如何通过一周期学习率来改进结果。
训练图像分类器
当我们致力于学习一种提高模型性能的巧妙技巧(周期率)时,何不在享受经典的石头剪子布游戏时进行呢。
由Markus Spiske拍摄于Unsplash
问题陈述
石头剪子布是一个经典的儿童游戏,涉及两个玩家使用手势(石头、纸或剪刀)进行竞争,以压倒对手。例如,石头手势战胜剪刀,但纸手势战胜石头。有趣吧?
我们的目标是训练一个图像分类模型,能够检测三种手势之一。然后,我们可以利用这样的训练模型开发一个端到端的游戏。为了本文的目的,我们将把范围限制在训练分类器本身,端到端的游戏和可部署的模型是另一个可能的文章主题。
数据集
我们很幸运已经有一个带标签的数据集,我们可以利用这个数据集有效地训练分类模型。数据集托管在TensorFlow 数据集目录中,由劳伦斯·莫罗尼(CC BY 2.0)提供。数据集具有以下属性:
-
数据点数量:2800
-
类别数量:3
-
可用的训练-测试拆分:是
-
数据集大小:220 MiB
TensorFlow 提供了一个干净的 API 来访问这些数据集,以下代码片段允许我们下载训练和验证拆分
import tensorflow_datasets as tfds
DATASET_NAME = 'rock_paper_scissors'
(dataset_train_raw, dataset_test_raw), dataset_info = tfds.load(
name=DATASET_NAME,
data_dir='tmp',
with_info=True,
as_supervised=True,
split=[tfds.Split.TRAIN, tfds.Split.TEST],
)
# plot samples from the dataset
fig = tfds.show_examples(dataset_train_raw, dataset_info)
以下是该数据集的一些样本图像:
图:石头剪子布数据集中的样本数据点
学习率
学习率是一个关键的超参数,它可能决定设置的成败,但通常被忽视。忽视的原因是,大多数库/包提供了足够好的默认值,但这些默认值只能带你走到一定程度。
对于像我们这样的定制使用案例,获得正确的学习率非常重要。找到最佳值是一项棘手的权衡。学习率设置得过慢(或过小),你的模型几乎无法学到任何东西。设置得过快(或过大),它将超越所有神经网络旨在找到的神秘最小值。下图展示了这一点,以便更好地理解。
图示:学习率对模型学习目标(最小值)能力的影响。来源:作者
梯度下降与优化器
梯度下降是训练/优化神经网络的标准方法。它通过更新网络参数,使其朝着梯度的反方向前进,从而最小化目标函数。不深入细节,它有助于沿着目标函数的坡度向下行进。有关梯度下降的详细介绍,请参考这里。
深度学习社区自最初模型使用基础梯度下降进行训练以来已经取得了长足的进步。多年来,许多改进帮助更快地训练并避免明显的陷阱。简要地说,一些显著和最受欢迎的方法有:
AdaGrad 自适应梯度算法 是一种优化算法,根据各个参数的历史梯度调整学习率,从而对不频繁的参数进行较大更新,对频繁的参数进行较小更新。它旨在有效处理稀疏数据,特别适合处理稀疏数据。
RMSProp 均方根传播 通过对每个参数单独调整学习率来优化学习。它通过使用平方梯度的移动平均值来解决 AdaGrad 中学习率递减的问题。这有助于根据最近的梯度大小自适应地缩放学习率。
ADAM 自适应矩估计 是一种优化算法,结合了 RMSProp 和动量方法的思想。它保持过去梯度和平方梯度的指数衰减平均值,利用这些值自适应地更新参数。ADAM 以其高效性和有效性在训练深度神经网络中而闻名。
One-Cycle Learning Rate 和超收敛
One-Cycle Learning Rate 是一个简单的两步过程,用于在训练过程中改进学习率和动量。其工作原理如下:
-
第 1 步:我们从较低的学习率开始,在几个时期内以线性递增的方式逐步提高到较高的值
-
第 2 步:我们在几个时期内维持学习率的最高值
-
第 3 步:然后我们回到一个较低的学习率,并随着时间的推移逐渐衰减
在这三个步骤中,动量在完全相反的方向上进行更新,即当学习率上升时,动量下降,反之亦然。
一周期学习率的实际应用
首先,我们将通过一个简单的 1 周期学习率实现进行操作,然后用它来训练我们的模型。我们将利用Martin Gorner 2019 年在TensorFlow World的演讲中现成的 1 周期 LR 计划实现,如清单 2 所示。
def lr_function(epoch):
# set start, min and max value for learning rate
start_lr = 1e-3; min_lr = 1e-3; max_lr = 2e-3
# define the number of epochs to increase
# LR lineary and then the decay factor
rampup_epochs = 6; sustain_epochs = 0; exp_decay = .5
# method to update the LR value based on the current epoch
def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs,
sustain_epochs, exp_decay):
if epoch < rampup_epochs:
lr = ((max_lr - start_lr) / rampup_epochs
* epoch + start_lr)
elif epoch < rampup_epochs + sustain_epochs:
lr = max_lr
else:
lr = ((max_lr - min_lr) *
exp_decay**(epoch - rampup_epochs -
sustain_epochs) + min_lr)
return lr
return lr(epoch, start_lr, min_lr, max_lr,
rampup_epochs, sustain_epochs, exp_decay)
我们执行这个函数(见 清单 2)来展示学习率如何根据我们之前讨论的两个步骤进行变化。这里我们从1e-3的初始学习率开始,并在前几个 epoch 中将其提升至2e-3。然后在剩余的 epoch 中将其再次降低至1e-3。这种动态学习率曲线在以下 24 个 epoch 的样本运行中得以展示。
24 个 epoch 的 1 周期学习率策略。学习率线性上升,然后在剩余的 epoch 中缓慢衰减。图像来源:作者
我们将通过在使用 MobileNetV2 模型作为特征提取器,并为当前的石头剪子布分类任务训练一个分类头时,测试我们的 1 周期学习率调度器。然后我们将其与简单的 CNN 以及使用标准 Adam 优化器的 MobileNetV2+分类头进行比较。完整的笔记本可以在github上找到参考。以下片段快速概述了我们如何使用 TensorFlow 回调来插入我们的 1 周期学习率工具。
# Set Image Shape
INPUT_IMG_SHAPE= (128, 128, 3)
# Get Pretrained MobileNetV2
base_model = tf.keras.applications.MobileNetV2(
input_shape=INPUT_IMG_SHAPE,
include_top=False,
weights='imagenet',
pooling='avg'
)
# Attach a classification head
model_lr = tf.keras.models.Sequential()
model_lr.add(base_model)
model_lr.add(tf.keras.layers.Dropout(0.5))
model_lr.add(tf.keras.layers.Dense(
units=NUM_CLASSES,
activation=tf.keras.activations.softmax,
kernel_regularizer=tf.keras.regularizers.l2(l=0.01)
))
# compile the model
model_lr.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
# set number of epochs
initial_epochs = 24
# Set the model for training
# The LearningRateScheduler callback is where we
# plug our custom 1-cycle rate function
training_history_lr = model_lr.fit(
x=dataset_train_augmented_shuffled.repeat(),
validation_data=dataset_test_shuffled.repeat(),
epochs=initial_epochs,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
callbacks=[
tf.keras.callbacks.LearningRateScheduler(lambda epoch: \
lr_function(epoch),
verbose=True)
],
verbose=1
)
我们用批量大小为 64 的情况下训练了所有 3 个模型 24 个 epoch。下图展示了 1 周期学习率的影响。与其他两个模型相比,它能帮助我们的模型在仅 5 个 epoch 内实现收敛。超收敛现象在验证数据集上也可见。
使用 1 周期学习率(mobileNetV2_lr)的 MobileNetV2 在 5 个 epoch 内即可收敛,表现优于 MobileNetV2 和简单 CNN 架构。
我们在 10 个 epoch 内达到了 90–92%的验证准确率,这在所有模型中是迄今为止表现最好的。模型在测试数据集上的表现也显示了同样的情况,即 MobileNetV2_lr 轻松超越了其他两个模型。
# Simple CNN
Test loss: 0.7511898279190063
Test accuracy: 0.7768816947937012
# MobileNetV2
Test loss: 0.24527719616889954
Test accuracy: 0.9220430254936218
# MobileNetV2_LR
Test loss: 0.27864792943000793
Test accuracy: 0.9166666865348816
结论
克服模型性能超过 90%准确率的瓶颈并优化训练时间,可以通过实施One-Cycle Learning Rate来实现。这一技术由 Leslie Smith 及其团队提出,在训练过程中动态调整学习率,提供了一种战略性方法以增强模型性能。通过采用这一方法,你可以有效地应对训练设置的复杂性,并发掘更快、更有效的深度学习模型的潜力。拥抱One-Cycle Learning Rate的力量,提高你的训练体验,实现卓越的结果!
使用营销组合建模来超级提升你的跨渠道客户获取
原文:
towardsdatascience.com/supercharge-your-cross-channel-customer-acquisition-with-marketing-mix-modeling-21ba1bc5f2e8?source=collection_archive---------8-----------------------#2023-02-22
离线营销和销售渠道正在回归,你需要适应
艾薇·刘
·
关注 发布于 Towards Data Science ·5 min 阅读·2023 年 2 月 22 日
--
随着数字媒体上的客户获取成本持续上涨,越来越多的消费品牌开始寻求多样化的营销支出。品牌正在重新评估传统渠道,如电视、广播、店内营销等。然而,这些渠道的数据追踪有限。现在,品牌面临着在所有渠道中可靠衡量营销绩效的新挑战。
此外,消费者在经历了三年的 COVID 后正回归到线下购物。对于同时在线上和线下销售的品牌来说,全面理解营销和销售表现是困难的。
此外,通过视频或网红进行促销的品牌难以揭示这些营销努力的回报(或没有回报)。随着 IDFA 和 Cookies 的消退,使用归因方法测量视图广告表现成为不可能的任务。
图片来源于作者
这些问题每年让品牌损失数百万美元。在当前经济不确定的情况下,防止这种浪费比以往任何时候都更为关键。因此,品牌高管们正在要求他们的营销和数据团队探索潜在的解决方案,并确定最有效的方案。在本文中,我将讨论如何通过营销组合建模(MMM)来解决这些难题,并在当今充满挑战的经济环境中获得优势。
使用 MMM 测量营销信号
经典的衡量营销活动是否带来回报的方法是通过回归分析。我们可以通过评估每个自变量与回报之间的相关性来发现变量是否带来了回报。
图片来源于作者
MMM 基于复杂的回归分析,处理包括广告支出、宏观经济和其他外部因素以及收入在内的许多输入,然后将转化信用分配给每个输入。典型的 MMM 输入包括以下内容。
图片来源于作者
MMM 展示了因素贡献如下。
图片来源于作者
MMM 实践中的应用
让我们来看一个现实例子。一家美妆品牌在其店面、Amazon、Sephora 和便利店销售产品,并在 Google、Meta、Amazon、电视、播客和店内花费了大量的营销预算。每个月,它都想了解其广告支出如何影响各渠道的销售。美妆品牌还希望了解下个月最佳的营销预算分配。
根据经验,许多美妆品牌的客户在线查看广告后会线下购买,反之亦然。因此,将在线和离线渠道的表现分开是不切实际的。
图片来源于作者
MMM 可以提供帮助。通过将每日广告支出、美妆市场趋势和每日销售数据输入模型,品牌可以看到每个变量对销售的贡献随时间的变化。
图片来源于作者
更好的是,回归模型易于可视化和解释。美妆品牌的业务利益相关者可以评估模型对其业务的拟合程度,并决定是否接受模型结果。
另一方面,通过将不同的广告支出情景输入到训练模型中,美妆品牌将获得这些情景下的相关收入预测。
图片由作者提供。
连接营销和销售渠道之间的点。
MMM 处理如每日广告支出和销售数据等汇总数据,而不是像点击流这样的详细用户级数据。广告支出和销售数据通常可以从主要的营销和销售平台获取,并且每个数据属性通常遵循类似的测量方式。因此,品牌可以期望在广告平台之间进行可比的营销测量。如果没有这种可比性,营销测量将是不可靠的。
同样,MMM 为视图广告提供了可靠的测量。如果品牌尝试用归因模型测量视图广告的表现,测量结果可能会被低估。由于归因模型依赖于点击流数据,如果用户看到广告但没有点击,模型将无法知道用户是否查看了广告。在这种情况下,MMM 成为归因模型的绝佳补充。
MMM 有其局限性。
然而,MMM 并不完美,有许多局限性。
由于 MMM 需要汇总数据,因此需要动态且长期的营销数据以检测足够的市场信号。因此,只有在营销上大量投资的品牌才能利用 MMM。此外,如果品牌希望测量特定的营销渠道,该品牌必须在该渠道上进行长期的积极营销活动。否则,MMM 因数据不足而无法生成有意义的结果。
此外,MMM 通常只能在广告平台层面进行测量,无法在活动层面进行测量。这是因为大多数品牌在特定活动上不会生成足够的数据点。
最后,品牌需要定期对活动进行实验,以创建动态的营销运动,使 MMM 能够发挥作用。并非所有团队都有足够的带宽来有效地操作活动,从而充分利用 MMM。
营销测量应服务于你的收入目标。
总结来说,以下是 MMM 可以独特地为品牌增值的使用场景:
-
在多个营销渠道之间进行多样化投资。
-
同时进行在线和线下销售。
-
在线营销旨在提升线下销售的认知度,反之亦然。
-
对于视频和某些付费社交广告等视图广告进行大量投入。
然而,在某些情况下,MMM 可能对品牌的帮助有限:
-
如果特定渠道的营销或销售历史较短,MMM 将无法对这些渠道进行准确测量。
-
测量目标设定在活动层面。
-
品牌在进行营销实验的带宽有限,因此没有足够的营销信号供 MMM 使用。
重要的是要记住,没有任何一种营销测量方法是完美的。营销人员和数据团队应该寻找一种最佳契合其使用案例的技术组合,以帮助他们通过营销获得更多利润。在当前的经济环境中,运营时减少开支至关重要,而营销测量方法可以帮助我们实现这一目标。
我在我的文章中讨论了如何利用数据科学提升您的业务和优化您的营销。如果您想讨论营销测量或其他数据科学话题,请在LinkedIn上关注我,或通过 newsletter@ivyliu.io 与我联系。下次见。
用这个新工具提升你的数据清洗技能
原文:
towardsdatascience.com/supercharge-your-data-cleaning-game-with-this-new-tool-d43e99cdc6a5
PYTHON | 数据 | 分析
一份利用 pandas_dq 轻松进行数据清洗的指南
David Farrugia
·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 4 月 19 日
--
照片由 JESHOOTS.COM 提供,来源于 Unsplash
如果这篇文章的标题引起了你的兴趣,那么你肯定意识到数据清洗和预处理步骤对整体分析项目的重要性。
如果你准备训练一个机器学习模型,或者你只是想进行一些探索性数据分析,脏数据无疑是你面前的障碍。你可能听过这样一句话:准备数据是 80%的工作。
数据清洗过程可能是数据分析过程中最耗时、最繁琐、最令人沮丧的部分之一。查找重复项、多重共线性、缺失值或无限值等,都是从理解数据和提取可操作洞察中浪费的宝贵时间。
在这篇文章中,我们将探讨令人惊叹的 Python 包pandas_dq
,以及它如何提升你下一个数据清洗任务的速度和质量。
[## GitHub - AutoViML/pandas_dq: 用一行代码发现数据质量问题并清理数据…
用兼容 Scikit-Learn 的转换器用一行代码发现数据质量问题并清理数据。…
github.com
首先…
我们需要安装这个包。pandas_dq
可以通过以下方式获得:
pip install pandas_dq
或者,你可以从源代码安装:
# download from https://github.com/AutoViML/pandas_dq/archive/master.zip
cd <pandas_dq_Destination>
git clone git@github.com:AutoViML/pandas_dq.git
它是如何工作的?
给定一个数据集,开始使用这个工具非常简单。它目前有 3 个主要功能:
-
dq_report
-
Fix_DQ
-
DataSchemaChecker
dq_report
这个函数的目的是生成包含数据集所有数据质量问题的报告。
假设我们有 iris 数据集
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')
我们可以如下运行 dq_report
:
from pandas_dq import dq_report
dq_report(df, target=None, verbose=1)
这将给我们带来以下结果:
作者提供的图像
该包迅速告诉我们特征sepal_width
有 4 个异常值,并建议我们要么对其进行修剪(将任何异常值设置为最大值),要么将其移除。它还可以识别是否存在多重共线性特征(高度相关的特征),例如petal_length
和petal_width
。
我们还有选项运行针对目标的数据质量检查。在有监督学习任务中(每当我们有目标/标签需要预测时),我们还需要检查特征与目标之间的关系。dq_report
通过允许我们指定目标列,使这一过程变得非常简单。
注意:目标列不能是非数值型值
# map string target to numeric
df['species'] = pd.factorize(df['species'])[0]
dq_report(df, target='species', verbose=1)
作者提供的图像
这个函数执行所有这些检查:
-
它检测 ID 列
-
它检测零方差列
-
它识别稀有类别(少于列中 5%的类别)
-
它查找列中的无限值
-
它检测混合数据类型(即具有多于一种数据类型的列)
-
它检测异常值(即超出四分位范围的浮点列)
-
它检测高基数特征(即具有超过 100 个类别的特征)
-
它检测高度相关的特征(即两个特征的绝对相关性高于 0.8)
-
它检测重复行(即数据集中同一行出现多次)
-
它检测重复列(即数据集中同一列出现两次或更多次)
-
它检测偏斜分布(即偏斜度超过 1.0 的特征)
-
它检测不平衡的类别(即目标变量中的某一类别显著多于其他类别)
-
它检测特征泄漏(即与目标高度相关的特征,相关性 > 0.8)
Fix_DQ
这个函数执行dq_report
所做的所有检查,但也在一行代码中进行处理。这通常在准备建模时对一组特征(排除目标列)进行操作。
from pandas_dq import Fix_DQ
fdq = Fix_DQ()
result = fdq.fit_transform(df.drop('species', axis=1))
我们得到以下输出:
Alert: Detecting 1 duplicate rows...
Dropping petal_length which has a high correlation with ['sepal_length']
Dropping petal_width which has a high correlation with ['sepal_length', 'petal_length']
Alert: Dropping 1 duplicate rows can sometimes cause column data types to change to object. Double-check!
以及结果数据框:
作者提供的图像
DataSchemaChecker
这个函数接受一个数据模式,并确保我们的数据框遵循该模式。这在我们想确保列的数据类型时很有用,无论是因为我们想在使用预训练模型时确保一致的数据类型,执行某种类型验证、序列化,甚至将其导入到数据库中。
这个函数的主要特点(可能是一个 bug)是当没有数据类型问题时,它会输出 AttributeError。
from pandas_dq import DataSchemaChecker
wrong_schema = dict(zip(df.columns, ['float64', 'float64', 'float64', 'float64', 'float64']))
ds = DataSchemaChecker(schema=wrong_schema)
ds.fit_transform(df)
作者提供的图片
总结
pandas_dq
仍然是一个相对较新的工具,它极大地改善了我的数据清洗过程。除了自动化整个流程,它还以惊人的速度执行这些步骤。你绝对需要检查它的结果并深入挖掘——但它在缩小焦点方面非常有帮助——在这个过程中节省了宝贵的时间。
你喜欢这篇文章吗?只需每月 $5,你就可以成为会员,解锁对 Medium 的无限访问权限。你将直接支持我以及你在 Medium 上的其他喜爱作家。非常感谢!
[## 使用我的推荐链接加入 Medium - David Farrugia
获取对我所有⚡高级⚡内容的独家访问权限,并在 Medium 上畅享无限制的阅读。通过购买我提供的服务来支持我的工作…
david-farrugia.medium.com
想要联系我吗?
我很想听听你对这个话题的看法,或者对 AI 和数据的任何想法。
如果你希望联系我,请发邮件到 davidfarrugia53@gmail.com。
用aiomultiprocess
超级增强你的 Python Asyncio:一份全面指南
原文:
towardsdatascience.com/supercharge-your-python-asyncio-with-aiomultiprocess-a-comprehensive-guide-571ee0e2f416
PYTHON TOOLBOX
利用asyncio
和multiprocessing
的力量来加速你的应用程序
Peng Qian
·发表于Towards Data Science ·阅读时间 9 分钟·2023 年 7 月 5 日
--
图片来源:作者创作,Canva
在这篇文章中,我将带你深入了解aiomultiprocess
,一个结合了 Python asyncio
和multiprocessing
强大功能的库。
这篇文章将通过丰富的代码示例和最佳实践进行解释。
在文章结束时,你将理解如何利用 aiomultiprocess 的强大功能来增强你的 Python 应用程序,就像主厨带领一组厨师来准备一顿丰盛的宴席。
引言
想象一下你想在周末邀请你的同事们来一顿大餐。你会怎么做?
作为一名经验丰富的厨师,你肯定不会一次只做一道菜;那样太慢了。你会高效地利用时间,让多个任务同时进行。
例如,当你等待水开时,你可以离开去洗菜。这样,当水开时,你可以把菜放进锅里。这就是并发的魅力。
然而,食谱常常很残酷:你需要在做汤时不停搅拌;蔬菜需要洗净并切碎;你还需要烤面包、煎牛排等等。
当有很多菜需要准备时,你会感到不知所措。
幸运的是,你的同事们不会只是坐在那等着吃。他们会进厨房来帮助你,每多一个人就像增加一个工作进程。这就是多进程和并发的强大结合。
代码也是如此。即使使用 asyncio,你的 Python 应用程序是否仍然遇到瓶颈?你是否在寻找进一步提高并发代码性能的方法?如果是的话,aiomultiprocess
是你一直在寻找的答案。
如何安装和基本用法
安装
如果你使用 pip,按照如下方式安装:
python -m pip install aiomultiprocess
如果你使用 Anaconda,从 conda-forge 安装:
conda install -c conda-forge aiomultiprocess
基本用法
aiomultiprocess
由三个主要类组成:
Process
是其他两个类的基类,用于启动一个进程并执行协程函数。通常你不需要使用这个类。
Worker
用于启动一个进程,执行一个协程函数,并返回结果。我们也不会使用这个类。
Pool
是我们将重点关注的核心类。与 multiprocessing.Pool
类似,它启动一个进程池,但其上下文需要使用 async with
来管理。我们将使用 Pool 的两个方法:map
和 apply
。
map
方法接受一个协程函数和一个可迭代对象。Pool
将遍历可迭代对象,并将协程函数分配给不同的进程运行。map
方法的结果可以使用 async for:
进行异步迭代
import asyncio
import random
import aiomultiprocess
async def coro_func(value: int) -> int:
await asyncio.sleep(random.randint(1, 3))
return value * 2
async def main():
results = []
async with aiomultiprocess.Pool() as pool:
async for result in pool.map(coro_func, [1, 2, 3]):
results.append(result)
print(results)
if __name__ == "__main__":
asyncio.run(main())
apply
方法接受一个协程函数和该函数所需的参数元组。根据调度器的规则,Pool
将协程函数分配给适当的进程进行执行。
import asyncio
import random
import aiomultiprocess
async def coro_func(value: int) -> int:
await asyncio.sleep(random.randint(1, 3))
return value * 2
async def main():
tasks = []
async with aiomultiprocess.Pool() as pool:
tasks.append(pool.apply(coro_func, (1,)))
tasks.append(pool.apply(coro_func, (2,)))
tasks.append(pool.apply(coro_func, (3,)))
results = await asyncio.gather(*tasks)
print(results) # Output: [2, 4, 6]
if __name__ == "__main__":
asyncio.run(main())
实现原理和实际示例
aiomultiprocess.Pool 的实现原理
在上一篇文章中,我解释了如何将 asyncio 任务分布到多个 CPU 核心上。
通用方法是在主进程中使用 [loop.run_in_executor](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor)
启动一个进程池。然后,在进程池中的每个进程中创建一个 asyncio 事件循环,并在各自的循环中执行协程函数。示意图如下:
此图展示了 asyncio 和 multiprocessing 的集成方式。图片由作者提供
aiomultiprocess.Pool
的实现类似。它包括 scheduler
、queue
和 process
作为其三个组件。
-
scheduler
可以理解为大厨,负责以合适的方式将任务分配给每个厨师。当然,你可以雇佣(实现)一个适合你需求的大厨。 -
queue
就像厨房的流水线。严格来说,它包括一个订单线和一个交付线。大厨通过订单线将菜单传递给厨师,厨师通过交付线返回完成的菜肴。 -
process
就像餐厅里的厨师。他们根据分配同时处理几个菜肴。每当一道菜准备好时,它将按照分配的顺序交付。
整个示意图如下所示:
Aiomultiprocess 由三个组件组成:调度器、队列和进程。图片由作者提供
现实世界的示例
基于之前提供的介绍,您现在应该理解如何使用aiomultiprocess
。让我们深入一个现实世界的示例,以体验它的强大功能。
首先,我们将使用远程调用和循环计算来模拟实际数据检索和处理过程。这个方法展示了 IO 绑定和 CPU 绑定任务通常混合在一起,它们之间的界限并不那么明确。
import asyncio
import random
import time
from aiohttp import ClientSession
from aiomultiprocess import Pool
def cpu_bound(n: int) -> int:
result = 0
for i in range(n*100_000):
result += 1
return result
async def invoke_remote(url: str) -> int:
await asyncio.sleep(random.uniform(0.2, 0.7))
async with ClientSession() as session:
async with session.get(url) as response:
status = response.status
result = cpu_bound(status)
return result
接下来,我们将使用传统的 asyncio 方法调用此任务 30 次作为基线:
async def main():
start = time.monotonic()
tasks = [asyncio.create_task(invoke_remote("https://www.example.com"))
for _ in range(30)]
await asyncio.gather(*tasks)
print(f"All jobs done in {time.monotonic() - start} seconds")
if __name__ == "__main__":
asyncio.run(main())
代码使用传统的 asyncio 方法运行。截图由作者提供
代码执行结果如图所示,耗时约 21 秒。现在让我们看看 aiomultiprocess 能带来多大的改进。
使用 aiomultiprocess 很简单。原始的并发代码无需修改。您只需调整主方法中的代码,使其在 Pool 内部运行:
async def main():
start = time.monotonic()
async with Pool() as pool:
tasks = [pool.apply(invoke_remote, ("https://www.example.com",))
for _ in range(30)]
await asyncio.gather(*tasks)
print(f"All jobs done in {time.monotonic() - start} seconds")
if __name__ == "__main__":
asyncio.run(main())
仅需使用修改后的 aiomultiprocess 版本。截图由作者提供
如您所见,使用 aiomultiprocess 的代码在我的笔记本电脑上只需 14 秒即可完成。性能提升在更强大的计算机上会更大。
详细的最佳实践
最后,基于我的经验,我将分享一些更实用的最佳实践。
仅使用池
尽管aiomultiprocess
也提供了Process
和Worker
类供我们选择,但由于创建进程会消耗大量资源,我们应始终使用Pool
类以确保最大效率。
如何使用队列
在上一篇文章中,我解释了如何使用asyncio.Queue
实现生产者-消费者模式来平衡资源和性能。
在aiomultiprocess
中,我们也可以使用队列。然而,由于我们处于进程池中,不能使用asyncio.Queue
。同时,我们也不能在进程池中直接使用multiprocessing.Queue
。
在这种情况下,您应该使用multiprocessing.Manager().Queue()
来创建队列,代码如下:
import random
import asyncio
from multiprocessing import Manager
from multiprocessing.queues import Queue
from aiomultiprocess import Pool
async def worker(name: str, queue: Queue):
while True:
item = queue.get()
if not item:
print(f"worker: {name} got the end signal, and will stop running.")
queue.put(item)
break
await asyncio.sleep(random.uniform(0.2, 0.7))
print(f"worker: {name} begin to process value {item}", flush=True)
async def producer(queue: Queue):
for i in range(20):
await asyncio.sleep(random.uniform(0.2, 0.7))
queue.put(random.randint(1, 3))
queue.put(None)
async def main():
queue: Queue = Manager().Queue()
producer_task = asyncio.create_task(producer(queue))
async with Pool() as pool:
c_tasks = [pool.apply(worker, args=(f"worker-{i}", queue))
for i in range(5)]
await asyncio.gather(*c_tasks)
await producer_task
if __name__ == "__main__":
asyncio.run(main())
使用initializer
初始化资源
假设您需要在协程方法中使用aiohttp
会话或数据库连接池,但由于这些对象无法被序列化,因此我们不能在主进程中创建任务时传递参数。
另一种选择是定义一个全局对象和一个初始化方法。在这个初始化方法中,访问全局对象并进行初始化。
就像 [multiprocessing.Pool](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool)
一样,aiomultiprocess.Pool
在初始化时可以接受一个初始化方法和相应的初始化参数。每个进程启动时都会调用这个方法来完成初始化:
import asyncio
from aiomultiprocess import Pool
import aiohttp
from aiohttp import ClientSession, ClientTimeout
session: ClientSession | None = None
def init_session(timeout: ClientTimeout = None):
global session
session = aiohttp.ClientSession(timeout=timeout)
async def get_status(url: str) -> int:
global session
async with session.get(url) as response:
status_code = response.status
return status_code
async def main():
url = "https://httpbin.org/get"
timeout = ClientTimeout(2)
async with Pool(initializer=init_session, initargs=(timeout,)) as pool:
tasks = [asyncio.create_task(pool.apply(get_status, (url,)))
for i in range(3)]
status = await asyncio.gather(*tasks)
print(status)
if __name__ == "__main__":
asyncio.run(main())
异常处理和重试
尽管 aiomultiprocess.Pool
提供了 exception_handler
参数来帮助处理异常,但如果你需要更多的灵活性,你需要将它与 asyncio.wait
结合使用。关于 asyncio.wait
的使用,你可以参考 我之前的文章。
使用 asyncio.wait
,你可以获取遇到异常的任务。提取任务后,你可以进行一些调整,然后重新执行任务,如下代码所示:
import asyncio
import random
from aiomultiprocess import Pool
async def worker():
await asyncio.sleep(0.2)
result = random.random()
if result > 0.5:
print("will raise an exception")
raise Exception("something error")
return result
async def main():
pending, results = set(), []
async with Pool() as pool:
for i in range(7):
pending.add(asyncio.create_task(pool.apply(worker)))
while len(pending) > 0:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_EXCEPTION)
print(f"now the count of done, pending is {len(done)}, {len(pending)}")
for result in done:
if result.exception():
pending.add(asyncio.create_task(pool.apply(worker)))
else:
results.append(await result)
print(results)
if __name__ == "__main__":
asyncio.run(main())
使用 Tenacity 进行重试
当然,我们还有更灵活和强大的异常处理和重试选项,比如使用 Tenacity
库,我在 这篇文章 中解释了它。
使用 Tenacity
,上述代码可以大大简化。你只需要在协程方法上添加一个装饰器,该方法在抛出异常时会自动重试。
import asyncio
from random import random
from aiomultiprocess import Pool
from tenacity import *
@retry()
async def worker(name: str):
await asyncio.sleep(0.3)
result = random()
if result > 0.6:
print(f"{name} will raise an exception")
raise Exception("something wrong")
return result
async def main():
async with Pool() as pool:
tasks = pool.map(worker, [f"worker-{i}" for i in range(5)])
results = await tasks
print(results)
if __name__ == "__main__":
asyncio.run(main())
使用 tqdm 指示进度
我喜欢 tqdm
,因为它总是能告诉我在屏幕前等待时代码执行的进度。这篇文章 也解释了如何使用它。
由于 aiomultiprocess
使用 asyncio 的 API 等待任务完成,它也与 tqdm
兼容:
import asyncio
from random import uniform
from aiomultiprocess import Pool
from tqdm.asyncio import tqdm_asyncio
async def worker():
delay = uniform(0.5, 5)
await asyncio.sleep(delay)
return delay * 10
async def main():
async with Pool() as pool:
tasks = [asyncio.create_task(pool.apply(worker)) for _ in range(1000)]
results = await tqdm_asyncio.gather(*tasks)
print(results[:10])
if __name__ == "__main__":
asyncio.run(main())
结论
运行 asyncio 代码就像厨师做饭一样。即使你可以通过并发执行不同的任务来提高效率,但最终你还是会遇到瓶颈。
在这一点上,最简单的解决方案是增加更多的厨师,以提高烹饪过程的并行性。
Aiomultiprocess
是一个强大的 Python 库。它通过允许并发任务在多个进程中运行,完美地突破了 asyncio 单线程特性造成的性能瓶颈。
本文中 aiomultiprocess
的使用和最佳实践基于我的工作经验。如果你对任何方面感兴趣,欢迎评论和参与讨论。
除了提高代码执行速度和性能,使用各种工具来提高工作效率也是一种性能提升:
彭茜
Python 工具箱
查看列表6 个故事!Seaborn 0.12:对象接口和声明式图形的深度指南 [## 通过我的推荐链接加入 Medium - 彭乾
作为 Medium 的会员,您的部分会员费用会分配给您阅读的作者,并且您可以完全访问每一个故事……
qtalen.medium.com
本文最初发表于:www.dataleadsfuture.com/supercharge-your-python-asyncio-with-aiomultiprocess-a-comprehensive-guide/
用 ChatGPT 超级提升你的电子表格
原文:
towardsdatascience.com/supercharge-your-spreadsheets-with-chatgpt-4f49f7faf676
重新定义你使用电子表格的方式。
Soner Yıldırım
· 发表在 Towards Data Science · 阅读时间 4 分钟 · 2023 年 5 月 29 日
--
图片由 Rubaitul Azad 提供,来源于 Unsplash
插件数量的增加展示了 ChatGPT 或其他大型语言模型(LLMs)的强大功能。
这些插件使 OpenAI API 能够与各种工具集成,从而方便地使用 OpenAI 提供的模型,包括 GPT-3.5 和 GPT-4。
可以通过大型语言模型(LLMs)增强的工具之一是 Google 电子表格。电子表格的实用性和易于访问使其成为处理表格数据的顶级工具之一。当它们配备 LLMs 时,最终结果会令人震惊。
让我们设置“GPT for Sheets and Docs”插件,探索它如何在只需一次点击的情况下处理各种任务。
设置 GPT for Sheets and Docs
在浏览器中打开一个新的 Google 电子表格,在“扩展”菜单下找到“附加组件”,点击“获取附加组件”:
(图片来源:作者)
在市场中找到“GPT for Sheets and Docs”并安装它。安装前务必阅读条款和条件,因为你需要授予它访问你 Google 账户中某些内容的权限:
(图片来源:作者)
安装后,你需要提供一个 API 密钥,可以从 OpenAI 网站上的 API Keys 菜单中获取。创建一个新密钥并复制它。然后,只需在插件的弹出菜单中输入该密钥即可。这里有一个逐步的 指南 介绍如何为“GPT for Sheets and Docs”插件整合你的 API 密钥。这个操作只需要做一次,不是每次使用时都需要。
总结文本
假设你是产品经理,需要总结来自电子商务网站的大量产品评价。通常,阅读评价并提取关键点需要很长时间。
我们知道 LLM 在总结文本方面很出色。借助我们刚刚安装的插件,我们可以一键总结产品评价。
我们需要一个主提示语。我正在使用以下提示,但可以根据需要进行自定义。
你的任务是为来自电子商务网站的给定产品评价生成一个简短的总结。最多使用 20 个词。
然后,我们可以在GPT
函数中使用以下提示语:
(作者提供的图片)
按下回车键,你就完成了。你可以像应用其他函数一样,将相同的提示语应用到所有单元格中。
这是我获得的 3 条产品评价的结果:
(作者提供的图片)
如果你不想编写提示语,可以使用GPT_SUMMARIZE
函数,它会总结给定单元格中的文本:
(作者提供的图片)
我总是喜欢编写提示语,因为它可以让我根据需要自定义输出。
评价的情感
假设你不需要总结评价,而是想找出它们的情感。你可以自定义提示语,并以你喜欢的格式获得情感,例如 1 表示积极,0 表示消极。
我将使用以下提示语来提取评价的情感:
你的任务是找出给定产品评价的情感。用以下一个词回答:positive,negative,neutral。
这是两个产品评价的结果:
(作者提供的图片)
天高地迥
我们仅仅触及了我们讨论的两个用例的表面。将 LLM 集成到电子表格中有潜力加速各种任务。一旦你习惯了使用它,你会发现你的效率和效果达到新的高度。
你可以考虑在电子表格中使用 ChatGPT 插件的几种方式:
-
生成自动报告
-
数据验证
-
推测
-
扩展
真的,天高地迥。
它有可能重新定义我们与电子表格的工作方式,不仅提高我们的生产力,还提升我们的创造力。
你可以成为 Medium 会员 以解锁我所有的写作内容,以及 Medium 的其余部分。如果你已经是会员,不要忘记 订阅 以便在我发布新文章时收到电子邮件。
感谢阅读。如果你有任何反馈,请告诉我。
超强 pandas:加密从 DataFrames 写入的 Excel 文件
原文:
towardsdatascience.com/supercharged-pandas-encrypting-excel-files-written-from-dataframes-1251b585145b?source=collection_archive---------7-----------------------#2023-06-12
介绍一个 ExcelHelper 类,它允许你使用强密码或自定义密码加密 Excel 文件
Ji Wei Liew
·
关注 发表在 Towards Data Science ·6 min read·2023 年 6 月 12 日
--
将数据写入 Excel 并加密(图片由作者提供)
介绍
在这篇文章中,我将分享如何结合使用 ExcelHelper
类来在将数据框写入 Excel 后打开并加密 Excel 文件。我在之前的 文章 中已将这一加密功能包含在 to_excelp
函数中。
## 超强大的 pandas:从 Excel 中读取和写入
增强.read_excel 和.to_excel 方法,以便你可以专注于数据探索
towardsdatascience.com
对于数据科学家和机器学习爱好者来说,你会发现这非常有用,因为它可以加快将数据框导出到 Excel 时的工作速度。
动机
在最近的一个项目中,我需要分析数据并为几个人准备统计数据。由于数据包含敏感信息,需要对文件进行密码保护。这与我的技能非常匹配,如果你读过我之前的文章,你会发现,在pypo.py
中,我已经有一个to_excelp
函数,该函数在 Excel 文件通过df.to_excel()
方法创建后打开它。虽然这对我有效,但现在似乎是重新审视实现方式并增加密码保护功能的好时机。
内容
-
打开和加密 Excel 文件
-
生成强密码
-
整合所有内容
完整代码 在这里。
第一部分 — 打开和加密 Excel 文件
在使用 python 和 pandas 时,打开 Excel 文件有很多理由,例如,在测试期间进行视觉检查,或者在发送给利益相关者之前进行格式化。如果有额外的需求来加密 Excel 文件,则需要处理 3 种情况:仅打开,仅加密,打开并加密。如果我们既不打开也不加密 Excel 文件,那么不需要做任何操作,因为df.to_excel()
就足够了。
ExcelHelper
是一个用于启动 Excel(应用程序)并根据提供的path
打开工作簿的类。从程序角度看,这是一个两步过程。大多数人从未意识到这一点,因为当你双击一个 Excel 文件时,Excel 应用程序和工作簿会一起启动。
初始化 ExcelHelper 类
-
__init__(self, launch=True, encrypt=False, password=None, length=20, quit_=True)
这是ExcelHelper
的初始化调用。 -
如果
launch
等于True
,在加密完成后工作簿会被显示出来。 -
如果
encrypt
等于True
,将调用self._encrypt()
方法,稍后会进行解释。 -
password
允许用户输入首选密码,否则将自动建议一个长度为length
的强密码,最大长度为 255。
打开工作簿
-
_open_wb(self, path, visible=False)
将给定路径转换为绝对路径,然后打开它。将路径转换为绝对路径是必要的,否则由win32com.client
调度的应用程序无法找到文件。(之前,我使用了try-except
块来在路径前添加当前工作目录,但这显得过于冗长,并且需要花费一些时间才能真正理解我们要做的事情。) -
visible
控制应用程序是否对用户可见。通常,只有在加密完成后才显示应用程序才有意义。因此,如果我们正在启动并加密,应该在self._encrypt()
完成后才将visible=True
设置为真。
加密 Excel
-
_encrypt(self, visible=False)
加密 Excel 工作簿,然后在加密完成后通过设置self.xl.Visible
属性来显示应用程序。 -
将
self.xl.DisplayAlerts
设置为True
是很重要的,否则启动的 Excel 文件不会发出任何警报(例如,如果你按 Ctrl+F 尝试查找一些无意义的字符而没有任何提示 😱;这曾经发生在我身上,我真的很困惑!)。
执行方法
-
execute(self, path, launch, encrypt, quit_)
处理上述 3 种情况。 -
quit_
参数关闭 Excel 应用程序(尾部下划线是一种约定,表示quit
在 Python 中是一个保留关键字)。当ExcelHelper
被初始化时,如果launch=False
,Excel 应用程序将在后台运行,Excel 文件会被打开
。如果用户现在双击 Excel 文件,他会被提示只能以只读模式打开。对于非技术用户来说,关闭文件相当困难;解决方法是打开任务管理器,选择 Excel 程序,然后结束任务。因此,需要调用.Quit()
来终止 Excel 应用程序。我们本可以直接关闭工作簿,但也许现在不需要处理得如此细致。
第二部分 — 生成强密码
起初,我使用from cryptography.fernet import Fernet; Fernet.generate_key()
来生成随机密码。虽然几个用户对密码的长度和随机性感到惊讶,但我不是很喜欢,因为它有点长,并且不包含多样的标点符号。我在 Google 上查找了一个更好的方法在 StackOverflow 上。(我总是对 StackOverflow 上如何轻松获得高质量答案感到非常惊讶。所有困难的工作已经由所有前辈完成,我们只需搜索、复制、粘贴并做些小调整(例如更改变量名)。)这个函数相当简单直接,自解释性也很强。
import secrets
import string
def gen_password(self, length):
char = string.ascii_letters + string.digits + string.punctuation
return ''.join(secrets.choice(char) for _ in range(length))
正当一切进展得过于顺利时,在测试我的代码时,我注意到有时密码无法用来打开文件!我真的很困惑。经过一番试错后,我开始怀疑可能有些字符是不适合用作密码的,因为这种现象只在密码包含两个反斜杠 \\
时发生。
这里有一些背景信息可以让你更好地理解情况:我使用 Powershell 和 Notepad++,我的代码将密码打印到 stdout
。接下来,我在 Powershell 上突出显示打印出的密码,然后在 Excel 提示我输入密码时粘贴它。所以问题是 \
是一个转义字符,因此在我输入密码时,第一个 \
应该被忽略。处理起来很麻烦,对于密码的目的,我可以少用一个字符。因此,我所做的就是在 string.punctuation
中切除反斜杠。
def _get_password(self, length):
string_punc = string.punctuation[:23] + string.punctuation[24:]
char = string.ascii_letters + string.digits + string_punc
return ''.join(secrets.choice(char) for _ in range(length))
第三部分 — 将一切整合起来
由于如果你不启动或加密 Excel 文件,实例化 ExcelHelper
对象几乎没有附加值,因此应该以 if launch or encrypt:
开始。接下来,仅需将关键字参数从 to_excelp
传递到 ExcelHelper
并返回对象和 password
。
def to_excelp(df, *arg, launch=True, encrypt=False, password=None, **kw):
''' Writes to Excel and opens it'''
filename, *arg = arg
if not filename.endswith(('.xlsx','.xls','.xlsm')):
filename += '.xlsx'
if os.path.isfile(filename):
name, ext = filename.rsplit('.')
filename = f'{name}_{timestr()}.{ext}'
# Default index=False
index = kw.get('index', False)
if not index:
kw['index']=False
df.to_excel(filename, *arg, **kw)
if launch or encrypt:
xl = ExcelHelper(filename, launch=launch, encrypt=encrypt, password=password)
return xl, xl.password
else:
return filename
如果你通过调用这个函数将数据框写入多个不同的 Excel 文件,我建议将结果存储在一个元组列表中。你可以随后遍历这个元组列表以获取 Excel 文件的路径及其密码。存储对象可能在未来很有用,特别是如果你打算为 ExcelHelper
添加更多功能的话。
l_xl_pw = []
for df in (df1, df2, df3, df4):
xl, pw = df.to_excelp(launch=False, encrypt=True, password=None)
l_xl_pw.append((xl, pw))
l_path_pass = [[xl.path, pw] for (xl, pw) in l_xl_pw]
df_path_pass = pd.DataFrame(l_path_pass, columns=['Path', 'Pw'])
# df_path_pass can also be written to Excel using .to_excelp(), how elegant! :D
ExcelHelper
也可以添加到你现有的其他脚本中。
def some_func():
df = pd.read_excel('some_file.xlsx')
# some data manipulation...
df.to_excel('some_file_modified.xlsx')
def some_func(launch=False, encrypt=True, password='5tr0ngP@ssw0rd'):
df = pd.read_excel('some_file.xlsx')
# some data manipulation...
df.to_excel('some_file_modified.xlsx')
if launch or encrypt:
xl = ExcelHelper('some_file_modified.xlsx', launch=launch, encrypt=encrypt, password=password)
return xl, xl.password
结论
重新审视自己写的旧代码就像是走在记忆的长廊上,揭示了当时我知道的少。虽然对此感到非常尴尬,但我很高兴知道自己已经有所进步。
“如果你对自己旧的代码不感到尴尬,那么你作为程序员就没有进步。” [匿名]
编写这些小类和函数需要时间,但它们的好处巨大,因为它们自动化了机械化且不太有趣的工作部分,并使人能够专注于重要任务。(想象一下,总是要考虑包含大写字母、小写字母、数字和标点符号的密码,并将它们存储在文件中。)
Python 中的监督与非监督主题建模方法
原文:
towardsdatascience.com/supervised-unsupervised-approach-to-topic-modelling-in-python-d03e0b9da1dc
从头开始在 Python 中构建主题建模管道
Vatsal
·发表在 Towards Data Science ·11 分钟阅读·2023 年 1 月 31 日
--
图片来源于 Unsplash 由 v2osk
本文将提供关于主题建模及其相关应用的高层次直观理解。将深入探讨解决需要主题建模的问题的各种方法,以及如何在监督和非监督方式下解决这些问题。我强调了数据和初始问题的重构,以便可以通过多种方法执行解决方案。下表详细说明了本文的内容。
目录
-
什么是主题建模?
-
主题建模的应用
-
监督学习与非监督学习
-
问题拆解
-
需求
-
数据
-
加载数据
-
清洗与预处理
-
数据统计
-
-
非监督学习
-
训练模型
-
可视化
-
主题分析
-
-
监督学习
-
关键词统计
-
生成标签
-
训练模型
-
评估
-
-
结论
-
资源
什么是主题建模?
主题建模是自然语言处理(NLP)或文本挖掘的一个子领域,旨在建立模型以解析各种文本,以识别与文本相关的主题。这些模型有助于识别与文档相关的大致主题,适用于大规模的文档处理。它是理解和组织大量文本数据的有用工具,并能帮助组织理解大量非结构化数据。
主题建模的应用
-
文档分类 — 将文档归类为各种主题
-
社交媒体分析 — 识别用户在社交媒体上发帖的主要话题
-
推荐系统 — 根据用户感兴趣的主题推荐产品。一个常见的应用是根据用户感兴趣的主题推荐定制广告。例如,如果用户对汽车感兴趣,他们可能会喜欢来自像本田/丰田这样有前景的汽车品牌的广告。
监督学习与非监督学习
监督学习与非监督学习之间有明确的区别。监督学习涉及在给定标签的情况下训练模型以映射到初始数据集。相反,非监督学习涉及在没有标签信息的情况下训练模型。主题建模通常是非监督学习方法,但本文将涵盖监督学习和非监督学习方法的主题建模。
监督学习方法将包括二分类。二分类是将输入数据映射到恰好 2 个目标,而多类分类是将输入数据映射到超过 2 个目标。二分类主题模型将指示输入文章是否映射到我们已标记的主题中。多类分类主题模型将识别该文章最有可能归属的主题。本文将展示二分类方法的实现。
问题分析
本文旨在解决的问题是,给定论文的摘要,识别与之相关的主要主题。根据识别出的主题,用户可以推断这篇论文是否对他们感兴趣。我们将使用 arXiv 数据库查询并获取多个领域的研究论文。
需求
以下是跟随本教程所需的模块及版本。我的环境中的 Python 版本是 3.10.0.
如果执行过程中发生错误,请注意您引用的模块的版本,因为这是跨平台协作中的常见问题。
pandas>=1.3.5
numpy>=1.22.4
arxiv>=1.4.2
Unidecode>=1.3.6
nltk>=3.7
gensim>=4.2.0
wordcloud>=1.8.2.2
pyLDAvis==2.1.2
如果您尚未安装 gensim 包,这里 是通过命令行安装它的库文档。类似地,您可以按照以下说明在 Python 中安装 arXiv 包,这里 是相关文档。
数据
根据 arXiv 的 API 使用条款,它是完全免费的,并且鼓励使用。有关使用条款的更多信息,请参考其文档,您可以在这里找到。
在本文中,我将展示如何通过 Python 访问 API 来收集构建今天模型所需的以下信息。如果你想通过其他编程语言访问这个 API,或者想了解如何使用 API 的更多信息,我强烈建议你参考他们的文档,你可以在这里找到。
加载数据
下列代码包括导入所需模块、设置在整个项目中使用的常量以及定义一个函数,通过一组提示从 arXiv 查询和加载数据。
该脚本应生成一个类似于下方截图的 pandas DataFrame:
从 arXiv 查询的数据。图像由作者提供。
清理与预处理
现在我们可以清理和预处理与每篇文章相关的摘要信息。处理文本数据的清理和预处理阶段对于优化底层模型的性能至关重要。你喂给模型的数据质量越低,模型在生产环境中的性能也会越低。此外,你清理、预处理和减少的数据量将影响模型的训练和推理时间。这将整体提高你运行的实验和生产中的性能。话题建模算法依赖于文档中词语的频率来识别模式和主题,因此任何传入的无关信息都可能扭曲结果。
我们将进行的文本预处理包括以下内容:
-
对输入数据进行 Unicode 编码。这在处理不同语言的数据时至关重要。它会将
à
转换为á
,这在清理阶段非常关键。 -
将文本转换为小写,使所有大写字符都变为小写。
我们将进行的文本清理包括以下内容:
-
移除标点符号
-
移除停用词
在删除停用词时,请注意你正在处理的数据。你想删除停用词的原因是因为它们不会提供任何新信息,并且有助于优化模型的性能。不想删除停用词的情况是当句子周围的上下文很重要时。不删除停用词对于情感分析和摘要等任务很有用。然而,对于我们的话题建模用例,我们可以继续删除停用词。
清理后的数据输出应生成一个名为cleaned_summary
的新列。结果数据集应类似于下图所示的样子。
通过清理和预处理摘要列转换初始数据集。图像由作者提供。
数据统计
现在让我们调查与清理后的数据集相关的词频分布,并识别与词频相关的潜在分布。
基于从 arXiv 中的研究论文样本的清理总结的字数分布。图片由作者提供。
基于此,我们从 arXiv 查询的 1548 篇文章中,大约有 700 篇文章少于 100 个字。这对应于 45.1% 的数据少于 100 个字。
无监督学习
我们将使用 LDA 作为 Python 中无监督学习方法的主题建模算法,用于识别研究论文的主题。LDA 是一种常见的主题建模方法,大型组织如 AWS 提供的 Comprehend
工具也使用这种方法。这种方法将基本展示 AWS 用于处理文档并以无监督方式生成每个主题的后台代码。至少这样,你就不必为此付费(除了计算成本——这取决于你处理的数据量)。
训练模型
可视化
现在我们已经拥有与我们准备训练的数据相关的模型对象。我们可以创建几个独特的可视化,这将有助于提供与模型为每篇文章识别的主题和关键词相关的见解。
我们将使用 pyLDAvis
库进行以下可视化。请注意,这个库的较新版本不支持在 JupyterNotebooks 中进行可视化。我强烈建议安装文章要求部分中提到的特定版本 2.1.2
。这个 Stack Overflow 线程突出了在不同版本中生成此可视化的难度。
LDA 主题可视化,显示每个主题的前 30 个最频繁的术语。图片由作者提供。
LDA 模型识别的主题词云。图片由作者提供。
通过上述脚本创建了 10 个词云图像,但本文仅展示了 2 个。如你所见,主题之间存在相当大的重叠(例如模型和模型等术语)。基于此,可以看出需要进一步预处理和清理,例如提取词干,去除像 use, show, first, also, may, one, number, etc...
这样的停用词。模型开发过程是一个迭代过程,但这高度突显了将高质量数据输入模型的重要性。
尽管如此,我们也可以看到,采用这种方法识别的两个主题是相当独特的。第一个主题似乎深入探讨了围绕量子计算和深度学习的内容,而第二个主题则集中于机器学习、自动化机器学习和数据。
主题分析
学习到的主题频率,阈值大于 0.3。图像由作者提供。
看起来模型训练所用的大多数文章属于第一个和第五个主题。
前 30 个词的词频。图像由作者提供。
显然,根据这些结果,似乎可以进行进一步的数据清理和预处理。由于数据、模型和模型
是最常见的术语,我们不希望模型受到这些词的影响,因为它们不够具有区分性。
与每个主题相关的顶级关键词及其文档计数。图像由作者提供。
如前两张图所证实,最常预测的主题是 5 和 0。这两个主题使用了诸如模型
和数据
之类的词汇,这些词在进一步迭代中应该被移除。这是该方法的模型开发过程的第一次迭代。投入生产的模型永远不会是你训练的第一个模型,必须利用前几次迭代模型的结果来影响未来迭代中模型所需的变化。
监督学习
主题建模的监督学习方法将包括生成主题标签来训练一个二分类模型。这可以通过识别我们感兴趣的标记和预测的主题的相关关键词来完成。我将主要关注机器学习、自然语言处理(NLP)和数学
这三个主题。
这是我为每个主题识别出的关键词集。这个列表绝不是详尽无遗,但足以作为起点。
topics_dct = {
'machinelearning': [
'machinelearning', 'clustering', 'classification', 'regression',
'supervised machine learning', 'unsupervised machine learning'
],
'mathematics': [
'mathematics', 'graph theory', 'combinatorics', 'calculus',
'linear algebra', 'probability', 'statistics', 'trigonometry',
'topology', 'differential equations', 'differentiate', 'algebra'
],
'nlp': [
'natural language', 'topic modelling', 'sentiment analysis',
'translation', 'chat bot', 'text analysis', 'text mining',
'semantic analysis', 'summarization', 'linguistic processing',
'language recognition', 'text processing', 'language models',
'linguistic', 'sequencetosequence', 'neural machine translation',
'word embeddings', 'word2vec'
]
}
我们可以解析与每篇文章相关的清理后的摘要,识别出包含我们感兴趣的关键词的摘要,并将其链接回这些关键词所映射的原始主题。这将为上述每个主题提供标签。我们可以使用 TF-IDF 将输入摘要转换为与输入模型的文章相对应的向量。
关键词统计
具有相应关键词计数的文章计数。图像由作者提供。
文章中关键词出现的频率。图像由作者提供。
生成标签
从我们处理的 1548 篇文档中,根据上述定义的主题相关关键词,这是拥有正标签的文档计数。图像由作者提供。
标签生成后的数据框。图像由作者提供。
训练模型
上述脚本将生成以下 sklearn 管道,对应于我们上面生成的清理后的摘要和标签。图片由作者提供。
评价
由于我们在训练阶段生成了一个保留集,我们现在可以将训练好的模型应用于保留集,以识别模型的性能。请注意,由于我们处理的是一个小样本数据且存在类别不平衡,训练模型很可能会过拟合。这可以通过增加我们标记的文章数量并用来训练模型来比较容易解决。这意味着我们应该查询 arXiv 以获取更大的数据集,并生成更好的关键词和标记文章。如果你处理的是不同的数据集,这可能不是一个容易解决的问题。
我也强烈建议你尝试多个分类模型,而不仅仅是梯度提升分类器。如前所述,迭代是机器学习开发周期的一个重要部分!
结论
这篇文章旨在为读者提供一个教程,旨在提供监督学习和无监督学习两种主题建模方法。我希望我能够概述在查看底层数据时思维方式的变化,以及这种变化如何影响和拓宽解决特定问题的方法。
我还希望这篇文章概述了机器学习中迭代的重要性。投入生产的模型永远不会是你训练的第一个模型,重要的是利用之前迭代模型的结果来影响未来迭代中模型所需的变更。
我希望也能清楚的是,无监督学习方法的结果可以影响监督学习方法。这也可能引发一种半监督学习方法来进行主题建模,你可以在 LDA 模型的结果上训练一个二分类模型。
如果你想下载与本教程相关的 jupyter notebook,我已经在这里提供了它。
资源
-
en.wikipedia.org/wiki/Topic_model
-
docs.aws.amazon.com/comprehend/latest/dg/topic-modeling.html
如果你喜欢今天我写的文章,这里还有一些我写的关于自然语言处理的其他文章,你可能也会喜欢!
## Python 中的文本相似度与 Levenshtein 距离
使用 Python 构建剽窃检测管道
towardsdatascience.com ## 解释 Word2Vec
解释 Word2Vec 的直觉及其在 Python 中的实现
towardsdatascience.com ## 使用 Jaro-Winkler 和 PageRank 进行 Python 文本摘要
使用 Jaro-Winkler 和 PageRank 构建文本摘要器
towardsdatascience.com ## 用 Python 识别推文情感
如何使用 Tweepy 和 Textblob 识别推文情感
towardsdatascience.com
使用 Scikit-Learn 的支持向量机:友好的介绍
原文:
towardsdatascience.com/support-vector-machine-with-scikit-learn-a-friendly-introduction-a2969f2ff00d
每个数据科学家都应该在他们的工具箱中拥有 SVM。通过实践介绍学习如何掌握这一多功能模型。
Riccardo Andreoni
·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 10 月 11 日
--
图片来源:unsplash.com。
在现有的机器学习模型中,有一种模型的多功能性使它成为每个数据科学家工具箱中的必备工具:支持向量机 (SVM)。
支持向量机(SVM)是一种强大而多功能的算法,其核心能够在高维空间中划分最优超平面,有效地隔离数据集中的不同类别。但它不仅仅止于此!它的有效性不限于分类任务:SVM 也非常适合回归和异常值检测任务。
有一个特点使得 SVM 方法特别有效。与 KNN 方法 处理整个数据集不同,SVM 战略性地只关注位于决策边界附近的数据点。这些点被称为支持向量,这种独特想法背后的数学将在接下来的部分中简单解释。
这样,SVM 算法在计算上是保守的,非常适合处理中等或甚至中大型数据集的任务。
正如我在所有文章中所做的,我不仅会解释理论概念,还会提供编码示例,以帮助你熟悉Scikit-Learn(sklearn)Python 库。
让我们分析支持向量机(SVM)算法,并探索机器学习技术、Python 编程和数据科学应用。
线性 SVM 分类
从本质上讲,SVM 分类类似于线性代数的优雅简洁。想象一个二维空间中的数据集,其中有两个需要分开的不同类别。线性 SVM 尝试用最佳的直线分隔这两个类别。
图片由作者提供。
在这个背景下,“最佳”是什么意思?SVM 寻找的是最优的分隔线:不仅能分隔类别,而且与每个类别的最近训练实例之间的距离尽可能远。这个距离称为边际。位于边际边缘的数据点是线性 SVM 分类器的关键元素,被称为支持向量。
需要优化的决策边界的方程。
需要注意的是,分隔线仅由支持向量定义。因此,添加更多“随机”训练实例对决策边界没有影响。当更多不是支持向量的训练实例被添加到训练集中时,决策边界不会移动,这些实例会被“忽略”。这一特性是 SVM 的一个巨大优势,因为它不需要记住整个训练集。
对于一个 m 维数据集,分隔线变成了分隔超平面,但激发的思想依然有效。
由于 SVM 模型基于距离,因此对特征缩放极为敏感。因此,通常建议对特征值进行标准化。
在 Scikit-Learn 中,我们可以从 .svm
模块中实例化一个 LinearSVC()
对象:
import pandas as pd
import sklearn
from sklearn import datasets
# Import the dataset
df_ = datasets.load_iris()
df = pd.DataFrame()
df['petal_length'] = df_['data'][:,2]
df['petal_width'] = df_['data'][:,3]
df['target'] = df_['target'] == 1
# Define train and test sets
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(df[['petal_length', 'petal_width']],
df['target'],
test_size=0.2,
random_state=42)
# Normalize data
scaler = sklearn.preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
# Instantiate the classifier model
svm_classifier = sklearn.svm.LinearSVC()
# Fit the model with the training data
svm_classifier.fit(X_train, y_train)
# Predict new intances classes
y_predicted = svm_classifier.predict(scaler.transform(Xtest))
# Evaluate model's accuracy
accuracy = sklearn.metrics.accuracy_score(y_test, y_predicted)
print(accuracy)
得到的准确率仅为 66%:这是一个不令人满意的结果。在下一节中,我们将了解为什么线性支持向量机分类器表现不佳以及如何应对这种情况。
软边际分类
正如你所想,拥有一个类别可以通过线性面完美分隔的数据集是一种奢侈。在现实世界中,这种情况并不会发生,数据只需要一个离群点就会使线性 SVM 无法找到可行的决策边界。
图片由作者提供。
如上图所示,在这种情况下,没有任何直线能够分开这两个类别。
为了应对实际场景,我们需要引入软边际支持向量机分类。
与传统的线性 SVM(也称为硬间隔 SVM)不同,软间隔 SVM 不要求类别之间有严格的分隔,允许一些灵活性元素。通过使用软间隔 SVM,我们承认存在噪声数据点和离群点。
更实际地说,我们可以指定一个超参数 C
,它作为正则化参数:
-
将
C
设置为较大值,我们强制执行更严格的间隔,从而减少误分类。 -
将
C
设置为较小的值,我们鼓励更宽的间隔,从而允许更多的误分类。
增加 C 值减少间隔违规。
有两种方法可以在 Scikit-Learn 中实现软间隔线性 SVM 分类器:
# Method 1
svm_classifier_soft = sklearn.svm.LinearSVC(C=10)
svm_classifier_soft.fit(X_train, y_train)
# Method 2
svm_classifier_soft = sklearn.svm.SVC(kernel='linear', C=10)
svm_classifier_soft.fit(X_train, y_train)
对于相同的数据集运行这些代码行,我们获得的准确度略高于我们为硬间隔线性 SVM 获得的准确度。尽管线性 SVM 分类器在许多场景中效率高且表现良好,但大多数数据集的类别并不是线性可分的。在这些情况下,线性决策边界会产生较差的结果。
幸运的是,SVM 分类器不仅限于线性决策边界。多亏了 核技巧,它们可以学习甚至最复杂的分隔形状。下一节将集中于此。
非线性 SVM 分类
如我所预见的,许多数据集不是线性可分的。即使我们考虑一些灵活性,通过线性分隔线获得的结果也不是最佳的。为了解决这个问题,我们可以添加更多特征,例如多项式特征。添加新特征会将原始数据集转换为更高维度的数据集,在那里它可能会被一条直线或超平面分开。
考虑这个简单的例子:
非线性可分的数据。图像由作者提供。
数据只有一个特征,且无法通过任何直线分开。如果我们添加一个人工特征,计算方式为
我们得到以下数据集,它可以被线性边界分开:
线性可分的数据。图像由作者提供。
然而,添加多项式特征,如我们上面所做的,对大型和复杂数据集来说不可行。结果特征数量会太高。
幸运的是,存在一种称为核技巧的技术,即使在高次多项式下也能实现。核技巧背后的数学并不复杂,但由于我想专注于实际实施,所以我将留给 这份指南 来解释。
要在 Python 中实现带有多项式核的 SVM 分类器,我们只需使用 SVC()
类,并指定我们打算使用的核类型及其度数:
# Instantiate the classifier object
SVM_classifier = sklearn.svm.SVC(kernel='poly', degree=5, C=10, coef0=1)
# Instantiate the scaler object
scaler = sklearn.preprocessing.StandardScaler()
# Normalize data
X_train = scaler.fit_transform(X_train)
# Train the classifier
SVM_classifier.fit(X_train, Y_train)
coef0
超参数作为正则化参数,允许控制模型如何受到高次多项式的影响。
找到 degree
、C
和 coef0
超参数之间的正确平衡并不是一项简单的任务。通常建议使用网格搜索方法来找到一些可行的值,然后进行更精细的手动搜索以精确调整它们。
相似性特征:高斯径向基函数
多项式核函数通常适用于各种机器学习问题,然而,还有一种显著的技术常常效果更佳:相似性特征。
与其基于原始特征值添加人工多项式特征,不如在我们高维特征空间中放置多个 标志 并测量它们到每个数据点的距离。这些距离度量成为 新模型的特征。
考虑一个具有两个独立特征的训练集,如下所示:
图片由作者提供。
通过在特定位置添加一定数量的标志,我们可以创建额外的特征,计算为数据点到每个标志的距离。
在这个例子中,我们可以在这些精确的位置放置两个标志。SVM 分类器可以学习将接近标志的实例预测为 A 类,将远离标志的实例预测为 B 类。
可以使用任何距离度量,然而,已证明实现 高斯核函数 是很方便的:
其中:
-
x 是包含数据点坐标的向量
-
l 是包含标志坐标的向量
-
gamma 作为正则化超参数
现在的问题是:“我应该将标志放在哪里?”。最常见的方法是在每个训练样本的位置放置一个标志。拥有 m 个训练样本意味着创建 m 个标志,从而产生 m 个新特征。
这种方法的缺点是对于大规模训练集,我们最终会得到同样大量的距离特征。
在 Scikit-Learn 中,带有高斯核的 SVM 分类器实现如下:
import pandas as pd
import sklearn
from sklearn import datasets
# Import the dataset
df_ = datasets.load_iris()
df = pd.DataFrame()
df['petal_length'] = df_['data'][:,2]
df['petal_width'] = df_['data'][:,3]
df['target'] = df_['target'] == 1
# Define train and test sets
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(df[['petal_length', 'petal_width']],
df['target'],
test_size=0.2,
random_state=42)
# Normalize data
scaler = sklearn.preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
# Instantiate the classifier model
svm_gaussian_classifier = sklearn.svm.SVC(kernel='rbf', gamma=6, C=0.001)
# Fit the model with the training data
svm_gaussian_classifier.fit(X_train, y_train)
# Predict new intances classes
y_predicted = svm_gaussian_classifier.predict(scaler.transform(X_test))
# Evaluate model's accuracy
accuracy = sklearn.metrics.accuracy_score(y_test, y_predicted)
结论
在这篇文章中,我们详细介绍了这一多功能且强大的模型的理论,并了解了通过 Scikit-Learn 库在 Python 中实现它是多么简单。
我通过阐述支持向量机模型的优缺点来总结这篇入门指南。
毫无疑问,SVM 的优势在于其高维空间中的准确性,使其非常适合特征众多的数据,如图像或基因数据。此外,我必须强调 SVM 的多样性:除了分类,它还可以无缝地适应回归和异常检测任务。最后,我必须指出其灵活性。SVM 通过不同的核函数处理非线性关系的能力,赋予了它广泛的应用范围。
众所周知,“机器学习没有免费的午餐”这一箴言,我们必须承认 SVM 的缺点。首先,SVM 容易受到噪声数据的影响,正如我们在线性 SVM部分所见。可扩展性也是 SVM 的一个致命弱点:用 SVM 训练大规模数据集可能会计算密集且要求高。
在完成本指南时,我需要指出,尽管已经涉及了 SVM 的所有最相关的方面,但其领域远超本指南所提供的见解。因此,我建议深入研究附在本文后的资源和参考文献。
如果你喜欢这个故事,考虑关注我,以便及时了解我即将发布的项目和文章!
这是我过去的一些项目:
使用 NetworkX 进行社交网络分析:友好入门
了解像 Facebook 和 LinkedIn 这样的公司如何从网络中提取洞察
集成学习与 Scikit-Learn:友好入门
像 XGBoost 或随机森林这样的集成学习算法是 Kaggle 竞赛中的顶尖模型之一……
深度学习生成幻想名字:从零构建语言模型
一个语言模型能否创造独特的幻想角色名字?让我们从零开始构建它
深度学习生成幻想名字:从零构建语言模型
参考文献
-
使用 Python 进行机器学习
-
《动手学机器学习:基于 Scikit-Learn、Keras 和 TensorFlow 的实践(第 2 版)》 — 奥雷利安·热龙
-
Scikit-Learn SVM 文档
-
机器学习专项课程
数据讲故事的辅助材料
原文:
towardsdatascience.com/supporting-pieces-for-data-storytelling-a0b5c69a476c
从原始数据到引人入胜的叙事
掌握数据讲故事艺术和通过引人注目的可视化和设计吸引受众的指南
Richmond Alake
·发表于Towards Data Science ·7 分钟阅读·2023 年 3 月 27 日
--
向观众展示——作者通过 Midjourney 制作的图像
在之前的文章中,我探讨了数据讲故事及相关主题,获得了读者的积极关注。像数据领域中的大多数事物一样,还有更多内容需要覆盖。
讲故事的艺术不仅仅是提供背景、识别争议和呈现解决方案;这只是基础——一个非常坚实的基础。
本文进一步深入探讨了如何利用创意和表现力材料制作完整而引人入胜的数据驱动叙事。
我写这篇文章的目的是为你,即 AI/数据/机器学习从业者,提供提升数据沟通、展示和讲故事能力的见解。
我还将向你介绍一些工具和库,这些工具和库可以在各种格式中促进数据可视化,例如动态和静态图表。总的来说,本文提到的技巧、工具和材料对于有效提升你通过视觉沟通传达数据和信息的能力至关重要。
数据可视化
人脑在从图像中快速提取信息方面优于从文本表示的相同数据。
坦率地说,我对阅读感到厌倦。逐行或逐句处理信息可能会很枯燥,特别是在时间紧迫的情况下。当然,信息的文字密集型展示还是有其必要的。但当涉及将数据中的模式和见解传达给关键利益相关者时,使用这种方法会更好。
使用图像。
观察图像时,你会同时处理大量信息,而不必从左上角开始并向右移动。数据的视觉表示利用了人类最佳的感官输入和最高的消耗带宽:视觉感知。
人脑在从图像中快速提取信息方面优于从相同数据的文本表示中提取信息。这就是为什么数据可视化在用数据讲故事时发挥着至关重要的作用。
如果你仅依靠原始数值数据而没有可视化元素来创建演示文稿,你会迅速失去观众的注意力。将可视化元素纳入演示文稿对于整体参与度至关重要 —— 也是为了保持理智。
在检查一些示例之前,让我们首先定义数据可视化。
数据可视化是将数据呈现为替代表示形式,以便通过使用文本、图像、图表、图形等视觉组件来识别和理解模式、趋势和洞察。
数据从业者职位的表格数据 — 图片由作者提供
参考上面的图像,它展示了 Google 趋势中 Data Analyst、Data Scientist、Data Engineer、ML Engineer 和 MLOps 等职位的表格表示。假设提取这些数据的目的是为了向年轻学习者介绍最受欢迎的 AI/数据相关职位。你可能会认为,纯数字/文本数据的表示无法有效传达预期的信息。简单来说,通过观察数据的表格表示而不进行一些统计分析,很难确定哪些职位是需求量大的。
通过使用 Flourish 的数据可视化和讲故事工具,我们可以将这些数据转化为更具吸引力和信息量的格式。
数据的视觉版本呈现了过去几十年中哪些 AI/数据相关职业的需求情况。尽管最初的表格数据是结构化的,但其复杂性可能会影响对关键模式、趋势和信息的理解。
以下是要点:数据可视化在适当应用时,可以使理解变得更加容易,帮助利益相关者在用数据讲故事时把握洞察和智能。作为额外的好处,Flourish 使数据展示中动态动画的集成成为可能。请参考这个 链接 体验上述数字数据背后的整个故事。
AI/Data 从业者职位趋势 — 图片由作者提供
让我们重点关注本节的关键要点:
-
使用图表、图像和动画可视化数据,将复杂数据转化为易于理解的格式,使观察者能够迅速提取关键模式和信息。
-
为此,开发了多种工具和库,包括 Matplotlib、Seaborn、Tableau、D3 和 Plotly。
-
对于 AI/数据从业者而言,以多种格式展示和沟通数据发现至关重要,因为这提升了他们作为团队、公司或组织中的跨职能资产的价值。
-
这一技能在技术能力不是每个人所需的组织中特别重要。
-
从业者可以通过有效展示数据洞察来弥合技术与非技术利益相关者之间的差距。这意味着每个人都可以理解和利用数据分析得出的信息。
支持材料和文档
小册子、材料和文档 — 作者图像来自 Midjourney
虽然每次可能无法做到,但有效实施时,可以显著提高信息的吸收。
在数据讲述中,超越常规是什么样的?给人们一部分可以触摸、感受并带走的数据。是的,这可能不是每次都能做到,但当有效执行时,它显著提高信息的吸收。
所以,让我们动手试试。
将触觉融入数据讲述并不常见,因为实际操作、成本和后勤挑战。比如,为超过 100 位利益相关者或与会者提供宣传册可能耗时且昂贵。
然而,在适当的时候,利用“硬”材料如纸质文档、宣传册或上述任何示例,可以作为秘密武器,为您的数据讲述能力增加“wow”因素。 通过调动多种感官,您可以为观众创造更身临其境和难忘的体验,进一步增强数据驱动叙事的影响力。
以下是一些将“硬”材料纳入演示的示例:
-
附带用户手册为利益相关者提供了一个信息包,他们可以在演示过程中及其后进行阅读。用户手册的目的是为演示内容提供额外的理解,包括定义、可视化、组件、网络链接等。
-
有效分发打印版宣传册提供关于外部组织或机构的信息,而不会偏离演示或数据的主要观点。
-
物理设备或装备,甚至原型,可以在向利益相关者展示专有硬件解决方案时,弥合想象与现实之间的差距。
-
宣传册和小册子可以强调和具体化关键指标、数据点以及供利益相关者记住的关键要点。
设计考虑
Canva 上的色彩斑斓 — 作者通过 Midjourney 提供的图片
不要害怕像 UI 设计师一样思考。
将设计考虑融入根植于严格逻辑和系统思维的技术学科可能具有挑战性,但却是可能的。作为 AI/数据从业者,我们的技术责任优先,但在数据展示和讲故事时,设计考虑同样不可忽视。
黑白展示已不再足够。
到目前为止,我们已了解数据从业者创建的材料必须吸引人的视觉感官。这种方法的一个重要好处是利益相关者能够快速消化信息,特别是当选择合适的颜色来突出特定数据点和引导关注时。简单来说,如果看起来好,内容可能也会好。
通过选择适当的色彩方案和创建视觉上令人愉悦的演示文稿,你正在微妙地邀请利益相关者接受你的信息,这提升了你有效沟通数据洞察力的能力。
你可以通过考虑颜色和空间来做到这一点。不要害怕像 UI 设计师一样思考。
颜色
这里没有什么太复杂的,关键在于人们对特定颜色的反应不同,对颜色的情感反应的研究属于颜色心理学的范畴。
例如,红色用于激发强烈的情感,橙色有时代表兴奋和能量,这也是为什么麦当劳、OpenTable 等食品品牌使用红色的原因。在数据可视化中,红色用于传达负面的模式、情感和动态。
绿色代表自然和动力,因此在生产力应用中是一个受欢迎的选择。另一个有趣的例子是生动的蓝色,它引发一种柔和、平静的情感。在数据可视化中,绿色用于传达积极的模式。
许多社交网络品牌,如 LinkedIn、Twitter 和 Facebook,利用不同的蓝色调,正念和冥想应用如 Calm 也融入了这种颜色的平静本质。
空间
考虑传达内容的结构至关重要。融入空间在信息结构化以突出关键点方面发挥着关键作用。例如,通过有意间隔演示组件,可以控制观众对演示和数据特定区域的关注。
另一种看法是,杂乱无章可能会阻碍数据讲故事的影响力。
故事中的重要方面包括关键的信息、议程或要点——你希望观众记住并带走的内容。利用空间结构可以使关键信息成为关注的焦点。通过去除支撑数据讲故事的演示和可视化中的冗余内容,你可以实现简洁性和全面性。去掉任何没有增加价值的多余信息。
摘要
讲故事是一种具有表现力的沟通艺术。当与数据的揭示性质结合时,它创造了一种媒介,使你能够改变人们的观点,推动议程,有效沟通,赢得项目等。这一切都是说,精通数据讲故事这一创造性技能的技术从业者可以获得只有少数人才享有的好处。
未来,我打算继续写技术文章,但加入轶事、个人经历和更具表现力的图像将增强我试图传达给读者和你的教训。
感谢阅读。
希望你觉得这篇文章对你有帮助。
若要与我联系或找到更多类似的内容,请执行以下操作:
-
支持我的写作成为推荐的 Medium 会员
-
订阅我的 YouTube 频道
-
订阅我的播客 Apple Podcasts | Spotify | Audible
-
订阅我的 邮件列表 获取我的新闻简报
生存分析:利用深度学习进行事件时间预测(第二部分)
原文:
towardsdatascience.com/survival-analysis-leveraging-deep-learning-for-time-to-event-forecasting-5c55bd4bb066
作者插图
实际应用于再住院
Lina Faik
·发布于 Towards Data Science ·10 分钟阅读·2023 年 4 月 21 日
--
生存模型非常适合预测事件发生的时间。这些模型可以应用于各种用例,包括预测维护(预测机器可能故障的时间)、营销分析(预测客户流失)、患者监测(预测患者可能会再度住院)等。
通过将机器学习与生存模型结合,结果模型可以利用前者的高预测能力,同时保留后者的框架和典型输出(例如生存概率或随时间变化的风险曲线)。欲了解更多信息,请查看本系列的第一篇文章 这里。
然而在实践中,基于机器学习的生存模型仍然需要广泛的特征工程,因此需要先验的业务知识和直觉才能取得令人满意的结果。那么,为什么不使用深度学习模型来弥补这一差距呢?
目标
本文重点介绍了如何将深度学习与生存分析框架结合,以解决例如预测患者(再)住院可能性的用例。
阅读本文后,你将了解:
-
深度学习如何应用于生存分析?
-
在生存分析中,常见的深度学习模型有哪些,它们是如何工作的?
-
这些模型如何具体应用于住院预测?
本文是关于生存分析系列的第二部分。如果你对生存分析不熟悉,最好先阅读第一部分 这里**。文章中描述的实验使用了库 scikit-survival、 pycox* 和* plotly。你可以在 GitHub上找到代码。
[## GitHub - linafaik08/survival_analysis
作者:莉娜·法伊克 创建日期:2023 年 2 月 最后更新:2023 年 4 月 本仓库包含代码和笔记本……
github.com
1. 生存分析与深度学习:它们如何结合?
1.1. 问题陈述
让我们首先描述一下手头的问题。
我们感兴趣的是根据可用的健康状态信息预测某一患者重新入院的可能性。更具体地说,我们希望在最后一次就诊后的不同时间点估计这种概率。这种估计对监测患者健康和降低复发风险至关重要。
这是一个典型的生存分析问题。数据包括三个元素:
患者的基线数据包括:
-
人口统计学:年龄、性别、所在地(乡村或城市)
-
患者历史:吸烟、饮酒、糖尿病、高血压等。
-
实验室结果:血红蛋白、总淋巴细胞计数、血小板、葡萄糖、尿素、肌酐等。
-
关于源数据集的更多信息请见 这里。
一个时间 t 和一个事件指示器 δ∈{0;1}:
-
如果事件在观察期间发生,t 等于从数据收集时刻到观察到事件(即重新入院)时刻之间的时间,这种情况下,δ = 1。
-
如果没有发生事件,t 等于从数据收集时刻到最后一次接触患者(例如,研究结束)之间的时间,这种情况下,δ = 0。
图 1 — 生存分析数据,作者插图。注意:患者 A 和 C 被删失。
⚠️ 使用生存分析方法的原因是什么,尤其是当问题如此类似于回归任务时?初始论文很好地解释了主要原因:
“如果选择使用标准回归方法,右删失数据会变成一种缺失数据。通常会被删除或填补,这可能会引入模型偏差。因此,建模右删失数据需要特别注意,因此需要使用生存模型。” 来源 [2]
1.2. DeepSurv
方法
让我们进入理论部分,稍微复习一下风险函数。
“风险函数是指一个个体在已经存活到时间 t 的条件下,额外微小时间δ内未能存活的概率。因此,更大的风险意味着更大的死亡风险。”
来源 [2]
与 Cox 比例风险(CPH)模型类似,DeepSurv 基于以下假设:风险函数是两个函数的乘积:
-
基线风险函数:λ_0(t)
-
风险评分,r(x)=exp(h(x))。它建模了给定个体在观察到的协变量下,风险函数如何相对于基线变化。
更多关于 Cox 比例风险(CPH)模型的信息,请参见本系列的第一篇文章。
函数 h(x)通常被称为对数风险函数。这正是 DeepSurv 模型旨在建模的函数。
实际上,CPH 模型假设h(x)是一个线性函数:h(x) = β . x。拟合模型的过程就是计算权重β以优化目标函数。然而,线性比例风险假设在许多应用中并不成立。这就证明了需要一个更复杂的非线性模型,理想情况下能够处理大量数据。
架构
在这种情况下,DeepSurv 模型如何提供更好的替代方案?让我们先来描述一下它。根据原始论文,它是一个“深度前馈神经网络,通过网络权重θ预测患者协变量对其风险率的影响。” [2]
它是如何工作的?
‣ 网络的输入是基线数据 x。
‣ 网络通过多个隐藏层进行输入传播,隐藏层具有权重θ。这些隐藏层包含全连接的非线性激活函数,随后是 dropout。
‣ 最后一层是一个节点,对隐藏特征进行线性组合。网络的输出被视为预测的对数风险函数。
来源 [2]
图 2 — DeepSurv 架构,作者插图,灵感来源于[2]
由于这种架构,模型非常灵活。通常使用超参数搜索技术来确定隐藏层的数量、每层中的节点数、丢弃概率以及其他设置。
那么,优化的目标函数是什么?
-
CPH 模型被训练以优化 Cox 部分似然。它包括计算每个患者i在时间Ti的事件发生概率,考虑到在时间Ti时仍处于风险中的所有个体,然后将所有这些概率相乘。您可以在这里找到确切的数学公式[2]。
-
类似地,DeepSurv 的目标函数是相同部分似然的对数负均值,加上一个用于正则化网络权重的附加部分。[2]
代码示例
下面是一个小的代码片段,以了解如何使用pycox库实现这种类型的模型。完整代码可以在该库的笔记本示例中找到这里 [6]。
# Step 1: Neural net
# simple MLP with two hidden layers, ReLU activations, batch norm and dropout
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
dropout, output_bias=output_bias)
model = CoxPH(net, tt.optim.Adam)
# Step 2: Model training
batch_size = 256
epochs = 512
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True
model.optimizer.set_lr(0.01)
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
val_data=val, val_batch_size=batch_size)
# Step 3: Prediction
_ = model.compute_baseline_hazards()
surv = model.predict_surv_df(x_test)
# Step 4: Evaluation
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
ev.concordance_td()
1.3. DeepHit
方法
如果我们可以训练一个深度神经网络直接学习生存时间的分布,而不是对其做出强假设呢?
这就是 DeepHit 模型的情况。特别是,它相对于以前的方法带来了两个显著改进:
-
它不依赖于任何关于基础随机过程的假设。因此,网络学会了建模协变量与风险之间随时间演变的关系。
-
它可以通过多任务学习架构处理竞争风险(例如,同时建模再住院和死亡的风险)。
架构
如此处所述[3],DeepHits 遵循多任务学习模型的常见架构。它由两个主要部分组成:
-
共享子网络,其中模型从数据中学习到对所有任务有用的通用表示。
-
任务特定的子网络,其中模型学习到更具任务特定性的表示。
然而,DeepHit 模型的架构与典型的多任务学习模型在两个方面有所不同:
-
它包括一个残差连接,将初始协变量与任务特定子网络的输入连接起来。
-
它只使用一个 softmax 输出层。得益于此,模型学习的是竞争事件的联合分布,而不是边际分布。
下图展示了模型在两个任务上同时训练的情况。
DeepHit 模型的输出是每个受试者的向量y。它给出受试者在观察时间内每个时间戳t发生事件 k ∈ [1, 2]的概率。
图 3 — DeepHit 架构,作者插图,灵感来源于源[4]
2. 用例应用:这些模型在实际中的表现如何?
2.1. 方法论
数据
数据集被分成三部分:训练集(占数据的 60%),验证集(20%),和测试集(20%)。训练集和验证集用于在训练期间优化神经网络,测试集用于最终评估。
基准
深度学习模型的表现与包括 CoxPH 和基于 ML 的生存模型(梯度提升和 SVM)在内的基准模型进行比较。有关这些模型的更多信息,请参见系列的第一篇文章。
指标
两种指标用于评估模型:
-
Concordance index (C-index):它衡量模型根据个体风险评分提供可靠生存时间排名的能力。它的计算方法是数据集中一致对的比例。
-
Brier 分数:这是均方误差对右删失数据的时间依赖扩展。换句话说,它表示观察到的生存状态与预测的生存概率之间的平均平方距离。
2.2. 结果
从 C-index 的角度来看,深度学习模型的表现明显优于基于 ML 的生存分析模型。此外,Deep Surval 和 Deep Hit 模型的表现几乎没有差异。
图 4 — 模型在训练集和测试集上的 C-Index
从 Brier 分数的角度来看,Deep Surv 模型脱颖而出。
- 在考察 Brier 分数随时间变化的曲线时,Deep Surv 模型的曲线低于其他模型,这反映了更好的准确性。
图 5 — 测试集上的 Brier 分数
- 考虑到在相同时间区间内分数的积分,这一观察结果得到了确认。
图 6 — 测试集上的集成 Brier 分数
请注意,Brier 分数未计算 SVM,因为该分数仅适用于能够估计生存函数的模型。
图 7 — 使用 DeepSurv 模型的随机选取患者的生存曲线
最终,深度学习模型不仅可以用于生存分析,也可以用于统计模型。例如,在这里,我们可以看到随机选择的患者的生存曲线。这些输出可以带来许多好处,特别是能够更有效地跟踪最有风险的患者。
关键要点
✔️ 生存模型对于预测事件发生所需的时间非常有用。
✔️ 它们可以通过提供学习框架和技术以及有用的输出(如生存概率或随时间变化的危险曲线)来帮助解决许多使用案例。
✔️ 在这种使用案例中,它们甚至是不可或缺的,以利用所有数据,包括被删失的观察(例如,当事件在观察期内未发生时)。
✔️ 基于机器学习的生存模型往往比统计模型表现更好(更多信息见这里)。然而,它们需要基于坚实的业务直觉进行高质量的特征工程,以实现令人满意的结果。
✔️ 这就是深度学习可以弥补差距的地方。基于深度学习的生存模型如 DeepSurv 或 DeepHit 有可能以更少的努力取得更好的表现!
✔️ 然而,这些模型也不是没有缺点。它们需要一个大型数据库进行训练,并且需要对多个超参数进行微调。
参考文献
1 Bollepalli, S.C.; Sahani, A.K.; Aslam, N.; Mohan, B.; Kulkarni, K.; Goyal, A.; Singh, B.; Singh, G.; Mittal, A.; Tandon, R.; Chhabra, S.T.; Wander, G.S.; Armoundas, A.A. 一种优化的机器学习模型准确预测入院心脏病科的住院结果。Diagnostics 2022, 12, 241。
2 Katzman, J., Shaham, U., Bates, J., Cloninger, A., Jiang, T., & Kluger, Y. (2016). DeepSurv: 个性化治疗推荐系统使用 Cox 比例风险深度神经网络,ArXiv
[3] Laura Löschmann, Daria Smorodina, 用于生存分析的深度学习,信息系统研讨会(WS19/20),2020 年 2 月 6 日
[4] Lee, Changhee 等人 DeepHit:一种深度学习方法用于处理竞争风险的生存分析。 AAAI 人工智能大会(2018)。
[5] 维基百科,比例风险模型
[6] Pycox 库
生存分析:用机器学习预测事件时间(第一部分)
原文:
towardsdatascience.com/survival-analysis-predict-time-to-event-with-machine-learning-part-i-ba52f9ab9a46
作者插图
客户流失预测的实际应用
Lina Faik
·发布于 Towards Data Science ·11 分钟阅读·2023 年 2 月 9 日
--
预测事件发生的概率很好,预测事件发生前的剩余时间更好!
以客户流失为例。如果不是预测客户在接下来的几个月内离开公司的概率,而是能够在接下来的几个月中多个时间点预测这一概率,这种方法的好处是显而易见的。它将使你能够更有效地预测和优先考虑营销行动,最终减少流失率。
这正好属于生存分析的领域,也称为事件时间分析。它指的是一种学习框架和一套技术,用于根据观察估计某一感兴趣事件发生的时间。
生存分析的名称来源于其首次应用的典型用例:预测临床研究中的死亡时间。然而,不应被其名称误导:它并不限于医学领域,还可以应用于多个行业的用例。随着数据科学的最新进展,生存分析已从经典统计学领域重新出现,纳入了更先进的机器学习方法。
目标
这篇文章集中讲述了如何将机器学习与生存分析框架结合,解决诸如预测流失等用例。
阅读完这篇文章后,你将了解:
-
生存分析到底是什么?
-
主要的生存模型是什么,它们是如何工作的?
-
这些模型如何具体应用于流失预测?
本文是围绕生存分析系列的第一部分。理解本文不需要任何先前知识。文章中描述的实验使用了 scikit-survival 和 plotly库进行。你可以在 GitHub 上找到代码 这里 。
1. 生存分析是关于什么的?
1.1. 问题陈述
乍一看,生存分析可能只是另一个回归问题,因为目标是预测事件发生的时间(一个连续变量)。然而,问题有一个转折:部分训练数据可能是部分观察到的——它是被审查的。
为了说明这一点,我们以一个按订阅方式提供服务的公司为例。该公司希望预测每个向支持部门求助的客户在一段时间内取消订阅的概率。在数据收集期间:
-
客户 A 在研究结束前没有取消订阅。
-
客户 B 和 D 在几个月后取消了他们的订阅。
-
客户 C 决定限制平台对其数据的访问。
在这种情况下,客户 A 和 C 的记录是审查的。
图 1 — 生存分析数据,由作者插图
更正式地说,每个观测包含一组协变量 X = (x_1, …, x_n),事件发生的时间 t,或审查时间 c>0。我们引入一个事件指示符 δ∈{0;1}。右审查样本的可观察时间 y 定义为:
在我们预测客户流失的案例中,数据包括客户的支持联系历史。每个观测包含以下信息:
-
互动:日期、呼叫原因(注册/支持)和渠道(电子邮件/电话)。
-
客户:年龄和性别。
-
订阅:产品、价格、计费周期(月度/年度)、注册。
数据通过额外特征进行了丰富,包括客户过去联系公司的次数、自客户订阅以来的时长以及周期性日期相关特征。
1.2. 用例
生存分析可以用于广泛的应用场景,其中目标是预测两个事件之间的时间。以下是一些其他示例:
-
预测性维护: 预测机器在开启后何时可能会发生故障。如果机器因外部因素(如火警警报等)不得不停止,数据可能会被审查。
-
患者监测: 预测患者在首次诊断或住院后何时可能再次住院。如果患者离开了研究的地理范围,数据可能会被审查。
-
市场营销分析: 预测潜在客户从第一次通话开始的转化时间。如果在观察期间个人去世,数据可能会被删失。
-
经济学: 预测被裁员的人找到工作的时间。如果一个人退出了研究,数据可能会被删失。
2. 应对生存案例的常见方法是什么?
2.1. Kaplan-Meier 估计器
解决这个问题最简单的方法之一是使用 Kaplan-Meier 估计器。这是一种非参数方法,专注于逼近生存函数。在检查应用案例的结果之前,我们先讨论一下基本的理论。
生存函数
生存函数 S(t) 表示一个被试在时间 t 之后生存的概率,或者类似地,持续时间至少等于 t 的概率。
其中 T 是从研究人群中取出的随机寿命。
S 在时间 t=0 时从 1 开始,因为在开始时没有被试经历事件。它减少并趋近于 0,因为每个人都可能在某个时刻经历感兴趣的事件。
Kaplan-Meier 估计器
为了逼近生存函数,Kaplan-Meier 模型将估计过程分解为小步骤。对于每个区间,概率计算如下:
其中 n_i 是在时间点 t_i 处面临风险的个体数量,d_i 是在时间点 t_i 处经历事件的个体数量。
这是一种非常简单的方法,不考虑协变量。
-
它可以作为一个简单的基线模型使用。
-
它也可以作为数据探索方法。在这种情况下,它提供了整个群体生存函数的概述,或帮助比较某些群体之间的差异。
例如,在我们的案例研究中,我们可能会根据订阅计费周期比较估计值。下面的图表确认了这样的直觉:每月订阅的客户更具波动性。他们在订阅后的头几年内更频繁、更快地流失。
图 2 — Kaplan-Meier 模型估计的生存函数
2.2. Cox 比例风险模型
最广泛使用的估计器无疑是 Cox 比例风险(Cox PH)模型。它易于实现,考虑了协变量,并提供了可解释的结果。这是一种半参数方法,旨在建模风险函数。
风险函数
风险函数 h(t) 表示在时间 t 发生死亡事件的概率,前提是被试在时间 t 之前没有经历死亡事件。
因此,风险函数在找到最安全或最风险的时间段方面非常有用。
Cox 比例风险
CoxPH 模型的风险函数如下:
作者插图
模型由两部分组成:
-
基线风险:它描述了风险随时间的演变。
-
风险比:它建模了解释变量对风险的影响。
使用这个参数函数,模型依赖于一个强大的比例假设:在某一时间点,主体的风险函数与基线或其他主体保持相同比例。
- 例如,如果客户在初始观察时的流失风险是另一位客户的两倍,那么在所有后续时间观察中,流失风险仍然是两倍。
应用
模型的输出具有很高的解释性。
在实例级别,模型为每个观察提供:
-
风险评分:风险越高,客户取消订阅的可能性就越大。
-
生存函数:它使分析师能够评估在时间点 t 之前生存的概率。例如,下面的图表显示,客户 2 最有可能在前几天内流失,而客户 1、3 和 4 则没有风险。
图 3 — 5 位随机选择的客户的生存函数
- 风险函数:它具有相同的目的。下面的图表确认了之前的结论。
图 4 — 5 位随机选择的客户的累积风险函数
在全球级别,模型可以通过其系数进行解释(见上方公式)。对于正系数,系数越高,对流失风险的影响越强。
例如,下面的图表显示模型通常将寻求支持的联系视为流失的风险因素。
图 5 — 由 Cox PH 模型获得的系数
3. 机器学习如何用于生存分析?
3.1. 基于机器学习的生存模型
在比较我们关于流失预测的案例研究中的模型性能之前,我们先来了解生存分析中的机器学习模型的基本原理。
随机生存森林
就像标准的随机森林一样,随机生存森林的核心在于对数据集的多个子样本(通常带有替换地抽取)训练若干生存树,并使用平均化来提高预测准确性和限制过拟合。
主要区别在于用于评估分裂质量的指标:log-rank,它通常用于比较两个或多个组之间的生存曲线。
有关模型的更多信息可以在这里找到。
梯度提升生存分析
应用于生存分析的梯度提升也非常相似:它通过将多个基本学习器的预测以加法方式结合起来,从而获得一个强大的整体模型。基本学习器,也称为弱学习器,通常是非常简单的模型。与随机森林不同的是,生存树不是独立训练的,而是以贪婪的阶段性方式顺序训练的。
该模型是一个非常通用的框架:它可以优化许多损失函数,包括:
-
纽曼的部分似然损失
-
平方回归损失
-
删失加权最小二乘误差。
这种损失允许模型通过一个常数因子来加速或减缓事件的发生时间。这被称为加速失效时间(AFT)。与仅特征影响风险函数的 Cox 比例风险模型不同。
关于该模型的更多信息可以在这里找到。
3.3. 生存支持向量机
生存支持向量机(SVM)也可以扩展到生存分析。它也是一个非常通用的模型,因为它可以通过所谓的核技巧来考虑特征与生存之间的复杂非线性关系。
然而,其预测不能很容易地与生存分析的标准量相关联,即生存函数和累积风险函数。
关于该模型的更多信息可以在这里找到。
3.2. 比较
方法论
为了比较模型的性能,初始数据集包含大约 320,000 个观测值,被划分为两个集合:一个训练集(70%)和一个验证集(30%),它们具有相同的删失分布。模型通过 5 折交叉验证进行训练和微调,然后在验证集上进行评估。
一致性指数
最常用的评估指标是一致性指数,也称为 c 指数。它衡量模型根据个体风险评分提供可靠的生存时间排名的能力。它的计算方法是数据集中一致对的比例。
更具体地说,我们来考虑两个观测值(i,j):
-
首先,为了进行比较,较低时间的观测值需要经历事件。
-
其次,如果可比的话,当生存模型估计的风险对生存时间较短的个体更高时,它就是一致的。
下图显示了模型在 5 折交叉测试和验证集上的结果。梯度提升是表现最好的模型,在 5 折交叉测试和验证集上的一致性指数均约为 0.70。
图 6— 模型的一致性指数
一致性指数计算和解释都很简单。然而,它有两个主要缺点:
-
当审查量增加时,它往往过于有利。另一种方法是使用逆审查概率权重。审查分布是通过对训练数据应用 Kaplan-Meier 估计器获得的。
-
当主要目标是衡量特定时间段内的表现时(例如预测在订阅的第一年内的流失),这并不是很有用。
第二个缺点可以通过使用其他指标如累计/动态 AUC 来克服。
累计/动态 AUC
众所周知的接收器操作特征曲线(ROC 曲线)可以扩展到审查生存时间。其思路是考虑几个时间点。在每个时间点,我们分别考虑:
-
累计病例:在时间 t 之前或时间 t 发生事件的所有个体。
-
动态控制:那些将在某时间点后经历事件的个体。
我们可以评估模型在区分会经历事件的个体(灵敏度)和不会经历事件的个体(特异性)方面的能力。
使用这种方法,可以仅在上下文中最重要的时间点上评估估计器(例如前两年)。
下图展示了模型在验证集上的结果。梯度提升依然是表现最好的模型,在 2 年期间的平均 AUC 约为 0.80。虽然随机森林和 Cox PH 的平均表现相似,但在订阅的前几个月,它们远远落后于梯度提升。
图 7— 模型在验证集上随时间变化的 AUC
关键要点
✔️ 生存分析指的是一种学习框架和一系列技术,用于根据观察数据估计事件发生所需的时间。
✔️ 这不仅仅是一个简单的回归预测问题,因为部分训练数据可能是部分观察到的——它是被审查的。
✔️ 常见的机器学习模型,如随机森林、梯度提升或 SVM,可以扩展到生存分析,从而得到更好的且仍可解释的模型。
✔️ 将生存分析框架与机器学习的预测能力相结合,可以为包括预测性维护、患者监测、营销分析、经济学等在内的广泛应用带来显著的商业价值。
参考文献
1 使用 scikit-survival 的生存分析介绍
2 Scikit survival 文档
[3] 维基百科,比例风险模型
[4] 维基百科,肯德尔秩相关系数
[5] 劳拉·勒施曼,达里亚·斯莫罗迪纳,生存分析中的深度学习,洪堡大学,2020 年 2 月
适者生存:紧凑型生成式 AI 模型是规模化成本效益 AI 的未来
原文:
towardsdatascience.com/survival-of-the-fittest-compact-generative-ai-models-are-the-future-for-cost-effective-ai-at-scale-6bbdc138f618?source=collection_archive---------6-----------------------#2023-07-25
图片来源:Adobe Stock。
支持灵活、针对性强的检索式模型作为规模化部署生成式 AI 应用的最佳解决方案。
Gadi Singer
·
关注 发表在 Towards Data Science ·18 分钟阅读·2023 年 7 月 25 日
--
在人工智能 (AI) 模型复杂性和计算能力快速增长了十年之后,2023 年标志着对效率和生成 AI (GenAI) 广泛应用的关注转变。因此,一批参数少于 150 亿的新型模型被称为灵活 AI,可以在特定领域中接近 ChatGPT 风格的大型模型(参数超过 1000 亿)的能力。尽管 GenAI 已经在各行业广泛应用于各种商业用途,但紧凑且智能的模型使用率正在上升。在不久的将来,我预计将会有少量的大型模型和大量的小型、更灵活的 AI 模型嵌入到无数应用中。
尽管较大的模型取得了巨大进展,但在训练和环境成本方面,大型模型未必更好。TrendForce 估计,仅 GPT-4 的 ChatGPT 训练费用就超过 1 亿美元,而灵活模型的预训练成本则低得多(例如,MosaicML 的 MPT-7B 的预训练费用约为 20 万美元)。大部分计算成本发生在持续的推断执行过程中,但这与大型模型面临的类似挑战有关,包括高昂的计算费用。此外,托管在第三方环境中的大型模型会带来安全性和隐私问题。灵活模型的运行成本大大降低,并提供了额外的好处,如适应性、硬件灵活性、在更大应用中的集成性、安全性和隐私、可解释性等(见图 1)。对较小模型表现不如大型模型的看法也在改变。较小的、针对性的模型并不缺乏智能——它们可以在商业、消费和科学领域提供等效或更优的性能,增加了其价值,同时减少了时间和成本投资。
越来越多的这些灵活模型大致匹配了 ChatGPT-3.5 级别的大型模型的性能,并且在性能和范围上持续快速提升。而且,当灵活模型配备了即时检索策划的特定领域私人数据和基于查询的网络内容有针对性的检索时,它们比记忆广泛数据集的大型模型更加准确且更具成本效益。
图 1. 灵活 GenAI 模型的好处。图片来源:Intel Labs。
随着灵活的开源 GenAI 模型不断推进领域的快速发展,这一“iPhone 时刻”——当一种革命性技术变得主流——正受到“Android 革命”的挑战,因为一个强大的研究和开发社区在彼此的开源努力基础上进行构建,创造出越来越强大的灵活模型。
思考、执行、了解:目标领域的灵活模型可以像巨型模型一样表现
图 2. 生成型人工智能能力分类。图像来源:Intel Labs。
要更深入地了解何时以及如何让较小的模型在生成型人工智能中提供高度竞争的结果,重要的是观察到,无论是灵活的还是巨型的生成型人工智能模型,都需要三类能力才能表现出色:
-
认知能力以思考: 包括语言理解、总结、推理、规划、从经验中学习、长篇阐述和互动对话。
-
功能技能以执行: 例如——在自然环境中阅读文本、阅读图表/图形、视觉识别、编程(编码和调试)、图像生成和语音。
-
信息(记忆或检索)以了解: 网页内容,包括社交媒体、新闻、研究和其他一般内容,和/或策划的领域特定内容,如医学、金融和企业数据。
思考的认知能力。 基于其认知能力,模型可以“思考”并理解、总结、综合、推理和构建语言及其他符号表示。无论是灵活模型还是巨型模型,在这些认知任务中表现良好,并且这些核心能力是否需要庞大的模型规模尚不清楚。例如,像微软研究的 Orca这样的灵活模型已经在多个基准测试中展示了与 ChatGPT 相匹配或超越的理解、逻辑和推理技能。此外,Orca 还表明,推理技能可以从作为教师的大型模型中提炼出来。然而,目前用于评估模型认知技能的基准仍然很初级。需要进一步的研究和基准测试来验证灵活模型是否可以通过预训练或微调来完全匹配巨型模型的“思考”能力。
执行功能技能。 较大的模型由于其作为全能模型的一般关注,可能具有更多的功能技能和信息。然而,对于大多数业务用途来说,每个应用程序所需的功能技能有特定的范围。用于业务应用的模型应具备灵活性和扩展性,但通常不需要无限的功能技能。GPT-4 可以生成多种语言的文本、代码和图像,但说几百种语言并不一定意味着这些巨型模型本质上具有更多的认知能力——它主要是给模型增加了更多的“执行”功能技能。此外,功能专用引擎将与 GenAI 模型关联,并在需要该功能时使用——例如,将数学“Wolfram 超能力”添加到 ChatGPT 模块化可以提供最佳的功能,而不会给模型带来不必要的规模负担。例如,GPT-4 正在部署插件,这些插件实质上利用了较小的模型来提供附加功能。此外,据传 GPT-4 模型本身是由多个巨型(少于 100B 参数)“专家混合”模型组成,这些模型在不同的数据和任务分布上进行训练,而不是像 GPT-3.5 那样的单一密集模型。为了获得最佳的能力和模型效率组合,未来的多功能模型可能会使用每个小于 15B 参数的更小、更专注的专家混合模型。
图 3. 基于检索的功能扩展模型可以提供广泛的功能和相关信息,与模型大小基本无关。图片来源:Intel Labs。
需要了解的信息(记忆的或检索的)。 巨型模型通过在参数记忆中记忆大量数据来“知道”更多内容,但这不一定使它们更聪明。它们只是比小模型更具一般知识。在零-shot 环境下,巨型模型具有较高的价值,对于新用例提供了通用消费者基础,当不需要进行目标定位时,以及在提炼和微调灵活模型(如 Orca)时作为教师模型。然而,目标明确的灵活模型可以为特定领域进行训练和/或微调,从而提供所需能力的更锐利技能。
图 4. 检索在允许小模型匹配更大模型的价值(使用 Contriever 检索方法)。图片来源:Intel Labs,基于 Mallen et al的工作。
例如,一个针对编程的模型可以专注于与医疗 AI 系统不同的能力集。此外,通过使用针对内部和外部数据的检索,模型的准确性和时效性可以大大提高。最近的一项研究 显示,在PopQA 基准上,参数只有 1.3B 的模型通过检索可以与参数高达 175B 的模型表现相当(见图 4)。在这种意义上,具有高质量索引的可访问数据的针对性系统的相关知识可能远比全能的通用系统更广泛。这对于需要用例或应用程序特定数据的多数企业应用程序来说可能更为重要——在许多情况下,还需要本地知识而不是广泛的通用知识。这就是灵活模型未来将显现其价值的地方。
三大因素推动灵活模型的爆炸性增长
评估灵活模型的好处和价值时需要考虑三个方面:
-
中等模型尺寸的高效率。
-
开源或专有的许可。
-
模型的专门化包括通用或针对性检索。
在尺寸方面,灵活的通用模型,如Meta 的 LLaMA-7B 和 -13B 或 技术创新研究所的 Falcon 7B 开源模型,以及专有模型如MosaicML 的 MPT-7B、微软研究院的 Orca-13B 和 Salesforce AI Research 的 XGen-7B 正在迅速改进(见图 6)。选择高性能的小型模型对操作成本以及计算环境的选择有重大影响。ChatGPT 的 175B 参数模型和 GPT-4 的估计 1.8 万亿参数 需要大量的加速器安装,如具有足够计算能力的 GPU 以处理训练和微调工作负载。相比之下,灵活的模型通常可以在任何选择的硬件上运行推理,从单插槽 CPU,到入门级 GPU,再到最大加速机架。灵活 AI 的定义目前已基于 13B 参数或更小模型的出色结果经验性地设定为 15B 参数。总体而言,灵活模型提供了一种更具成本效益和可扩展性的方法来处理新的用例(见灵活模型的优缺点部分)。
开源许可的第二个方面允许大学和公司互相迭代模型,从而推动了创造性创新的蓬勃发展。开源模型允许小型模型能力的惊人进步,如图 5 所示。
图 5. 灵活的开源非商业和商业 GenAI 模型在 2023 年上半年迅速崛起。图片来源:英特尔实验室。
从 2023 年初开始,有多个例子显示了通用灵活生成性 AI 模型的出现,比如Meta 的 LLaMA,该模型包括 7B、13B、33B 和 65B 参数。以下在 7B 和 13B 参数范围内的模型是通过微调 LLaMA 创建的:斯坦福大学的Alpaca,伯克利人工智能研究所的Koala,以及由加州大学伯克利分校、卡内基梅隆大学、斯坦福大学、加州大学圣地亚哥分校和 MBZUAI 的研究人员创建的Vicuna。最近,微软研究院发布了一篇关于尚未发布的 Orca 的论文,这是一个基于 LLaMA 的 13B 参数模型,模拟了大型模型的推理过程,并在针对特定领域进行微调之前取得了令人印象深刻的结果。
图 6. 使用 Vicuna 评估集,通过 GPT-4 评估的开源聊天机器人的相对响应质量比较。图片来源: 微软研究院。
Vicuna 可以作为近期从 LLaMA 衍生出的开源灵活模型的一个良好代表。Vicuna-13B 是一个由大学合作创建的聊天机器人,旨在“解决现有模型如 ChatGPT 中训练和架构细节的缺乏。”在 ShareGPT 上进行用户共享对话微调后,Vicuna 的响应质量相比于 ChatGPT 和 Google Bard 的 GPT-4 判定结果超过 90%。然而,这些早期的开源模型尚不可用于商业用途。MosaicML 的 MPT-7B和技术创新研究所的 Falcon 7B为商业可用的开源模型,其质量 reportedly 与 LLaMA-7B 相当。
图 7. Orca-13B 在 BIG-bench Hard 的复杂零样本推理任务上表现与 ChatGPT 相当。图片来源: 微软研究院。
Orca 在复杂的零-shot 推理基准测试中,如Big-Bench Hard(BBH),超越了传统的指令调优模型,如 Vicuna-13B,超过了 100%。在 BBH 基准测试中,它与 ChatGPT-3.5 达到了相同的水平,”根据研究人员的说法。Orca-13B 在其他通用模型中的顶级表现强化了这样的观点,即巨型模型的巨大规模可能来源于早期模型的暴力破解。巨型基础模型的规模对一些较小的模型,如 Orca-13B,用于提炼知识和方法可能很重要,但大小不一定是推理所必需的——即使对于一般情况也是如此。需要注意的是——对模型的认知能力、功能技能和知识记忆的全面评估,只有在广泛部署和应用时才有可能。
截至撰写本博客时,Meta 发布了他们的 Llama 2 模型,包含 7B、13B 和 70B 参数。该模型在首代发布四个月后问世,提供了显著的改进。在对比图表中,灵活的 Llama 2 13B 达到了与前一代 LLaMA 更大模型以及 MPT-30B 和 Falcon 40B 相似的结果。Llama 2 是开源的,可用于研究和商业用途。它在与微软及包括英特尔在内的多个合作伙伴的紧密合作下推出。Meta 对开源模型的承诺及其广泛的合作将无疑为我们看到的这种模型的跨行业/学术快速改进周期提供额外的推动力。
灵活模型的第三个方面涉及专业化。许多新推出的灵活模型都是通用的——例如 LLaMA、Vicuna 和 Orca。通用灵活模型可能完全依赖于它们的参数记忆,通过细调方法进行低成本更新,包括LoRA: 大型语言模型的低秩适应以及检索增强生成,在推理时实时从策划的语料库中提取相关知识。检索增强的解决方案正在建立并不断通过 GenAI 框架如LangChain和Haystack进行增强。这些框架允许轻松灵活的集成索引和有效访问大型语料库以进行基于语义的检索。
大多数商业用户更倾向于针对其特定领域的定制模型。这些针对性模型通常也是基于检索的,以利用所有关键的信息资产。例如,医疗保健用户可能希望自动化患者沟通。
针对性模型使用两种方法:
-
对模型本身进行专业化以满足目标使用案例所需的任务和数据类型。这可以通过多种方式实现,包括在特定领域知识上对模型进行预训练(如phi-1在来自网络的教科书质量数据上进行预训练)、对相同规模的通用基础模型进行微调(如Clinical Camel微调了 LLaMA-13B),或将巨型模型的知识提炼并学习到学生型灵活模型中(如Orca学习模仿 GPT-4 的推理过程,包括解释痕迹、逐步思维过程和其他复杂指令)。
-
为即时检索策划和索引相关数据,这可能是大量的,但仍在目标使用案例的范围/空间内。模型可以检索持续更新的公共网络和私有消费者或企业内容。用户确定索引哪些来源,从而选择来自网络的高质量资源以及更完整的资源,如个人的私人数据或公司的企业数据。虽然检索现在已集成到巨型和灵活系统中,但在小型模型中发挥着关键作用,因为它提供了模型性能所需的所有必要信息。它还允许企业将其所有私有和本地信息提供给在其计算环境中运行的灵活模型。
灵活生成性 AI 模型的优点和缺点
在未来,紧凑模型的规模可能会增加到 20B 或 25B 参数,但仍远低于 100B 参数范围。也有许多中等规模的模型,如 MPT-30B、Falcon 40B 和 Llama 2 70B。虽然这些模型在零样本任务上预计表现会比小模型更好,但我不认为它们在任何特定功能集上的表现会显著优于灵活的、针对性的、基于检索的模型。
与巨型模型相比,灵活模型有许多优点,尤其是当模型是针对性的和基于检索的时,这些优点会得到进一步增强。这些好处包括:
-
可持续且成本较低的模型: 模型在训练和推理计算上具有显著较低的成本。推理运行时计算成本可能是 24x7 使用的商业导向模型可行性的决定因素,并且在广泛部署中整体环境影响的大幅减少也很显著。最后,凭借其可持续、特定和功能导向的系统,灵活模型并不试图解决人工通用智能(AGI)的雄心勃勃的目标,因此在与后者相关的公共和监管辩论中参与较少。
-
更快的微调迭代: 较小的模型可以在几小时(或更短时间)内完成微调,通过类似 LoRA 的适应方法向模型添加新信息或功能,这在灵活模型中非常有效。这使得改进周期更频繁,使模型始终与其使用需求保持同步。
-
基于检索的模型优点: 检索系统重新组织知识,从直接来源引用大部分信息,而不是模型的参数记忆。这改善了以下方面:
– 可解释性: 检索模型使用源属性,提供来源或能够追溯到信息来源的能力,以提供可信度。
– 时效性: 一旦索引了最新的来源,模型可以立即使用,无需任何训练或微调。这允许在接近实时的情况下不断添加或更新相关信息。
– 数据范围: 为按需检索编制的信息可以非常广泛和详细。当集中在目标领域时,模型可以覆盖大量的私有和公共数据的范围和深度。它可能在其目标空间中包含比巨型基础模型训练数据集更多的量和细节。
– 准确性: 直接访问数据的原始形式、细节和上下文可以减少幻觉和数据近似。只要在检索范围内,就可以提供可靠和完整的答案。使用较小的模型时,按需检索的可追溯策划信息和(如巨型模型中)可能过时、部分且未标注来源的记忆信息之间的冲突也更少。
-
硬件选择: 灵活模型的推理可以在任何硬件上实际完成,包括可能已经是计算设置一部分的普遍解决方案。例如,Meta 的 Llama 2 灵活模型(7B 和 13B 参数)在英特尔的数据中心产品上运行良好,包括 Xeon、Gaudi2 和 Intel 数据中心 GPU Max 系列。
-
集成、安全和隐私: 今天的 ChatGPT 和其他巨型 GenAI 模型是独立模型,通常在第三方平台上的大型加速器设施上运行,并通过接口访问。灵活的 AI 模型可以作为嵌入到更大业务应用程序中的引擎运行,并且可以完全集成到本地计算环境中。这对于安全和隐私有重大影响,因为不需要与第三方模型和计算环境交换/暴露信息,而且更广泛应用的所有安全机制都可以应用于 GenAI 引擎。
-
优化和模型缩减: 优化和模型缩减技术,如量化,通过将输入值转换为较小的输出值来减少计算需求,在增加功率效率的灵活模型上已经显示出了强大的初步结果。
检索模型需要对所有源数据进行索引: 模型在推断期间通过索引映射获取所需信息,但存在错过信息源的风险,使其对模型不可用。为了确保来源可追溯性、可解释性和其他属性,针对检索型模型不应依赖于存储在参数记忆中的详细信息,而应主要依赖于在需要时可用于提取的索引信息。
-
任务范围减少: 通用巨型模型具有出色的多功能性,特别擅长于未曾考虑的零-shot 新用途。灵活系统能够实现的广度和范围目前仍在评估中,但似乎随着最近的模型而有所改善。目标模型假定任务范围在预训练和/或微调期间已知和定义,因此范围的缩减不应影响任何相关能力。目标模型不是单一任务,而是一组相关能力。这可能会导致由于任务或业务特定的灵活模型而产生的碎片化。
-
可能通过少量样本的精细调整来改进: 为了有效地解决目标空间的模型,不总是需要进行精细调整,但可以通过调整模型以适应应用程序所需的任务和信息来提高 AI 的效果。现代技术使得这一过程可以用少量示例完成,而无需深入的数据科学专业知识。
-
在 Intel CPU 上的生成 AI 模型已经显示了强大的初步结果。
摘要
生成 AI 的重大飞跃使得新的能力成为可能,例如 AI 代理以自然语言交谈、引人入胜的文本总结和生成、图像创作、利用先前迭代的上下文等。本文介绍了“灵活 AI”这一术语,并阐述了为什么它将成为大规模部署生成 AI 的主要方法。简单来说,灵活 AI 模型运行更快,通过持续微调更新更迅速,并且通过开源社区的集体创新,更容易进行快速技术改进。
通过多个示例展示,随着最大模型的进化,表现出的卓越性能表明,灵活模型不需要像巨型模型那样的庞大体积。一旦掌握了基本认知能力,调整了所需功能,并根据需要提供数据,灵活模型为商业世界提供了最高的价值。
尽管如此,灵活模型不会使巨型模型灭绝。巨型模型在零-shot、开箱即用的设置中仍然有更好的表现。这些大型模型也可能被用作蒸馏成较小灵活模型的来源(教师模型)。虽然巨型模型拥有大量额外的记忆信息来处理任何潜在的用途,并且配备了多种技能,但这种通用性不被期望在大多数生成 AI 应用中需要。相反,将模型微调到与领域相关的信息和技能,并能够从策划的本地和全球来源中检索最新信息,将为许多应用提供更好的价值主张。
将灵活、针对性的 AI 模型视为可以集成到任何现有应用中的模块,提供了非常有吸引力的价值主张,包括:
-
部署和操作成本仅为其一小部分。
-
适应任务和私人/企业数据。
-
每晚更新,并且可以在从 CPU 到 GPU 或加速器的任何硬件上运行。
-
集成到当前计算环境和应用中。
-
运行在本地或私有云中。
-
受益于所有的安全和隐私设置。
-
更高的准确性和可解释性。
-
在提供类似生成 AI 能力的同时,更具环保责任感。
在少量巨型模型上取得的令人印象深刻的进展将继续。然而,行业最可能需要的只是几十个通用的灵活基础模型,这些模型可以用来构建无数的目标版本。我预见到不久的将来,广泛扩展的高级生成 AI 将渗透到所有行业,主要通过将灵活、针对性的安全智能模块作为增长引擎。
参考资料
-
Tseng, P. K. (2023 年 3 月 1 日). TrendForce 表示,由于云公司发起 AI 军备竞赛,ChatGPT 对 GPU 的需求可能达到 30,000 芯片,为商业化做好准备。TrendForce。
www.trendforce.com/presscenter/news/20230301-11584.html
-
介绍 MPT-7B:开源、商业可用 LLM 的新标准。 (2023 年 5 月 5 日)。
www.mosaicml.com/blog/mpt-7b
-
Mukherjee, S., Mitra, A., Jawahar, G., Agarwal, S., Palangi, H., & Awadallah, A. (2023). Orca:从 GPT-4 的复杂解释轨迹中逐步学习。arXiv (康奈尔大学)。
doi.org/10.48550/arxiv.2306.02707
-
Wolfram, S. (2023 年 3 月 23 日). ChatGPT 获得了“Wolfram 超能力”!Stephen Wolfram Writings。
writings.stephenwolfram.com/2023/03/chatgpt-gets-its-wolfram-superpowers/
-
Schreiner, M. (2023 年 7 月 11 日). GPT-4 架构、数据集、成本等泄露。THE DECODER。
the-decoder.com/gpt-4-architecture-datasets-costs-and-more-leaked/
-
ChatGPT 插件。 (无日期)。
openai.com/blog/chatgpt-plugins
-
Izacard, G., Caron, M., Hosseini, L., Riedel, S., Bojanowski, P., Joulin, A., & Grave, E. (2021). 使用对比学习进行无监督密集信息检索。arXiv (康奈尔大学)。
doi.org/10.48550/arxiv.2112.09118
-
Mallen, A., Asai, A., Zhong, V., Das, R., Hajishirzi, H., 和 Khashabi, D. (2022). 何时不信任语言模型:调查参数化和非参数化记忆的有效性。arXiv (康奈尔大学)。
doi.org/10.48550/arxiv.2212.10511
-
Papers with Code — PopQA 数据集。 (无日期)。
paperswithcode.com/dataset/popqa
-
介绍 LLaMA:一个基础的 65 亿参数的大型语言模型。 (2023 年 2 月 24 日)。
ai.facebook.com/blog/large-language-model-llama-meta-ai/
-
介绍 Falcon LLM。 (无日期)。
falconllm.tii.ae/
-
Nijkamp, E., Hayashi, H., Xie, T., Xia, C., Pang, B., Meng, R., Kryscinski, W., Tu, L., Bhat, M., Yavuz, S., Xing, C., Vig, J., Murakhovs’ka, L., Wu, C. S., Zhou, Y., Joty, S. R., Xiong, C., 和 Savarese, S. (2023). 使用 XGen 进行长序列建模:一个在 8K 输入序列长度上训练的 7B LLM。Salesforce AI Research。
blog.salesforceairesearch.com/xgen/
-
Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., 和 Chen, W. (2021 年 6 月 17 日). LoRA:大型语言模型的低秩适应。arXiv (康奈尔大学).
doi.org/10.48550/arXiv.2106.09685
-
Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goyal, N., Küttler, H., Lewis, M., Yih, W., Rocktäschel, T., Riedel, S., 和 Kiela, D. (2020). 针对知识密集型 NLP 任务的检索增强生成。NeurIPS 2020.
proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html
-
Introduction LangChain. (无日期).
python.langchain.com/docs/get_started/introduction.html
-
Haystack. (无日期).
www.haystackteam.com/core/knowledge
-
Mantium. (2023). Haystack 和 LangChain 如何赋能大型语言模型。Mantium.
mantiumai.com/blog/how-haystack-and-langchain-are-empowering-large-language-models/
-
Taori, R., Gulrajani, I., Zhang, T., Dubois, Y., Li, X., Guestrin, C., Liang, P., 和 Hashimoto, T. B. (2023 年 3 月 13 日). Alpaca:一个强大且可复制的指令跟随模型。斯坦福大学 CRFM.
crfm.stanford.edu/2023/03/13/alpaca.html
-
Geng, X., Gudibande, A., Liu, H., Wallace, E., Abbeel, P., Levine, S. 和 Song, D. (2023 年 4 月 3 日). Koala:一个用于学术研究的对话模型。伯克利人工智能研究博客.
bair.berkeley.edu/blog/2023/04/03/koala/
-
Chiang, W. L., Li, Z., Lin, Z., Sheng, Y., Wu, Z., Zhang, H., Zheng, L., Zhuang, S., Zhuang, Y., Gonzalez, J. E., Stoica, I., 和 Xing, E. P. (2023 年 3 月 30 日). Vicuna:一个令人印象深刻的开源聊天机器人,媲美 GPT-4,质量达到 90%的 ChatGPT。LMSYS Org.
lmsys.org/blog/2023-03-30-vicuna/
-
Rodriguez, J. (2023 年 4 月 5 日). 认识 Vicuna:最新的 Meta 的 Llama 模型,与 ChatGPT 性能相当。Medium.
pub.towardsai.net/meet-vicuna-the-latest-metas-llama-model-that-matches-chatgpt-performance-e23b2fc67e6b
-
Papers with Code — BIG-bench 数据集. (无日期).
paperswithcode.com/dataset/big-bench
-
Meta. (2023 年 7 月 18 日). Meta 和微软推出下一代 Llama。Meta.
about.fb.com/news/2023/07/llama-2/
-
Meta AI. (无日期). 介绍 Llama 2。
ai.meta.com/llama/
-
Gunasekar, S., Zhang, Y., Aneja, J., Mendes, C. C. T., Allie, D. G., Gopi, S., Javaheripi, M., Kauffmann, P., Gustavo, D. R., Saarikivi, O., Salim, A., Shah, S., Behl, H. S., Wang, X., Bubeck, S., Eldan, R., Kalai, A. T., Lee, Y. T., 和 Li, Y. (2023). 教科书就是你所需的一切。arXiv (康奈尔大学)。
doi.org/10.48550/arxiv.2306.11644
-
Toma, A., Lawler, P. R., Ba, J., Krishnan, R. G., Rubin, B. B., 和 Wang, B. (2023). Clinical Camel: 一个开源的专家级医疗语言模型,具有基于对话的知识编码。arXiv (康奈尔大学)。
doi.org/10.48550/arxiv.2305.12031
-
Patel, D., 和 Ahmad, A. (2023 年 5 月 4 日). Google “我们没有护城河,OpenAI 也没有。” SemiAnalysis。
www.semianalysis.com/p/google-we-have-no-moat-and-neither
-
加速 Llama 2 通过英特尔 AI 硬件和软件优化。(无日期)。英特尔。
www.intel.com/content/www/us/en/developer/articles/news/llama2.html
-
更小更好:Q8-Chat,在 Xeon 上的高效生成 AI 体验。(无日期)。Hugging Face。
huggingface.co/blog/generative-ai-models-on-intel-cpu
分析车辆尺寸与行人安全
原文:
towardsdatascience.com/suvs-are-killing-people-de6ce08bac3d
公开的交通事故数据表明,SUV 导致行人死亡和受伤的比例高于小型汽车
Danny Cunningham
·发布于Towards Data Science ·8 分钟阅读·2023 年 1 月 10 日
--
图片由Alexandru Acea提供,来源于Unsplash
《纽约时报》最近强调了“极具美国特色”的道路死亡上升问题。除了美国之外,几乎所有发达国家的道路变得越来越安全。即使在 COVID-19 疫情高峰期,当时道路上的汽车大幅减少,交通死亡人数仍在增加。
美国的道路对行人特别危险。保险公路安全研究所(IIHS)2022 年 3 月的一项研究发现,自 2009 年以来,行人死亡人数增加了 59%,而 2020 年所有机动车死亡中有 20%是行人。导致这些严峻统计数据的因素有很多,但一个主要因素(言外之意)是道路上车辆的尺寸。
大型 SUV 和皮卡车由于其较大的重量和更高的前端,明显比小型汽车更容易伤害或杀死行人。而且,如果你生活在美国,你肯定会注意到,美国人喜欢大型车辆。一些报告显示,现在超过 80%的新车销售在美国是 SUV 或皮卡。这对行人来说是个坏消息。并且我们还将面临荒谬沉重的电动汽车。
但我们真的可以将这些行人安全统计数据视为绝对真实吗?我决定自己分析数据,以查看大型车辆是否确实导致了行人伤害和死亡的显著增加(剧透:确实如此)。
以下分析的所有代码可以在 GitHub 上的 Jupyter 笔记本中找到。该笔记本包含了本文中未包含的额外细节。
数据
我分析了我所在城市芝加哥的数据。芝加哥发布了交通碰撞数据(以及大量其他数据集),并通过 Socrata API 公开提供。开始使用 Socrata 数据集很简单:只需按照文档获取应用程序令牌,找到你感兴趣的 API 端点,然后开始查询。我使用了sodapy Python 包来与 API 交互。
具体来说,我使用了与芝加哥交通碰撞相关的三个数据集:
-
碰撞记录:有关碰撞的基本细节,例如发生的时间和地点。每次碰撞一个记录。
-
人员记录:有关涉及碰撞的人员的详细信息。识别一个人是否受伤,以及他们是司机还是行人。每次碰撞至少一个记录。
-
车辆:涉及碰撞的车辆的详细信息。包括品牌、型号和车辆类型。每次碰撞至少包含一条记录。
这些数据集中的数据是由每次碰撞事件中的报告警官收集的。
照片由 Sawyer Bengtson 提供,来源于 Unsplash
方法论
分析的目标是确定大型车辆(SUV 和皮卡)是否比小型汽车更可能致使行人遇难或严重受伤。我们将通过拟合逻辑回归模型来实现这一点,该模型将显示哪些因素在决定伤害是否发生中是相关的。
定义目标变量和特征
在拟合模型之前,我们需要将原始数据转换为模型可以使用的格式。这意味着我们需要定义我们的目标变量(我们试图预测的内容)和一些特征(我们将用来进行预测的内容)。
在我们的案例中,目标变量是一个二进制值,指示在碰撞中是否有行人遇难或严重受伤(1 = 受伤,0 = 无受伤)。这可以从 People 数据集中轻松计算得出。首先,我们通过“person_type”字段确定碰撞中的人员是否为行人。然后,我们使用“injury_classification”字段确定是否有任何行人受到了致残性或致命伤害。我们实际上会在这里创建两个目标变量——一个用于致残伤害,一个用于致命伤害——并在后续单独建模。(注:数据集将致残性伤害定义为“阻止受伤者行走、驾驶或正常继续其能够进行的活动”的任何伤害。)
我们需要包括的主要特征是车辆类型。这通常比较简单,因为车辆数据集中包含了区分汽车、SUV、皮卡等的分类。然而,数据集中对车辆的标注并不一致。例如,一位警官可能将一辆丰田 RAV4 标记为 SUV,而另一位警官可能将其标记为汽车。为了考虑这些差异,我们将检查每个品牌/型号最常被标记的类别,并将其用于该品牌/型号的所有实例。在完成这小部分特征工程后,我们可以很容易地为每次碰撞计算两个二进制特征:一个表示是否涉及 SUV,另一个表示是否涉及皮卡。
我们还应该控制其他可能影响碰撞结果的因素,如天气条件和限速。我们将通过在回归模型中包括额外的特征来做到这一点,这将使我们能够估计每个单独特征的独立效应。
最后,我们将从数据集中移除所有不涉及行人的事故,因为这些事故与当前的问题无关。
逻辑回归模型
我们试图回答的问题可以很容易地框定为一个二元分类问题。根据关于事故的信息,预测行人是否会被杀害(或受伤)。在我们的用例中,理解为什么做出这种预测也是至关重要的。
逻辑回归是一个不错的选择,因为它产生了一个简单且易于解释的模型。我们将能够分析模型以了解哪些特征对目标变量的概率有正面或负面影响。具体来说,我们将能够看到车辆类型是否对行人死亡和受伤有显著影响。
逻辑回归的完整数学推导超出了本文的范围,但本质上模型会产生一个线性预测器:
其中β系数表示每个特征X对目标变量的影响。我们可以将系数转换为易于解释的赔率比(例如,当车辆是 SUV 时,行人受伤的可能性增加Z%)。
我们将拟合两个独立的逻辑回归模型:一个用于预测行人死亡,另一个用于预测行人致残伤害(包括死亡)。除了目标变量外,模型结构是相同的。
假设
我们在模型中做了几个假设,无论是明确还是隐含的:
-
SUV 和较小的汽车在涉及行人的事故中发生率相同。我们只是在预测事故的结果,而不是预测事故是否会因不同的汽车而发生。(注意: 有一些证据 表明 SUV 撞击行人的频率高于小型汽车,这可能会质疑这一假设。如果这确实是事实,那么 SUV 对行人死亡的影响可能比下面的结果所显示的更糟。)
-
所有 SUV 都是相同的(所有皮卡车也是如此)。显然,在现实世界中情况并非如此,因为 SUV 的形状和尺寸差异很大。
显然,这两辆 SUV 并不相同,但在模型中被视为相同。图像已获得求助于 carsized.com的许可。
结果
那么 SUV 和皮卡车对行人更危险吗?是的。
逻辑回归模型显示,与较小的汽车相比,SUV 更可能导致 16%的致残伤害和 36%的行人死亡。皮卡车导致致残伤害的可能性增加 33%,导致行人死亡的可能性增加 108%(是两倍多!)。
模型预测,在过去五年中,由于 SUV 和皮卡车,芝加哥大约发生了 20 起额外死亡(即,如果这些车辆被更小的汽车替代,这些死亡将不会发生)。
2017 到 2022 年间,205 名行人被个人车辆撞击致死。如果 SUV 和皮卡车被更小的汽车替代,大约可以拯救 20 条生命。(图像由作者提供)
模型还识别了一些其他有趣(但可能并不完全令人惊讶)的因素,这些因素导致了行人死亡。下表显示了模型发现的所有统计显著特征的影响。以下是关键发现的总结:
-
大型车辆对行人更危险。 SUV 和皮卡车在统计上更可能造成致残伤害。这是一个相当明显的结果,也是本文的主要话题。
-
高速行驶对行人危险。 更重要的是,允许高速的条件对行人危险。这最明显地表现为与限速的正相关关系,但也通过其他特征显示出来(有更多车道或由中央隔离带分隔的道路通常允许更高速度;停车场则不允许)。令人惊讶的是,雪天条件降低了行人受伤的可能性。这可能是因为雪天道路迫使汽车减速行驶。
-
低能见度条件对行人危险。 晚上的伤害概率增加。巷道中的碰撞也可能导致伤害,这可能是因为司机在狭窄的巷道中能见度降低(且巷道中可能有行人与车辆共同通过)。
-
在 COVID 期间,严重的行人伤害变得更加常见。 即使在控制了数据集中的其他特征时,COVID-19 大流行期间行为也发生了统计学上显著的变化。模型中添加了两个二元日期特征(“高峰 COVID”和“高峰后 COVID”)以解释这种差异。这可能归因于疫情高峰期街道更空旷导致速度更快。 (注:致残伤害的总数不一定增加;模型检测到的是 COVID 期间致残伤害的发生率增加。)
下表显示了支持上述总结信息的数据点,更多细节(如置信区间)可以在GitHub上找到。
一张表格总结了各种碰撞特征对行人致残伤害可能性的影响。特征只有在模型识别为统计显著(p < 0.05)时才被包含。“相对几率”+X% 表示如果该特征为真(或对于数值特征增加一个单位),而其他碰撞特征保持不变,那么致残伤害的可能性提高了 X%。
结论…我们可以对此做些什么?
我们得出的结论是——请鼓掌!——比起小型、轻型汽车,大型、重型汽车更容易致使行人死亡。好吧…这并不算突破性发现。然而,许多人没有意识到这是我们城市中的一个问题——直到被指出来我才想到这个问题。我希望一些有力的证据能提高人们的意识。
那我们能做些什么呢?
首先,如果你开的是 SUV 或皮卡车,这并不意味着你是一个坏人。我甚至不会试图劝你不要买这种车。人们购买这些车辆有很多原因(其中一些理由可能还是真正有效的)。
值得注意的是,SUV 确实在碰撞中能使驾驶员更安全。这就形成了一个每个驾驶员都被激励购买大型车辆的情况,但如果每个人都做出这个决定,整个系统对所有人(特别是车外的人)来说会变得更危险。在这种“安全军备竞赛”场景中,唯一无可争议受益的群体是那些宁愿向你出售昂贵 SUV 而不是便宜轿车的汽车公司。
冒着有点政治风险的说,解决这个问题的一种方法是通过法规。例如,欧洲车辆安全法规包括了行人安全的条款;而美国的安全法规则没有。下次去投票的时候,尤其是在地方选举中,值得记住这一点,因为当选官员实际上能影响你所在城市的街道。
成为 Medium 会员以获取成千上万作家的故事!
Svelte & 数据可视化
原文:
towardsdatascience.com/svelte-data-visualisation-6210e8164e74
Svelte 条形图(来源:作者,2023 年)
使用 Svelte 创建交互式条形图
Sutan Mufti
·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 3 月 20 日
--
介绍
数据可视化揭示了我们数据中的信息。通常的方法是制作图表或地图。我们可以在诸如 Microsoft Excel 或 Google Sheets 的电子表格软件中完成这项工作。通常这已经足够了,但我们可以通过添加交互性来使其更具吸引力。一个简单的交互性添加例子是通过添加切片器来根据类别筛选数据。在 Microsoft Excel 中,我们可以启用“开发者”选项卡并添加表单元素,如按钮、复选框等。
使用现代软件技术,创建这种交互式数据体验的经典工具是使用 d3.js 以及原生 JavaScript。它的工作原理如此简单,却能产生美丽的结果,真是令人惊叹。然而,与 React、Angular、Vue 等现代框架相比,原生 JavaScript 感觉有些过时。虽然大家都知道这些框架,但我觉得有一个用于数据可视化的 UI 框架被低估了;你从标题中猜到它就是 Svelte。考虑到它的年轻生态系统,我认为这个出色的框架需要被推广,以便让更多人使用或开发它;尤其是在数据科学领域。
本文展示了 Svelte 在交互式数据可视化中的应用(代码链接见下文的 GitHub 链接)。首先,我将简要说明其区别于其他框架的特性。其次,我将概述我们将要可视化的数据和我们的特性。最后,我将解释一些代码片段的工作原理;我更侧重于理念而不是语法。最后,我们将制作一个如文章主要 gif 图片所示的交互式数据 UI。
代码
代码可在以下链接获取。
[## GitHub - sutanmufti/svelte-data-visualisation: 使用 svelte 可视化数据
由 Sutan Mufti(2023)创建的这段代码是 RekayasaData.co.uk 项目的一部分。这个代码库托管了代码……
github.com
演示
你可以在这里找到演示:
[## 伦敦人口数据 2023
演示
ceritapeta.co.uk
什么是 Svelte?
Svelte 是一个通过将其文件编译成纯 HTML + JavaScript 和 CSS 来创建用户界面的框架。Svelte-kit 是开发 svelte 应用程序的新元框架。它刚刚在去年 12 月(2022)达到 v1.0。这是一个很好的工具来开始和管理 svelte 应用程序。
在使用 Svelte 几个月后,我爱上了这个框架,因为它的惊人功能。个人认为,这些功能让它脱颖而出,并提供了出色的开发体验。对于数据科学和可视化,我认为这些功能最为有用:
-
HMR(热模块替换):在开发过程中,应用程序在我们更改代码时保持其状态。这意味着实时更改会直接反映在浏览器中,同时保持变量。
-
原生响应性:这将作为本文中我们应用的主要功能进行演示!我们不需要在 svelte 中管理状态。如果你熟悉 React,这相当于
useState
钩子。虽然,我们仍然需要管理多个组件之间的状态连接。一个非常简短的代码来演示这个功能;点击一个按钮,生成一个随机数,并在 HTML 元素中显示它:
<!-- ./randomiser.svelte -->
<script lang='ts'>
let theValue: number = 40;
function randomiser(){theValue = Math.floor(Math.random() * 100);}
</script>
<button on:click={randomiser}>random!</button><br>
<p>the value: {theValue}</p>
随机器应用(来源:作者,2023)
你可以按原样将上述代码粘贴到 svelte 文件中!
随机器应用代码(来源:作者,2023)
- 数据绑定:类似于
Document.querySelector
,bind
语法将一个元素绑定到脚本的变量上。我认为这种方式管理 HTML 元素更简单。我在 github 代码库中提供了一个循环和绑定的示例。
绑定(来源:作者,2023)
构建界面
构建界面需要 HTML、JavaScript 和 CSS 的知识。我更喜欢 typescript 来减少 bug 开发,因为它使我们编写的代码比纯 JavaScript 更严格。现在,让我们使用这些功能(绑定、响应性和 HMR)。我将讨论的步骤是:
-
使用
+layout.svelte
设置主题 -
使用 ES6 语法设置要导入的数据
-
整合各个部分
设置主题
首先,我喜欢从根布局的样式开始。+layout.svelte
文件定义了所有 Svelte 文件的行为。在这个文件中,我只设置了字体、颜色、边距和背景颜色。这基本上是应用的“主题”。为了这个演示,我们保持简单,至少它不是默认的。
<!-- +layout.svelte -->
<svelte:head>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
color: #303030;
margin: 0;
background-color: #ebebeb;
}
</style>
</svelte:head>
<slot></slot>
设置数据
在这个演示中,我们将使用来自英国大伦敦地区的人口数据。你可以在这里找到这些数据。数据格式为 csv,建议将其转换为 JSON 格式。使用pandas
非常简单,因为有to_json("data.json", orient="record")
。请注意orient
参数,它表示一种无模式的风格数据,类似于 MongoDB。这样,我们可以在 Javascript 中使用forEach
或map
方法来处理一个数组。
在 ES6 语法中,我们可以使用import
从其他脚本文件中导入变量或对象。由于我们的数据基本上是 JSON,我们可以像这样将其定义为一个变量:
// data.ts
export const data = [
{name: "sutan", article: "Visualising Data with Svelte"},
{name: "sutan", article: "Spatial Data Science"},
{name: "someone", article: "somedata"}
//... and many more rows
]
然后我们可以将数据导入并解构到主脚本文件中。
// main.ts
import {data} from './data.ts'
interface Record {
name: string;
article: string;
}
function main(data: Record[]){
data.forEach(record=>console.log(`${record.name} wrote ${record.article}`))
}
main(data)
在我的示例中,我们将伦敦人口数据存储在./src/routes/population_london.ts
文件中,变量名为LondonData
。
构建应用
让我们将这一部分分成两个小节:处理用户事件和实际可视化数据。
绑定与用户事件
首先,我们创建一个选项表单字段,用户可以用来选择一个区。请查看以下代码,我们使用bind
将表单值绑定到selectionIndex
。在未来的组件中,我们可以使用selectionIndex
,它会根据用户的事件动态变化。
<div>
density of
<select bind:value={selectionIndex}>
{#each data as d,i}
<option value={i}>{d.Name}</option>
{/each}
</select>
is {selectedDensity} per square km
</div>
数据可视化
借用 d3.js 开发范式的理念,我使用 SVG 来可视化图形。Svelte 提供了原生的each
语法来循环变量。对于我们的人口数据中的每条记录,我们想要创建一个矩形,并将高度设置为密度数据。查看rect
的height
属性。
然后,on:click
属性处理用户点击条形图时的点击事件。它的回调执行一个函数,该函数接受矩形索引作为参数。activateSelection(index: number)
基本上设置了前一节中绑定的selectionIndex
。当selectionIndex
改变时,选择选项表单的值也会改变。
<svg width={data.length * barwidth / 100} height={maxHeight} >
<g >
{#each data as d,i}
<rect height={d.Population_per_square_kilometre / scaling} y={maxHeight - d.Population_per_square_kilometre / scaling} width={barwidth/100*0.8} x={(i * barwidth / 100 * 1)} fill={(i == selectionIndex)? "black": "grey"} on:click={()=>{return activateSelection(i)}} on:keydown={()=>{}}></rect>
{/each}
</g>
</svg>
d3.js 中的等效代码如下:
// usingD3.js
const graph = d3
.select("svg")
.selectAll("rect")
.data(LondonData)
.enter()
.append("rect")
.attr("width", (d)=>d.Population_per_square_kilometre)
.on("click",handleClick)
演示
使用静态适配器,我们可以将其构建为普通的index.html
文件,以便在如 apache httpd 这样的 Web 服务器中静态服务。以下链接是已发布的演示,也是在文章顶部的主要图片。
[## 伦敦人口数据 2023
演示使用 Svelte! ceritapeta.co.uk
结论
我认为 Svelte 是一个被低估的工具,特别是在交互式数据可视化方面。在这个演示中,我们甚至不需要额外的库来进行数据可视化,只需导入数据并使用 Svelte 进行显示。目前一切都是静态的,因为我认为这样足以展示可视化部分。更有趣的是,我们在这篇文章中还没有探讨 Svelte 的全部潜力。例如,可以利用 +server.js
构建 API 并将其集成到我们的应用中;或者使用过渡效果;或者构建组件以增强图表的外观。可以确定的是:
Svelte 在数据可视化方面的未来令人兴奋。
我希望你觉得这篇文章有用。谢谢阅读。
大型语言模型(LLMs)的软件/硬件协同优化策略
原文:
towardsdatascience.com/sw-hw-co-optimization-strategy-for-large-language-models-llms-855f20a14629
如何最大限度地发挥系统性能以加速运行 LLMs?— 最佳实践
Liz Li
·发表在Towards Data Science ·5 分钟阅读·2023 年 12 月 16 日
--
领先的大型语言模型(LLMs)如 ChatGPT、Llama 等正在革新科技行业并影响每个人的生活。然而,它们的成本构成了一个重大障碍。使用 OpenAI API 的应用会产生持续的高费用(每 1,000 个提示令牌$0.03,每 1,000 个采样令牌$0.06)。
为了降低成本,公司倾向于托管自己的 LLMs,费用因模型大小而异(100–200B 参数的大型 LLMs 成本约为 7–15B 参数的小型模型的 10 倍)。这一趋势催生了 AI 芯片竞赛,各大科技公司旨在开发自己的 AI 芯片,减少对昂贵硬件的依赖。
模型大小的趋势。来源:AWS reInvent
如何挤压每一丝计算能力以运行 LLMs?在这篇文章中,我将对 LLM 的优化策略进行深入分析,涵盖模型、软件和硬件。这将遵循我在之前文章中编写的 AI SW/HW 协同设计方法论,并对 LLM 特有的成本和性能优化进行更深入的探讨。
## 如何在新时代共同设计 AI/ML 的软件/硬件架构?
对 AI/ML 高效架构设计的全面视角
towardsdatascience.com
来源:作者及其他同事制作
运行 LLM 模型的计算和内存需求正在呈指数级增长,而计算/内存能力却在较慢的轨迹上滞后,如上图所示。为了弥补这种性能差距,探索三个关键领域的改进是至关重要的:
-
算法改进与模型压缩: 我们如何通过增强模型特性来减少计算和内存需求而不影响质量?最新的 LLM 量化技术有哪些进展,能在保持质量的同时减少模型大小?
-
高效的软件栈与加速库: 在构建能够无缝连接 AI 模型和硬件的软件栈时,哪些考虑因素至关重要?我们如何利用硬件特性来优化 LLM 加速?现有的软件挑战和潜在的改进有哪些?
-
强大的 AI 硬件加速与先进的内存层次结构: 当前有哪些针对 LLM 的硬件加速器?我们如何通过内存层次结构的潜在改进来缓解高内存需求?
我将为上述每个主题撰写一篇文章。在这篇文章中,我们将深入探讨第一个主题(算法改进与模型压缩)!
LLM 基于变换器架构(编码器-解码器),其中包括仅解码器模型架构如 Llama、ChatGPT 等,以及编码器-解码器模型架构如 Whisper、T5 等。新兴模型每天都在出现。在这篇文章中,我们将重点关注以下四个新特性来加速变换器的性能。
1. 量化
将 FP32 模型转换为 INT8 模型可以将内存大小缩小约 4 倍,而 INT4 量化则能实现大约 8 倍的模型大小减少。此外,整数矩阵乘法的计算成本显著降低,因为其速度超过了浮点计算。量化分为两类——后训练量化(PTQ)和量化感知训练(QAT)。对于推理,推荐使用 PTQ。Hugging Face 托管了大量利用各种量化方法如 GPTQ、GGUF、AWQ 等的量化 LLM 模型。
通过量化实现模型大小的减少。来源:huggingface.co/TheBloke
2. 注意力机制
缩放点积注意力计算密集,涉及多次矩阵乘法键、查询和数值。在多头注意力中,存在多个注意力层(称为头),每个头生成的输出都会被拼接在一起。
一个缩放点积注意力(左)和多头注意力(右)的示意图,多头注意力实际上是多个 SDPA 头并行。来源:注意力机制就是一切 [参考 1]
为了优化注意力推理,引入了多查询注意力的概念(参考 2 快速变换器解码)。在这种方法中,键和值在不同的注意力头之间共享,减少了为每个注意力头获取新键值对的需求,最小化了数据传输。
此外,在多头注意力和多查询注意力之间存在一种中间机制,称为分组查询注意力。它涉及将键和值投影到不同的组中,这与多查询注意力中的单一投影不同。这种方法在保持模型质量的同时有效减少了内存需求。
不同注意力机制的比较。来源: GQA: 从多头检查点训练通用多查询变换器模型 [Ref 3]
Flash Attention(参考 [4])。与传统的逐层计算方法不同,Flash Attention 使用平铺技术将多个层融合在一起,并在单次操作中计算出最终结果。平铺大小考虑了系统内存层次结构,优化了 IO 操作。下图演示了 Flash Attention 与 PyTorch 原生实现相比的概念和延迟改进。
在 40 GB GPU 上使用的平铺 Flash 注意力计算模式和内存层次结构。来源:Flash Attention: 快速且内存高效的精确注意力与 IO 感知
3. 分页 KV 缓存
随着输入和输出标记数量的增加,键值缓存可能变得非常庞大,具有动态长度,这导致由于碎片化和冗余复制而造成的内存访问效率低下。受到操作系统中虚拟内存机制的启发,Paged Attention 旨在最小化 KV 缓存内存中的冗余,并促进 KV 缓存在请求内和跨请求的灵活共享。
左侧:参数(灰色)在每次服务请求中保持在内存和 KV 缓存(红色)中。右侧:vLLM 有助于减缓内存需求以提升系统吞吐量。来源:大语言模型服务的高效内存管理与 PagedAttention [Ref 5]
4. 推测采样 [参考 6]
在自回归生成模型中,生成单个标记需要完整的模型推理,这会导致重复的权重加载,耗时较长。推测采样旨在缩小小型模型和大型模型之间的差距,通过提供类似于大型模型的高质量结果,同时具有类似于小型模型的更快速度。
AWQ 引擎显著加快了猜测性解码的速度。来源:在快速车道上!猜测性解码 — 10 倍更大的模型,无额外成本
除了从算法和模型的角度提到的四大推理加速技术,还有许多其他特性可以加速 LLM 模型的推理。这些包括模型/张量并行、模型稀疏性、知识蒸馏等,新的研究也在不断涌现。利用这些技术对加速 LLM 解决方案至关重要。
需要注意的是,优化 AI 工作负载总是涉及模型、软件和硬件方面的协同。在即将发布的文章中,我们将深入探讨 LLM 加速的软件栈/库和硬件架构方面,敬请关注!
参考文献
1 Ashish Vaswani 等人,注意力即你所需,NIPS 2017,加州长滩
2 Noam Shazeer,快速 Transformer 解码:一个写头就足够,2019,arxiv
[3] Joshua Ainslie 等人,GQA:从多头检查点训练通用多查询 Transformer 模型,2023
[4] Tri Dao 等人,Flash Attention:具有 IO 感知的快速且内存高效的精确注意力,2022,arxiv
[5] Woosuk Kwon 等人,大语言模型服务的高效内存管理与 PagedAttention,2023,arxiv
[6] Charlie Chen 等人,使用猜测性采样加速大语言模型解码,2023,arxiv
LLM 和 GUI 的协同作用,超越聊天机器人
原文:
towardsdatascience.com/synergy-of-llm-and-gui-beyond-the-chatbot-c8b0e08c6801?source=collection_archive---------3-----------------------#2023-10-20
使用 OpenAI GPT 功能调用来驱动你的移动应用
汉斯·范·达姆
·
关注 发布于 Towards Data Science ·10 分钟阅读·2023 年 10 月 20 日
--
图片由 Midjourney 创建
介绍
我们引入了一种激进的用户体验(UX)方法,以最佳方式将对话式人工智能(Conversational AI)与图形用户界面(GUI)交互融合,形式为自然语言条。它位于每个屏幕的底部,允许用户通过一个入口点与整个应用程序进行交互。用户始终可以选择语言或直接操作。他们无需搜索如何完成任务,可以用自己的语言表达意图,同时保持 GUI 的速度、紧凑性和可操作性。GUI 的屏幕定义与用户的请求一起发送到大型语言模型(LLM),让 LLM 引导 GUI 朝向用户的意图。我们在上一篇文章中介绍了这一概念,并在此基础上进行了优化,实施了一个 Flutter 示例应用程序,点击这里试用。完整的 Flutter 代码可以在 GitHub 上找到,因此你可以在自己的上下文中探索这一概念。一个简短的视频解释了该功能,点击这里观看。本文面向产品负责人、UX 设计师和移动开发人员。
背景
自然语言接口和图形用户界面(GUIs)将人类用户与计算机系统的能力连接起来。自然语言允许人们在即时性之外交流,而指点允许对世界上具体事物进行沟通。指点相对于产生和处理自然语言来说,要求对方的认知努力更少,也减少了混淆的可能。然而,自然语言可以传达关于整个世界的信息:具体的、抽象的、过去的、现在的、未来的,以及元世界,提供对一切的随机访问。
随着 ChatGPT 的兴起,自然语言处理(NLP)的解读质量已达到了高水平,利用‘功能调用’,现在可以创建完整的自然语言接口,减少误解的发生。当前 LLM 社区的趋势集中在聊天界面作为主要的对话用户界面。这种方法源于聊天是书面的人际交互的主要形式,并在滚动窗口中保留对话历史。许多信息适合图形表示。一种常见的方法是将 GUI 元素融入聊天对话中。然而,这样的成本是聊天历史变得庞大,并且在聊天历史中管理 GUI 元素的状态是复杂的。此外,通过完全采用聊天范式,我们失去了向用户提供菜单驱动交互路径的选项,使他们在应用程序的功能方面更加模糊。
这里采用的方法可以应用于各种应用程序,例如银行、购物和旅行应用。移动应用的最重要功能位于主屏幕上,但其他选项卡或菜单中的功能可能让用户很难找到。当用户可以用自己的语言表达请求时,他们自然会被引导到最有可能满足需求的屏幕上。当最重要的功能在主屏幕上时,针对这一核心功能的可选项数量可能会使人不知所措。自然语言从另一端接近这一点:用户主动表达他们想要的内容。将这两者结合起来,可以实现最佳状态,即两者互补,用户可以选择最适合其任务或子任务的选项。
自然语言条
自然语言条(NLB)允许用户输入或说出他们想要从应用程序中得到什么。与他们的请求一起,所有屏幕的定义都通过 OpenAI 创造的“函数调用”技术发送到 LLM。在我们的概念中,我们将 GUI 屏幕视为应用程序中的一个函数,其中屏幕上的用户输入小部件被视为该函数的参数。
我们将以银行应用程序为例来说明这一概念。当用户用自然语言发出请求时,LLM 会告诉我们应用中的导航组件打开哪个屏幕以及设置哪些值。这在以下图中进行了说明:
以下图像提供了一些交互示例:
以下图像展示了 LLM 得出的结论。它得出的最佳方式是展示您附近的银行网点:
以下示例展示了即使显著缩短的表达也可能达到用户期望的结果:
因此,自由输入也可以是一种非常快速的交互模式。这种缩略语的正确解释取决于其背后意图的明确性。在这种情况下,应用程序没有其他屏幕可以转移,因此 LLM 可以做出明确的决定。
另一个额外功能是交互有历史记录,因此用户可以继续输入以纠正之前的意图:
因此,LLM 可以结合几条消息,其中一条纠正或增强另一条,以产生所需的函数调用。这对于旅行规划应用程序非常方便,用户最初只提到出发地和目的地,然后在后续消息中添加额外的要求,如日期、时间、仅直达连接、仅头等舱等。
点击这里亲自尝试示例应用程序。语音输入在 Chrome 浏览器以及 Android 和 iOS 本地环境中有效。使用的是平台提供的语音识别,因此如果质量不足以满足您的目的,还有改进的空间。
工作原理
当用户在自然语言栏中提出问题时,会向 LLM 的提示中添加一个JSON 模式。JSON 模式定义了所有屏幕及其输入元素的结构和目的。LLM 尝试将用户的自然语言表达映射到这些屏幕定义中的一个。它返回一个 JSON 对象,以便您的代码可以进行‘函数调用’以激活相应的屏幕。
功能和屏幕之间的对应关系在下图中进行了说明:
完整的功能规范可以在这里查看。
自然语言栏的 Flutter 实现基于LangChain Dart,这是 LangChain 生态系统的 Dart 版本。所有提示工程都发生在客户端。将屏幕、导航逻辑和功能模板保存在一起更有意义。由于一对一关系,功能模板被整合到导航结构中。以下展示了激活并导航到信用卡屏幕的代码:
DocumentedGoRoute(
name: 'creditcard',
description: 'Show your credit card and maybe perform an action on it',
parameters: [
UIParameter(
name: 'limit',
description: 'New limit for the card',
type: 'integer',
),
UIParameter(
name: 'action',
description: 'Action to perform on the card',
enumeration: ['replace', 'cancel'],
),
],
pageBuilder: (context, state) {
return MaterialPage(
fullscreenDialog: true,
child: LangBarWrapper(
body: CreditCardScreen(
label: 'Credit Card',
action: ActionOnCard.fromString(
state.uri.queryParameters['action']),
limit:
int.tryParse(state.uri.queryParameters['limit'] ?? ''))));
}),
在顶部,我们看到这是一个路由:应用程序路由系统中的一个目标,可以通过超链接激活。描述部分是 LLM 用来将屏幕与用户意图匹配的部分。下面的参数(信用卡额度和要采取的操作)定义了自然语言中的屏幕字段,以便 LLM 可以从用户的问题中提取这些字段。然后,pageBuilder-item 决定如何使用深层链接的查询参数来激活屏幕。可以在langbar-1d3b9.web.app/home
中识别这些参数:在 NLB 中输入‘信用卡额度为 10000’,浏览器的地址栏将显示:langbar-1d3b9.web.app/creditcard?limit=10000
。
使用了 LangChain 代理,这使得这种方法独立于 GPT,因此也可以使用其他 LLM 如 Llama、Gemini、Falcon 等。此方法还便于添加基于 LLM 的辅助功能。
历史面板
自然语言栏提供了一个可折叠的互动历史面板,用户可以轻松重复以前的陈述。这样,互动历史会被保存,类似于聊天界面,但以紧凑、可折叠的形式出现,节省屏幕空间并防止混乱。用户之前的语言陈述会以用户使用的语言显示。系统响应会作为超链接包含在用户陈述中,可以点击以重新激活对应的屏幕:
当 LLM 无法完全确定要激活的屏幕时,系统响应会被明确显示,此时历史面板会自动展开。这种情况可能发生在用户提供的信息过少、用户的请求超出应用程序的范围,或发生错误时:
未来
历史面板是提供客户支持和上下文敏感帮助的绝佳场所,采用聊天机器人形式。写作时,有关 RAG(检索增强生成)技术的讨论和演变非常活跃,这些技术使聊天机器人能够根据贵组织提供的大量文本内容回答用户的问题。此外,自然语言栏是想象如何利用自然语言赋予应用程序更多力量和便捷的良好起点。
客户支持
互动的历史面板是嵌入客户支持对话的好地方。这些对话占用的垂直空间比文本中大多数示例更多。在客户支持对话中,贵组织的回答是语言表达,无论是由聊天机器人还是人工服务人员生成。它们需要完全显示,而不是嵌入超链接中。但这没关系,因为否则这些空间会被其他地方占用。贵组织可能已经在网站或应用程序上拥有一个聊天机器人。将其与自然语言栏的历史面板统一是合乎逻辑的。
上下文敏感的帮助
在上述描述的背景下,我们保持与应用程序的语言互动历史。未来,我们可能(隐形地)将直接用户交互的轨迹添加到该历史序列中。通过将用户交互的历史轨迹与 RAG 结合在应用程序的帮助文档中,可以提供上下文敏感的帮助。用户的问题将更能根据应用程序当前状态得到回答。
超越移动应用程序的静态辅助
当前提议是一个 MVP(最简可行产品)。它提供了一个静态模板,用于在应用程序的上下文中解释用户的语言请求。这种技术为未来的广泛改进打开了广阔的前景:
-
当用户在特定屏幕上提问时,我们可能能够动态地将更多具体的解释模板(功能)添加到提示中,这些模板依赖于该屏幕的状态,例如‘为什么提交按钮变灰/禁用?’。
-
使用自然语言栏进行功能调用可以作为创意应用的助手,例如执行像‘调整为相同大小’或‘转变为可重用组件’这样的操作。微软 Copilot 365 已经在使用类似的功能。本文中采用的方法也可以帮助您的组织利用这些功能。
与系统每个方面的自然语言交互将迅速成为每个 UI 的主要组成部分。在使用‘功能调用’时,您必须在提示中包含系统能力,但很快会有更经济和更强大的方法上市。例如,OpenAI 最近开放了模型微调与功能调用,允许您创建一个具有系统能力的 LLM 版本。即使这些能力非常广泛,提示的负担仍然有限。
结论
LLMs 可以通过‘功能调用’成为与基于 GUI 的应用程序进行自然语言交互的绝佳桥梁。引入了一个自然语言栏,允许用户输入或说出他们的意图。系统将通过导航到正确的屏幕并预填正确的值来回应。示例应用程序使您实际体验到这一点,提供的源代码使得如果您使用 Flutter,可以快速将其应用到自己的应用程序中。自然语言栏不仅适用于 Flutter 或移动应用程序,还可以应用于任何具有 GUI 的应用程序。它的最大优势是可以从一个单一的访问点打开应用程序的全部功能,而无需用户知道如何操作、在哪里找到功能,甚至不需要了解应用程序的术语。从应用开发的角度来看,您只需简单地记录屏幕的目的和屏幕上的输入小部件,就可以向用户提供这一切。
请在评论中分享您的想法。我非常好奇。
请在LinkedIn或UXX.AI上关注我。
特别感谢David Miguel Lozano帮助我完成LangChain Dart。
一些有趣的文章:多模态对话、谷歌关于 GUI 和 LLM 的博客、将 GUI 交互解释为语言、LLM 驱动的助手、语言与 GUI、聊天机器人与 GUI
除非另有说明,本文中的所有图片均由作者提供
SynthDiD 101:Synthetic Difference-in-Differences 初学者指南
原文:
towardsdatascience.com/synthdid-101-a-beginners-guide-to-synthetic-difference-in-differences-84fed9b730ae
关于该方法的优缺点,使用 R 中的 synthdid 包进行演示
Nazlı Alagöz
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 4 月 26 日
--
由作者使用Nightcafe生成的标题图像
在这篇博客文章中,我简要介绍了 Synthetic Difference-in-Differences(SynthDiD)方法,并讨论了它与传统的 Difference-in-Differences(DiD)和 Synthetic Control Method(SCM)的关系。SynthDiD 是 SCM 和 DiD 的一个概括版本,结合了两种方法的优势。它可以在大面板数据下进行因果推断,即使在短期的前处理期间也能有效。我讨论了这种方法的优缺点,并使用 R 中的synthdi包演示了这种方法。我提供了简要的要点介绍。
Synthetic Control Method 与 Synthetic Difference-in-Differences
Synthetic Control Method 和 Synthetic Difference-in-Differences 方法紧密相关,但在估计因果效应的方式上有所不同。Synthetic Control Method 是一种统计技术,通过结合多个与处理单元在所有相关特征上相似的对照单元,创建一个“合成”的对照组。合成对照组的构建旨在尽可能接近处理单元的前处理结果。然后,通过将处理单元的后处理结果与合成对照组的结果进行比较,来估计处理效应。
另一方面,合成 DiD 结合了合成控制方法和差分中的差分方法1。在这种方法中,使用与合成控制方法相同的方法构建一个合成控制组。然而,通过比较处理单位和合成控制组在处理前后结果的变化来估计处理效果。这种方法通过考虑处理组和控制组之间的先前差异来允许对处理效果进行更稳健的估计。
总结来说,尽管两种方法都使用合成控制组,但合成控制方法通过比较处理单位和合成控制组的后处理结果来估计处理效果,而合成 DiD 通过比较处理单位和合成控制组在引入处理前后的结果变化来估计处理效果。
合成 DiD 的要点:
-
SynthDiD 是 SCM 和 DiD 的广义版本。
-
它借鉴了 DiD 方法和合成控制方法的优点[2][3]。
-
它通过最优加权控制组单位来构建处理组的反事实,以最小化处理组和控制组在处理前阶段的差异,就像在 SCM 中一样。
-
然后,通过比较处理单位和合成控制组在干预前后结果的变化来估计处理效果,就像在 DiD 方法中一样。
-
SynthDiD 考虑了单位级别的结果变化,就像 DiD 方法一样[4]。
-
即使在处理阶段很短的情况下,它也能在广泛的面板数据中进行推断,这使其与合成控制方法有所区别(SCM 需要较长的处理阶段)。
-
与 SCM 相同,单位变成了“变量”,我们将结果表示为单位的加权平均值(即,合成控制)。
示例
假设我们是一家销售植物性食品产品的公司,例如豆奶或豆酸奶,我们在多个国家运营。一些国家实施了新法规,禁止我们将植物性产品标记为“牛奶”或“酸奶”,因为声明只有动物产品可以被标记为“牛奶”或“酸奶”(感谢我的一位前学生提供这个示例的灵感 😃。因此,由于这些国家的新规定,我们必须将豆奶标记为豆饮料,而不是豆奶等。我们想知道这项立法对我们收入的影响,因为这可能有助于指导我们在不同国家的游说和营销活动。
我模拟了一个平衡的面板数据集,显示了我们公司在 30 个不同国家 30 个时期的收入。三个国家在 20 期实施了这项立法。下图中可以看到数据的快照。treat
是一个虚拟变量,指示一个国家是否在给定时期实施了立法。revenue
是以百万欧元计的收入。你可以在这个 Gist中找到模拟和估计代码。
# Install and load the required packages
# devtools::install_github("synth-inference/synthdid")
library(synthdid)
library(ggplot2)
library(data.table)
# Set seed for reproducibility
set.seed(12345)
source('sim_data.R') # Import simulation function and some utilities
dt <- sim_data()
head(dt)
数据快照,图片由作者提供。
接下来,我们将面板数据转换为synthdid
包所需的矩阵。给定结果、处理和对照单元以及处理前期,创建一个合成对照,并使用synthdid_estimate
函数估计治疗效果。为了进行推断,我们还需要计算标准误差。如果有多个处理单元,我使用jacknife
方法。如果只有一个处理单元,placebo
方法是唯一的选择。根据标准误差,我还计算了治疗效果的 95%置信区间。我将在下图中报告这些结果。
# Convert the data into a matrix
setup = panel.matrices(dt, unit = 'country', time = 'period',
outcome = 'revenue', treatment = 'treat')
# Estimate treatment effect using SynthDiD
tau.hat = synthdid_estimate(setup$Y, setup$N0, setup$T0)
# Calculate standard errors
se = sqrt(vcov(tau.hat, method='jackknife'))
te_est <- sprintf('Point estimate for the treatment effect: %1.2f', tau.hat)
CI <- sprintf('95%% CI (%1.2f, %1.2f)', tau.hat - 1.96 * se, tau.hat + 1.96 * se)\
我们还将结果绘制一些附加数据。
# Check the number of treatment and control countries to report
num_treated <- length(unique(dt[treat==1]$country))
num_control <- length(unique(dt$country))-num_treated
# Create spaghetti plot with top 10 control units
top.controls = synthdid_controls(tau.hat)[1:10, , drop=FALSE]
plot(tau.hat, spaghetti.units=rownames(top.controls),
trajectory.linetype = 1, line.width=.75,
trajectory.alpha=.9, effect.alpha=.9,
diagram.alpha=1, onset.alpha=.9, ci.alpha = .3, spaghetti.line.alpha =.2,
spaghetti.label.alpha = .1, overlay = 1) +
labs(x = 'Period', y = 'Revenue', title = 'Estimation Results',
subtitle = paste0(te_est, ', ', CI, '.'),
caption = paste0('The number of treatment and control units: ', num_treated, ' and ', num_control, '.'))
在下图中,展示了估计结果。观察治疗国家和合成对照的平均趋势如何相对平行(它可能看起来不完全平行,但对于本示例而言并不必要)。治疗国家的平均值变动较大,主要由于只有三个这样的国家,导致趋势不够平滑。透明的灰色线条代表不同的对照国家。自 20 期的处理开始后,治疗国家的收入下降,估计为 0.51 百万欧元,如图所示。这意味着新法规对我们公司的收入有负面影响,应采取必要措施以防止进一步下降。
结果,图片由作者提供。
让我们绘制用于估计合成对照的权重。
# Plot control unit contributions
synthdid_units_plot(tau.hat, se.method='jackknife') +
labs(x = 'Country', y = 'Treatment effect',
caption = 'The black horizontal line shows the actual effect;
the gray ones show the endpoints of a 95% confidence interval.')
在下图中,你可以观察到每个国家在构建合成对照时的加权情况。治疗效果根据选定的未处理国家作为对照单元有所不同。
国家权重,图片由作者提供。
现在我们对 SynthDiD 有了更多的了解,接下来讨论一下这种方法的优缺点。每种方法都有其优缺点,SynthDiD 也不例外。以下是开始使用这种方法时需要注意的一些优缺点。
SynthDiD 方法的优点:
-
合成控制方法通常用于少量处理和控制单位,并且需要处理前的长期平衡数据。而 SynthDiD,即使在处理前的数据周期较短的情况下也能很好地工作,这与合成控制方法不同[4]。
-
该方法之所以被优先考虑,特别是因为它不像 DiD 那样有严格的平行趋势假设(PTA)要求。
-
SynthDiD 保证了控制单位的适当数量,考虑了可能的干预前模式,并可能容纳一定程度的内生处理时机[4]。
SynthDiD 方法的缺点:
-
计算可能比较昂贵(即使只有一个处理组/区块)。
-
需要一个平衡面板(即,你只能使用在所有时间段都被观察到的单位),并且处理时机对所有处理单位是相同的。
-
需要足够的处理前时间段以获得良好的估计,因此,如果没有足够的处理前时间段,可能更适合使用普通的 DiD。
-
计算和比较子组的平均处理效应很棘手。一个选择是将样本拆分为子组,并计算每个子组的平均处理效应。
-
实施 SynthDiD 时,如果处理时机各不相同,可能会很棘手。对于错峰处理时机,可以考虑为每个处理队列估计平均处理效应,然后将队列特定的平均处理效应汇总为总体平均处理效应。
这里还有一些其他你可能想知道的起步时要点。
需要注意的事项:
-
SynthDiD 采用正则化岭回归(L2),同时确保结果权重的总和为 1。
-
在处理前匹配的过程中,SynthDiD 尝试确定整个样本的平均处理效应。这种方法可能会导致个别时间段的估计不够精确。然而,总体平均值能提供一个无偏的评估。
-
处理效应的标准误差是通过抛弃法(jacknife)估计的,或者如果一个队列只有一个处理单位,则使用安慰剂方法。
-
估计量被认为是一致且渐近正态的,前提是控制单位数量和处理前时间段的组合相对于处理单位数量和处理后时间段的组合足够大。
-
在实践中,处理前变量在合成 DiD 中作用较小,因为滞后结果具有更强的预测能力,使得这些变量的处理不那么重要。
结论
在这篇博客文章中,我介绍了 SynthDiD 方法,并讨论了它与传统 DiD 和 SCM 的关系。SynthDiD 结合了 SCM 和 DiD 的优势,即使在处理期较短的情况下也能进行因果推断。我使用 R 中的 synthdid 包来演示该方法。尽管它有几个优点,比如不需要严格的平行趋势假设,但也有缺点,如计算开销大且需要平衡面板。总体而言,SynthDiD 是一个对有兴趣使用观察数据估计因果效应的研究人员非常有价值的工具,为传统的 DiD 和 SCM 方法提供了替代方案。
参考文献
1 D. Arkhangelsky, S. Athey, D.A. Hirshberg, G.W. Imbens, 和 S. Wager, 合成差分法 (2021),美国经济评论。
2 A. Abadie, A. Diamond, J. Hainmueller, 比较案例研究的合成控制方法:估计加州烟草控制计划的效果 (2010),美国统计协会杂志。
[3] A. Abadie, 使用合成控制:可行性、数据要求和方法论方面 (2021),经济学视角杂志。
[4] Berman, R., & Israeli, A., 描述性分析的价值:来自在线零售商的证据 (2022),营销科学。
有用的链接
勇敢和真实的因果推断,合成差分法。
Matteo Courthhoud, 了解合成控制方法。
感谢阅读!
如果你喜欢这篇文章并希望看到更多我的文章,可以考虑 关注我。
免责声明:我写作是为了学习,因此你可能会在文章或代码中发现错误。如果发现,请告知我。
合成控制:如果我们可以模拟替代现实呢?
原文:
towardsdatascience.com/synthetic-control-what-if-we-could-simulate-alternate-realities-4e88eb69d7b9
一种更好的政策评估方法
布鲁诺·波内
·发表于 Towards Data Science ·6 分钟阅读·2023 年 6 月 10 日
--
图片由 Hubert Buratynski 提供,来源于 Unsplash
看哪,看那边
太阳已经消失的地方
《莫比 — 最后一日》
当我在高中时,我清楚地记得问我的历史老师:“如果罗马帝国没有崩溃,我们今天的技术会有多先进?”她并不特别欣赏我的问题。事实上,历史学家往往对“如果”问题持保留态度,有时称之为反事实历史。他们更喜欢解释和说明事件的发生,而不是可能发生的情况。他们的工作基于事实、来源和证据,而“如果”场景可能会导致猜测或推测,影响对历史现实的严谨分析。
作为一个青少年时期的内省梦想者,我不断想象如果我们没有经历中世纪,会发生什么。实证科学是否会更早发展?几个世纪以来,战争是否会更频繁发生?我们是否会更好地照顾我们的星球?
这样的问题依然开放,因为一旦发展发生,就不可能经历一个没有发生该发展的替代现实。这本质上是因果推断的基本问题,即研究因果关系的科学。例如,如果政府决定实施禁止饮酒的政策,这是否会导致车祸死亡率的下降?理想情况下,这个因果关系问题应通过比较在没有禁令的实际世界与只有政策实施的平行世界中的车祸死亡率来回答。在这个理想的场景中,政策的影响将是观察到的无政策死亡率与有政策死亡率之间的差异。显然,这并不可行,因为我们只能访问我们自己的现实。
激励市长是否能改善教育?
我一直对教育和公共政策的动态感兴趣,特别是它们如何相互作用来塑造社会的未来。当选择硕士论文的主题时,我希望探讨一个相关、具有影响力并且有现实世界影响的课题。我想深入探讨一个可能为改善巴西教育系统提供见解的主题,不仅在理论上,而且在实践中也是如此。正是在这个过程中,我发现了巴西塞阿拉州实施的两个有趣的教育政策。
第一个政策是对市长的税收激励(TI),以改善市政教育。这是一种创新的方法,将市政税收转移与教育成就挂钩,鼓励地方政府更多地投资于教育系统。第二个政策是向市政府提供教育技术援助(TA)的项目,为他们提供改善教育实践所需的资源。
一些描述性图表显示,尽管塞阿拉州的资源投入较少,但相较于其他州其表现有所提升,如下图所示。纵轴展示了学生在数学和葡萄牙语测试中的正面分数变化,而横轴则显示了教育平均支出。
来源:Ponne, B. G. (2023)¹
为确保政策确实导致了这些改善,我需要更深入地分析这些政策,再次遇到了因果推断的基本问题:如果塞阿拉州没有采取这些政策,会怎样?他们的教育指标会更差吗?换句话说,这些政策是否对教育成就产生了积极影响?我没有一个完美的反事实,一个政策未被采纳的替代塞阿拉州。幸运的是,因果推断提供了一些近似反事实的方法。其中之一是合成控制法。
合成控制法
合成控制方法是一种统计技术,主要用于评估政策变化或其他干预措施的效果,当没有对照组时使用。其原理是通过结合未经历政策变化的多个州,创建一个感兴趣单位(在本例中为塞阿拉州)的合成版本。这个“合成控制”作为对照——它是我们期望在没有实施政策的情况下,感兴趣单位会发生的情况。
为了构建这个合成控制,我们必须选择一组未受到政策影响的州——这些州通常称为捐赠单位。然后,合成控制作为这些捐赠单位的加权组合创建,选择的方式是使合成控制与处理单位(塞阿拉州)干预前的特征紧密匹配。实际上,合成控制代表了一个没有采纳教育政策的假设性塞阿拉州。这种解释只是概述了该方法的基本概念。要更全面地理解,请参阅 Alberto Abadie(2021)²的《使用合成控制:可行性、数据要求和方法论方面》。
一旦建立了合成控制,我们就比较处理单位(塞阿拉州)及其合成对照组的干预后结果。这两者之间的差异可以解释为干预或政策的效果。
在下面的图表中,我描绘了塞阿拉州及其人工构建的未受政策影响的塞阿拉州的数学和葡萄牙语分数趋势。请注意,在政策实施之前,合成趋势与实际趋势非常接近,但之后明显分歧。根据这种方法,在没有政策的情况下,塞阿拉州的分数将遵循黄色线所代表的轨迹。在政策影响下,塞阿拉州的实际分数由绿色线表示。这两条线之间的差异表明这些政策的积极效果。
两项政策的结合导致了葡萄牙语测试分数在小学教育中稳定提高了 12%,在初中教育中提高了 6.5%。这些结果表明,精心设计的政策可能对教育成果产生重大影响。数学领域的结果并没有统计学意义。在我发表的论文¹中,我提供了一些解释为什么会发生这种情况。
然而,我的分析也揭示了一个关注点。尽管在初级和中级教育方面取得了这些进展,但上级中学虽然没有直接受到新政策的影响,但却接收到了来自低年级的更好准备的学生,却没有显著改善。这一发现突显了政策实施中的关键缺口,并引发了对将教育政策的好处扩展到上级中学以及其他巴西州的进一步讨论的需求。
R 中的合成控制方法
我使用了R synth library来实现合成控制。这个库是 R 中估计合成控制的一个非常强大的工具。它提供了两个主要功能:
-
dataprep()
: 准备捐赠池和处理单元特征的矩阵以及它们的感兴趣结果。这些矩阵随后可以传递给synth()
。 -
synth()
: 优化权重集以形成合成单元。
这个包还提供了在基础 R 中绘制结果的函数,但你也可以像我上面做的那样,准备由synth()
交付的数据以便在 ggplot2 中绘制。查看代码:github.com/bruno-ponne/Better-Incentives-Better-Marks
最后的思考
合成控制方法为我提供了一个独特的机会来研究这些政策对塞阿拉州教育成就的因果影响,为“如果”问题提供了定量维度。通过这种方法,我的研究超越了理论推测的领域,使基于数据和统计方法的严格分析成为可能。
我一直相信教育是促进发展中国家宽容、机会和民主的关键因素。使用合成控制的方法揭示了精心设计的政策在显著改善教育成果方面的潜力。我希望这些发现能为政策制定者提供有价值的见解,以便做出明智的教育决策。
引用的文章:
¹Ponne, B. G. (2023). 更好的激励,更好的成绩:对巴西塞阿拉州教育政策的合成控制评估。巴西政治科学评论,17(1),e0005。 doi.org/10.1590/1981-3821202300010005
²Abadie, Alberto (2021), 使用合成控制:可行性、数据要求和方法学方面。经济文献杂志。59(2),第 391–425 页。
系统设计备忘单:ElasticSearch
原文:
towardsdatascience.com/system-design-cheatsheets-elasticsearch-673b98eebfff?source=collection_archive---------0-----------------------#2023-11-28
了解如何以及何时在系统中使用 ElasticSearch,并通过三个实际的系统设计示例进行说明
Sanil Khurana
·
关注 发表在 Towards Data Science ·13 分钟阅读·2023 年 11 月 28 日
--
引言
什么是搜索?它为何重要?
如果你阅读过我之前关于搜索的文章,你会知道搜索对于应用程序的重要性。想一想:在你每天使用的各种网络应用和移动应用中,无论是 Netflix、Amazon 还是 Swiggy,搜索框可能是唯一一个在所有这些应用中都存在的通用 UI 元素,而且它通常都位于主页的顶部。如果你在设计一个系统,十有九九,你会考虑如何支持搜索功能。
建立一个搜索系统不是小事,但一个很好的起点是 ElasticSearch。如果你对搜索或推荐系统的工作原理一无所知,这篇博客文章是一个很好的起点。我们将讨论 ElasticSearch 是什么,它的适用场景和不适用场景,以及 ElasticSearch 常见的三种设计。搜索系统还有很多其他属性,但这些内容将在文章的后面详细讨论。
什么是 ElasticSearch?
ElasticSearch 是一个流行的数据库,它处理的是大多数数据库难以应对的任务:搜索。搜索对于 ElasticSearch 来说至关重要,它甚至体现在它的名字里!
但是如果你还没有听说过 ElasticSearch,你可能会想:为什么搜索这么困难?为什么关系型数据库无法进行搜索?大多数关系型数据库支持多种搜索和过滤数据的方法,比如 WHERE
查询、LIKE
关键字或索引。或者为什么像 MongoDB 这样的文档数据库无法工作?你也可以在 MongoDB 中编写 find
查询。
要理解这个答案,想象你正在建立一个新闻网站。当用户使用你的搜索框搜索新闻时,例如“新德里 COVID19 感染”,用户对所有 讨论 新德里 COVID 感染的文章感兴趣。在一个简单的搜索系统中,这意味着扫描数据库中的所有文章,并返回那些包含“COVID19”、“感染”或“新德里”这些词的文章。你不能用关系型数据库做到这一点。关系型数据库允许你基于特定属性搜索文章,比如特定作者撰写的文章或今天发布的文章等,但它不能(至少不能高效地)执行一个扫描 每一篇 新闻文章(通常是数千万篇)并返回那些包含特定单词的搜索。
此外,还有许多复杂的细节需要考虑。你如何对这些文章进行评分?也许有一篇文章讨论了 COVID19 的传播,也许有一篇讨论了新感染,你如何知道哪一篇与用户查询更相关,换句话说,你如何根据相关性对这些文章进行排序?
答案是:ElasticSearch!ElasticSearch 可以立即提供所有这些功能和更多。
但是,和世界上其他一切事物一样,它也有其自身的缺点。让我们来探讨一下 ElasticSearch 是什么,什么时候使用它,最重要的是,什么时候不适合使用它。
ElasticSearch
搜索能力
ElasticSearch 提供了一种执行“全文搜索”的方法。全文搜索指的是在大量文档中搜索一个短语或单词。继续我们之前的例子,假设你正在构建一个包含数百万篇新闻文章的新闻网站。每篇文章包含一些数据,比如标题、副标题、文章内容、发布时间等。在 ElasticSearch 的上下文中,每篇文章被存储为一个 JSON 文档。
你可以将所有这些文档加载到 ElasticSearch 中,然后在几毫秒内搜索每个文档中的特定单词或短语。因此,如果你加载所有新闻文章,然后执行一个搜索,比如“COVID19 感染在德里”,ElasticSearch 会返回所有包含“COVID19”、“感染”或“德里”这些词的文章。
为了演示在 ElasticSearch 中的搜索,我们来设置 Elasticsearch 并加载一些数据。对于本文,我将使用我在 Kaggle 上找到的这个新闻数据集(Misra, Rishabh. “News Category Dataset.” arXiv 预印本 arXiv:2209.11429 (2022)) (来源) (许可证)。该数据集非常简单,包含约 210,000 篇新闻文章,涵盖标题、简短描述、作者以及一些我们不太关注的其他字段。我们并不需要所有 210,000 篇文档,因此我会加载大约 10,000 篇文档到 ES 中并开始搜索。
这些是数据集中一些文档的示例——
[
{
"link": "https://www.huffpost.com/entry/new-york-city-board-of-elections-mess_n_60de223ee4b094dd26898361",
"headline": "Why New York City’s Board Of Elections Is A Mess",
"short_description": "“There’s a fundamental problem having partisan boards of elections,” said a New York elections attorney.",
"category": "POLITICS",
"authors": "Daniel Marans",
"country": "IN",
"timestamp": 1689878099
},
....
]
每个文档代表一篇新闻文章。每篇文章包含一个 link
、headline
、一个 short_description
、一个 category
、authors
、country
(随机值,由我添加)和 timestamp
(同样是随机值,由我添加)。
Elasticsearch 查询是用 JSON 编写的。在深入探讨所有不同的语法之前,我们先从简单的开始,逐步构建。
最简单的全文查询之一是 multi_match
查询(不用太担心在 ElasticSearch 中查询数据,它非常简单,我们将在文章末尾讨论)。其思想很简单,你编写一个查询,Elasticsearch 执行全文搜索,实质上扫描你数据库中的所有文档,找到包含查询中单词的文档,给它们分配一个评分,并返回这些文档。例如,
GET news/_search
{
"query": {
"multi_match": {
"query": "COVID19 infections"
}
}
}
上述查询找到了与“COVID19 感染”相关的文章。这些是我得到的结果 -
[
{
"_index" : "news",
"_id" : "czrouIsBC1dvdsZHkGkd",
"_score" : 8.842152,
"_source" : {
"link" : "https://www.huffpost.com/entry/china-shanghai-lockdown-coronavirus_n_62599aa1e4b0723f8018b9c2",
"headline" : "Strict Coronavirus Shutdowns In China Continue As Infections Rise",
"short_description" : "Access to Guangzhou, an industrial center of 19 million people near Hong Kong, was suspended this week.",
"category" : "WORLD NEWS",
"authors" : "Joe McDonald, AP",
"country" : "IN",
"timestamp" : 1695106458
}
},
{
"_index" : "news",
"_id" : "ODrouIsBC1dvdsZHlmoc",
"_score" : 8.064016,
"_source" : {
"link" : "https://www.huffpost.com/entry/who-covid-19-pandemic-report_n_6228912fe4b07e948aed68f9",
"headline" : "COVID-19 Cases, Deaths Continue To Drop Globally, WHO Says",
"short_description" : "The World Health Organization said new infections declined by 5 percent in the last week, continuing the downward trend in COVID-19 infections globally.",
"category" : "WORLD NEWS",
"authors" : "",
"country" : "US",
"timestamp" : 1695263499
}
},
....
]
正如你所见,它返回了讨论 COVID19 感染的文档。它还按相关性顺序对这些文档进行排序(_score
字段表示特定文档的相关性)。
ElasticSearch 具有丰富的查询语言和大量功能,但目前只需知道构建一个简单的搜索系统非常容易,只需将所有数据加载到 ElasticSearch 中,并使用我们讨论过的简单查询即可。我们有许多选项可以改进、配置和调整搜索性能和相关性(关于搜索查询的更多内容将在本文末尾讨论)。
分布式架构
ElasticSearch 作为分布式数据库工作。这意味着在一个 ElasticSearch 集群中有多个节点。如果一个节点变得不可用或失败,这通常不会导致系统停机,其他节点通常会接管额外的工作并继续服务用户请求。因此,多个节点有助于提高可用性。
多个节点还帮助我们扩展系统,数据和用户请求可以在这些节点之间划分,从而减少每个节点的负载。例如,如果你想在 ElasticSearch 中存储 1 亿篇新闻文章,你可以将这些数据分割到多个节点上,每个节点存储一部分文章。实际上,这非常简单,ElasticSearch 提供了内置功能来使这一过程尽可能简单和无缝。
扩展性
ElasticSearch 横向扩展,能够将数据分区到多个节点。这意味着你可以通过增加更多节点来始终提高查询性能。
关于架构你的 ElasticSearch 集群,思考的过程远不止于增加更多服务器。不同类型的节点运行着称为“shards”的进程,每个 shard、节点,可以有多种类型和配置选项。关于 ElasticSearch 集群的架构及其工作原理有很多内容可以讨论,如果你想更深入了解,可以查看我写的完整文章 这里。
总结:你可以添加更多机器来扩展你的集群并提高性能。数据和查询会被分配到多个机器上,这有助于提高性能和扩展性。
基于文档的数据建模
ElasticSearch 是一个文档数据库,它以 JSON 文档格式存储数据,类似于 MongoDB。因此,在我们的例子中,每篇新闻文章都作为 JSON 文档存储在集群中。
实时数据分析
实时数据分析是实时查看用户行为并了解用户模式和行为。我们可以绘制用户行为图表,更好地理解我们的用户,从而改进我们的产品。例如,假设我们测量每个用户在新闻网站上的每一次点击、滚动事件和阅读时间。我们将这些指标绘制在仪表板上并观察几天。通过这些数据,我们可以收集大量可操作的见解来改进我们的新闻应用。我们发现用户通常在早上 9-10 点使用网站,并且发现用户通常点击与他们国家相关的文章。利用这些信息,我们可以在高峰期(早上 9-10 点)超配资源,并可能在用户的首页上显示来自他们国家的文章。
Elasticsearch 由于其分布式架构和强大的搜索能力,非常适合实时数据分析。当处理实时数据时,如日志、指标或社交媒体更新,Elasticsearch 能高效地索引和存储这些信息。它的近实时索引使得数据在摄取后几乎可以立即被搜索。Elasticsearch 还可以很好地与其他工具配合使用,如用于可视化的 Kibana 或用于收集指标的 Logstash 和 Beats。
在文章的末尾,我们将探讨一种有助于实现这一点的架构。
成本
ElasticSearch 的运行和维护成本很高。正如世上所有美好事物都需付出代价一样,为了执行全文搜索,ElasticSearch 会在 RAM 中保持大量数据并构建复杂的索引。这意味着它需要大量的 RAM 来运行,这也是一笔不小的开支。
简而言之,ElasticSearch 在执行全文搜索时提供了惊人的性能,但它并不便宜。
什么时候不该使用 ElasticSearch
ACID 合规性
ElasticSearch 像大多数 NoSQL 数据库一样,对 ACID 的支持非常有限,因此如果你需要强一致性或事务支持,ElasticSearch 可能不是适合你的数据库选择。其后果是,如果你在 ElasticSearch 中插入一个文档(称为“索引”一个文档),该文档可能不会立即对其他节点可见,并且可能需要几毫秒才能被其他节点看到。
比如说,你正在构建一个银行系统;如果用户向其账户中存款,你希望这些数据能立即对用户执行的其他交易可见。另一方面,如果你使用 ElasticSearch 为你的新闻网站提供搜索服务,当一篇新文章发布时,文章在前几毫秒内对所有用户不可见可能是可以接受的。
当你需要复杂的联接时
ElasticSearch 不支持 JOIN 操作或不同表之间的关系。如果你习惯使用关系型数据库,这可能会让你感到有些震惊,但大多数 NoSQL 数据库对这些类型的操作支持有限。
如果你需要执行 JOIN 操作或使用外键来处理高度相关的结构化数据,ElasticSearch 可能不是你用例的最佳选择。
小型数据集或简单查询需求
ElasticSearch 复杂且成本高昂。运行和管理大型 ElasticSearch 集群不仅需要软件工程师和 DevOps 工程师的知识和技能,还可能需要擅长管理和架构 ElasticSearch 集群的专家,称为“ElasticSearch 架构师”。有大量的配置选项和架构选择可以尝试,每一个都对你的查询和摄取产生重要影响,从而间接影响系统核心流程中的用户体验。
如果你只需要执行简单的查询或数据量相对较少,那么简单的数据库可能更适合你的应用程序。
如何在系统设计中使用 ElasticSearch
一个单一的软件系统通常需要多个数据库,每个数据库支持不同的功能。让我们通过一个例子来更好地理解使用 ElasticSearch 的设计选择。
假设你想构建一个视频流媒体服务,比如 Netflix。让我们看看 ElasticSearch 在这个例子中可以适应的地方。
作为搜索系统
ElasticSearch 的一个非常常见的用例是作为支持全文搜索查询的辅助数据库。这对我们的在线视频应用非常有用。我们不能将视频存储在 ElasticSearch 中,并且我们可能也不想将与计费或用户相关的数据存储在 ElasticSearch 中。
为此,我们可以使用其他数据库,但我们可以将电影标题、描述、类型、评分等信息存储在 ElasticSearch 中。
我们可以有一个类似这样的架构:
作者提供的图片
我们可以将我们希望支持全文搜索的数据摄取到 ElasticSearch 中。当用户执行搜索操作时,我们可以查询 ElasticSearch 集群。这样,我们就可以利用 ElasticSearch 的全文搜索功能,当我们需要更新用户信息时,可以在我们的主要存储中执行这些更新。
作为实时数据分析管道
正如我们讨论的,了解用户行为和模式是决定如何发展产品的关键步骤。我们可以发布事件,例如点击流事件和滚动事件,以更好地理解用户如何使用我们的产品。
例如,在我们的在线视频应用中,我们可以在用户点击电影或节目时发布包含用户和电影数据的事件。然后我们可以分析和绘制汇总图表,以更好地理解用户如何使用我们的产品。例如,我们可能会注意到用户在晚上使用我们的产品的频率比在下午高,或者用户可能更喜欢用本国语言而非其他语言的节目或电影。利用这些信息,我们可以改进我们的产品,提升用户体验。
这就是使用 ElasticSearch 和 Kibana(一个与 ElasticSearch 配合良好的仪表板工具)的实时数据分析基本系统的样子:
作者提供的图片
作为推荐系统
我们可以在 ElasticSearch 中构建查询,以对某些属性给予更多优先级(称为提升)。例如,与简单查询相比
我们可以使用 ElasticSearch 构建基本的推荐系统。我们可以存储有关用户的信息,例如用户的国家、年龄、偏好等,并生成查询,以获取该用户的热门电影节目或系列。
理解查询语言和如何提升某些字段以及执行汇总是一个较大的主题,但我在这里写了一篇涵盖基础知识的博客文章:
## 掌握 Elasticsearch: 初学者强大的搜索与精准指南 — 第一部分
在第一部分解锁 Elasticsearch 的力量:深入了解 Elasticsearch,掌握基本的搜索查询,并探索词汇…
towardsdatascience.com
结论
如何构建 ElasticSearch 集群?
构建 ElasticSearch 集群绝非易事,它需要了解节点、分片、索引以及如何协调它们。选择几乎是无限的,且领域不断发展(尤其是随着 AI 和 AI 驱动搜索的流行)。为了深入探讨,我写了一篇完整的博客文章,从基础知识到构建搜索集群所需了解的一切:
[## 系统设计系列: ElasticSearch, 搜索架构设计
理解 Elasticsearch 架构和全文搜索
betterprogramming.pub](https://betterprogramming.pub/system-design-series-elasticsearch-architecting-for-search-5d5e61360463?source=post_page-----673b98eebfff--------------------------------)
理解搜索查询并改进搜索系统
搜索是复杂的,非常复杂。有很多方法可以改进搜索系统,使其更强大并更理解用户需求。你已经了解了 ElasticSearch 及其功能。从这里开始,构建一个基本的搜索查询,理解查询和系统中的问题,并通过示例一步一步地演变和改进系统。
## 掌握 Elasticsearch: 初学者强大的搜索与精准指南 — 第一部分
在第一部分解锁 Elasticsearch 的力量:深入了解 Elasticsearch,掌握基本的搜索查询,并探索词汇…
towardsdatascience.com
上下文感知搜索
我最近读到一个很好的搜索系统类比。你可以把我们讨论的搜索系统看作是一个机械而僵化的搜索。当用户输入一个词时,我们找到所有包含该词的文档并返回它们。
或者你可以把搜索系统想象成一个图书管理员。当用户问一个问题,比如“温斯顿·丘吉尔在第二次世界大战中的角色是什么?”,图书管理员不会仅仅告诉他包含“温斯顿”、“丘吉尔”或“第二次世界大战”这些词的书籍。相反,图书管理员会评估和理解客户及其背景。也许是一个小学生,所以她不会推荐一本大教科书,而是找到一本更适合年轻孩子的书。或者她可能没有任何关于温斯顿·丘吉尔的书籍,于是她会找到一本讲述第二次世界大战或英国首相的书籍,并推荐这本书。图书管理员甚至可能会为考试和暑假作业推荐不同的书籍(你们中的一些人可能不知道,但在一些国家,暑假作业量非常大)。
对你我来说这很容易理解,但我们的系统如何知道温斯顿·丘吉尔是英国首相并推荐关于第二次世界大战期间英国的书籍,或者我们的系统如何理解讨论的背景、理解用户并推荐合适的书籍呢?
尽管看起来很困难,但实际上并没有那么难。这叫做语义搜索,它是大多数大型科技公司构建搜索系统的方式。
语义搜索是一组搜索技术,旨在理解用户查询背后的含义和内容的上下文,通过考虑单词之间的关系和搜索意图,从而提供更准确、更相关的搜索结果。
这是一个广泛的话题,我仍在阅读和理解更多内容,但即将发布一篇从基础开始的博客文章,如果你想了解更多这个话题,可以在 Medium 上关注我。
其他数据库
我写关于系统设计概念的文章,例如数据库、队列和发布-订阅系统,所以可以在 Medium 上关注我,获取类似的文章。我还在 LinkedIn 上写了很多简短的内容(例如,这篇文章讲述了 RabbitMQ 和 Kafka 的区别),所以可以在 LinkedIn 上关注我,获取更短的内容形式在这里。
同时,你可以查看我关于其他数据库和系统设计概念的博客文章-
[## Sanil Khurana 在 Medium 上策划了一些列表
开始探索 Linux、Cassandra、面试问题等
medium.com](https://medium.com/@sanilkhurana7/lists?source=post_page-----673b98eebfff--------------------------------)
系统设计系列:从零开始构建高性能数据流系统的终极指南!
原文:
towardsdatascience.com/system-design-series-0-to-100-guide-to-data-streaming-systems-3dd584bd28fa?source=collection_archive---------0-----------------------#2023-12-17
Sanil Khurana
·
关注 发表在 Towards Data Science ·12 分钟阅读·2023 年 12 月 17 日
--
来源:Unsplash
设定一个示例问题:推荐系统
“数据流”听起来非常复杂,而“数据流管道”则更为复杂。在我们讨论这意味着什么并陷入术语之前,让我们从任何软件系统存在的原因——一个问题——开始。
我们的问题非常简单,我们需要为一个电子商务网站(类似于亚马逊)建立一个推荐系统,即一个根据用户偏好返回一组产品的服务。我们暂时不需要为它如何工作而感到疲惫(稍后会详细讨论),现在我们将专注于数据如何发送到这个服务中,以及它如何返回数据。
数据以“事件”的形式发送到服务中。这些事件都是用户执行的特定操作。例如,点击特定产品或搜索查询。简单来说,我们网站上所有的用户互动,从简单的滚动到昂贵的购买,都被视为一个“事件”。
图片来源:作者
这些事件本质上告诉我们有关用户的信息。例如,一个有意购买游戏 PC 的用户可能也会对游戏键盘或鼠标感兴趣。
我们的服务不时会收到获取用户推荐的请求,它的工作很简单,响应用户感兴趣的产品列表。
图片来源:作者
目前,我们不关心这些推荐列表是如何生成的,假设这个“推荐服务”执行了一些神奇的步骤(关于这些魔法的更多内容将在文章末尾讨论,现在我们不太关心这些步骤的逻辑),并找出用户的偏好。
推荐在许多系统中通常是事后的考虑,但它比你想象的要重要得多。几乎你使用的每个应用程序都依赖于像这样的推荐服务来驱动用户行为。例如,根据这篇论文,35%的亚马逊网络销售是通过推荐商品生成的。
然而,问题在于数据的巨大规模。即使我们运行的是一个中等流行的网站,在高峰时段,我们仍可能每秒接收到数十万(甚至可能是百万)个事件!如果有新产品或大型促销活动,那么这个数量可能会更高。
但我们面临的问题不仅仅如此。我们必须实时处理这些数据(执行之前讨论的魔法),并实时向用户提供推荐!如果有促销活动,即使是几分钟的推荐更新延迟也可能对业务造成重大财务损失。
什么是数据流处理管道?
数据流处理管道就是我上面描述的那样。它是一个接收连续数据(如事件)、执行多个处理步骤并存储结果以备未来使用的系统。
在我们的案例中,事件将来自多个服务,我们的处理步骤将涉及几个“神奇”的步骤来计算用户的推荐,然后我们将在数据存储中更新每个用户的推荐。当我们收到对特定用户推荐的查询时,我们只需获取之前存储的推荐并返回它们。
这篇文章的目的是了解如何处理这种规模的数据,如何摄取、处理和输出这些数据以供以后使用,而不是了解处理步骤的实际逻辑(但我们仍会稍微探讨一下,增加一些趣味)。
创建数据流管道:逐步指南
这涉及到许多方面,如数据摄取、处理、输出和查询,因此我们一步一步来。把每一步看作是一个较小的、孤立的问题。在每一步,我们将从最直观的解决方案开始,看看它为什么不起作用,然后构建一个有效的解决方案。
数据摄取
让我们从管道的起点开始,数据摄取。数据摄取问题相对容易理解,目标只是从多个来源摄取事件。
作者提供的图片
但虽然问题看起来简单,它也有其复杂的方面,
-
数据的规模非常大,轻松达到每秒数十万事件。
-
所有这些事件必须实时摄取,我们不能有哪怕几秒钟的延迟。
我们从简单的开始,实现这一目标的最直观方法是将每个事件作为请求发送到推荐系统,但这个解决方案存在许多问题。
-
发送事件的服务不应等待我们的推荐服务的响应。这会增加服务的延迟,并在推荐服务发送 200 状态码之前阻塞它们。它们应该改为发送“火并忘”的请求。
-
事件的数量会高度波动,一整天都会上下波动(例如,晚上或促销期间会增加),我们需要根据事件的规模来扩展我们的推荐服务。这是我们需要管理和计算的内容。
-
如果我们的推荐服务崩溃,我们将会在其停机期间丢失事件。在这种架构中,我们的推荐服务是一个单点故障。
我们可以通过使用消息代理或像 Apache Kafka 这样的“事件流平台”来解决这个问题。如果你不知道这是什么,它简单来说就是一个工具,可以从“发布者”那里接收消息并发布到特定主题上。“订阅者”监听或订阅这些主题,每当在主题上发布消息时,订阅者就会收到消息。我们将在下一节中进一步讨论 Kafka 主题。
你需要了解关于 Kafka 的一点是,它促进了生产者和消费者之间的解耦架构。生产者可以在 Kafka 主题上发布消息,而不需要关心消费者何时、如何或是否消费消息。消费者可以在自己的时间消费消息并处理它。Kafka 也能够实现非常高的扩展,因为它可以水平扩展,并且线性扩展,提供几乎无限的扩展能力(只要我们继续增加更多机器)。
作者提供的图片
所以每个服务将事件发送到 Apache Kafka。推荐服务从 Kafka 中获取这些事件。让我们看看这对我们有什么帮助 —
-
事件是异步处理的,服务不再需要等待推荐服务的响应。
-
扩展 Kafka 更加容易,如果事件的规模增加,Kafka 将简单地存储更多事件,同时我们也扩展我们的推荐服务。
-
即使推荐服务崩溃,我们也不会丢失任何事件。事件被持久化在 Kafka 中,因此我们不会丢失任何数据。
现在我们知道如何将事件引入我们的服务,让我们转到架构的下一部分——处理事件。
数据处理
数据处理是我们数据管道的一个组成部分。一旦我们接收到事件,我们需要为用户生成新的推荐。例如,如果用户搜索“显示器”,我们需要基于这个搜索更新该用户的推荐,也许会添加用户对显示器感兴趣的信息。
在我们进一步讨论架构之前,让我们忘记这些,稍微谈谈如何生成推荐。这也是机器学习发挥作用的地方,虽然理解这些内容对继续阅读文章并不是非常重要,但它非常有趣,所以我会尝试给出一个非常基础的简要描述。
让我们更好地理解用户互动及其含义。当用户通过搜索、点击或滚动事件与我们的网站互动时,用户是在告诉我们他/她的兴趣。我们的目标是理解这些互动,并利用它们来了解用户。
当你想到一个用户时,你可能会想到一个有姓名、年龄等个人信息的人。然而,为了我们的目的,将每个用户视为一个向量,或者简单地说是一组数字会更容易理解。听起来可能有些令人困惑(毕竟用户怎么能用一组数字来表示呢),但请耐心看下去,我们来看看这是怎么工作的。
假设我们可以将每个用户(或他的/她的兴趣)表示为二维空间中的一个点。每个轴表示用户的一个特征。假设 X 轴表示他/她喜欢旅行的程度,Y 轴表示他/她喜欢摄影的程度。用户的每个行为都会影响这个用户在二维空间中的位置。
假设一个用户在我们的二维空间中从以下点开始 —
作者提供的图片
当用户搜索“旅行包”时,我们将点向右移动,因为这表明用户喜欢旅行。
作者提供的图片
如果用户搜索了相机,我们将把用户向 Y 轴方向上移。
我们还将每个产品表示为相同二维空间中的一个点。
作者提供的图片
上图中用户的位置表明用户喜欢旅行,并且也有一点喜欢摄影。每个产品的位置也是根据它们与摄影和旅行的相关性来放置的。
由于用户和产品只是二维空间中的点,我们可以对它们进行比较和数学运算。例如,从上面的图表中,我们可以找到离用户最近的产品,在这种情况下是行李箱,并自信地说它是用户的一个好推荐。
上面是推荐系统的一个非常基础的介绍(更多内容请参见文章末尾)。这些向量(通常比 2 维大得多)被称为嵌入(用户嵌入表示我们的用户,产品嵌入表示我们网站上的产品)。我们可以使用不同类型的机器学习模型生成它们,虽然它们比我描述的要复杂得多,但基本原理是一样的。
回到我们的问题。对于每个事件,我们需要更新用户的嵌入(在我们的 n 维图表上移动用户),并返回相关的产品作为推荐。
让我们考虑一下生成这些嵌入所需的几个基本步骤,
-
update-embeddings
:更新用户的嵌入 -
gen-recommendations
:获取与用户嵌入相关(或接近)的产品 -
save
:保存生成的推荐和事件
我们可以为每种类型的事件构建一个 Python 服务。
作者提供的图片
每个微服务会监听一个 Kafka 主题,处理事件,然后将其发送到下一个主题,在那里一个不同的服务会监听。
作者提供的图片
由于我们再次使用 Kafka 而不是发送请求,这种架构也给我们带来了之前讨论的所有优势。没有单一的 Python 微服务是单点故障,处理规模也更容易。最后一个服务save-worker
需要保存推荐以供将来使用。让我们看看它是如何工作的。
数据接收端
一旦我们处理了一个事件,并生成了推荐,我们需要存储事件和推荐数据。在决定将事件和推荐数据存储在哪里之前,让我们考虑数据存储的要求。
-
可扩展性和高写入吞吐量 — 记住我们有大量的事件到达,每个事件还会更新用户推荐。这意味着我们的数据存储应该能够处理非常高数量的写入。我们的数据库应该具有高度的可扩展性,并且能够线性扩展。
-
简单查询 — 我们不会执行复杂的 JOIN 操作,也不会进行各种类型的查询。我们的查询需求相对简单,对于给定的用户,返回预计算的推荐列表。
-
无 ACID 要求 — 我们的数据库不需要强 ACID 合规性。它不需要对一致性、原子性、隔离性和持久性提供任何保证。
简单来说,我们关注的是一个可以处理极大规模的数据库,没有额外的花里胡哨。
Cassandra 是满足这些要求的完美选择。由于其去中心化的架构,它可以线性扩展,并且能够处理非常高的写入吞吐量,这正是我们需要的。
我们可以使用两个表,一个用于存储每个用户的推荐,另一个用于存储事件。最后的 Python 微服务 save
工作人员将把事件和推荐数据保存在 Cassandra 中。
图片来源于作者
查询
查询非常简单。我们已经为每个用户计算并持久化了推荐。要查询这些推荐,我们只需查询我们的数据库并获取特定用户的推荐。
图片来源于作者
完整架构
就这样!我们完成了整个架构,让我们画出完整的架构,看看它是什么样子的。
图片来源于作者
进一步学习
Kafka
Kafka 是 LinkedIn 开发的一个惊人的工具,用于处理极大的规模(这篇 LinkedIn 在 2015 年的博客文章讨论了每秒约 1300 万条消息!)。
Kafka 在线性扩展和处理极高规模方面非常出色,但要构建这样的系统,工程师需要了解和理解 Kafka,它是什么,它如何工作,以及与其他工具的对比。
我写了一篇博客文章,解释了 Kafka 是什么,它与消息代理的不同之处,以及 LinkedIn 工程师撰写的原始 Kafka 论文的摘录。如果你喜欢这篇文章,看看我关于 Kafka 的文章 —
[## 系统设计系列:从 10,000 英尺看 Apache Kafka
让我们来看看 Kafka 是什么,它是如何工作的以及我们什么时候应该使用它!
betterprogramming.pub](https://betterprogramming.pub/system-design-series-apache-kafka-from-10-000-feet-9c95af56f18d?source=post_page-----3dd584bd28fa--------------------------------)
Cassandra
Cassandra 是一种独特的数据库,旨在处理非常高的写入吞吐量。它能够处理如此高吞吐量的原因在于其高度可扩展的去中心化架构。我最近写了一篇博客文章讨论 Cassandra、它是如何工作的,以及最重要的何时使用它和何时不使用它 —
## 系统设计解决方案:何时使用 Cassandra 以及何时不使用
关于何时使用 Cassandra 以及何时不使用 Cassandra 的所有信息
medium.com
推荐系统
推荐系统是非常出色的技术,它们几乎在你我今天使用的所有应用程序中都得到了应用。在任何系统中,个性化和推荐系统形成了用户搜索和发现流程的核心。
我一直在写关于搜索系统的内容,并稍微涉及了一下如何在搜索系统中构建基础的个性化功能,但我的下一个话题将深入探讨推荐引擎的细节,它们是如何工作的,以及如何设计它们。如果这对你感兴趣,请在 Medium 上关注我以获取更多内容!我也在 LinkedIn 上发布了很多简短的内容,例如,这篇关于 Kafka Connect 的帖子,描述了它的工作原理以及为什么它如此受欢迎,仅用一个简单的图表。
结论
我喜欢讨论有趣且复杂的话题,并将其分解成 10 分钟的阅读内容。如果你喜欢这篇文章,请在 Medium 上关注我以获取更多类似的内容! 在 LinkedIn 上关注我 ,每天获取更小、更常规的指南,逐步提升你的技术和设计知识。
希望你喜欢这篇文章,如果你对这篇文章有任何反馈或对我接下来应该讨论的内容有任何想法,可以在评论中发表!
从头实现 t-SNE(配合 NumPy)
原文:
towardsdatascience.com/t-sne-from-scratch-ft-numpy-172ee2a61df7?source=collection_archive---------2-----------------------#2023-04-14
封面图片由作者提供
通过从头实现 t-SNE 并使用 Python,深入理解其内部工作原理
Jacob Pieniazek
·
关注 发表在 Towards Data Science ·17 分钟阅读·2023 年 4 月 14 日
--
我发现,真正理解任何统计算法或方法的最佳方式之一就是亲自实现它。另一方面,编写这些算法有时会很耗时且非常麻烦,如果别人已经完成了,为什么我还要花时间去做呢——这似乎不太高效,不是吗?这两个观点都很公平,我并不是要为其中一个观点辩护。
本文旨在通过将 原始论文 — 由 Laurens van der Maaten 和 Geoffrey Hinton 合作编写 — 中的数学翻译成 Python 代码实现来帮助读者理解 t-SNE。1 我发现这类练习对于揭示统计算法/模型的内部工作机制非常有启发性,并真正测试你对这些算法/模型的理解和假设。你几乎可以肯定地带着比以前更好的理解离开。至少,成功的实现总是非常令人满意的!
本文对任何程度接触 t-SNE 的读者都是可访问的。然而,请注意这篇文章绝对不是:
-
对 t-SNE 的严格概念性介绍和探索,因为有很多其他很棒的资源已经做到了这一点;尽管如此,我将尽力将数学方程式与其直观/概念性对应物在每个实现阶段连接起来。
-
对 t-SNE 的应用及优缺点的全面讨论,以及 t-SNE 与其他降维技术的直接比较。我将会在本文中简要提及这些话题,但绝不会深入探讨。
言归正传,让我们开始对 t-SNE 的简要介绍。
t-SNE 简要介绍
t-分布随机邻居嵌入(*t-SNE)是一个降维工具,主要用于具有大维度特征空间的数据集,能够将数据可视化到更低维度的空间(通常是 2-D)。它特别适用于可视化非线性可分数据,其中线性方法如主成分分析(PCA)会失败。将线性降维框架(如 PCA)推广到非线性方法(如 t-SNE)也称为流形学习。这些方法对于可视化和理解高维非线性数据集的基础结构非常有用,并且对解开和分组在高维空间中相似的观察值很有帮助。有关 t-SNE 和其他流形学习技术的更多信息,scikit-learn 文档 是一个很好的资源。此外,要了解 t-SNE 的一些有趣应用领域,维基百科页面 列出了这些领域及其相关工作的参考资料。
让我们先把名字 t-distributed stochastic neighbor embedding 拆解成它的组成部分。t-SNE 是对 6 年前 Geoffrey Hinton 和 Sam Roweis 在这篇论文中提出的随机邻域嵌入 (SNE) 的扩展。让我们从那里开始。名字中的 stochastic 部分源于目标函数不是凸的,因此不同的初始化可能会产生不同的结果。neighbor embedding 突出了算法的特性——在尽可能保留点的“邻域”结构的同时,将原始高维空间中的点最佳映射到相应的低维空间。SNE 包含以下(简化的)步骤:
-
获得原始空间中点之间的相似性矩阵: 计算每个数据点 j 相对于每个数据点 i 的条件概率。这些条件概率是在原始高维空间中使用以 i 为中心的高斯分布计算的,并具有以下解释:i 选择 j 作为其在原始空间中邻居的概率。这会创建一个表示点之间相似性的矩阵。
-
初始化: 在低维空间(例如 2-D)中为每个原始空间中的数据点选择随机起点,并在这个新空间中类似地计算新的条件概率。
-
映射: 迭代改进低维空间中的点,直到所有条件概率之间的Kullback-Leibler 发散度最小化。本质上,我们正在最小化两个空间的相似性矩阵之间的概率差异,以确保在将原始高维数据集映射到低维数据集时,尽可能保留相似性。
t-SNE 主要通过两种方式改进 SNE:
-
它最小化了Kullback-Leibler 发散度,而不是条件概率之间的发散度。作者称之为“对称 SNE”,因为他们的方法确保了联合概率 p_ij = p_ji。这使得成本函数的表现大大改善,更易于优化。
-
它使用具有一个自由度的Student-t 分布(也就是柯西分布)来计算点之间的相似性,而不是在低维空间中使用高斯分布(上面的第 2 步)。在这里我们可以看到 t-SNE 中的“t”来自哪里。这一改进有助于缓解作者所强调的“拥挤问题”,并进一步改善优化问题。 “拥挤问题”可以这样理解:假设我们有一个 10 维空间,那么在 2 维空间中可用的空间将不足以准确捕捉那些适度不相似的点,而与 10 维空间中相邻点所占用的空间相比,2 维空间的空间远远不够。更简单地说,只需设想将 3 维空间投影到 2 维空间,3 维空间将有更多的整体空间来建模相似性,相对于投影到 2 维的空间。Student-t 分布通过具有比正态分布更重的尾部来帮助缓解这个问题。有关这个问题的更深入的讨论,请参见原始论文。
如果这些内容没有立即跟上,那也没关系!我希望当我们在代码中实现这些时,所有部分都会迎刃而解。主要的要点是:t-SNE 在高维空间中通过“数据点选择其他点作为邻居”的联合概率来建模数据点之间的相似性,然后尝试找到这些点映射到低维空间的最佳方式,同时尽可能保留原始高维相似性。
从头开始的实现
现在让我们继续了解 t-SNE,方法是实现 Laurens van der Maaten 和 Geoffrey Hinton 在论文中提出的算法原版。我们将首先逐步实现下面的算法 1,这将涵盖主算法的 95%。作者还提到了两个额外的改进:1) 早期夸张和 2) 自适应学习率。我们将仅讨论添加早期夸张,因为这有助于解释实际算法的内部工作原理,而自适应学习率则侧重于提高收敛速度。
算法 1(见论文)
1. 输入和输出
根据原始论文,我们将使用来自 OpenML 的公开 MNIST 数据集,该数据集包含从 0 到 9 的手写数字图像。2 我们还将从数据集中随机抽取 1000 张图像,并使用主成分分析 (PCA) 降维,将组件数保留为 30。这两者都是为了提高算法的计算时间,因为这里的代码没有针对速度进行优化,而是为了可解释性和学习。
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
import pandas as pd
# Fetch MNIST data
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
mnist.target = mnist.target.astype(np.uint8)
X_total = pd.DataFrame(mnist["data"])
y_total = pd.DataFrame(mnist["target"])
X_reduced = X_total.sample(n=1000)
y_reduced = y_total.loc[X_total.index]
# PCA to keep 30 components
X = PCA(n_components=30).fit_transform(X_reduced)
这将是我们的 X 数据集,每一行是一个图像,每一列是一个特征,或者在这种情况下是主成分(即原始像素的线性组合):
从 MNIST 数据集中抽取的 1000 个样本及前 30 个主成分
我们还需要指定代价函数参数——困惑度——以及优化参数——迭代次数、学习率和动量。我们现在暂时不讨论这些参数,而是在每个阶段出现时再进行处理。
就输出而言,请记住,我们寻求的是原始数据集 X 的低维映射。在整个示例中,我们将把原始空间映射到二维空间。因此,我们的新输出将是现在以二维空间表示的 1000 张图像,而不是原始的 30 维空间:
所需的二维空间输出
2. 计算原始空间中 X 的亲和力/相似度
现在我们有了输入,第一步是计算原始高维空间中的成对相似度。即,对于每个图像 i,我们计算 i 在原始空间中选择图像 j 作为其邻居的概率。这些概率是通过围绕每个点的正态分布计算的,然后归一化为总和为 1。数学上,我们有:
Eq. (1) — 高维亲和力
请注意,在我们的例子中,n = 1000,这些计算将产生一个 1000 x 1000 的相似度评分矩阵。注意,我们在 i = j 时将 p = 0,因为我们正在建模成对相似度。然而,你可能会注意到,我们没有提到如何确定 σ。这个值是通过基于用户指定的期望困惑度的网格搜索为每个观察值 i 确定的。我们将立即讨论这个问题,但首先让我们看看如何编写上述公式(1)的代码:
def get_original_pairwise_affinities(X: np.ndarray, perplexity: int = 10) -> np.ndarray:
"""
Function to obtain affinities matrix.
Parameters:
X (np.ndarray): The input data array.
perplexity (int): The perplexity value for the grid search.
Returns:
np.ndarray: The pairwise affinities matrix.
"""
n = len(X)
print("Computing Pairwise Affinities....")
p_ij = np.zeros(shape=(n, n))
for i in range(0, n):
# Equation 1 numerator
diff = X[i] - X
σ_i = grid_search(diff, i, perplexity) # Grid Search for σ_i
norm = np.linalg.norm(diff, axis=1)
p_ij[i, :] = np.exp(-(norm**2) / (2 * σ_i**2))
# Set p = 0 when j = i
np.fill_diagonal(p_ij, 0)
# Equation 1
p_ij[i, :] = p_ij[i, :] / np.sum(p_ij[i, :])
# Set 0 values to minimum numpy value (ε approx. = 0)
ε = np.nextafter(0, 1)
p_ij = np.maximum(p_ij, ε)
print("Completed Pairwise Affinities Matrix. \n")
return p_ij
在我们查看这段代码的结果之前,让我们讨论一下如何通过 grid_search() 函数确定 σ 的值。给定一个指定的 perplexity 值(在这种情况下可以大致理解为每个点的最近邻数量),我们对一系列 σ 值进行网格搜索,以便使以下方程对于每个 i 尽可能接近等式:
Perplexity
其中 H(P) 是 P 的香农 熵。
P 的香农熵
在我们的案例中,我们将 perplexity 设置为 10,并将搜索空间定义为 [0.01 * 图像 i 和 j 之间差异的范数的标准差,5 * 图像 i 和 j 之间差异的范数的标准差],分成 200 个相等的步长。知道这一点后,我们可以按如下方式定义我们的 grid_search() 函数:
def grid_search(diff_i: np.ndarray, i: int, perplexity: int) -> float:
"""
Helper function to obtain σ's based on user-specified perplexity.
Parameters:
diff_i (np.ndarray): Array containing the pairwise differences between data points.
i (int): Index of the current data point.
perplexity (int): User-specified perplexity value.
Returns:
float: The value of σ that satisfies the perplexity condition.
"""
result = np.inf # Set first result to be infinity
norm = np.linalg.norm(diff_i, axis=1)
std_norm = np.std(norm) # Use standard deviation of norms to define search space
for σ_search in np.linspace(0.01 * std_norm, 5 * std_norm, 200):
# Equation 1 Numerator
p = np.exp(-(norm**2) / (2 * σ_search**2))
# Set p = 0 when i = j
p[i] = 0
# Equation 1 (ε -> 0)
ε = np.nextafter(0, 1)
p_new = np.maximum(p / np.sum(p), ε)
# Shannon Entropy
H = -np.sum(p_new * np.log2(p_new))
# Get log(perplexity equation) as close to equality
if np.abs(np.log(perplexity) - H * np.log(2)) < np.abs(result):
result = np.log(perplexity) - H * np.log(2)
σ = σ_search
return σ
有了这些函数,我们可以通过p_ij = get_original_pairwise_affinities(X)
计算亲和矩阵,从而得到以下矩阵:
原始高维空间中条件概率的亲和矩阵
请注意,主对角线元素按构造设置为 ε ≈ 0(每当 i = j 时)。请记住,t-SNE 算法的一个关键扩展是计算联合概率而不是条件概率。这可以简单地按如下方式计算:
将条件概率转换为联合概率
因此,我们可以定义一个新函数:
def get_symmetric_p_ij(p_ij: np.ndarray) -> np.ndarray:
"""
Function to obtain symmetric affinities matrix utilized in t-SNE.
Parameters:
p_ij (np.ndarray): The input affinity matrix.
Returns:
np.ndarray: The symmetric affinities matrix.
"""
print("Computing Symmetric p_ij matrix....")
n = len(p_ij)
p_ij_symmetric = np.zeros(shape=(n, n))
for i in range(0, n):
for j in range(0, n):
p_ij_symmetric[i, j] = (p_ij[i, j] + p_ij[j, i]) / (2 * n)
# Set 0 values to minimum numpy value (ε approx. = 0)
ε = np.nextafter(0, 1)
p_ij_symmetric = np.maximum(p_ij_symmetric, ε)
print("Completed Symmetric p_ij Matrix. \n")
return p_ij_symmetric
将上面的p_ij
代入,我们得到p_ij_symmetric = get_symmetric_p_ij(p_ij)
,从而获得以下symmetric亲和矩阵:
原始高维空间中联合概率的对称亲和矩阵
现在我们已经完成了 t-SNE 中的第一个主要步骤!我们计算了原始高维空间中的对称亲和矩阵。在我们深入优化阶段之前,我们将在接下来的两个步骤中讨论优化问题的主要组件,然后将它们结合到我们的 for 循环中。
3. 样本初始解决方案及计算低维亲和矩阵
现在我们想在低维空间中随机抽样一个初始解决方案,如下所示:
def initialization(
X: np.ndarray, n_dimensions: int = 2, initialization: str = "random"
) -> np.ndarray:
"""
Obtain initial solution for t-SNE either randomly or using PCA.
Parameters:
X (np.ndarray): The input data array.
n_dimensions (int): The number of dimensions for the output solution. Default is 2.
initialization (str): The initialization method. Can be 'random' or 'PCA'. Default is 'random'.
Returns:
np.ndarray: The initial solution for t-SNE.
Raises:
ValueError: If the initialization method is neither 'random' nor 'PCA'.
"""
# Sample Initial Solution
if initialization == "random" or initialization != "PCA":
y0 = np.random.normal(loc=0, scale=1e-4, size=(len(X), n_dimensions))
elif initialization == "PCA":
X_centered = X - X.mean(axis=0)
_, _, Vt = np.linalg.svd(X_centered)
y0 = X_centered @ Vt.T[:, :n_dimensions]
else:
raise ValueError("Initialization must be 'random' or 'PCA'")
return y0
其中调用 y0 = initialization(X)
我们得到一个随机的起始解决方案:
2-D 初始随机解决方案
现在,我们想在这个低维空间中计算亲和矩阵。然而,请记住,我们是利用具有 1 个自由度的学生-t 分布来完成的:
方程 (4) — 低维亲和
同样地,我们设置 q = 0 当 i = j。注意这个方程与公式 (1) 的不同之处在于分母涉及所有 i,因此按照构造是对称的。将其转化为代码,我们得到:
def get_low_dimensional_affinities(Y: np.ndarray) -> np.ndarray:
"""
Obtain low-dimensional affinities.
Parameters:
Y (np.ndarray): The low-dimensional representation of the data points.
Returns:
np.ndarray: The low-dimensional affinities matrix.
"""
n = len(Y)
q_ij = np.zeros(shape=(n, n))
for i in range(0, n):
# Equation 4 Numerator
diff = Y[i] - Y
norm = np.linalg.norm(diff, axis=1)
q_ij[i, :] = (1 + norm**2) ** (-1)
# Set p = 0 when j = i
np.fill_diagonal(q_ij, 0)
# Equation 4
q_ij = q_ij / q_ij.sum()
# Set 0 values to minimum numpy value (ε approx. = 0)
ε = np.nextafter(0, 1)
q_ij = np.maximum(q_ij, ε)
return q_ij
这里我们正在寻找一个 1000 x 1000 的亲和矩阵,但现在是在低维空间中。调用 q_ij = get_low_dimensional_affinities(y0)
我们得到:
新低维空间中联合概率的对称亲和矩阵
4. 计算成本函数的梯度
回顾一下,我们的成本函数是高维空间和低维空间中联合概率分布的 Kullback-Leibler 散度:
联合概率分布的 Kullback-Leibler 散度
直观上,我们希望最小化亲和矩阵 p_ij
和 q_ij
之间的差异,从而最好地保留原始空间的“邻域”结构。使用梯度下降法来解决优化问题,但首先让我们看看如何计算上面成本函数的梯度。作者推导了成本函数的梯度(见 论文 的附录 A)如下:
成本函数的梯度(公式 5,但来自附录)
在 Python 中,我们有:
def get_gradient(p_ij: np.ndarray, q_ij: np.ndarray, Y: np.ndarray) -> np.ndarray:
"""
Obtain gradient of cost function at current point Y.
Parameters:
p_ij (np.ndarray): The joint probability distribution matrix.
q_ij (np.ndarray): The Student's t-distribution matrix.
Y (np.ndarray): The current point in the low-dimensional space.
Returns:
np.ndarray: The gradient of the cost function at the current point Y.
"""
n = len(p_ij)
# Compute gradient
gradient = np.zeros(shape=(n, Y.shape[1]))
for i in range(0, n):
# Equation 5
diff = Y[i] - Y
A = np.array([(p_ij[i, :] - q_ij[i, :])])
B = np.array([(1 + np.linalg.norm(diff, axis=1)) ** (-1)])
C = diff
gradient[i] = 4 * np.sum((A * B).T * C, axis=0)
return gradient
输入相关参数,我们通过 gradient = get_gradient(p_ij_symmetric,q_ij,y0)
得到在 y0
处的梯度及相应输出:
初始解(y0)下成本函数的梯度
现在,我们已经准备好解决优化问题的所有部分!
5. 迭代与优化低维映射
为了更新我们的低维映射,我们使用 带动量的梯度下降,正如作者所述:
更新规则(带动量的梯度下降)
其中 η 是我们的 学习率,α(t) 是我们随时间变化的动量项。学习率控制每次迭代的步长,而动量项使优化算法在搜索空间的平滑方向上获得惯性,同时不被梯度的噪声部分所困扰。我们将例子中的 η=200,并且如果 t < 250 时将 α(t)=0.5,否则将 α(t)=0.8。以上是计算更新规则所需的所有组件,因此我们可以在设定的迭代次数 T 上进行优化(我们将 T=1000)。
在设置迭代方案之前,首先介绍作者所称的“早期夸张”增强。这一术语是一个常数,用于缩放原始的亲和度矩阵 p_ij
。这将更多地强调在新空间中早期建模非常相似的点(原始空间中 p_ij
的高值),从而形成高度相似点的“簇”。早期夸张在迭代方案的开始阶段 (T<250) 中开启,然后关闭。在我们的情况下,早期夸张将设置为 4。我们将在下面的可视化中看到它的实际效果。
现在,将所有算法部分结合起来,我们得到了如下内容:
def tsne(
X: np.ndarray,
perplexity: int = 10,
T: int = 1000,
η: int = 200,
early_exaggeration: int = 4,
n_dimensions: int = 2,
) -> list[np.ndarray, np.ndarray]:
"""
t-SNE (t-Distributed Stochastic Neighbor Embedding) algorithm implementation.
Args:
X (np.ndarray): The input data matrix of shape (n_samples, n_features).
perplexity (int, optional): The perplexity parameter. Default is 10.
T (int, optional): The number of iterations for optimization. Default is 1000.
η (int, optional): The learning rate for updating the low-dimensional embeddings. Default is 200.
early_exaggeration (int, optional): The factor by which the pairwise affinities are exaggerated
during the early iterations of optimization. Default is 4.
n_dimensions (int, optional): The number of dimensions of the low-dimensional embeddings. Default is 2.
Returns:
list[np.ndarray, np.ndarray]: A list containing the final low-dimensional embeddings and the history
of embeddings at each iteration.
"""
n = len(X)
# Get original affinities matrix
p_ij = get_original_pairwise_affinities(X, perplexity)
p_ij_symmetric = get_symmetric_p_ij(p_ij)
# Initialization
Y = np.zeros(shape=(T, n, n_dimensions))
Y_minus1 = np.zeros(shape=(n, n_dimensions))
Y[0] = Y_minus1
Y1 = initialization(X, n_dimensions)
Y[1] = np.array(Y1)
print("Optimizing Low Dimensional Embedding....")
# Optimization
for t in range(1, T - 1):
# Momentum & Early Exaggeration
if t < 250:
α = 0.5
early_exaggeration = early_exaggeration
else:
α = 0.8
early_exaggeration = 1
# Get Low Dimensional Affinities
q_ij = get_low_dimensional_affinities(Y[t])
# Get Gradient of Cost Function
gradient = get_gradient(early_exaggeration * p_ij_symmetric, q_ij, Y[t])
# Update Rule
Y[t + 1] = Y[t] - η * gradient + α * (Y[t] - Y[t - 1]) # Use negative gradient
# Compute current value of cost function
if t % 50 == 0 or t == 1:
cost = np.sum(p_ij_symmetric * np.log(p_ij_symmetric / q_ij))
print(f"Iteration {t}: Value of Cost Function is {cost}")
print(
f"Completed Low Dimensional Embedding: Final Value of Cost Function is {np.sum(p_ij_symmetric * np.log(p_ij_symmetric / q_ij))}"
)
solution = Y[-1]
return solution, Y
调用 solution, Y = tSNE(X)
我们得到以下输出:
其中 solution
是最终的 2-D 映射,Y
是我们在每次迭代步骤中的 2-D 映射值。绘制 Y
的演变,其中 Y[-1]
是我们的最终 2-D 映射,我们得到(注意算法在早期夸张开启和关闭时的表现):
t-SNE 算法中的 2-D 映射演变
我建议尝试不同的参数值(如困惑度、学习率、早期夸张等),看看解决方案如何变化(参见原始论文和scikit-learn 文档获取使用这些参数的指南)。
结论
就这样,我们从零开始实现了 t-SNE!我希望你发现这个练习对 t-SNE 的内部工作有启发,至少是令人满意的。请注意,这个实现并不旨在优化速度,而是为了理解。t-SNE 算法的改进包括提高计算速度和性能,例如Barnes-Hut 算法的变体(基于树的方法)、使用 PCA 作为嵌入的初始化,或使用如自适应学习率等额外的梯度下降扩展。scikit-learn中的实现采用了许多这些增强功能。
一如既往,我希望你阅读这篇文章的乐趣与我写作时的乐趣一样。
资源
1 van der Maaten, L.J.P.; Hinton, G.E. 使用 t-SNE 可视化高维数据。《机器学习研究期刊》9:2579–2605, 2008。
2 LeCun et al. (1999):手写数字 (图像) 的 MNIST 数据集 许可证:CC BY-SA 3.0
通过这个 GitHub 仓库访问所有代码: github.com/jakepenzak/Blog-Posts
感谢你阅读我的帖子!我在 Medium 上的帖子旨在探讨利用 计量经济学 和 统计学/机器学习 技术的实际和理论应用。此外,我还希望通过理论和模拟提供有关各种方法论的理论基础的帖子。最重要的是,我写作是为了学习并帮助他人学习!我希望使复杂的主题对大家稍微更加易懂。如果你喜欢这篇帖子,请考虑 在 Medium 上关注我!
T5:文本到文本的变换器(第一部分)
原文:
towardsdatascience.com/t5-text-to-text-transformers-part-one-6b655f27c79a
创建一个统一的语言建模框架
Cameron R. Wolfe, Ph.D.
·发表于Towards Data Science ·14 分钟阅读·2023 年 6 月 27 日
--
(照片由Patrick Tomasso拍摄,来源于Unsplash)
迁移学习范式包括两个主要阶段。首先,我们在一大堆数据上对深度神经网络进行预训练。然后,我们在一个更具体的下游数据集上对这个模型进行微调(即,再训练它一段时间)。这些阶段的具体实现可能有多种形式。例如,在计算机视觉中,我们通常使用有监督学习目标在 ImageNet 数据集上对模型进行预训练。然后,这些模型在下游数据集上进行有监督微调(即,我们实际尝试解决的任务)。而在自然语言处理(NLP)中,我们通常在未标记的文本语料库上进行自监督预训练。
结合大型深度神经网络和庞大的(预)训练数据集通常会产生令人印象深刻的结果。这一发现对自然语言处理(NLP)尤其适用。由于原始文本数据在互联网上可以自由获得,我们可以简单地下载大量文本语料库,先在这些数据上预训练一个大型神经网络,然后在各种下游任务上微调模型(或仅使用零/少样本学习技术)。这一大规模迁移学习方法最初由 BERT 探索,该方法在未标记的数据上使用掩蔽目标预训练了一个变换器编码器,然后在下游语言任务上进行微调。
BERT 2 的成功不可小觑(即,在几乎所有语言基准上都取得了新的最先进性能)。因此,自然语言处理(NLP)社区开始深入研究迁移学习这一主题,提出了许多新的扩展和改进。由于这一领域的发展迅速,各种方法的比较变得困难。文本到文本转换器(T5)模型 1 提出了一个统一的框架,用于研究 NLP 中的迁移学习方法,使我们能够分析不同的设置并得出一套最佳实践。这套最佳实践包括 T5,这是一种用于语言理解任务的最先进模型和训练框架。
(来自 1)
相关历史和背景
T5 将现有的迁移学习技术重新定义为统一的格式,进行比较,并确定最佳实践以获得高性能结果。但这意味着什么?迁移学习是什么,为什么我们应该关注它? 为了回答这些问题,我们将首先概述一些重要的概念,包括迁移学习和不同的 Transformer 架构变体,这些对于理解 1 中的分析至关重要。从这里开始,我们将通过解释 BERT 2 架构来提供一些历史背景,这一架构使得迁移学习在自然语言处理(NLP)任务中变得流行。
什么是迁移学习?
训练神经网络的不同选项(由作者创建)
如果我们想训练一个神经网络来解决某个任务,我们有两个基本的选择。
-
从头开始训练:随机初始化你的神经网络,并在你的目标任务上进行训练(以监督方式)。
-
迁移学习:在一个独立的数据集上进行预训练,然后在目标任务上进行微调(即,进一步训练)。
通常,预训练是在比下游目标数据集大得多的数据集上进行的。一般而言,预训练会大幅提高数据效率。模型在微调期间学习得更快,甚至可能表现更好。迁移学习过程可以有许多不同的形式。例如,在计算机视觉中,我们可能会在 ImageNet 上进行模型预训练(使用监督学习),然后在像 CIFAR-10/100 这样的较小数据集上进行微调。对于自然语言处理(NLP)任务,情况稍有不同。通常,我们使用自监督预训练目标(例如,掩蔽语言建模或因果语言建模)与未标记的文本。
不同的 Transformer 架构
(来自 [6])
Transformer,如在 1 中最初提出的,使用编码器-解码器架构,如上所示。有关此架构的更深入概述,请查看链接 here。然而,编码器-解码器 Transformer 架构并不是我们唯一的选择!BERT 使用 仅编码器架构,而大多数 现代大型语言模型(LLMs)基于 仅解码器的 Transformers。让我们花一点时间了解这些架构变体之间的区别。
Transformer 编码器中的双向自注意力(由作者创建)
自注意力的简介。 自注意力操作将一个令牌向量序列作为输入,并生成一个长度相同的新的变换后令牌向量序列作为输出;如上所示。这个新序列的每个条目都是输入序列中向量的加权平均值。具体而言,我们计算输出序列中每个令牌向量的方法如下,其中 y_i
和 x_j
分别是输出和输入序列的元素。
(由作者创建)
上述权重 w_{i, j}
是一个注意力分数,它是 x_i
和 x_j
的函数。简单来说,这个分数捕捉了当前令牌在计算其新表示时应该“关注”序列中的其他令牌的程度。
(来自 [6])
单堆栈还是双堆栈? 原始 Transformer 架构使用两个“堆栈” 的 Transformer 层;见上文。第一个堆栈(编码器模块)由几个包含双向自注意力和 前馈神经网络 的块组成。第二个堆栈(解码器模块)非常相似,但它使用 掩码自注意力,并增加了一个“交叉注意力”机制,该机制在执行自注意力时考虑对应编码器层中的激活。Transformer 最初用于 序列到序列 任务(例如语言翻译)。对于其他任务,单堆栈 Transformer 模型变得非常流行:
-
语言模型使用仅解码器架构
-
BERT 风格的模型使用仅编码器架构
(来自 1)
注意力掩码。 变压器架构的变体有一个主要区别:在其注意力层中使用的掩码类型。在这里,当我们说“掩码”时,我们指的是在自注意力计算过程中某些标记被掩盖(或忽略)。简单来说,某些标记可能仅查看完整输入序列中的一部分其他标记。上图描绘了自注意力的不同掩码选项。
仅编码器模型利用双向(或完全可见)自注意力,这在自注意力过程中考虑了整个序列中的所有标记。自注意力中的每个标记表示是通过序列中所有其他标记的加权平均来计算的。相比之下,仅解码器模型使用因果自注意力,其中每个标记仅考虑序列中其之前的标记。
(来自 1)
我们还可以通过定义“前缀”来采用混合方法。更具体地说,我们可以对序列开头的一组标记(即前缀)执行双向自注意力,然后对序列中其余的标记执行因果自注意力;见上文。完全可见(或双向)自注意力对于处理前缀或执行分类任务非常有用。然而,某些应用(例如,语言建模)在训练过程中需要因果自注意力,以防止变压器“看到未来”(即在生成输出时仅复制正确的标记)。
T5 使用什么? 尽管1中的分析考虑了许多变压器架构,但 T5 主要使用的是标准的编码器-解码器架构。除了少数小修改外,该模型与最初提出的变压器[6]非常相似。由于编码器仅架构设计用于标记或序列级分类,而不是像翻译或总结这样的生成任务,1中没有探索它们。T5 旨在找到一种统一的方法(基于迁移学习)来解决许多语言理解任务。
BERT:NLP 的迁移学习
在早期,NLP 中的迁移学习通常使用经过 因果语言建模目标 预训练的递归神经网络。然而,随着 BERT 2 的提出,一切发生了变化。BERT 是一种基于变换器的模型 [6],其使用 自监督目标 进行预训练。BERT 可以在大量未标记的文本上进行预训练,然后微调以对句子(甚至句子中的单个标记)进行高精度分类。在提出时,BERT 在几乎所有被考虑的 NLP 任务上都设立了新的最先进水平,巩固了迁移学习在 NLP 中的主导地位。
使用 BERT 执行自监督 MLM 预训练(由作者创建)
为了使这一点更加具体,BERT 在预训练过程中依赖于一种“去噪”目标,称为 掩码语言建模 (MLM);见上文。虽然这听起来可能有些复杂,但其核心思想很简单,我们只需:
-
将输入序列中的一些标记掩盖,用特殊的
[MASK]
标记替代 -
使用 BERT 处理这些被破坏/修改过的序列
-
训练 BERT 以准确预测掩码标记
精确的实现要复杂一些。我们随机选择 15% 的标记,然后将它们替换为 [MASK]
标记(90% 的概率)或随机标记(10% 的概率)。通过在足够大的预训练语料库上使用这一目标,BERT 可以学习大量的一般语言学知识,使其成为一个高效的迁移学习模型。
T5 与 BERT 有什么关系? BERT 的提出展示了迁移学习是一种解决 NLP 问题的有效方法。许多人很快开始使用 BERT,尝试新技术并提出改进建议。因此,该领域充斥着各种使用类似 BERT 模型进行迁移学习的选项。T5 1 在这一研究方向上继续前进,但试图使用统一的框架来分析所有这些不同的提案,从而为我们提供了关于 NLP 中迁移学习最佳实践的更清晰的视角。最终的 T5 模型利用这些最佳实践进行训练,以达到最先进的性能。
T5 与 LLMs 的关系是什么? 目前,我们正在看到生成 AI 领域的重大革命,其中 LLMs(基于仅解码器的变压器架构)被用于通过语言模型预训练解决语言任务,然后进行 零/少样本学习。LLMs 很出色,但 T5 存在于一个相对独特的工具和研究领域。即,T5 主要关注那些明确通过编码器处理输入,然后通过单独的解码器生成输出的模型。此外,T5 采用迁移学习方法(即,预训练后在每个目标任务上进行微调),而不是零/少样本学习。
其他有用的链接
-
变压器架构 [link]
-
自注意力 [link]
-
BERT 模型 [link]
-
语言模型的基础 [link]
T5:统一的文本到文本变压器
T5 的贡献并不是一种新颖的架构或训练方法。相反,1中进行的研究完全基于现有技术。T5 考虑了 NLP 领域中迁移学习管道的各个方面,例如不同的(未标记的)数据集、预训练目标、基准测试和微调方法。然而,所有这些方面都是通过统一的文本到文本格式进行研究的。T5 的目标是 i) 分析迁移学习设置和 ii) 确定最有效的方法。
文本到文本框架
T5 将所有文本处理问题转换为“文本到文本”格式(即,将文本作为输入并生成文本作为输出)。这种通用结构也被采用于具有零/少样本学习的 LLMs,使我们能够用统一的方法建模和解决各种不同的任务。我们可以将相同的模型、目标、训练程序和解码过程应用于我们考虑的每个任务!我们只需采用 提示 方法,并要求我们的语言模型以文本格式生成答案。
(来自 1)
为了让这一点更加具体,T5 解决的所有任务都可以转换为文本到文本格式,如下所示:
-
为原始输入序列添加任务特定的前缀
-
将这个序列输入到变压器中
-
将模型的目标形式化为文本序列
使用这种格式,我们可以轻松执行诸如总结或翻译(即目标自然是一个序列)之类的任务。此外,我们可以通过仅训练模型生成与正确类别相关的文本来进行分类。这一过程在回归问题(即我们必须将实值输出舍入到最近的小数并将其视为分类问题)时会变得有些复杂,但对于大多数语言任务,它往往效果很好。示例见上图。
“如果我们的模型在文本分类任务中输出的文本与任何可能的标签都不对应,那么会出现问题……在这种情况下,我们总是将模型的输出视为错误,尽管我们在任何训练模型中都没有观察到这种行为。” — 来自1
T5 会针对它解决的每个任务进行微调。这与使用少样本学习的 LLM 和使用多任务学习一次性解决多个任务的 NLP 十项全能[3]形成对比。
T5 是如何研究的?
所有在1中进行的分析都使用了上述统一的文本到文本框架,因为它允许将各种不同的语言理解任务转换为共享格式。此外,T5 的分析使用了相同的基础变换器架构和预训练数据集。
(来自[6])
模型。 如前所述,变换器架构,正如[6]中最初提出的那样,包含编码器和解码器模块。最近对语言建模的研究探讨了只使用编码器或解码器的架构变体;例如,BERT只使用编码器2,而大多数(大型)语言模型只使用解码器。T5 使用了一个与原始变换器非常相似的编码器-解码器架构。不同之处在于:
-
LayerNorm在每次注意力和前馈变换之前立即应用(即,位于残差路径之外)
-
对于 LayerNorm 未使用加性偏置(即,见这里;我们只使用缩放并消除加性偏置)
-
使用了一个简单的位置嵌入方案,将一个标量添加到计算注意力权重时使用的相应logit中。
-
在整个网络中应用了 Dropout(例如,注意力权重、前馈网络、跳跃连接等)
这些修改在上图中有所说明。使用该模型(以及其他一些模型),T5 可以测试许多不同的迁移学习设置,以得出一套最佳实践。
预训练数据集。 T5 在 Colossal Clean Crawled Corpus(C4)上进行预训练,这是一个 750GB 的“相对干净”的英文文本数据集,详见1。尽管先前的工作中提出了各种预训练数据集,1中的作者选择构建自己的数据集,因为先前的数据集不可公开获取,使用的过滤规则有限,范围有限(例如,仅来自Creative Commons),或仅专注于机器翻译的平行数据(即多个不同语言中的相同句子版本)。
(来自[4])
值得注意的是,C4 后来被用作 MassiveText 数据集的一个子集,该数据集用于预训练Gopher和Chinchilla[4, 5]。请参见上表,以了解该数据集的规模指标,这有助于更好地理解 C4 与用于训练现代 LLM 的预训练数据集的相对大小。对于 LLM,我们已经看到,预训练仅解码器模型在足够大的数据集上是其成功的关键。不同架构的变换器,如 T5,亦是如此。在大型未标记数据集上的广泛预训练有利于更好的下游表现。
实验设置。 T5 在 C4 上进行预训练,然后微调以解决各种下游任务。然而,在这个框架中使用的确切设置是可变的。即,我们可以更改:
-
变换器架构
-
预训练设置(即任务或数据量)
-
微调设置
-
模型的规模/大小
通过逐一更改这些设置并评估结果,我们可以为 NLP 中的迁移学习开发一套最佳实践,从而将 BERT 之后的众多提议提炼成一个有效的管道,用于创建有效的语言理解模型。
要点
本文涵盖了与 T5 模型相关的所有初步信息,包括重要的背景信息和使用的基本实验框架。在下一篇文章中,我们将详细介绍1中进行的广泛分析,揭示 NLP 中迁移学习的最佳实践。目前,T5 的主要要点概述如下。
迁移学习是强大的。 迁移学习是指在某些独立数据集上预训练深度学习模型,然后在下游目标数据集(即我们实际要解决的任务)上微调(或进一步训练)该模型。如果在足够大且对齐(即,与下游任务类似)的数据集上进行,预训练是非常有效的。模型在微调期间可以学习得更快,甚至达到更高的准确率。这种技术在不同领域(例如计算机视觉和自然语言处理)中都有效,但用于预训练或微调的确切方法可能会有所不同。
“虽然我们在本文中没有明确测量数据效率的提升,但我们强调这是迁移学习范式的主要好处之一。” — 来源于 1
BERT 之后是什么? BERT 2 的提出是一个巨大的突破,普及了迁移学习在自然语言处理任务中的应用。事实上,BERT 在几乎所有涉及的任务上都设置了新的最先进性能。由于其成功,研究社区采纳并迭代了 BERT 的方法。T5 尝试统一 BERT 提出后的所有后续工作和分析,提供了对最有效迁移学习方法的更清晰视角。
通用任务制定。 为了创建一个统一的框架,以便研究多种不同的迁移学习方法,T5 提出了一个通用的文本到文本框架。类似于用于大型语言模型(LLM)的提示和少样本学习技术,这个文本到文本框架可以将任何语言任务重组为文本输入和输出。具体来说,这通过在文本输入中附加特定任务的前缀(即,让 T5 知道它正在解决什么任务)来完成,并使用 T5 的解码器模块生成与期望目标(例如标签、回归值或文本序列)对应的文本。
结束语
感谢阅读本文。我是 Cameron R. Wolfe,Rebuy 的 AI 总监。我研究深度学习的实证和理论基础。你也可以查看我在 Medium 上的 其他文章!如果你喜欢,请在 twitter 上关注我,或订阅我的 Deep (Learning) Focus 电子通讯,我在其中通过易于理解的热门论文概述帮助读者深入了解 AI 研究中的主题。
参考文献
1 Raffel, Colin, et al. “Exploring the limits of transfer learning with a unified text-to-text transformer.” The Journal of Machine Learning Research 21.1 (2020): 5485–5551.
2 Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” arXiv preprint arXiv:1810.04805 (2018).
[3] McCann, Bryan, 等人。“自然语言十项全能:将多任务学习应用于问答。” arXiv 预印本 arXiv:1806.08730 (2018)。
[4] Rae, Jack W., 等人。“语言模型的扩展:训练 Gopher 的方法、分析与见解。” arXiv 预印本 arXiv:2112.11446 (2021)。
[5] Hoffmann, Jordan, 等人。“训练计算最优的大型语言模型。” arXiv 预印本 arXiv:2203.15556 (2022)。
[6] Vaswani, Ashish, 等人。“注意力机制才是你所需要的。” 神经信息处理系统进展 30 (2017)。
T5: 文本到文本的变换器(第二部分)
原文:
towardsdatascience.com/t5-text-to-text-transformers-part-two-837ba23a9eb4
大型语言模型的最佳迁移学习
Cameron R. Wolfe, Ph.D.
·发表于Towards Data Science ·阅读时间 14 分钟·2023 年 7 月 5 日
--
(照片来源于Patrick Tomasso于Unsplash)
BERT [5] 的提出促使了自然语言处理(NLP)领域中迁移学习方法的普及。由于互联网上大量未标记文本的广泛存在,我们可以轻松地(i) 在大量原始文本上预训练大型变换器模型,以及(ii) 微调这些模型以准确解决下游任务。这种方法非常有效,但其新兴的受欢迎程度导致许多替代方法和改进被提出。随着这些新方法的出现,人们不禁开始想:NLP 中迁移学习的最佳实践是什么?
这个问题通过对统一的文本到文本变换器(T5)模型的分析得到了回答。T5 在预训练和微调期间都将所有任务重新格式化为文本到文本的格式,这意味着模型接收文本输入并生成文本输出。利用这种统一格式,T5 可以分析各种不同的迁移学习设置,从而允许比较多种方法。在之前的通讯中,我们了解了 T5 模型的格式、架构和整体方法。
在本期通讯中,我们将概述 T5 进行的分析,包括不同预训练目标、架构、模型/数据规模和 NLP 领域迁移学习的训练方法的实证比较。在1中,这些选项被逐一研究,以确定它们对 T5 性能的影响。通过研究这些分析,我们得出了一套最佳实践,这些实践(当结合在一起时)产生了最先进的 T5 框架,能够以极高的准确性解决语言理解任务。
(来自 1)
基础知识
我们已经覆盖了 T5 架构的动机和基础知识。请查看链接中的帖子 这里。我们也可以快速回顾这些思想。 BERT [5] 的提案普及了 迁移学习 范式(即,在某些独立数据集上预训练模型,然后在目标数据集上进行微调)用于 NLP。然而,BERT 的有效性使许多研究人员将重点放在这个话题上,并提出各种修改和改进。T5 的想法是 (i) 将所有语言任务转换为统一的文本到文本格式(见下图),并 (ii) 研究各种不同的 NLP 迁移学习设置,以推断出最有效的技术。
(来自 1)
语言建模 vs. 去噪
初期的 NLP 迁移学习方法利用了 因果语言建模 目标 [6] 进行预训练。然而,随后显示去噪(也称为 掩码语言建模,或 MLM)目标表现更好 [5]。给定一组要作为输入传递给某些模型的文本标记,MLM 通过以下方式操作:
-
随机(均匀)选择 15% 的标记
-
用一个
[MASK]
标记替换 90% 的选择标记 -
用一个随机标记替换 10% 的选择标记
-
训练模型预测/分类每个
[MASK]
标记
被均匀选择的标记的百分比称为“损坏率”。在 T5 中,我们将看到这种去噪目标的几个不同变体,但基本想法保持不变。
“我们所有的目标都接收来自我们未标记文本数据集的一个标记化文本片段的标记 ID 序列。标记序列被处理以生成一个(损坏的)输入序列和相应的目标。然后,模型像往常一样通过最大似然法来预测目标序列。” — 来自 1
基准测试和评估
T5 尝试推导出 NLP 中迁移学习的最佳实践集。然而,为了确定哪些技术效果最佳,T5 在各种任务和自然语言基准上进行了评估。所有这些任务都是使用 T5 的 文本到文本格式解决的。有关这些任务的完整描述,请参见 1 的第 2.3 节。下面提供了简要总结。
-
GLUE 和 SuperGLUE [7, 8]:这两个基准测试包含许多不同的任务,如 句子接受度判断、情感分析、释义、句子相似度、自然语言推断(NLI)、共指消解、句子完成、词义消歧和 问答。SuperGLUE 是一个改进且更具挑战性的基准测试,其结构类似于 GLUE。
-
CNN + Daily Mail 抽象总结 [9]:将新闻文章与一段简短的总结文本配对,捕捉文章的主要亮点。
-
SQuAD [10]:一个关于维基百科文章的问答数据集,其中每个问题的答案是相关文章中的一段文本。
-
几个翻译数据集(例如,英语到德语、法语和罗马尼亚语)。
值得注意的是,GLUE 和 SuperGLUE 基准测试中的所有任务都由 T5 连接在一起,并且在所有任务上同时进行微调。
其他重要观点
-
不同类型的 Transformer 架构 [link]
-
语言建模基础 [link]
-
自注意力 [link]
我们从 T5 中学到了什么?
如前所述,T5 的实验尝试发现 NLP 中迁移学习的最佳实践。为此,首先提出一种基线方法,然后逐一更改该基线的几个方面(例如,模型架构/大小、数据集和预训练目标),以查看哪些效果最佳。这种方法类似于 坐标下降策略。我们将首先描述基线技术,然后解释 T5 在测试各种迁移学习设置后的发现。
T5 基线模型
(来自 [11])
模型。T5 基础架构使用标准的编码器-解码器变换器架构;见上文。编码器和解码器的结构与BERTBase类似。尽管许多现代 NLP 方法使用“单堆栈”变换器架构(例如,仅编码器架构用于 BERT 或仅解码器架构用于大多数语言模型),T5 选择避免这些架构。有趣的是,1中的作者发现编码器-解码器架构在生成和分类任务中都取得了令人印象深刻的结果。1中没有考虑仅编码器模型,因为它们专门用于标记/跨度预测,不能很好地解决生成任务。
(来自 1)
与编码器-解码器架构相比,仅解码器模型受到限制,因为它们仅使用因果(或掩码)自注意力;见上文。掩码自注意力在计算序列中任何给定标记的表示时只考虑前面的标记。然而,有些情况下我们希望对初始范围或文本前缀进行完全可见的注意力,然后基于这个前缀生成输出(例如,翻译任务)。仅解码器模型无法处理这些情况,因为它们对整个输入进行因果自注意力。
训练 T5。T5 模型在C4 语料库上进行了 34B 标记的预训练。作为对比,BERT 在 137B 标记上进行训练,而 RoBERTa 在 2.2T 标记上进行训练[5, 12]。受 BERT 中的 MLM 目标启发,T5 使用略微修改的去噪目标进行预训练:
-
随机选择输入序列中的 15%标记
-
用单一标记替换所有连续选择的标记范围
“哨兵”标记
-
给每个哨兵标记一个在当前输入序列中唯一的 ID
-
使用所有选择的标记构造目标,用哨兵标记分隔
虽然这个任务看起来有点复杂,但我们可以在下面看到它在短输入序列上的工作示例。
(来自 1)
通过用单一的哨兵标记替换整个掩码标记的范围,我们降低了预训练的计算成本,因为我们通常在较短的输入和目标序列上进行操作。
微调。 在完成预训练后,T5 会在每个下游任务上单独进行微调,然后再进行评估。由于 T5 使用文本到文本的格式,预训练和微调都使用相同的 最大似然目标!换句话说,我们只需将正确答案表述为文本序列(在预训练和微调过程中),然后训练模型输出正确的文本序列。
基线表现如何? 如下表所示,基线 T5 模型的表现与之前的模型(如 BERT)相似,尽管这些模型并不可直接比较(即,基线 T5 模型使用的计算量是 BERTBase 的 25%)。此外,我们看到预训练在大多数任务上提供了巨大的好处。这个规则的例外是翻译任务,在这些任务中,预训练与未预训练的表现相似。
(来自 1)
寻找更好的方法…
在测试基线架构和训练方法后,1 的作者逐次修改这种方法的一个方面,例如底层架构、预训练目标或微调策略。通过测试这些不同的迁移学习变体,我们可以找到在不同语言理解任务中 consistently 表现最佳的方法。
(来自 1)
架构。 为了研究架构选择对迁移学习结果的影响,我们可以测试不同的 变体变换器架构。在 1 中测试的架构包括普通的编码器-解码器架构、仅解码器架构以及一个前缀语言模型,它在序列中对固定前缀执行完全可见的注意力,然后使用因果自注意力生成输出;见上文。这些架构之间的主要区别在于它们的自注意力机制中使用的掩蔽类型。
(来自 1)
当测试几种不同的架构(使用因果语言建模和去噪目标进行预训练)时,我们发现编码器-解码器变换器架构(具有去噪目标)表现最佳,因此在剩下的实验中采用了这种架构。相对于其他模型,这种编码器-解码器变体总共有 2P 个参数,但计算成本与具有 P 个参数的仅解码器模型相同。为了将总参数数量减少到 P,我们可以在编码器和解码器之间共享参数,发现这种方法表现非常好。
预训练目标。 最初,T5 使用三种不同类型的预训练目标进行训练。第一种是 BERT 风格的 MLM 目标。其他目标是一个去洗牌策略(即模型尝试将打乱的句子恢复到正确的顺序)和一个基于前缀的语言建模目标。在后者中,文本被分成两个跨度,第一个跨度作为输入传递给编码器,第二个跨度由解码器预测(即回忆一下我们使用的是编码器-解码器转换器)。下文比较了这些目标训练的模型的表现,我们可以看到去噪目标明显优于其他策略。
(来自 1)
从这里开始,1中的作者测试了对 BERT 风格 MLM 目标的几种修改,如下表所示。
(来自 1)
这些变体的表现趋于相似;见下文。然而,通过选择替换整个损坏令牌跨度为单个哨兵令牌的预训练目标,并且仅尝试预测目标中的损坏令牌,我们可以最小化预训练的计算成本。因此,遮蔽整个连续令牌跨度的基线策略是高效的,因为它产生了更短的目标序列。
(来自 1)
1 的作者测试了不同的损坏率,发现损坏率对结果没有显著影响,并且 15%的设置效果良好。还发现一种替代的预训练目标,即明确选择令牌跨度进行损坏(即基线方法选择令牌时采用均匀分布而不是跨度,然后将连续的令牌组合在一起),其表现与基线方法相似。下图展示了1中测试的不同预训练目标的示意图。
(来自 1)
研究了许多不同的策略,但这里的主要结论是 (i) 去噪目标效果最好,(ii) 去噪目标的变体表现相似,以及 (iii) 最小化目标长度的策略在计算上最为高效。
数据和模型规模。 最后,研究了规模对 T5 质量的影响。首先,T5 使用几个不同的数据集进行预训练,包括一个未过滤的数据集,一个新闻特定数据集,一个模仿 GPT-2 的 WebText 语料库的数据集,以及几个维基百科语料库的变体。T5 在每个数据集上进行预训练后的表现如下面所示。
(来自 1)
我们在这里看到 (i) 不过滤预训练语料库是极其有害的,(ii) 在特定领域语料库上进行预训练在某些情况下是有帮助的。例如,在基于新闻的语料库上进行预训练在 ReCoRD 这个基于新闻文章的阅读理解数据集上表现最佳。
“这些发现背后的主要教训是,在领域内未标记的数据上进行预训练可以提高下游任务的性能。这并不令人意外,但如果我们的目标是预训练一个能够快速适应来自任意领域的语言任务的模型,这种情况就显得不那么令人满意。” — 来源 1
进一步说,T5 使用不同大小的 C4 语料库的截断版本进行预训练。从这些实验中,我们了解到更多的数据(并不令人意外)更好。在预训练期间多次循环通过较小版本的数据集会导致过拟合,并损害下游性能;见下文。
(来源 1)
为了扩展 T5 模型,作者测试了以下修改:
-
4X
更多训练迭代(或 4X
更大的批次大小) -
2X
更多训练迭代和 2X
更大的模型 -
4X
更大的模型 -
训练一个由 4 个编码器-解码器变换器组成的集合
这里,为了简化,预训练和微调步骤都增加了。这些实验的结果如下所示。
(来源 1)
这些结果大致符合我们的预期。增加训练时间(或批次大小)可以提高性能。将此与更大的模型结合起来,相比单独增加训练迭代或批次大小,能带来进一步的好处。换句话说,增加预训练数据量和模型大小在提升性能方面是互补的。
“机器学习研究的痛苦教训认为,可以利用额外计算的通用方法最终会胜过依赖人类专业知识的方法” — 来源 1
其他内容。 T5 也使用不同的多任务训练策略进行了微调。总体而言,这些模型的表现略逊于那些针对每个任务单独微调的模型。然而,确实存在策略来最小化任务特定微调和多任务学习之间的性能差距。有关更多信息,请查看 此处 的概述。
许多深度神经网络的微调方法仅训练模型参数的子集(例如,“冻结”早期层并仅微调模型中的最后几层)。1 的作者尝试了几种这种微调 T5 的技术(例如,通过 适配器层 或逐步解冻 [6]),但这些方法被端到端微调完整模型所超越;见下文。
(来自1)
T5: 把一切整合在一起!
现在我们已经回顾了1中的整个实验分析,我们对 NLP 中的迁移学习不同选项以及什么效果最好有了更清晰的认识!下面,我们将讨论这一分析的主要要点,这些要点构成了 T5 所使用的官方迁移学习框架。与各种替代方案相比,这种方法被发现表现相当不错。
基线设置。 首先,让我们回顾一下 T5 的基线架构。它是一个编码器-解码器变换器,通过统一的文本到文本格式进行训练。在进行去噪预训练后,模型会在每个下游任务上单独进行微调,然后再进行评估。值得注意的是,最终的 T5 模型在 GLUE 和 SuperGLUE 基准测试中对每个任务进行了单独的微调,因为对所有任务同时训练的效果稍微差一点(假设我们采取必要的步骤以避免过拟合)。
预训练。 最终的 T5 方法并不是均匀选择标记,而是进行跨度破坏(即一次选择整个标记跨度进行破坏),平均跨度长度为三。然而,15%的标记仍然会被选择进行破坏。这个目标略好于基线,并产生更短的目标序列长度。此外,T5 将无监督预训练更新与多任务监督更新混合使用。无监督更新与监督更新的比例取决于所使用的模型大小(即,大型模型需要更多无监督更新以避免过拟合)。
训练量。 额外的预训练对 T5 的性能有帮助。具体来说,增加批量大小和训练迭代次数都能提升 T5 的性能。因此,最终的 T5 模型在总共预训练了 1T 标记。这比基线的 34B 标记要大得多,但仍远远低于在 2.2T 标记上进行预训练的 RoBERTa[12]。预训练是在通用的过滤 C4 数据集上进行的,因为特定任务的预训练在不同任务中没有一致的好处。
模型规模。 使用更大的模型是有帮助的,但有时较小的模型可能更合适(例如,当你在推理时计算资源有限)。因此,T5 发布了五种不同规模的模型,参数从 220M 到 11B 不等。因此,T5 实际上是一套不同的模型!我们可以通过链接这里访问这些模型。
结论
非常感谢阅读这篇文章。我是 Cameron R. Wolfe, Rebuy 的 AI 总监。我研究深度学习的实证和理论基础。你还可以查看我在 medium 上的 其他文章!如果你喜欢这篇文章,请在 twitter 上关注我,或订阅我的 Deep (Learning) Focus 新闻通讯,在这里我通过对流行论文的易懂概述帮助读者建立对 AI 研究主题的深入理解。
参考文献
1 Raffel, Colin, 等. “探索统一文本到文本转换器的迁移学习极限。” 机器学习研究杂志 21.1 (2020): 5485–5551。
2 Liu, Peter J., 等. “通过总结长序列生成维基百科。” arXiv 预印本 arXiv:1801.10198 (2018)。
[3] Liu, Peter J., Yu-An Chung 和 Jie Ren. “Summae: 使用长度无关的自编码器进行零样本抽象文本总结。” arXiv 预印本 arXiv:1910.00998 (2019)。
[4] Song, Kaitao, 等. “Mass: 用于语言生成的掩码序列到序列预训练。” arXiv 预印本 arXiv:1905.02450 (2019)。
[5] Devlin, Jacob, 等. “Bert: 用于语言理解的深度双向转换器预训练。” arXiv 预印本 arXiv:1810.04805 (2018)。
[6] Howard, Jeremy 和 Sebastian Ruder. “通用语言模型微调用于文本分类。” arXiv 预印本 arXiv:1801.06146 (2018)。
[7] Wang, Alex, 等. “GLUE: 一个多任务基准和自然语言理解分析平台。” arXiv 预印本 arXiv:1804.07461 (2018)。
[8] Wang, Alex, 等. “Superglue: 一个更具粘性的通用语言理解系统基准。” 神经信息处理系统进展 32 (2019)。
[9] Hermann, Karl Moritz, 等. “教机器阅读和理解。” 神经信息处理系统进展 28 (2015)。
[10] Rajpurkar, Pranav, 等. “Squad: 100,000+ 个用于机器文本理解的问题。” arXiv 预印本 arXiv:1606.05250 (2016)。
[11] Vaswani, Ashish, 等. “注意力机制就是你所需要的一切。” 神经信息处理系统进展 30 (2017)。
[12] Liu, Yinhan, 等. “Roberta: 一种稳健优化的 BERT 预训练方法。” arXiv 预印本 arXiv:1907.11692 (2019)。
TaatikNet: 序列到序列学习用于希伯来文音译
原文:
towardsdatascience.com/taatiknet-sequence-to-sequence-learning-for-hebrew-transliteration-4c9175a90c23?source=collection_archive---------5-----------------------#2023-06-28
一个简单的示例,展示了字符级别的 seq2seq 学习应用于复杂任务:在希伯来文文本和拉丁文音译之间的转换
莫里斯·阿尔珀
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 6 月 28 日
--
我们如何利用深度学习在不被“困惑”的情况下进行字符串转换?(图片 by Andrew Malone, CC BY 2.0)
本文描述了 TaatikNet 以及如何轻松实现 seq2seq 模型。有关代码和文档,请参见 TaatikNet GitHub 仓库。要查看交互式演示,请参见 TaatikNet 在 HF Spaces 上。
介绍
许多自然语言处理中的有趣任务涉及在不同风格、语言或格式的文本之间转换:
-
机器翻译(例如,从英语到德语)
-
文本摘要 和释义(例如,将长文本缩短为短文本)
-
拼写纠正
-
抽象问题回答(输入:上下文和问题,输出:答案文本)
这些任务统称为序列到序列(Seq2seq)学习。在所有这些任务中,输入和期望的输出都是字符串,这些字符串可能具有不同的长度,并且通常彼此之间没有一一对应关系。
假设你有一个配对示例的数据集(例如,句子及其翻译的列表,拼写错误和修正文本的多个示例等)。如今,只要数据足够多,使得模型能够学习到对新输入的泛化,训练神经网络变得相当容易。让我们看看如何使用 PyTorch 和 Hugging Face transformers 库以最小的努力训练 seq2seq 模型。
我们将重点关注一个特别有趣的用例:学习在希伯来文本和拉丁音译之间转换。我们将在下文中概述这一任务,但这里提出的想法和代码对超出这一特定案例的应用也是有用的——本教程对任何希望从示例数据集中执行 seq2seq 学习的人都应该有帮助。
我们的任务:希伯来音译
为了展示具有趣味性和相当新颖的应用案例,我们将其应用于音译。一般来说,音译是指在不同书写系统之间转换。虽然英语使用拉丁字母书写(“ABC…”),但世界上的语言使用许多不同的书写系统,如下所示:
世界上的一些书写系统。(图片 by Nickshanks, CC-BY-SA-3)
如果我们想使用拉丁字母来书写一个原本用不同书写系统书写的语言中的单词怎么办?这一挑战通过许多书写犹太节日汉 ukkah 名称的方式得到体现。目前的介绍在维基百科文章中读取如下:
Hanukkah (/ˈhɑːnəkə/; 希伯来语: חֲנֻכָּה, 现代希伯来语: Ḥanukka, 提比留语: Ḥănukkā) 是一个犹太节日,纪念耶路撒冷的恢复以及第二圣殿的重新奉献,发生在公元前 2 世纪的马加比起义之初,对抗塞琉古帝国。
希伯来词汇חֲנֻכָּה的拉丁字母音译可以是Hanukkah、Chanukah、Chanukkah、Ḥanukka或许多其他变体。在希伯来语以及许多其他书写系统中,存在各种约定和模糊性,使得音译复杂,而不是简单的一对一字符映射。
在希伯来语的情况下,可以使用复杂的规则将带有 nikkud(元音符号)的文本音译成拉丁字符,尽管也存在各种边缘情况,使得这看似复杂。此外,尝试将没有元音符号的文本音译或执行反向映射(例如Chanukah → חֲנֻכָּה)要困难得多,因为可能的有效输出非常多。
幸运的是,借助于对现有数据应用深度学习,我们可以用极少的代码在解决这个问题上取得很大进展。让我们看看如何训练一个 seq2seq 模型——TaatikNet——使其能够自主学习如何在希伯来文文本和拉丁音译之间转换。我们注意到这是一个字符级任务,因为它涉及到对希伯来文本和音译中不同字符之间关系的推理。我们将进一步讨论这一点的重要性。
顺便提一下,你可能听说过我们的 UNIKUD 模型,它用于给无标记的希伯来文文本添加元音符号。这些任务之间有一些相似之处,但主要的区别在于 UNIKUD 执行的是字符级分类,即对每个字符,我们学习是否在其旁边插入一个或多个元音符号。相比之下,在我们的情况下,由于音译的复杂性,输入和输出文本的长度或顺序可能不完全对应,因此我们在这里使用 seq2seq 学习(而不仅仅是按字符分类)。
数据收集
与大多数机器学习任务一样,我们很幸运能够收集到大量模型输入和期望输出的例子,以便我们可以使用监督学习对其进行训练。
对于许多与单词和短语相关的任务,一个很好的资源是Wiktionary及其多语言对照——可以将其想象为维基百科与词典的结合。特别是,希伯来语 Wiktionary (ויקימילון)包含了结构化的语法信息条目,如下所示:
来自希伯来语 Wiktionary 文章עגבניה(西红柿)的语法信息。
特别是,这包括拉丁音译(agvaniya,其中粗体表示重音)。连同包含 nikkud(元音字符)的章节标题,这为我们提供了训练模型所需的(自由许可)数据。
为了创建数据集,我们使用 Wikimedia REST API 抓取这些条目(示例见此)。请注意,Wiktionary 条目的原始文本具有宽松的衍生作品许可(CC 和 GNU 许可,详情见此),并要求共享相同许可(TaatikNet 许可见此);通常情况下,如果你执行数据抓取,请确保使用具有宽松许可的数据,适当抓取,并使用正确的衍生作品许可。
我们对这些数据执行各种预处理步骤,包括:
-
去除 Wiki 标记和元数据
-
用尖音符号代替粗体表示重音(例如 agvaniya → agvaniyá)。
-
Unicode NFC 规范化 用于统一相同出现的字形,例如 בּ(U+05D1 希伯来字母 Bet + U+05BC 希伯来点 Dagesh 或 Mapiq)和 בּ(U+FB31 希伯来字母 Bet 带 Dagesh)。你可以通过将它们粘贴到Show Unicode Character 工具中自行比较。我们还统一了类似的标点符号,如希伯来语 geresh(׳)和撇号(‘)。
-
将多词表达拆分为单个词。
数据抓取和预处理后,我们得到近 15k 对单词-音译对(csv 文件见此)。以下是几个示例:
我们数据集中的几个项目示例。请注意,带有 nikkud(元音点)的希伯来语在第二列,但由于从右到左的文本渲染问题,它首先出现。
音译绝非一致或无误;例如,重音标记不一致且经常错误标记,且使用了各种拼写规则(例如 ח 可能对应于 h, kh, 或 ch)。我们不会尝试清理这些,而是将其直接输入模型,让模型自行理解。
训练
现在我们有了数据集,让我们进入项目的“核心”——在我们的数据上训练 seq2seq 模型。我们将最终模型命名为 TaatikNet,取自希伯来语单词 תעתיק taatik,意为“音译”。我们将在这里高层次地描述 TaatikNet 的训练,但强烈建议你阅读注释过的训练笔记本。训练代码本身非常简短且具有指导性。
要在自然语言处理(NLP)任务上实现最先进的结果,一种常见的方法是使用预训练的变换器神经网络,并通过继续在任务特定数据集上进行微调来应用迁移学习。对于 seq2seq 任务,最自然的基模型选择是一个编码器-解码器(enc-dec)模型。像 T5 和 BART 这样的常见 enc-dec 模型非常适合常见的 seq2seq 任务,如文本摘要,但由于它们对文本进行标记化(将其拆分为子词标记,大致是词或词块),因此不太适合我们的任务,因为我们需要在单个字符级别上进行推理。为此,我们使用无标记化的 ByT5 enc-dec 模型(论文,HF 模型页面),该模型在单个字节级别上进行计算(大致为字符,但请参阅Joel Spolsky 对 Unicode 和字符集的优秀文章以更好地理解 Unicode 字形如何映射到字节)。
我们首先创建一个 PyTorch Dataset 对象来封装我们的训练数据。我们可以简单地将数据集 CSV 文件中的数据包装起来而不做任何修改,但我们添加了一些随机增强,使模型的训练过程更加有趣:
def __getitem__(self, idx):
row = self.df.iloc[idx]
out = {}
if np.random.random() < 0.5:
out['input'] = row.word if np.random.random() < 0.2 else row.nikkud
out['target'] = row.transliteration
else:
out['input'] = randomly_remove_accent(row.transliteration, 0.5)
out['target'] = row.nikkud
return out
这种增强方法教会 TaatikNet 接受希伯来文字或拉丁文字作为输入,并计算相应的匹配输出。我们还会随机丢弃元音符号或重音,以训练模型对其缺失具有鲁棒性。一般来说,随机增强是一种很好的技巧,当你希望网络学会处理各种类型的输入,而不必事先计算数据集中所有可能的输入和输出时。
我们使用一行代码通过 Hugging Face pipeline API 加载基础模型:
pipe = pipeline("text2text-generation", model='google/byt5-small', device_map='auto')
在处理数据整合和设置超参数(如训练轮次、批量大小、学习率)后,我们在数据集上训练模型,并在每轮训练后打印出选定的结果。训练循环是标准的 PyTorch,除了 evaluate(…)
函数外,该函数在其他地方定义,并打印出模型在各种输入上的当前预测:
for i in trange(epochs):
pipe.model.train()
for B in tqdm(dl):
optimizer.zero_grad()
loss = pipe.model(**B).loss
losses.append(loss.item())
loss.backward()
optimizer.step()
evaluate(i + 1)
比较早期轮次和训练结束时的一些结果:
Epoch 0 before training: kokoro => okoroo-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-oroa-o
Epoch 0 before training: יִשְׂרָאֵל => אלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלאלא
Epoch 0 before training: ajiliti => ajabiliti siti siti siti siti siti siti siti siti siti siti siti siti siti siti siti siti siti sit
Epoch 1: kokoro => מְשִׁית
Epoch 1: יִשְׂרָאֵל => mará
Epoch 1: ajiliti => מְשִׁית
Epoch 2: kokoro => כּוֹקוֹרְבּוֹרוֹר
Epoch 2: יִשְׂרָאֵל => yishishál
Epoch 2: ajiliti => אַדִּיטִי
Epoch 5: kokoro => קוֹקוֹרוֹ
Epoch 5: יִשְׂרָאֵל => yisraél
Epoch 5: ajiliti => אֲגִילִיטִי
Epoch 10 after training: kokoro => קוֹקוֹרוֹ
Epoch 10 after training: יִשְׂרָאֵל => yisraél
Epoch 10 after training: ajiliti => אָגִ'ילִיטִי
在训练之前,模型输出的是无意义的字符,这在预期之中。在训练过程中,我们看到模型首先学会如何构造有效的希伯来语和音译,但花费更长时间来学习它们之间的联系。它也需要更长时间来学习诸如ג׳(gimel + geresh)对应于 j 的稀有项目。
一个警告:我们没有尝试优化训练过程;超参数的选择相当随意,也没有为严格评估预留验证集或测试集。此目的仅为提供一个 seq2seq 训练的简单示例和音译学习的概念验证;然而,超参数调优和严格评估将是未来工作的一个有前途的方向,结合下文限制部分提到的要点。
结果
下面显示了一些示例,演示了希伯来文本(有或没有元音)与拉丁音译之间的双向转换。你可以在HF Spaces 上的互动演示中尝试自己使用 TaatikNet。注意,它使用束搜索(5 束)进行解码,推理是对每个单词单独进行的。
TaatikNet 的输入和输出示例见互动演示。使用束搜索解码(5 束)生成多个输出。
更长文本的示例输出。推理是对每个单词单独进行的。注意挑战性案例如שבעיניו(最后的 yud 不发音)、חוכמה(kamatz gadol)、כאלה(倒数第二个音节重音)的成功音译。
限制与进一步方向
为了简化,我们将 TaatikNet 实现为一个最小的 seq2seq 模型,没有进行广泛的调优。然而,如果你对改进希伯来文文本与音译之间的转换结果感兴趣,有许多有前景的未来工作方向:
-
TaatikNet 仅尝试根据字母或声音对应关系猜测适当的拼写(无论是希伯来文还是拉丁音译)。然而,根据上下文,你可能希望将音译转换为有效的希伯来文本(例如 zot dugma → זאת דוגמא,而不是拼写错误的 זות דוגמע)。实现这一点的可能方法包括检索增强生成(访问词典)或在希伯来语句子及其拉丁音译的配对上进行训练,以学习上下文提示。
-
形式不寻常的输入可能会导致 TaatikNet 的解码陷入循环,例如 drapapap → דְּרַפָּפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּאפָּ. 这可能通过训练中的数据增强、更丰富的训练数据或在训练或解码中使用循环一致性来解决。
-
TaatikNet 可能无法处理其训练数据中一些相当稀有的惯例。例如,它常常无法正确处理ז׳(zayin+geresh),这表示稀有的外来音zh。这可能表明过拟合,或者在训练时使用样本权重以强调困难示例可能会有所帮助。
-
seq2seq 训练的便利性是以解释性和鲁棒性为代价的——我们可能希望确切了解 TaatikNet 如何做出决策,并确保这些决策的一致性。一个有趣的可能扩展是将其知识提炼成一组基于规则的条件(例如,如果在上下文 Y 中看到字符 X,则写 Z)。最近的代码预训练 LLM 可能对这个有所帮助。
-
我们没有处理“完整拼写”和“缺陷拼写”(כתיב מלא / חסר),即希伯来词在有无元音符号的情况下拼写略有不同。理想情况下,模型应在无元音的“完整”拼写和有元音的“缺陷”拼写上进行训练。请参阅 UNIKUD 了解处理这些拼写的一种方法。
如果你尝试这些或其他想法并发现它们带来了改进,我非常乐意听到你的反馈,并在此处给你致谢——可以通过本文下方的联系信息与我联系。
结论
我们已经看到,用监督学习训练一个 seq2seq 模型是相当简单的——教它从大量配对示例中进行归纳。在我们的案例中,我们使用了一个字符级模型(TaatikNet,从基础 ByT5 模型微调而来),但几乎可以使用相同的过程和代码来处理更标准的 seq2seq 任务,如机器翻译。
我希望你从本教程中学到的东西与我整理它时学到的一样多!如有任何问题、意见或建议,请随时与我联系;我的联系信息可以在下面链接的我的网站上找到。
莫里斯·阿尔珀,硕士 是特拉维夫大学的博士生,研究多模态学习(NLP、计算机视觉及其他模态)。有关更多信息和联系信息,请访问他的网页:morrisalp.github.io/
使用 Tableau 仪表盘处理大数据:挑战与经验
原文:
towardsdatascience.com/tableau-dashboards-and-big-data-learnings-e0a29cb7377c
像专业人士一样使用 Tableau 分析大数据
Alle Sravani
·发表于 Towards Data Science ·阅读时长 7 分钟·2023 年 1 月 5 日
--
图片由 Myriam Jessier 在 Unsplash 提供
我的谦逊仪表盘创建旅程始于 Excel。此后,我使用了多个工具,如 Qlik、Sisense、PowerBI 和 Tableau。尽管如此,Tableau 仍然是我最喜欢的工具,因为它从未变得枯燥。它易于使用和学习,但也可以迅速变得复杂。完成任务后获得的满足感是无价的。我有机会与许多 Tableau 专家合作,逐渐掌握了制作强大视觉效果的许多技巧和窍门。尽管我已经熟练使用这个工具,我认为我仅仅是触及了表面。
最近我面临的最大挑战之一是使用大数据构建仪表盘。我需要设计一个仪表盘来跟踪推送通知对移动应用用户的效果。推送通知发送的原因有很多,包括状态更新、新优惠、提醒、活动等。根据用户数量和发送的通知,这些数据可能会变得极其庞大。当我开始处理这个项目时,源表中有三个月的数据,总共有超过 5000 万条记录!
仪表盘创建过程的最佳实践。图像由作者使用 Powerpoint 创建。
在创建新仪表板之前,应从利益相关者/最终用户那里收集所有需求。理解仪表板的目的——它是为了什么,需要跟踪哪些指标,使用什么粒度级别,需要哪些过滤器等,这一点至关重要。如果可能的话,所有这些内容都应进行文档记录。这是为了避免收到多次变更请求和浪费时间开发最终不会被使用的内容。在项目进行期间,我与利益相关者进行了几次会议,以确保满足所有需求。我还涉及了工程团队,以确保所需数据在我们的数据仓库中随时可用。
下一步是创建工程框架——在 Tableau 的情况下,这将是一个适当的数据提取,仪表板可以连接到其中。什么是提取?它们只是原始数据的保存子集,可以定期刷新。提取的最大优势是性能提升。你减少了对主要数据源的总查询次数(仅在刷新时需要)。此外,你可以汇总或仅保留所需字段,并应用过滤器以减少提取的总大小。如果提取设计得当,可以作为多个工作簿的数据源,从而节省每次创建新数据集的时间。你可以在这里阅读 Tableau 提取的完整文档。
创建提取的方法有多种:
-
使用 Tableau Desktop 创建数据模型,然后将其发布到服务器作为提取——这种方法简单且适用于小型数据集。
-
使用 Tableau Desktop 创建数据模型,将其发布到服务器作为“实时”连接,然后在服务器上将其更改为提取——这种方法最适合中等大小的数据集,有助于节省在本地创建提取的时间。
这是我遇到的第一个障碍。我无法创建提取,因为数据量太大。我建立了与数据集的实时连接,看看是否可以在不创建提取的情况下继续工作。这一过程简直是一场噩梦。仪表板非常缓慢,每当我尝试更改视图中的内容时,加载时间要么需要 20 秒或更长时间,要么应用程序完全停止工作。我开始寻找其他选项,直到遇到了一篇有用的文章,展示了如何创建空提取并欺骗服务器刷新数据这里。
发布到 Tableau 后的成功消息。截图由作者拍摄。
我对成功创建提取感到非常兴奋。我根据规格构建了仪表板,并在经过彻底的 QA(质量保证)检查后,仪表板向所有人开放。然而,我的成功并未持续太久。几个月后,提取在服务器上未能刷新,连接自动从提取切换到实时。数据量已超过 2 亿条记录,任何尝试提取的方法,即便使用空提取技巧,也失败了。仪表板再次变得缓慢,我开始收到来自各个用户的提高性能的请求。
课程 1 — 临时修复不会持久
我知道唯一提高性能的方法是减少数据的大小并将其制作成提取。虽然在我们的数据仓库中创建数据集时,已经遵循了一些最佳实践。为了避免昂贵的联接,所有需要分析的属性(无论是用户还是推送通知)都被整合到一个大型表中(OBT)。普遍认为星型模式总是优于 OBT,因为“规范化”表需要更少的存储空间且更易于理解。然而,对于大数据分析报告,事实表和多个维度表之间的联接在检索数据时会对性能产生很大影响。有关星型模式和 OBT 之间权衡的详细解释,请参见这个博客 这里。
我能进行的任何操作仅限于我为仪表板创建的数据集。现在仪表板已经完成,我注意到有些指标仅针对用户,而其他指标仅针对推送通知。这开启了新的可能性:
-
我可以创建两个不同的数据集 — 一个用于用户,另一个用于推送通知
-
我还可以利用 CTE(公用表表达式)和窗口查询来添加额外的指标,以避免在 Tableau 中创建一些计算字段。
通过上述更改,新聚合数据集的合并大小仅为原始数据的 3%!剩下的就是用新的数据源替换仪表板中的旧数据源。
课程 2 — 在 Tableau 上始终使用提取
我遇到的第二个障碍是,虽然 Tableau 允许你替换数据源,但不能用多个数据源替换它。没有办法选择性地将数据源分配给特定的工作表。这意味着我必须用新的数据集之一从头开始重建至少一半的整个仪表板。我对重新开始感到犹豫,因为我已经实现了很多自定义计算、参数和格式化技巧。这是我最后的选择。我在网上寻找解决方案但没有找到。我认识的一位 Tableau 专家 — Kasia Gąsiewska-Holc— 建议了一个聪明的解决方法,即从另一个工作簿复制工作表。
以下是其工作原理:
-
复制你的原始工作簿(称之为 workbook1),并删除任何你不想替换数据源的工作表。
-
将 workbook1 中的数据源替换为新的数据源。
-
将 workbook1 的工作表复制并粘贴到原始工作簿中。(要复制工作表,请在底部功能区选择工作表名称,然后选择‘复制’。然后,通过右键单击功能区并选择粘贴,返回到原始工作簿。)
-
现在就是魔法时刻了:你将有两套工作表,一套来自 workbook1 具有正确的数据源,一套来自原始数据源但具有不正确的数据源。
-
剩下的就是交换工作表了。要交换工作表,进入仪表板并点击‘交换工作表’以确保你选择了正确的工作表。
交换后,删除原始工作簿中不再需要的工作表。
-
将原始工作簿的旧数据源替换为新的数据源。这将改变剩余工作表的数据源。
-
恭喜!你已成功将单一数据源工作簿替换为两个数据源工作簿。
整个过程花费了一个小时,而手动构建一半的仪表板可能要花费我整整一天的时间。
第三课 — 与 Tableau 专家交朋友
新版仪表板上线已有一段时间,目前没有报告任何错误或性能问题。所以我猜这里的工作已经完成了。有时候,你会期待一个非常复杂的解决方案,但实际上效果最好的解决方案往往是最简单的。希望你在这里找到了有用的东西。如果你也是 Tableau 用户,我很想知道你最近使用过的最佳技巧!
继续学习,不要犹豫寻求帮助。图片由 John Schnobrich 提供,来源于 Unsplash。
在你离开之前…
请在 Medium 上关注我,以免错过我未来撰写的任何新文章;你可以在我的 个人主页上找到更多我的文章。你也可以通过 LinkedIn 或 Twitter与我联系!
Tableau 数据融合教程——初学者的逐步指南
原文:
towardsdatascience.com/tableau-data-blending-tutorial-a-step-by-step-guide-for-beginners-5fd80fa001db
我们探索使用 Tableau 对数据融合进行全面概述,适用于数据科学家和数据分析师
Zoumana Keita
·发布于 数据科学前沿 ·7 分钟阅读·2023 年 1 月 30 日
--
图片由 Lukas Blazek 提供,来源于 Unsplash
数据融合——动机
现在,公司使用来自不同来源的数据来解决业务问题。能够有效地收集和结合这些数据已成为所有数据科学家和数据工程师的基本技能,以帮助组织做出明智的决策。
在本教程中,我们将首先建立对一种强大的数据组合方法——数据融合的理解,然后探索其好处。接着,我们将探讨其一些缺点,最后解答一些关于使用 Tableau 进行数据融合的常见问题。
Tableau 中的数据融合与连接。
在现实生活中,你将处理来自多个来源的信息,例如 Excel 表格、SQL、CSV 等。作为数据科学家,你需要将这些信息互联,以生成全球业务洞察。
Tableau 可以通过两种不同的方法来解决这个问题:连接、融合和关系。
在本节中,我们将重点理解前两种方法之间的区别。
Tableau 中的连接是什么?
如果你之前使用过 SQL,可能对主要的连接概念有所了解:左连接、右连接、内连接、交叉连接和全外连接。使用 Tableau 时也是如此。
这些连接旨在基于这些表之间的一些逻辑列关系,结合来自相同来源的不同表。例如,尝试将 Excel 文件和 SQL 表结合起来会失败,因为它们不来自相同的来源。
此外,在连接表时,两个表中使用的列必须相同,改变这些数据类型会导致连接中断。
最后,无法从 Tableau 的连接中删除重复列。
Tableau 中的数据混合是什么?
与连接和关系不同,数据混合将来自多个来源(如数据库、商业智能系统、云系统、平面文件、网络服务等)的数据汇总为单一数据,以便更好地进行可视化。
为什么使用 Tableau 进行数据混合?
首先,什么是 Tableau?
Tableau 是一个无代码的商业智能工具,提供了一个直观的拖放界面用于分析和可视化。这一非技术性特点使其在行业中脱颖而出。
此外,它快速且提供了将来自多个来源的数据互连的能力,如电子表格、SQL 数据库、网络服务等,无论是来自云还是本地。
使用 Tableau 进行数据混合提供了相同的灵活性,并使得从多个来源合并数据变得更容易,无需编程技能。
数据混合在 Tableau 中的逐步指南。
为了更好地理解数据混合的概念,本节将通过逐步过程指导你如何在 Tableau 中进行数据混合。
- 了解用例
本案例分析了来自 Zoom.AI 的两个部门在 2020 年和 2021 年的收入,Zoom.AI 是一家在不同非洲首都(大城市)运营的 AI 公司,如下图所示。
部门 1 和部门 2 的收入数据(作者图片)
2. 了解数据列
我们注意到这两个数据的格式相同,并传达相同的信息。然而,前两列的名称不同(部门 1 为 Period、Capitals,部门 2 为 Year、Cities)。
尝试将连接应用于当前的数据格式将会失败,因为参与连接的列需要相同。我们将看到数据混合如何巧妙地结合这些数据,而无需事先进行列的标准化。
3. 创建数据连接
了解我们的数据存储在不同的文件中,我们需要在 Tableau 中创建一个空间,将它们重新分组,这可以通过在 Tableau 中创建数据连接来完成,如下所示:
将数据导入 Tableau(作者动画)
我们部门的数据出现在左上角的数据选项卡中,意味着它们已成功上传。
4. 开始数据混合
首先,我们从可视化第一个部门的 Capitals 收入的变化开始。这可以通过以下两个步骤完成:
-
将 Revenue 列拖放到 Tableau 的列部分。
-
将 Capitals 列拖放到 Tableau 的行部分。
2020 年收入的可视化(作者动画)
目标是使用数据融合来比较部门的收入。为此,我们需要对第二部门的数据执行相同的拖放操作,而该数据使用的是城市列,如下所示。
部门 2 的结果已用橙色标记以便于说明。
为部门 2 添加收入(作者动画)
可视化破裂了💥💥💥,为什么?!
原因在于我们没有向 Tableau 指明首都列的信息水平与城市列相同,这导致在可视化中出现了城市列的星号标记。
这时数据融合就派上用场了,通过指定城市和首都列表示相同的内容来解决这个问题。下面是示例图。
数据的正确融合(作者动画)
在解决问题后,我们可以看到左侧城市列上有一个橙色链条⛓ 标记。这意味着城市列被用作连接字段。
更多的数据融合
现在你理解了如何正确地融合数据,让我们在年份/期间级别上进行相同的分析。为此,我们只需在期间和年份列之间创建一个链接,否则,我们将面临与之前相同的问题。
数据的正确融合(作者动画)
恭喜你达到了这个层级!现在你知道如何进行数据融合,你可以通过这个课程学习如何使用 Tableau 创建高级可视化。
数据融合在 Tableau 中有哪些好处?
从之前的教程中可以观察到,与 Joins 相比,使用数据融合提供了许多好处。下面列出了一些:
- 数据列无需预处理
我们不需要对列进行任何昂贵的预处理就能将不同的数据结合起来,因为只需简单指定感兴趣的列,数据融合算法就能理解如何高效地结合数据。
2. 解决粒度问题的更好方法
数据融合在结合不同粒度级别的数据时提供了更好的灵活性。假设我们的部门收入数据是按月组织的,而不是按年组织的。
使用简单的联接会导致不准确的结果,因为每年的行将被复制到每个月的行中,如下所示。数据融合可以通过将月份信息汇总到年度水平来解决这个问题。
不同粒度级别的数据融合(作者图片)
团队生产力
让我们考虑之前的粒度情况。如果没有数据混合,我们将不得不执行一些数据预处理,如汇总和重复数据删除。处理庞大的数据集时,这种预处理可能会迅速变得耗时。
快速分析以更好地决策
精确的决策制定依赖于对数据的更好概述。只需几次点击和拖放,就能捕捉到两个部门的本质。使用联接可能会导致不准确的结果,从而导致不准确的决策制定。
Tableau 数据混合的限制。
尽管数据混合有其优点,但也有一些缺点,如下所示:
-
数据混合在后台使用左连接,它不执行其他类型的连接。
-
在尝试混合不同粒度的数据时,顺序很重要。次要数据始终必须具有最小的粒度,因为它是需要聚合的。
-
非加性汇总,如 COUNT、MEAN、MEDIAN 和 SUM,会受到数据混合的影响。
结论
本教程涵盖了使用数据混合的动机、与连接相比对数据科学家的某些好处,以及实践操作,以更好地理解其工作原理。
希望你觉得这次基准分析对做出明智选择有帮助!
如果你喜欢阅读我的故事并希望支持我的写作,可以考虑 成为 Medium 会员。每月$5 的承诺,您将解锁对 Medium 上故事的无限访问权限。
随时欢迎在 Medium、Twitter 和 YouTube 上关注我,或者在 LinkedIn 上打招呼。讨论 AI、ML、数据科学、NLP 和 MLOps 相关话题总是一件愉快的事!
什么是禁忌搜索?
原文:
towardsdatascience.com/tabu-search-simply-explained-ee2852339d78
禁忌搜索优化算法的直观解释及其在旅行商问题中的应用
Egor Howell
·发表于Towards Data Science ·5 分钟阅读·2023 年 3 月 13 日
--
图片由Clint Adair提供,来源于Unsplash
背景
在我的上一篇文章中,我们讨论了元启发式优化算法模拟退火。 这是一种随机搜索算法,用于尝试在组合优化问题中找到全局最优解,如著名的旅行商问题 (TSP)和背包问题。
另一种类似的算法称为禁忌搜索, 它可以被视为模拟退火算法的推广。在这篇文章中,我想讨论和解释禁忌搜索,回顾旅行商问题(TSP),然后在 Python 中实现禁忌搜索来解决 TSP。
禁忌搜索
概述
Tabu Search 是一种元启发式优化算法,由 Fred Glover 在 1980 年代末期构思。类似于模拟退火,Tabu Search 使用 局部搜索,但可以接受更差的解以避免陷入 局部最小值。它的另一个主要关键成分是它使用 记忆结构 防止算法访问之前观察过的解,从而更广泛地探索搜索空间。换句话说,它有一个‘TABU’列表!
Tabu Search 算法可以用来解决各种各样的问题:
-
资源分配
-
调度
-
供应链优化
因此,学习和理解它是值得的,因为它可以应用于多种情况。
记忆结构、任期和 Tabu List
如上所述,Tabu Search 记录之前访问过的解,这被称为 Tabu List,并在特定时间内将它们保存在记忆中,任期,以防止解的回收并更好地 探索 搜索空间。
一般来说,Tabu Search 使用两种类型的 记忆结构:
-
短期: 这通常是一定数量的之前访问过的解,我们不应回到这些解上。
-
长期: 这是为了在搜索遇到困境时提供帮助,并帮助扩展搜索范围。
实际上,没有要求一定要有一个或两个,甚至两个都不需要。主要的思想是我们跟踪算法正在做什么,并帮助其探索更广泛的可能解。
算法概述与愿望准则
Tabu Search 的一般流程如下:
-
生成一个初始有效解。
-
使用从当前解出发的局部搜索获取可能的邻域解集。
-
从这些邻域解中,获取一个不在 Tabu List 上的最佳候选解。
-
将这个最佳候选解与迄今为止找到的最佳解进行比较,并根据需要进行分配。
-
使用最佳候选解更新 tabu 列表。
-
使用最佳候选解重复步骤 2-5 以生成新的邻域,直到满足某些停止条件。
另一个规则是,如果我们发现一个在 Tabu List 上的解但其目标函数比当前最佳解更好,我们仍然接受这个解。这被称为 愿望准则。
就这样!如果你对这个过程还有疑问,请继续阅读,我们将用 Python 实现这个理论以使其更为具体。
旅行推销员问题
在使用 Tabu Search 解决旅行推销员问题(TSP)之前,值得快速讨论一下什么是 TSP。
TSP 可能是最著名且最易于理解的组合优化问题。问题很简单:“给定一组城市,什么是访问每个城市一次并返回原始城市的最短路线?”
这个问题难以解决的原因是它是NP 难的,随着我们需要访问的城市数量增加,可能路线的数量会组合爆炸。例如,对 20 个城市进行穷举所有解需要约 2000 年!
可能路线的数量按(n-1)!/2 计算
由于 TSP 在特定数量城市下的难解性,我们需要 resort to heuristics such as Tabu Search and Simulated Annealing to provide sufficient solutions in a reasonable amount of time.
用 Python 实现 TSP 的禁忌搜索
算法设计
首先列出一些伪代码,说明我们如何实现 TSP 的禁忌搜索:
-
生成一个初始路线并用这个初始路线更新禁忌列表
-
从这个初始路线中,通过交换当前路线中的相邻城市来生成邻域
-
从这个邻域中获取最短的最佳邻域路线,该路线不在禁忌列表上
-
将最佳邻域路线与找到的最佳总体路线进行比较,并根据需要更新
-
使用当前最佳邻域路线重复步骤 1–3,以产生新的邻域
这是一个相当基础的禁忌搜索算法,因为它仅包含短期记忆结构。
Python 代码
以下是实现上述算法的通用类。该类只需要一个initial_solution
,即城市的某个顺序列表,以及一个将城市映射到其坐标的字典cities
。
作者的 Github Gist。
现在让我们对一些合成生成的数据集运行这个类:
由作者在 Python 中生成的图。
由作者在 Python 中生成的图。
最佳解决方案看起来是一个相当合理的结果,而且没有花费我们数千年的时间来计算!
总结与进一步思考
在这篇文章中,我们解释了元启发式禁忌搜索算法。这个优化算法使用局部搜索技术,但仍可以通过接受更差的解决方案来逃避局部最小值。它还利用了禁忌列表,阻止它过渡到之前访问过的解决方案,并更大程度地探索搜索空间。应用于旅行推销员问题时,这个算法取得了很好的结果。
完整代码可以在我的 GitHub 上找到:
## Medium-Articles/Optimisation/tabu-search at main · egorhowell/Medium-Articles
目前无法执行该操作。您已在另一个标签页或窗口中登录。您在另一个标签页或…
github.com
另一个事项!
我有一个免费的新闻通讯,Dishing the Data,每周分享成为更好的数据科学家的技巧。没有“废话”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作见解。
## Dishing The Data | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读《Dishing The Data》,由 Egor Howell 编写,是 Substack 出版物,包含…
newsletter.egorhowell.com
连接与我!
-
YouTube
-
LinkedIn
-
Twitter
-
GitHub
参考文献与进一步阅读
- 优化算法。 Mykel J. Kochenderfer 和 Tim A. Wheeler。2019。
Tabyl:现代 R 用户的频率表格
原文:
towardsdatascience.com/tabyl-a-frequency-table-for-the-modern-r-user-e061cd48baef?source=collection_archive---------4-----------------------#2023-05-20
旧的已过时,新的正在兴起!
兹沃尼米尔·博班
·
关注 发表在 向数据科学迈进 ·6 分钟阅读·2023 年 5 月 20 日
--
使用 Canva 图像生成器创建的图像
任何处理分类数据的人最终都会遇到计算某个类别的绝对数和比例的需要。本文介绍了通过一系列实际示例使用 tabyl
函数创建频率表格。
tabyl
为表格带来了什么(没有刻意捣乱的意思 😄)?
tabyl
函数是 R 语言中 janitor
包的一个特性。它是创建列联表的非常方便的工具,也被称为频率表或交叉制表表格。以下是使用 tabyl
的一些好处:
1. 简单的语法:tabyl
具有易于使用的语法。它可以接受一个、两个或三个变量,并自动返回一个包含计数和比例的数据框。
2. 灵活性:tabyl
可以生成单向(单变量)、双向(双变量)和三向(三变量)列联表。这种灵活性使其适用于各种应用场景。
3. 自动计算比例:tabyl
自动计算单向列联表的比例(百分比)。对于双向和三向表,可以结合使用同一包中的 adorn_percentages
函数来实现相同的结果。
4. 与 dplyr
的兼容性:tabyl
的输出是数据框(或 tibble),这使其完全兼容 dplyr
函数和 tidyverse 生态系统。这意味着你可以轻松地将 %>%
管道操作符应用于进一步的数据处理或可视化功能。
5. 整洁且信息丰富的输出:tabyl
提供整洁且信息丰富的输出,包括将变量名作为行名和列名,这使得结果更易于解释。
基于以上所有原因,当你想在 R 中创建频率表时,tabyl
是一个很好的选择。它简化了许多步骤,并且与 tidyverse 数据分析方法很好地集成。
数据集
由 Hans Veth 在 Unsplash 拍摄的照片
本文将使用关于不同类型蘑菇的气味的可食性数据来演示 tabyl
函数的优点。在这里,我将使用一个名为 mushrooms 的整理数据集,但你可以访问 Kaggle 上的原始数据。以下是用于清理数据的代码。
library(tidyverse)
library(janitor)
mushrooms <- read_csv("mushrooms.csv") %>%
select(class, odor) %>%
mutate(
class = case_when(
class == "p" ~ "poisonous",
class == "e" ~ "edible"
),
odor = case_when(
odor == "a" ~ "almond",
odor == "l" ~ "anise",
odor == "c" ~ "creosote",
odor == "y" ~ "fishy",
odor == "f" ~ "foul",
odor == "m" ~ "musty",
odor == "n" ~ "none",
odor == "p" ~ "pungent",
odor == "s" ~ "spicy"
)
)
如果你对上述语法不熟悉,请查看我早期文章中的 tidyverse 实用指南。
## 使用泰坦尼克号数据深入了解 tidyverse
[towardsdatascience.com
旧的
为了更好地理解 tabyl
提供了哪些优势,让我们首先使用基础 R 中的 table
函数创建一个频率表。
table(mushrooms$class)
edible poisonous
4208 3916
table(mushrooms$odor, mushrooms$class)
edible poisonous
almond 400 0
anise 400 0
creosote 0 192
fishy 0 576
foul 0 2160
musty 0 36
none 3408 120
pungent 0 256
spicy 0 576
毫不奇怪,气味竟然是预测蘑菇可食性的一个重要指标,任何“有怪味”的蘑菇可能都是有毒的。谢谢进化!此外,似乎还有很多更多的有毒蘑菇,因此在自己采摘蘑菇时始终要小心谨慎。
如果我们希望直接使用变量名而不指定$
运算符,我们需要使用with
命令使数据集对table
函数可用。
mush_table <- with(mushrooms, table(odor, class))
不幸的是,如果我们想升级到比例而不是绝对数,我们不能使用相同的函数,而是要使用另一个函数——prop.table
。
prop.table(mush_table)
class
odor edible poisonous
almond 0.049236829 0.000000000
anise 0.049236829 0.000000000
creosote 0.000000000 0.023633678
fishy 0.000000000 0.070901034
foul 0.000000000 0.265878877
musty 0.000000000 0.004431315
none 0.419497784 0.014771049
pungent 0.000000000 0.031511571
spicy 0.000000000 0.070901034
默认情况下,这会给我们一个按列的比例表。如果我们想要按行的比例,我们可以指定margin
参数(1 表示按行,2 表示按列)。
prop.table(mush_table, margin = 1)
class
odor edible poisonous
almond 1.00000000 0.00000000
anise 1.00000000 0.00000000
creosote 0.00000000 1.00000000
fishy 0.00000000 1.00000000
foul 0.00000000 1.00000000
musty 0.00000000 1.00000000
none 0.96598639 0.03401361
pungent 0.00000000 1.00000000
spicy 0.00000000 1.00000000
所有这些特殊函数可能会感觉繁琐且难以记住,因此拥有一个包含所有上述功能的单一函数将是很好的。
此外,如果我们使用class(mush_table)
命令检查创建的对象类型,我们会发现它属于table
类。
这创建了一个兼容性问题,因为如今 R 用户大多使用 tidyverse 生态系统,该系统以将函数应用于data.frame
类型对象并通过管道(%>%
)运算符串联结果为中心。
新的
让我们用tabyl
函数做同样的事情。
tabyl(mushrooms, class)
class n percent
edible 4208 0.5179714
poisonous 3916 0.4820286
mush_tabyl <- tabyl(mushrooms, odor, class)
mush_tabyl
odor edible poisonous
almond 400 0
anise 400 0
creosote 0 192
fishy 0 576
foul 0 2160
musty 0 36
none 3408 120
pungent 0 256
spicy 0 576
与相应的table
输出相比,使用tabyl
函数生成的表格更整洁,变量名(类别)被明确指出。此外,对于一维表格,除了数字,百分比也会自动生成。
我们还可以注意到,我们不需要使用 which 函数就可以直接指定变量名。此外,运行class(mush_tabyl)
告诉我们生成的对象是data.frame
类,这确保了与 tidyverse 的兼容性!
装饰过的 janitor
使用 Canva 图像生成器创建的图像
为了额外的tabyl
功能,janitor
包还包含了一系列adorn
函数。要获取百分比,我们只需将生成的频率表传递给adorn_percentages
函数。
mush_tabyl %>% adorn_percentages()
odor edible poisonous
almond 1.0000000 0.00000000
anise 1.0000000 0.00000000
creosote 0.0000000 1.00000000
fishy 0.0000000 1.00000000
foul 0.0000000 1.00000000
musty 0.0000000 1.00000000
none 0.9659864 0.03401361
pungent 0.0000000 1.00000000
spicy 0.0000000 1.00000000
如果我们想要按列的百分比,我们可以将denominator
参数指定为“col”。
mush_tabyl %>% adorn_percentages(denominator = "col")
odor edible poisonous
almond 0.09505703 0.000000000
anise 0.09505703 0.000000000
creosote 0.00000000 0.049029622
fishy 0.00000000 0.147088866
foul 0.00000000 0.551583248
musty 0.00000000 0.009193054
none 0.80988593 0.030643514
pungent 0.00000000 0.065372829
spicy 0.00000000 0.147088866
tabyl
— adorn
组合甚至使我们能够轻松地将数量和百分比结合在同一表格单元格中…
mush_tabyl %>% adorn_percentages %>% adorn_ns
odor edible poisonous
almond 1.0000000 (400) 0.00000000 (0)
anise 1.0000000 (400) 0.00000000 (0)
creosote 0.0000000 (0) 1.00000000 (192)
fishy 0.0000000 (0) 1.00000000 (576)
foul 0.0000000 (0) 1.00000000 (2160)
musty 0.0000000 (0) 1.00000000 (36)
none 0.9659864 (3408) 0.03401361 (120)
pungent 0.0000000 (0) 1.00000000 (256)
spicy 0.0000000 (0) 1.00000000 (576)
… 或将总计添加到行和列。
mush_tabyl %>% adorn_totals(c("row", "col"))
odor edible poisonous Total
almond 400 0 400
anise 400 0 400
creosote 0 192 192
fishy 0 576 576
foul 0 2160 2160
musty 0 36 36
none 3408 120 3528
pungent 0 256 256
spicy 0 576 576
Total 4208 3916 8124
结论
R 中janitor
包的tabyl()
函数提供了一个用户友好且灵活的解决方案,用于创建一维、二维或三维列联表。它在自动计算比例和生成整洁的数据框方面表现出色,能够与 tidyverse 生态系统,特别是dplyr
,无缝集成。它的输出结构良好且易于解释,并且可以通过 adorn 函数进一步增强,从而简化生成信息频率表的整体过程。这使得tabyl()
在 R 的数据分析中成为一个非常有用的工具。
处理集中数据管理中的敏感性问题
原文:
towardsdatascience.com/tackling-sensitivities-in-centralized-data-management-c1050a4310b7?source=collection_archive---------12-----------------------#2023-07-24
战斗中的经验教训
Willem Koenders
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 24 日
--
图片由 Jehyun Sung 提供,通过 Unsplash 获取。
大多数规模和年龄相当的组织如今已经启动了提升数据处理和管理方式的举措。任何结构性提升数据管理能力的尝试都将需要一定程度的集中化,即便只是为了发现组织内的现状实践。一个普遍的趋势是任命首席数据官,一项调查发现 超过 80%的财富 1000 强组织现在都报告说已经设立了首席数据官。
任何变革都可能引发反应,尤其是那些定义新责任、改变权威并需要资金的组织变革可能会敏感,因此处理起来可能棘手。我花了十多年时间陪伴新兴的首席数据官走过他们的旅程。我有许多经验教训可以分享。接下来的文章中,我们将回顾启动中央数据团队时的典型挑战以及应对这些风险的实际措施。
启动中央数据团队及相关敏感问题
创建一个初步的中央团队并集中责任可以通过多种方式完成。这可能包括启动首席数据办公室、成立数据(治理)委员会、改进中央驱动的政策和标准,以及提供中央资源以执行特定的数据治理流程。
随着中央数据团队获得权威,其他团队必须让出这部分权力——这可能带来敏感问题。以下是一些最常见的挑战:
-
资源争夺战。中央团队需要人员和预算来运作。有时,这可能意味着从其他团队的预算或人员中挪用资源,这可能引发这些团队领导的反应。
-
控制权丧失。曾经对某些任务有自由支配权的团队可能不喜欢中央团队接管这些任务,或对如何执行这些任务提供标准。
-
缺乏地方相关性。如果中央团队与地方(或业务特定)的需求、操作和细节过于脱节,他们可能做出不能产生预期业务影响的决策,并且地方专长可能会丧失。
-
成为瓶颈。如果中央团队的人员配置和管理不当,可能会成为组织中的瓶颈或摩擦点,从而成为挫败感和怨恨的源头。
-
成为倾卸点。中央团队可能成为一个倾卸点,接手其他人不愿意承担的任务和责任,而这些任务和责任只产生有限的价值。
-
对变革的普遍抵触。团队和个人可能会对变革产生抵触,尤其是当影响未知或理解不充分时,且如果他们觉得自己无法达到新的期望。
让我们回顾一些与这些挑战相关的真实案例,我曾亲身目睹。在第一个案例中,该组织是一家全球前 100 的保险公司,拥有超过 5 万名员工,策略非常激进。新任首席数据官有“完成任务”的历史,这也是他被任命的部分原因。他的计划是定义一系列被归类为“数据治理活动”的活动。然后,手握描述的他,识别出整个组织中已经执行这些活动的人,最初是以评估当前状态的名义进行的。在最终被识别出的人中,有时数据治理占据了他们整个工作的内容,有时只是其中的一小部分。例如,有人可能在市场营销流程中从事数据质量工作。
一旦完成了对这些人员的分析,提出了将这些人员的大部分转移到中央团队的建议。当然,这导致了业务和职能团队的即时反抗,这些人员原本就位于这些团队中。他们曾通过参与成熟度评估展现了善意,却看到他们的团队成员,往往是各自操作中至关重要的部分,被重新分配。新的首席数据官很快发现自己陷入了敌对的局面,并且在第一年剩下的时间里几乎没有取得任何成果。
另一个案例来自一家全球银行,其中全球首席数据与分析官(“CDAO”)启动了一个转型项目,以创建一个经过整理的数据湖。CDAO 担心创建一个数据沼泽,因此她坚持严格的认证过程。该项目的承诺很强 — 团队可以提交请求,将数据上传到数据湖中,由 CDAO 领导的中央团队负责数据的摄取、标记和质量控制,以及访问权限的提供,使得业务和职能团队可以专注于数据分析和建立分析模型。然而,需求很快超过了容量,数据进入数据湖的平均处理时间超过了 1 个月,业务团队也因此变得沮丧,因为中央数据团队成了瓶颈。尽管有良好的意图,但需要进行一次重大重置(在情况如上所述恶化时,我的团队帮助分析了可以自动化数据治理过程的工具)。
上述例子来源于我的个人经验,但对于那些感兴趣的人,《数据驱动企业与 DataOps》由 Ashish Thusoo 和 Joydeep Sen Sarma 编写,提供了一些出色的阅读材料和附加案例研究。
减轻策略
图片由作者提供。
组织可以采取各种措施来管理与集中数据管理相关的敏感问题。以下是我个人的一些偏爱方法:
-
关注自动化。通过采用自动化优先的思维方式,可以减少手动工作和成本,从而缩减对大团队的需求。设计中的数据治理,将治理原则嵌入数据系统设计中,可以促进一致性和效率。
-
沟通和透明度。关于集中化的原因和预期结果的清晰及时沟通可以培养信任,并鼓励利益相关者的支持。
-
指导和培训。任何新的或更新的流程或政策都应配备适用的教育材料和社交化流程,以明确其实施方式。
-
提供反馈和影响的途径。数据委员会或治理论坛可以确保利益相关者不仅仅是数据治理指南的接受者,而是旅程中积极且受重视的合作伙伴。
-
主动应对棘手的利益相关者。识别可能有特定关注点或异议的利益相关者,并主动应对,尽可能包括一对一的会议,并明确管理各自的需求和要求。
-
从小处开始,取得初步胜利。从一个理解明确、涉及相对“友好”人员且有明确业务案例的范围开始。这可以产生初步的动力,为以后处理更复杂的问题奠定基础。
-
赋权于本地团队和管理者。智能数据治理不是将所有工作集中化,而是使现有角色和人员能够以一致的方式更好地履行职责。跨功能和跨区域的团队可能是一个选择。
-
仅在需要时集中化。确保标准化或集中化数据治理的努力基于积极的业务案例。在不确定时避免集中化。
让我们回顾一下我个人观察的一些额外案例。首先,为了避免资源上的激烈争夺,一家大型区域零售商采用了有针对性的、基于用例的方法,分析了基于约 25 个数据角色的资源需求和缺口。这些角色包括数据所有者、数据管理员、数据建模师、数据科学家、数据工程师、系统所有者、流程所有者、领域数据管理员、数据架构师等。下一步是与业务和职能团队讨论这些角色,识别他们已经拥有的角色和面临的挑战,例如在拥有正确的专业知识和技术方面。当识别出共同的痛点和机会领域,中央团队能够有效地提升业务和职能团队时,这会得到立即欢迎。完成业务案例以获得初步的、尽管规模较小的团队的资金是快速而简单的。
与我曾合作过的另一位首席数据官一起,我们在她推出新的数据治理操作模型的同时,开展了数据素养和文化推广活动。新的操作模型引入了许多员工不熟悉的新术语和责任,我们认为这可能会引起混淆和焦虑。除了其他举措外,我们还组织了一场“问我任何问题”的会议,组织中的任何人——总计超过 25,000 名员工——都可以匿名提交问题,问题会被实时回答。除了许多实际问题外,还提出了很多敏感问题,比如“这些新责任对我的薪酬有何影响?”,“这是三年来的第 3 个新操作模型——我们真的需要另一个吗?”,以及“我的团队领导认为这是胡说八道——我该怎么办?”随后匿名收集的反馈和轶事评论表明,这种透明的方法缓解了个人的担忧,并鼓励了那些受到重新定义的责任影响的员工的认可。
最后的一个例子,也许是我最喜欢的案例研究,是经过改造的首席数据办公室采纳了自动化优先的思维模式。为新实施的云原生数据平台创建了一个参考架构,其中包括几个基础设施:数据产品方法、严格的互操作标准和数据管理中心。这意味着任何转移到平台上的内容从定义上都是自动管理的。即,只有符合数据产品标准的内容才被允许存储在平台上,从而确保了(产品和数据)所有权的定义。互操作标准、平台内部的共同存储模式以及数据管理中心(包括数据目录)的结合实现了几乎完全自动化的元数据管理。尽管这需要大量的前期投资来打好基础,但避免了中央数据平台团队需要大量分析师和工程师来操作平台及其数据产品的情况。
展望未来
集中管理数据并非一项简单的任务;它充满了复杂性和敏感性。然而,通过有效的沟通、利益相关者参与、重点培训和经过深思熟虑的方法,仅在有明确业务案例的情况下进行集中,这些挑战可以得到有效管理。如果你有任何见解或经验要分享,请在评论中告诉我们。
参考文献
-
首席数据官的三个被低估的成功因素,Medium。
-
首席数据官创造并展示价值的 8 种策略,HBR。
-
三次首席数据官的五个教训,Data World。
-
创建以数据驱动的企业与 DataOps,Ashish Thusoo 和 Joydeep Sen Sarma。
应对变化世界中的问题
原文:
towardsdatascience.com/tackling-the-problems-of-a-changing-world-with-data-13781b3a6088?source=collection_archive---------8-----------------------#2023-03-30
TDS Editors
·
关注 发表在 Towards Data Science · 以 新闻通讯 发送 · 3 分钟阅读 · 2023 年 3 月 30 日
--
如果你觉得过去几年特别忙碌和紧张,你并不孤单。我们怀疑人类历史上任何时期都可以被描述为平静,但全球危机、技术进步和媒体传播的交汇使得我们当前的时刻具有非常独特的—有时甚至是压倒性的—特质。一切!无处不在!一齐发生!
我们希望为许多数据从业者提供一个建设性的出口,让他们思考解决方案和务实的方法来应对我们全球社区面临的(众多)问题,因此 我们最近汇编了一系列 30 篇优秀文章,这些文章涉及人口分析、可持续农业和城市规划等主题。我们还包括了一些公共数据集,以便您可以直接参与并开始探索。
本周,我们重点介绍了您可以在 “我们的全球村庄变迁” 中找到的一些贡献,这是一项我们本月早些时候分享的特别特辑。我们希望您能探索我们其他的一些推荐,并希望您阅读后能受到启发,进一步了解更多内容,并可能参与有意义的项目。祝您阅读愉快!
-
在人的尺度上讲故事。当我们谈论全球性问题如人口增长或气候变化时,很容易忽视那些将要参与(并受这些讨论影响)的实际人群。Emily A. Halford 对艺术家 Norwood Viviano 作品的概述很好地将数据、可视化叙事与复杂信息的传达连接起来。
-
计算机视觉与野生动物的结合。设计得当,新兴技术具有巨大的潜力,可以帮助人类在其环境中更可持续地生活。Abhay Kashyap 提供了一个引人注目的案例研究:一个利用 AI 系统支持致力于保护加州野猫的非营利组织的项目。
Kyle Winkle 拍摄的照片,来源于 Unsplash。
-
评估关键行业对环境的影响。由于研究人员可以获取大量公开数据,挑战通常在于找到合适的数据集并为项目定义合理的范围。Aine Fairbrother-Browne 提供了一个很好的例子,展示了如何同时完成这两项工作,通过对英国航空公司数据的研究,突显了航空行业对环境的影响以及改进的最大潜力领域。
-
如何推动交通规划向前发展。如果你对地理空间数据、城市规划或图论感兴趣,不要错过苏坦·穆夫提的实践入门,其中向我们介绍了运输规划者网络分析这一迷人的主题。(苏坦使用的 Python 工具你可能已经熟悉,所以学习曲线可能比你想象的要平缓!)
在我们让你去探索这些重要主题之前,我们想为希望本周探索其他主题的你提供一些额外的阅读推荐:
-
针对在生产中监控 NLP 模型的实践方法,不要错过埃琳娜·萨穆伊洛娃的最新教程。
-
什么是模拟退火,为什么你需要关心它?跟随亨妮·德·哈德的介绍,了解这一强大的优化技术。
-
管理数据团队面临独特的挑战;瑞贝卡·维克里提出了在你的组织中实施最佳实践的六种方法。
-
在他的首篇 TDS 文章中,迪奥戈·莱塔奥探讨了一个对机器学习从业者至关重要的问题:在使用梯度提升树时,何时使用提前停止。
-
来自图形机器学习前沿的迈克尔·加尔金及其合著者任洪宇、迈克尔·科赫兹和朱兆诚展示他们最新的神经图数据库研究。
感谢你本周的时间和支持!如果你喜欢我们发布的内容(并想访问所有内容),考虑成为 Medium 会员。
直到下次可变性,
TDS 编辑们
Taipy:构建用户友好的生产就绪数据科学应用程序的工具
原文:
towardsdatascience.com/taipy-a-tool-for-building-user-friendly-production-ready-data-scientists-applications-80de97aaf7dd
一种简单、快速且高效的方式来构建全栈数据应用程序
Zoumana Keita
·发布在 Towards Data Science ·14 分钟阅读·2023 年 7 月 6 日
--
图片由 Campaign Creators 提供,来源于 Unsplash
介绍
作为数据科学家,你可能希望创建数据可视化的仪表板,展示数据,甚至实现商业应用来协助利益相关者做出可操作的决策。
多种工具和技术可以用于执行这些任务,无论是开源还是专有软件。然而,这些可能由于以下原因而不理想:
-
一些开源技术需要陡峭的学习曲线和聘请具备相应专长的人员。因此,组织可能面临新员工的入职时间增加、培训成本更高以及寻找合格候选人的潜在挑战。
-
其他开源解决方案非常适合原型设计,但无法扩展到生产就绪的应用程序。
-
同样,专有工具也带来了自己的挑战,包括更高的许可费用、有限的自定义和企业难以切换到其他解决方案。
如果有一个不仅是开源的,而且易于学习并能够扩展为完整应用程序的工具,那该多好啊?
这就是 Taipy 发挥作用的地方🎉
本文将解释 Taipy 是什么,并展示它可以解决的一些商业案例,然后再深入探讨其关键特性。此外,它还将说明创建完整 Web 应用程序的所有步骤。
Taipy 是什么,为什么你应该关心它?
这是一个开源的、100% Python 库,只需要基本的 Python 编程知识即可使用。它允许数据科学家、机器学习工程师以及任何其他 Python 程序员迅速将他们的数据和机器学习模型转化为功能齐全的 Web 应用程序。
在当今迅速变化的环境中,对强大、灵活且高效工具的需求变得至关重要,以下是一些使 Taipy 成为独特平台的特性:
-
它不仅仅为试点项目设计,还可以扩展到工业化项目。
-
Taipy 的简单性与强大的功能相结合,使得具有最低编程背景的 Python 开发者可以在短时间内构建强大的解决方案。
-
高度的可定制性使用户能够快速修改和调整 Taipy 的功能以满足他们的需求,这提供了许多开源工具无法提供的个性化体验。
-
Taipy 提供的同步和异步调用允许同时执行多个任务,从而提高了整体性能。
-
Taipy 应用程序可以通过 Python 脚本或 Jupyter Notebooks 开发。
-
借助 Taipy 的管道版本控制功能,用户可以有效地管理不同的项目版本。
Taipy studio 扩展可以安装到 Visual Studio Code 中,从而显著加快 Taipy 应用程序的开发速度。
Taipy 的关键特性
尽管 Taipy 对于前端或后端开发非常出色,但当涉及到开发具有前端和后端组件的完整 Web 应用程序时,其真正的潜力才会显现。
让我们仔细看看它们的主要功能:
Taipy 前端功能
-
创建用户界面只需具备基本的 Python 编程知识。
-
Taipy 旨在用户友好,使用户界面的创建简单直观。
-
不需要网页设计知识,消除了所有 CSS 和 HTML 的先决条件。
-
它利用增强的 Markdown 语法来帮助用户创建他们想要的网页。
Taipy 后端功能
-
Taipy 支持创建强大的管道以处理不同的场景。
-
它使得有向无环图(DAGs)的建模变得简单明了。
-
数据缓存功能提升了 Taipy 应用程序的整体性能。
-
管道执行的注册。
-
管道版本控制。
-
用户可以通过 Taipy 的 KPI 追踪工具跟踪和评估他们应用程序的性能。
-
内置的管道和相关数据的可视化。
开始使用 Taipy
现在你对 Taipy 有了更好的了解,让我们深入探讨一个端到端的实现。
核心Taipy 文档和社区贡献包含相关信息,本文绝不会取代它们,但可以作为了解 Taipy 在实际场景中的一种替代起点。
为了更好地说明我们的案例,我们将使用健康相关数据泄露,这些数据由美国卫生与公众服务部民权办公室维护。它提供了关于 500 多名个人的未加密受保护健康信息泄露的报告信息。
本节将分为两个部分:
-
使用 Taipy 构建一个图形界面,以帮助最终用户对不同类型的漏洞有一个整体概述,从而做出可操作的决策。
-
开发一个 Taipy 后端框架,以与分类机器学习模型互动,以预测给定信息的漏洞类型。
快速安装
使用 Taipy 需要 Python 3.8 或更高版本。使用 Anaconda Python 发行版(conda)和 Visual Studio Code IDE 安装 Taipy,如下所示:
创建名为taipy-env的虚拟环境并安装 Python 3.8
conda create –-name taipy-env python=3.8
激活之前创建的环境
conda activate taipy-env
以下命令将在虚拟环境中安装 taipy 库
pip install taipy
运行 Taipy 应用程序
-
创建一个 Python 脚本文件<taipy_app.py>
-
输入以下代码,然后保存文件:
from taipy import Gui
analytics_choice = ["Breach types distribution",
"Breach by State",
"Top 10 Risky States",
"Covered Entity Type",
""]
choice = ""
my_app_page = """
# Security Breach Analytics Dashboard
## Breach Analysis
Please choose from the list below to start your analysis
<|{choice}|selector|lov={analytics_choice}|dropdown|>
Your choice: <|{choice}|text|>
"""
if __name__ == '__main__':
Gui(page=my_app_page).run(host="0.0.0.0", port=9696)
在 conda 控制台中,从 taipy_app.py 中输入以下命令:
python taipy_app.py
上述代码的成功执行会生成此 URL,并自动打开一个浏览器窗口:
访问应用程序的 URL
作者提供的图片
太棒了!
现在,让我们深入了解之前的代码。
-
导入用于创建仪表盘的 Gui 模块。
-
analytics_choice
是可能选择的列表。 -
然后,变量
choice
将保存来自analytics_choice
的值,这些变量的插值使用<|…|>语法完成。 -
my_page 包含以下 markdown 格式的信息:
-
安全漏洞分析仪表盘的 H1 级别用单个“#”符号表示。
-
漏洞分析的 H2 级别用双“#”符号表示,后跟简单的文本“请选择…分析”
-
我们使用原始的
analytics_choice
和 choice 变量创建一个下拉列表。 -
显示用户做出的选择。
最后,通过指定 my_app_page 以及端口和主机来运行应用程序。不指定服务器端口将使用默认端口(5000)。对于这个特定的示例,应用程序在9696端口打开,网址为http://localhost:9696
从头开始创建一个 Taipy 仪表盘
通过实现一个完整的仪表盘,将我们的 Taipy 知识提升到一个新的水平。仪表盘的主要部分将利用 Taipy 的以下视觉元素:
-
从选项列表中进行选择,使用 选择器。
-
使用 按钮 通过点击按钮触发操作。
-
在表格中显示原始数据。
使用 图表 显示图形结果。
所有这些可视化元素都是通过引入以下 Markdown 语法创建的:
<|{variable}|visual_element_name|param1=param1|param2=param2|…|>
最终仪表板将如下所示,最终的源代码将在文章末尾提供。
使用 Taipy GUI 创建的最终仪表板(作者提供的图像)
为了进行逐步演示,将在单独的文件中提供每个组件的示例,并使用以下命令运行每个文件:
python file_name.py
选择器
这些选项允许用户从下拉列表中选择,这与我们在“运行 Taipy 应用程序”部分中实现的功能相对应。
按钮和表格
用户界面中的按钮在被点击或按下时会启动特定的功能。on_action 函数会在按钮被按下时触发。
表格用于组织数据,提供三种显示模式:分页、allow_all_rows、unpaginated 和 auto_loading。 官方文档 提供了关于这些模式的更多信息。
创建一个新文件 button.py
,并包含以下代码:
from taipy import Gui
import pandas as pd
breach_data = pd.read_csv("data/breach_report_data.csv")
def toggle_table_dialog(state):
state.show_table_dialog = not state.show_table_dialog
show_table_dialog = False
my_app_page = """
<center> Security Breach Analytics Dashboard</center>
------------------------------
<br/>
<center> Click the Button below to display data </center>
<br/>
<center><|Display Raw Data|button|on_action=toggle_table_dialog|></center>
<|{show_table_dialog}|dialog|on_action=toggle_table_dialog|width=90vw|labels=Cancel|
<center><|{breach_data}|table|width=fit-content|height=65vh|></center>
|>
"""
我们首先将违约数据加载到 Pandas 数据框中。然后,选择“显示原始数据”将所有数据以表格格式展示,如下所示:
使用 Taipy 创建的按钮结果(作者提供的图像)
图表
通过更好地理解上述组件,我们可以将它们结合起来创建图表,基于全面的 plotly.js 图形库。否则,Taipy 的文档 提供了很好的示例作为起点。与前一部分类似,创建一个 charts.py
文件并包含以下代码:
创建一个条形图,其中 State 位于 x 轴
,Proportion 位于 y 轴
。
# import libraries here
my_app_page = """
<center> Security Breach Analytics Dashboard</center>
------------------------------
<center> Graph 3: Top 10 Most Affected States</center>
<br/>
<|{breach_df}|chart|type=bar|x=State|y=Individuals_Affected|>
"""
# Put the '__main__' section here
最终结果是这个动态图表,显示了按 State 受影响的个人数量,似乎加利福尼亚州受影响最严重。
使用 Taipy 的图表(作者提供的图像)
显示图像
在 Taipy GUI 中显示图像也很简单,可以使用 image
属性实现。以下代码展示了由 generate_word_cloud
生成的词云。图像的宽度为 2400 像素,高度为 1000 像素。当用户的鼠标悬停在图像上时,将显示 hover_text
属性的值:在这种特定情况下为 “违约地点的词云”。
<|{breach_location_image}|image|width="2400px"|height="1000px"|hover_text="Word cloud of Breach Location"|>
违约信息的位置词云(作者提供的图像)
此外,辅助函数generate_word_cloud
的定义如下:
from wordcloud import WordCloud
from PIL import Image
from io import BytesIO
def generate_word_cloud(data, column_name):
# Join all the location information into one long string
text = ' '.join(data[str(column_name)])
wordcloud = WordCloud(
background_color="#1E3043"
)
# Generate the word cloud
my_wordcloud = wordcloud.generate(text)
image = my_wordcloud.to_image()
my_buffer = BytesIO()
image.save(my_buffer, format = 'PNG')
return my_buffer.getvalue()
回调函数
目标是拥有一个基于用户选择动态更新的 GUI。通过使用 Taipy 的回调函数实现,这些函数会自动触发局部命名空间中的任何on_change
函数作为全局回调函数。实现如下:
def update_Type_of_Breach(state, var_name, var_value):
if var_name == "Type_of_Breach":
state.df = breach_df[breach_df.Type_of_Breach == var_value]
布局
多个图表可以提供有价值的商业洞察,但将它们垂直展示一个接一个可能不是最有效的方法。
相反,我们可以创建一个布局,将组件组织成一个规则网格,放置在layout.start
和layout.end
块之间。每个组件都在part.start
和part.end
块内创建。
以下基本语法创建了一个 2 列网格,根元素的字体大小为 1.8:
<|layout.start|columns= 1 2|gap=1.8rem|
<optional_id|part|>
<|{first content}|>
|optional_id>
…
<
<|{second content}|>
>
>
理解布局后,我们可以创建最终的仪表板,其中包含五个主要图表:
-
图表 1 展示了与漏洞信息位置相关的词云。
-
图表 2 显示了按州受影响的人员数量。
-
图表 3 确定了按漏洞类型受影响的总人数。
-
图表 4 展示了每年受影响的总人数。
-
图表 5 显示了每个覆盖实体的受影响人数。
# Preprocessing of the DateTime column
breach_df['Breach_Submission_Date'] = pd.to_datetime(breach_df['Breach_Submission_Date'])
breach_df["Year"] = breach_df["Breach_Submission_Date"].dt.year
markdown = """
<|toggle|theme|>
# <center>Security Breach Analytics Dashboard 🚨</center>
<center>**Chart 1:**General Trend Location of Breached Information </center>
<center><|{breach_location_image}|image|width=2400px|height=1000px|hover_text=Word cloud of Breach Location|></center>
------------------------------
<|layout|columns=2 5 5|gap=1.5rem|
<column_1|
### Type of Breach:
<|{breach_type}|selector|lov={breach_types}|dropdown|width=100%|>
------------------------------
<|Display Raw Data|button|on_action=toggle_table_dialog|>
<|{show_table_dialog}|dialog|on_action=toggle_table_dialog|width=90vw|labels=Cancel|
<center><|{breach_df}|table|width=fit-content|height=65vh|></center>
|>
|column_1>
<column_2|
**Chart 2:** Individuals Affected by State
<|{df}|chart|type=bar|x=State|y=Individuals_Affected|>
**Chart 4:** Individuals Affected by Year
<|{df}|chart|type=bar|x=Year|y=Individuals_Affected|>
|column_2>
<column_3|
**Chart 3:** Individuals Affected by Type of Breach
<|{df}|chart|type=bar|x=Type_of_Breach|y=Individuals_Affected|>
**Chart 5:** Individuals Affected per Covered Entity Type
<|{df}|chart|type=bar|x=Covered_Entity_Type|y=Individuals_Affected|>
|column_3>
|>
"""
if __name__ == "__main__":
gui = Gui(page=markdown)
gui.run(dark_mode=False, host="0.0.0.0", port=9696)
在配置仪表板之前,从Breach_Submission
列创建一个新的Year
列,然后将其用作图表 4 中的 x 轴。
运行所有代码应生成上面展示的第一个仪表板。
Taipy 后端运作情况
在下一节中,你将使用 Taipy 的后端功能轻松高效地创建、管理和执行数据管道,以训练一个随机森林分类器,从而确定给定数据的漏洞类型。
本节分为两个主要部分。首先,你将使用 Taipy Studio 构建完整的工作流图形表示。然后,编写相应的 Python 代码。
Taipy Studio
Taipy Studio 是 Visual Studio Code 的一个扩展,安装方法如下:
Taipy 安装过程(图像由作者提供)
安装完成后重启 VSCode,然后点击左下角的 Taipy 图标将显示 Taipy Studio 界面。这将显示四个主要标签,如配置文件、数据笔记、任务、管道和场景。
Taipy Studio 界面(图像由作者提供)
所有这些标签都可以用来实现我们的端到端管道目标,第一步是创建一个配置文件(taipy_config.toml),该文件将包含所有这些标签,这些标签由选择“Taipy: Show View”图标后右上角的 4 个图标表示。
Taipy Studio 组件(图像由作者提供)
Taipy 标签说明
以下是将要实现的主要函数,并附有对每个先前选项卡的简要说明。
-
filter_columns
函数负责从数据中选择相关列并生成 Pandas 数据框。 -
preprocess_columns
用于执行特征工程。 -
encode_features
负责以正确的格式对相关特征进行编码。 -
split_data
是将数据拆分为训练集和测试集的函数。 -
train_model
用于训练模型。 -
show_performance
是展示模型性能的最后阶段。
场景和管道
这是设置管道时要做的第一件事。一个场景由一个或多个管道组成。它作为执行的注册表。让我们创建一个名为 DATA_BREACH_SCENARIO 的场景,然后创建一个名为 DATA_BREACH_PIPELINE 的管道,如下所示:
从场景到管道(作者提供的图像)
任务
一个任务指的是一个可以执行的 Python 函数,总共会实现六个任务,从 filter_columns
到 show_performance
。
管道的输出连接到每个任务的输入,如下所示:
从管道到任务
下一步是在 Taipy Studio 中配置这些任务,通过将每个 Python 函数连接到相应的任务。但是在此之前,我们需要在 data_breach_tasks.py
文件中创建这些函数的签名,如下所示:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
confusion_matrix,
accuracy_score,
precision_score,
recall_score,
f1_score
)
def filter_columns(df, list_columns_to_skip):
filtered_df = df.drop(list_columns_to_skip, axis=1)
return filtered_df
def preprocess_columns(df):
df['Breach_Submission_Date'] = pd.to_datetime(data['Breach_Submission_Date'])
df['Breach_Submission_Month'] = df['Breach_Submission_Date'].dt.month
df['Breach_Submission_Year'] = df['Breach_Submission_Date'].dt.year
df.drop("Breach_Submission_Date", axis=1, inplace=True)
return df
def encode_features(df):
list_columns_to_encode = ['State','Location_of_Breached_Information',
'Business_Associate_Present',
'Covered_Entity_Type']
le = LabelEncoder()
for col in list_columns_to_encode:
df[col] = le.fit_transform(df[col])
X = df.drop('Type_of_Breach', axis=1)
y = le.fit_transform(df['Type_of_Breach'])
return {"X": X, "y": y}
def split_data(features_target_dict):
X_train, X_test, y_train, y_test =
train_test_split(features_target_dict["X"],
features_target_dict["y"],
test_size=0.3,
random_state=42)
return {
"X_train": X_train, "X_test": X_test,
"y_train": y_train, "y_test": y_test
}
def train_model(train_test_dictionary):
classifier = RandomForestClassifier()
classifier.fit(train_test_dictionary["X_train"],
train_test_dictionary["y_train"])
predictions = classifier.predict(train_test_dictionary["X_test"],
train_test_dictionary["y_test"])
return predictions
def show_performance(train_test_dictionary, predictions):
y_test = train_test_dictionary["y_test"]
accuracy = accuracy_score(y_test, predictions)
precision = precision_score(y_test, predictions)
recall = recall_score(y_test, predictions)
f1score = f1_score(y_test, predictions)
return pd.DataFrame({
"Metrics": ['accuracy', 'precision', 'recall', 'f1_score'],
"Values": [accuracy, precision, recall, f1score]
})
接下来,我们按照以下 3 个步骤将每个任务链接到相应的 Python。下面的插图是针对 filter_columns
任务的,但必须对每个任务执行。
将任务链接到脚本的 3 个主要步骤(作者提供的图像)
数据节点
数据节点不包含实际数据,而是包含读取和写入这些数据所需的所有信息。它们可以是对任何数据类型的引用,例如文本、CSV、JSON 等。
例如,filter_columns
函数包含:
-
一个输入节点 (filtering_node),其类型为 .CSV 文件,以及
-
一个输出节点 (filtered_df):也以 .CSV 文件的形式存储。这然后作为 preprocess_columns 函数的输入。
交互节点定义如下,显示了存储类型从 pickle 修改为 .csv:
过滤节点输入类型的定义(作者提供的图像)
更新后的过滤节点的输入类型(作者提供的图像)
下一步是定义原始输入数据集的路径。这是通过数据节点中的“新属性”属性完成的。然后,按 Enter 并提供 .CSV 文件的路径。
过滤节点属性的定义(作者提供的图像)
过滤节点路径的定义(图片来源于作者)
对于所有需要.csv 文件的输入,重复相同的过程,最终图示在指定所有数据节点及其关系后将如下所示。
指定所有数据节点及其关系后的工作流程状态(图片来源于作者)
在配置管道后,整个图示的.toml 脚本格式会生成在taipy_config.toml文件中,其样式如下面的动画所示。
taipy_config.toml 文件的内容
然后,可以在任何 Python 脚本中加载这个.toml 文件来执行管道。我们来创建一个名为run_pipeline.py
的文件。
from taipy import Core, create_scenario
from taipy.core.config import Config
config_file_name = "./taipy_config.toml"
scenario_name = "DATA_BREACH_SCENARIO"
Config.load(config_file_name)
scenario_config = Config.scenarios[scenario_name]
if __name__ == "__main__":
Core().run()
pipeline_scenario = create_scenario(scenario_config)
pipeline_scenario.submit() # This executes the scenario
model_metrics = pipeline_scenario.performance_data.read()
print(model_metrics)
我们首先导入相关模块,然后定义配置文件及触发场景的名称。
然后,使用 submit()函数执行管道。
最后,我们检索模型的性能并打印结果,如下所示:
run_pipeline.py 的结果(图片来源于作者)
这个数据框可以进一步整合到初始仪表盘中,以图形化的方式展示数值。
结论
本文提供了对 Taipy 的全面概述,并展示了如何将前端和后端与任何数据和机器学习模型结合起来,创建完全功能的 Web 应用程序。
此外,随着新版本的发布,Taipy 提供了核心可视化元素,允许前端和后端之间的无缝集成,使用户能够轻松创建强大的业务对象,这些集成功能可以从官方网站获取。
如果你还在犹豫是否使用 Taipy,是时候尝试一下,以节省时间、精力,最重要的是金钱。最后,这些绝妙的教程可以帮助你进一步学习并提升技能。
迈出下一步,扩展你的数据科学技能
原文:
towardsdatascience.com/take-the-next-step-to-expand-your-data-science-skill-set-f9d3beb0652e?source=collection_archive---------6-----------------------#2023-11-09
TDS 编辑
·
关注 发布于 Towards Data Science · 发送至 通讯 · 阅读时间 3 分钟 · 2023 年 11 月 9 日
--
从有效的讲故事技巧到战略性的职业规划,你在数据科学职业生涯中所需的技能各异且日益跨学科。与统计或编程不同,这些领域的结构化教育环境,如学位课程或训练营,只能提供有限的帮助。
幸运的是,TDS 作者来自各种专业和个人背景,并在实际组织中学到了许多关于什么有效、什么无效的经验。本周的重点,我们汇聚了一系列强有力的文章,关注这些不那么技术性的但同样重要的数据科学工作方面:它们基于这些专业人士的实际经验提供了宝贵的见解。
-
哦,你是说“管理变革”? 在数据科学迅速变化的领域中航行总是一个挑战,尤其是在应对业务优先级变化的组织中。Marc Delbaere 揭示了在数据团队背景下实施变革管理的困难,并探讨了领导者和个体贡献者如何平衡他们有时冲突的目标。
-
数据讲故事中的 4D:将科学变为艺术 “ 数据科学家的工作如果没有讲故事,纯粹是数字化的算命,” Zijing Zhu, PhD 如是说,他接着分享了一个详细的框架,该框架超越了可视化基础,帮助从业者更高效、更有影响力地传达数据见解。
-
了解你的受众:技术演示准备指南 从不同的角度探讨数据讲故事,John Lenehan 深入分析了数据驱动的演示文稿准备工作,并提供了具体建议,帮助你以引人入胜的方式结构化见解,吸引同事和利益相关者并解决他们的关注点。(完成后,你还应该探讨 John 的后续文章,关于 将数据转化为连贯叙事。)
图片由 Terri Bleeker 提供,来源于 Unsplash
-
在掌握这 6 项必备数据科学技能之前不要申请科技职位在她最新的职业指导中,Khouloud El Alami 聚焦于数据科学家需要在一些关键领域扎实掌握的技能,以便成为主要科技公司职位的竞争者。根据她在 Spotify 的经验,Khouloud 涵盖了技术和非技术技能,并解释了它们如何最终必须与可衡量的影响连接起来。
-
初级数据科学家的 3 个关键职业决策尽管这个领域相对较新,但数据专业人员现在已经有了相当成熟的职业路径——但是如果你不想让你的职业跟随这些路径怎么办?Matt Chapman的新帖子邀请早期职业者审视自己的优先事项,反思对他们真正重要的东西,并根据他们所追求的角色类型来塑造他们的选择。
如果你也希望拓展你的技能和知识,我们为你准备了一些优秀的阅读材料,你不容错过。
-
想要了解推荐系统的新视角,请访问 Irene Chang的 Thompson 采样简介,这是关于这些无处不在的算法工具系列的第一部分。
-
通过关注 Caroline Arnold 的易懂且图文并茂的入门介绍,了解 伪随机数及其在机器学习中的作用。
-
Michael Galkin及其合著者介绍了他们最新的研究,ULTRA,一个 知识图谱推理的基础模型。
-
在最近的一次深入探讨中,Ms Aerin 探索了 RLHF(来自人类反馈的强化学习)及其作用 在训练数据和学习范式中,推动了最近的大型语言模型的进展。
-
我们非常高兴欢迎新的Carolina Bento解释员:一个隐藏马尔可夫模型的实践指南,配有直观的实现。
-
在他的首篇 TDS 文章中,Jon Flynn提供了关于当前 AI 持续学习方法的清晰全面的介绍,以及它们如何旨在应对保持模型最新这一(主要)挑战。
感谢支持我们作者的工作!如果你喜欢在 TDS 上阅读的文章,可以考虑成为 Medium 会员——这将解锁我们的整个档案(以及 Medium 上的其他每一篇文章)。
直到下一期 Variable,
TDS 编辑部
什么是时间序列预测中的谐波回归?
原文:
towardsdatascience.com/take-your-forecasting-to-the-next-level-with-harmonic-regression-5a8515f63295
揭示傅里叶级数与时间序列之间的迷人关系
Egor Howell
·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 3 月 29 日
--
图片来自于Pawel Czerwinski在Unsplash
背景与问题
当我们想要对时间序列中的季节性进行建模时,我们通常会使用SARIMA模型。这个模型通过在特定滞后索引处找到自回归量和移动平均来添加季节性组件到ARIMA模型中。例如,具有年度季节性的月度数据将适配自回归量和移动平均于12的倍数。你可以在我之前的文章中阅读更多关于这个过程的信息:
## 如何使用 SARIMA 进行预测
深入探讨 SARIMA 模型及其应用
pub.towardsai.net
然而,如果我们有一个年季节性为365.25 天的日数据集怎么办?或者有一个季节性为52.14的周数据集呢?
不幸的是,SARIMA 无法处理这种 非整数 的数据,并且由于需要在 365 个数据点中寻找模式,计算上也 有困难。
那么,我们该怎么办?
使用 傅里叶级数 来拯救局面!
补充视频。
什么是傅里叶级数?
直觉
傅里叶级数是数学中最有趣的发现之一,它指出 这个:
任何周期函数都可以分解为正弦和余弦波的总和
这是一个非常简单的陈述,但其含义非常重要。
例如,下图展示了 sin(2x) 和 cos(3x) 及其对应的总和:
作者在 Python 中生成的图。
注意到 sin(2x) 和 cos(3x) 的函数非常均匀和简单,但它们的总和(红线)却产生了更复杂的模式。这就是傅里叶级数的核心思想。
我们甚至可以使用傅里叶级数通过将不同 奇数 频率和幅度的正弦波 (谐波) 相加来构造一个 方波:
作者在 LaTeX 中的方程。
作者在 Python 中生成的图。
令人惊讶的是,我们从光滑的正弦函数中生成了一条尖锐而直的图。这展示了傅里叶级数构造任何周期函数的真正力量。
用于生成这些图的代码可以在我的 GitHub 上找到:
[## Medium-Articles/fourier_series.py 在 egorhowell/Medium-Articles 主分支]
我在我的中等博客/文章中使用的代码。通过创建账户来贡献于 egorhowell/Medium-Articles 的开发...
github.com](https://github.com/egorhowell/Medium-Articles/blob/main/Time Series/Time Series Tools/fourier_series.py?source=post_page-----5a8515f63295--------------------------------)
理论
如上所述,傅里叶级数指出任何周期函数都可以分解为正弦和余弦波的总和。从数学角度来看,这是写作的:
作者在 LaTeX 中的方程。
其中:
-
A_0: 给定周期函数的平均值
-
A_n: 余弦分量的系数
-
B_n: 正弦分量的系数
-
n: 阶数,即正弦或余弦波的频率,这被称为‘谐波’
-
P: 函数的周期
周期P和阶数n是事先已知的。然而,需要计算系数(A_0, A_n, B_n)来确定哪些正弦和余弦分量的组合产生给定的周期函数。这些通常通过积分推导得出(有关示例请参见这里),但幸运的是,大多数 Python 数据科学包为我们完成了这个过程!
预测链接
你是否在想傅里叶级数如何适用于时间序列预测?好吧,请记住傅里叶级数处理周期函数,而我们经常发现时间序列中包含某些周期性结构(通常是季节性)。因此,我们可以使用傅里叶级数来建模我们时间序列数据中的任何复杂季节性模式!
使用傅里叶级数建模季节性的优点是:
-
任何季节长度
-
建模多个季节模式
-
傅里叶季节性的敏感性可以通过正弦和余弦分量的阶数和振幅进行调节
-
当季节性周期 大于~200 时计算效率高
许多这些优势无法通过 SARIMA 模型实现,因为它仅接受整数季节性、单一季节,并且当季节周期超过~200时,往往会耗尽内存。
使用傅里叶级数建模季节性的缺点是:
- 假设季节性模式和周期保持不变
现在的问题是,我们如何将其添加到我们的模型中?
ARIMAX 与外生特征
直觉
对于 ARIMA 模型,我们可以添加额外的外部特征来帮助预测。这些特征被称为 外生特征,使得 ARIMA 模型变成一个 ARIMAX 模型。例如,我们可以在预测房屋价值时使用当前的利率作为外生特征。
你可以将 ARIMAX 模型看作是常规的线性回归加上自回归量和移动平均组件(内生变量)。关键是让傅里叶级数成为这些外生特征之一,或者像线性回归中常描述的那样成为解释变量。
由于我们处理的是时间序列,外生特征需要像自回归量和移动平均一样具有时间索引。它们也需要在预测时已知。例如,如果我们想预测 5 月的房价,我们需要知道 5 月的利率,如果我们希望将其作为外生特征。
理论
数学上,外生特征以以下方式添加到经典 ARIMA 模型中:
作者用 LaTeX 给出的方程。
-
y: 不同时间步长的时间序列/滞后
-
x: 外生特征
-
β: 外生特征的系数
-
ϕ: 自回归组件(滞后)的系数 自回归组件(滞后)
-
p: 自回归组件的数量
-
ε: 预测 误差项, 移动平均组件
-
θ: 滞后预测误差的系数
-
q: 滞后误差组件的数量
傅里叶级数特征
将傅里叶级数作为外生量添加到 ARIMA 模型中相对简单,因为系数/幅度β是为我们推导的,我们只需要提供相应的正弦和余弦项。在伪代码中,这等同于:
# Sine component
sin(2*pi*frequency*time_index/period)
# Cosine component
cos(2*pi*frequency*time_index/period)
例如,假设我们有每月数据并且有年度季节性,我们想要 5 月的傅里叶分量。在伪代码中,这将是:
# Sine component
sin(2*pi*frequency*5/12)
# Cosine component
cos(2*pi*frequency*5/12)
五月是第 5 个月,一年有 12 个月。
然而,我们仍然需要推导频率(阶数)值。这通常通过传递大量正弦和余弦组件的阶数来找到,并让模型找出最有用的阶数。在下面的 Python 示例中,我们将演示这个过程。
Python 实现
我们将使用谐波回归和 ARIMAX 进行一些实际的预测!我们将使用来自 Kaggle 的美国航空公司乘客数据集。
数据 来自 Kaggle 采用 CC0 许可证。
作者总结。
由作者在 Python 中生成的图。
正如我们所看到的,傅里叶级数已经很好地捕捉到了季节性!
注意:在上面的代码中,我们使用了 Box-Cox 变换以使方差平稳。你可以在这里了解更多关于这个过程的内容。
总结与思考
当你的时间序列的季节性是非整数的、有众多模式或非常长(>50 点)时,最好使用傅里叶级数来建模这一季节性组件。这可以通过将傅里叶级数作为外生特征添加到常规 ARIMA 模型中,将其转化为 ARIMAX 来实现。这些外生特征是辅助预测时间序列的外部协变量。
完整代码可以在我的 GitHub 上找到:
## GitHub - egorhowell/Medium-Articles: 代码我在我的 Medium 博客/文章中使用。
你目前无法执行该操作。你在另一个标签或窗口中登录了。你在另一个标签或窗口中注销了…
github.com
还有别的事情!
我有一个免费的时事通讯,Dishing the Data,在其中我每周分享成为更好数据科学家的技巧。没有“无关紧要的内容”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作的见解。
## Dishing The Data | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读《Dishing The Data》,由 Egor Howell 撰写的 Substack 出版物,包含…
newsletter.egorhowell.com
连接我!
-
YouTube
-
LinkedIn
-
Twitter
-
GitHub
参考文献和进一步阅读
- 预测:原则与实践:
otexts.com/fpp2/
使用 LangChain 和 Azure OpenAI 与您的 SQL 数据库“对话”
原文:
towardsdatascience.com/talk-to-your-sql-database-using-langchain-and-azure-openai-bb79ad22c5e2?source=collection_archive---------0-----------------------#2023-09-28
探索使用 LLM 处理数据库查询的自然语言处理的强大功能
Satwiki De
·
关注 发表在 Towards Data Science · 10 分钟阅读 · 2023 年 9 月 28 日
--
LangChain 是一个开源框架,用于开发可以使用 LLM(大型语言模型)处理自然语言的应用程序。
LangChain 的 Agent 组件是一个 LLM 的封装器,它决定解决问题的最佳步骤或行动。Agent 通常可以访问一组称为 Tools(或工具包)的功能,并且可以根据用户输入决定使用哪个工具。每个代理都可以执行各种 NLP 任务,例如解析、计算、翻译等。
Agent Executor 是 Agent 及其工具集的可运行接口。Agent 执行器负责调用 Agent、获取操作和操作输入、调用操作引用的工具及其对应输入、获取工具的输出,然后将所有这些信息传回 Agent 以获取下一步操作。通常这是一个迭代过程,直到 Agent 达到最终答案或输出。
在这篇文章中,我将展示如何使用 LangChain Agent 和 Azure OpenAI gpt-35-turbo 模型 通过自然语言查询你的 SQL 数据库(无需编写任何 SQL!)并获取有用的数据洞察。我们将使用 SQL Database Toolkit and Agent,它可以将用户输入转换为适当的 SQL 查询,并在数据库中执行以获取答案。
这是一篇探索性文章。它旨在提供对当前可用工具的概述,并在过程中识别任何挑战。
作者提供的图片(使用 Bing 图像创作者创建)
需求范围
在这次探索中,我们只读取数据库中的数据,避免任何插入、更新或删除操作。这是为了保持数据库中的数据完整性。我们将重点关注如何利用数据库中可用的数据来回答问题。
不过,SQL Agent 不保证它不会根据特定的问题对你的数据库执行任何 DML 操作。确保不发生任何意外 DML 操作的一种方法是创建一个仅具有 *read*
权限的数据库用户,并在以下代码中使用它。
让我们以一个电子零售公司的订单和库存系统数据库为例。库存跟踪多个类别的产品,例如厨房、园艺、文具、浴室等。订单系统记录每个产品的购买历史,包括订单状态、交货日期等。
以下可能是该应用程序的最终用户提出的一些问题:
-
上个月销售的厨房产品数量。
-
有多少订单尚未发货?
-
上个月有多少订单延迟交付?
-
上个月销售的前三大产品是什么?
设置
-
Python>=3.8 和一个 IDE 进行我们的探索。我使用 VS Code。
-
我在这里使用 Azure OpenAI gpt-35-turbo 作为 LLM。这个模型是 GPT-3.5 系列的一部分,可以理解和生成自然语言及代码。要跟随本指南,你需要一个启用了 OpenAI 服务的 Azure 订阅。了解更多 here。
-
我在这里使用一个 Azure SQL 数据库。然而,你也可以使用本地 SQL 数据库。
数据库
我创建了一个名为retailshopdb
的数据库,包含以下表格和关系:
-
类别
-
产品
-
订单
作者提供的图片
除了‘Id’列作为每个表的主键外,这些表还有相互之间的外键关系,例如 CategoryId 是 Product 表中的外键,ProductId 是 Orders 表中的外键。这些关系对于 LangChain 代理根据最终用户的问题构造 SQL 查询至关重要。
Azure OpenAI
如果你在订阅中创建了 Azure OpenAI 资源,请导航到 Azure OpenAI Studio。为 gpt-35-turbo
模型创建一个 deployment
。
作者提供的图片
代码与输出分析
让我们从一些基础代码开始,以便在 VS Code Python 笔记本中访问 LLM。
- 通过安装所需的库并设置所需的环境变量来初始化笔记本。
%pip install langchain openai sqlalchemy
import os
from dotenv import load_dotenv
os.environ["OPENAI_API_TYPE"]="azure"
os.environ["OPENAI_API_VERSION"]="2023-07-01-preview"
os.environ["OPENAI_API_BASE"]="" # Your Azure OpenAI resource endpoint
os.environ["OPENAI_API_KEY"]="" # Your Azure OpenAI resource key
os.environ["OPENAI_CHAT_MODEL"]="gpt-35-turbo-16k" # Use name of deployment
os.environ["SQL_SERVER"]="" # Your az SQL server name
os.environ["SQL_DB"]="retailshop"
os.environ["SQL_USERNAME"]="" # SQL server username
os.environ["SQL_PWD"]="{<password>}" # SQL server password
2. 连接到数据库。
from sqlalchemy import create_engine
driver = '{ODBC Driver 17 for SQL Server}'
odbc_str = 'mssql+pyodbc:///?odbc_connect=' \
'Driver='+driver+ \
';Server=tcp:' + os.getenv("SQL_SERVER")+'.database.windows.net;PORT=1433' + \
';DATABASE=' + os.getenv("SQL_DB") + \
';Uid=' + os.getenv("SQL_USERNAME")+ \
';Pwd=' + os.getenv("SQL_PWD") + \
';Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;'
db_engine = create_engine(odbc_str)
3. 初始化 LangChain chat_model
实例,它提供了使用聊天 API 调用 LLM 提供者的接口。选择聊天模型的原因是 gpt-35-turbo
模型已针对聊天进行了优化,因此我们在这里使用 AzureChatOpenAI
类来初始化实例。
from langchain.chat_models import AzureChatOpenAI
llm = AzureChatOpenAI(model=os.getenv("OPENAI_CHAT_MODEL"),
deployment_name=os.getenv("OPENAI_CHAT_MODEL"),
temperature=0)
请注意,
temperature
设置为 0。温度是一个控制生成文本的“创造力”或随机性的参数。较低的温度(0 为最低)使输出更“专注”或确定。由于我们处理的是数据库,LLM 生成的响应必须是事实性的。
4. 创建提示模板。
提示 是我们发送给 LLM 以生成输出的输入。提示也可以设计为包含指令、上下文、示例(单次或少次),这些都对生成准确的输出至关重要,同时也可以设置语气和格式化输出数据。
使用提示模板是结构化这些属性的好方法,包括提供给 LLM 的最终用户输入。我们在这里使用 LangChain 的 ChatPromptTemplate
模块,它基于 ChatML(聊天标记语言)。
这是一个基础提示模板,供我们开始使用。随着时间的推移,我们会根据需要更新此模板 —
from langchain.prompts.chat import ChatPromptTemplate
final_prompt = ChatPromptTemplate.from_messages(
[
("system",
"""
You are a helpful AI assistant expert in querying SQL Database to find answers to user's question about Categories, Products and Orders.
"""
),
("user", "{question}\n ai: "),
]
)
现在初始化 create_sql_agent
,它设计用于与 SQL 数据库交互,如下所示。该代理配备了与 SQL 数据库连接并读取表的元数据和内容的工具包。
from langchain.agents import AgentType, create_sql_agent
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
db = SQLDatabase(db_engine)
sql_toolkit = SQLDatabaseToolkit(db=db, llm=llm)
sql_toolkit.get_tools()
sqldb_agent = create_sql_agent(
llm=llm,
toolkit=sql_toolkit,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True
)
请注意,我们在这里使用 ZERO_SHOT_REACT_DESCRIPTION
作为 agent_type
参数的值,这指示代理 不使用记忆。
一切准备就绪进行测试 —
sqldb_agent.run(final_prompt.format(
question="Quantity of kitchen products sold last month?"
))
注意在以下单元格输出中,LangChain Agent Executor 如何以迭代的方式使用
Action
、Observation
和Thought
流,直到达到Final Answer
。
作者提供的图片
输出: 10
这是一个正确的答案。
让我们来点乐趣吧?将问题中的‘Quantity’替换为‘how many’,这应该会产生相同的答案。
sqldb_agent.run(final_prompt.format(
question="How many kitchen products were sold last month?"
))
但这就是我们得到的结果 —
作者提供的图片
输出: ‘本月销售了 2 件厨房产品。’
这个输出不正确! 代理在创建 SQL 查询时犯了错误。它没有执行 SUM(ProductOrderedQuantity)
来获取输出,而是对 JOIN 结果执行了 COUNT(*)
,这导致了错误的输出。
为什么稍微改变提示输入会产生不同的输出?
OpenAI 模型是非确定性的,这意味着相同的输入可能会产生不同的输出。将温度设置为 0 会使输出大多是确定性的,但由于 GPU 浮点运算,可能仍会存在少量的变异。
使用不同的输入运行另一个测试 —
sqldb_agent.run(final_prompt.format(
question="How many orders have not been shipped yet?"
))
图片由作者提供
输出:‘还有 15 个订单尚未发货。’
这又是错误的结果。 代理也考虑了‘已完成’的订单,而我们的提问仅涉及尚未发货的订单。
让我们看看可以做哪些修改以生成准确的输出。
玩转 Prompt Engineering
LangChain 代理可以使用其工具包从 SQL 数据库中读取表格元数据,并且在一定程度上可以解释列名。但是仍然存在一些推理上的差距,我们可以尝试使用 Prompt Engineering 技术来弥补这些差距。
我们从一个基础提示模板开始,该模板只有一行指令。让我们添加一些额外的信息,以便向 LLM 提供更多关于我们用例的上下文,以形成更好的 SQL 查询。以下是我在系统消息中添加的高层次信息:
-
表格列的信息
-
订单状态值的‘含义’
-
最具体的信息在最后
from langchain.prompts.chat import ChatPromptTemplate
final_prompt = ChatPromptTemplate.from_messages(
[
("system",
"""
You are a helpful AI assistant expert in identifying the relevant topic from user's question about Categories, Products and Orders and then querying SQL Database to find answer.
Use following context to create the SQL query. Context:
Product table contains information about products including product name, description, product price and product category.
Category table contains information about categories including category name and description. Each Product is mapped to a Category.
Orders table contains information about orders placed by customers including
quantity or number of products ordered,
expected delivery date and actual delivery date of the Order in the location and the status of the order.
Order status = 'Processing' means the order is being processed by seller and not yet shipped,
Order status = 'Shipped' means the order is shipped by the seller and is on the way to the customer,
Order status = 'Completed' means the order is delivered to the customer, and
Order status = 'Cancelled' means the order is cancelled by the customer.
If the question is about number of products in an order, then look for column names with 'quantity' in the tables and use 'sum' function to find the total number of products.
"""
),
("user", "{question}\n ai: "),
]
)
现在重新运行第一个输入 —
sqldb_agent.run(final_prompt.format(
question ="How many Kitchen products were sold in current month?"
))
图片由作者提供
输出:10
推理得到了改善!通过向 LLM 提供额外的上下文,我们能够获得准确的输出。
现在测试所有用户输入 —
思考之后 . . .
这一切在测试时看起来都很不错,但如果我们想实际构建一个解决方案并将其发布供最终用户使用呢?这是一个很好的想法,适用于类似于在自己数据库上运行聊天机器人的用例,但像任何典型的软件开发一样,在构建基于 LLM 的系统之前,我们还需要考虑和决定一些关键的设计方面。
可扩展性
在这个例子中,我使用了 3 个表,总共约 30 行。涉及连接所有 3 个表的输出的平均延迟约为 5 秒。截至目前,我在官方文档中没有找到关于我们可以使用的数据库最大大小的信息。然而,我们可以考虑一些参数来确定我们的需求:
-
您的应用程序需要的延迟是多少?如果您正在构建一个聊天机器人,那么您的预期延迟可能不会超过某个数值。
-
您的数据库大小是多少?还要考虑您希望用于查询的各个表的大小。
请注意,你不需要将整个数据库传递给 Agent Toolkit。可以选择特定的表与 Toolkit 一起使用。一个好的选择是为不同的用例识别表的子集,并创建多个指向不同表子集的 Agent。
3. Azure OpenAI 资源的速率和配额限制。如果你使用的是其他 LLM 提供商,也请查看那里可能存在的限制/约束。
可靠性
我们如何确保始终获得准确的响应?如何确保系统不会出现幻觉或产生完全意外的内容?
当前正在研究如何提高 LLM 的可靠性和鲁棒性。通过使用特定用例的提示,我们可以帮助改善我们的用例的推理,这有时也被称为“临时学习”或“上下文学习”。
请记住,我们在这里不是在训练 LLM。从使用预训练 LLM 构建产品的角度,我们只能有针对性地调整我们的代码、模型参数和构建在这些 LLM 上的提示。
以迭代方式进行开发,并在过程中进行评估,可以将我们引向正确的方向,开发一个整体有效的系统。
日志记录和监控
像其他典型的软件开发一样,启用日志记录和持续监控 LLM 基础的应用程序是一种好习惯。监控不仅可以捕获系统相关的指标,如性能、延迟、请求-响应率,还可以捕获我们系统的输入和输出,这可以帮助我们确定系统的一致性。从监控中收集的一些有用信息可以用来改进我们的系统:
-
终端用户频繁提出类似问题,以及这些问题生成的输出
-
LLM 的幻觉率
结论
软件工程领域正在迅速变化,LLM 的巨大生成能力带来了许多解决方案。我们有机会采用这项技术,利用其力量创建产品,同时保持对 LLM 支持系统的可靠性的检查。通常,从小处开始,建立一个概念验证应用程序,看看它是否符合你的需求,总是一个好主意。
参考文献:
community.openai.com/t/run-same-query-many-times-different-results/140588
help.openai.com/en/articles/6654000-best-practices-for-prompt-engineering-with-openai-api
mlops.community/concepts-for-reliability-of-llms-in-production/
如果你想阅读更多关于新兴技术的内容,请关注我。请在评论区留下你的反馈。
2023 年最佳:关于 ChatGPT 和 LLMs
原文:
towardsdatascience.com/tds-best-of-2023-on-chatgpt-and-llms-83bdfbb2136d?source=collection_archive---------3-----------------------#2023-12-14
TDS Editors
·
关注 发表在 Towards Data Science · 发送为 通讯 · 阅读时间 5 分钟·2023 年 12 月 14 日
--
你可能会说 2023 年对数据科学家和机器学习专业人士来说是多事之年,但这并不能完全捕捉到过去 12 个月我们在这个领域所经历的繁忙活动量。
尽管我们总是力图抵制炒作和夸张,但我们不得不承认,是的,我们确实看到了一些戏剧性的变化,既包括从业者的观点,也包括社会整体对人工智能及其对我们日常生活影响的看法。ChatGPT 在 2022 年最后几周的发布远不是这种过渡的唯一因素,但很难否认它既是催化剂又是象征性的焦点。
当我们考虑如何盘点 2023 年我们作者的最佳和最受欢迎作品时,回顾关于大语言模型的文章——尤其是那个无处不在的聊天机器人——成为了一个非常自然的选择。我们在这里呈现的文章并不全面,但确实提供了一个具有代表性的样本,展示了你们这些读者对哪些文章反响最强烈——无论是你们无法停止阅读和分享的文章,还是那些在 TDS 及其他地方引发了最深刻讨论的文章。
在我们深入探讨过去一年中最引起关注的文章之前,我们想花一点时间感谢我们的整个社区对我们的支持。我们特别感谢我们了不起的作者们、Medium的合作伙伴、慷慨提供专业知识的志愿编辑组,以及我们的两位前同事和杰出编辑,凯特琳·金迪格和凯瑟琳·普雷里。
-
ChatGPT 如何运作:模型背后的机器人在最不令人惊讶的发展中,莫莉·鲁比的易于理解且信息丰富的解释成为了我们 2023 年最受欢迎的文章。如果你还没读过,现在也不算太晚!
-
封闭的 AI 模型造成糟糕的基准在一个后 ChatGPT 世界中,自然语言处理研究将会采取什么方向?安娜·罗杰斯考察了这个迅速变化领域的现状。
-
ChatGPT 能写出比数据分析师更好的 SQL 吗?虽然大语言模型是否对整个职业构成威胁仍有待观察,玛丽·陈在 ChatGPT 发布后不久便花时间调查了其编程技能。
-
GPT 是一个不可靠的信息库在对人工智能幻觉的前瞻性观察中,诺布尔·阿克森深入探讨了将大语言模型当作可靠搜索引擎使用的潜在风险。
图片由 米奇·豪普特 拍摄,来源于 Unsplash
-
如何将任何文本转换成概念图 由于 LLM,NLP 领域的可能性得到了探索,Rahul Nayak 提供了一种将文本语料库转换为知识图谱的实用方法。
-
并非所有美好:ChatGPT 的阴暗面 从内建偏见到隐私和剽窃问题,Mary Reagan PhD 揭示了 ChatGPT 崛起后出现的一些主要风险。
-
零 ETL、ChatGPT 与数据工程的未来 ChatGPT 和类似工具将如何影响日常的数据工程工作流?Barr Moses 分享了对“后现代数据栈”未来的见解。
-
构建你的第一个 LLM 应用所需了解的一切 2023 年是 LLM 驱动的应用程序构建过程变得实际民主化的一年,这在很大程度上得益于像 Dominik Polzer 的广泛分享的教程。
-
GPT-4 与 ChatGPT:训练、性能、能力与局限性的探讨 在发布 ChatGPT 几个月后,OpenAI 通过最新的 GPT-4 提升了标准,Mary Newhauser 迅速提供了这两款产品的详细对比。
-
TimeGPT:首个时间序列预测的基础模型 随着年份的推进,我们遇到了越来越多针对特定用例的 LLM 解决方案。Marco Peixeiro 对 TimeGPT 进行了解释,它是一个定制化基础模型的示例。
-
掌握客户细分与 LLM LLM 的实际应用案例及其支持的产品每天都在不断增长;Damian Gil 为营销人员和商业战略家概述了一个有前途的方向。
-
开始使用 LangChain:构建 LLM 驱动应用程序的初学者指南 与 ChatGPT 一起,LangChain 成为了构建基于 LLM 的产品的热门工具;Leonie Monigatti 撰写了这本资源指南,适合任何对其进行探索的人。
-
新的 ChatGPT 提示工程技术:程序模拟 将我们的需求和目标转化为 LLM 可以正确解读的语言仍然是一个挑战。Giuseppe Scalamogna 揭示了一种更有效的提示设计创新框架。
-
GPT 模型如何工作 作为对 GPT 模型背后的数学和理论的全面且易于理解的入门介绍,Beatriz Stollnitz的深入分析仍然是初学者和经验丰富的从业者的顶级选择。
-
如何从零开始构建 LLM 如果你偏好更具实践性的学习方法,Shawhin Talebi 的教程将带你从数据整理到模型评估——即使你不打算在家里创建下一个 Llama 或 Falcon 模型,它也值得一探!
-
RAG 与微调——哪个是提升 LLM 应用的最佳工具? 随着我们了解了预训练模型的局限性,新的方法出现以提升其性能。Heiko Hotz 提供了对两种主要选项的有用比较:微调和检索增强生成(RAG)。
-
在 CPU 上本地运行 Llama 2 进行文档问答 能够与我们自己的文本文件、PDF 和音频记录“对话”已经成为 LLM 的一个流行日常应用场景。Kenneth Leung 的逐步指南展示了我们如何在本地机器上创建这样的工作流程。
-
LangChain 中链式 LLMs、代理和工具的温和介绍 对于任何刚刚开始使用 LLMs 的人,Dr. Varshita Sher关于 LangChain 构建模块的有用且全面的教程是必读之作。
-
分子生物学中的大型语言模型 探索 LLMs 在科学研究中的潜力,Serafim Batzoglou的深度挖掘关注了其在分子生物学中的影响,应用范围从基因结构预测到药物发现。
敬请关注! 在 2023 年,我们发布了大量优秀文章,涵盖了远超 LLMs 和 ChatGPT 的广泛话题。下周,我们将把今年的最后一期 Variable 专注于数据科学和编程技能、职业道路以及特别项目的精彩文章。
再次感谢您在 2023 年支持我们作者的工作!如果您喜欢 TDS 上的文章,可以考虑成为 Medium 的朋友会员:这是一个新的会员等级,能为您喜爱的作者提供更大的奖励。
直到下一个 Variable,
TDS 编辑部
TDSP:当敏捷遇上数据科学
原文:
towardsdatascience.com/tdsp-when-agile-meets-data-science-15ccb5bf8f87
实用指南:将敏捷原则应用于数据科学项目
Amol Mavuduru
·发表于 Towards Data Science ·7 分钟阅读·2023 年 1 月 23 日
--
由 Daria Nepriakhina 🇺🇦 拍摄于 Unsplash
如果你参加过软件开发/项目管理课程或培训,你可能听说过敏捷。敏捷是一套关注适应性规划、早期交付、持续改进和灵活响应需求变化的软件开发实践。
尽管敏捷在软件开发中非常流行,但这些实践带来的灵活性在数据科学项目中同样,甚至更为有效。事实上,微软在 2016 年发布了一个迭代的数据科学框架,专门将敏捷原则应用于数据科学项目。
在这篇文章中,我将解释微软的团队数据科学流程(TDSP)如何用于将敏捷原则应用于数据科学项目。
引言——敏捷原则
在介绍 TDSP 之前,我们应当讨论敏捷的指导原则,这些原则最初由敏捷宣言的作者提出。
-
个人与互动胜于过程和工具。
-
工作软件胜于全面文档。
-
客户协作胜于合同谈判。
-
响应变化胜于遵循计划。
整个敏捷框架中的实践,如每日立会和迭代增量开发,都是基于这些价值观。这些实践旨在适应变化,并迅速交付可工作的解决方案。
什么是 TDSP?
TDSP 是一种将敏捷原则应用于高效交付数据科学解决方案的方法论。TDSP 的核心概念是数据科学生命周期,类似于敏捷中的软件开发生命周期。数据科学生命周期包括在数据科学项目中反复进行的五个生命周期步骤:
-
业务理解
-
数据获取与理解
-
建模
-
部署
-
利益相关者/客户接受
这些步骤可以通过下面的工作流图进行可视化。
数据科学工作流。图片由作者提供,灵感来源于Microsoft Azure。
在这种方法论中,每个数据科学项目都从定义业务问题和理解业务需求开始。这导致了数据获取和理解步骤,这是进行任何模型开发工作的先决条件。一旦我们有了一个表现良好的模型,就可以将其部署到生产环境中,或以仪表板或报告的形式展示结果。
如果在任何时候我们对结果不满意或面临变化的需求,我们可以返回到之前的步骤,因为这种方法论专注于迭代开发。每个步骤将在后续部分中详细解释。
业务理解
照片由Medienstürmer提供,来源于Unsplash
这一阶段的核心是确定项目的业务需求和识别解决相关机器学习问题所需的数据。此阶段有两个主要任务:
-
定义目标: 我们需要与客户/利益相关者合作,以确定我们试图解决的业务问题。
-
识别数据来源:一旦我们知道了要解决的问题,我们需要识别为解决该问题所需的数据来源。
这些步骤是每个实际数据科学项目的基础。明确的目标和数据来源在数据分析和模型开发之前就应确定。
数据获取与理解
照片由Myriam Jessier提供,来源于Unsplash
在制定业务问题并确定数据来源之后,下一步最自然的步骤是获取和探索数据。TDSP 的这一阶段包括三个主要任务:
-
获取数据: 我们需要将数据引入用于分析的环境中。如果我们在本地工作,这个任务可能只是将文件上传到 Jupyter 的工作目录。
-
探索数据: 这一步通常被称为探索性数据分析(EDA),包括预处理和理解数据中的属性和模式,并确定数据是否适合模型开发。
-
建立数据管道: 这一步涉及构建一个处理新数据的过程。数据管道可以是基于批处理的、实时的或这两者的混合。
尽管所有这些步骤可能由数据科学家执行,但第三步甚至可能是第一步可能需要由数据工程师执行。这个阶段突显了数据科学团队中除数据科学外的其他技能的重要性。数据科学家在进行 EDA 时,数据工程师可能在设置数据管道,这样我们可以更快地完成这个阶段。
建模
照片由 Kevin Ku 提供,来自 Unsplash
建模阶段,可能是大多数数据科学家最兴奋的阶段,依赖于前两个阶段的成功完成。我们的模型质量最终受到数据质量以及对数据理解的限制。一旦我们有了数据管道和经过探索的高质量数据,我们就准备开始建模阶段,该阶段有三个关键步骤:
-
特征工程: 这一步是从原始数据中创建数据特征以用于模型训练的过程。特征工程通过对数据的良好理解来增强,这使我们能够提取甚至创建新的特征以供模型使用。
-
模型训练: 一旦我们有了一组特征和目标变量,我们可以训练模型来预测目标变量。在这一阶段,我们将数据分为训练集、验证集和测试集。
-
模型评估: 在训练每个模型后,我们需要通过回答以下问题来评估这些模型:1)模型在验证/测试集上的指标是什么?2)模型是否解决了业务问题并符合问题的限制?3)这个模型是否适合生产环境?
这一阶段的步骤是迭代的。我们可以训练一个模型,发现结果不令人满意,然后返回到特征工程和模型训练阶段,打造更好的特征并尝试不同的建模方法。在这个阶段训练和评估多个模型并不是不寻常的,事实上,这是预期中的事情。
部署
照片由 Kevin Ku 提供,来源于 Unsplash
这个阶段将我们的模型转变为实际可用并产生业务结果的东西。这个阶段的唯一步骤是通过将数据管道和模型集成到生产或类似生产的环境中来使我们的模型具备操作性。有许多方法可以部署模型,但我们选择的路线将取决于业务用例。考虑以下部署选项:
-
通过 API 公开模型,以便其他应用程序可以使用。
-
创建一个微服务或容器化应用程序来运行模型。
-
将模型集成到一个具有仪表板的 Web 应用程序中,仪表板显示预测结果。
-
创建一个批处理过程来运行模型并将预测写入一个可以使用的数据源。
一旦这个阶段完成,相关利益方和/或客户应该能够观察到模型产生的结果。例如,一个投入生产的产品推荐系统应该能够为在线购物的客户生成推荐。
客户接受
照片由 Rock Staar 提供,来源于 Unsplash
在这个最终阶段,我们的目标是确认数据管道、模型和生产部署满足客户的需求,并解决第一阶段中提到的业务问题。这个阶段有两个步骤:
-
系统验证: 确认数据管道、模型和部署满足业务用例,并满足客户需求。
-
项目移交: 将项目转交给负责生产管理的团队。
请注意,这个过程是迭代的,因此如果数据管道、模型或系统验证步骤中的部署有问题,我们可能需要返回到前一个步骤来修复问题。
为什么 TDSP 适用于数据科学
以下是 TDSP 在数据科学项目中效果良好的原因:
-
TDSP 捕捉了大多数数据科学项目中的基本步骤和依赖关系。
-
数据科学是一个迭代的过程,我们经常在项目的不同阶段发现新信息。
-
TDSP 设计用于包含多个角色的数据科学团队,这些角色包括纯机器学习之外的技能,如数据工程和软件工程。
-
TDSP 从业务需求和数据开始,然后进入模型开发。
-
TDSP 允许数据科学团队适应不断变化的需求以及数据分析、建模和部署步骤中的意外结果。
总结
尽管敏捷原则通常应用于软件开发,但我们也可以认为这些原则在数据科学项目中更为重要。微软的团队数据科学过程(TDSP)是一个将敏捷原则应用于数据科学的框架,围绕五步数据科学生命周期进行设计。TDSP 设计良好,能够反映大多数数据科学项目的阶段,并允许数据科学团队使用迭代方法适应变化。
加入我的邮件列表
加入我的 邮件列表 以获取有关我数据科学内容的更新。你还可以在 注册 时免费获得我的 解决机器学习问题的逐步指南!你也可以通过 Twitter 关注我以获取内容更新。
同时,你也可以考虑 加入 Medium 社区,阅读来自其他成千上万位作者的文章。
来源
- 微软 Azure 团队,什么是团队数据科学过程?,(2020),团队数据科学过程文档。
教人工智能玩棋盘游戏
原文:
towardsdatascience.com/teaching-ai-to-play-board-games-77e5d1749dd9
使用从零开始的强化学习教计算机玩井字棋
Heiko Hotz
·发表于 Towards Data Science ·18 分钟阅读·2023 年 12 月 12 日
--
图片由作者提供(由 ChatGPT 创建)
这是什么内容?
目前,人工智能领域似乎每个人都在提升他们的强化学习(RL)技能,特别是在 Q-learning 方面,跟随关于 OpenAI 新 AI 模型 *Q** 的最新传闻,我也参与其中。然而,我决定用我对棋盘游戏的热情来介绍 Q-learning 🤓,而不是对 Q* 进行猜测或重温 Q-learning 的旧论文和示例。
在这篇博客文章中,我将从头开始创建一个简单的程序,教一个模型如何玩井字棋(TTT)。我将避免使用任何强化学习库,比如 Gym 或 Stable Baselines;所有内容都是用原生 Python 手动编写的,脚本只有 100 行。如果你对如何指导人工智能玩游戏感到好奇,请继续阅读。
你可以在 GitHub 上找到所有代码,链接为 github.com/marshmellow77/tictactoe-q
。
为什么这很重要?
教人工智能玩井字棋(TTT)可能看起来并不那么重要。然而,它确实提供了一个(希望)清晰且易于理解的 Q-learning 和 RL 的介绍,这在生成式人工智能(GenAI)领域可能是重要的,因为有人猜测像 GPT-4 这样的独立 GenAI 模型对于显著的进步是不够的。它们的局限性在于只能预测下一个标记,而无法进行任何推理。RL 被认为能够解决这个问题,并可能增强 GenAI 模型的响应能力。
无论你是为了迎接这些进展而提升你的 RL 技能,还是仅仅寻求一个有趣的 Q 学习入门教程,这个教程都适合这两种情况🤗
理解 Q 学习
从本质上讲,Q 学习是一种算法,它学习特定状态下一个动作的价值,然后利用这些信息找到最佳动作。让我们考虑Frozen Lake游戏的例子,这是一款用于演示 Q 学习的流行单人游戏。
在 Frozen Lake 中,玩家(从单元格 0 开始)在冰和水的网格上移动,目标是到达目标(单元格 15)而不掉入水中。每个单元格代表一个状态,玩家可以向四个方向移动:上、下、左或右。
作者提供的图片(使用 Stable Diffusion 创建)
在游戏开始时,代理(这就是 AI 玩家通常的称呼)没有任何信息,只会随机尝试一些动作。在 Q 学习的背景下,这个探索阶段至关重要。代理通过根据其动作获得奖励或惩罚来学习。在 Frozen Lake 中,达到目标会获得高奖励,而掉入水中则会受到惩罚。这种奖励和惩罚的系统引导代理学习最有效的到达目标的路径。
Q 学习使用一个表格,称为 Q-表,用于记录每个状态下每个动作的价值。随着代理探索环境,这个表格会不断更新。Q-表条目,称为 Q 值,表示在给定状态下采取某个动作的预期效用,它们通过贝尔曼方程进行更新。这个方程考虑了动作的即时奖励和可能的最高未来奖励(稍后会详细讲解)。
基本上,Q-表是代理的备忘单或查找表:根据游戏的状态,代理会查找该状态,确定哪个动作具有最高效用(即哪个是最佳动作),然后执行该动作。以下是 Q-表可能的示例:
作者提供的图片
在这个例子中,如果玩家处于状态 1(即在单元格 1),他会选择动作右,因为这是具有最高价值的动作。
随着时间的推移,代理探索环境并更新 Q-表,它在导航 Frozen Lake 时变得更加熟练,最终学会了一种最佳或接近最佳的策略,以可靠地到达目标。Q 学习在这种情况下的美妙之处在于它的无模型性质,这意味着它不需要环境模型,可以仅通过交互学习,使其广泛适用于各种 RL 问题。
存在许多教程演示了如何利用和实现 Q-learning 来解决 Frozen Lake 游戏,例如 towardsdatascience.com/q-learning-for-beginners-2837b777741
。然而,正如前面提到的,作为一个棋盘游戏爱好者,我对将这种方法适用于双人游戏,甚至更多玩家的游戏更感兴趣。
双人游戏中的挑战
将 Q-learning 应用于双人游戏,如井字棋,需要进行一些小的修改。在 Frozen Lake 游戏中,下一状态仅由代理的行动决定。然而,在井字棋中,尽管玩家可能采取一个回合,但随后的状态还依赖于对手的行动。例如,如果我在左上角放置一个‘X’,那么下一状态是不确定的,因为我的对手有几个潜在的移动:
作者提供的图片
可以采用几种方法来解决这个问题。一种方法是模拟对手所有可能的行动及其相应结果。这需要生成所有潜在后续状态的概率分布,并根据这些状态的预期结果更新 Q 值。然而,这种方法可能计算量较大。在本教程中,我们将采用一种更简单的方法,为对手随机采取一个动作,并根据这个动作的实际结果更新 Q 表。这很好地反映了对手的不可预测性,正如我们后面将看到的那样。通过这种方法,Q-learning 可以有效地适应双人游戏,使 AI 不仅能够学习最佳移动,还能(最终)适应人类对手的策略。
这种方法原则上与训练 AlphaGo Zero 的方法类似。该 AI 程序在快速连续的对弈中自我对弈了 490 万局围棋。在这个过程中,它不断提高自己的技能,自主学习和调整策略。这种自学习方法,绕过了模拟对手每一个可能的移动的需求,为 AI 系统提供了一种高效且有效的学习和适应复杂任务的方法。
李世石与 AlphaGo 的第 2 局比赛,AlphaGo 的著名第 37 步。图片来源:commons.wikimedia.org/wiki/File:Lee_Sedol_(W)_vs_AlphaGo_(B)_-_Game_2.svg
(许可 CC BY-SA 4.0)
在接下来的部分中,我们将深入探讨这些原则如何在井字棋的具体情况下应用,展示在双人环境中 Q-learning 的实现。
井字棋的 Q-Learning
当我们开始将 Q-learning 应用于井字棋时,了解我们程序的设置以及 AI 代理将要操作的环境非常重要。
概述
这段代码旨在训练一个 AI(我们称之为玩家 1或智能体),通过 Q 学习(一种强化学习形式)来玩类似井字棋的游戏。它首先设置学习参数并初始化一个 Q 表来存储不同状态下不同动作的值。脚本定义了几个函数来管理游戏机制,如确定可能的移动、检查胜者、更新游戏状态,以及在移动后计算下一个状态和奖励。
在脚本的主要部分中,实现了 Q 学习算法。它运行多个回合,模拟智能体与其对手(我们称之为玩家 2)之间的游戏。在每一回合中,AI 要么探索一个随机动作,要么利用 Q 表中的知识来进行决策,从结果中学习以更新 Q 表的值。这个过程涉及随着时间的推移调整探索率,从随机探索转向更具策略性的动作。
我们设置的一个关键方面是 AI 的对手。与对手可能拥有复杂策略的更复杂场景不同,我们的 AI 将与一个随机移动的对手进行游戏。这一选择简化了学习环境,使我们可以专注于 AI 的学习过程,而不是对手策略的复杂性。
Q 学习设置
我们的 Q 学习设置涉及一些关键参数,这些参数将影响 AI 的学习方式:
learning_rate = 0.2
discount_factor = 0.9
num_episodes = int(1e7)
epsilon = 1.0 # Exploration rate
epsilon_min = 0.01
epsilon_decay = 0.999
-
学习率 (
**learning_rate**
): 这决定了新信息对现有知识的影响程度。较高的学习率加速了学习过程,但可能导致不稳定。学习率为 0.2 在学习新策略和保留之前的学习之间取得了平衡。 -
折扣因子 (
**discount_factor**
): 这反映了未来奖励的重要性,影响 AI 策略的远见程度。折扣因子为 0.9 时,AI 会特别重视未来奖励,鼓励 AI 前瞻性思考,而不仅仅关注即时收益。 -
回合数 (
**num_episodes**
): 这是 AI 学习的游戏数量,为 AI 提供了充足的机会来体验各种游戏场景。将其设置为 1000 万 (1e7
) 允许广泛的训练,为 AI 提供了从各种游戏场景中学习的充足机会。 -
探索率 (
**epsilon**
): 探索率(epsilon)最初设置较高,以允许 AI 探索各种动作,而不是仅仅利用已知策略。最初,AI 会更多地进行探索(由于epsilon
为 1.0)。随着时间的推移,epsilon
逐渐减小到epsilon_min
,AI 将开始更多地利用其学习到的策略。
关于探索率的附注
在 Q 学习中,探索率通常用符号 ε(epsilon)表示,这是一个关键参数,决定了探索(尝试新动作)和利用(使用已知最佳动作)之间的平衡。最初,智能体对环境了解不多,因此它需要广泛探索,通过尝试不同的动作。探索率通常在开始时设定为较高的值(例如 1 或接近 1),决定了智能体选择随机动作而不是根据 Q 表选择最佳已知动作的概率。
然而,随着智能体对环境的了解越来越多,Q 表变得更加可靠,探索的必要性减少,利用已获得的知识变得更加有益。这时,探索率衰减就发挥作用了。探索率衰减是一个随着时间推移而减少探索率的因素。它确保智能体在学习和收集更多信息的过程中,逐渐从探索环境转向利用 Q 表中学到的值。
这种平衡在 Q 学习中很重要,因为它可以避免两个主要问题:
陷入局部最优: 如果智能体只利用已知信息(低探索),可能会陷入局部最优。这意味着它会根据有限的信息反复选择看似最佳的动作,但可能错过发现能带来更好长期奖励的动作。
低效学习: 另一方面,如果智能体过度探索(高探索)且时间过长,可能导致低效学习。智能体可能会不断尝试次优动作而没有充分利用已经获得的知识,从而导致收敛到最优策略的速度变慢。
通过适当设置探索率及其衰减,Q-learning 算法可以有效地平衡这两个方面,使智能体能够最初探索环境,然后逐渐更多地专注于利用它所学到的最佳策略。这种平衡对于在复杂环境中学习的效率和有效性至关重要。
在接下来的章节中,我们将深入代码,看看 AI 如何使用 Q-learning 来做决策、更新策略,并最终掌握 Tic-Tac-Toe。
代码深度解析
训练脚本
这是 train.py 文件的详细解读。
训练从 for 循环开始(大致在脚本的中间),我们将在其中进行一定数量的回合:
for episode in range(num_episodes):
state = [0] * 9 # Starting state - empty board
接着,我们随机确定起始玩家。一个更简单的方法是让我们的智能体总是作为起始玩家。然而,实现一个随机起始玩家并不比直接总是让智能体作为起始玩家多花费多少精力,并且这种方法使 Q 表模式更加通用,即我们的智能体将学习如何作为起始玩家以及非起始玩家进行游戏。
如果玩家 2 开始游戏,那么我们将为玩家 2 进行随机移动:
# If Player 2 starts, make a random move
if current_player == 2:
actions = get_possible_actions(state)
random_action = random.choice(actions)
state = update_state(state, random_action, 2)
current_player = 1 # Switch to Player 1
现在我们进入实际的 TTT 游戏训练循环,只有在游戏结束时才会停止。一个关键机制是之前讨论的开发 vs 探索机制。它的实现如下:
if random.uniform(0, 1) < epsilon:
# Explore: choose a random action
action = random.choice(actions)
else:
# Exploit: choose the best action based on Q-table
action = max(actions, key=lambda x: Q_table[state_str][x])
epsilon 值越低,智能体通过随机移动进行的探索越少,它将更多地利用 Q 表。
一旦选择了智能体的动作,我们将执行它并确定下一状态(以及适用的奖励):
# Take action and observe new state and reward
new_state, reward = get_next_state_and_reward(state, action)
处理所有这些操作的函数值得更仔细地查看:
def get_next_state_and_reward(state, action):
new_state = update_state(state, action, 1) # Player 1's move
if is_winner(new_state, 1):
return (new_state, 1) # Reward for winning
elif 0 not in new_state:
return (new_state, 0.1) # Draw
else:
# Player 2 (random) makes a move
actions = get_possible_actions(new_state)
random_action = random.choice(actions)
new_state = update_state(new_state, random_action, 2)
if is_winner(new_state, 2):
return (new_state, -1) # Penalty for losing
else:
return (new_state, 0) # No immediate reward or penalty
在这个函数中,我们首先更新棋盘的状态并检查我们的智能体是否赢得了游戏。如果没有,我们为对手进行随机移动,并再次检查对手是否赢得了游戏。根据结果,我们返回 0(游戏仍在进行中)、0.1(平局)、+1(智能体获胜)或 -1(对手获胜)。我们选择 0.1 作为平局的奖励是为了激励智能体尽快结束游戏。
现在我们已经确定了奖励,接下来是整个程序中最关键的部分:通过 Bellman 方程更新 Q 表:
Q_table[state_str][action] += learning_rate * (
reward + discount_factor * max(Q_table[new_state_str]) - Q_table[state_str][action])
这个 Bellman 方程在其他博客文章中解释得更好(再次参考 towardsdatascience.com/q-learning-for-beginners-2837b777741
)。但简要解释如下:
如前所述,Q 表本质上是一个大的备忘单:它跟踪游戏中的所有可能状态以及从该状态开始的每个可能移动的价值。它告诉智能体在给定情况下每个移动的好坏,基于它迄今为止学到的知识。
Bellman 方程更新这个 Q 表。它通过查看智能体收到的即时奖励(赢、输、平局)和它可以移动到的未来状态(即未来奖励)的质量来实现。因此,在每局游戏后,智能体使用 Bellman 方程来修订其 Q 表,学习哪些移动可能导致胜利、失败或平局。
最后,我们调整探索率,以便在未来的游戏中,智能体更多地使用 Q 表而较少进行探索。
epsilon = max(epsilon_min, epsilon_decay * epsilon)
运行训练
一旦训练脚本准备好,我们就可以执行它。幸运的是,这个过程计算需求不高,完成得非常快,不需要特别的计算能力。例如,我在 MacBook M1 Air 上执行了这个过程,它在 1000 万局游戏中不到 5 分钟就完成了。训练完成后,我们将保存 Q 表(它不是特别大),以便我们可以用它来测试智能体,与 AI 对战,并可能在稍后的阶段继续训练,以进一步增强表格。我们来看看吧 🧐
Q 表的人工检查
这个表格相对容易理解:每一行代表了棋盘状态、可采取的行动及其质量。让我们来看看一些有趣的状态。请注意,你的表格可能会有不同(但希望是相似)的值:
图片由作者提供
棋盘状态显示了每个玩家已经放置的位置(前 3 个数字代表第一行,接下来的 3 个代表第二行,最后 3 个代表最后一行。动作对应棋盘上的位置,每个动作的数字表示该动作的质量。在这个例子中,我们看到一个状态,似乎只有一个动作(动作 7)被认为是好的,其他所有动作都显得较差。
注意:棋盘位置的索引如下:
图片由作者提供
所以,让我们来可视化 Q 表中的这个特定条目的棋盘状态:
图片由作者提供
的确,在这个位置,代理(玩家 1)唯一的好选择是选择位置 7。所有其他移动可能会导致输掉比赛(请记住,玩家 2 将在下一轮随机移动,因此输掉比赛并非必然)。
让我们再看一个例子:
图片由作者提供
图片由作者提供
在这个例子中,显然最佳移动是选择位置 8(右下角)并赢得比赛。如果代理选择其他位置,它很可能会输掉比赛。因此,Q 表将指示我们的代理采取动作 8。
测试新代理
现在我们已经训练了模型,我们可以用 GH 仓库中的脚本test.py来测试它。在脚本中,我们将让代理与一个随机移动的对手进行若干局比赛,看看它的表现如何。我们首先初始化我们的代理并加载 Q 表以便在游戏环境中用于决策。play_game
函数模拟了一场比赛,使用加载的 Q 表来指导代理的决策。这里的游戏环境是一个简单的 3x3 棋盘,每个状态代表棋盘的不同配置。
代理以玩家 1 的身份,根据 Q 表做出决策——选择当前状态下值最高的行动。如果在 Q 表中找不到状态,代理将做出随机移动。这种学习行为和随机性的结合有助于评估训练的鲁棒性。玩家 2 的移动完全随机,为代理提供了多样化的场景。
这些游戏的结果会被跟踪,量化胜利、失败和平局的数量。这有助于评估训练模型的效果。如果设置了log_lost_games
标志,将保存详细的失败游戏日志,这对于进一步分析和改进模型是非常宝贵的。这一测试过程,通过进行大量游戏,提供了对训练后代理能力的全面了解。
作者提供的图片
与 AI 对战
看起来对随机机器人进行的测试很成功。我们的 AI 赢得了超过 95%的比赛。现在,我们想亲自与 AI 对战。我们可以使用play.py来实现这一点。
在这个程序中,我们通过一个简单的控制台界面与 AI 互动。游戏板表示为一个 3x3 的网格,每个位置从 0 到 8 编号。当轮到我们时,我们会被提示输入一个数字,以选择我们想要移动的位置。
AI 使用从 CSV 文件加载的 Q 表来做出决策。这个 Q 表来源于之前的训练过程,引导 AI 根据当前的游戏板状态选择最佳可能的移动。如果 AI 遇到 Q 表中没有的状态,它将默认进行随机移动。
游戏在我们的回合和 AI 的回合之间交替进行。每次移动后,更新后的棋盘会被显示,程序会检查是否有赢家。如果玩家获胜或游戏结果为平局,游戏结束,结果将被宣布——无论是我们获胜、AI 获胜还是平局。
这个互动游戏提供了一个很好的机会来实时测试 AI 的能力。让我们开始吧:
作者提供的图片
在这个游戏中,如果我们不选择动作 0(左上角),AI 将有机会赢得比赛。它会意识到这一点吗?
作者提供的图片
确实做到了!很好😊
结论
在这篇文章中,我们训练了我们的 AI 代理对抗一个进行随机移动的玩家。这已经足够好,能够在对抗进行随机移动的对手时达到超过 95%的胜率。但是,有方法可以改进训练过程,希望也能提高 AI 的表现。
参数调整的影响
将 Q 学习应用于井字游戏揭示了强化学习的一个关键方面:调整参数的艺术。正确设置这些参数,如开发与探索之间的平衡、学习率和折扣因子,是 RL 代理成功的关键。
-
探索与利用: 由
epsilon
值控制,这一平衡决定了智能体尝试新策略的频率与依赖已知策略的比例。高探索率鼓励智能体尝试新事物,可能导致创新策略,而高利用率使智能体依赖现有知识,虽然可能更高效,但可能会错过更好的策略。 -
学习率: 高学习率意味着智能体迅速采纳新信息,这在动态环境中可能有利,但如果智能体过快地覆盖有用的学习,可能导致不稳定。相反,低学习率意味着智能体更多依赖过去的知识,导致稳定但可能较慢的学习。
-
折扣因子: 这个参数影响智能体对未来奖励的重视程度。高折扣因子使智能体更具前瞻性,考虑其行动的长期后果。相反,低折扣因子则使智能体目光短浅,专注于即时奖励。
这些参数的变化可以显著改变 RL 智能体的行为。例如,折扣因子低的智能体可能会以攻击性方式玩井字棋,专注于即时胜利,而不是制定未来的策略。相反,折扣因子高的智能体可能会更具策略性,考虑每一步对游戏未来状态的影响。
同样,高学习率的智能体可能迅速适应新策略,不断发展其游戏玩法,而低学习率的智能体可能坚持经过验证的策略,游戏中的变化较小。
轮到你来实验了
这就是强化学习真正的激动所在。每个参数都可以进行微调,以观察它如何影响 AI 智能体的学习和表现。我邀请你深入这个实验的世界。调整学习率、探索率和折扣因子,观察这些变化如何影响 AI 在井字棋游戏中的策略。
更高级的技术
为了进一步提高模型的表现,实施自我对弈机制,即 AI 与来自不同训练阶段的自身版本对弈(而不是与进行随机移动的对手对弈),可能是一种有效的策略。这种方法在 AlphaGo 等系统中成功应用过,并可能导致更强大和适应性更强的 AI 玩家。
对于更复杂的游戏,如国际象棋和围棋,维持一个 Q 表将不再可行,因为它变得过于庞大。在这些游戏中,采用像深度 Q 学习这样的技术可以显著增强 AI 的学习能力。通过使用神经网络来逼近 Q 表,AI 可以处理超出简单 3x3 井字棋网格的更复杂状态,使其在更复杂的游戏中具备可扩展性。
总之,目前的设置已经展示了有希望的结果。然而,这些建议的改进可能会进一步提升 AI 的表现,将其从一个合格的井字棋玩家转变为一个能够应对更复杂战略游戏的高级 AI。
进一步的相关资料
如果你对学习更多关于强化学习如何应用于棋盘游戏感兴趣,可以查看下面的两个视频。第一个视频非常简短,深入探讨了现代象棋 AI 机器人如何进行游戏:
第二个视频是电影AlphaGo(在 YouTube 上免费观看),讲述了 AlphaGo 模型的开发过程以及它如何击败当时的世界冠军:
Heiko Hotz
👋 在Medium和LinkedIn关注我,阅读更多关于生成 AI、机器学习和自然语言处理的内容。
👥 如果你在伦敦,可以参加我们的NLP London Meetups。
教学 CLIP 时尚
原文:
towardsdatascience.com/teaching-clip-some-fashion-3005ac3fdcc3?source=collection_archive---------3-----------------------#2023-03-07
训练 FashionCLIP,一个专门用于时尚的 CLIP 模型
Federico Bianchi
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 3 月 7 日
--
图片由 Domenico Loia 提供,发布在 Unsplash 上。
这是一篇简短的博客文章,描述了 FashionCLIP。如果你是数据科学家,你可能需要处理图像和文本。然而,你的数据将非常特定于你的领域,标准模型可能效果不佳。本文解释了如何在领域特定的环境中使用领域特定的视觉和语言模型,以及为何使用这些模型可能是创建搜索引擎或(零样本)分类器的一个有前景的方式。
FashionCLIP 是一种用于时尚行业的新视觉和语言模型,支持从业者解决两个任务:
-
分类:产品图像的零样本分类;
-
搜索:根据查询高效检索产品。
尽管 FashionCLIP 是许多人努力工作的结果,这篇博客文章主要是我在构建过程中获得的惊人经验的总结和个人观点,并不一定代表所有其他作者及其组织的观点。
模型
我们目前以两种不同的格式发布模型:
-
我们的内部封装
-
HuggingFace 权重
我们还有一个 colab 教程 介绍了使用 FashionCLIP 可以做的大部分事情。
介绍
时尚是可以从 AI 产品中受益最多的行业之一。实际上,由于领域的性质、不同的目录和客户特定的数据集,通常很难构建可以无缝应用于不同问题的解决方案。
想象一下在一家大型时尚公司工作的两位数据科学家:Mary 和 Luis。他们必须应对不断变化的系统,其操作需要持续的关注:
-
Mary 正在构建一个 产品分类器 以帮助大规模分类:她的模型接收一个产品并从一系列类别中选择一个(鞋子、连衣裙等);
-
Luis 正在研究 产品匹配 以改善搜索体验:他的模型接受一种支持的语言中的查询(例如,“一件红色连衣裙”),并返回匹配该查询的产品列表。
正如每个从业者所知道的,任何新的生产模型都会带来复杂的生命周期和某种程度的脆弱依赖:
-
随着库存的增长和类别的变化,Mary 的模型需要不断重新训练;
-
Luis 的模型依赖于产品元数据的质量。
同一公司,不同用例,不同模型。
如果有另一种方法呢?
今天我们尝试向前迈出一步,展示如何构建一个用于时尚数据的通用模型。我们描述了 FashionCLIP,它是著名的 CLIP 模型的微调版本,专门处理时尚数据。我们最近的关于 FashionCLIP 的论文已在《自然科学报告》中发布。
Chia, P.J., Attanasio, G., Bianchi, F. 等 一般时尚概念的对比语言与视觉学习。Sci Rep 12, 18958 (2022)。
doi.org/10.1038/s41598-022-23052-9
FashionCLIP 的诞生源于与Farfetch的合作,这是一家在纽约证券交易所上市的巨大(且真实的)奢侈品电商。FashionCLIP 是与来自业界(Coveo、Farfetch)和学术界(斯坦福、博科尼、比科卡)的人们共同完成的工作。模型权重可以在线获得,格式为HuggingFace。使用示例可以在Patrick 的 repo中找到。
我们将首先介绍用例,并解释一些模型的更深入细节。最后,我们将分享我们用来训练模型的代码以及如何获取权重。
FashionCLIP: 故事
FashionCLIP 是一个通用模型,用于将时尚产品的图像及其描述嵌入到同一个向量空间中:每个图像和每个产品将由一个单独的稠密向量表示。
为什么我们要把它们放在同一个向量空间中? 这样它们才能进行比较。 这个原则是像 CLIP 这样的模型成功的关键。
FashionCLIP 源自原始的 CLIP。这个想法非常简单。如果你:
-
大量带有标题的图像;
-
一个图像编码器(这可以是 CNN 或 ViT);
-
一个文本编码器(这可以是基于 transformers 的语言模型)。
你可以训练一个模型(使用对比损失)来使图像的嵌入接近其标题嵌入,并远离不相关的标题。在 GIF 中,你展示了一个二维的例子。这个概念可以推广到 N 维。
FashionCLIP 将描述和图像嵌入到同一个向量空间中。这对于零-shot 分类和图像检索非常有用。图片由作者使用 Farfetch 目录提供。
最终结果是一个多模态空间,允许你在视觉和文本交互之间移动,使用新的图像和新的文本描述:如果你有一些文本,你可以检索到对应的图像(如产品搜索);如果你有一些图像,你可以排序标题基于语义相似性(如分类)。
要微调 CLIP,你需要一个好的数据集。我们与 Farfetch 合作,使用高质量的图像和标题来训练 CLIP。这个数据集(即将公开发布)包含了超过 80 万的样本。
我们训练模型几个周期,并检查在多个基准上的表现,包括零-shot 分类、探测和检索。在查看结果之前,让我们深入了解一下现在有了训练好的 FashionCLIP 后我们可以做什么。
我们不会深入探讨 CLIP 本身。如果你想了解更多关于 CLIP 的内容,我这里有一篇专门的博客文章:
[## 如何训练你的 CLIP]
介绍 CLIP 以及我们如何在 HuggingFace 社区周期间为意大利语言微调它。
towardsdatascience.com
FashionCLIP 可以处理的两个关键任务是:
-
图像检索
-
零-shot 分类
检索:从文本到图像
我们首先从文本到图像:我们使用 FashionCLIP 文本编码器对搜索查询(“一件红色连衣裙”)进行编码,并通过简单的点积检索最接近的图像向量。点积的值越大,文本和图像之间的相似度越高。在下面的 GIF 中,搜索以 4 个产品图像为例进行。
对于检索,我们可以在目标目录上预先计算图像嵌入。在运行时,我们编码查询并通过简单的点积对图像进行排名。图片由作者使用 Farfetch 目录提供。
虽然“红色连衣裙”是一个简单的查询,搜索引擎可能不需要额外的输入,但稍微模糊一些的查询,如“浅红色连衣裙”与“深红色连衣裙”则变得有趣,其中“浅”和“深”是同一颜色的修饰词:
FashionCLIP 有助于消歧义几何特征。图片由作者使用 Farfetch 目录提供。
更有趣的是 FashionCLIP 捕捉到衣物中代表的物品的能力。产品描述通常未能明确提及具象图案,FashionCLIP 能够识别印刷的物品,即使是类似卡通的形状,如下面 T 恤上挂着的猫:
FashionCLIP 识别印刷在 T 恤上的具象物品。图片由作者使用 Farfetch 目录提供。
虽然我们尚未详细评估这一能力,但我们相信这可能来自原始 CLIP 所具备的“知识”,在微调过程中部分保留。
当然,信息在描述中(例如,品牌通常在描述中提及)比 FashionCLIP 可能捕获的任何语义细微差别编码得更好。然而,它在增强标准学习排名信号而没有行为数据方面的能力可能大大改善搜索体验,特别是在冷启动场景下。
分类:从图像到文本
我们现在从图像到文本进行分类:我们使用 FashionCLIP 的图像编码器对要分类的时尚物品图像进行编码,并通过点积检索最接近的标签向量:
对于零-shot 分类,我们计算查询项的图像嵌入和目标标签的文本嵌入。图片由作者使用 Farfetch 目录提供。
CLIP-like 模型的技巧在于将标签视为语义上有意义的标签,而不是类别变量。
换句话说,当我们“分类”时,我们在问“这些文本中哪个是这个图像的最佳标题?”。
得益于 CLIP 的预训练和自然语言的无限可能性,我们现在拥有一个不局限于任何特定标签、类别或属性的分类器:当然,首要应用可能是在 Farfetch 目录中的新产品上使用该分类器,我们还可以在具有不同标签或用途的其他数据集上重复使用相同的模型,例如:
-
如果供应商没有将鞋子分类为“高跟鞋”与“平底鞋”,我们可以添加该属性;
-
如果商品管理员在目录中创建新的视图——例如,将项目匹配到风格——我们可以根据新的维度(“优雅”、“街头风”等)对现有产品进行分类。
CLIP 的泛化能力当然是以某些精度为代价的:也就是说,如果我们以监督方式训练一个新的分类器来解决上述用例,它们都会比 FashionCLIP 更好。像往常一样,真实世界的机器学习没有一刀切的方案,模型之间的权衡可以根据用例的重要性、训练时间、标注成本等不同方式进行评估。
性能
我们在两个不同任务和多个数据集上将 FashionCLIP 与 CLIP 进行比较。有关设置的更多细节请参阅论文,本节的范围只是为了展示在时尚相关任务中使用 FashionCLIP 替代 CLIP 时性能的提升。
对于零样本分类,我们使用了三个不同的数据集(KAGL、DEEP 和 FMNIST),这些数据集应作为分布外数据集(我们知道从其他实验中我们在领域内数据上表现比 CLIP 好得多,但这是预期中的)。
不同数据集上的加权宏 F1 分数(领域外数据)。FashionCLIP 在这些数据集上显示出相对于 CLIP 的显著提升。
Zero-shot 结果确认我们的模型表现如预期!
对于图像检索,我们使用了在训练时遗漏的原始数据集的一部分。需要注意的是,这显然使我们相对于 CLIP 有优势,因为这个数据集对于我们来说是领域内的。然而,这仍然是一个有趣的实验。以下结果确认我们的模型表现最佳:
在我们内部测试集上的前 5 和前 10 精度(领域内数据)。FashionCLIP 的检索性能明显更好。
Torch 实现和 HuggingFace 权重
由于 Patrick 的工作,FashionCLIP 使用起来非常简单。你只需加载模型并使用简单的方法进行零样本分类,所有这些都可以用 Python 完成!
fclip = [...load FCLIP ...]
test_captions = [
"nike sneakers", "adidas sneakers", "nike blue sneakers",
"converse", "nike", "library", "the flag of italy",
"pizza", "a gucci dress"
]
test_img_path = 'images/16790484.jpg'
fclip.zero_shot_classification([test_img_path], test_captions)
你还可以进行图像检索!
candidates = fclip.retrieval(['shoes'])
print(candidates)
告别
漫长旅程的总结
构建 FashionCLIP 是一段长时间且有趣的冒险,涉及到来自地球上一些最酷地方的老朋友和新朋友。结果总是更美好,当你和朋友一起获得它们时。此外,我们中的一些人已经合作多年,实际上从未在现实生活中见过面!
从更务实的角度来看,我们希望 FashionCLIP 能为快速迭代内部和外部时尚用例的公司开辟前所未有的机会:例如,虽然你可能会最终构建一个专注的风格分类器,但使用 FashionCLIP 进行概念验证将大大证明该功能的价值,而无需在新的模型生命周期支持上进行前期投资。
当我们考虑零售领域日益增长的智能 API SaaS 服务提供商——如 Coveo、Algolia、Bloomreach——时,垂直模型的重要性不可低估:由于 B2B 公司以账户为基础增长,稳健性和可重用性比纯粹的精准度更为重要。我们展望不久的将来,FashionCLIP —— 以及 DIYCLIP、ElectronicsCLIP 等 —— 将成为 B2B 机器学习参与者的标准组件,使得迭代迅速、数据标准化,并在完全不同于目前的水平上实现规模经济。
我去年也在 Pinecone 上做了一个关于 FashionCLIP 的演讲:
我在 Pinecone 上关于如何构建像 FashionCLIP 这样的模型的演讲。
另一个演示
开源的力量是什么?Pablo 看到这个模型并联系了我们,提供了一个用户界面来帮助我们测试标准的 HuggingFace CLIP 与我们刚刚发布的 FashionCLIP 之间的差异——然后我使用了 Objective Search 来测试使用 FashionCLIP 的几个查询(您可以在这里亲自查看):
使用 FashionCLIP 进行搜索。GIF 由作者提供,图片来自 H&M 数据集。
很酷,不是吗?
局限性、偏见与公平性
我们承认 FashionCLIP 存在某些限制,并预计它继承了原始 CLIP 模型中的一些局限性和偏见。我们不期望我们的微调会显著增加这些限制:我们承认,我们使用的时尚数据对性别的概念做出了明确假设,例如“女性的蓝色鞋子”,这不可避免地将服装的某些方面与特定的人联系在一起。
我们的调查还表明,所使用的数据在 FashionCLIP 中引入了某些限制。从文本模态来看,鉴于大多数来自 Farfetch 数据集的标题较长,我们观察到 FashionCLIP 在处理较长查询时可能比短查询表现更好。
从图像模态来看,FashionCLIP 对标准产品图像(居中、白色背景)也存在偏见。这意味着模型可能在不具备相同结构的图像上表现不佳。
我们做的更多事情
FashionCLIP 的发展经历了漫长的过程,但在等待正式发布期间我们做了一些事情。
GradedRecs
我们在 FashionCLIP 的基础上进行了探索,通过遍历潜在空间来研究推荐。如果你感兴趣,请查看我们的 论文!
GradedRec。图片由作者提供。
推荐系统评估中的公平性
如果你对相关行业任务感兴趣,例如推荐系统,我们去年进行了一项关于推荐系统全面评估的挑战。
这个挑战旨在理解如何构建不仅仅关注点对点度量(例如准确率)的评估。你可以在这里找到一些细节和介绍性的博客文章
[## 关于推荐系统的全面评估
EvalRS:在多个测试中评估推荐系统
fede-bianchi.medium.com](https://fede-bianchi.medium.com/a-rounded-evaluation-of-recommender-systems-b9fa101ef79a?source=post_page-----3005ac3fdcc3--------------------------------)
教学很难:如何训练小模型并超越大型对手
原文:
towardsdatascience.com/teaching-is-hard-how-to-train-small-models-and-outperforming-large-counterparts-f131f9d463e1
|模型蒸馏|人工智能|大型语言模型|
蒸馏大型模型的知识是复杂的,但一种新方法显示出惊人的性能
Salvatore Raieli
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 11 月 11 日
--
图片由 JESHOOTS.COM 提供,来源于 Unsplash
大型语言模型(LLMs)和少样本学习已经证明我们可以将这些模型用于未见过的任务。然而,这些技能是有代价的:大量的参数。这意味着你还需要一个专业化的基础设施,并且将最先进的 LLMs 限制在只有少数几家公司和研究团队中。
-
我们真的需要为每个任务设计一个独特的模型吗?
-
是否有可能创建专门的模型来替代它们用于特定的应用?
-
我们如何才能拥有一个在特定应用中与大型 LLMs 竞争的小模型?我们是否确实需要大量的数据?
在这篇文章中,我对这些问题给出了答案。
“教育是人生成功的关键,教师在学生的生活中留下了深远的影响。” ——所罗门·奥尔蒂斯
匹配冠军!
图片由 Fauzan Saari 提供,来源于 Unsplash
教学的艺术是协助发现的艺术。——马克·范·多伦
大型语言模型(LLMs)展现了革命性的能力。例如,研究人员对像上下文学习这样的难以捉摸的行为感到惊讶。这导致模型规模的增加,越来越大的模型寻找新能力,这些能力超出了参数的数量。
## 关于上下文学习的一切
什么是大型语言模型,它是如何工作的,以及是什么使大型语言模型如此强大
[towardsdatascience.com ## 人工智能中的涌现能力:我们是否在追逐一个神话?
对大型语言模型出现特性的观点变化
[towardsdatascience.com
但这会有代价;例如,GPT-3(超过 175 万亿个参数)至少需要 350 GB 的 GPU 来运行。这意味着你需要专门的基础设施来训练和使用它进行推理。将这样的模型部署以使其公开访问需要克服重大挑战和成本(尤其是如果你想减少延迟)。因此,只有少数公司能够负担得起为实际应用部署一定规模的模型。
拥有超过 100 B 参数的模型具有大型建模能力,但这些能力分散在许多技能上。相比之下,少于 10 B 的模型建模能力较弱,但可以将这种能力集中于单一任务。例如,推理是超过 100 B 参数模型展示的一种能力,但在小型模型中缺失。这项研究的作者表明,推理只是大型 LLM 中的众多能力之一。因此,将小型模型的训练重点放在推理上,即使模型小于 100 B,也可以获得显著的结果。
当然,专注于小型模型会有代价:对其他任务的表现。但通常你只对一个任务感兴趣,因此可以使用小型模型。
以前的研究表明,推理能力随着规模的增加而突然出现(左侧图)。这项研究的作者表明,通过专注于推理任务(专业化),你可以在推理方面取得良好的结果。图片来源:这里
因此,几家公司专注于仅对特定任务表现良好的小模型。此外,微调的使用使得为特定应用创建小型专业模型成为可能。对于一些任务,如分类,微调需要一个带注释的数据集。收集这些带注释的数据集是昂贵的,因此使用的另一种技术是蒸馏。
蒸馏是一种技术,通过它你可以利用从更大模型生成的标签来训练一个小模型。收集这些未标记的数据集可能同样昂贵(例如,在医疗领域)。性能要求越高,成本也就越高。因此,使用微调或蒸馏来实现与大型语言模型(LLM)相同的性能可能在计算上是昂贵的。
因此,我们如何才能使小模型以数据和时间高效的方式从 LLM 中学习呢?
如何让 LLM 成为高效的教师
照片由 ThisisEngineering RAEng 提供,来源于 Unsplash
我不能教任何人任何东西;我只能让他们思考。——苏格拉底
当我们想训练一个小模型时,LLM 要么用来为未标记的文本生成标签,要么用于数据增强(从 LLM 生成的示例数据集中提取)。直观上,这可能不足以使模型学习高效。
例如,如果我想让我的小模型学习如何对推文进行排序(积极、消极或中立),我可以下载大量推文,通过 LLM 生成标签,然后用这些标签训练小模型。
蒸馏的示意图。图片由作者提供
然而,虽然这对于像推文分类这样的简单任务有效,但对于更复杂的任务来说是不够的。 我们可能会从互联网上下载谜题并让 LLM 解决它们,但解决方案本身并未提供关于解决过程的任何信息。一个通过解决方案训练的小模型不会学会如何解谜。
确实,要学会解决困难的任务(例如解谜),你需要比仅仅解决方案更多的信息。
实际上,这对于 LLM 也是如此,对于推理任务(算术、常识和符号推理),提供链式思维的上下文有助于模型得出解决方案而不会产生幻觉。
图片来源:这里
基于这一意图,一些谷歌研究人员甚至训练了在特定任务上超越 LLM 的小模型(770M 参数与 540BPaLM)。他们随后在最近发表的一篇论文中描述了这种方法。
## 逐步提炼!用更少的训练数据和更小的模型超越大型语言模型…
部署大型语言模型(LLMs)具有挑战性,因为它们在内存使用上效率低下,并且计算密集型……
arxiv.org
简而言之,作者利用了 LLM 进行推理的能力(超越单纯生成标签)。通过使用一个未标记的数据集,他们要求 LLM 生成正确的标签和推理(为什么这是最合适的标签的自然语言解释)。之后,他们使用标签和推理来训练小模型。
该方法的示意图。图像由作者提供
通过这种方式,他们不仅向小模型提供了问题的解决方案,还提供了老师(LLMs)如何得出该解决方案的过程。 此外,推理不仅包含解释,还包含理解任务的有用元素(这些元素从简单的输入中不易推断出,特别是对于参数有限的模型)。
图片来源:这里
逐步提炼
更详细地说,作者使用了与链式思维(CoT)相同的提示。一个提示包括一个问题、背景或推理,以及问题的答案。之后,将推理附加到问题上,模型必须给出答案。
图片来源:这里
小模型通过简单的多任务方法进行训练:它不仅需要预测正确的标签,还需要生成相应的推理。损失函数也会考虑生成推理时是否出现错误。
通过这种方式,作者迫使模型生成中间推理步骤,从而引导模型找到正确的答案。
从比喻的角度来看,这就像是一个老师强迫学生写下所有的推理步骤,而不是直接给出答案。这种方法的优点是,在测试时,模型将不再需要老师模型(LLM),而应该学会进行推理。
我们能否教会学生推理?
图片由Element5 Digital提供,来源于Unsplash
你告诉我,我会忘记。你教我,我会记住。你让我参与,我会学习。 – 本杰明·富兰克林
作者使用 PaLM (540 B 参数)作为 LLM 生成理由。他们选择使用T5作为小模型,使用现有的预训练权重检查点。有趣的是,作者使用了一个已经训练过的非常小的模型。通过这种方式,他们使用一个已经具备一般语言知识的模型,但可以适应特定任务。
模型比较,以更好地理解大小差异(圆圈按比例)。图片由作者提供,生成图片的脚本可以在这里找到
他们选择了三个特定的自然语言处理任务:
-
自然语言推理。 他们使用了两个不同的数据集:e-SNLI 和 ANLI。
-
常识问答 (CQA)。
-
算术数学应用题 (SVAMP)。
如所示,这些任务和数据集要求模型展示推理能力。
在文章中,该方法与两种经典方法进行了比较:
-
微调。 其中预训练模型在带有正确标签的注释示例上进行训练。
-
蒸馏。 在该方法中,LLM 用于生成真实标签。
结果显示,新的方法(逐步提炼)在所有基准数据集和任务中都优于标准微调,同时所需示例也远少于达到更好表现的标准。因此,这种方法性能更佳,同时成本更低(仅有 12.5%的示例表现超过传统微调)。
图片来源:这里
对于标准蒸馏而言,同样的新方法在性能上更优,并且所需的示例数量也少得多。
图片来源:这里
作者们随后使用不同版本的模型(220M、770M、11B)采用相同的方法,并与 LLM 基线(PaLM)进行比较。结果表明,新方法根据规模提高了性能(更大的模型表现更好)。此外,逐步蒸馏在某些任务上似乎甚至超越了 LLM 基线。换句话说,770M 模型在 ANLI 中超越了一个大 700 倍的模型。更令人印象深刻的是,在 e-SNLI 中,一个 220M 的模型超越了一个大 2000 倍的模型。
图片来源: 这里
在标准微调中,我们使用人工标注的数据,而在蒸馏中,我们使用未标注的数据。结果类似,显示模型即使从 LLM 标注的数据中也能学习。
图片来源: 这里
这些结果本身已经很令人印象深刻,但令人难以置信的是你不需要整个数据集。 即使仅用 0.1% 的数据集,该方法仍然有效。对于标准的微调和任务蒸馏,您需要更多的示例才能看到显著的性能。在 ANLI 中,对于 T5-770M,80% 的示例足以超越 PaLM 540B。即使使用完整的数据集,标准微调也无法达到 LLM 基线。
图片来源: 这里
图片来源: 这里
正如作者所提到的,尽管这种方法也适用于其他模型(如 20B GPT-NeoX 模型),但结果不如预期。这是因为 PaLM 提供了更高质量和更详细的推理。
图片来源: 这里
在一个消融研究中,他们注意到多任务训练效果更好。换句话说,让模型生成推理有助于它的学习。
图片来源: 这里
作者们也发布了供社区测试的代码:
## GitHub - google-research/distilling-step-by-step
通过在 GitHub 上创建帐户,您可以为 google-research/distilling-step-by-step 的开发做出贡献。
GitHub - google-research/distilling-step-by-step
结束语
照片由 Saif71.com 提供,来自 Unsplash
教育是创造所有其他职业的唯一职业。 – 无名氏
本文展示了如何利用 LLM 教导较小的模型解决特定任务。 超越结果,本文还展示了即使是较小的模型,通过提供上下文也能得出解决方案。因此,这种方法使用户能够用更少的数据提炼出一个小模型,并超越大型 LLM:
文章的示意图。图片来源:这里
作者在本文中展示了比 LLM 小 2000 倍的模型能够学习并在复杂任务(如推理任务)上超越教师模型。此外,与经典的逐步提炼方法相比,它需要的数据要少得多。
一般来说,近年来模型学习研究发生了范式转变,试图将记忆与实际学习分开。
[## 理解:学习是泛化而非记忆
理解神经网络如何学习可以帮助我们避免模型忘记所学内容。
levelup.gitconnected.com
确实,本文表明,要执行特定任务,你并不需要大容量(记忆)。你可以教导一个小模型通过提供解决问题的信息来学习任务(泛化)。
这项工作很重要,因为用少量数据,可以训练出一个更小的模型在任务上表现出色。这些模型可以以更低的成本更容易地部署。此外,这种方法适用于任何模型,因此用户可以使用开源模型(如 LLaMA)或专有模型(GPT-4 或 PaLM)的 API 进行逐步提炼,创建自己的专业模型。
这项工作开辟了许多令人兴奋的可能性,如以低成本开发适用于多个应用的专业模型,并且其性能优于巨型模型。这些模型不仅可以在线部署,还可以在桌面计算机或手机应用中使用。因此,拥有一个小而专有的数据集,你可以用有限的资源开发和部署专家模型。
例如,你可以设想一个用户开发一个专门解决谜题的小模型。你只需与 LLM 创建推理,使用逐步提炼来训练你的专家模型,然后甚至可以将其部署到手机应用上。
TL;DR
-
Google 公布了一种新的简单方法来从大型模型中提取知识。通过使用推理和答案,你可以教导一个小模型(甚至小 2000 倍)在推理任务中超越 LLM。
-
这种方法超越了之前的最新技术。
-
这种方法只需要一个小的训练集和较小的模型尺寸
-
这种方法使得可以为专业任务部署独立的语言模型。现在模型尺寸与网页应用兼容,并且可以在设备上进行推理,无需复杂的基础设施。
你怎么看?在评论中告诉我
如果你觉得这很有趣:
你可以查看我的其他文章,你也可以 订阅 以便在我发布文章时收到通知,也可以在LinkedIn上联系或找到我。
这是我 GitHub 仓库的链接,我计划在这里收集与机器学习、人工智能等相关的代码和资源。
[## GitHub - SalvatoreRa/tutorial: 机器学习、人工智能、数据科学的教程…
关于机器学习、人工智能、数据科学的教程,包括数学解释和可重复使用的代码(用 Python 编写…
github.com](https://github.com/SalvatoreRa/tutorial?source=post_page-----f131f9d463e1--------------------------------)
或者你可能对我的一篇近期文章感兴趣:
## 无需重新训练即可重塑模型的记忆
擦除大型语言模型所学到的有问题内容的任何回响
towardsdatascience.com [## 超越语言:用 AI 解码脑波中的言语
AI 能够从非侵入性脑记录中解码语言
levelup.gitconnected.com](https://levelup.gitconnected.com/beyond-words-unraveling-speech-from-brain-waves-with-ai-7ff81862dfff?source=post_page-----f131f9d463e1--------------------------------)
参考文献
这是我在撰写本文时参考的主要文献列表(只引用了每篇文章的第一作者姓名)。
-
傅, 2023, 《将较小的语言模型专门化为多步骤推理》, 链接
-
辛顿, 2015, 《提炼神经网络中的知识》, 链接
-
霍华德, 2018, 《通用语言模型微调用于文本分类》, 链接
-
卡普兰, 2020, 《神经语言模型的规模定律》, 链接
-
韦, 2022, 《链式思维提示在大型语言模型中引发推理》, 链接
-
Hsieh, 2023, 逐步提炼!以更少的训练数据和更小的模型尺寸超越更大的语言模型,链接
-
Chowdhery, 2022, PaLM: 通过路径扩展语言建模,链接
-
Raffel, 2019, 使用统一的文本到文本转换器探索迁移学习的极限,链接
教授语言模型使用工具
原文:
towardsdatascience.com/teaching-language-models-to-use-tools-7fd58916c66b
使用工具让我们作为人类更具能力。LLMs 是否也是如此?
Cameron R. Wolfe, Ph.D.
·发表于 Towards Data Science ·阅读时间 17 分钟·2023 年 8 月 27 日
--
(照片由 Barn Images 提供,来自 Unsplash)
随着我们对大语言模型(LLMs)了解的深入,这些模型变得越来越有趣。这些模型能够准确解决各种复杂任务。然而,与此同时,它们在某些我们人类认为基本的功能上却存在困难!例如,LLMs 常常犯算术错误,缺乏获取当前信息的能力,甚至难以理解时间的进程。鉴于这些局限性,我们可能会想,如何才能使 LLMs 更具能力?LLMs 注定要永远受到这些局限的困扰吗?
人类历史上的许多进步都由获得新的创新工具(例如 印刷机 或 计算机)所推动。相同的发现是否适用于 LLMs? 在这篇概述中,我们将研究一个最新的研究方向,旨在教会 LLMs 如何使用外部工具,这些工具通过简单的文本到文本的 API 提供。通过使用这些工具,LLMs 可以将执行算术或查找当前信息等任务委派给专门的工具。然后,这些工具返回的信息可以被 LLM 在生成输出时用作上下文,从而产生更准确和有依据的响应。
(来自 1 和 ChatGPT Plus)
使 LLMs 更具能力
为 LLM 提供外部工具是一种可靠的方法,可以解决这些模型面临的一些限制。然而,LLM 不会自然地知道如何使用工具,这就提出了一个问题:我们如何教会模型利用外部工具? 在本节中,我们将探讨我们拥有的一些选项,并列举对构建 LLM 应用程序有用的各种工具。
不同类型的学习
LLM 的不同学习形式(作者创建)
教会 LLM 利用工具与学习如何解决其他任务没有什么不同。由于这些模型以几种不同的方式学习,我们将在这里讨论 LLM 的主要学习形式。本文之外,网上还有详细解释。
预训练。 LLM 的第一个和最基本的学习形式是预训练。在预训练过程中,模型在大量未标记的文本数据上进行训练,使用语言建模目标。预训练过程从随机初始化开始,计算成本相当高。通常,由于计算成本,预训练只执行一次——我们不希望频繁重复预训练过程!值得注意的是,预训练的计算成本解释了像 ChatGPT 这样的 LLM 中存在知识截止点的原因。这些模型在预训练期间学习所有知识,因此知识截止点仅与最近预训练期间存在的数据相关。
LLM 的微调方法(来自[11])
微调。 在预训练之后,LLM 可以准确地执行下一个标记预测,但这并不总是意味着它们实际上有用。例如,如果我们玩一下GPT-2的演示,只需 2 分钟,我们立刻会发现准确预测下一个标记可能会导致一些相当无聊和无用的输出!因此,我们通常在预训练之后对 LLM 进行微调,通常通过监督微调(SFT)或从人类反馈中进行强化学习(RLHF);详情见上面的图片和这里。虽然这些技术的细节超出了本文的范围,但基本的思路是:
-
筛选更多的训练数据(例如,针对我们要解决的任务的领域数据、正确对话的示例、人类对 LLM 输出的反馈等)。
-
使用强化学习或带有(自我)监督目标的梯度下降对模型参数进行训练。
通过这样做,我们可以完成很多事情!例如,使用 RLHF 对 LLM 进行微调 [11] 已被证明可以使 LLM 更有趣、更准确、更有帮助。更进一步,Meta 最近的 LIMA 出版物显示,通过仅 1,000 个高质量对话示例进行 SFT,可以生成一个与 GPT-4 质量相媲美的模型 [12]。简单来说,微调将我们从一个普通的 LLM 提升到真正特别且有用的水平。
(来自 [7])
上下文学习。 我们应当了解的最终学习形式是上下文学习;见上文。上下文学习不同于预训练和微调,它并不会实际修改底层模型的参数。相反,我们通过修改提示来教 LLM 更有效地解决问题!特别是,我们可以通过使用特定的提示技术重新表述提示,甚至将数据插入提示中以进行少样本学习。微调和上下文学习之间的区别如下所示。
(来自 [7])
上下文学习是极其强大的,因为它允许我们使用单一模型解决各种不同的任务。我们可以将有用的数据插入到 LLM 的提示中,而不是微调模型或修改其底层参数。LLM 可以从这些数据中学习,更准确地解决任务,而无需修改模型本身!此外,我们可以使用预训练模型和微调模型进行上下文学习。要了解可以与 LLM 配合使用的提示技术,请查看下面的概述:
-
实用提示 [link]
-
高级提示 [link]
-
思维链提示 [link]
-
提示集合 [link]
对 LLM 有用的工具有哪些?
尽管将 LLM(大语言模型)与外部工具连接的想法很诱人,但我们可能会想:哪些工具最有用? 为了回答这个问题,我们应当关注 LLM 的常见局限性,例如:
-
缺乏访问最新信息 2
-
有产生幻觉的倾向(即,输出不正确的信息)
-
处理数学表达式的困难
-
对低资源语言的理解不完全
-
无法理解时间的推移[8]
如果我们想解决这些问题,我们有几个选项。我们可以专注于通过SFT 或 RLHF对模型进行微调和完善——彻底微调模型以避免上述行为。实际上,大量资源已经投入到通过目标人类反馈来完善像GPT-4这样的模型,这也取得了相当令人印象深刻的结果。然而,我们也可以选择将重点放在让模型采取间接但通常更可靠的方法,而不是在模型内部解决这些问题。特别是,我们可以教会模型如何使用外部工具来帮助回答问题!
工具如何提供帮助? 在解决问题时,LLM 通常会通过查询一个可以提供更多上下文的外部工具来获得帮助。值得注意的有用工具包括(但不限于):
-
能够返回当前日期的日历应用
-
能够评估数学表达式的计算器
-
向量数据库用于存储(可能)相关但无法直接存储在提示中的大量信息。
-
将数据转换为不同语言的翻译模块
总的来说,工具在提供额外信息或上下文来帮助 LLM 解决问题时极为有用。超越这些简单的工具,我们甚至可以将 LLM 连接到外部代码解释器,使其能够编写和执行任意程序。结合支持代码的 LLM(例如,Codex [10]),这种方法实际上可以非常强大!更多信息请见这里。
工具非常受欢迎!
尽管本概述将主要关注最近研究的工具与 LLM 集成,但通过外部工具增强模型(如 GPT-4)已成为近期关注的主题。例如,OpenAI 最近发布了一个模型插件扩展,使这些强大的 LLM 能够利用大量外部工具;见下文。
ChatGPT Plus 插件商店中的热门应用(来自 ChatGPT Plus)
截至撰写时,GPT-4 有近 130 种不同的插件可用,这展示了将各种工具与强大的 LLM 集成的巨大兴趣。超越第三方插件,OpenAI 最近为 GPT-4 发布了代码解释器和互联网搜索工具。互联网搜索工具对于减轻 LLM 中的幻觉非常有用,因为模型提供的答案可以通过从互联网获取的相关、最新信息进行情境化。除了使 LLM 更具事实性和基础性外,代码解释器工具能够处理大量代码和数据文件并对这些数据进行准确分析,以提供有价值的见解。
TL;DR: 主要结论是,工具正在成为 LLM 的一个常见特性。除了 OpenAI 的产品外,我们甚至看到像 Bard 这样的模型正在增强类似功能,而像 LangChain 这样的开源库可以用来轻松构建多种工具类功能供现有 LLM 使用。
教授 LLM 使用工具
(来自1)
在1中,作者探讨了一种名为 Toolformer 的方法,它i) 教授 LLM 如何利用外部工具,并且ii) 在过程中保持 LLM 的通用性质。这些工具通过一组简单的文本到文本的 API 提供给 LLM(即模型提供文本作为输入,API 返回文本输出)。有趣的是,我们在1中看到 LLM 可以完全端到端地学习如何利用这些工具。模型决定调用哪些 API,向这些 API 传递哪些参数,并且如何最佳地利用返回的信息,而无需任何硬编码的控制流。
“语言模型可以学习控制各种工具,并自行选择何时、如何使用哪个工具。” — 来源于1
为了做到这一点,我们策划了一个训练数据集,展示了这些工具的正确使用。在1中,这个数据集是使用自监督启发式方法自动创建的——意味着不需要人工干预——只需为每个工具提供几个使用示例。然后,我们在这些数据上微调 LLM,使其学习每个工具的正确使用方法。结果是一个高性能的 LLM,它可以将简单但困难的子任务(如语言翻译、算术运算、访问当前信息等)委托给专门的外部工具,这些工具返回相关且准确的数据供 LLM 生成输出。
(来自1)
使用了哪些工具? 在1中,Toolformer 使用了以下固定的一组工具:
-
问答工具: 基于 Atlas [13],一种针对回答简单、基于事实的问题进行微调的 LLM。
-
计算器: 用于数值运算的基本计算器。
-
维基百科搜索工具: 一个搜索引擎,给定搜索词返回来自维基百科的简短文本片段。
-
翻译器: 一个可以将任何语言的文本翻译成英文的语言翻译系统(但不能反向翻译!)。
-
日历: 一个在查询时只返回当前日期的工具。
这些工具都通过一个简单的文本到文本结构的 API 提供;见上文。要使用这些工具,LLM 必须学习* i)* 识别需要工具的场景,ii) 指定使用哪个工具,iii) 向工具的 API 提供相关的文本输入,以及iv) 使用从 API 返回的文本来制作响应。值得注意的是,这些 API 简单的文本到文本结构允许我们轻松地将工具使用示例直接插入到文本序列中;见下文。
对外部 API 的调用以文本格式呈现,并与现有文本序列内嵌在一起(来自1)
相较于以前工作的改进。 让 LLM 使用外部工具并不是一个新想法。例如,许多研究者尝试通过让 LLM 访问外部计算器来提高其在算术——特别是大数计算——方面的能力(见[4]的附录 B)。然而,主要问题是:我们应该如何教 LLM 使用这样的工具? 以前的方法严重依赖于人类标注的数据集。例如,LaMDA[3]使用外部搜索工具来减少幻觉;见下文。
(来自[3])
然而,我们在[3]中看到,教会 LaMDA 利用外部工具——在这个例子中是外部的信息检索系统——需要大量的人类标注数据。更具体地说,[3]中的作者让大量的众包工人手动编写对话,利用与 LLM 相同的搜索工具,从而提供了 LLM 应如何行为和回应的示例。相关出版物往往依赖于类似的人类中心方法2。创建这样的数据集困难、昂贵且耗时,这促使1中的作者开发了更高效的解决方案。
自动学习使用工具。 在1中,我们看到一个用于教 LLM 如何利用外部工具的数据集——为了简单起见,我们称之为“工具跟随数据集”——可以通过利用现有的、预训练的 LLM 的提示方法自动创建。我们从一个初始(正常)数据集开始,例如用于预训练的文本语料库。然后,我们提示一个预训练的 LLM 用外部 API 调用来增强这些数据。在这里,我们依赖于通用预训练 LLM 的上下文学习能力,来策划一组 API 调用,展示如何正确使用可用工具。下面展示了一个生成请求到问答工具 API 的示例提示。
(来自 1)
在我们用每个工具的示例用法扩充了数据集之后,我们需要执行过滤步骤。这一步骤是必要的,因为我们只希望在工具实际上对 LLM 有帮助时才使用外部工具!我们不应该在不需要时总是依赖外部工具——使用工具通常会有延迟(甚至是经济)成本。为了捕捉这个想法,我们可以这样做:
-
使用工具测量 LLM 的性能(即,交叉熵损失在 API 调用之后的标记上)。
-
测量 LLM 在没有工具情况下的性能。
-
丢弃那些使用工具未能使 LLM 的性能超越某个阈值的示例。
在这里,我们假设可以访问一个演示 LLM 应产生正确输出的数据集。通过这种方法,我们可以自动构建一个包含示例的数据集,说明何时以及如何利用工具来实际改善 LLM 的输出。在实践中,实际过程要复杂一些。具体来说,为了在没有工具的情况下测量 LLM 的性能,我们观察两个独立的情况——一个是完全不使用工具的情况,另一个是执行 API 调用但不提供响应的情况。这种方法确保了工具及其数据对 LLM 的有用性。
“如果提供此调用的输入和输出使得预测未来的标记更容易,则 API 调用对[语言模型]是有帮助的”— 来自 1
此外,我们没有将 API 调用插入到文本序列中,而是将其作为前缀附加,这样可以避免 LLM 损失的波动。记住,这样的 API 调用在 LLM 的原始预训练语料库中不存在,这意味着直接将 API 调用插入文本序列可能会扭曲用于过滤的结果。模型并不期望在数据中看到这样的 API 调用! 此外,在测量性能时,我们为 API 调用空间上接近的标记分配更高的权重,确保 API 调用发生在所需的地方,而不是在生成输出时的随机时刻。
(来自 1)
1中使用的工具跟随数据集的完整构建过程如上所示。与以前的工作不同,这个过程不需要人工劳动。相反,我们利用 LLM 的上下文学习能力和一些巧妙的启发式方法来自动构建数据集。尽管这个过程并不完美(即,某些无用的 API 调用可能会避免过滤),但在实践中效果相当好!
学习使用工具。 一旦我们构建了数据集,教会 LLM 如何利用外部工具是很容易的——我们只需使用标准语言建模目标对模型进行微调。在1中,工具跟随数据集来源于预训练语料库。因此,尽管微调后的 LLM 能够利用外部工具,但它仍然是一个通用模型。此外,由于1中的筛选过程会去除那些不利于性能的 API 调用,LLM 会在隐含中学习何时以及如何使用每个工具以提升其输出。这种简单的方法取得了相当酷的结果!
工具是否有影响?
在1中分析的模型基于GPT-J [5],这是一个拥有 60 亿参数的语言模型,并且采用了CCNet作为训练数据集。Toolformer 与多个基准模型进行了比较,包括禁用 API 调用的 Toolformer 模型、原始的 GPT-J 模型、在 CCNet 上微调的 GPT-J 版本,以及其他一些 LLM,如OPT [6]和GPT-3 [7]。与研究少样本学习的先前工作不同,这些模型使用零样本方法进行评估,这种方法只是简单地向模型描述任务而不提供任何示例,并且使用了贪婪解码策略。在 Toolformer 中,只要<API>
(即 API 调用的起始标记)出现在模型的k
个最可能标记之一中,就会利用工具。
Toolformer 在多个不同领域中进行了评估。在基于事实的数据集上,我们发现问答工具被大量利用,相比基准模型的准确率显著提高。同样,在数学推理数据集上,计算器工具也被发现非常有用;见下文。
(来自1)
在(多语言)问答基准上,模型的表现并不像预期那样令人印象深刻(即,Toolformer 在某些情况下不及 GPT-3 或 GPT-J 的表现)。然而,某些工具,如日历工具,被发现对提升 LLM 在时间推理等任务上的表现非常有用。有趣的是,作者还进行了一些分析,修改了 LLM 解码策略中 API 调用的概率。通过这项分析,我们了解到更频繁地利用外部工具并不总是好事——如果工具使用过于频繁,性能会下降;见下文。
(来自1)
这样的发现突显了1中使用的过滤策略的重要性。工具使用不仅有成本,而且可能会降低性能。LLM 必须学习在何种场景下调用工具最为重要。1中采取的方法明确地使 LLM 在仅在显著提升模型性能时才利用外部工具。
(摘自1)
保持通用。 除了上述下游评估,1中的作者在工具跟随数据集微调后,在预训练数据集的留出部分上评估了 Toolformer,发现模型在微调前后都达到可比的困惑度;如上所述。换句话说,Toolformer 在学习如何利用外部工具时不会丧失作为通用语言模型的任何能力,这意味着—与先前以任务特定方式接近工具跟随的工作不同[8]—该模型仍然是一个基础模型,能够解决各种不同的任务。
使用工具变得越来越简单
尽管1中提出的方法具有突破性并且信息量巨大,但它仍然需要一个广泛的微调过程。与大多数最近应用的 LLM 相比,这确实是一个麻烦!我们是否可以利用仅通过提示的方法来教会 LLM 使用外部工具? 最近围绕 GPT-4 的进展表明,这个问题可能通过提高 LLM 的指令跟随能力来解决。
(作者创建)
GPT-4 插件工作流程。 例如,GPT-4 可以通过插件商店访问各种工具。然而,模型并没有明确地针对商店中的每个插件进行微调。相反,它只是使用上下文学习。特别是,OpenAI 在提升 GPT-4 的可控性方面投入了大量资金,这使得模型能够非常详细地跟随指令和提示。因此,教会 GPT-4 如何使用插件只需要:
-
描述插件目的的文本描述
-
描述插件 API 的输入/输出格式的架构
使用这些信息,模型可以自行决定何时使用插件,进行格式正确的 API 调用,并将结果信息整合到对话中。这一切都是通过文本描述完成的,没有任何明确的微调,这表明教会 LLM 利用外部工具可能会随着时间的推移变得更加容易。要更详细地了解这一过程,我们可以查看 开源插件实现 或 OpenAI 插件开发文档。
结语
类似于人类在使用工具(例如,锤子、计算机、飞机等)后变得更好,LLMs 在获得一组可以提供有用信息或执行简单任务的简单 API 时也变得更有能力。为什么我们要完全依赖 LLM 解决一切问题,而不是将困难的任务委派给更准确、更专业的工具? 我们可以使用这种方法来缓解这些模型常常遇到的问题,例如输出中的不正确信息或缺乏时间推理能力。通过 Toolformer 1,我们看到 LLM 可以通过对工具跟随示例的数据集进行微调来学习利用外部工具。但是,最近的趋势表明,仅通过上下文学习可能就能教会 LLM 使用外部工具。这个领域还有很多未被揭示的内容,观察这些主题和相关应用随时间的发展将会很有趣!
与我联系!
非常感谢你阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的人工智能总监。我研究深度学习的实证和理论基础。如果你喜欢这个概述,请订阅我的 Deep (Learning) Focus 新闻通讯,在这里我通过从基础开始概述相关主题,帮助读者理解 AI 研究。你还可以在 X 和 LinkedIn 上关注我,或者查看我在 Medium 上的 其他文章!
参考文献
1 Schick, Timo 等人. “Toolformer: 语言模型可以自我学习使用工具。” arXiv 预印本 arXiv:2302.04761 (2023)。
2 Komeili, Mojtaba, Kurt Shuster 和 Jason Weston. “互联网增强的对话生成。” arXiv 预印本 arXiv:2107.07566 (2021)。
[3] Thoppilan, Romal 等人. “Lamda: 对话应用的语言模型。” arXiv 预印本 arXiv:2201.08239 (2022)。
[4] Wei, Jason 等人. “思维链提示引发大型语言模型的推理。” arXiv 预印本 arXiv:2201.11903 (2022)。
[5] Wang, Ben 和 Aran Komatsuzaki. “GPT-J-6B: 一种 60 亿参数的自回归语言模型。” (2021)。
[6] 张苏珊等,“Opt: 开放预训练变换器语言模型。” arXiv 预印本 arXiv:2205.01068 (2022)。
[7] 布朗·汤姆等,“语言模型是少样本学习者。” 神经信息处理系统进展 33 (2020): 1877–1901。
[8] 帕里西·亚伦、姚赵和诺亚·费德尔,“Talm: 工具增强语言模型。” arXiv 预印本 arXiv:2205.12255 (2022)。
[9] 丁格拉·布万等,“时间感知语言模型作为时间知识库。” 计算语言学学会会刊 10 (2022): 257–273。
[10] 陈马克等,“评估基于代码训练的大型语言模型。” arXiv 预印本 arXiv:2107.03374 (2021)。
[11] 欧阳龙等,“训练语言模型以遵循人类反馈的指令。” 神经信息处理系统进展 35 (2022): 27730–27744。
[12] 周春婷等,“Lima: 对齐的少即是多。” arXiv 预印本 arXiv:2305.11206 (2023)。
[13] 伊扎卡德·戈蒂埃等,“Atlas: 带检索增强的语言模型的少样本学习。” arXiv 预印本 arXiv 2208 (2022)。
时间差学习及探索的重要性:图解指南
原文:
towardsdatascience.com/temporal-difference-learning-and-the-importance-of-exploration-an-illustrated-guide-5f9c3371413a?source=collection_archive---------2-----------------------#2023-09-23
在动态网格世界中比较无模型和有模型的强化学习方法
Ryan Pégoud
·
关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 9 月 23 日
--
图片来源:Saffu 供图于 Unsplash
最近,强化学习(RL)算法因解决诸如蛋白质折叠、在无人机竞速中达到超人类水平,甚至在你喜欢的聊天机器人中整合人类反馈等研究问题而受到广泛关注。
的确,RL 为各种顺序决策问题提供了有用的解决方案。时间差分学习(TD 学习)方法是 RL 算法中的一个流行子集。TD 学习方法结合了蒙特卡洛和动态规划方法的关键方面,以加速学习而不需要完美的环境动态模型。
在这篇文章中,我们将比较不同类型的TD 算法在自定义网格世界中的表现。实验设计将展示持续探索的重要性以及被测试算法的个体特征:Q-learning、Dyna-Q 和 Dyna-Q+。
本文的概要包括:
-
环境描述
-
时间差分(TD)学习
-
无模型 TD 方法(Q-learning)和基于模型的 TD 方法(Dyna-Q 和 Dyna-Q+)
-
参数
-
性能比较
-
结论
允许重现结果和图表的完整代码可以在这里找到: github.com/RPegoud/Temporal-Difference-learning
环境
我们将在此实验中使用的环境是一个具有以下特征的网格世界:
-
网格是 12 x 8 单元格。
-
代理从网格的左下角开始,目标是到达位于右上角的宝藏(一个终端状态,奖励为 1)。
-
蓝色传送门是相连的,通过位于单元格(10, 6)的传送门到达单元格(11, 0)。代理在第一次过渡后不能再次使用该传送门。
-
紫色传送门仅在100 个剧集后出现,但能使代理更快到达宝藏。这鼓励持续探索环境。
-
红色传送门是陷阱(终端状态,奖励为 0),并结束剧集。
-
碰到墙壁会导致代理保持在同一状态。
网格世界不同组件的描述(由作者制作)
本实验旨在比较 Q-learning、Dyna-Q 和 Dyna-Q+ 代理在变化环境中的行为。的确,在100 个剧集之后,最优策略必定会发生变化,成功剧集中的最优步骤数将从17减少到12。
网格世界的表示,最优路径依赖于当前的剧集(由作者制作)
时间差分学习介绍:
时间差分学习是蒙特卡洛(MC)和动态规划(DP)方法的组合:
-
与 MC 方法类似,TD 方法可以从经验中学习而不需要环境动态模型。
-
与 DP 方法类似,TD 方法在每一步后更新估计,基于其他学习到的估计,而不是等待结果(这称为 自举)。
TD 方法的一个特点是,它们在每个时间步都更新其价值估计,而 MC 方法则等到回合结束。
确实,这两种方法有不同的更新目标。MC 方法旨在更新回报Gt,它仅在一个回合结束时可用。而 TD 方法则针对:
TD 方法的更新目标
其中V是真实价值函数 Vπ的估计。
因此,TD 方法结合了MC的采样(通过使用真实价值的估计)和DP的自举(通过基于进一步估计的估计更新 V)。
时间差分学习的最简单版本称为TD(0)或一步 TD,实际实现 TD(0)看起来像这样:
TD(0)算法的伪代码,摘自《强化学习导论》[4]
当从状态S转移到新状态S’时,TD(0)算法将计算备份值并相应地更新V(S)。这个备份值称为 TD 误差,即观察到的奖励R加上新状态γV(St+1)的折扣值与当前价值估计V(S)之间的差异:
TD 误差
总之,TD 方法具有若干优点:
-
它们不需要环境动态的完美模型p
-
它们以在线方式实现,在每个时间步后更新目标
-
如果α(学习率或步长)遵循随机逼近条件,TD(0)保证会在任何固定策略π下收敛(更多细节请参见[4]第 55 页“追踪非平稳问题”)
实现细节:
以下各节探讨了多个 TD 算法在网格世界中的主要特性和性能。
为了简化起见,所有模型使用了相同的参数:
-
Epsilon (ε) = 0.1:在ε-贪心策略中选择随机动作的概率
-
Gamma (γ) = 0.9:应用于未来奖励或价值估计的折扣因子
-
Aplha (α) = 0.25:限制 Q 值更新的学习率
-
Planning steps = 100:对于 Dyna-Q 和 Dyna-Q+,每次直接交互执行的规划步骤数量
-
Kappa (κ**) = 0.001:对于 Dyna-Q+,在规划步骤中应用的奖励加权
每个算法的性能首先在单次运行 400 个回合的基础上进行展示(部分:Q 学习、Dyna-Q和Dyna-Q+),然后在“总结与算法比较”部分对 100 次运行 250 回合的数据进行平均。
Q 学习
我们在这里实现的第一个算法是著名的 Q 学习(Watkins, 1989):
Q 学习被称为离策略算法,因为其目标是直接逼近最优值函数,而不是代理遵循的策略π的值函数。
实际上,Q 学习仍然依赖于一个策略,通常称为‘行为策略’,以选择哪些状态-动作对被访问和更新。然而,Q 学习是离策略的,因为它基于未来奖励的最佳估计来更新其 Q 值,无论所选动作是否遵循当前策略π。
与之前的 TD 学习伪代码相比,有三个主要区别:
-
我们需要初始化所有状态和动作的 Q 函数,并且 Q(terminal)应为 0
-
动作是从基于 Q 值的策略中选择的(例如相对于 Q 值的ϵ-贪心策略)
-
更新的目标是动作值函数 Q 而非状态值函数 V
Q 学习算法的伪代码,摘自《强化学习导论》[4]
现在我们有了第一个算法读取用于测试,我们可以开始训练阶段。我们的代理将使用其ε-贪心策略在网格世界中导航,相对于 Q 值。该策略以(1 - ε)的概率选择最高 Q 值的动作,并以ε的概率选择随机动作。每次行动后,代理将更新其 Q 值估计。
我们可以使用热图可视化每个网格世界单元的估计最大动作值 Q(S, a)的演变。这里代理器进行 400 个回合。由于每个回合只有一次更新,Q 值的演变较慢,大部分状态仍未映射:
训练过程中学习到的每个状态的 Q 值的热图表示(作者提供)
完成 400 个回合后,对每个单元总访问次数的分析为我们提供了代理平均路径的合理估计。如下面右侧图所示,代理似乎已收敛到一个次优路径,避免了单元(4,4),并且始终沿着下墙行进。
(左)每个状态的最大动作值估计,(右)每个状态的访问次数(作者提供)
由于这种次优策略,代理在每回合达到最少21 步,遵循“总访问次数”图中勾画的路径。步骤数量的变化可归因于ε-贪心策略,该策略引入了 10%的随机动作概率。鉴于这一策略,沿下墙行进是一种限制随机动作带来的潜在干扰的不错策略。
训练最后 100 回合的步数(300–400)(作者提供)
总结来说,Q 学习代理如前所述收敛于次优策略。此外,Q 函数仍有一部分环境是未被探索的,这阻止了代理在第 100 集后出现紫色传送门时找到新的最佳路径。
这些性能限制可以归因于相对较少的训练步骤(400),这限制了与环境互动的可能性以及 ε-贪婪策略引发的探索。
规划,作为基于模型的强化学习方法的一个基本组成部分,特别有助于提高样本效率和动作价值的估计。Dyna-Q 和 Dyna-Q+ 是结合了规划步骤的 TD 算法的良好示例。
Dyna-Q
Dyna-Q 算法(动态 Q 学习)是基于模型的强化学习和TD 学习的结合体。
基于模型的强化学习算法依赖于环境模型,将规划作为其更新价值估计的主要方式。相比之下,无模型算法依赖于直接学习。
“环境模型是代理可以用来预测环境如何对其动作做出响应的任何东西” — 强化学习:导论。
在本文的范围内,模型可以被视为对转移动态 p(s', r|s, a) 的近似。这里,p 返回一个单一的下一个状态和奖励对,给定当前状态-动作对。
在随机的环境中,我们区分分布模型和样本模型,前者返回下一个状态和动作的分布,而后者返回从估计分布中抽样得到的单一对。
模型特别有助于模拟情节,因此通过用规划步骤替代现实世界的互动来训练代理,即与模拟环境的互动。
实施 Dyna-Q 算法的代理是规划代理的一部分,这些代理结合了直接强化学习和模型学习。它们使用与环境的直接互动来更新它们的价值函数(如 Q 学习所示),同时也学习环境的模型。在每次直接互动之后,它们还可以执行规划步骤,通过模拟互动来更新它们的价值函数。
一个快速的国际象棋示例
想象一下玩一局好的国际象棋。每次走一步棋后,你对手的反应让你评估你的走棋质量。这类似于收到正面或负面的奖励,这让你可以“更新”你的策略。如果你的走棋导致了失误,你可能不会再这样做,前提是棋盘的配置相同。到目前为止,这与直接强化学习是类似的。
现在,让我们加入规划。假设在你每次移动后,当对手思考时,你在脑海中回顾你的每一次移动以重新评估它们的质量。你可能会发现最初忽视的弱点,或发现某些移动比你想象的更好。这些思考还可能让你更新策略。这正是规划的意义,在不与真实环境交互的情况下更新值函数,而是对环境的模型。
计划、行动、模型学习和直接强化学习:一个规划代理的时间表(由作者制定)
因此,Dyna-Q 相比 Q 学习包含了一些额外的步骤:
在每次直接更新 Q 值后,模型会存储观察到的状态-动作对、奖励和下一个状态。这个步骤称为模型训练。
-
在模型训练后,Dyna-Q 执行n规划步骤:
-
从模型缓冲区中选择一个随机的状态-动作对(即这个状态-动作对是在直接交互中观察到的)
-
模型生成模拟的奖励和下一个状态
-
值函数通过模拟观察进行更新(s, a, r, s’)
Dyna-Q 算法的伪代码,摘自《强化学习简介》[4]
我们现在使用n=100来复制 Dyna-Q 算法的学习过程。这意味着在每次与环境的直接交互后,我们使用模型执行 100 次规划步骤(即更新)。
下图热力图展示了 Dyna-Q 模型的快速收敛。事实上,算法只需约10 个回合即可找到最优路径。这是因为每一步会导致 Q 值的 101 次更新(而 Q 学习只更新 1 次)。
训练期间每个状态的学习 Q 值的热力图表示(由作者制作)
规划步骤的另一个好处是更好地估计网格中的动作值。由于间接更新针对的是存储在模型中的随机过渡,距离目标较远的状态也会被更新。
相比之下,Q 学习中的动作值会从目标点缓慢传播,导致网格的映射不完整。
(左)每个状态的最大动作值估计,(右)每个状态的访问次数(由作者制作)
使用 Dyna-Q,我们找到一个最优路径,允许在17 步内解决网格世界,如下图红条所示。尽管为了探索偶尔会有ε-贪婪行为的干扰,最佳表现仍然会定期达到。
最终,虽然 Dyna-Q 由于引入了规划,可能看起来比 Q-learning 更具说服力,但需要记住的是,规划带来了权衡,在计算成本和现实世界探索之间。
训练的最后 100 集(300–400)的步骤数(作者制作)
Dyna-Q+
到目前为止,测试的算法没有一个能找到第 100 步之后出现的最优路径(紫色传送门)。实际上,这两个算法都迅速收敛到一个在训练阶段结束前保持固定的最优解决方案。这突显了持续探索在训练过程中的必要性。
Dyna-Q+与 Dyna-Q 大致相似,但在算法上增加了一个小变化。实际上,Dyna-Q+不断跟踪自每个状态-动作对在与环境的真实交互中尝试以来所经过的时间步数。
特别地,考虑一个奖励r的转移,该转移在τ时间步中没有被尝试。Dyna-Q+会进行规划,假设该转移的奖励为r + κ √τ,其中κ足够小(实验中为 0.001)。
这种奖励设计的变化鼓励智能体持续探索环境。它假设状态-动作对未被尝试的时间越长,这对的动态发生变化或模型不正确的可能性就越大。
Dyna-Q+算法的伪代码,摘自《强化学习导论》[4]
如下热图所示,与之前的算法相比,Dyna-Q+在更新方面更加活跃。在第 100 集之前,智能体探索了整个网格,找到了蓝色传送门和第一个最优路线。
网格其余部分的动作值在减少后再缓慢增加,因为左上角的状态-动作对在一段时间内没有被探索。
当紫色传送门在第 100 集出现时,智能体找到新的捷径,整个区域的值上升。在完成 400 集之前,智能体将不断更新每个状态-动作对的动作值,同时保持对网格的偶尔探索。
训练过程中每个状态的学习 Q 值的热图表示(作者制作)
多亏了对模型奖励的额外奖金,我们最终得到了Q 函数的完整映射(每个状态或单元都有一个动作值)。
结合持续探索,智能体能够找到出现的新最佳路线(即最优策略),同时保留以前的解决方案。
(左)每个状态的最大动作值估计,(右)每个状态的访问次数(作者制作)
然而,Dyna-Q+ 中的探索与利用权衡确实带来了成本。当状态-动作对在足够长时间内未被访问时,探索奖励会鼓励代理重新访问这些状态,这可能会暂时降低其即时性能。这种探索行为优先更新模型以改善长期决策。
这解释了为什么 Dyna-Q+ 有些回合可以长达 70 步,而 Q 学习和 Dyna-Q 最多为 35 步和 25 步。Dyna-Q+ 中较长的回合反映了代理愿意投入额外的步数进行探索,以获取更多关于环境的信息并完善其模型,即使这会导致短期性能下降。
相比之下,Dyna-Q+ 经常实现最佳性能(如下图中的绿色条形图所示),这是以前的算法未能达到的。
训练最后 100 回合的步数(300–400)(作者提供)
总结与算法比较
为了比较算法之间的关键差异,我们使用了两个指标(请注意,结果依赖于输入参数,为简化起见,所有模型的输入参数均相同):
-
每回合步数:该指标描述了算法向最优解收敛的速度。它还描述了算法在收敛后的行为,特别是在探索方面。
-
平均累计奖励:指导致正奖励的回合百分比。
分析每回合的步数(见下图)揭示了基于模型和非基于模型的方法的几个方面:
-
基于模型的效率:在这个特定的网格世界中,基于模型的算法(Dyna-Q 和 Dyna-Q+)往往更具样本效率(这一特性在 RL 中也较为普遍)。这是因为它们可以利用环境的学习模型进行前瞻性规划,从而更快地收敛到接近最优或最优的解决方案。
-
Q 学习收敛:Q 学习虽然最终会收敛到接近最优解,但需要更多的回合(125)。需要强调的是,Q 学习每步仅执行 1 次更新,这与 Dyna-Q 和 Dyna-Q+ 执行的多次更新形成对比。
-
多次更新:Dyna-Q 和 Dyna-Q+ 每步执行 101 次更新,这有助于它们更快地收敛。然而,这种样本效率的权衡是计算成本(见下表的运行时间部分)。
-
复杂环境:在更复杂或随机的环境中,基于模型的方法的优势可能会减弱。模型可能引入错误或不准确,从而导致次优策略。因此,这种比较应被视为不同方法的优缺点概述,而不是直接的性能比较。
平均每集步骤数的比较(由作者制作)
现在我们引入平均累计奖励(ACR),它表示代理达到目标的集数百分比(因为达到目标的奖励为 1,而触发陷阱的奖励为 0),因此 ACR 计算方式为:
其中 N 是集数(250),K 是独立运行次数(100),Rn,k 是运行 k 中第 n 集的累计奖励。
以下是所有算法性能的详细分析:
-
Dyna-Q 收敛迅速,达到最高的总体回报,ACR 为 87%。这意味着它在很大一部分集数中能够高效地学习并达到目标。
-
Q-learning 也达到了类似的性能水平,但需要更多的集数才能收敛,这解释了其稍低的 ACR,为 70%。
-
Dyna-Q+ 能够迅速找到一个良好的策略,在仅经过 15 集后达到累计奖励 0.8。然而,奖励的变异性和探索性降低了其性能,直到第 100 步之后才开始改善,因为它发现了新的最优路径。然而,短期的探索会妨碍其性能,导致其 ACR 为 79%,低于 Dyna-Q,但高于 Q-learning。
平均每集累计奖励的比较(由作者制作)
结论
在本文中,我们介绍了时序差分学习的基本原理,并将 Q-learning、Dyna-Q 和 Dyna-Q+ 应用于自定义网格世界。这个网格世界的设计有助于强调持续探索的重要性,以发现和利用在变化环境中新的最优策略。通过每集步骤数和累计奖励的表现差异,展示了这些算法的优缺点。
总结来说,基于模型的方法(Dyna-Q、Dyna-Q+)相较于基于模型的方法(Q-learning)在样本效率上有优势,但计算效率较低。然而,在随机或更复杂的环境中,模型的不准确性可能会阻碍性能并导致次优策略。
参考文献:
1 Demis Hassabis, AlphaFold 揭示了蛋白质宇宙的结构 (2022), DeepMind
2 Elia Kaufmann, Leonard Bauersfeld, Antonio Loquercio, Matthias Müller, Vladlen Koltun & Davide Scaramuzza, 冠军级无人机竞速使用深度强化学习 (2023), Nature
[3] Nathan Lambert, Louis Castricato, Leandro von Werra, Alex Havrilla, 从人类反馈中阐述强化学习(RLHF), HuggingFace
[4] Sutton, R. S. 和 Barto, A. G. . 强化学习:导论 (2018), 剑桥(马萨诸塞州):麻省理工学院出版社。
[5] Christopher J. C. H. Watkins 和 Peter Dayan, Q-learning (1992), 《机器学习》,Springer Link
Python 中的时序差分:第一个基于样本的强化学习算法
原文:
towardsdatascience.com/temporal-differences-with-python-first-sample-based-reinforcement-learning-algorithm-54c11745a0ee
使用 Python 编写并理解 TD(0)算法
Eligijus Bujokas
·发表于 Towards Data Science ·13 分钟阅读·2023 年 1 月 27 日
--
Kurt Cotoaga在Unsplash上的照片
这是我之前文章的续集:
## Python 强化学习中的第一步
Python 的原始实现,展示了如何在强化学习的基本世界之一中找到最佳位置……
towardsdatascience.com
在这篇文章中,我想让读者熟悉强化学习中的基于样本的算法逻辑(RL)。为此,我们将创建一个带有洞的网格世界(类似于缩略图中的那个),并让我们的代理在创建的世界中自由遍历。
希望在代理的旅程结束时,他能学会在世界上哪个地方是好的地方,哪些位置应该避免。为了帮助我们的代理学习,我们将使用著名的TD(0)算法。
在深入算法之前,让我们定义一下我们想要解决的目标。
在这篇文章中,我们将创建一个 5 行 7 列的网格世界,这意味着我们的代理将能够处于 35 个状态中的一个。移动规则如下:
-
代理不能离开网格世界的边界。
-
在每个时间步,代理只能向上、向下、向左或向右移动。
-
代理从我们网格世界的左上角开始。
-
如果代理达到目标或掉入洞里,游戏结束,代理会被返回到起始状态。
-
每次移动都会获得-1 的奖励。
-
掉入洞里会获得-10 的奖励。
-
达到目标会获得 10 的奖励。
我们代理的终极目标是尽可能准确地评估它可能处于的每一个状态。换句话说,我们代理希望根据给定的移动策略评估每个状态的价值。**
以下代码片段初始化了前一节中描述的环境:
import numpy as np
def init_policy(S: np.array, weight_dict: dict = {'right': 1}) -> dict:
# Saving all the unique states to a vector
states = np.unique(S)
# Getting the number of rows and columns of the S matrix
n_row = S.shape[0]
n_col = S.shape[1]
# Dictionary to hold each action for a given state
P = {}
for s in states:
s_dict = {}
# Checking which index is the current state in the S matrix
s_index = np.where(S == s)
# If the state is in the top left corner, we can only move right and down
if s_index == (0, 0):
s_dict['right'] = 0.5 * weight_dict['right']
s_dict['down'] = 1 - s_dict['right']
# If the state is in the top right corner, we can only move left and down
elif s_index == (0, n_col - 1):
s_dict['left'] = 0.5
s_dict['down'] = 0.5
# If the state is in the bottom left corner, we can only move right and up
elif s_index == (n_row - 1, 0):
s_dict['right'] = 0.5 * weight_dict['right']
s_dict['up'] = 1 - s_dict['right']
# If the state is in the bottom right corner, we can only move left and up
elif s_index == (n_row - 1, n_col - 1):
s_dict['left'] = 0.5
s_dict['up'] = 0.5
# If the state is in the first row, we can only move left, right, and down
elif s_index[0] == 0:
s_dict['right'] = 0.333 * weight_dict['right']
s_dict['left'] = (1 - s_dict['right']) / 2
s_dict['down'] = (1 - s_dict['right']) / 2
# If the state is in the last row, we can only move left, right, and up
elif s_index[0] == n_row - 1:
s_dict['right'] = 0.333 * weight_dict['right']
s_dict['left'] = (1 - s_dict['right']) / 2
s_dict['up'] = (1 - s_dict['right']) / 2
# If the state is in the first column, we can only move up, down, and right
elif s_index[1] == 0:
s_dict['right'] = 0.333 * weight_dict['right']
s_dict['up'] = (1 - s_dict['right']) / 2
s_dict['down'] = (1 - s_dict['right']) / 2
# If the state is in the last column, we can only move up, down, and left
elif s_index[1] == n_col - 1:
s_dict['up'] = 0.333
s_dict['down'] = 0.333
s_dict['left'] = 1 - s_dict['up'] - s_dict['down']
# If the state is in the middle, we can move in all directions
else:
s_dict['right'] = 0.25 * weight_dict['right']
s_dict['up'] = (1 - s_dict['right']) / 3
s_dict['down'] = (1 - s_dict['right']) / 3
s_dict['left'] = (1 - s_dict['right']) / 3
# Saving the current states trasition probabilities
P[s] = s_dict
return P
def generate_holes(nrow: int, ncol: int, start_coords: list, hole_coords: list, nholes: int = 1) -> list:
"""
Function that generates nholes in a gridworld
The holes cannot be:
- in the start state
- in the goal state
"""
# Generating the hole coordinates
# The hole cannot be in the start or goal state
hole_coords = []
for _ in range(nholes):
hole_row = np.random.randint(0, nrow - 1)
hole_col = np.random.randint(0, ncol - 1)
while (hole_row, hole_col) in start_coords or (hole_row, hole_col) in hole_coords:
hole_row = np.random.randint(0, nrow - 1)
hole_col = np.random.randint(0, ncol - 1)
# Appending to the hole coordinates list
hole_coords.append((hole_row, hole_col))
return hole_coords
def init_env(
n_rows: int,
n_cols: int,
step_reward: float = -1,
goal_reward: float = 10,
hole_reward: float = -10,
n_holes: int = 1,
random_seed: int = 42,
policy_weights: dict = {'right': 1}
) -> np.array:
"""
Functionat that returns the initial environment:
S - the state matrix indexed by [row, col]
V - the initial value matrix indexed by [row, col]
R - the reward matrix indexed by [row, col]
A - the action matrix indexed by [row, col]
P - the probability dictionary where for each state, the keys are the actions and the values are the probabilities of the next state
"""
# Setting the random seed
np.random.seed(random_seed)
# Initiating the S matrix
S = np.arange(0, n_rows * n_cols).reshape(n_rows, n_cols)
# Creating the initial V matrix
V = np.zeros((n_rows, n_cols))
# The start state will be always the top left corner
# The goal state will be always the bottom right corner
# We will generate a random holes that our agent can fall in
# Any other state that is not the hole or the goal state will receive a step reward
goal_coord = (n_rows - 1, n_cols - 1)
R = np.zeros((n_rows, n_cols))
R.fill(step_reward)
R[0, 0] = step_reward
R[goal_coord] = goal_reward
# Generating the hole coordinates
# The hole cannot be in the start or goal state
hole_coords = generate_holes(n_rows, n_cols, [(0, 0)], [goal_coord], n_holes)
# Setting the hole reward
for hole_coord in hole_coords:
R[hole_coord] = hole_reward
# Initiating the policy
P = init_policy(S, weight_dict=policy_weights)
return S, V, R, P, hole_coords, [goal_coord]
我们需要开始学习的对象是:
-
状态矩阵 S
-
值矩阵 V
-
奖励矩阵 R
-
策略字典 P
默认情况下,上述代码片段初始化了一个随机策略的世界。
随机策略意味着我们的代理通过均匀概率分布选择从一个状态转移到另一个状态。
让我们创建我们的世界,更详细地探索这些矩阵:
S, V, R, P, hole_coords, goal_coard = init_env(5, 7, n_holes=4, random_seed=3)
以下代码片段用于绘制矩阵:
def array_index_to_matplot_coords(i: int, j: int, n_cols: int) -> Tuple[int, int]:
"""
Converts an array index to a matplot coordinate
"""
x = j
y = n_cols - i - 1
return x, y
def plot_matrix(
M: np.array,
goal_coords: list = [],
hole_coords: list = [],
img_width: int = 5,
img_height: int = 5,
title: str = None,
) -> None:
"""
Plots a matrix as an image.
"""
height, width = M.shape
fig = plt.figure(figsize=(img_width, img_width))
ax = fig.add_subplot(111, aspect='equal')
for x in range(height):
for y in range(width):
# By default, the (0, 0) coordinate in matplotlib is the bottom left corner,
# so we need to invert the y coordinate to plot the matrix correctly
matplot_x, matplot_y = array_index_to_matplot_coords(x, y, height)
# If there is a tuple of (x, y) in the goal_coords list, we color the cell gray
if (x, y) in goal_coords:
ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='gray'))
# If there is a tuple of (x, y) in the hole_coords list, we color the cell salmon
elif (x, y) in hole_coords:
ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='salmon'))
ax.annotate(str(M[x][y]), xy=(matplot_x, matplot_y), ha='center', va='center')
offset = .5
ax.set_xlim(-offset, width - offset)
ax.set_ylim(-offset, height - offset)
ax.hlines(y=np.arange(height+1)- offset, xmin=-offset, xmax=width-offset)
ax.vlines(x=np.arange(width+1) - offset, ymin=-offset, ymax=height-offset)
plt.title(title)
plt.show()
def plot_policy_matrix(P: dict, S:np.array, terminal_coords: list = [], img_width: int = 5, img_height: int = 5, title: str = None) -> None:
"""
Plots the policy matrix out of the dictionary provided; The dictionary values are used to draw the arrows
"""
height, width = S.shape
fig = plt.figure(figsize=(img_width, img_width))
ax = fig.add_subplot(111, aspect='equal')
for x in range(height):
for y in range(width):
matplot_x, matplot_y = array_index_to_matplot_coords(x, y, height)
# If there is a tuple of (x, y) in the goal_coords list, we color the cell gray
if (x, y) in terminal_coords:
ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='gray'))
else:
try:
# Adding the arrows to the plot
if 'up' in P[S[x, y]]:
plt.arrow(matplot_x, matplot_y, 0, 0.3, head_width = 0.05, head_length = 0.05)
if 'down' in P[S[x, y]]:
plt.arrow(matplot_x, matplot_y, 0, -0.3, head_width = 0.05, head_length = 0.05)
if 'left' in P[S[x, y]]:
plt.arrow(matplot_x, matplot_y, -0.3, 0, head_width = 0.05, head_length = 0.05)
if 'right' in P[S[x, y]]:
plt.arrow(matplot_x, matplot_y, 0.3, 0, head_width = 0.05, head_length = 0.05)
except Exception as e:
print(f"Error: {e}")
print(f"Current x and y: {x}, {y}")
offset = .5
ax.set_xlim(-offset, width - offset)
ax.set_ylim(-offset, height - offset)
ax.hlines(y=np.arange(height+1)- offset, xmin=-offset, xmax=width-offset)
ax.vlines(x=np.arange(width+1) - offset, ymin=-offset, ymax=height-offset)
plt.title(title)
首先让我们可视化状态矩阵:
plot_matrix(S, goal_coords=goal_coard, hole_coords=hole_coords, title='State Matrix')
状态矩阵;作者拍摄的照片
红色状态表示洞的坐标——这些是我们的代理想要避免的状态。
灰色状态表示目标——这是我们的代理想要到达的地方。
我们的代理总是从状态 0 开始它的旅程。
奖励矩阵如下:
plot_matrix(R, goal_coords=goal_coard, hole_coords=hole_coords, title='Reward Matrix')
奖励矩阵;作者拍摄的照片
转移到某个状态的奖励矩阵在上面可视化。
例如:
-
从状态 1 到 8 会获得-1 的奖励
-
从状态 9 到 10 会获得-10 的奖励
-
从状态 33 到 34 会获得 10 的奖励
依此类推。
我们的代理将遵循的策略是随机策略——进入每个状态的概率均等:
plot_policy_matrix(P, S, terminal_coords=hole_coords + goal_coard, title='Policy Matrix')
策略矩阵;作者拍摄的照片
策略矩阵中的灰色状态表示终端状态:如果代理选择进入该状态,剧集将结束,代理将被重置到状态 0。
TD(0)算法的目标是评估给定策略下每个状态的价值。
换句话说,我们想要填充值矩阵的值:
plot_matrix(V, goal_coords=goal_coard, hole_coords=hole_coords, title='Value Matrix')
初始值矩阵;作者拍摄的照片
TD(0)算法是单步时序差分算法的简称。为了开始建立直觉并广泛地说,在此算法中,我们的代理按照给定的策略执行一步,观察奖励,并在这种步骤后更新状态价值的估计。
从数学上讲,更新步骤如下:
TD(0) 更新方程
这里:
-
s prime — 我们的代理从当前状态 s 转移到的状态。
-
奖励 r 等于转移到 s prime 的奖励。
-
Gamma 是折扣率(大于 0,小于或等于 1)。
-
Alpha 是大小(大于 0,小于或等于 1)。
完整算法¹如下:
完整 TD(0);作者照片
TD(0)算法是一种预测算法。在强化学习中,预测算法指的是一种尝试估计状态值的算法,同时不改变给定的策略(转移概率)。
这也是一种自助算法,因为我们使用当前的价值函数估计来估计下一个状态的价值函数。
因此,我们只关心状态值——智能体从当前状态移动的总期望累计奖励:
状态价值
现在让我们开始实现算法。
我们的智能体首先需要根据我们创建的策略进行移动:
def select_move(s, S, P) -> int:
"""
Given the current state, returns the coordinates of the next state based on the current policy
"""
# Getting the current state index
s_index = np.where(S == s)
# Getting the current state policy
s_policy = P[s]
# Selecting the next action based on the current policy
next_action = np.random.choice(list(s_policy.keys()), p=list(s_policy.values()))
# Getting the next state coordinates based on the next action
try:
if next_action == 'up':
next_state = S[s_index[0] - 1, s_index[1]][0]
elif next_action == 'down':
next_state = S[s_index[0] + 1, s_index[1]][0]
elif next_action == 'left':
next_state = S[s_index[0], s_index[1] - 1][0]
elif next_action == 'right':
next_state = S[s_index[0], s_index[1] + 1][0]
except Exception as e:
print(f"Current state: {s}")
print(f'Next action: {next_action}')
print(f'Error: {e}')
return next_state
当智能体处于状态 s 时,它只能前往策略矩阵字典中存在的状态。例如,状态 1 中的所有动作是:
状态 1 的所有可能动作
所有概率的总和等于 1,我们的智能体随机选择右、左或下(请参阅状态矩阵图以查看状态位置)。
上述动作是开始更新价值函数所需的全部。当智能体进行移动时,它转移到另一个状态并收集该状态的奖励。然后我们应用方程:
TD(0)更新方程
def get_state_coords(s, S) -> tuple:
"""
Returns the state coordinates given the state index
"""
s_index = np.where(S == s)
return s_index[0][0], s_index[1][0]
def update_value(s, s_prime, S, P, V, R, alpha: float = 0.1, gamma: float = 0.9) -> float:
"""
Updates the current value function based on the current policy
"""
# Getting the CURRENT state's nrow and ncol index
s_index_now = get_state_coords(s, S)
# Getting the SELECTED state's nrow and ncol index
s_index_prime = get_state_coords(s_prime, S)
# Getting the reward by moving to the selected state
move_reward = R[s_index_prime[0], s_index_prime[1]]
# Getting the current estimated value of the selected state
current_value = V[s_index_now[0], s_index_now[1]]
# The next value
prime_value = V[s_index_prime[0], s_index_prime[1]]
# Returning the TD(0) current state value
return current_value + alpha * (move_reward + gamma * prime_value - current_value)
最后一步是将所有内容封装到一个while 循环中,只有当我们的智能体转移到终止状态时才停止探索:
def episode_exploration(S, P, V, R, terminal_state_coords: list, alpha: float = 0.1, gamma: float = 0.9) -> None:
"""
Agent exploration and value updating using TD(0) equation until a terminal state is reached
"""
# The starting state is 0
s = 0
# Keeping track of the number of moves
n_moves = 0
# Getting the coordinates of the s
s_coords = get_state_coords(s, S)
while s_coords not in terminal_state_coords:
# Selecting the next state based on the current policy
s_prime = select_move(s, S, P)
# Updating the current state value
V[s_coords] = update_value(s, s_prime, S, P, V, R, alpha, gamma)
# Updating the current state
s = s_prime
# Incrementing the number of moves
n_moves += 1
# Getting teh new s coords
s_coords = get_state_coords(s, S)
return n_moves
我们现在拥有了实施完整 TD(0)算法所需的一切。
让我们定义 10000 次实验,让我们的智能体进行学习吧!
# Defining the number of episodes to explore
n_episodes = 10000
# We will plot the V matrix after each episode filling the same device plot to make an animation
number_of_walks = []
for _ in tqdm(range(n_episodes)):
n = episode_exploration(S, P, V, R, terminal_state_coords=hole_coords + goal_coard, alpha=0.1, gamma=0.9)
number_of_walks.append(n)
我们的智能体在终止之前所采取的动作数量:
# Ploting the distribution of the number of moves
plt.figure(figsize=(10, 5))
sns.kdeplot(number_of_walks, fill=True)
plt.title(f'Number of moves distribution | Mean: {np.mean(number_of_walks):.2f} | Std: {np.std(number_of_walks):.2f}')
plt.xlabel('Number of moves')
plt.ylabel('Frequency')
plt.show()
移动次数;作者绘图
平均而言,我们的智能体在碰到终止状态之前进行了 10 次移动。
最终评估的状态价值矩阵:
使用 TD(0)和随机策略评估的 V;作者绘图
正如我们所见,按照给定的策略,智能体开始旅程的状态非常糟糕。平均而言,从该状态开始,智能体仅获得-9.96 的奖励。然而,随着我们接近目标状态,价值会增加。
注意,目标状态和洞穴状态的值为 0,因为这些状态没有探索——每次智能体转移到这些状态,游戏就结束了。
如果我们选择了另一种策略会发生什么?例如,更频繁地选择“向右”方向:
# Assiging a different policy
S, V, R, P, hole_coords, goal_coard = init_env(5, 7, n_holes=4, random_seed=3, policy_weights={'right': 1.5})
# Defining the number of episodes to explore
n_episodes = 10000
# We will plot the V matrix after each episode filling the same device plot to make an animation
number_of_walks = []
for _ in tqdm(range(n_episodes)):
n = episode_exploration(S, P, V, R, terminal_state_coords=hole_coords + goal_coard, alpha=0.1, gamma=0.9)
number_of_walks.append(n)
不同策略下的移动次数
不同策略的价值矩阵
随机策略的状态价值矩阵总和为-249.29,而更高概率向右的策略总和为-213.51。
从这个意义上说,我们可以说更频繁地向右移动是一种更好的策略!
在这篇文章中,我介绍了 RL 中的第一个基于样本的算法——一步时序差分算法或 TD(0)。
这是一种预测算法,即仅用于评估给定策略的状态。改变策略会得到不同的状态价值结果。
祝大家学习愉快,编程快乐!
1
-
作者: 理查德·S·萨顿,安德鲁·G·巴托
-
年份: 2018
-
页码: 120
-
书名: 强化学习:一种介绍
-
URL:http://archive.ics.uci.edu/ml
时间图基准
原文:
towardsdatascience.com/temporal-graph-benchmark-bb5cc26fcf11?source=collection_archive---------2-----------------------#2023-12-09
挑战性和现实的时间图学习数据集
Shenyang(Andy) Huang
·
关注 发表在 Towards Data Science · 10 分钟阅读 · 2023 年 12 月 9 日
--
近年来,静态图上的机器学习取得了显著进展,这得益于公共数据集和标准化评估协议的普及,例如广泛采用的开放图基准 (OGB)。然而,许多现实世界系统,如社交网络、交通网络和金融交易网络,随着时间的推移不断演变,节点和边不断添加或删除。这些系统通常被建模为时间图。到目前为止,时间图学习的进展受到缺乏大型高质量数据集以及缺乏适当评估的制约,导致了过于乐观的性能表现。
现实世界的网络随着时间的推移而演变。图片来源:Armand Khoury on Unsplash。
为了解决这一问题,我们推出了时序图基准测试(TGB),这是一个针对时序图的挑战性和多样化基准数据集的集合,用于现实的、可重复的、稳健的机器学习评估。受到 OGB 成功的启发,TGB 自动化了数据集下载和处理以及评估协议,并允许用户通过排行榜比较模型性能。我们希望 TGB 能够成为时序图社区的标准化基准,促进新方法的发展,并提高对大型时序网络的理解。
针对时序图学习的挑战性和现实性基准
-
网站:
tgb.complexdatalab.com/
-
论文:
arxiv.org/abs/2307.01026
-
Github:
github.com/shenyangHuang/TGB
这篇文章基于我们的论文 时序图基准测试在时序图上的机器学习 (NeurIPS 2023 数据集和基准测试专题),由 Emanuele Rossi共同撰写。请在 我的网站查找更多时序图相关工作。想了解更多关于时序图的内容?加入 时序图阅读小组 和 NeurIPS 2023 时序图学习研讨会 ,了解最前沿的 TG 研究。
目录:
-
动机
-
问题设定
-
数据集详情
-
动态链接属性预测
-
动态节点属性预测
-
开始使用 TGB
-
结论与未来工作
动机
近年来,静态图的机器学习领域得到了显著提升,这主要归功于公开数据集的出现和已建立的基准测试,例如开放图基准(OGB)、长程图基准和TDC 基准。然而,许多现实世界的系统,如社交网络、交通网络和金融交易网络,都是时间性的:它们随着时间的发展而演变。直到现在,时间图的发展由于缺乏大型、高质量的数据集和全面的评估框架而受到显著阻碍。这种稀缺性,加上评估限制,导致了在流行数据集(如 Wikipedia 和 Reddit)上的几乎完美的 AP 或 AUROC 分数,导致了对模型性能的过于乐观的评估,并且在区分竞争模型方面面临挑战。
数据集的缺乏。 常见的 TG 数据集仅包含几百万条边,远远小于实际时间网络中的规模。此外,这些数据集大多限制在社交和互动网络领域。由于网络属性在不同领域间通常变化显著,因此在多个领域上进行基准测试也很重要。最后,缺乏节点级任务的数据集,导致大多数方法仅关注链接预测。为了解决这个挑战,TGB 包含了来自五个不同领域的nine个数据集,这些数据集在节点、边和时间戳的数量上都是数量级更大的。此外,TGB 还提出了四个数据集用于新的节点亲和预测任务。
TGB 数据集显著大于常见的 TG 数据集
简化的评估。 动态链接预测通常被框架化为二分类任务:正(真实)边标记为 1,负(不存在)边标记为 0。在评估时,通过保持源节点固定并随机选择目标节点来采样每个正边的一个负边。这种评估仅考虑少量容易预测的负边,导致模型性能被夸大,许多模型在 Wikipedia 和 Reddit 上获得了>95%的 AP(Poursafaei et al. 2022,Rossi et al. 2020,Wang et al. 2021,Souza et al. 2022)。在 TGB 中,我们将链接预测任务视为排序问题,并使评估更加稳健。我们展示了改进的评估结果能提供更现实的性能表现,并突出了不同模型之间的明显差距。
问题设定
在 TGB 中,我们专注于连续时间的时间图,如 Kazemi et al. 2020 定义的那样。在这种设置中,我们将时间图表示为带时间戳的边流,由(源节点, 目标节点, 时间戳)三元组组成。请注意,时间边可以是加权的、有向的,同时节点和边可以选择性地具有特征。
此外,我们还考虑了流式设置,在这种设置中,模型可以在推理时纳入新信息。特别地,在时间t预测测试边时,模型可以访问1所有发生在t之前的边,包括测试边。然而,不允许使用测试信息进行反向传播和权重更新。
数据集详情
TGB 包含 九 个数据集,其中七个是为此工作专门整理的,两个来自以前的文献。这些数据集在时间上分为训练集、验证集和测试集,比例为 70/15/15。数据集根据边的数量分类:小型(<5 百万)、中型(5–25 百万)和大型(> 25 百万)。
TGB 数据集的统计信息
TGB 数据集还具有不同的领域和时间粒度(从 UNIX 时间戳到年度)。最后,数据集的统计信息也非常多样化。例如,惊讶指数,由训练集中从未观察到的测试边的比例定义,在不同的数据集中差异显著。许多 TGB 数据集中还包含许多测试集中出现的新节点,这需要归纳推理。
TGB 数据集也与现实世界任务相关。例如,tgbl-flight
数据集是一个从 2019 年到 2022 年的众包国际航班网络,其中机场建模为节点,而边则是给定日期的机场之间的航班。任务是预测未来某个日期两特定机场之间是否会发生航班。这对于预测潜在的航班中断(如取消和延误)非常有用。例如,在 COVID-19 大流行期间,为了遏制 COVID-19 的传播,许多航班路线被取消。预测全球航班网络对研究和预测疾病(如 COVID-19)向新地区传播也很重要,如 Ding et al. 2021 中所见。详细的数据集和任务描述在论文第四部分中提供。
动态链接属性预测
动态链接属性预测的目标是预测在未来时间戳下,节点对之间链接的属性(通常是存在性)。
负边采样。 在实际应用中,真实的边在事先并不为人知。因此,查询大量节点对,仅将得分最高的节点对视为边。受到这一点的启发,我们将链接预测任务框架化为排名问题,并对每个正边采样多个负边。具体而言,对于给定的正边(s,d,t),我们固定源节点s和时间戳t,并采样q个不同的目标节点d。对于每个数据集,q的选择基于评估完整性和测试集推断时间之间的权衡。在q个负样本中,一半是均匀随机采样的,另一半是历史负边(在训练集中观察到但在时间t时不存在的边)。
性能指标。 我们使用过滤后的平均倒数排名(MRR)作为本任务的指标,因为它专为排名问题设计。MRR 计算真实目标节点在负样本或伪目标中的倒数排名,通常用于推荐系统和知识图谱文献中。
tgbl-wiki 和 tgbl-review 数据集上的 MRR 表现
小数据集的结果。 在小数据集tgbl-wiki
和tgbl-review
上,我们观察到最佳表现的模型有很大差异。此外,在tgbl-wiki
上的顶级模型,如 CAWN 和 NAT,在tgbl-review
上的性能显著下降。一个可能的解释是,与tgbl-wiki
数据集相比,tgbl-review
数据集具有更高的惊讶指数。高惊讶指数表明,测试集边的高比例从未在训练集中观察到,因此tgbl-review
需要更多的归纳推理。在tgbl-review
中,GraphMixer 和 TGAT 是表现最佳的模型。由于其较小的规模,我们能够为tgbl-wiki
采样所有可能的负样本,为tgbl-review
每个正边采样一百个负样本。
tgbl-coin、tgbl-comment 和 tgbl-flight 数据集上的 MRR 表现。
大多数方法在这些数据集上运行时耗尽了 GPU 内存,因此我们对 TGN、DyRep 和 Edgebank 进行了比较,因为它们的 GPU 内存需求较低。注意,某些数据集如tgbl-comment
或tgbl-flight
跨越多年,因此可能导致其长期跨度上的分布变化。
负样本数量对 tgbl-wiki 的影响
洞察。 如tgbl-wiki
中所示,用于评估的负样本数量可以显著影响模型性能:我们看到,当负样本数量从 20 增加到所有可能的目标时,大多数方法的性能显著下降。这验证了确实需要更多的负样本来进行稳健的评估。有趣的是,像 CAWN 和 Edgebank 这样的算法性能下降相对较小,我们将其作为未来的工作来调查为何某些方法受影响较小。
TG 模型的总训练和验证时间
接下来,我们观察到 TG 方法的训练和验证时间差异高达两个数量级,其中启发式基线 Edgebank 始终是最快的(因为它简单地实现为哈希表)。这表明,提高模型效率和可扩展性是未来的重要方向,以便可以在 TGB 中提供的大型数据集上测试新的和现有的模型。
动态节点属性预测
动态节点属性预测的目标是在任何给定的时间戳t预测节点的属性。由于缺乏具有动态节点标签的大型公共 TG 数据集,我们引入了节点亲和性预测任务来研究时间图上的节点级任务。如果您希望贡献具有节点标签的新数据集,请与我们联系。
节点亲和性预测。 该任务考虑节点子集(例如用户)对其他节点(例如项目)的亲和性及其随时间自然变化的方式。这个任务在推荐系统中很相关,在那里,通过建模用户对不同项目的偏好随时间的变化来为用户提供个性化推荐非常重要。在这里,我们使用前 10 项的归一化折扣累积增益(NDCG@10)来比较预测项目与真实值的相对顺序。标签是通过统计用户在未来一段时间内与不同项目的互动频率生成的。
节点亲和性预测任务的实证结果。
结果。 对于这个任务,我们将 TG 模型与两种简单的启发式方法进行比较:持久性预测,即预测当前时间点的最近观察到的节点标签,以及移动平均,即过去几步中的节点标签的平均值。这里的关键观察是,在这个任务中,像持久性预测和移动平均这样的简单启发式方法是 TG 方法的有力竞争者,并且在大多数情况下,它们的表现超过了 TG 方法。这突显了未来需要开发更多针对节点级任务的 TG 方法。
开始使用 TGB
TGB 的机器学习管道
如何使用 TGB?上面展示了 TGB 的 ML 流程。你可以自动下载数据集,并将其处理为 numpy
、PyTorch
和 PyG
兼容的数据格式。用户只需设计自己的 TG 模型,这些模型可以通过 TGB 评估器 进行 标准化评估。 最后,公开的 TGB 排行榜帮助研究人员跟踪时间图领域的最新进展。你可以轻松安装该软件包:
pip install py-tgb
最后,你可以将你的模型性能提交到 TGB 排行榜。我们要求你提供代码链接和描述你方法的论文以确保可重复性。要提交,请填写 google 表单。
结论与未来工作
为了实现对时间图进行现实、可重复和鲁棒的评估,我们推出了时间图基准(Temporal Graph Benchmark),这是一个包含挑战性和多样化数据集的集合。通过 TGB 数据集和评估,我们发现模型性能在不同数据集上差异显著,这显示了在多样的时间图领域进行评估的必要性。此外,在节点亲和度预测任务中,简单的启发式方法优于 TG 方法,从而激发了未来开发更多节点级 TG 模型的动机。
集成到 PyG 中。 Matthias Fey(Kumo.AI),PyG 的核心负责人,在 斯坦福图学习研讨会 2023 上宣布,TGB 将集成到 PyG 的未来版本中。敬请关注!
TGX 库。 我们目前正在开发一个用于时间图的实用工具和可视化 Python 库,名为 TGX。TGX 支持来自 TGB 的 20 个内置时间图数据集以及 Poursafaei et al. 2022。
社区反馈与数据集贡献。 TGB 是一个社区驱动的项目,我们感谢所有通过电子邮件或 Github 问题提供建议的社区成员。如果你有任何建议或希望向 TGB 贡献新的数据集,请通过 email 或 在 Github 上创建问题 与我们联系。我们正在寻找大规模数据集,特别是用于动态节点或图分类任务的数据集。
2023 年的时间图学习
原文:
towardsdatascience.com/temporal-graph-learning-in-2023-d28d1640dbf2?source=collection_archive---------1-----------------------#2023-01-16
目前为止的故事
Shenyang(Andy) Huang
·
关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 1 月 16 日
--
现实世界的网络,如社交网络、交通网络和引用网络,往往会随着时间演变,而时间图学习(TGL)领域旨在从这些不断演变的网络中提取、学习和预测。最近,TGL 在机器学习社区中受到越来越多的关注,相关论文数量激增,去年在 NeurIPS 2022 上举办了该领域的首个研讨会!
时间图中的演变。图片由作者提供。
这篇文章由 Emanuele Rossi, Michael Galkin 和 Kellin Pelrine 共同撰写。 感谢 Farimah Poursafaei 提供的有益反馈。
在这篇博客文章中,我们展示了 TGL 在 2022 年之前的主要进展,并讨论了有前景的未来方向。请注意,我们将“动态图”和“时序图”交替使用。如果你想学习或开始一个 TGL 项目,这篇文章将是一个很好的参考和起点。
请在评论区与我们分享您感兴趣的其他进展。
目录:
-
时序图学习简介
-
时序图网络的表达能力
-
重新思考时序图中的评估
-
时序知识图谱
-
库和数据集
-
利用时序图进行疾病建模
-
时序图中的异常检测
-
检测时序图中的虚假信息
-
加入时序图学习社区
时序图学习简介
在本节中,我们简要介绍了文献中一些著名的 TGL 方法。学习连续时间动态图(CTDGs)的方法主要分为两类:时序图网络和游走聚合方法。有关 CTDGs 的详细信息,请参阅 Kazemi 等 的这篇综述。
时序图网络 (TGNs) 将信息传递神经网络 (MPNNs) 推广到时序图。它们通过引入一个节点记忆来实现,该记忆表示节点在给定时间的状态,作为节点过去交互的压缩表示。每当两个节点参与交互时,它们会相互发送消息,这些消息然后用于更新它们的记忆。在计算节点嵌入时,会对节点的时序邻居进行额外的图聚合,使用该时刻的原始节点特征和记忆。以下是 TGN 计算的示意图。
对一批训练边缘的 TGN 计算。
图片来源:Rossi 等
TGN 是一个通用框架,它将以前的模型,如联合动态用户-项目嵌入 (JODIE) 和时序图注意力 (TGAT),作为特例进行推广。有关 TGN 的更全面介绍,请参阅下面其中一位作者的博客文章。
## 时序图网络
一种用于动态图的新型神经网络架构。
[towardsdatascience.com
诸如 Causal Anonymous Walks (CAW) 这样的 Walk 聚合方法则依赖于(时间)随机游走。特别是,为了预测时间 t 上的一个链接 (u, v) 的存在,CAW 首先提取多个从 u 和 v 开始的随机游走,使得游走中的边的时间戳只能单调递减。这些游走首先通过用节点在游走中每个可能位置出现的次数向量替换每个节点标识符来进行匿名化。然后,使用 RNN 对每个游走进行编码,并通过自注意力或简单平均来聚合编码。
时间图网络的表达力。
关于在静态图上运行的图神经网络(GNNs)表达能力的研究已有大量工作。Xu et al. 2019 首次通过将图神经网络(GNNs)与 Weisfeiler-Lehman (WL) 图同构测试关联起来,并展示了许多 GNNs 的能力不超过 1-WL 测试,从而描述了其区分能力。随后,出现了更具表达能力的模型,如 子图 GNNs,图变换器 和 高阶 GNNs,这些模型被设计得比 1-WL 测试更具表达力(下面是 Michael Bronstein 关于如何超越 WL 测试的精彩博客文章的链接)。
Graph Neural Networks beyond Weisfeiler-Lehman and vanilla Message Passing
受物理启发的图上连续学习模型可以克服传统 GNNs 的局限性。
towardsdatascience.com
直到今年,关于 TGL 方法的表达力的研究仍然很少。第一个弥合这一差距的努力是由 Ribeiro et al. 提出的,其关键思想是将现有的 TGL 方法分为 时间-和-图 和 时间-然后-图 框架。
将 TG 转换为时间-然后-图表示。
图片来源: Ribeiro et al.
1️)。在 时间-和-图 中,GNNs 用于在每个时间快照图上生成节点嵌入,从而形成节点嵌入的序列。
2️)。在 时间-然后-图 中,TG 中的每条边被转换为一个时间序列,该序列指示边存在的时间,从而将时间边折叠为静态图中的边特征。
已证明 时间-然后-图 表示可以从任何给定的 时间-和-图 表示中构建,从而证明 时间-然后-图 至少与 时间-和-图 一样具备表达力。通过在 时间-然后-图 中的静态表示,我们可以直接将静态图的 WL 测试表达框架应用于 TGL 方法。这样,只要使用 1-WL GNN 作为主干模型,时间-然后-图 就比 时间-和-图 更具表达力。
Souza et al. 也旨在为 TGL 方法建立 1-WL 表达框架。值得注意的是,他们将 CTDG 视为一系列时间戳多图,其中在给定时间 t 的多图 G(t) 是通过顺序应用所有早于 t 的事件来获得的。这里的多图意味着两个节点之间可以有多条边,而边的属性是时间戳信息。
现在,时间 WL 测试可以通过对从 CTDG 构建的多图应用 WL 测试来定义。因此,更具表达力的 TGN 方法必须在其时间邻域上是单射的(即将两个不同的多集节点哈希为不同的颜色),称为单射 MP-TGNs。Souza et al. 还分析了基于游走的 TGNs,如 CAW,并显示 MP-TGNs 和 CAW 之间并没有比彼此更具表达力(如上所示)。他们提出的 PINT 方法结合了这两类方法的优点,因此是最具表达力的。下面的示例显示了 MP-TGNs 无法区分的两个时间图。颜色表示节点标签,边的时间戳从 t₁ 开始。
MP-TGNs 无法区分的时间图示例,例如直径、环长和循环数量。
图片来源:Souza et al.
重新思考时间图中的评估
在很大程度上,TGL 中的评估程序相对未被充分探索,并且受到静态图学习的重大影响。例如,对动态图上的链路预测任务(或动态链路预测)的评估通常涉及:1)。固定的训练、测试拆分,2)。随机负边采样 和 3)。来自类似领域的小数据集。这样的评估协议往往导致结果表中报告的指标已经达到 95% 以上,很难区分新模型是否带来了实际的好处,还是只是重新使用现有方法。
典型的时间链路预测结果表,报告了平均精度(AP)。即使基线模型也能达到 98%,我们真的在取得进展吗?
图片来源: Souza 等
You 等 讨论了当前 TGL 方法在离散时间动态图(DTDGs)中的模型设计、评估设置和训练策略的局限性。他们认为数据和模型的演变特性没有被考虑。在标准评估中,所有时间点按时间顺序划分为训练集、评估集和测试集。对于给定的数据集,这种划分是固定的。
他们指出,这种固定的划分意味着只有来自所选测试期的边会被评估,因此可能跨越训练、验证和测试期的长期行为将无法正确评估。此外,许多 TGL 方法在测试时是过时的,意味着模型表示在评估过程中没有得到更新。考虑一个示例交易图,如果前一天的信息可用,用户很可能希望利用这些信息来更新模型,以实现最佳性能。因此,提出了一种实时更新评估的方法,其中模型根据新观察到的数据进行微调,利用历史信息并预测未来的连接。
灰色/红色条分别表示 Wikipedia / MOOC 数据集中的重复/新颖边。时间图中的许多边随时间重复出现。
图片来源: Poursafaei 等
近期工作由两位作者研究了如何选择负边进行 CTDG 方法的评估,并引入了来自不同领域的更多数据集。在动态链接预测中,负边通常是从任意节点对中随机抽取的。然而,时间图中的许多边会随着时间的推移而重复(如上图所示)。考虑到现实世界图的稀疏性,大多数节点对不太可能形成边。因此,随机 负边可以被视为容易的负边。
TGL 方法的平均性能。使用更困难的负边显著影响模型性能。简单的基线 EdgeBank 的表现也出奇地好。
图片来源: Poursafaei 等
现在,什么可以被视为困难的负边?首先,我们介绍历史负边,即在训练集中出现但在当前测试步骤中缺失的边。我们还将归纳负边定义为在测试集中之前出现但在当前步骤中不存在的测试边。最后,我们提出了一个基线 EdgeBank,仅依靠记住过去的边(本质上是已见边的哈希表)。在上面的图中,我们看到,通过改变负边进行评估时,现有 TGL 方法在历史和归纳设置下的平均性能显著降低,与标准设置相比。EdgeBank 在标准设置下也是一个出乎意料的强大基线。有关详细信息,请参见下方作者之一的博客。
[## 迈向更好的动态图链接预测
伴随博客文章,介绍了《迈向更好的动态链接预测》的评估,将在 NeurIPS 2022 数据集和…
medium.com](https://medium.com/@shenyanghuang1996/towards-better-link-prediction-in-dynamic-graphs-cdb8bb1e24e9?source=post_page-----d28d1640dbf2--------------------------------)
时间知识图谱
在知识图谱(KG)的领域中,时间设置与同质世界略有不同,即时间戳图快照并不常见。相反,一些(或所有)三元组具有一个(开始时间,结束时间)对属性,表示某个事实为真的时间范围。因此,三元组变成了五元组,或者在 Wikidata 中,时间属性成为 限定词 的一部分,更一般的 声明(主三元组 + 多个键值限定词),声明形成所谓的 超关系 KGs.
例如,(法国总统,职务持有者,尼古拉·萨科齐,2007,2012)
是一个五元组,描述了尼古拉·萨科齐担任法国总统的时间段。或者,每个三元组也可以只有一个时间戳(形成四元组)。最常见的预测任务是给定时间属性评分头/尾预测,例如,(法国总统,职务持有者,**???**,2007,2012)
—— 这可以被视为超关系链接预测的特例,其中限定词仅为日期时间文字。一个经典的时间 KG 补全模型是 TNTComplex(ICLR 2020)。
Krause et al. 已经迈出了弥合时间知识图谱与同质图之间差距的第一步。在这项工作中,作者提出了一个框架,以形式化知识图谱中的各种时间方面。即,他们将 时间 知识图谱定义为局部扩展,即边上具有时间戳的图,而 动态 知识图谱定义为全局扩展,即随着时间的推移通过添加或删除节点和边而改变拓扑的图。更进一步,这些基本类型的组合是存在的,例如,时间和动态知识图谱的组合被称为 增量。我们希望这项工作能为时间知识图谱的繁杂文献带来更多秩序和清晰度,社区也能遵循这个良好的分类法。下一步:为这些图类型最终确定一个适当的评估协议。
时间和动态知识图谱(及其组合)。
图片来源:Krause et al.
Wang et al. 解决了在时间 + 动态图上进行少样本链接预测的任务,其中边具有时间戳并且新节点可能在后续时间步出现(增量 图,如 Krause et al. 上述分类)。少样本场景使得任务更加具有挑战性——我们只能访问有限数量的训练和推理点(通常小于 5)来推理查询链接。在这里,作者提出了 MetaTKGR,这是一种基于元学习的方法,通过聚合一定 delta t 时间邻域内现有节点的特征来构建新节点的表示。时间戳之间的标量差异通过傅里叶变换进行向量化。
MetaTKGR 的组件。
图片来源:Wang et al.
库和数据集
过去几年中,缺乏大规模数据集和具有挑战性的任务一直阻碍着时间图学习领域的研究。幸运的是,来自不同领域的新数据集正在涌现。例如,Poursafaei et al. 引入了六个新的公开可用的 TG 数据集,涵盖了交通、政治、经济和接近领域。然而,该领域仍然缺乏一致的努力来将基准和评估标准化到高质量,就像 OGB 对静态图所做的那样。我们希望在 2023 年,我们能看到更多关注实际应用的标准化 TG 基准。
关于库,著名的一个是 Pytorch Geometric Temporal,这是 Pytorch Geometric 的时间图扩展。然而,Pytorch-Geometric Temporal 似乎只包含离散时间方法和数据集。一个包含连续时间方法的库将为社区带来很大价值。最近,Zhou et al. 提出了 TGL,这是一个用于大规模离线时间图神经网络训练的统一框架。特别是在一台 4-GPU 机器上,TGL 可以在 1–10 小时内训练超过十亿条时间边的一轮。
我们列出了各种 TGL 库和数据集的链接如下。
-
通过 “pip install dgb” 访问的 13 个处理过的 TG 数据集
-
Pytorch Geometric Temporal
-
TGL library
-
Chartalist Dynamic Blockchain Transaction Network
-
Temporal knowledge graph forecasting benchmark
使用时间图进行疾病建模
在近期的 COVID-19 大流行中,流行病建模对理解疾病传播以及设计相应的干预策略至关重要。人际接触网络实际上是时间图。通过将接触图与经典的基于隔离的模型如 SEIR 和 SIR 相结合,我们可以更准确地预测 COVID-19 感染曲线,并超越同质混合假设(所有个体之间的接触概率相等)。
Chang et al. 从手机数据中推导出了时间移动网络,并将 9800 万人的小时移动从人口普查区块组(CBGs)映射到美国的特定兴趣点(POIs)。通过将小时接触网络与 CBG 层面的 SEIR 模型结合,他们能够准确拟合实际感染轨迹。特别是,模型显示一些‘超级传播者’POIs 如餐馆和健身中心占据了大多数感染。此外,不同种族和社会经济群体之间的流动差异导致这些群体之间的感染率不同。这项工作展示了利用大规模时间图进行疾病预测和制定干预策略的现实潜力。
除了人际接触网络,动态交通网络在 COVID-19 的传播中也扮演着重要角色。在一项研究中,我们将每日航班网络纳入 SEIR 模型,以估计输入的 COVID-19 病例。通过纳入航班网络,可以实现对疫情爆发的早期检测并预测旅行限制的影响。更多细节请见作者的博客文章。
尽管基于时间图的疾病模型在实践中取得了成功,但回答诸如“接触网络结构如何影响疾病传播?”和“如何修改接触模式以减缓或阻止 COVID-19 的传播?”等问题也很重要。Holme 等比较了在八个网络数据集中使用时间、静态和完全连接网络的爆发特征差异,并研究了不同网络结构对疾病传播的影响。他们展示了将时间网络转换为静态网络可能导致对疾病爆发规模和消失时间的严重低估或高估。
TGL 在流行病建模方面的下一步是什么?
首先,预测整个接触或流动网络的快照以应对短期挑战是一个关键问题。通过预测的结构,我们可以应用基于网络的 SEIR 模型来估计感染曲线。
其次,定义和理解互动模式对接触网络的影响对于政策制定和可解释性至关重要。分析图结构与感染曲线之间的相互作用可以帮助我们确定最有效的干预策略。
时间图中的异常检测
异常检测是分析时间图中的一个基本任务,它识别出与其他实体显著偏离的实体。例如,欺诈检测可以被建模为在交易网络中检测异常边缘,而交通事故识别可以被视为在交通网络中检测异常事件。
对于利用时间图网络的表示能力进行异常检测的兴趣日益增长。蔡等人设计了一个端到端结构化时间图神经网络模型,用于检测异常边,称为StrGNN。首先基于感兴趣的边提取一个包围子图,一个以该边为中心的 k-hop 子图,以减少计算复杂性。然后使用图卷积神经网络(GCN)从子图中生成结构嵌入。接着使用门控递归单元(GRUs)来捕捉时间信息。异常检测的挑战之一是缺乏标记样本。因此,蔡等人提出通过替换正常边中的一个节点来生成“上下文相关”的负边,并用这些负边来训练模型。
与无监督的、非 GNN 基础的异常检测方法,如SEDANSPOT和AnomRank相比,GNN 基础的方法可以轻松地结合任何给定的属性,并具有实现更强性能的潜力。然而,GNN 基础的方法面临两个重大挑战。
1). 首先,如何扩展到具有数百万条边和节点的动态图?这是一个开放性问题,既涉及到 GNN 模块在提取图特征时的挑战,也涉及到处理长期信息的时间模块,如 GRUs 和 transformers。
2️). 其次,如何为检测到的异常提供准确的解释?在实际应用中,检测到的异常通常会被验证,然后可能对这些检测到的实体采取惩罚措施。GNN 在动态图上的可解释性仍然是一个未解决的挑战。
LAD 检测到 2013 年是加拿大 MP 投票网络中的一个变化点,原因是政治党派之间的边的数量异常。
图片来源:黄等人
变化点检测任务旨在检测动态图中时间点的变化,其中图结构或分布显著偏离之前观察到的状态。这种变化可能归因于外部事件(如交通中断和 COVID-19 相关的航班限制),或仅仅是动态图的自然演变。作者之一的近期工作利用了每个图快照的拉普拉斯矩阵的特征值来嵌入图结构,同时应用滑动窗口来比较图结构在长短期内的变化。在上述内容中,提出的拉普拉斯异常检测 (LAD) 方法检测到了由于政治党派之间边缘增加而导致的加拿大国会议员(MP)投票网络中的变化。这与贾斯廷·特鲁多在 2013 年被选为自由党领导人的事件相吻合。
在时间图上检测虚假信息
虚假信息的传播模式和速度与真实信息不同 (Vosoughi 等)。已有大量研究在静态图中研究这些网络模式,而动态图方法尚未得到充分探索 (Song 等)。然而,在过去的一年里,使用 TGL 方法进行虚假信息检测和理解的数量有所增加。例如,Zhang 等 开发了一种基于时间点过程的方法,而动态 GCN (DynGCN) 和 DGNF 是基于动态 GNN 的方法。
下图展示了 DynGCN 的架构。他们以均匀时间间隔构建图快照,通过 GCN 层处理每个快照,然后结合这些表示并使用注意力机制学习快照的演变模式。这是一种相对简单的方法,比起上述一些方法如 TGN 或 CAW,它利用时间信息的方式更为简单,但在作者检查的数据集上,比之前的最先进技术在虚假信息检测方面表现更好。
DynGCN 使用具有共享权重的 GCN 层处理单个图快照,然后通过注意力机制结合这些表示以获取时间上的演变。
图片来源: Choi 等
动态交互模式在虚假信息检测中被证明非常有用(Plepi 等)。随着 TGL 方法的显著进展,我们可以期待结合动态图的新型最先进的虚假信息检测方法。
加入时间图学习社区
2022 年,机器学习社区对时间图学习(TGL)的关注有所增加。首届TGL 研讨会于 NeurIPS 2022 上举办。会议的演讲和讨论会录像将很快在NeurIPS 虚拟网站上提供。接受的论文可以在研讨会网站上找到。请关注 TGL 研讨会的新版本公告,并加入研讨会 Slack(网站上有最新链接)以便与社区互动。今年,我们还计划组织一个 TGL 阅读小组,如果你希望分享你的工作或参与组织阅读小组,请发送邮件至 shenyang.huang@mail.mcgill.ca。
图片来源:NeurIPS 2022 时间图学习研讨会的 logo。图片由作者提供。
Python 中的临时变量:可读性与性能
原文:
towardsdatascience.com/temporary-variables-in-python-readability-versus-performance-f6708b5f293c
PYTHON 编程
临时变量可以使代码更清晰。那么这样的代码性能如何呢?
Marcin Kozak
·发布于Towards Data Science ·8 分钟阅读·2023 年 6 月 1 日
--
Python 的快捷方式快吗?图片由Stefan Steinbauer提供,Unsplash
临时变量是生命周期很短的变量:
[## 临时变量 - 维基百科
来自维基百科,免费的百科全书。在计算机编程中,临时变量是生命周期很短的变量……
en.wikipedia.org
临时变量在编程中非常常用,你不需要知道这个术语也可以使用临时变量。一个最常见的用例是使代码更清晰,例如,在管道中:
input → tempvar_1 := func_1(input) →
tempvar_2 := func_2(tempvar_1) →
func_3(tempvar_2) → output
在这里,我使用了 Python 的海象运算符来直观地表示赋值,就像在 Python 代码中使用的一样。在这个管道中,我们有两个临时变量:tempvar_1
和 tempvar_2
。它们在数据流过代码的时间上生命周期很短,尽管在实际时间上可能很长。tempvar_1
仅用于一个目的:将管道第一步的结果传递到下一步。但请注意,从技术上讲,它是没有必要的:
input → func_3(func_2(func_1(input))) → output
虽然这两个版本的功能相同,但后者的可读性可能大大降低。因此,前者在编程中被广泛使用,唯一的原因是使代码更清晰。
注意,如果tempvar_1
或tempvar_2
在代码后续中使用,它们就不再是临时变量,因为它们的生命周期不会很短。为了简单起见,我们可以假设临时变量是你只使用一次的变量,用于将一个可调用的输出作为输入传递给另一个。
你是否曾经思考过在管道中使用临时变量是否比直接——和最短——的计算方式更好?比如,以下两个代码片段哪个更好?
# snippet 1
third_function(second_function(first_function(x)))
# snippet 2
x1 = first_function(x)
x2 = second_function(x1)
x3 = third_function
或者,这次使用简单的算术运算:
# snippet 1
x1 = 2.056 * x
x2 = x1 / (1 + x1)
y = 2.3 / (- x2 - 7.33)
# snippet 2
y = 2.3 / (- 2.056 * x / (1 + 2.056 * x) - 7.33)
你会选择哪种方式?这重要吗?
Python 因各种原因而非常受欢迎,其中之一是其代码的可读性。同时,Python 也因其性能差而闻名——尽管它并不像许多人声称的那样糟糕,正如我在下面的文章中所写的:
[## Python 的速度:并没有那么糟糕!
我一直听到 Python 太慢了。这是真的吗?
medium.com](https://medium.com/pythoniq/the-speed-of-python-it-aint-that-bad-9f703dd2924e?source=post_page-----f6708b5f293c--------------------------------)
很多时候,你可以——也需要——在可读性和性能之间进行选择。有时你可能需要哪怕是最微小的性能提升,即使这意味着可读性降低。其他时候,小幅度的性能提升意味着没有副作用,并且代码既可读又易懂;为何不选择它呢?
当性能的提升带来一些成本时,你应该小心。你应该问自己——或你所在的开发团队——以下问题:这种微小的性能提升是否值得降低代码的可读性?
在这篇文章中,我想向你展示一个通过避免临时变量实现的改进的例子。去除它们可以稍微提高性能,但通常会以降低可读性为代价。是的,通常是这样,所以不总是:如果你运气好,去除临时变量可以帮助你同时提高性能和可读性。这是完美的情况,不是吗?
Python 代码中的临时变量
想象一下你想实现一个计算一系列事物的函数。为了简单起见,我们将进行一些基本的算术计算,使例子变得简单。然而,在现实生活中,这样的管道可能包含多个函数执行各种操作,甚至相当复杂。
def calc_with_tempvar(x):
y = x**2
z = y/2
f = z + 78
g = f/333.333
return g
因此,我们从 x
开始,然后计算 y
、z
、f
,最终得到 g
,g
是最终输出,因此返回。这类似于 函数组合,不同之处在于这里我们不是组合函数,而是组合计算。然而,在许多场景中,你将会有实际的函数;例如,代替 y = x**2
,你可以有 y = some_function(x)
。在 Python 中类似的一个完美示例就是生成器管道:
## 在 Python 中构建生成器管道
本文提出了一种优雅的构建生成器管道的方法。
towardsdatascience.com
及其一般版本,理解管道:
## 在 Python 中构建理解管道
理解管道是 Python 特有的构建管道的概念。
towardsdatascience.com
在简单的情况下,例如我们的 calc_with_tempvar()
函数中,这种方法似乎有些多余。相反,我们可以简单地做如下操作:
def calc_without_tempvar(x):
return ((x**2)/2 + 78)/333.333
这些测试表明,两者产生了完全相同的结果:
>>> for x in (1, 2.3, 0.0000465, 100_000_000.004):
... assert calc_with_tempvar(x) == calc_without_tempvar(x)
没有输出意味着这确实是正确的。
首先,让我们拆解这两个函数,看看它们如何转换为 Python 字节码:
使用 dis.dis() 函数对两个函数进行拆解。图片由作者提供
即使不分析这两个函数的字节码,我们也可以看到,使用临时变量的函数比不使用临时变量的函数在 Python 中需要做更复杂的工作。这不奇怪,对吧?函数的定义方式本身表明 calc_with_tempvar()
要达到结果需要做更多工作,而不是 calc_without_tempvar()
。
临时变量:性能
然而,这如何转化为性能呢?为了了解这一点,让我们使用 [perftester](https://github.com/nyggus/perftester)
Python 包,它专门用于基准测试和测试 Python 函数的时间和内存性能:
## 轻松基准测试 Python 函数:perftester
你可以使用 perftester 轻松地基准测试 Python 函数。
towardsdatascience.com
对于基准测试,我在 Windows 10 机器上的 WSL 1 中使用了 Python 3.11,配备 32GB 的 RAM 和四个物理(八个逻辑)核心。然而,在我们的案例中,原始时间并不那么重要;我们将重点关注相对比较。
首先,让我更改基准测试的默认设置。我将使用 2000 万次函数调用重复 7 次;从中选择最快的一次作为基准结果。
>>> import perftester
>>> perftester.config.set_defaults(
... "time",
... Number=20_000_000,
... Repeat=7,
... )
现在实际的基准测试,对于一个float
数值:
>>> x = 1.67
>>> t1 = perftester.time_benchmark(calc_with_tempvar, x)
>>> t2 = perftester.time_benchmark(calc_without_tempvar, x)
然后让我们看看结果¹:
>>> perftester.pp({
... "1\. composition": t1["min"],
... "2\. one-shot": t2["min"],
... "3\. composition--to--one-shot ratio": t1["min"] / t2["min"]
... })
{'1\. composition': 2.063e-07,
'2\. one-shot': 1.954e-07,
'3\. composition--to--one-shot ratio': 1.056}
如预期的那样,一次性版本(不使用临时变量)更快——大约快 5%。一方面,这并不多。另一方面,这仅仅是通过如此微小的改变——如此小的变化就达到了 5%!
上述计算是快速的。然而,对于较长的计算,差异可能接近于不可见。
你注意到我们可以稍微改进一下calc_with_tempvar()
函数吗?我们需要最后一个对象g
吗?有时候,像这样的对象通过一个好的名字可以提高函数的可读性,但在这种情况下并不需要——所以我们不需要g
。让我们看看去掉它是否会提高性能:
def calc_with_tempvar_shorter(x):
y = x**2
z = y/2
f = z + 78
return f/333.333
>>> t3 = perftester.time_benchmark(calc_without_tempvar_shorter, x)
>>> t3["min"]
1.998e-07
一个微小的改进,因为组合版本比这个版本慢1.032
倍,而这个版本比一次性版本慢1.023
倍。但再次强调,这种改进是通过如此微小的变化实现的!既然如此,这个微小的改变值得使用吗?
结论
对我来说——绝对值得,但并非总是。
关键是,当性能不重要时,优先考虑可读性。如果程序运行多一分钟、10 秒钟或甚至半秒钟都不会改变任何事情——干脆不要考虑通过这些技巧来提高性能。为什么要这样做?为什么要为了微小的改进而降低可读性呢?只需关注可读性。
当然,有时去掉临时变量会提高函数的可读性。在这种情况下,为什么我们还要讨论这个?再次强调,优先考虑可读性,当这也意味着提高性能时——那就完美了。
有时性能确实很重要。即使是秒的分割也可能产生差异。如果是这种情况,你应该分析你的代码并找出瓶颈。其他时候,你可能希望优化代码的每一个部分。一个例子是为他人使用的框架进行工作,对于其中一些人,性能将很重要。在这种情况下,作为框架作者的你有责任提供尽可能快的工具。否则,你有可能会冒着一些用户不会使用你的框架的风险。
总结:
-
如果性能很重要,避免使用像
calc_with_tempvar()
中的临时变量。如果性能重要性较低(如果有的话),则优先考虑可读性——这意味着是否使用临时变量的决定应该完全基于代码的可读性。 -
临时变量并不总是增加可读性。例如,假设你有一个数学函数
y(x) = ((x**2)/2 + 78)/333.333
。你认为calc_with_tempvar()
,那个包含所有临时变量的函数,会提高可读性吗?我不这么认为。
因此,有时临时变量会提高代码的可读性,有时则不会。如果性能至关重要,请记住,临时变量可能会增加一些轻微的开销。更多时候,这些开销是微不足道的——但在一些项目中,即使是那些秒数的分割也可能很重要。
总之,始终双重检查是否值得在你的代码中去除临时变量——或者是否值得使用它们。
注释
¹ 代码使用了 perftester.pp()
函数,该函数以标准库函数 pprint.pprint()
的方式美观地打印 Python 对象,并将其中的所有数字四舍五入到四位有效数字。它使用了rounder包:
## GitHub - nyggus/rounder: 用于在复杂 Python 对象中对浮点数和复杂数字进行四舍五入的 Python 包
rounder
是一个轻量级的包,用于在复杂的 Python 对象(如字典、列表、元组等)中对数字进行四舍五入。
github.com
感谢阅读。如果你喜欢这篇文章,你可能也会喜欢我写的其他文章;你可以在这里查看。如果你想加入 Medium,请使用下面的推荐链接:
## 使用我的推荐链接加入 Medium - Marcin Kozak
阅读 Marcin Kozak 的每一个故事(以及 Medium 上成千上万其他作家的故事)。你的会员费直接支持…
medium.com
AI 十年回顾
原文:
towardsdatascience.com/ten-years-of-ai-in-review-85decdb2a540
从图像分类到聊天机器人治疗
Thomas A Dorfer
·发表于 Towards Data Science ·阅读时间 15 分钟·2023 年 5 月 23 日
--
图片由作者提供。
过去十年对于人工智能(AI)领域来说是一段激动人心且充满变故的旅程。对深度学习潜力的初步探索转变为一个爆炸性的增长领域,现在涵盖了从电子商务中的推荐系统到自动驾驶汽车的目标检测,以及能够生成从逼真图像到连贯文本的一切生成模型。
在这篇文章中,我们将回顾一下让我们走到今天的一些关键突破。不论你是经验丰富的 AI 从业者还是对该领域最新进展感兴趣的读者,这篇文章都将为你提供一个全面的概述,展示 AI 成为家喻户晓名字的非凡进展。
2013 年:AlexNet 和变分自编码器
2013 年被广泛认为是深度学习的“成熟期”,这主要得益于计算机视觉领域的重大进展。根据最近对 Geoffrey Hinton 的采访,到 2013 年“几乎所有的计算机视觉研究都转向了神经网络”。这股热潮主要由一年前图像识别领域的一个相当意外的突破所推动。
在 2012 年 9 月,AlexNet,一个深度卷积神经网络(CNN),在 ImageNet 大规模视觉识别挑战赛(ILSVRC)中表现出色,展示了深度学习在图像识别任务中的潜力。它达到了 15.3%的前五名错误率,比其最接近的竞争对手低了 10.9%。
图片由作者提供。
这些成功背后的技术进步对人工智能未来的发展轨迹至关重要,并且极大地改变了人们对深度学习的认知。
首先,作者应用了一个由五个卷积层和三个全连接线性层组成的深度卷积神经网络(CNN)——这一架构设计当时被许多人认为是不切实际的。此外,由于网络的深度产生了大量的参数,训练是在两个图形处理单元(GPU)上并行进行的,展示了在大规模数据集上显著加速训练的能力。通过将传统的激活函数,如 sigmoid 和 tanh,替换为更高效的修正线性单元(ReLU),训练时间得到了进一步缩短。
图片由作者提供。
这些共同促成 AlexNet 成功的进展标志着人工智能历史上的一个转折点,并激发了学术界和技术社区对深度学习的广泛兴趣。因此,2013 年被许多人认为是深度学习真正开始蓬勃发展的拐点。
虽然在 2013 年也发生了相关进展,但被 AlexNet 的声势所掩盖,即变分自编码器(VAEs)的发展——这种生成模型能够学习表示和生成图像、声音等数据。它们通过学习输入数据在一个低维空间中的压缩表示,称为潜在空间,从而生成新的数据。这使得它们能够通过从学习到的潜在空间中采样来生成新数据。变分自编码器后来被证明为生成建模和数据生成开辟了新途径,并在艺术、设计和游戏等领域中找到了应用。
2014 年:生成对抗网络
在接下来的一年,即 2014 年 6 月,深度学习领域见证了另一项重大进展,那就是 Ian Goodfellow 和同事们引入了生成对抗网络(GANs)。
生成对抗网络是一种神经网络,能够生成与训练集相似的新数据样本。基本上,两个网络被同时训练:(1)生成器网络生成虚假的或合成的样本,(2)鉴别器网络评估这些样本的真实性。这种训练在一个类似游戏的设置中进行,生成器试图创建能够欺骗鉴别器的样本,而鉴别器则试图正确识别出虚假的样本。
当时,生成对抗网络(GANs)代表了一种强大而新颖的数据生成工具,不仅用于生成图像和视频,还用于音乐和艺术创作。它们还推动了无监督学习的发展,这一领域在很大程度上被认为是欠发展的且具有挑战性的,通过展示在没有明确标签的情况下生成高质量数据样本的可能性。
2015 年:ResNets 和 NLP 突破
在 2015 年,人工智能领域在计算机视觉和自然语言处理(NLP)方面取得了显著进展。
Kaiming He 及其同事发表了一篇名为“图像识别的深度残差学习”的论文,在论文中他们介绍了残差神经网络(ResNets)的概念——这种架构通过添加快捷连接使信息在网络中更容易流动。与普通神经网络不同,在普通神经网络中,每一层都将前一层的输出作为输入,而在 ResNet 中,增加了额外的残差连接,这些连接跳过一层或多层,并直接连接到网络中的更深层。
因此,ResNets 能够解决梯度消失问题,这使得训练比当时认为可能的更深层次的神经网络成为可能。这反过来又在图像分类和物体识别任务中带来了显著的改进。
与此同时,研究人员在递归神经网络(RNNs)和长短期记忆(LSTM)模型的开发方面也取得了显著进展。尽管这些模型自 1990 年代以来就存在,但它们直到 2015 年左右才开始引起关注,这主要是由于以下几个因素:(1)训练用的数据集变得更大、更具多样性,(2)计算能力和硬件的提升,使得训练更深层次、更复杂的模型成为可能,(3)在此过程中进行的修改,如更复杂的门控机制。
因此,这些架构使语言模型能够更好地理解文本的上下文和含义,从而在语言翻译、文本生成和情感分析等任务中取得了巨大的进步。RNNs 和 LSTMs 在那个时期的成功为我们今天看到的大型语言模型(LLMs)的发展铺平了道路。
2016 年:AlphaGo
在 1997 年加里·卡斯帕罗夫被 IBM 的 Deep Blue 击败之后,另一场人类对机器的对决在 2016 年引起了游戏界的震动:谷歌的 AlphaGo 战胜了围棋世界冠军李世石。
图片由Elena Popova提供,来源于Unsplash。
Sedol 的失败标志着人工智能进步轨迹中的另一个重要里程碑:它证明了机器能够超越即使是最熟练的人类玩家,在一个曾经被认为过于复杂的游戏中。利用深度强化学习和蒙特卡罗树搜索的结合,AlphaGo 分析了来自先前游戏的数百万种局面,并评估了最佳可能的走法——这种策略在此背景下远超人类决策。
2017: Transformer 架构与语言模型
可以说,2017 年是奠定我们今天所见生成式 AI 突破基础的关键一年。
2017 年 12 月,Vaswani 及其同事发布了基础性的论文“Attention is all you need”,介绍了利用自注意力概念处理序列输入数据的 transformer 架构。这使得对长距离依赖的处理变得更加高效,而这以前一直是传统 RNN 架构的挑战。
照片由Jeffery Ho拍摄,发布在Unsplash。
Transformer 由两个基本组件组成:编码器和解码器。编码器负责对输入数据进行编码,例如,这可以是一系列单词。它然后对输入序列应用多层自注意力和前馈神经网络,以捕捉句子中的关系和特征,并学习有意义的表示。
本质上,自注意力使模型能够理解句子中不同单词之间的关系。与传统模型按固定顺序处理单词不同,transformers 实际上一次性检查所有单词。它们根据单词在句子中的相关性为每个单词分配一种叫做注意力的分数。
解码器则从编码器处获取编码后的表示,并生成一个输出序列。在诸如机器翻译或文本生成等任务中,解码器根据从编码器接收到的输入生成翻译序列。与编码器类似,解码器也由多层自注意力和前馈神经网络组成。然而,它包含一个额外的注意力机制,使其能够关注编码器的输出。这使得解码器在生成输出时能够考虑输入序列中的相关信息。
自那时以来,transformer 架构已成为 LLM 发展的关键组成部分,并在 NLP 领域(如机器翻译、语言建模和问答)带来了显著的改进。
2018 年:GPT-1、BERT 和图神经网络
在 Vaswani 等人发表了他们的基础论文几个月后,Generative Pretrained Transformer,或 GPT-1 于 2018 年 6 月由 OpenAI 推出,利用 transformer 架构有效捕捉文本中的长距离依赖关系。GPT-1 是最早展示无监督预训练效果并在特定 NLP 任务上进行微调的模型之一。
谷歌也利用了当时还相对新颖的 transformer 架构,2018 年末发布并开源了他们自己的预训练方法,称为Bidirectional Encoder Representations from Transformers,或 BERT。与以单向方式处理文本的之前模型(包括 GPT-1)不同,BERT 同时考虑每个单词的前后上下文。为了说明这一点,作者提供了一个非常直观的例子:
在句子 “I accessed the bank account” 中,一个单向上下文模型会基于 “I accessed the” 来表示 “bank” ,而不是 “account” 。然而,BERT 通过使用前后上下文来表示 “bank” ——即 “I accessed the … account” ——从深度神经网络的最底层开始,使其具有深度的双向特性。
双向性的概念如此强大,以至于 BERT 在各种基准任务上超越了最先进的 NLP 系统。
除了 GPT-1 和 BERT,图神经网络或 GNNs 在那一年也引起了一些关注。它们属于专门设计用于处理图数据的神经网络类别。GNNs 利用消息传递算法在图的节点和边之间传播信息。这使得网络能够以更加直观的方式学习数据的结构和关系。
这一工作使得从数据中提取更深层次的见解成为可能,从而拓宽了深度学习应用的问题范围。借助 GNNs,社交网络分析、推荐系统和药物发现等领域取得了重大进展。
2019 年:GPT-2 和改进的生成模型
2019 年标志着生成模型的几项显著进展,特别是 GPT-2 的推出。这个模型通过在许多自然语言处理任务中实现最先进的性能而让同行相形见绌,而且能够生成高度逼真的文本,事后来看,这为我们展示了该领域即将到来的新进展。
该领域的其他改进包括 DeepMind 的 BigGAN,它生成了几乎无法与真实图像区分的高质量图像,以及 NVIDIA 的 StyleGAN,它允许对生成的图像外观进行更好的控制。
总体而言,这些现在被称为生成式 AI 的进展将这一领域的边界推向了更远的地方,…
2020 年:GPT-3 和自监督学习
… 不久之后,另一个模型诞生了,它甚至在技术圈外也成为了家喻户晓的名字:GPT-3。这个模型在大型语言模型的规模和能力上迈出了重大的一步。为了让事情有个背景,GPT-1 只有区区 1.17 亿个参数。这个数字在 GPT-2 中增加到了 15 亿,而 GPT-3 则有 1750 亿个参数。
这个巨大的参数空间使 GPT-3 能够在各种提示和任务中生成非常连贯的文本。它在文本完成、问题回答甚至创意写作等各种自然语言处理任务中也表现出了令人印象深刻的性能。
此外,GPT-3 再次突出了使用自监督学习的潜力,这使得模型可以在大量未标记的数据上进行训练。这种方法的优点在于,这些模型可以在不需要大量任务特定训练的情况下获得广泛的语言理解,这使得它更加经济。
Yann LeCun 在推特上提到了一篇关于自监督学习的《纽约时报》文章。
2021 年:AlphaFold 2、DALL·E 和 GitHub Copilot
从蛋白质折叠到图像生成以及自动化编程辅助,2021 年因 AlphaFold 2、DALL·E 和 GitHub Copilot 的发布而变得格外引人注目。
AlphaFold 2 被誉为对长期存在的蛋白质折叠问题的期待已久的解决方案。DeepMind 的研究人员扩展了 Transformer 架构,创建了 evoformer blocks——一种利用进化策略进行模型优化的架构——来构建一个能够基于蛋白质的 1D 氨基酸序列预测其 3D 结构的模型。这一突破在药物发现、生物工程以及我们对生物系统的理解等领域具有巨大的革命性潜力。
OpenAI 今年再次成为新闻焦点,发布了DALL·E。本质上,这个模型结合了 GPT 风格的语言模型和图像生成的概念,使得从文本描述中生成高质量图像成为可能。
为了说明这个模型的强大程度,请参见下面的图像,它是通过提示“未来世界的油画,飞行汽车”生成的。
由 DALL·E 制作的图像。
最后,GitHub 发布了后来成为每位开发者最佳伙伴的Copilot。这是与 OpenAI 合作实现的,OpenAI 提供了基础语言模型 Codex,该模型在大量公开可用的代码库上进行训练,从而学习理解和生成各种编程语言的代码。开发者只需提供一个描述他们尝试解决问题的代码注释,Copilot 就会建议实现解决方案的代码。其他功能包括能够用自然语言描述输入代码以及在编程语言之间翻译代码。
2022 年:ChatGPT 和 Stable Diffusion
在过去十年中,人工智能的快速发展 culminated in a groundbreaking advancement: OpenAI 的ChatGPT,这是一个于 2022 年 11 月发布的聊天机器人。这个工具代表了自然语言处理领域的尖端成就,能够对各种查询和提示生成连贯且具有上下文相关的回应。此外,它还可以进行对话、提供解释、提出创意建议、协助解决问题、编写和解释代码,甚至模拟不同的人物性格或写作风格。
作者提供的图像。
与聊天机器人交互的简单直观界面也刺激了其使用率的急剧上升。之前,主要是科技圈会玩弄最新的基于 AI 的发明。然而,如今,AI 工具几乎渗透到了每个专业领域,从软件工程师到作家、音乐家和广告商。许多公司也在利用这一模型自动化服务,如客户支持、语言翻译或回答常见问题。事实上,我们所见的自动化浪潮重新激发了一些担忧,并引发了有关哪些工作可能面临自动化风险的讨论。
尽管 ChatGPT 在 2022 年占据了大部分的焦点,但在图像生成方面也取得了显著的进展。Stable Diffusion,一个能够从文本描述生成照片级真实图像的潜在文本到图像扩散模型,由 Stability AI 发布。
稳定扩散是传统扩散模型的扩展,其工作原理是通过迭代地向图像中添加噪声,然后逆转这一过程以恢复数据。它的设计目的是通过不直接处理输入图像,而是处理其较低维度的表示或潜在空间,从而加快这一过程。此外,扩散过程通过将用户的变换器嵌入文本提示添加到网络中来进行修改,使其能够在每次迭代中引导图像生成过程。
总的来说,2022 年 ChatGPT 和 Stable Diffusion 的发布突显了多模态生成 AI 的潜力,并引发了该领域进一步发展的巨大推动和投资。
2023:LLMs 和 Bots
今年无疑成为了 LLMs 和聊天机器人的一年。越来越多的模型以快速增长的速度被开发和发布。
图片由作者提供。
例如,2 月 24 日,Meta AI 发布了 LLaMA —— 一个在大多数基准测试中超过 GPT-3 的 LLM,尽管其参数数量要少得多。不到一个月后,3 月 14 日,OpenAI 发布了 GPT-4 —— 一个比 GPT-3 更大、更强大且多模态的版本。尽管 GPT-4 的确切参数数量未知,但据推测已达到万亿级别。
3 月 15 日,斯坦福大学的研究人员发布了 Alpaca,这是一个从 LLaMA 上经过指令跟随演示微调的轻量级语言模型。几天后,3 月 21 日,谷歌推出了其 ChatGPT 竞争对手:Bard。谷歌还于本月 5 月 10 日发布了其最新的 LLM,PaLM-2。鉴于这一领域的发展步伐如此迅猛,极有可能在你阅读此文时,又有一个新模型出现了。
我们还看到越来越多的公司将这些模型融入到他们的产品中。例如,Duolingo 宣布了其 GPT-4 驱动的 Duolingo Max,这是一个新的订阅层级,旨在为每个人提供量身定制的语言课程。Slack 也推出了一个名为 Slack GPT 的 AI 助手,可以进行诸如草拟回复或总结对话线程等任务。此外,Shopify 将 ChatGPT 驱动的助手引入了公司的 Shop 应用,该助手可以帮助客户使用各种提示识别所需的产品。
Shopify 在 Twitter 上宣布了其 ChatGPT 驱动的助手。
有趣的是,如今 AI 聊天机器人甚至被视为人类治疗师的替代品。例如,Replika,一个美国聊天机器人应用程序,为用户提供了一个“关心的 AI 伴侣,总是倾听和交谈,总是在你身边”。其创始人尤金尼亚·库伊达表示,该应用拥有广泛的用户群体,从自闭症儿童,他们使用它作为“在人际互动前热身”的方法,到孤独的成年人,他们仅仅需要一个朋友。
在我们总结之前,我想强调一下可能是过去十年人工智能发展高潮的事件:人们实际上在使用 Bing!今年早些时候,微软推出了其基于 GPT-4 的“网络助手”,该助手为搜索进行了定制,并且在……永远以来(?)首次成为谷歌在搜索领域长期霸主的真正竞争者。
回顾过去,展望未来
当我们回顾过去十年的人工智能发展时,很明显我们见证了一场对我们工作、商业和人际互动方式产生深远影响的变革。最近在生成模型,尤其是大型语言模型(LLMs)方面取得的显著进展似乎都遵循了“越大越好”的普遍信念,这指的是模型的参数空间。这在 GPT 系列中尤为明显,GPT-1 最初有 1.17 亿个参数,而每个后续模型的参数量大约增加一个数量级,最终发展到 GPT-4,可能拥有数万亿个参数。
然而,根据最近的采访,OpenAI 首席执行官山姆·奥特曼认为我们已经迎来了“越大越好”时代的终结。他认为,未来参数数量仍会增长,但未来模型改进的主要关注点将是提高模型的能力、实用性和安全性。
后者尤为重要。考虑到这些强大的人工智能工具现在已经掌握在公众手中,而不再局限于受控的研究实验室环境,现在比以往任何时候都更加关键,我们要谨慎行事,确保这些工具的安全,并符合人类的最佳利益。希望我们在人工智能安全领域能看到与其他领域同样的开发和投资。
附言: 如果我遗漏了你认为应该包含在这篇文章中的核心 AI 概念或突破,请在下方评论告诉我!
喜欢这篇文章吗?
让我们保持联系吧!你可以在Twitter、LinkedIn和Substack找到我。
如果你想支持我的写作,可以通过Medium 会员来实现,这样你可以访问我所有的故事以及 Medium 上成千上万其他作家的故事。
[## 通过我的推荐链接加入 Medium - Thomas A Dorfer
阅读 Thomas A Dorfer 的每一个故事(以及 Medium 上成千上万其他作家的故事)。你的会员费直接支持…
medium.com
张量量化:未被讲述的故事
原文:
towardsdatascience.com/tensor-quantization-the-untold-story-d798c30e7646?source=collection_archive---------2-----------------------#2023-09-08
详细了解机器学习框架中量化的实现细节
Dhruv Matani
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 9 月 8 日
--
与 Naresh Singh 合著。
目录
-
介绍
-
量化中的尺度和零点这两个术语是什么意思?
-
量化方案的类型
-
量化尺度和零点示例
-
量化与激活归一化
-
结论
-
参考文献
介绍
在本文的剩余部分,我们将尝试通过具体示例回答以下问题。
-
量化中的尺度和零点这两个术语是什么意思?
-
不同的量化方案有哪些类型?
-
如何计算不同量化方案的尺度和零点
-
为什么零点对于量化很重要?
-
归一化技术如何有利于量化
量化中的尺度和零点是什么意思?
尺度: 在量化浮点范围时,通常会将浮点范围[Fmin..Fmax]表示为量化范围[Qmin..Qmax]。在这种情况下,尺度是浮点范围和量化范围的比率。
我们稍后将看到如何计算它的示例。
零点: 量化的零点是量化范围中浮点 0.0 的表示。具体来说,零点是一个量化值,它代表浮点值 0.0,满足所有实际用途。稍后我们将通过示例看到它是如何计算的,以及这种表示对我们为何具有实际意义。
接下来,让我们查看实际中使用的主要量化方案,并熟悉它们的相似之处和不同之处。
量化方案类型
在考虑用于模型压缩的量化类型时,有两种主要类型可以选择。
-
对称量化: 在这种情况下,零点是零 —— 即浮点范围的 0.0 与量化范围中的 0 相同。通常,这种方法在运行时计算更高效,但如果浮点范围在浮点 0.0 周围不均匀分布,可能会导致较低的准确度。
-
仿射(或非对称)量化: 这是具有非零值零点的量化方法。
但在深入细节之前,让我们尝试定义一下零点是什么意思。
量化尺度和零点示例
让我们从一个非常简单的例子开始,并逐步构建起来。
示例-1:对称 uint8 量化
假设我们希望将浮点范围[0.0 .. 1000.0]映射到量化范围[0 .. 255]。范围[0 .. 255]是一组可以适合无符号 8 位整数的值。
为了进行这种转换,我们希望重新调整浮点范围,使得以下条件成立:
浮点 0.0 = 量化 0
浮点 1000.0 = 量化 255
这被称为对称量化,因为浮点 0.0 被量化为 0。
因此,我们定义一个尺度,其等于
其中,
在这种情况下,尺度= 3.9215
要将浮点值转换为量化值,我们可以简单地将浮点值除以尺度。例如,浮点值 500.0 对应于量化值
在这个简单的例子中,浮点范围的 0.0 恰好映射到量化范围中的 0。这被称为对称量化。让我们看看当情况不是这样时会发生什么。
示例-2:仿射 uint8 量化
假设我们希望将浮点范围[-20.0 .. 1000.0]映射到量化范围[0 .. 255]。
在这种情况下,由于我们的xmin不同,所以我们有一个不同的缩放因子。
如果我们将缩放因子应用于 0.0,看看浮点数 0.0 在量化范围中是如何表示的
好吧,这似乎不太对,因为根据上面的图示,我们本来预期浮点值-20.0 映射到量化值 0。
这就是零点概念的所在。零点作为偏差,用于移动缩放后的浮点值,并对应于表示浮点值 0.0 的量化范围中的值。 在我们的案例中,零点是-20.0 的缩放浮点表示的负值,即-(-5) = 5。零点总是最小浮点值表示的负值,因为最小值总是负数或零。我们将在解释示例 4 的部分进一步了解为什么会这样。
每当我们量化一个值时,我们总是会将零点加到这个缩放值上,以获得有效量化范围中的实际量化值。如果我们希望量化值-20.0,我们将其计算为-20.0 的缩放值加上零点,即-5 + 5 = 0。因此,量化(-20.0, scale=4, zp=5) = 0。
示例-3:仿射 int8 量化
如果我们的量化范围是带符号的 8 位整数而不是无符号的 8 位整数会发生什么?好吧,范围现在是[-128 .. 127]。
在这种情况下,浮点范围中的-20.0 映射到量化范围中的-128,而浮点范围中的 1000.0 映射到量化范围中的 127。
我们计算零点的方法是将量化范围视为[0 .. 255],然后偏移-128,因此新范围中的零点是
因此,新范围的零点是-123。
到目前为止,我们查看了浮点范围包括值 0.0 的示例。在下一组示例中,我们将查看当浮点范围不包括值 0.0 时会发生什么。
0.0 的重要性
为什么在浮点范围中表示浮点值 0.0 很重要?
当使用填充卷积时,我们期望边缘像素在最常见的情况下使用值 0.0 进行填充。因此,0.0 在浮点范围内的表示是很重要的。同样,如果值 X 将在网络中用于填充,你需要确保值 X 在浮点范围内被表示,并且量化是意识到这一点的。
示例-4:未讲述的故事 — 歪斜的浮点范围
现在,让我们来看一下如果 0.0 不在浮点范围内会发生什么。
在这个示例中,我们试图将浮点范围[40.0 .. 1000.0]量化到量化范围[0 .. 255]。
由于我们无法在浮点范围内表示值 0.0,我们需要将范围的下限扩展到 0.0。
我们可以看到量化范围的某些部分被浪费了。为了确定浪费了多少,让我们计算浮点值 40.0 映射到的量化值。
因此,我们在量化范围[0 .. 9]中浪费了约 3.92%的范围。这可能会显著影响量化后模型的准确性。
如果我们希望确保浮点范围内的值 0.0 可以在量化范围内表示,那么这种歪斜是必要的。
将值 0.0 包含在浮点范围中的另一个原因是有效地比较量化值以检查其是否在浮点范围内的 0.0 非常有价值。想想像 ReLU 这样的运算符,它将浮点范围内所有小于 0.0 的值剪裁为 0.0。
对我们来说,使用与量化值相同的数据类型(有符号或无符号的 int8) 表示零点是很重要的。这使得我们可以快速而有效地进行这些比较。
接下来,让我们看看激活归一化如何帮助模型量化。我们将特别关注激活值的标准化如何使我们有效地使用整个量化范围。
量化与激活归一化
批量/层归一化将激活张量的均值调整为零,方差调整为单位,无论是按通道还是按层。
假设我们有一个浮点范围为[2000.0 .. 4000.0]的输入张量。这就是量化范围的样子。
我们观察到量化范围[-127 .. -1]的一半没有使用。这是一个问题,因为我们只使用了可用的 8 位中的 7 位来量化整个浮点范围。这无疑会导致更高的量化误差和降低模型的准确性。为了解决这个问题,让我们对激活张量应用层归一化。
在对激活张量应用层归一化后,激活张量将具有[-2.0 .. 2.0]的浮点范围。这可以表示为带符号的 int8 范围[-128 .. 127]。为了确保分布的对称性,我们将量化范围限制为[-127 .. 127]。
因此,归一化可以避免量化范围中的空洞或未使用的部分。
结论
我们了解了什么是仿射(非对称)量化和对称量化,它们的区别是什么。我们还学习了什么是尺度和零点,以及如何为这两种量化方案计算它们。
接下来,我们看到需要在浮点范围中包括浮点 0.0,并了解为什么以及如何在实践中做到这一点。这导致了一个缺点,即在量化范围内浪费了空间。
最后,我们看到了归一化如何通过将激活值带入固定范围并避免在量化范围内浪费空间来帮助量化。事实上,基于 0 均值的归一化可以帮助将仿射量化转换为对称量化,这可以在推理过程中加快速度。
本文中的所有图片均由作者创建。
参考文献
-
高效深度学习书籍,第二章:压缩技术简介
-
Hugging Face:量化
-
TensorRT:量化
-
神经网络蒸馏:量化
-
雷毛:量化
-
量化浮点数
TensorFlow Decision Forests:全面介绍
原文:
towardsdatascience.com/tensorflow-decision-forests-a-comprehensive-introduction-3b6056a6d6b0
使用 TensorFlow 训练、调整、评估、解释和服务基于树的模型
Antons Tocilins-Ruberts
·发表于Towards Data Science ·阅读时长 11 分钟·2023 年 4 月 14 日
--
图片由Javier Allegue Barros提供,来源于Unsplash
介绍
两年前,TensorFlow (TF) 团队开源了一个用于训练基于树的模型的库,TensorFlow Decision Forests (TFDF)。就在上个月,他们终于宣布该软件包已准备好投入生产,因此我决定更深入地了解一下。本文的目的是让你对这个软件包有更好的了解,并展示如何(有效地)使用它。下面你可以看到本文的结构,随意跳过你最感兴趣的部分。
-
什么是 TFDF 以及为什么使用它?
-
使用 TFDF 训练随机森林(RF)和梯度提升树(GBT)模型
-
使用 TFDF 和 Optuna 进行超参数调整
-
模型检查
-
使用 TF Serving 服务 GBT 模型
设置
你可以在我的仓库中找到所有代码,如果还没有,请务必给它加个星。在这篇文章中,我们将使用美国小企业管理局数据集(CC BY-SA 4.0 许可)来训练几个贷款违约预测模型。模型将使用已经预处理的数据进行训练,但你可以在仓库中找到一个笔记本,描述了处理和特征工程步骤。如果你想直接复制我的代码,请确保遵循这些步骤。或者,使用这些代码作为起点,并根据你的数据集进行调整(这是我推荐的方法)。
安装 TensorFlow Decision Forests 非常简单,只需运行pip install tensorflow_decision_forests
,通常这就可以了。虽然有一些在 M1 和 M2 Mac 上报告的问题,但我个人在最新版本的 TFDF 上没有遇到问题。
TensorFlow Decision Forest
什么是 TFDF?
TensorFlow Decision Forest 实际上是基于 Google 开发的 C++库 Yggdrasil Decision Forests。原始 C++算法旨在构建可扩展的决策树模型,以处理大数据集和高维特征空间。通过将这个库整合到更广泛的 TF 生态系统中,用户现在可以轻松地构建可扩展的 RF 和 GBT 模型,而无需学习另一种语言。
为什么使用它?
这个库相对于例如 XGBoost 或 LightGBM 的主要优势在于它与其他 TF 生态系统组件的紧密集成。对于已经在管道中使用其他 TensorFlow 模型或使用 TFX 的团队,它可能特别有趣。TFDF 可以很容易地与例如 NLP 模型集成,使多模态管道更加简单。此外,如果你使用 TF Serving 来服务模型,你也可能会考虑这个库,因为它原生支持(无需 ONNX 或其他跨包序列化方法)。最后,这个库提供了大量参数,你可以调整以接近 XGBoost、LightGBM 和许多其他梯度提升机(GBM)方法的模型。这意味着你在训练过程中无需在不同的 GBM 库之间切换,这对于代码维护来说是非常有利的。
模型训练
确保拉取这个笔记本,并按照下面的步骤操作,因为你在这里只能看到部分代码。
数据
正如设置部分所述,我将使用这个数据集的预处理版本。为了准备 TFDF,我们首先需要像往常一样用 pandas 读取数据,并决定哪些列将作为分类变量,哪些作为数值变量。
特征使用情况
为了确保项目结构良好并避免意外行为,通常建议为每个特征指定一个FeatureUsage
,虽然这不是强制性的。幸运的是,这是一项简单的任务:你只需决定将每个特征分配给六种支持类型中的一种——BOOLEAN
、CATEGORICAL
、CATEGORICAL_SET
、DISCRETIZED_NUMERICAL
、HASH
和NUMERICAL
。其中一些类型还有额外的参数,因此确保了解更多信息 这里。
在这个示例中,我们将保持简单,只使用数值型和类别型数据类型,但不要犹豫尝试其他选项,特别是DISCRETIZED_NUMERICAL
,因为它们可以显著加快训练过程(类似于 LightGBM)。正如下方所示,你需要将所选的数据类型提供给semantic
参数,对于类别特征,我们还需要指定min_vocab_frequency
参数以去除稀有值。
使用 TF 数据集读取数据
读取数据集的最简单方法是使用 TF 数据集。TFDF 提供了一个非常好的实用函数pd_dataframe_to_tf_dataset
,使这一过程变得非常简单。
在上面的代码中,我们将 DataFrame 对象传递给函数,并提供以下参数:
-
标签列的名称
-
权重列的名称(在此情况下为 None)
-
批量大小(有助于加快数据读取速度)
生成的数据集已按照 TF 数据集的正确格式(批处理和预取)进行准备,可以用于训练/评估。当然,你也可以创建自己的读取数据集的方法,但必须特别注意输出格式。
TFDF 默认参数
如果你按照所有先前的数据准备说明进行操作,训练模型非常简单。
从上面的代码可以看出,只需几行代码即可使用默认参数构建和训练 GBT 和 RF 模型。你只需指定所使用的特征、训练和验证数据集,然后就可以开始了。在使用 ROC 和 PR AUC 评估这两个模型时,我们可以看到其性能已经相当不错。
# GBT with Default Parameters
PR AUC: 0.8367
ROC AUC: 0.9583
# RF with Default Parameters
PR AUC: 0.8102
ROC AUC: 0.9453
让我们看看这些结果是否可以通过超参数调整进一步改善。为了简化,我将专注于优化 GBT 模型,但这些方法也可以很容易地应用于 RF 模型。
超参数调整
有许多参数需要调整,每一个的详细解释可以在官方 Yggdrasil 文档中找到。TFDF 为你提供了一些内置的参数调整选项,但你也可以使用更标准的库,如Optuna或Hyperpot。以下是按从最少参与到最多参与的方式排列的方法列表。
-
超参数模板
-
使用预定义空间的超参数搜索
-
使用自定义空间的超参数搜索
超参数模板
TFDF 提供了一个非常酷的功能,就是超参数模板的可用性。这些是论文中显示在各种数据集上表现最好的参数。两个可用的模板是——better_default
和benchmark_rank1
。如果你时间紧迫或对机器学习不太熟悉,这可能是一个不错的选择。指定这些参数只需一行代码。
从结果来看,我们可以看到使用better_default
参数在 ROC 和 PR AUC 上都有了轻微的提升。而benchmark_rank1
参数则表现较差。这就是为什么在部署模型之前正确评估结果模型很重要的原因。
GBT with 'Better Default' Parameters
PR AUC: 0.8483
ROC AUC: 0.9593
GBT with 'Benchmark Rank 1' Parameters
PR AUC: 0.7869
ROC AUC: 0.9442
预定义搜索空间
TFDF 附带了一个名为RandomSearch
的实用工具,它在许多可用参数之间执行随机网格搜索(类似于sklearn
)。有一个选项可以手动指定这些参数(参见示例),但也可以使用预定义的搜索空间。再次说明,如果你对机器学习不太熟悉,这可能是一个不错的选择,因为它不需要你手动设置这些参数。
警告:这个搜索花了我很长时间,因此我不得不在 12 次迭代后停止。一些被测试的参数(例如斜切分裂)需要较长时间来拟合。
你可以使用以下命令访问所有尝试过的组合。
tuning_logs = tuned_model.make_inspector().tuning_logs()
超参数表。截图由作者提供。
在 12 次迭代后,最佳模型的表现稍逊于基线,因此使用这种调优方法时要谨慎。你可以尝试更改搜索空间,但此时你不妨使用另一个库。
PR AUC: 0.8216
ROC AUC: 0.9418
自定义搜索空间(带自定义损失)
使用RandomSearch
方法有一些显著的缺点:
-
仅提供随机网格搜索算法
-
没有选项可以定义你自己的损失函数进行优化
-
如果不使用
use_predefined_hps
标志,则需要提供完整的参数网格
由于这些原因,我强烈建议如果你有足够的知识来自己设置合理的搜索空间,最好使用外部优化库。下面你可以看到如何使用optuna
进行调优。
这些参数中的大多数对于 GBT 来说都是相当标准的,但也有一些值得注意的参数。首先,我们可以将 growing_strategy
更改为 BEST_FIRST_GLOBAL
(即叶子优先生长),这是 LightGBM 使用的策略。其次,我们可以使用 BINARY_FOCAL_LOSS
,这种方法对于不平衡数据集表现更佳 (source)。第三,可以选择将 split_axis
参数更改为使用稀疏斜切分,这在 这篇论文 中显示效果相当好。最后,还可以使用 honest
参数构建 “诚实树”。
下面是使用最佳参数取得的结果。正如你所见,自定义搜索空间的调优迄今为止取得了最佳结果。
GBT with Custom Tuned Parameters
PR AUC: 0.8666
ROC AUC: 0.9631
现在我们已经确定了超参数,让我们重新训练模型并继续进行检查。
模型检查
TFDF 提供了一个名为 Inspector
的实用工具来检查训练后的模型。这个对象有 3 个主要的用途,下面我将进行详细探讨:
-
检查模型的属性,例如类型、树的数量或使用的特征
-
获取特征重要性
-
提取树结构
检查模型属性
inspector 类存储了各种属性,如果例如你加载了其他人的模型或有一段时间没有使用它,你可能想要探索这些属性。你可以打印出模型类型(GBT 或 RF)、模型的树的数量、训练目标以及用于训练模型的特征。检查树的数量特别有用,因为如果早期停止已触发,这个参数会比你设置的值小。
另一种选择是简单地运行 manual_tuned.summary()
以更详细地检查模型。
特征重要性
就像所有其他库一样,TFDF 提供了内置的特征重要性评分。对于 GBT,你可以访问 NUM_NODES
、SUM_SCORE
、INV_MEAN_MIN_DEPTH
、NUM_AS_ROOT
方法进行解释。请注意,你还可以在训练过程中将 compute_permutation_variable_importance
参数设置为 True
,这将添加一些额外的方法。缺点是模型训练时间会显著增加,因此请谨慎使用(最好是在数据样本上)。
重要性条形图。截图由作者提供。
对于我构建的模型,Term 变量一直被认为是最重要的特征,类别变量如 Bank、State 和 Bank State 紧随其后。我认为 TFDF 库最大的缺点之一是无法与 SHAP 配合使用。希望未来的版本能提供支持。
检查单独的树
有时我们需要查看单独的树,以便于解释性或模型验证。 TFDF 在 inspector 对象中提供了对训练过程中构建的所有树的简单访问。目前,让我们检查一下我们的 GBT 模型的第一棵树,因为它通常是最具信息性的。
树结构。作者截图。
正如你所见,当我们处理大型树时,使用打印语句检查它们可能不是很方便。这就是为什么 TFDF 还有一个树绘图工具——tfdf.model_plotter.plot_model
。
第一棵 GBT 树(深度=4)。作者截图。
另外请注意,对于随机森林模型,你可以使用dtreeviz
包,它会给你更具视觉吸引力的结果(这里是如何使用它)。目前,这个包对 GBT 模型尚不支持。
TF Serving
到目前为止,我们已经训练、调整和评估了模型。还有什么其他的呢?当然是服务模型!幸运的是,TFDF 得到了 TF Serving 的原生支持(从最新版本开始),所以这一部分也很简单。如果你已经有了最新的 TF Serving 实例,你只需在model_base_path
参数中指向你保存的模型即可。你可以使用save
方法保存 TFDF 模型。请注意,你应该将其保存到文件夹1
,因为这是你模型的第一个版本。
manual_tuned.save("../models/loan_default_model/1/")
对于那些没有使用 TF Serving 模型的人,你可以在 TF 团队的这里找到一个很好的教程,我也写了这个Colab 笔记本,以防你在 M1 或 M2 Mac 上工作(目前不支持 TF Serving)。
本质上,你需要做的就是在本地安装 TF Serving 并使用正确的参数启动它。一旦下载了二进制文件,这里是启动服务器的命令:
./tensorflow_model_server \
--rest_api_port=8501 \
--model_name=loan_default_model \
--model_base_path=/path/models/loan_default_model/1
请注意,model_base_path
应该是绝对路径,而不是相对路径。TF Serving 服务器启动后,你可以开始向其发送请求。有两种预期的格式——instances
和inputs
。下面你可以看到后者格式的示例,但你可以在这个教程中查看到两者的示例。
# Input data formatted correctly
data = {
"Bank": ["Other"],
"BankState": ["TN"],
"City": ["Other"],
"CreateJob": [12.0],
"FranchiseCode": ["0"],
"GrAppv": [14900000.0],
"NoEmp": [28.0],
"RetainedJob": [16.0],
"RevLineCr": ["N"],
"SBA_Appv": [14900000.0],
"State": ["TN"],
"Term": [240.0],
"UrbanRural": ["0"],
"is_new": [0.0],
"latitude": [35.3468],
"longitude": [-86.22],
"naics_first_two": ["44"],
"same_state": [1.0],
"ApprovalFY": [1]
}
payload = {"inputs": data}
# Send the request
url = 'http://localhost:8501/v1/models/default_model:predict'
response = requests.post(url, json=payload)
# Print out the response
print(json.loads(response.text)['outputs'])
# Expected output: [[0.0138759678]]
如果你成功获得了响应(可能与我的不同)——恭喜你!你已经完成了本文的最后一步。现在,让我们回顾一下如果你遵循了所有这些章节,你完成了什么。
结论
总结来说,TFDF 是一个强大且可扩展的库,用于在 TensorFlow 中训练基于树的模型。TFDF 模型与 TensorFlow 生态系统的其他部分集成良好,因此如果你在使用 TFX,有其他 TF 模型在生产中,或在使用 TF Serving,你会发现这个库非常有用。
如果你已经完成了这些笔记本,你现在应该知道如何训练、调优、检查和服务 TFDF 模型。正如你所看到的,TFDF 模型高度可定制,因此如果你需要一个高性能的树模型库,可以试试看,并告诉我效果如何!
还不是 Medium 会员?
[## 使用我的推荐链接加入 Medium - Antons Tocilins-Ruberts
阅读 Antons Tocilins-Ruberts(以及 Medium 上成千上万的其他作者)撰写的每一个故事。您的会员费直接……
medium.com](https://medium.com/@antonsruberts/membership?source=post_page-----3b6056a6d6b0--------------------------------)
TensorFlow-GNN:图神经网络的端到端指南
原文:
towardsdatascience.com/tensorflow-gnn-an-end-to-end-guide-for-graph-neural-networks-a66bfd237c8c
“Mapsterpiece”由 Heidi Malin 创作,已获许可使用
教程
如何使用自己的 Pandas/NetworkX 数据集进行图、节点和边预测
Michael Malin
·发表于Towards Data Science ·20 分钟阅读·2023 年 1 月 16 日
--
特别感谢 DeepMind 的 Alvaro Sanchez Gonzalez 和 Google 的 Bryan Perozzi 及 Sami Abu-el-haija,他们在本教程中给予了帮助
更新于 2023 年 04 月 22 日,修复了小问题并添加了 Graph Nets 方法
图数据无处不在。图研究仍处于起步阶段,图数据建模工具刚刚开始出现。这使得如果你是希望脱颖而出的数据科学家,现在是最佳时机。不幸的是,由于缺乏教程和支持,处于前沿可能很困难。本指南希望显著减轻这个痛点。
为什么选择 TensorFlow-GNN?
TF-GNN 是 Google 最近发布的用于图神经网络的 TensorFlow 库。虽然市场上还有其他 GNN 库,但由于 TF-GNN 在大规模图上的建模灵活性、分布式学习带来的性能优势和 Google 的支持,它很可能会成为行业标准。本指南假设你已经了解了这个库的优点,但请参阅这篇论文以获取更多信息和性能比较。此外,查看 TF-GNN 的文档。如果你对 GNN 完全陌生,请查看这本指南以获得概念理解。
缺点是什么?
由于该库当前处于 alpha 阶段,代码对建模所需的结构、输入形状和格式非常严格。这使得没有指南的情况下很难进行导航。不幸的是,目前没有关于使用 TF-GNN 的大量信息。我能找到的指南都集中在使用预构建 TensorFlow 数据集的相同上下文级预测用例上。截至写作时,没有一个完整的操作示例:
-
进行边或节点预测
-
从您自己的 Pandas 或 NetworkX 数据集开始
-
创建保留数据集
-
模型调优
-
解决您可能遇到的故障
在经过一个月的文档重读、反复试错编程和来自 Google/DeepMind 的 TensorFlow 开发人员的直接帮助后,我决定编写这个指南。
“许多[小时]为我们带来了这些信息。”
本指南将涵盖:
首先,我们将非常简单地开始,以掌握构建模块。然后我们将转向一个更高级的示例——大学橄榄球会议预测。以下是将要涵盖的内容概要:
-
TF-GNN 元素
-
构建模块
-
从 Pandas 生成的图形张量
-
-
数据设置
-
从 NetworkX 生成的图形张量
-
特性工程
-
创建测试拆分
-
创建图形 TensorFlow 数据集
-
-
构建模型
-
节点模型
-
边模型
-
上下文模型
-
-
故障排除
-
参数调优
TF-GNN 元素
一个图形由节点和边组成。以下是一个简单图形的示例,显示了最近互相接触的人(节点):
作者提供的示例图形
相同的图形也可以表示为节点和边表。我们还可以为这些节点和边添加特性。例如,我们可以添加“年龄”作为节点特性,并将“是否为朋友”作为边特性。
作者提供的示例节点和边数据
当我们向 TF-GNN 添加边时,我们需要按数字而非名称进行索引。我们可以这样做:
node_df = node_df.reset_index()
merge_df = node_df.reset_index().set_index('Name').rename(
columns={'index':'Name1_idx'})
edge_df = pd.merge(edge_df,merge_df['Name1_idx'],
how='left',left_on='Name1',right_index=True)
merge_df = merge_df.rename(columns={'Name1_idx':'Name2_idx'})
edge_df = pd.merge(edge_df,merge_df['Name2_idx'],
how='left',left_on='Name2',right_index=True)
作者提供的带有数字索引的节点和边数据
最后,我们可能会得到图形的上下文值。例如,也许这个朋友小组在某次测试中的平均分为 84%。这对于这个单一图形示例意义不大。如果我们有其他朋友图形,我们或许可以基于学到的群体动态预测新朋友小组的分数。
从 pandas 生成的图形张量
通过这些元素,我们现在可以为我们的 GNN 构建基础:一个图形张量。
import tensorflow_gnn as tfgnn
graph_tensor = tfgnn.GraphTensor.from_pieces(
node_sets = {
"People": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Age': np.array(node_df['Age'],
dtype='int32').reshape(len(node_df),1)})},
edge_sets ={
"Contact": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'Is-friend': np.array(edge_df['Is-friend'],
dtype='int32').reshape(len(edge_df),1)},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(edge_df['Name1_idx'], dtype='int32')),
target = ("People", np.array(edge_df['Name2_idx'], dtype='int32'))))
})
注意我们创建的特性如何适配到节点和边中。缩进结构使得添加额外的节点、边和特性变得简单。例如,我们可以轻松地为每个朋友观看的电影添加节点和边,并这次包含一个图形上下文值。
graph_tensor = tfgnn.GraphTensor.from_pieces(
context_spec = tfgnn.ContextSpec.from_field_specs(
features_spec ={
"score": [[0.84]]
}),
node_sets = {
"People": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Age': np.array(node_df['Age'],
dtype='int32').reshape(len(node_df),1)}),
"Movies": tfgnn.NodeSet.from_fields(
sizes = [len(movie_df)],
features ={
'Name': np.array(movie_df['Name'],
dtype='string').reshape(len(movie_df),1),
'Length': np.array(movie_df['Length'],
dtype='float32').reshape(len(movie_df),1)})},
edge_sets ={
"Contact": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'Is-friend': np.array(edge_df['Is-friend'],
dtype='int32').reshape(len(edge_df),1)},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(edge_df['Name1_idx'], dtype='int32')),
target = ("People", np.array(edge_df['Name2_idx'], dtype='int32')))),
'Watched': tfgnn.EdgeSet.from_fields(
sizes = [len(watched_df)],
features = {},
adjacency = tfgnn.Adjacency.from_indices(
source = ("People", np.array(watched_df['Name_idx'], dtype='int32')),
target = ("Movies", np.array(watched_df['Movie_idx'], dtype='int32'))))
})
注意:请非常小心你的数据类型和形状。任何偏差都会导致错误或训练问题。唯一支持的数据类型是‘int32’,‘float32’,和‘string’。如果遇到问题,请参阅本文末尾的故障排除部分。
你可能已经注意到,图张量是有方向的,具有源节点和目标节点。这对于萨姆看电影可能没问题,但通信是双向的。当萨姆与艾米交谈时,艾米也在与萨姆交谈。对于双向数据,你需要复制那些边(将源和目标反转),以指示数据流的两个方向。
作者提供的示例双向数据
有了这个基础,我们现在准备过渡到在真实数据集上进行预测。
数据设置
训练数据是 2000 年秋季期间 IA 分区大学之间的美式足球比赛网络,如下所示
作者:M. Girvan 和 M. Newman。节点数据包括大学名称和他们所属的会议索引(例如,会议 8 = Pac 10)。边数据包括两个大学名称,表示它们之间进行了一场比赛。数据可以如下提取(参见 Google Colab 以便跟进):
import urllib.request
import io
import zipfile
import networkx as nx
url = "http://www-personal.umich.edu/~mejn/netdata/football.zip"
sock = urllib.request.urlopen(url) # open URL
s = io.BytesIO(sock.read()) # read into BytesIO "file"
sock.close()
zf = zipfile.ZipFile(s) # zipfile object
txt = zf.read("football.txt").decode() # read info file
gml = zf.read("football.gml").decode() # read gml data
# throw away bogus first line with # from mejn files
gml = gml.split("\n")[1:]
G = nx.parse_gml(gml) # parse gml data
print(txt)
从 NetworkX 导入图张量
我们的数据现在在 NetworkX 图中。让我们看看用节点按其所属会议着色的效果如何。
cmap = {0:'#bd2309', 1:'#bbb12d',2:'#1480fa',3:'#14fa2f',4:'#faf214',
5:'#2edfea',6:'#ea2ec4',7:'#ea2e40',8:'#577a4d',9:'#2e46c0',
10:'#f59422',11:'#8086d9'}
colors = [cmap[G.nodes[n]['value']] for n in G.nodes()]
pos = nx.spring_layout(G, seed=1987)
nx.draw_networkx_edges(G, pos, alpha=0.2)
nx.draw_networkx_nodes(G, pos, nodelist=G.nodes(),
node_color=colors, node_size=100)
作者提供的大学美式足球网络
对于我们的节点模型,我们将尝试预测一个学校所属的会议。对于我们的边模型,我们将尝试预测一场比赛是否是会议内的比赛。 这两个预测将基于持出数据集进行评估。我们如何从 NetworkX 做到这一点?可以直接从图中构建图张量,使用这些函数来提取数据:
node_data = G.nodes(data=True)
edge_data = G.edges(data=True)
问题是,我们仍然想做一些特征工程,但我们还没有持出数据集。基于这些原因,我强烈建议将你的图数据转换为 Pandas。之后,我们可以使用在第一个示例中展示的方法将数据插入图张量。
node_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient='index')
node_df.index.name = 'school'
node_df.columns = ['conference']
edge_df = nx.to_pandas_edgelist(G)
作者提供的大学美式足球节点和边数据
特征工程
使用基础图,一个模型可能能够基于网络确定两所大学是否在同一个会议中。但它如何知道具体是哪个会议呢?在没有任何节点或边数据的情况下,它如何学习会议之间的差异?为此任务,我们需要添加更多特征。
我们应该收集什么样的特征?我不是大学橄榄球方面的专家,但我想会议的组成是基于邻近性和排名的。本指南侧重于 TF-GNN,因此我将使用魔法添加这些新特征,但你可以在链接的 Google Colab 中找到具体的代码。
对于节点,我们将添加纬度/经度以及前一年的(1999 年)排名、胜场和会议胜场。我们还将把会议列转换为 12 个虚拟变量列,以进行 softmax 预测。
作者提供的最终节点数据集
对于边,我们将计算学校之间的距离,添加名称相似度评分(也许名称中包含相同州的学校更有可能在同一会议中),以及比赛是否为会议内比赛的目标值。
作者提供的最终边数据集
让我们用新的信息可视化我们的数据(橙色边表示会议比赛)。地理位置显然在会议选择中至少发挥了作用。
作者提供的美国地图上的大学数据
创建测试拆分
创建训练集是直接的;排除保留的节点和边的方式与你通常的做法相同。然而,保留数据与典型的机器学习应用有所不同。由于整体连接对于准确预测很重要,最终的预测需要基于整个图。一旦做出预测,结果可以过滤到保留数据中进行最终评估。我将在预测阶段更详细地展示这个过程;目前我创建拆分的方式如下:
from sklearn.model_selection import train_test_split
node_train, node_test = train_test_split(node_df,test_size=0.15,random_state=42)
edge_train = edge_df.loc[~((edge_df['source'].isin(node_test.index)) | (edge_df['target'].isin(node_test.index)))]
edge_test = edge_df.loc[(edge_df['source'].isin(node_test.index)) | (edge_df['target'].isin(node_test.index))]
使用我们新的拆分,现在我们可以进行双向调整并添加边索引列。
def bidirectional(edge_df):
reverse_df = edge_df.rename(columns={'source':'target','target':'source'})
reverse_df = reverse_df[edge_df.columns]
reverse_df = pd.concat([edge_df, reverse_df], ignore_index=True, axis=0)
return reverse_df
def create_adj_id(node_df,edge_df):
node_df = node_df.reset_index().reset_index()
edge_df = pd.merge(edge_df,node_df[['school','index']].rename(columns={"index":"source_id"}),
how='left',left_on='source',right_on='school').drop(columns=['school'])
edge_df = pd.merge(edge_df,node_df[['school','index']].rename(columns={"index":"target_id"}),
how='left',left_on='target',right_on='school').drop(columns=['school'])
edge_df.dropna(inplace=True)
return node_df, edge_df
edge_full_adj = bidirectional(edge_df)
edge_train_adj = bidirectional(edge_train)
node_full_adj,edge_full_adj = create_adj_id(node_df,edge_full_adj)
node_train_adj,edge_train_adj = create_adj_id(node_train,edge_train_adj)
创建 TensorFlow 数据集
现在我们准备创建我们的图张量,我们将把它们转换成 TensorFlow 数据集。
def create_graph_tensor(node_df,edge_df):
graph_tensor = tfgnn.GraphTensor.from_pieces(
node_sets = {
"schools": tfgnn.NodeSet.from_fields(
sizes = [len(node_df)],
features ={
'Latitude': np.array(node_df['Latitude'], dtype='float32').reshape(len(node_df),1),
'Longitude': np.array(node_df['Longitude'], dtype='float32').reshape(len(node_df),1),
'Rank': np.array(node_df['Rank'], dtype='int32').reshape(len(node_df),1),
'Wins': np.array(node_df['Wins'], dtype='int32').reshape(len(node_df),1),
'Conf_wins': np.array(node_df['Conf_wins'], dtype='int32').reshape(len(node_df),1),
'conference': np.array(node_df.iloc[:,-12:], dtype='int32'),
}),
},
edge_sets ={
"games": tfgnn.EdgeSet.from_fields(
sizes = [len(edge_df)],
features = {
'name_sim_score': np.array(edge_df['name_sim_score'], dtype='float32').reshape(len(edge_df),1),
'euclidean_dist': np.array(edge_df['euclidean_dist'], dtype='float32').reshape(len(edge_df),1),
'conference_game': np.array(edge_df['conference_game'], dtype='int32').reshape(len(edge_df),1)
},
adjacency = tfgnn.Adjacency.from_indices(
source = ("schools", np.array(edge_df['source_id'], dtype='int32')),
target = ("schools", np.array(edge_df['target_id'], dtype='int32')),
)),
})
return graph_tensor
full_tensor = create_graph_tensor(node_full_adj,edge_full_adj)
train_tensor = create_graph_tensor(node_train_adj,edge_train_adj)
在创建数据集之前,我们需要一个函数将图拆分为训练数据和我们将要预测的目标(如下所示的标签)。对于我们的节点预测问题,我们将“conference”作为我们的标签。我们还需要从数据集中删除“conference_game”特征,因为它会造成数据泄露问题(即作弊)。
def node_batch_merge(graph):
graph = graph.merge_batch_to_components()
node_features = graph.node_sets['schools'].get_features_dict()
edge_features = graph.edge_sets['games'].get_features_dict()
label = node_features.pop('conference')
_ = edge_features.pop('conference_game')
new_graph = graph.replace_features(
node_sets={'schools':node_features},
edge_sets={'games':edge_features})
return new_graph, label
我们将对边模型进行反向操作:删除“conference”特征并将“conference_game”拆分为目标(标签)。
def edge_batch_merge(graph):
graph = graph.merge_batch_to_components()
node_features = graph.node_sets['schools'].get_features_dict()
edge_features = graph.edge_sets['games'].get_features_dict()
_ = node_features.pop('conference')
label = edge_features.pop('conference_game')
new_graph = graph.replace_features(
node_sets={'schools':node_features},
edge_sets={'games':edge_features})
return new_graph, label
我们现在可以创建数据集,并通过上述函数进行映射。
def create_dataset(graph,function):
dataset = tf.data.Dataset.from_tensors(graph)
dataset = dataset.batch(32)
return dataset.map(function)
#Node Datasets
full_node_dataset = create_dataset(full_tensor,node_batch_merge)
train_node_dataset = create_dataset(train_tensor,node_batch_merge)
#Edge Datasets
full_edge_dataset = create_dataset(full_tensor,edge_batch_merge)
train_edge_dataset = create_dataset(train_tensor,edge_batch_merge)
这些程序的顺序非常重要:
1. 我们从图张量创建数据集。
2. 我们将数据集按批次拆分(了解一下批次大小)。
3. 在映射函数中,我们将这些批次合并回一个图中。
4. 根据需要拆分/删除特征。
如果你不严格按照此顺序操作,模型将无法训练(或无法正确训练)。
构建模型
我们已经有了数据集,现在是有趣的部分!首先,我们使用数据集规格定义输入。
graph_spec = train_node_dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)
现在我们需要初始化特征。我们将创建初始化节点和边的函数。然后,我们通过这些函数映射我们的特征。为了简化,我将为每个特征创建一个密集层。
def set_initial_node_state(node_set, node_set_name):
features = [
tf.keras.layers.Dense(32,activation="relu")(node_set['Latitude']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Longitude']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Rank']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Wins']),
tf.keras.layers.Dense(32,activation="relu")(node_set['Conf_wins'])
]
return tf.keras.layers.Concatenate()(features)
def set_initial_edge_state(edge_set, edge_set_name):
features = [
tf.keras.layers.Dense(32,activation="relu")(edge_set['name_sim_score']),
tf.keras.layers.Dense(32,activation="relu")(edge_set['euclidean_dist'])
]
return tf.keras.layers.Concatenate()(features)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
在之前的步骤中可以进行很多自定义。例如,我们可以为字符串特征创建词嵌入。我们可以通过对纬度/经度网格进行哈希处理而不是仅仅使用密集层来获得一些准确性。TensorFlow 为我们提供了许多选项。
几点说明:
-
如果你有多个节点或边,你需要添加 ‘if 语句’ 以将特征应用到正确的节点/边。
-
没有特征的节点或边也可以使用 ‘MakeEmptyFeature’ 函数进行初始化。
-
对于以节点为中心的问题,初始化边是可选的(阅读更多关于节点与边中心的内容)。
-
第一个节点必须至少有一个特征。如果没有特征,你可能需要在一个索引上创建一个嵌入(结果可能不会很好)。
# Examples, do not use for this problem
def set_initial_node_state(node_set, node_set_name):
if node_set_name == "node_1":
return tf.keras.layers.Embedding(115,3)(node_set['id'])
elif node_set_name == "node_2":
return tfgnn.keras.layers.MakeEmptyFeature()(node_set)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state)(input_graph)
在我们开发更新循环之前,我们需要一个额外的辅助函数。随着我们添加密集层,我们希望确保我们在使用 L2 正则化和/或 dropout(L1 也可以使用)。
def dense_layer(self,units=64,l2_reg=0.1,dropout=0.25,activation='relu'):
regularizer = tf.keras.regularizers.l2(l2_reg)
return tf.keras.Sequential([
tf.keras.layers.Dense(units,
kernel_regularizer=regularizer,
bias_regularizer=regularizer),
tf.keras.layers.Dropout(dropout)])
节点模型
有许多模型架构,但图卷积网络迄今为止是最常见的(见其他方法 这里 )。图卷积类似于计算机视觉问题中常用的卷积。主要区别在于图卷积处理的是你在图结构中找到的不规则数据。让我们跳入实际代码中。
graph_updates = 3 # tunable parameter
for i in range(graph_updates):
graph = tfgnn.keras.layers.GraphUpdate(
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = dense_layer(32),
reduce_type="sum",
sender_edge_feature = tfgnn.HIDDEN_STATE,
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64)))})(graph) #start here
logits = tf.keras.layers.Dense(12,activation='softmax')(graph.node_sets["schools"][tfgnn.HIDDEN_STATE])
node_model = tf.keras.Model(input_graph, logits)
上面的代码可能有些令人困惑,因为 TensorFlow 堆叠的工作原理。请记住,‘#start here’ 标记的(图)实际上是前面代码的输入。在开始时,这个(图)等于我们之前映射的初始化特征。输入被送入 ‘GraphUpdate’ 函数,成为新的(图)。每次‘graph_updates’循环中,之前的 ‘GraphUpdate’ 成为新的 ‘GraphUpdate’ 的输入,同时还指定了一个通过 ‘NextStateFromConcat’ 函数的密集层。这个图示应该有助于解释:
图卷积网络图,显示了作者进行的两次图更新
‘GraphUpdate’函数简单地更新指定的状态(节点、边或上下文),并添加一个下一个状态层。在这种情况下,我们只用‘NodeSetUpdate’更新节点状态,但当我们处理边模型时,我们将探索以边为中心的方法。通过这个节点更新,我们在边缘上应用了一个卷积层,使信息能够从邻近的节点和边缘传递到节点。图形更新的数量是一个可调节的参数,每次更新允许信息从更远的节点传播。例如,我们在案例中指定的三次更新允许信息从最多三节点远的地方传播。在图形更新后,最终的节点状态成为我们标记为‘logits’的预测头的输入。由于我们预测 12 个不同的会议,我们有一个包含 12 个单元的密集层,并使用 softmax 激活函数。现在我们可以编译模型。
node_model.compile(
tf.keras.optimizers.Adam(learning_rate=0.01),
loss = 'categorical_crossentropy',
metrics = ['categorical_accuracy']
)
node_model.summary()
最后,我们训练模型。我使用了一个回调函数来在验证数据集停止提高准确率时停止训练。这并不完美,因为我们必须使用完整数据集(如上所述)。这将导致我们的准确率包括数据泄漏。一个完美的解决方案是编写一个自定义评估函数,仅返回验证数据上的验证节点的准确率,以及训练数据上的训练节点的准确率。这需要大量工作(要讲解的话会成为一个教程),以便在最准确的停止点上更近一步。我选择保持简单,并接受一个略微不那么准确的模型。
es = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',mode='min',verbose=1,
patience=10,restore_best_weights=True)
node_model.fit(train_node_dataset.repeat(),
validation_data=full_node_dataset,
steps_per_epoch=10,
epochs=1000,
callbacks=[es])
现在来看一下我们使用 node_model.predict(full_node_dataset)
的效果,并通过魔法将结果打印在地图上(见Google Colab)。
按作者的节点预测准确性比较
总体而言,我们的准确率达到了令人尊敬的 88%(见Google Colab获取模型参数)。模型在山区州的表现似乎较差。深入分析后,我们发现了一些有趣的见解。例如,模型错误地预测犹他州属于 Pac 10 会议。然而,事实上,犹他州在第二年确实加入了 Pac 10。完全有可能模型准确地识别了应该如何进行,而约 12%的误差实际上是衡量在创建会议时的人为不一致性。另一种考虑方式是用社交网络中的朋友。如果网络预测两个人是朋友,而他们从未见过面,这个模型是错的还是他们实际上是好朋友?对于许多(或大多数)图形问题,这些“错误”实际上是你真正想要找到的。这些“错误”可以用于推荐购买的产品、观看的电影、应该联系的人等。
在这种情况下,假设数据是完美的,我们关注的是分类准确率。为了真正了解我们的表现如何,我们需要在持出数据上测试准确率。为此,我们将在完整数据集上进行预测,然后筛选到持出节点。
def evaluate_node():
### Add raw prediction ####
yhat = node_model.predict(full_node_dataset)
yhat_df = node_full_adj.set_index('school').iloc[:,-12:].copy()
yhat_df.iloc[:,:] = yhat
### Classify max of softmax output ###
yhat_df = yhat_df.apply(lambda x: x == x.max(), axis=1).astype(int)
### Merge output back to single column ###
yhat_df = yhat_df.dot(yhat_df.columns).to_frame().rename(columns={0:'conf_yhat'})
yhat_df = yhat_df['conf_yhat'].str.replace('conf_', '').astype(int).to_frame()
yhat_df['conf_actual'] = node_full_adj['conference']
### Filter down to test nodes ###
yhat_df = yhat_df.loc[yhat_df.index.isin(params['testset'].index)]
### Calculate accuracy ###
yhat_df['Accuracy'] = yhat_df['conf_yhat']==yhat_df['conf_actual']
return yhat_df['Accuracy'].mean()
对于这个模型,准确率下降到约 72%(别惊慌,持出数据集上下降是预期的)。鉴于特征工程有限、仅有一年的数据和 12 个输出预测——这些结果是合理的。通过对下面地图的视觉检查(并与上面的完整地图进行比较),大多数错误看起来像是合理的猜测。
按作者比较节点持出预测准确率
边模型
现在我们将尝试预测某场比赛是否为会议内比赛。我们已经在上面定义了我们的边数据集,大多数步骤可以重复使用,仅有一个更改:
### Change to train_edge_dataset ###
graph_spec = train_edge_dataset.element_spec[0]
input_graph = tf.keras.layers.Input(type_spec=graph_spec)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
我们确实需要对图更新进行一些更改。首先,我们需要在‘GraphUpdate’函数中添加一个‘edge_sets’更新。保留‘node_sets’更新是可选的,但模型似乎在保留它时表现更好。接下来,我们将从 GCN 切换到 Graph Nets 方法。这种方法将边视为一等公民(即一种时髦的说法,表示它们将学习自己的权重,这正是我们所追求的)。最后,我们需要将‘logits’更新为一个单单位的 sigmoid 激活密集层,因为我们正在预测一个虚拟变量。
graph_updates = 3
for i in range(graph_updates):
graph = tfgnn.keras.layers.GraphUpdate(
edge_sets = {'games': tfgnn.keras.layers.EdgeSetUpdate(
next_state = tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64,activation='relu')))},
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.Pool(
tag=tfgnn.TARGET,
reduce_type="sum",
feature_name = tfgnn.HIDDEN_STATE)},
tfgnn.keras.layers.NextStateFromConcat(
dense_layer(64)))})(graph)
logits = tf.keras.layers.Dense(1,activation='sigmoid')(graph.edge_sets['games'][tfgnn.HIDDEN_STATE])
edge_model = tf.keras.Model(input_graph, logits)
这次我们使用‘binary_crossentropy’来编译模型。
edge_model.compile(
tf.keras.optimizers.Adam(learning_rate=0.01),
loss = 'binary_crossentropy',
metrics = ['Accuracy']
)
edge_model.summary()
我们使用与节点问题中定义的相同回调来拟合模型。
edge_model.fit(train_edge_dataset.repeat(),
validation_data=full_edge_dataset,
steps_per_epoch=10,
epochs=1000,
callbacks=[es])
yhat = edge_model.predict(full_edge_dataset)
yhat_df = edge_full_adj.copy().set_index(['source','target'])
yhat_df['conf_game_yhat'] = yhat.round(0)
yhat_df = yhat_df.loc[yhat_df.index.isin(
edge_test.set_index(['source','target']).index)]
yhat_df['loss'] = abs(yhat_df['conference_game'] - yhat_df['conf_game_yhat'])
loss = yhat_df['loss'].mean()
print("edge accuracy:",1 - loss)
在持出数据集上进行评估时,我们获得了 85%的准确率,而均值为 56%。模型完成了它的工作,我对这些结果感到满意。
上下文模型
这个特定的问题没有上下文值。我们可以假设我们将上面的图划分为每个会议的单独图。这些新图将显示会议中每个球队进行的每场比赛,并忽略所有其他比赛。然后我们可以为每个图提供会议排名的值。现在我们可以训练一个模型来进行上下文级别的预测。
首先,我们需要将上下文值添加到图中。
graph_tensor = tfgnn.GraphTensor.from_pieces(
context = tfgnn.Context.from_fields(
features ={
<context_feature>
}),
node_sets = {
...
接下来,我们需要创建一个新的数据集,将上下文映射到标签。
def node_batch_merge(graph):
graph = graph.merge_batch_to_components()
context_features = graph.context.get_features_dict()
label = context_features.pop('<context_feature>')
new_graph = graph.replace_features(
context=context_features)
return new_graph, label
我们可以设置初始上下文状态。在这种情况下,我们正在预测这个特征,因此它将从训练数据中缺失。对于其他模型,上下文可能是一个可训练的特征,可以这样设置:
def set_initial_context_state(context):
return tf.keras.layers.Dense(32,activation="relu")(context['<context_feature>'])
graph = tfgnn.keras.layers.MapFeatures(
context_fn=set_initial_context_state,
node_sets_fn=set_initial_node_state,
edge_sets_fn=set_initial_edge_state
)(input_graph)
再次,我们可以选择性地将上下文更新添加到‘GraphUpdate’(见下文)。我没有测试这种方法,所以可以随意尝试。
graph = tfgnn.keras.layers.GraphUpdate(
node_sets ={...},
context = tfgnn.keras.layers.ContextUpdate({
'schools': tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean")},
tfgnn.keras.layers.NextStateFromConcat(tf.keras.layers.Dense(128))))
最后,我们更新‘logits’以进行上下文预测。
logits = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean",
node_set_name="schools")(graph)
故障排除错误
在尝试理解上述代码时,我遇到了许多错误和训练效果不佳的模型。虽然我尽量保持内容的通用性以适用于多种不同的问题,但你在为数据进行调整时无疑会遇到错误。诀窍是识别错误的来源。我发现诊断错误的最佳方法是创建一个图形模式。
在我们的代码中,我们从数据集中提取了图形模式。然而,你也可以直接构建一个图形模式。对于我们的足球示例,图形模式如下所示:
graph_spec = tfgnn.GraphTensorSpec.from_piece_specs(
context_spec = tfgnn.ContextSpec.from_field_specs(
features_spec ={
#Added as an example for context problems
#"conf_rank": tf.TensorSpec(shape=(None,1), dtype=tf.float32),
}),
node_sets_spec={
'schools':
tfgnn.NodeSetSpec.from_field_specs(
features_spec={
'Latitude': tf.TensorSpec((None, 1), tf.float32),
'Longitude': tf.TensorSpec((None, 1), tf.float32),
'Rank': tf.TensorSpec((None, 1), tf.int32),
'Wins': tf.TensorSpec((None, 1), tf.int32),
'Conf_wins': tf.TensorSpec((None, 1), tf.int32),
'conference': tf.TensorSpec((None, 12), tf.int32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32))
},
edge_sets_spec={
'games':
tfgnn.EdgeSetSpec.from_field_specs(
features_spec={
'name_sim_score': tf.TensorSpec((None, 1), tf.float32),
'euclidean_dist': tf.TensorSpec((None, 1), tf.float32),
'conference_game': tf.TensorSpec((None, 1), tf.int32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32),
adjacency_spec=tfgnn.AdjacencySpec.from_incident_node_sets(
'schools', 'schools'))
})
我们可以通过尝试构建和编译模型来测试‘graph_spec’是否至少有效。如果出现错误,可能是你的特征形状或‘set_initial_…’函数存在问题。如果成功,你可以验证你创建的模式是否与‘graph_tensor’兼容。
graph_spec.is_compatible_with(full_tensor)
如果为假,你可以打印‘full_tensor.spec’和‘graph_spec’来比较每一部分,以确保形状和数据类型完全相同。你还可以直接从‘graph_spec’创建一个随机生成的图张量。
random_graph = tfgnn.random_graph_tensor(graph_spec)
使用这个‘random_graph’你可以尝试训练一个模型。这应该有助于你确定错误是出在规范还是模型代码上。如果没有错误,你可以打印‘random_graph’的值,看看输出与‘graph_tensor’的比较情况。
print("Nodes:",random_graph.node_sets['schools'].features)
print("Edges:",random_graph.edge_sets['games'].features)
print("Context:",random_graph.context.features)
这些步骤应该可以帮助你追踪大部分遇到的问题。
参数调优
我们已经成功修复了所有错误并训练了一个模型。现在我们希望调整超参数以获得一个准确的模型。我选择的调优工具是 Hyperopt 库,因为它易于使用且集成了贝叶斯优化。但首先,我们需要将上述建模代码转换为具有变量的类。
class GCNN:
def __init__(self,params):
self.params = params
def set_initial_node_state(self, node_set, node_set_name):
features = [
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Latitude']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Longitude']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Rank']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Wins']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(node_set['Conf_wins'])
]
return tf.keras.layers.Concatenate()(features)
def set_initial_edge_state(self, edge_set, edge_set_name):
features = [
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(edge_set['name_sim_score']),
tf.keras.layers.Dense(self.params['feature_dim'],activation="relu")(edge_set['euclidean_dist'])
]
return tf.keras.layers.Concatenate()(features)
def dense_layer(self,units=64):
regularizer = tf.keras.regularizers.l2(self.params['l2_reg'])
return tf.keras.Sequential([
tf.keras.layers.Dense(units,
kernel_regularizer=regularizer,
bias_regularizer=regularizer,
activation='relu'),
tf.keras.layers.Dropout(self.params['dropout'])])
def build_model(self):
input_graph = tf.keras.layers.Input(type_spec=self.params['trainset'].element_spec[0])
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=self.set_initial_node_state,
edge_sets_fn=self.set_initial_edge_state
)(input_graph)
if self.params['loss']=='categorical_crossentropy':
for i in range(self.params['graph_updates']):
graph = tfgnn.keras.layers.GraphUpdate(
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = self.dense_layer(self.params['message_dim']),
reduce_type="sum",
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))})(graph)
logits = tf.keras.layers.Dense(12,activation='softmax')(graph.node_sets['schools'][tfgnn.HIDDEN_STATE])
else:
for i in range(self.params['graph_updates']):
graph = tfgnn.keras.layers.GraphUpdate(
edge_sets = {'games': tfgnn.keras.layers.EdgeSetUpdate(
next_state = tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))},
node_sets = {
'schools': tfgnn.keras.layers.NodeSetUpdate({
'games': tfgnn.keras.layers.SimpleConv(
message_fn = self.dense_layer(self.params['message_dim']),
reduce_type="sum",
receiver_tag=tfgnn.TARGET)},
tfgnn.keras.layers.NextStateFromConcat(
self.dense_layer(self.params['next_state_dim'])))})(graph)
logits = tf.keras.layers.Dense(1,activation='sigmoid')(graph.edge_sets['games'][tfgnn.HIDDEN_STATE])
return tf.keras.Model(input_graph, logits)
def train_model(self,trial=True):
model = self.build_model()
model.compile(tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate']),
loss=self.params['loss'],
metrics=['Accuracy'])
callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss',
mode='min',
verbose=1,
patience=self.params['patience'],
restore_best_weights=True)]
model.fit(self.params['trainset'].repeat(),
validation_data=self.params['full_dataset'],
steps_per_epoch=self.params['steps_per_epoch'],
epochs=self.params['epochs'],
verbose=0,
callbacks = callbacks)
loss = self.evaluate_model(model,trial=trial)
if trial == True:
sys.stdout.flush()
hypt_params = {
'graph_updates':self.params['graph_updates'],
'feature_dim':self.params['feature_dim'],
'next_state_dim':self.params['next_state_dim'],
'message_dim':self.params['message_dim'],
'l2_reg':self.params['l2_reg'],
'dropout':self.params['dropout'],
'learning_rate':self.params['learning_rate']}
print(hypt_params,'loss:',loss)
return {'loss': loss, 'status': STATUS_OK}
else:
print('loss:',loss)
return model
def evaluate_model(self,model,trial=True):
if self.params['loss'] == 'categorical_crossentropy':
yhat = model.predict(full_node_dataset)
yhat_df = node_full_adj.set_index('school').iloc[:,-12:].copy()
yhat_df.iloc[:,:] = yhat
yhat_df = yhat_df.apply(lambda x: x == x.max(), axis=1).astype(int)
yhat_df = yhat_df.dot(yhat_df.columns).to_frame().rename(columns={0:'conf_yhat'})
yhat_df = yhat_df['conf_yhat'].str.replace('conf_', '').astype(int).to_frame()
yhat_df['conf_actual'] = node_full_adj.set_index('school')['conference']
yhat_df = yhat_df.loc[yhat_df.index.isin(node_test.index)]
yhat_df['Accuracy'] = yhat_df['conf_yhat']==yhat_df['conf_actual']
loss = 1 - yhat_df['Accuracy'].mean()
else:
yhat = model.predict(full_edge_dataset)
yhat_df = edge_full_adj.copy().set_index(['source','target'])
yhat_df['conf_game_yhat'] = yhat.round(0)
yhat_df = yhat_df.loc[yhat_df.index.isin(
edge_test.set_index(['source','target']).index)]
yhat_df['loss'] = abs(yhat_df['conference_game'] - yhat_df['conf_game_yhat'])
loss = yhat_df['loss'].mean()
return loss
现在我们定义我们的参数。对于调优参数,我们可以明确地定义值(例如,‘dropout’: 0.1),或像下面我所做的那样定义 Hyperopt 可以实验的空间。‘hp.choice’会在你指定的选项之间进行选择,而‘hp.uniform’会在两个值之间选择。Hyperopt 文档中还有许多其他可用的选项。
params = {
### Tuning parameters ###
'graph_updates': hp.choice('graph_updates',[2,3,4]),
'feature_dim': hp.choice('feature_dim',[16,32,64,128]),
'message_dim': hp.choice('message_dim',[16,32,64,128]),
'next_state_dim': hp.choice('next_state_dim',[16,32,64,128]),
'l2_reg': hp.uniform('l2_reg',0.0,0.3),
'dropout': hp.choice('dropout',[0,0.125,0.25,0.375,0.5]),
'learning_rate': hp.uniform('learning_rate',0.0,0.1),
### Static parameters ###
'loss': 'categorical_crossentropy',
'epochs': 1000,
'steps_per_epoch':10, ### This could also be a tuned parameter
'patience':10,
'trainset':train_node_dataset,
'full_dataset':full_node_dataset
}
接下来,我们定义一个辅助函数,并将其与我们的参数一起插入到‘fmin’中。每次评估都是一个训练好的模型,因此根据你的硬件,这可能需要一段时间。如果速度对你来说太慢,可以考虑减少‘max_evals’的次数。我个人的经验法则是每个调优参数大约进行 15 次评估,因此我会明确地定义一些参数,以便根据评估次数的减少进行调整。
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
def tune_model(params):
return GCNN(params).train_model()
best = fmin(tune_model, params, algo=tpe.suggest,
max_evals=100, trials=Trials())
现在我们有了最佳的超参数,可以训练最终模型(注意:由于 TensorFlow 随机初始化其权重,你的准确度可能会略有不同)。
### Perameters from my hyperopt run ###
best = {'graph_updates': 4,
'feature_dim': 64,
'next_state_dim': 32,
'message_dim': 128,
'l2_reg': 0.095,
'dropout': 0,
'learning_rate': 0.0025
}
node_params = params
for param, value in best.items():
node_params[param] = value
node_model = GCNN(node_params).train_model(trial=False)
我们可以通过一些微小的调整来调优和训练我们的边缘模型:
params['loss'] = 'binary_crossentropy'
params['trainset'] = train_edge_dataset
params['full_dataset'] = full_edge_dataset
best = fmin(tune_model, params, algo=tpe.suggest,
max_evals=100, trials=Trials())
### Perameters from my hyperopt run ###
best = {'graph_updates': 4,
'feature_dim': 64,
'next_state_dim': 32,
'message_dim': 128,
'l2_reg': 0.095,
'dropout': 0,
'learning_rate': 0.0025
}
edge_params = params
for param, value in best.items():
edge_params[param] = value
edge_model = GCNN(edge_params).train_model(trial=False)
最后的思考
GNN 研究仍处于初期阶段。新的建模方法很可能会被发现。由于 TF-GNN 仍处于 alpha 状态,未来可能会有一些代码变更。如果您发现我尚未修复的更改或错误,请在下方留言,我会尽力更新此指南。如果您不喜欢这篇文章,可以在评论中将我与您最喜欢的历史独裁者作类比。否则,鼓掌或友好的评论将不胜感激。
我希望这份指南能成为更多人进入这一领域并进行实验的起点。请考虑这是您参与下一波 AI 浪潮的机会!
关于我
我是一名资深数据科学家和兼职自由职业者,拥有超过 12 年的经验。我始终希望能与他人建立联系,请随时:
-
在 LinkedIn 上与我连接
-
在 Twitter 上关注我
-
访问我的网站:www.modelforge.ai
-
查看我的其他文章
如果您有任何问题,请随时在下面留言。