Arcgis Prv2.8.3 UNET模型RTX 3090 训练

我 的 segmentation.py

python```

import warnings

# Import all methods/classes for BC:

from . import *  # noqa: F401, F403


 

warnings.warn(

    "The 'torchvision.models.segmentation.segmentation' module is deprecated since 0.12 and will be removed in "

    "0.14. Please use the 'torchvision.models.segmentation' directly instead."

)

from .._utils import IntermediateLayerGetter

from ..utils import load_state_dict_from_url

from .. import resnet

from .deeplabv3 import DeepLabHead, DeepLabV3

from .fcn import FCN, FCNHead


 

__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']


 

model_urls = {

    'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',

    'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',

    'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',

    'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',

}


 

def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):

    backbone = resnet.__dict__[backbone_name](

        pretrained=pretrained_backbone,

        replace_stride_with_dilation=[False, True, True])

    return_layers = {'layer4': 'out'}

    if aux:

        return_layers['layer3'] = 'aux'

    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None

    if aux:

        inplanes = 1024

        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {

        'deeplabv3': (DeepLabHead, DeepLabV3),

        'fcn': (FCNHead, FCN),

    }

    inplanes = 2048

    classifier = model_map[name][0](inplanes, num_classes)

    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)

    return model


 

def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):

    if pretrained:

        aux_loss = True

    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)

    if pretrained:

        arch = arch_type + '_' + backbone + '_coco'

        model_url = model_urls[arch]

        if model_url is None:

            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))

        else:

            state_dict = load_state_dict_from_url(model_url, progress=progress)

            model.load_state_dict(state_dict)

    return model


 

def fcn_resnet50(pretrained=False, progress=True,

                 num_classes=21, aux_loss=None, **kwargs):

    """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

    Args:

        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which

            contains the same classes as Pascal VOC

        progress (bool): If True, displays a progress bar of the download to stderr

    """

    return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)



 

def fcn_resnet101(pretrained=False, progress=True,

                  num_classes=21, aux_loss=None, **kwargs):

    """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

    Args:

        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which

            contains the same classes as Pascal VOC

        progress (bool): If True, displays a progress bar of the download to stderr

    """

    return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)



 

def deeplabv3_resnet50(pretrained=False, progress=True,

                       num_classes=21, aux_loss=None, **kwargs):

    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:

        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which

            contains the same classes as Pascal VOC

        progress (bool): If True, displays a progress bar of the download to stderr

    """

    return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)



 

def deeplabv3_resnet101(pretrained=False, progress=True,

                        num_classes=21, aux_loss=None, **kwargs):

    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:

        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which

            contains the same classes as Pascal VOC

        progress (bool): If True, displays a progress bar of the download to stderr

    """

    return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)

```

先克隆arcgis pro 原有环境

pip uninstall torch相关库,一共7个

安装下面分享的python库

链接:https://pan.baidu.com/s/1AklKmI3UJj3SeAncuO4Vhw?pwd=u92i 
提取码:u92i

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

GIS 数据栈

谢谢打赏!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值