1.简单的张量加法案例
"""
@Time : 2023/10/16 0016 14:27
@Auth : yeqc
"""
'''简单的张量加法案例'''
import torch
import syft as sy
hook = sy.TorchHook(torch)
# 创建一个虚拟联邦学习节点
bob = sy.VirtualWorker(hook, id="Bob")
print('bob:', bob)
# 将x,y发送给Bob
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 1])
x_ptr = x.send(bob)
y_ptr = y.send(bob)
print('bob._objects = ', bob._objects)
print('x_ptr = ', x_ptr)
print('y_ptr = ', y_ptr)
print('x_ptr.location = ', x_ptr.location) # 张量指针指向的位置
print('x_ptr.owner = ', x_ptr.owner) # 持有指针张量的虚拟机器工作节点
z = x_ptr + y_ptr # x_ptr 与 y_ptr 加法
print('z = ', z)
print('bob._objects = ', bob._objects)
2.基于指针的远程操作案例
"""
@Time : 2023/10/16 0016 14:36
@Auth : yeqc
"""
'''基于指针的远程操作案例'''
import torch
import syft as sy
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="Bob")
alice = sy.VirtualWorker(hook, id="Alice")
x = torch.tensor([1, 2, 3])
# x发送给bob
x_ptr = x.send(bob)
print("bob._objects = ", bob._objects)
print("alice._objects = ", alice._objects)
'''数据仍然在bob,alice具有了从alice->bob的指针,x_ptr是me->bob的指针,pointer_to_x_ptr是me->alice的指针。x_ptr发送给alice后,
实际数据x还是原始Torch张量数据,还在bob,仍然可以进行Torch张量操作,但是只能在指针张量上执行get()方法,不能在Torch张量上执行get()方法'''
# 创建新指针pointer_to_x_ptr
pointer_to_x_ptr = x_ptr.send(alice)
print(pointer_to_x_ptr)
print("bob._objects = ", bob._objects)
print("alice._objects = ", alice._objects) # alice指向bob
'''下面从alice处取回指针x_ptr,即alice为空,从bob处取回指针x_ptr'''
x_ptr = pointer_to_x_ptr.get()
print(x_ptr)
print("alice._objects = ", alice._objects)
x = x_ptr.get()
print(x)
print("bob._objects = ", bob._objects)
3.基于指针的链式操作案例
"""
@Time : 2023/10/16 0016 14:48
@Auth : yeqc
"""
'''基于指针的链式操作案例'''
import torch
import syft as sy
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="Bob")
alice = sy.VirtualWorker(hook, id="Alice")
x = torch.tensor([1, 2, 3])
x_ptr = x.send(bob)
print(x_ptr)
print("bob:", bob._objects)
print("alice:", alice._objects)
# 张量x执行move()操作
x_ptr.move(alice)
print(x_ptr) # 此时x_ptr指针为me->alice,bob张量数据为空,alice有张量数据
print("bob:", bob._objects)
print("alice:", alice._objects)
# 下面进行三个工作节点的指针操作测试
dave = sy.VirtualWorker(hook, id="Dave")
alice.clear_objects() # 由于上述alice有张量数据,因此清空alice
print("bob:", bob._objects)
print("alice:", alice._objects)
print("dave:", dave._objects)
# 将张量数据x按指针模式发送给bob、alice和dave
x = torch.tensor([3, 2, 1]).send(bob).send(alice).send(dave)
print("bob:", bob._objects)
print("alice:", alice._objects)
print("dave:", dave._objects)
# 引入新节点,并执行move()操作
fiona = sy.VirtualWorker(hook, id="Fiona")
x.move(fiona) # 可以看到move()操作只改变最后部分的指针指向
print("bob:", bob._objects)
print("alice:", alice._objects)
print("dave:", dave._objects)
print("fiona:", fiona._objects)
# 连续调用三次get(),可以将数据全部收回
x = x.get().get().get()
print("bob:", bob._objects)
print("alice:", alice._objects)
print("dave:", dave._objects)
print("fiona:", fiona._objects)
print(x)