pytorch矩阵计算

class Vector(object):
    CANNOT_NORMALIZE_ZERO_VECTOR_MSG = 'Cannot normalize the zero vector'
    def __init__(self, coordinates):
        self.coordinates = coordinates
        self.num_samples = len(coordinates)
        self.dimension = len(coordinates[0])

    def __str__(self):
        # return 'Vector: {}'.format(self.coordinates)
        return self.coordinates

    #两个矩阵是否相等
    def __eq__(self, v):
        return self.coordinates == v.coordinates

    # 矩阵的大小
    def magnitude(self):
        return torch.sqrt(torch.sum(torch.square(self.coordinates), dim = -1)).view(self.num_samples, 1)

    # 矩阵归一化
    def normalized(self):
        try:
            magnitude = self.magnitude()
            weight = (1.0 / magnitude).view(self.num_samples, 1)
            return self.coordinates * weight
        except ZeroDivisionError:
            raise Exception(self.CANNOT_NORMALIZE_ZERO_VECTOR_MSG)

    # 两个矩阵的点积
    def dot(self, v):
        # return sum([x * y for x, y in zip(self.coordinates, v.coordinates)])
        return torch.sum(self.coordinates * v, dim = -1)


    # 判断两个矩阵是否正交
    def is_orthogonal_to(self, v, tolerance = 1e-10):
        return abs(self.dot(v)) < tolerance

    # 是否全为0的矩阵
    def is_zero(self, tolerance = 1e-10):
        return self.magnitude() < tolerance

    # 两个矩阵是否平行
    def is_parallel_to(self, v):
        return (self.is_zero() or
                v.is_zero() or
                self.angle_with(v) == 0 or
                self.angle_with(v) == pi)

    # 矩阵在另一个矩阵上的投影
    def component_parallel_to(self, basis):
        try:
            u = basis.normalized()
            weight = torch.sum(self.coordinates * u, dim = -1).view(self.num_samples, 1)
            return u*weight
        except Exception as e:
            if str(e) == self.CANNOT_NORMALIZE_ZERO_VECTOR_MSG:
                raise Exception('Cannot compute an angle with the zero vector')
            else:
                raise e

    # 矩阵相对于投影矩阵的垂直矩阵
    def component_orthogonal_to(self, basis):
        try:
            projection = self.component_parallel_to(basis)
            return self.minus(projection)
        except Exception as e:
            if str(e) == self.CANNOT_NORMALIZE_ZERO_VECTOR_MSG:
                raise Exception('Cannot compute an angle with the zero vector')
            else:
                raise e

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值