spark自带的logistic_regression例子分析

import sys

import numpy as np
from pyspark.sql import SparkSession


D = 10  # Number of dimensions


# Read a batch of points from the input file into a NumPy matrix object. We operate on batches to
# make further computations faster.
# The data file contains lines of the form <label> <x1> <x2> ... <xD>. We load each block of these
# into a NumPy array of size numLines * (D + 1) and pull out column 0 vs the others in gradient().
def readPointBatch(iterator): 
    strs = list(iterator)
    matrix = np.zeros((len(strs), D + 1))
    for i, s in enumerate(strs):
        matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ')
    return [matrix]

if __name__ == "__main__":

    if len(sys.argv) != 3:
        print("Usage: logistic_regression <file> <iterations>", file=sys.stderr)
        exit(-1)

    print("""WARN: This is a naive implementation of Logistic Regression and is
      given as an example!
      Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
      to see how ML's implementation is used.""", file=sys.stderr)

    spark = SparkSession\
        .builder\
        .appName("PythonLR")\
        .getOrCreate()

    points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
        .mapPartitions(readPointBatch).cache() 
    iterations = int(sys.argv[2])

    # Initialize w to a random value
    w = 2 * np.random.ranf(size=D) - 1
    print("Initial w: " + str(w))

    # Compute logistic regression gradient for a matrix of data points
    def gradient(matrix, w):
        Y = matrix[:, 0]    # point labels (first column of input file)
        X = matrix[:, 1:]   # point coordinates
        # For each point (x, y), compute gradient function, then sum these up
        return ((1.0 / (1.0 + np.exp(-Y * X.dot(w))) - 1.0) * Y * X.T).sum(1)

    def add(x, y):
        x += y
        return x

    for i in range(iterations):
        print("On iteration %i" % (i + 1))
        w -= points.map(lambda m: gradient(m, w)).reduce(add)

    print("Final w: " + str(w))

    spark.stop()

1. 其中map()与mapPartitions()函数区别:

map()函数是作用到数据集的每一个元素,所以它的目标是单个的元素

mapPartitions()函数作用到数据集的每个分区上,所以它的目标是迭代器(包含多个元素)

2. enumerate()函数

返回枚举对象

以下展示了使用 enumerate() 方法的实例:

>>> seasons = [ ' Spring ' , ' Summer ' , ' Fall ' , ' Winter ' ] >>> list ( enumerate ( seasons ) ) [ ( 0 , ' Spring ' ) , ( 1 , ' Summer ' ) , ( 2 , ' Fall ' ) , ( 3 , ' Winter ' ) ] >>> list ( enumerate ( seasons , start = 1 ) ) # 小标从 1 开始 [ ( 1 , ' Spring ' ) , ( 2 , ' Summer ' ) , ( 3 , ' Fall ' ) , ( 4 , ' Winter ' ) ]

ps:readPointBatch()函数写的挺优雅的

3. numpy fromstring()将字符串转换成一维数组

Examples

>>> np.fromstring('1 2', dtype=int, sep=' ')
array([1, 2])
>>> np.fromstring('1, 2', dtype=int, sep=',')
array([1, 2])




  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值