直接见代码:
import paddle.fluid as fluid
import numpy as np
#data = fluid.layers.data(name='data', shape=[1], dtype='int64')
data = fluid.data(name='data', shape=[-1, 1], dtype='int64')
np.random.seed(28)
weight_data = np.random.random(size=(20, 8))
print(weight_data)
#加载用户自定义或预训练的词向量
w_param_attrs = fluid.ParamAttr(
name="w_param_attrs",
initializer=fluid.initializer.NumpyArrayInitializer(weight_data),
trainable=False)
emb_2 = fluid.embedding(input=data,
size=(20, 8),
param_attr=w_param_attrs,
dtype='float32')
cpu = fluid.CPUPlace() # 定义运算场所
exe = fluid.Executor(cpu) # 创建执行器
exe.run(fluid.default_startup_program()) # 网络参数初始化
x = np.array([[1], [2]])
outs = exe.run(feed={'data': x}, fetch_list=[emb_2.name])
print(outs)
官方文档 https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn/data_cn.html#data 中不推荐使用 paddle.fluid.layers.data
,因其在之后的版本中会被删除。请使用 paddle.fluid.data
。
结果如下:
[[0.72901374 0.5612396 0.12496709 0.39759237 0.78130821 0.51099298
0.18269336 0.85351288]
[0.95537189 0.98421347 0.19270097 0.9707951 0.23480835 0.02635385
0.94606034 0.92172485]
[0.29397577 0.1662737 0.39542284 0.51066973 0.30803723 0.42956883
0.83006941 0.56239357]
[0.83088831 0.99692929 0.33257881 0.09100813 0.77383156 0.14938373
0.72535506 0.95514643]
[0.07309577 0.44716275 0.84111807 0.14553967 0.76527154 0.78178492
0.67507855 0.13170219]
[0.03930318 0.65602308 0.25118261 0.98841838 0.53338304 0.05917524
0.69875531 0.62717477]
[0.89577854 0.16192467 0.61038158 0.3169851 0.76326567 0.15628208
0.92988758 0.49781052]
[0.83323397 0.22996943 0.10681001 0.67370038 0.57898325 0.87584937
0.99712764 0.27530634]
[0.74263626 0.28473195 0.72624867 0.49107034 0.86801609 0.1622617
0.9713251 0.04888569]
[0.70054591 0.65194491 0.04645909 0.19730088 0.33060701 0.75264495
0.36501458 0.53077101]
[0.35418132 0.51467406 0.26169937 0.85173949 0.62324126 0.30446975
0.77547856 0.89555198]
[0.7374077 0.85555241 0.82012533 0.86522095 0.38212962 0.61140706
0.41550595 0.2421348 ]
[0.06125105 0.81751611 0.38363211 0.97884048 0.38187252 0.63014968
0.44335181 0.02552223]
[0.23321525 0.77924846 0.16996923 0.41457111 0.59480006 0.91087008
0.50639157 0.4386332 ]
[0.03229215 0.22840922 0.18160441 0.24255622 0.8094556 0.51928847
0.36861752 0.46235367]
[0.60488351 0.55737864 0.03305479 0.39902018 0.08332113 0.48316635
0.85653765 0.84775654]
[0.37035053 0.71812028 0.00461064 0.76418841 0.74670009 0.85891882
0.45676896 0.94777212]
[0.63737347 0.49762039 0.18912248 0.75981605 0.37119162 0.20927375
0.32256109 0.20617277]
[0.40986867 0.13548799 0.81640462 0.63828349 0.67581164 0.00853934
0.73750379 0.76717025]
[0.16223589 0.9606869 0.79786617 0.58411784 0.04252264 0.34268869
0.36767624 0.88560098]]
[array([[[0.9553719 , 0.9842135 , 0.19270097, 0.9707951 , 0.23480836,
0.02635385, 0.94606036, 0.92172486]],
[[0.29397577, 0.1662737 , 0.39542285, 0.5106697 , 0.30803722,
0.42956883, 0.8300694 , 0.56239355]]], dtype=float32)]