numpy.argmax 用在求解混淆矩阵用

numpy.argmax

numpy. argmax (a, axis=None, out=None)[source]

Returns the indices of the maximum values along an axis.

Parameters:

a : array_like

Input array.

axis : int, optional

By default, the index is into the flattened array, otherwise along the specified axis.

out : array, optional

If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.

Returns:

index_array : ndarray of ints

Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

See also

ndarray.argmax, argmin

amax
The maximum value along a given axis.
unravel_index
Convert a flat index into an index tuple.

Notes

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples

>>> a = np.arange(6).reshape(2,3) >>> a array([[0, 1, 2],  [3, 4, 5]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2]) 
>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1

在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
if __name__ == "__main__":
    width, height = 32, 32
    X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
    trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
    print("sample data:")
    print(trainX[0])
    print(trainY[0])
    print(testX[-1])
    print(testY[-1])

    model = get_model(width, height, classes=100)

    filename = 'cnn_handwrite-acc0.8.tflearn'
    # try to load model and resume training
    #try:
    #    model.load(filename)
    #    print("Model loaded OK. Resume training!")
    #except:
    #    pass

    # Initialize our callback with desired accuracy threshold.
    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6)
    try:
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
    except StopIteration as e:
        print("OK, stop iterate!Good!")

    model.save(filename)

    # predict all data and calculate confusion_matrix
    model.load(filename)

    pro_arr =model.predict(X)
    predict_labels = np.argmax(pro_arr, axis=1)
    print(classification_report(org_labels, predict_labels))
    print(confusion_matrix(org_labels, predict_labels))

 

转载于:https://www.cnblogs.com/bonelee/p/8976380.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值