Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

吴金笛、林亦霖校对王菁编辑张若楠翻译Matthew Smith作者

不同机器学习模型的决策边界(附代码)

作者前言

我使用Iris数据集训练了一系列机器学习模型,从数据中的极端值合成了新数据点,并测试了许多机器学习模型来绘制出决策边界,这些模型可根据这些边界在2D空间中进行预测,这对于阐明目的和了解不同机器学习模型如何进行预测会很有帮助。

前沿的机器学习

机器学习模型可以胜过传统的计量经济学模型,这并没有什么新奇的,但是作为研究的一部分,我想说明某些模型为什么以及如何进行分类预测。我想展示我的二分类模型所依据的决策边界,也就是展示数据进行分类预测的分区空间。该问题以及代码经过一些调整也能够适用于多分类问题

初始化

首先加载一系列程序包,然后新建一个logistic函数,以便稍后将log-odds转换为logistic概率函数。

library(dplyr)
library(patchwork)
library(ggplot2)
library(knitr)
library(kableExtra)
library(purrr)
library(stringr)
library(tidyr)
library(xgboost)
library(lightgbm)
library(keras)
library(tidyquant)
##################### Pre-define some functions


logit2prob <- function(logit){
  odds <- exp(logit)
  prob <- odds / (1 + odds)
  return(prob)
}

数据

我使用的iris数据集包含有关英国统计员Ronald Fisher在1936年收集的3种不同植物变量的信息。该数据集包含4种植物物种的不同特征,这些特征可区分33种不同物种(Setosa,Virginica和Versicolor)。但是,我的问题需要一个二元分类问题,而不是一个多分类问题。在下面的代码中,我导入了iris数据并删除了一种植物物种virginica,以将其从多重分类转变为二元分类问题

data(iris)
df <- iris %>%
  filter(Species != "virginica") %>%
  mutate(Species = +(Species == "versicolor"))
str(df)


## 'data.frame':    100 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : int  0 0 0 0 0 0 0 0 0 0 ...

我首先采用ggplot来绘制数据,以下储存的ggplot对象中,每个图仅更改x和y变量选择。

plt1 <- df %>%  ggplot(aes(x = Sepal.Width, y = Sepal.Length, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt2 <- df %>%  ggplot(aes(x = Petal.Length, y = Sepal.Length, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt3 <- df %>%  ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt3 <- df %>%  ggplot(aes(x = Sepal.Length, y = Sepal.Width, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt4 <- df %>%  ggplot(aes(x = Petal.Length, y = Sepal.Width, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt5 <- df %>%  ggplot(aes(x = Petal.Width, y = Sepal.Width, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")
plt6 <- df %>%  ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) +  geom_point(size = 4) +  theme_bw(base_size = 15) +  theme(legend.position = "none")

我还使用了新的patchwork 包,使展示ggplot结果变得很容易。下面的代码很直白的绘制了我们的图形(1个顶部图占满了网格空间的长度,2个中等大小的图,另一个单个图以及底部另外2个图)

    (plt1)    /
  (plt2 + plt3)

或者,我们可以将绘图重新布置为所需的任何方式,并通过以下方式进行绘图:

(plt1 + plt2) /
 (plt5 + plt6)

我觉得这看起来不错。

专业用户独享

本文为机器之心深度精选内容,专业认证后即可阅读全文
开启专业认证
工程机器学习决策边界
21
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

逻辑回归技术

逻辑回归(英语:Logistic regression 或logit regression),即逻辑模型(英语:Logit model,也译作“评定模型”、“分类评定模型”)是离散选择法模型之一,属于多重变量分析范畴,是社会学、生物统计学、临床、数量心理学、计量经济学、市场营销等统计实证分析的常用方法。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

随机森林技术

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

决策边界技术

在具有两类的统计分类问题中,决策边界或决策曲面是一个超曲面,它将底层的向量空间分成两组,每组一个。分类器会将决策边界一侧的所有点分为属于一个类,而另一侧属于另一个类。也即二元分类或多类别分类问题中,模型学到的类别之间的分界线。

逻辑技术

人工智能领域用逻辑来理解智能推理问题;它可以提供用于分析编程语言的技术,也可用作分析、表征知识或编程的工具。目前人们常用的逻辑分支有命题逻辑(Propositional Logic )以及一阶逻辑(FOL)等谓词逻辑。

支持向量机技术

在机器学习中,支持向量机是在分类与回归分析中分析数据的监督式学习模型与相关的学习算法。给定一组训练实例,每个训练实例被标记为属于两个类别中的一个或另一个,SVM训练算法创建一个将新的实例分配给两个类别之一的模型,使其成为非概率二元线性分类器。SVM模型是将实例表示为空间中的点,这样映射就使得单独类别的实例被尽可能宽的明显的间隔分开。然后,将新的实例映射到同一空间,并基于它们落在间隔的哪一侧来预测所属类别。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

分类问题技术

分类问题是数据挖掘处理的一个重要组成部分,在机器学习领域,分类问题通常被认为属于监督式学习(supervised learning),也就是说,分类问题的目标是根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类。根据类别的数量还可以进一步将分类问题划分为二元分类(binary classification)和多元分类(multiclass classification)。

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

XGBoost技术

XGBoost是一个开源软件库,为C ++,Java,Python,R,和Julia提供了渐变增强框架。 它适用于Linux,Windows,MacOS。从项目描述来看,它旨在提供一个“可扩展,便携式和分布式的梯度提升(GBM,GBRT,GBDT)库”。 除了在一台机器上运行,它还支持分布式处理框架Apache Hadoop,Apache Spark和Apache Flink。 由于它是许多机器学习大赛中获胜团队的首选算法,因此它已经赢得了很多人的关注。

京东・算法工程师
虽然很简单,但是很实用。