以前做als相关的东西的时候,都是用的公司的内部工具居多,今天第一次用了下spark的mlib,拿了个几M的小数据集试了个水。。
结果一跑,我擦。。。居然stackoverflow了。。
源码如下:
from pyspark.mllib.recommendation import ALS
from numpy import array
from pyspark import SparkContext
if __name__ == '__main__':
# sc = SparkSession\
# .builder\
# .appName("PythonWordCount")\
# .getOrCreate()
sc = SparkContext(appName="PythonWordCount")
data = sc.textFile("CollaborativeFiltering.txt", 20)
ratings = data.map(lambda line: [float(x) for x in line.split(' ')]).persist()
rank = 10
n = 30
model = ALS.train(ratings, rank, n)
testdata = ratings.map(lambda r: (int(r[0]), int(r[1])))
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist()
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()
print "Mean Squared Error = " + str(MSE)
ratesAndPreds.unpersist()
错误信息如下:
2017-11-24 17:15:23 [INFO] ShuffleMapStage 66 (flatMap at ALS.scala:1272) failed in Unknown s due to Job aborted due to stage failure: Task serialization failed: java.lang.StackOverflowError
java.lang.StackOverflowError
at java.io.ObjectOutputStream$BlockDataOutputStream.write(ObjectOutputStream.java:1841)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1534)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348)
at scala.collection.immutable.$colon$colon.writeObject(List.scala:379)
at sun.reflect.GeneratedMethodAccessor15.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1028)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
泪崩 + 泪崩 + 泪崩
再后来怀疑到了linage 是不是过长导致,遂google和请教大神
发现果然如此,spark在迭代计算的过程中,会导致linage剧烈变长,所需的栈空间也急剧上升,最终爆栈了。。
这类问题解决方法如下:
在代码中加入 sc.setCheckpointDir(path),显示指明checkpoint路径,问题便可得到解决。当然这也带来了一个问题,如果数据量变大,磁盘的IO变成为了瓶颈,这方面暂时没能解决,各位聚聚有更好的解决方案,欢迎联系我~
修改后代码如下:
from pyspark.mllib.recommendation import ALS
from numpy import array
from pyspark import SparkContext
if __name__ == '__main__':
# sc = SparkSession\
# .builder\
# .appName("PythonWordCount")\
# .getOrCreate()
sc = SparkContext(appName="PythonWordCount")
sc.setCheckpointDir('checkpoint')
data = sc.textFile("CollaborativeFiltering.txt", 20)
ratings = data.map(lambda line: [float(x) for x in line.split(' ')]).persist()
rank = 10
n = 30
#ALS.setCheckpointInterval(2).setMaxIter(100).setRank(10).setAlpha(0.1)
model = ALS.train(ratings, rank, n)
testdata = ratings.map(lambda r: (int(r[0]), int(r[1])))
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist()
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()
print "Mean Squared Error = " + str(MSE)
ratesAndPreds.unpersist()