转载

使用基于Apache Spark的随机森林方法预测贷款风险

原文: Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests

作者:Carol McDonald,MapR解决方案架构师

翻译:KK4SBB

责编:周建丁(zhoujd@csdn.net)

在本文中,我将向大家介绍如何使用Apache Spark的spark.ml库中的随机森林算法来对银行信用贷款的风险做分类预测。 Spark的spark.ml库 基于DataFrame,它提供了大量的接口,帮助用户创建和调优机器学习工作流。结合dataframe使用spark.ml,能够实现模型的智能优化,从而提升模型效果。

分类算法

分类算法是一类监督式机器学习算法,它根据已知标签的样本(如已经明确交易是否存在欺诈)来预测其它样本所属的类别(如是否属于欺诈性的交易)。分类问题需要一个已经标记过的数据集和预先设计好的特征,然后基于这些信息来学习给新样本打标签。所谓的特征即是一些“是与否”的问题。标签就是这些问题的答案。在下面这个例子里,如果某个动物的行走姿态、游泳姿势和叫声都像鸭子,那么就给它打上“鸭子”的标签。

我们来看一个银行信贷的信用风险例子:

  • 我们需要预测什么?
    • 某个人是否会按时还款
    • 这就是标签:此人的信用度
  • 你用来预测的“是与否”问题或者属性是什么?
    • 申请人的基本信息和社会身份信息:职业,年龄,存款储蓄,婚姻状态等等……
    • 这些就是特征,用来构建一个分类模型,你从中提取出对分类有帮助的特征信息。

决策树模型

决策树是一种基于输入特征来预测类别或是标签的分类模型。决策树的工作原理是这样的,它在每个节点都需要计算特征在该节点的表达式值,然后基于运算结果选择一个分支通往下一个节点。下图展示了一种用来预测信用风险的决策树模型。每个决策问题就是模型的一个节点,“是”或者“否”的答案是通往子节点的分支。

  • 问题1:账户余额是否大于200元?
    • 问题2:当前就职时间是否超过1年?
      • 不可信赖

使用基于Apache Spark的随机森林方法预测贷款风险

随机森林模型

融合学习算法 结合了多个机器学习的算法,从而得到了效果更好的模型。随机森林是分类和回归问题中一类常用的融合学习方法。此算法基于训练数据的不同子集构建多棵 决策树 ,组合成一个新的模型。预测结果是所有决策树输出的组合,这样能够减少波动,并且提高预测的准确度。对于随机森林分类模型,每棵树的预测结果都视为一张投票。获得投票数最多的类别就是预测的类别。

使用基于Apache Spark的随机森林方法预测贷款风险

基于Spark机器学习工具来分析信用风险问题

我们使用 德国人信用度数据集 ,它按照一系列特征属性将人分为信用风险好和坏两类。我们可以获得每个银行贷款申请者的以下信息:

使用基于Apache Spark的随机森林方法预测贷款风险

存放德国人信用数据的csv文件格式如下:

1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1 1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1 1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1   

在这个背景下,我们会构建一个由决策树组成的随机森林模型来预测是否守信用的标签/类别,基于以下特征:

  • 标签 -> 守信用或者不守信用(1或者0)
  • 特征 -> {存款余额,信用历史,贷款目的等等}

软件

本教程将使用Spark 1.6.1

  • 你可以从链接中下载代码和数据来运行示例:
    https://github.com/caroljmcdonald/spark-ml-randomforest-creditrisk
  • 输入spark-shell命令后,本文的例子可以在spark-shell环境中交互式的运行。
  • 你也可以把代码当做独立的应用来运行,操作步骤参考 Getting Started with Spark on MapR Sandbox

按照教程指示,登录MapR沙箱,用户名为user01,密码为mapr。将样本数据文件复制到你的沙箱主目录下/user/user01 using scp。(注意,你可能需要先更新Spark的版本)打开spark shell:

$spark-shell --masterlocal[1]   

加载并解析csv数据文件

首先,我们需要引入机器学习相关的包。

importorg.apache.spark.ml.classification.RandomForestClassifier importorg.apache.spark.ml.evaluation.BinaryClassificationEvaluator importorg.apache.spark.ml.feature.StringIndexer importorg.apache.spark.ml.feature.VectorAssembler importsqlContext.implicits._ importsqlContext._ importorg.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator } importorg.apache.spark.ml.{ Pipeline, PipelineStage }   

