go.Image在绘制单通道图像(比如灰度图)的时候需要将图像变成三通道的才可以.
这里可以使用np.stack函数进行堆叠.
import plotly.graph_objects as go
# data.shape==(28, 28)
data = (batch_data[0][0] * 255).astype(np.uint8)
# img.shape==(28, 28, 3)
img = np.stack((data, data, data), axis=-1)
fig = go.Figure(go.Image(z=img))
fig.show()
当然也可以使用go.Heatmap或者ff.create_annotated_heatmap来创建热力图,并把colorscale指定为Greys即可.
使用热力图的另外一个好处是对于值在[0,1]的数据无需转换到[0, 255],而go.Image需要转换.
import plotly.graph_objects as go
data = batch_data[0][0] # data.shape == (28, 28)
fig = go.Figure(go.Heatmap(z=data))
fig.update_traces(colorscale='Greys')
fig.show()