import breeze.linalg.{DenseMatrix, DenseVector}
import org.apache.spark.rdd.RDD
/**
* Created by legotime
*/
case class MatrixEntry(i: Long, j: Long, value: Double)
case class VectorEntry(i: Long,value: Array[Double])
class bigMatrix(val entries: RDD[MatrixEntry],
var nRows: Long,
var nCols: Long) extends Serializable{
def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L)
// 列数
def numCols(): Long = {
if (nCols <= 0L) {
computeSize()
}
nCols
}
//行数
def numRows(): Long = {
if (nRows <= 0L) {
computeSize()
}
nRows
}
def apply(row: Long, col: Long) = {
if(row < - numRows || row >= numRows) throw new IndexOutOfBoundsException((row,col) + " not in [-"+numRows+","+numRows+") x [-"+numCols+"," + numCols+")")
if(col < - numCols || col >= numCols) throw new IndexOutOfBoundsException((row,col) + " not in [-"+numRows+","+numRows+") x [-"+numCols+"," + numCols+")")
var tmp = 0.0
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(i==row && j == col){
tmp = value
}
}
tmp
}
def apply(row:Long)={
if(row < - numRows) throw new IndexOutOfBoundsException((row) + " not in [-"+numRows+","+numRows+") x [-"+numCols+"," + numCols+")")
getRow(row)
}
//和apply一样
def valueAt(i: Int) = apply(i)
def valueAt(row: Long, col: Long) = apply(row,col)
/**
* 查看一个值出现在矩阵中有多少次
*
* @param target 输入的值
* @return
*/
def valueCount(target:Double)= entries.filter{element =>
if(target == element.value){true}else{false}
}.count()
/**
* 查看一个值在矩阵中的坐标值
*
* @param target 输入的值
* @return 得到一个n*2的矩阵,n表示个数
*/
def valueIndex(target:Double)={
val mat = DenseMatrix.zeros[Int](valueCount(target).toInt, 2)
var x = 0
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(target ==value){
mat(x,0) = i.toInt
mat(x,1) = j.toInt
x = x+1
}
}
mat
}
/**
*计算一行中的topN
*
* @param i 输入的第几个词语
* @param n 找出排名前几
*/
def getRowTopN(i: Long,n:Int): DenseVector[Double] ={
if( n>numCols){
require( n < numCols,
s"getRowTopN function input arg must < matrix length")
}
val vec: DenseVector[Double] = DenseVector.zeros[Double](n)
val tmp = mergeSort(getColumn(i).toArray.toList)
for(j <- 0 until n){
vec(j) = tmp(n-j)
}
vec
}
/**
*计算一列中的topN
*
* @param j 输入的第几个词语
* @param n 找出排名前几
*/
def getColTopN(j:Long,n:Int):DenseVector[Double]={
if( n>numRows){
require( n < numRows,
s"getColTopN function input arg must < matrix length")
}
val vec: DenseVector[Double] = DenseVector.zeros[Double](n)
val tmp = mergeSort(getRow(j).toArray.toList)
for(i <- 0 until n){
vec(i) = tmp(n-i)
}
vec
}
/**
* 计算这个矩阵的全部累计值
*
* @return
*/
def sum(): Double =entries.map(part =>part.value).reduce(_+_)
/**
* 计算第x行的累计和
*
* @param x
* @return
*/
def sumRow(x:Long):Double ={
var result = 0.0
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(i == x){
result = result+value
}
}
result
}
/**
* 计算第y列的累计和
*
* @param y
* @return
*/
def sumCol(y:Long):Double={
var result = 0.0
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(j == y){
result = result+value
}
}
result
}
/**
* 把其中一个行的数据作为向量提取出来
*
* @param x 第几行
* @return 返回一个向量
*/
def getColumn(x: Long): DenseVector[Double] ={
//经过测试,默认JVM参数,发现矩阵存储可以亿级别向量,尝试转换为向量用于排序
//千万级别的List,可以顶住
val denseVector = DenseVector.zeros[Double](numCols().toInt)
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(i == x){
denseVector(j.toInt) = value
}
}
denseVector
}
/**
* 把其中一个列的数据作为向量提取出来
*
* @param y 第几列
* @return 返回一个向量
*/
def getRow(y:Long): DenseVector[Double] ={
val denseVector = DenseVector.zeros[Double](numCols().toInt)
entries.collect().foreach { case MatrixEntry(i, j, value) =>
if(j == y){
denseVector(i.toInt) = value
}
}
denseVector
}
//转置
def transpose():bigMatrix={
new bigMatrix(entries.map(element => MatrixEntry(element.j,element.i,element.value)),numCols(),numRows())
}
override def toString = {
numCols+" colmun "+numRows+ " row " + "big matrix"
}
override def equals(obj: scala.Any): Boolean = super.equals(obj)
/**
* 转换为breeze 的密集矩阵(直接爆炸,内存不够就别尝试)
*
* @return
*/
def toDenseMatrix: DenseMatrix[Double] = {
val m = numRows().toInt
val n = numCols().toInt
val mat = DenseMatrix.zeros[Double](m, n)
entries.collect().foreach { case MatrixEntry(i, j, value) =>
mat(i.toInt, j.toInt) = value
}
mat
}
/**
* 归并排序
*
* @param SortList 需要排序的list
* @return 排序好的List
*/
private def mergeSort(SortList: List[Double]): List[Double] = {
def merge(a: List[Double], b: List[Double]): List[Double] = (a,b) match {
case (Nil, _) => b
case (_, Nil) => a
case (x::xs, y::ys) =>
if(x <= y) x :: merge(xs, b)
else y :: merge(a, ys)
}
if(SortList.length == 1) SortList
else{
val (first, second) = SortList.splitAt(SortList.length/2)
merge(mergeSort(first), mergeSort(second))
}
}
private def computeSize()={
val (m1, n1) = entries.map(entry => (entry.i, entry.j)).reduce { case ((i1, j1), (i2, j2)) =>
(math.max(i1, i2), math.max(j1, j2))
}
nRows = math.max(nRows, m1 + 1L)
nCols = math.max(nCols, n1 + 1L)
(nRows,nCols)
}
}
基于RDD解决大矩阵问题
最新推荐文章于 2020-04-30 00:04:44 发布