我们用一个Scala的case类来定义Credit的属性,对应于csv文件中的一行。

    <spanclass="hljs-comment">// define the Credit Schema</span>     <spanclass="hljs-class"><spanclass="hljs-keyword">case</span> <spanclass="hljs-keyword">class</span> <spanclass="hljs-title">Credit</span><spanclass="hljs-params">(         creditability: Double,         balance: Double, duration: Double, history: Double, purpose: Double, amount: Double,         savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double,         residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double,         credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double       )</span></span> 

下面的函数解析一行数据文件,将值存入Credit类中。类别的索引值减去了1,因此起始索引值为0.

  <spanclass="hljs-comment"> // function to create a  Credit class from an Array of Double</span>     defparseCredit(<spanclass="hljs-built_in">line</span>: Array[Double]): Credit = {         Credit(           <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">0</span>),           <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">1</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">2</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">3</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">4</span>) , <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">5</span>),           <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">6</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">7</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">8</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">9</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">10</span>) - <spanclass="hljs-number">1</span>,           <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">11</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">12</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">13</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">14</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">15</span>) - <spanclass="hljs-number">1</span>,           <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">16</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">17</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">18</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">19</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">20</span>) - <spanclass="hljs-number">1</span>         )       }   <spanclass="hljs-comment"> // function to transform an RDD of Strings into an RDD of Double</span>       defparseRDD(rdd: RDD[String]): RDD[Array[Double]] = {         rdd.map(_.<spanclass="hljs-built_in">split</span>(<spanclass="hljs-string">","</span>)).map(_.map(_.toDouble))       } 

接下去,我们导入germancredit.csv文件中的数据,存为一个String类型的RDD。然后我们对RDD做map操作,将RDD中的每个字符串经过ParseRDDR函数的映射,转换为一个Double类型的数组。紧接着是另一个map操作,使用ParseCredit函数,将每个Double类型的RDD转换为Credit对象。toDF()函数将Array[[Credit]]类型的RDD转为一个Credit类的Dataframe。

    // <span class="hljs-operator"><span class="hljs-keyword">load</span> the data <span class="hljs-keyword">into</span> a  RDD     valcreditDF= parseRDD(sc.textFile(<spanclass="hljs-string">"germancredit.csv"</span>)).map(parseCredit).toDF().cache()     creditDF.registerTempTable(<spanclass="hljs-string">"credit"</span>)   DataFrame的printSchema()函数将各个字段含义以树状的形式打印到控制台输出。       // Return the <span class="hljs-keyword">schema</span> <span class="hljs-keyword">of</span> this DataFrame     creditDF.printSchema       root     |-- creditability: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- balance: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- duration: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- history: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- purpose: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- amount: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- savings: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- employment: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- instPercent: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- sexMarried: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- guarantors: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- residenceDuration: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- assets: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- age: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- concCredit: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- apartment: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- credits: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- occupation: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- dependents: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- hasPhone: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)     |-- <spanclass="hljs-keyword">foreign</span>: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>)       // Display the top <span class="hljs-number">20</span> <span class="hljs-keyword">rows</span> <span class="hljs-keyword">of</span> DataFrame     creditDF.<spanclass="hljs-keyword">show</span>       +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+     |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|<spanclass="hljs-keyword">foreign</span>|     +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">9.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2799.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">1.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">12.0</span>|    <spanclass="hljs-number">2.0</span>|    <spanclass="hljs-number">9.0</span>| <spanclass="hljs-number">841.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">23.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">12.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2122.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">3.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">1.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">39.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">12.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2171.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">38.0</span>|      <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">10.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2241.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">2.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">48.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">8.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3398.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">39.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">6.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">1361.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">40.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">3.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1098.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">0.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">65.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">24.0</span>|    <spanclass="hljs-number">2.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">3758.0</span>|    <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">23.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">11.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3905.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">1.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">30.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">6187.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">3.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">24.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">6.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1957.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">31.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">48.0</span>|    <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">10.0</span>|<spanclass="hljs-number">7582.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">31.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">3.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">2.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1936.0</span>|    <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">3.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">23.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">6.0</span>|    <spanclass="hljs-number">2.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">2647.0</span>|    <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">2.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">44.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">11.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3939.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">1.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">40.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">2.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">3213.0</span>|    <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">3.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">2.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">25.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">36.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">2337.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">4.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">3.0</span>|    <spanclass="hljs-number">11.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">7228.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|        <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">39.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|     +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+  </span> 

