1 基本使用
三个维度以下的看看这个就行了
python 多维切片之冒号和三个点
https://blog.csdn.net/z13653662052/article/details/78010654
2高纬张量切片
五个维度的如何理解?
先看最后两位就是我们常规理解的矩阵 几行几列
a = np.arange(720).reshape((2,3,4,5,6))
a 就是 很多个五行六列的矩阵
注意观察中括号
开头有五个中括号总共720个数
包括 2个 四个中括号的张量
一个 四个中括号的张量 是由3个有3个中括号的张量构成的
一个 三个中括号的张量是由4 个有2个中括号的张量构成的
一个 两个中括号的张量就是 5 * 6 的矩阵了
具体的切片操作复试代码显示看一下
[[[[[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]
[ 12 13 14 15 16 17]
[ 18 19 20 21 22 23]
[ 24 25 26 27 28 29]]
[[ 30 31 32 33 34 35]
[ 36 37 38 39 40 41]
[ 42 43 44 45 46 47]
[ 48 49 50 51 52 53]
[ 54 55 56 57 58 59]]
[[ 60 61 62 63 64 65]
[ 66 67 68 69 70 71]
[ 72 73 74 75 76 77]
[ 78 79 80 81 82 83]
[ 84 85 86 87 88 89]]
[[ 90 91 92 93 94 95]
[ 96 97 98 99 100 101]
[102 103 104 105 106 107]
[108 109 110 111 112 113]
[114 115 116 117 118 119]]]
[[[120 121 122 123 124 125]
[126 127 128 129 130 131]
[132 133 134 135 136 137]
[138 139 140 141 142 143]
[144 145 146 147 148 149]]
[[150 151 152 153 154 155]
[156 157 158 159 160 161]
[162 163 164 165 166 167]
[168 169 170 171 172 173]
[174 175 176 177 178 179]]
[[180 181 182 183 184 185]
[186 187 188 189 190 191]
[192 193 194 195 196 197]
[198 199 200 201 202 203]
[204 205 206 207 208 209]]
[[210 211 212 213 214 215]
[216 217 218 219 220 221]
[222 223 224 225 226 227]
[228 229 230 231 232 233]
[234 235 236 237 238 239]]]
[[[240 241 242 243 244 245]
[246 247 248 249 250 251]
[252 253 254 255 256 257]
[258 259 260 261 262 263]
[264 265 266 267 268 269]]
[[270 271 272 273 274 275]
[276 277 278 279 280 281]
[282 283 284 285 286 287]
[288 289 290 291 292 293]
[294 295 296 297 298 299]]
[[300 301 302 303 304 305]
[306 307 308 309 310 311]
[312 313 314 315 316 317]
[318 319 320 321 322 323]
[324 325 326 327 328 329]]
[[330 331 332 333 334 335]
[336 337 338 339 340 341]
[342 343 344 345 346 347]
[348 349 350 351 352 353]
[354 355 356 357 358 359]]]]
[[[[360 361 362 363 364 365]
[366 367 368 369 370 371]
[372 373 374 375 376 377]
[378 379 380 381 382 383]
[384 385 386 387 388 389]]
[[390 391 392 393 394 395]
[396 397 398 399 400 401]
[402 403 404 405 406 407]
[408 409 410 411 412 413]
[414 415 416 417 418 419]]
[[420 421 422 423 424 425]
[426 427 428 429 430 431]
[432 433 434 435 436 437]
[438 439 440 441 442 443]
[444 445 446 447 448 449]]
[[450 451 452 453 454 455]
[456 457 458 459 460 461]
[462 463 464 465 466 467]
[468 469 470 471 472 473]
[474 475 476 477 478 479]]]
[[[480 481 482 483 484 485]
[486 487 488 489 490 491]
[492 493 494 495 496 497]
[498 499 500 501 502 503]
[504 505 506 507 508 509]]
[[510 511 512 513 514 515]
[516 517 518 519 520 521]
[522 523 524 525 526 527]
[528 529 530 531 532 533]
[534 535 536 537 538 539]]
[[540 541 542 543 544 545]
[546 547 548 549 550 551]
[552 553 554 555 556 557]
[558 559 560 561 562 563]
[564 565 566 567 568 569]]
[[570 571 572 573 574 575]
[576 577 578 579 580 581]
[582 583 584 585 586 587]
[588 589 590 591 592 593]
[594 595 596 597 598 599]]]
[[[600 601 602 603 604 605]
[606 607 608 609 610 611]
[612 613 614 615 616 617]
[618 619 620 621 622 623]
[624 625 626 627 628 629]]
[[630 631 632 633 634 635]
[636 637 638 639 640 641]
[642 643 644 645 646 647]
[648 649 650 651 652 653]
[654 655 656 657 658 659]]
[[660 661 662 663 664 665]
[666 667 668 669 670 671]
[672 673 674 675 676 677]
[678 679 680 681 682 683]
[684 685 686 687 688 689]]
[[690 691 692 693 694 695]
[696 697 698 699 700 701]
[702 703 704 705 706 707]
[708 709 710 711 712 713]
[714 715 716 717 718 719]]]]]
a[…,0]注意看和a[…,0:1] 的shape区别
print(a[…,0].shape)
(2, 3, 4, 5)
print(a[…,0:1].shape)
(2, 3, 4, 5, 1)
可以参考看下这个 python中的a[…,0],或者a[…,0:1]
https://blog.csdn.net/qq_44487483/article/details/116085502
print(a[...,0])
[[[[ 0 6 12 18 24]
[ 30 36 42 48 54]
[ 60 66 72 78 84]
[ 90 96 102 108 114]]
[[120 126 132 138 144]
[150 156 162 168 174]
[180 186 192 198 204]
[210 216 222 228 234]]
[[240 246 252 258 264]
[270 276 282 288 294]
[300 306 312 318 324]
[330 336 342 348 354]]]
[[[360 366 372 378 384]
[390 396 402 408 414]
[420 426 432 438 444]
[450 456 462 468 474]]
[[480 486 492 498 504]
[510 516 522 528 534]
[540 546 552 558 564]
[570 576 582 588 594]]
[[600 606 612 618 624]
[630 636 642 648 654]
[660 666 672 678 684]
[690 696 702 708 714]]]]
print(a[…,1])
[[[[ 1 7 13 19 25]
[ 31 37 43 49 55]
[ 61 67 73 79 85]
[ 91 97 103 109 115]]
[[121 127 133 139 145]
[151 157 163 169 175]
[181 187 193 199 205]
[211 217 223 229 235]]
[[241 247 253 259 265]
[271 277 283 289 295]
[301 307 313 319 325]
[331 337 343 349 355]]]
[[[361 367 373 379 385]
[391 397 403 409 415]
[421 427 433 439 445]
[451 457 463 469 475]]
[[481 487 493 499 505]
[511 517 523 529 535]
[541 547 553 559 565]
[571 577 583 589 595]]
[[601 607 613 619 625]
[631 637 643 649 655]
[661 667 673 679 685]
[691 697 703 709 715]]]]
print(a[…,0:1]) 和print(a[…,:1])是一样的
[[[[[ 0]
[ 6]
[ 12]
[ 18]
[ 24]]
[[ 30]
[ 36]
[ 42]
[ 48]
[ 54]]
[[ 60]
[ 66]
[ 72]
[ 78]
[ 84]]
[[ 90]
[ 96]
[102]
[108]
[114]]]
[[[120]
[126]
[132]
[138]
[144]]
[[150]
[156]
[162]
[168]
[174]]
[[180]
[186]
[192]
[198]
[204]]
[[210]
[216]
[222]
[228]
[234]]]
[[[240]
[246]
[252]
[258]
[264]]
[[270]
[276]
[282]
[288]
[294]]
[[300]
[306]
[312]
[318]
[324]]
[[330]
[336]
[342]
[348]
[354]]]]
[[[[360]
[366]
[372]
[378]
[384]]
[[390]
[396]
[402]
[408]
[414]]
[[420]
[426]
[432]
[438]
[444]]
[[450]
[456]
[462]
[468]
[474]]]
[[[480]
[486]
[492]
[498]
[504]]
[[510]
[516]
[522]
[528]
[534]]
[[540]
[546]
[552]
[558]
[564]]
[[570]
[576]
[582]
[588]
[594]]]
[[[600]
[606]
[612]
[618]
[624]]
[[630]
[636]
[642]
[648]
[654]]
[[660]
[666]
[672]
[678]
[684]]
[[690]
[696]
[702]
[708]
[714]]]]]
print(a[...,:2])
[[[[[ 0 1]
[ 6 7]
[ 12 13]
[ 18 19]
[ 24 25]]
[[ 30 31]
[ 36 37]
[ 42 43]
[ 48 49]
[ 54 55]]
[[ 60 61]
[ 66 67]
[ 72 73]
[ 78 79]
[ 84 85]]
[[ 90 91]
[ 96 97]
[102 103]
[108 109]
[114 115]]]
[[[120 121]
[126 127]
[132 133]
[138 139]
[144 145]]
[[150 151]
[156 157]
[162 163]
[168 169]
[174 175]]
[[180 181]
[186 187]
[192 193]
[198 199]
[204 205]]
[[210 211]
[216 217]
[222 223]
[228 229]
[234 235]]]
[[[240 241]
[246 247]
[252 253]
[258 259]
[264 265]]
[[270 271]
[276 277]
[282 283]
[288 289]
[294 295]]
[[300 301]
[306 307]
[312 313]
[318 319]
[324 325]]
[[330 331]
[336 337]
[342 343]
[348 349]
[354 355]]]]
[[[[360 361]
[366 367]
[372 373]
[378 379]
[384 385]]
[[390 391]
[396 397]
[402 403]
[408 409]
[414 415]]
[[420 421]
[426 427]
[432 433]
[438 439]
[444 445]]
[[450 451]
[456 457]
[462 463]
[468 469]
[474 475]]]
[[[480 481]
[486 487]
[492 493]
[498 499]
[504 505]]
[[510 511]
[516 517]
[522 523]
[528 529]
[534 535]]
[[540 541]
[546 547]
[552 553]
[558 559]
[564 565]]
[[570 571]
[576 577]
[582 583]
[588 589]
[594 595]]]
[[[600 601]
[606 607]
[612 613]
[618 619]
[624 625]]
[[630 631]
[636 637]
[642 643]
[648 649]
[654 655]]
[[660 661]
[666 667]
[672 673]
[678 679]
[684 685]]
[[690 691]
[696 697]
[702 703]
[708 709]
[714 715]]]]]
输出看一下就理解了。
print(a[:2, :3]) 输出所有的720个数,
print(a[:2, :3].shape)
(2, 3, 4, 5, 6)
print(a[1:2,:3])和print(a[1,:3])输出是一样的
print(a[1:2,:3].shape)
(1, 3, 4, 5, 6)
print(a[1:2,:3])
[[[[[360 361 362 363 364 365]
[366 367 368 369 370 371]
[372 373 374 375 376 377]
[378 379 380 381 382 383]
[384 385 386 387 388 389]]
[[390 391 392 393 394 395]
[396 397 398 399 400 401]
[402 403 404 405 406 407]
[408 409 410 411 412 413]
[414 415 416 417 418 419]]
[[420 421 422 423 424 425]
[426 427 428 429 430 431]
[432 433 434 435 436 437]
[438 439 440 441 442 443]
[444 445 446 447 448 449]]
[[450 451 452 453 454 455]
[456 457 458 459 460 461]
[462 463 464 465 466 467]
[468 469 470 471 472 473]
[474 475 476 477 478 479]]]
[[[480 481 482 483 484 485]
[486 487 488 489 490 491]
[492 493 494 495 496 497]
[498 499 500 501 502 503]
[504 505 506 507 508 509]]
[[510 511 512 513 514 515]
[516 517 518 519 520 521]
[522 523 524 525 526 527]
[528 529 530 531 532 533]
[534 535 536 537 538 539]]
[[540 541 542 543 544 545]
[546 547 548 549 550 551]
[552 553 554 555 556 557]
[558 559 560 561 562 563]
[564 565 566 567 568 569]]
[[570 571 572 573 574 575]
[576 577 578 579 580 581]
[582 583 584 585 586 587]
[588 589 590 591 592 593]
[594 595 596 597 598 599]]]
[[[600 601 602 603 604 605]
[606 607 608 609 610 611]
[612 613 614 615 616 617]
[618 619 620 621 622 623]
[624 625 626 627 628 629]]
[[630 631 632 633 634 635]
[636 637 638 639 640 641]
[642 643 644 645 646 647]
[648 649 650 651 652 653]
[654 655 656 657 658 659]]
[[660 661 662 663 664 665]
[666 667 668 669 670 671]
[672 673 674 675 676 677]
[678 679 680 681 682 683]
[684 685 686 687 688 689]]
[[690 691 692 693 694 695]
[696 697 698 699 700 701]
[702 703 704 705 706 707]
[708 709 710 711 712 713]
[714 715 716 717 718 719]]]]]
print(a[1:2,2:3])和 print(a[1:2,2])输出是一样的。可以这么理解,a前两个维度是2*3,一个切片就是720/6=120个数
[[[[[600 601 602 603 604 605]
[606 607 608 609 610 611]
[612 613 614 615 616 617]
[618 619 620 621 622 623]
[624 625 626 627 628 629]]
[[630 631 632 633 634 635]
[636 637 638 639 640 641]
[642 643 644 645 646 647]
[648 649 650 651 652 653]
[654 655 656 657 658 659]]
[[660 661 662 663 664 665]
[666 667 668 669 670 671]
[672 673 674 675 676 677]
[678 679 680 681 682 683]
[684 685 686 687 688 689]]
[[690 691 692 693 694 695]
[696 697 698 699 700 701]
[702 703 704 705 706 707]
[708 709 710 711 712 713]
[714 715 716 717 718 719]]]]]
注意以下输出与print(a[1,:3].shape)都是一样的,输出的内容也是一样的都是360-720,这个表示720在第一个维度上切片。
print(a[1:2,:4].shape)
(1, 3, 4, 5, 6)
print(a[1:3, :5].shape)
(1, 3, 4, 5, 6)
print(a[1:2,:6].shape)
(1, 3, 4, 5, 6)
import numpy as np
a = np.arange(720).reshape((2,3,4,5,6)) #五行六列 4个五行六列 3个(4个五行六列)
print(a)
# print(a[..., :])
# print(a[2:3, 2:5,None])
# []
# print(a[2:3, :5])
# []
# print(a[1:3, :5])等效于print(a[1,:3])
# print(a[1:3, 2:5])等效于print(a[1:3,2:3])等效于print(a[1, 2:3])
# print(a[1:3, 2:5,None])等效于print(a[1, 2:3,None])
# print(a[...,2])
# print(a[...,2:4])
# b = a[...,2:4]
# print(b/2)
3 实战代码演练
yolo4 中CIOU的计算
def box_ciou(b1, b2):
"""
输入为:
----------
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
返回为:
-------
ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
"""
# 求出预测框左上角右下角
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh/2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
# 求出真实框左上角右下角
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh/2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
# 求真实框和预测框所有的iou
intersect_mins = K.maximum(b1_mins, b2_mins)
intersect_maxes = K.minimum(b1_maxes, b2_maxes)
intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / (union_area + K.epsilon())
# 计算中心的差距
center_distance = K.sum(K.square(b1_xy - b2_xy), axis=-1)
# 找到包裹两个框的最小框的左上角和右下角
enclose_mins = K.minimum(b1_mins, b2_mins)
enclose_maxes = K.maximum(b1_maxes, b2_maxes)
enclose_wh = K.maximum(enclose_maxes - enclose_mins, 0.0)
# 计算对角线距离
enclose_diagonal = K.sum(K.square(enclose_wh), axis=-1)
# calculate ciou, add epsilon in denominator to avoid dividing by 0
ciou = iou - 1.0 * (center_distance) / (enclose_diagonal + K.epsilon())
# calculate param v and alpha to extend to CIoU
v = 4*K.square(tf.math.atan2(b1_wh[..., 0], b1_wh[..., 1]) - tf.math.atan2(b2_wh[..., 0], b2_wh[..., 1])) / (math.pi * math.pi)
alpha = v / (1.0 - iou + v)
ciou = ciou - alpha * v
ciou = K.expand_dims(ciou, -1)
return ciou
引用大佬博客:https://blog.csdn.net/weixin_44791964/article/details/106014717
继续加油奥里给