使用热力图表示联邦学习场景中的客户端数据分布

用于生成热力图,记录过程,方便之后直接使用。
使用场景:联邦学习中显示客户端数据分布,或者显示数据分布的各类其他场景


一、代码

写这段代码时主要考虑联邦学习中显示客户端数据分布这一场景

hot.py

import numpy as np
import matplotlib.pyplot as plt
def hot_map(y_train, dataidx_map):
    # CIFAR-10 数据集共有 10 个类别
    num_classes = 10
    # 有 10 个客户端
    num_clients = 10
    #图片中字体大小
    font_size = 32

    # 初始化一个矩阵来存储每个客户端的数据分布
    client_data_distribution = np.zeros((num_clients, num_classes), dtype=int)
    # 统计每个客户端中每个类别的样本数量
    for client_id in range(num_clients):
        indices = dataidx_map[client_id]
        client_labels = y_train[indices]
        unique_labels, label_counts = np.unique(client_labels, return_counts=True)
        for label, count in zip(unique_labels, label_counts):
            client_data_distribution[client_id, label] = count
    # 转置矩阵,这里的转置主要是为了让横坐标是客户端,纵坐标是类标签。如果不转置,横纵坐标会交换
    client_data_distribution = client_data_distribution.T
    # 设置全局字体为新罗马字体
    plt.rcParams["font.family"] = "Times New Roman"
    # 绘制热力图
    plt.figure(figsize=(10, 6))
    plt.imshow(client_data_distribution, cmap='Reds', interpolation='nearest')
    #设置图片标题(上方)
    # plt.title('Clients Data Distribution in CIFAR-10 Dataset')
    # 隐藏坐标轴的边框,更美观
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.xlabel('Client', fontsize=font_size)
    plt.ylabel('Label', fontsize=font_size)
    cbar = plt.colorbar()
        
    # 隐藏颜色条的边框
    cbar.outline.set_visible(False)
    cbar.ax.tick_params(labelsize=font_size)  # 设置颜色条刻度标签的字体大小 

    plt.xticks(np.arange(num_classes), np.arange(num_classes), fontsize =font_size)
    plt.yticks(np.arange(num_clients), np.arange(num_clients), fontsize=font_size)
    
    # 设置坐标(i, j)显示的数值,可直接注释去除
    for i in range(num_clients):
        for j in range(num_classes):
            # text((x, y)=坐标, s=数值, ha=水平对齐, va=垂直对齐, color=颜色)
            plt.text(x=i, y=j, s=client_data_distribution[j][i], ha='center', va='center', color='white')

    plt.tight_layout()
    plt.savefig('Fig.jpg',dpi = 400, bbox_inches='tight')# bbox_inches用于在保存时将图片位于画布中间,保持紧凑;dpi是一个关于图片清晰度的参数,数值越大,图片越高清
    plt.show()

使用方法

首先在需要调用热力图的地方引入文件

from hot import hot_map

接着在需要画图的地方调用,通常是刚对客户端分配好数据或者对数据分布进行处理后的位置


hot_map(y_train,net_dataidx_map)

二、参数解释

y_train:[6 9 9 … 9 1 1],就是训练数据的标签,用列表表示。

net_dataidx_map:{0:[39982, 40086, 49891, 13047, 8170, 94, 4697,],1:[…], …},这是各客户端的数据分配情况,使用字典显示,字典的键表示客户端标记,表示几号客户端;值用列表显示,列表中的各数值表示y_train的下标,举例来说,以0的39982为例,表示0号客户端包含了y_train中第39982个标签,是客户端与数据标签的映射。

三、样图

在这里插入图片描述


关键词

热力图; 联邦学习; 数据分布;python

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

虫本初阳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值