dataframe初始化之后,你可以用SQL命令查询数据了。下面是一些使用Scala DataFrame接口查询数据的例子:

计算数值型数据的统计信息,包括计数、均值、标准差、最小值和最大值。

  <spanclass="hljs-comment"> //  computes statistics for balance </span>       creditDF.describe(<spanclass="hljs-string">"balance"</span>).show       +<spanclass="hljs-comment">-------+-----------------+</span>     |summary|          balance|     +<spanclass="hljs-comment">-------+-----------------+</span>     |  count|            <spanclass="hljs-number">1000</span>|     |  mean|            <spanclass="hljs-number">1.577</span>|     | stddev|<spanclass="hljs-number">1.257637727110893</span>|     |    <spanclass="hljs-built_in">min</span>|              <spanclass="hljs-number">0.0</span>|     |    <spanclass="hljs-built_in">max</span>|              <spanclass="hljs-number">3.0</span>|     +<spanclass="hljs-comment">-------+-----------------+</span>       <spanclass="hljs-comment"> // compute the avg balance by creditability (the label) </span>     creditDF.groupBy(<spanclass="hljs-string">"creditability"</span>).<spanclass="hljs-built_in">avg</span>(<spanclass="hljs-string">"balance"</span>).show       +<spanclass="hljs-comment">-------------+------------------+</span>     |creditability|      <spanclass="hljs-built_in">avg</span>(balance)|     +<spanclass="hljs-comment">-------------+------------------+</span>     |          <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">1.8657142857142857</span>|     |          <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">0.9033333333333333</span>|     +<spanclass="hljs-comment">-------------+------------------+</span> 

你可以用某个表名将DataFrame注册为一张临时表,然后用SQLContext提供的sql方法执行SQL命令。下面是几个用sqlContext查询的例子:

    sqlContext.sql("<span class="hljs-operator"><span class="hljs-keyword">SELECT</span> creditability, <span class="hljs-aggregate">avg</span>(balance) <span class="hljs-keyword">as</span> avgbalance, <span class="hljs-aggregate">avg</span>(amount) <span class="hljs-keyword">as</span> avgamt, <span class="hljs-aggregate">avg</span>(duration) <span class="hljs-keyword">as</span> avgdur  <span class="hljs-keyword">FROM</span> credit <span class="hljs-keyword">GROUP</span> <span class="hljs-keyword">BY</span> creditability <span class="hljs-string">").show       +-------------+------------------+------------------+------------------+     |creditability|        avgbalance|            avgamt|            avgdur|     +-------------+------------------+------------------+------------------+     |          1.0|1.8657142857142857| 2985.442857142857|19.207142857142856|     |          0.0|0.9033333333333333|3938.1266666666666|            24.86|     +-------------+------------------+------------------+------------------+</span></span> 

提取特征

为了构建一个分类模型,你首先需要提取对分类最有帮助的特征。在德国人信用度的数据集里,每条样本用两个类别来标记——1(可信)和0(不可信)。

每个样本的特征包括以下的字段:

  • 标签 -> 是否可信:0或者1
  • 特征 -> {“存款”,“期限”,“历史记录”,“目的”,“数额”,“储蓄”,“是否在职”,“婚姻”,“担保人”,“居住时间”,“资产”,“年龄”,“历史信用”,“居住公寓”,“贷款”,“职业”,“监护人”,“是否有电话”,“外籍”}

定义特征数组

使用基于Apache Spark的随机森林方法预测贷款风险

( 图片来自:学习Spark )
为了在机器学习算法中使用这些特征,这些特征经过了变换,存入特征向量中,即一组表示各个维度特征值的数值向量。

下图中,用VectorAssembler方法将每个维度的特征都做变换,返回一个新的dataframe。

    //define the feature columns to put in the feature vector     valfeatureCols = <spanclass="hljs-class">Array</span>(<spanclass="hljs-comment">"balance"</span>, <spanclass="hljs-comment">"duration"</span>, <spanclass="hljs-comment">"history"</span>, <spanclass="hljs-comment">"purpose"</span>, <spanclass="hljs-comment">"amount"</span>,         <spanclass="hljs-comment">"savings"</span>, <spanclass="hljs-comment">"employment"</span>, <spanclass="hljs-comment">"instPercent"</span>, <spanclass="hljs-comment">"sexMarried"</span>,  <spanclass="hljs-comment">"guarantors"</span>,         <spanclass="hljs-comment">"residenceDuration"</span>, <spanclass="hljs-comment">"assets"</span>,  <spanclass="hljs-comment">"age"</span>, <spanclass="hljs-comment">"concCredit"</span>, <spanclass="hljs-comment">"apartment"</span>,         <spanclass="hljs-comment">"credits"</span>,  <spanclass="hljs-comment">"occupation"</span>, <spanclass="hljs-comment">"dependents"</span>,  <spanclass="hljs-comment">"hasPhone"</span>, <spanclass="hljs-comment">"foreign"</span> )     //set the input and output column names       valassembler = new <spanclass="hljs-class">VectorAssembler</span>().setInputCols(featureCols).setOutputCol(<spanclass="hljs-comment">"features"</span>)     //return a dataframe with all of the  feature columns in  a vector column     valdf2 = assembler.transform( creditDF)     // the transform method produced a new <span class="hljs-method">column:</span> features.     df2.show       +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+     <spanclass="hljs-localvars">|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|</span>     +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">2</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,...| 

接着,我们使用StringIndexer方法返回一个Dataframe,增加了信用度这一列作为标签。

    //  <span class="hljs-class">Create</span> a label column with the <span class="hljs-class">StringIndexer</span>       vallabelIndexer = new <spanclass="hljs-class">StringIndexer</span>().setInputCol(<spanclass="hljs-comment">"creditability"</span>).setOutputCol(<spanclass="hljs-comment">"label"</span>)     valdf3 = labelIndexer.fit(df2).transform(df2)     // the  transform method produced a new <span class="hljs-method">column:</span> label.     df3.show       +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+     <spanclass="hljs-localvars">|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|</span>     +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+     |          <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">18.0</span>|    <spanclass="hljs-number">4.0</span>|    <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">1.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">3.0</span>|  <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">2</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,...|  <spanclass="hljs-number">0.0</span>| 

下图中,数据集被分为训练数据和测试数据两个部分,70%的数据用来训练模型,30%的数据用来测试模型。

    <spanclass="hljs-comment">//  split the dataframe into training and test data</span>     valsplitSeed = <spanclass="hljs-number">5043</span>      val <spanclass="hljs-built_in">Array</span>(trainingData, testData) = df3.randomSplit(<spanclass="hljs-built_in">Array</span>(<spanclass="hljs-number">0.7</span>, <spanclass="hljs-number">0.3</span>), splitSeed) 

训练模型

使用基于Apache Spark的随机森林方法预测贷款风险

接着,我们按照下列参数训练一个随机森林分类器:

  • maxDepth:每棵树的最大深度。增加树的深度可以提高模型的效果,但是会延长训练时间。
  • maxBins:连续特征离散化时选用的最大分桶个数,并且决定每个节点如何分裂。
  • impurity:计算信息增益的指标
  • auto:在每个节点分裂时是否自动选择参与的特征个数
  • seed:随机数生成种子

模型的训练过程就是将输入特征和这些特征对应的样本标签相关联的过程。

    // create the classifier,  <span class="hljs-keyword">set</span> parameters <span class="hljs-flow">for</span> training     valclassifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(<spanclass="hljs-number">3</span>).setNumTrees(<spanclass="hljs-number">20</span>).setFeatureSubsetStrategy("auto").setSeed(<spanclass="hljs-number">5043</span>)     //  use the random forest classifier  to train (fit) the model     valmodel = classifier.fit(trainingData)        // print out the random forest trees     model.toDebugString     res20: String =      res5: String =      "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with <span class="hljs-number">20</span> trees       Tree <span class="hljs-number">0</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>):         <span class="hljs-flow">If</span> (feature <span class="hljs-number">0</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">10</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">6</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">6</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">10</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">12</span> <= <span class="hljs-number">63</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">12</span> > <span class="hljs-number">63</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>         <span class="hljs-flow">Else</span> (feature <span class="hljs-number">0</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">13</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">3</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">3</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">13</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">7</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">7</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>       Tree <span class="hljs-number">1</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>):         <span class="hljs-flow">If</span> (feature <span class="hljs-number">2</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">15</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">15</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>         <span class="hljs-flow">Else</span> (feature <span class="hljs-number">2</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">12</span> <= <span class="hljs-number">31</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">5</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">5</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">12</span> > <span class="hljs-number">31</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">4</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">4</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>       Tree <span class="hljs-number">2</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>):         <span class="hljs-flow">If</span> (feature <span class="hljs-number">8</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">6</span> <= <span class="hljs-number">2</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">4</span> <= <span class="hljs-number">10875</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">4</span> > <span class="hljs-number">10875</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">6</span> > <span class="hljs-number">2</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">1</span> <= <span class="hljs-number">36</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">1</span> > <span class="hljs-number">36</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>         <span class="hljs-flow">Else</span> (feature <span class="hljs-number">8</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>)          <span class="hljs-flow">If</span> (feature <span class="hljs-number">5</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">4</span> <= <span class="hljs-number">4113</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">4</span> > <span class="hljs-number">4113</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span>          <span class="hljs-flow">Else</span> (feature <span class="hljs-number">5</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>)           <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">2</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>           <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">2</span>.<span class="hljs-number">0</span>)            Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span>       Tree <span class="hljs-number">3</span> ... 

测试模型

接下来,我们对测试数据进行预测。

    // run the  model on test features to get predictions     valpredictions = model.transform(testData)      //As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.     predictions.show       +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+     |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|      rawPrediction|        probability|prediction|     +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+     |          <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">12.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">5.0</span>|<spanclass="hljs-number">1108.0</span>|    <spanclass="hljs-number">0.0</span>|      <spanclass="hljs-number">3.0</span>|        <spanclass="hljs-number">4.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|              <spanclass="hljs-number">2.0</span>|  <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">28.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">1.0</span>|    <spanclass="hljs-number">1.0</span>|      <spanclass="hljs-number">2.0</span>|      <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|    <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,<spanclass="hljs-number">8</span>,<spanclass="hljs-keyword">...</span>|  <spanclass="hljs-number">1.0</span>|[<spanclass="hljs-number">14.1964586927573</span><spanclass="hljs-keyword">...</span>|[<spanclass="hljs-number">0.70982293463786</span><spanclass="hljs-keyword">...</span>|      <spanclass="hljs-number">0.0</span>| 

然后,我们用BinaryClassificationEvaluator评估预测的效果,它将预测结果与样本的实际标签相比较,返回一个准确度指标(ROC曲线所覆盖的面积)。本例子中,AUC达到78%。

    // <span class="hljs-operator"><span class="hljs-keyword">create</span> an Evaluator <span class="hljs-keyword">for</span> binary classification, which expects two <span class="hljs-keyword">input</span> columns: rawPrediction <span class="hljs-keyword">and</span> label.     valevaluator = new BinaryClassificationEvaluator().setLabelCol(<spanclass="hljs-string">"label"</span>)     // Evaluates predictions <span class="hljs-keyword">and</span> returns a scalar metric areaUnderROC(larger <span class="hljs-keyword">is</span> better).     valaccuracy = evaluator.evaluate(predictions)      accuracy: <spanclass="hljs-keyword">Double</span> = <spanclass="hljs-number">0.7824906081835722</span></span> 

使用机器学习管道

我们接着用管道来训练模型,可能会取得更好的效果。管道采取了一种简单的方式来比较各种不同组合的参数的效果,这个方法称为网格搜索法(grid search),你先设置好待测试的参数,MLLib就会自动完成这些参数的不同组合。管道搭建了一条工作流,一次性完成了整个模型的调优,而不是独立对每个参数进行调优。

下面我们就用ParamGridBuilder工具来构建参数网格。

    // We <span class="hljs-keyword">use</span> a ParamGridBuilder <span class="hljs-keyword">to</span> construct a grid <span class="hljs-keyword">of</span> parameters <span class="hljs-keyword">to</span> search over     valparamGrid = <spanclass="hljs-keyword">new</span> ParamGridBuilder()       .addGrid(classifier.maxBins, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-number">25</span>, <spanclass="hljs-number">28</span>, <spanclass="hljs-number">31</span>))       .addGrid(classifier.maxDepth, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-number">4</span>, <spanclass="hljs-number">6</span>, <spanclass="hljs-number">8</span>))       .addGrid(classifier.impurity, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-string">"entropy"</span>, <spanclass="hljs-string">"gini"</span>))       .build() 

创建并完成一条管道。一条管道由一系列stage组成,每个stage相当于一个Estimator或是Transformer。

    valsteps: <spanclass="hljs-built_in">Array</span>[PipelineStage] = <spanclass="hljs-built_in">Array</span>(classifier)     valpipeline = <spanclass="hljs-keyword">new</span> Pipeline().setStages(steps) 

我们用CrossValidator类来完成模型筛选。CrossValidator类使用一个Estimator类,一组ParamMaps类和一个Evaluator类。注意,使用CrossValidator类的开销很大。

    // Evaluate model on test instances <span class="hljs-keyword">and</span> compute test error     valevaluator = new BinaryClassificationEvaluator()       <spanclass="hljs-preprocessor">.setLabelCol</span>(<spanclass="hljs-string">"label"</span>)     valcv = new CrossValidator()       <spanclass="hljs-preprocessor">.setEstimator</span>(pipeline)       <spanclass="hljs-preprocessor">.setEvaluator</span>(evaluator)       <spanclass="hljs-preprocessor">.setEstimatorParamMaps</span>(paramGrid)       <spanclass="hljs-preprocessor">.setNumFolds</span>(<spanclass="hljs-number">10</span>) 

管道在参数网格上不断地爬行,自动完成了模型优化的过程:对于每个ParamMap类,CrossValidator训练得到一个Estimator,然后用Evaluator来评价结果,然后用最好的ParamMap和整个数据集来训练最优的Estimator。

使用基于Apache Spark的随机森林方法预测贷款风险

    // When fit <span class="hljs-keyword">is</span> called, the stages are executed <span class="hljs-keyword">in</span> order.     // Fit will run cross-validation,  <span class="hljs-keyword">and</span> choose the best <span class="hljs-keyword">set</span> <span class="hljs-keyword">of</span> parameters     //The fitted model <span class="hljs-keyword">from</span> a Pipeline <span class="hljs-keyword">is</span> an PipelineModel, which consists <span class="hljs-keyword">of</span> fitted models <span class="hljs-keyword">and</span> transformers       valpipelineFittedModel = cv.fit(trainingData) 

现在,我们可以用管道训练得到的最优模型进行预测,将预测结果与标签做比较。预测结果取得了82%的准确率,相比之前78%的准确率有提高。

    //  <span class="hljs-keyword">call</span> tranform to make predictions on test data. The fitted model will use the best model found     valpredictions = pipelineFittedModel<spanclass="hljs-preprocessor">.transform</span>(testData)     valaccuracy = evaluator<spanclass="hljs-preprocessor">.evaluate</span>(predictions)       Double = <spanclass="hljs-number">0.8204386232104784</span>     valrm2 = new RegressionMetrics(       predictions<spanclass="hljs-preprocessor">.select</span>(<spanclass="hljs-string">"prediction"</span>, <spanclass="hljs-string">"label"</span>)<spanclass="hljs-preprocessor">.rdd</span><spanclass="hljs-preprocessor">.map</span>(<spanclass="hljs-built_in">x</span> =>       (<spanclass="hljs-built_in">x</span>(<spanclass="hljs-number">0</span>)<spanclass="hljs-preprocessor">.asInstanceOf</span>[Double], <spanclass="hljs-built_in">x</span>(<spanclass="hljs-number">1</span>)<spanclass="hljs-preprocessor">.asInstanceOf</span>[Double])))     println(<spanclass="hljs-string">"MSE: "</span> + rm2<spanclass="hljs-preprocessor">.meanSquaredError</span>)     println(<spanclass="hljs-string">"MAE: "</span> + rm2<spanclass="hljs-preprocessor">.meanAbsoluteError</span>)     println(<spanclass="hljs-string">"RMSE Squared: "</span> + rm2<spanclass="hljs-preprocessor">.rootMeanSquaredError</span>)     println(<spanclass="hljs-string">"R Squared: "</span> + rm2<spanclass="hljs-preprocessor">.r</span>2)     println(<spanclass="hljs-string">"Explained Variance: "</span> + rm2<spanclass="hljs-preprocessor">.explainedVariance</span> + <spanclass="hljs-string">"/n"</span>)       MSE: <spanclass="hljs-number">0.2575250836120402</span>     MAE: <spanclass="hljs-number">0.25752508361204013</span>     RMSESquared: <spanclass="hljs-number">0.5074692932700856</span>     R Squared: -<spanclass="hljs-number">0.1687988628287138</span>     ExplainedVariance: <spanclass="hljs-number">0.15466269952237702</span> 
原文  http://dataunion.org/25243.html
正文到此结束
Loading...