download, extract the data

1.传参

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--kitti_url', default='', type=str)
args = parser.parse_args()
kitti_data_url = args.kitti_url

这几行的作用是外部传入参数--kitti_url http://mynameisjupy.com  详细参考Python命令行解析argparse常用语法使用简介

parser.add_argument里面参数讲解参考python的argparse模块add_argument详解

2.文件操作

    train_txt = "data/train3.txt"
    val_txt = "data/val3.txt"
    testing_txt = "data/testing.txt"
    copy2(train_txt, kitti_road_dir)
    copy2(val_txt, kitti_road_dir)
    copy2(testing_txt, kitti_road_dir)

详细参照python- shutil 高级文件操作

3.下载数据集或者与训练模型

使用

import os
import sys  
import tarfile # 解压tar文件
from six.moves import urllib  # 下载工具
def maybe_download_and_extract():
  """Download and extract model tar file.

  If the pretrained model we're using doesn't already exist, this function
  downloads it from the TensorFlow.org website and unpacks it into a directory.
  """
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):

    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' %
                       (filename,
                        float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()

    filepath, _ = urllib.request.urlretrieve(DATA_URL,
                                             filename=filepath,
                                             reporthook=_progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)

 

"""Download data relevant to train the KittiSeg model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import sys
import os
import subprocess

import zipfile


from six.moves import urllib
from shutil import copy2

import argparse

logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logging.INFO,
                    stream=sys.stdout)

sys.path.insert(1, 'incl')

# Please set kitti_data_url to the download link for the Kitti DATA.
#
# You can obtain by going to this website:
# http://www.cvlibs.net/download.php?file=data_road.zip
#
# Replace 'http://kitti.is.tue.mpg.de/kitti/?????????.???' by the
# correct URL.


vgg_url = 'ftp://mi.eng.cam.ac.uk/pub/mttt2/models/vgg16.npy'


def get_pathes():
    """
    Get location of `data_dir` and `run_dir'.

    Defaut is ./DATA and ./RUNS.
    Alternativly they can be set by the environoment variabels
    'TV_DIR_DATA' and 'TV_DIR_RUNS'.
    """

    if 'TV_DIR_DATA' in os.environ:
        data_dir = os.path.join(['hypes'], os.environ['TV_DIR_DATA'])
    else:
        data_dir = "DATA"

    if 'TV_DIR_RUNS' in os.environ:
        run_dir = os.path.join(['hypes'], os.environ['TV_DIR_DATA'])
    else:
        run_dir = "RUNS"

    return data_dir, run_dir


def download(url, dest_directory):
    filename = url.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)

    logging.info("Download URL: {}".format(url))
    logging.info("Download DIR: {}".format(dest_directory))

    def _progress(count, block_size, total_size):
                prog = float(count * block_size) / float(total_size) * 100.0
                sys.stdout.write('\r>> Downloading %s %.1f%%' %
                                 (filename, prog))
                sys.stdout.flush()

    filepath, _ = urllib.request.urlretrieve(url, filepath,
                                             reporthook=_progress)
    print()
    return filepath


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--kitti_url', default='', type=str)
    args = parser.parse_args()
    kitti_data_url = args.kitti_url

    data_dir, run_dir = get_pathes()

    vgg_weights = os.path.join(data_dir, 'weights', 'vgg16.npy')

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    # Download VGG DATA
    if not os.path.exists(vgg_weights):
        download_command = "wget {} -P {}".format(vgg_url, data_dir)
        logging.info("Downloading VGG weights: {}".format(download_command))
        download(vgg_url, data_dir)
    else:
        logging.warning("File: {} exists.".format(vgg_weights))
        logging.warning("Please delete to redownload VGG weights.")

    data_road_zip = os.path.join(data_dir, 'data_road.zip')

    # Download KITTI DATA
    if not os.path.exists(data_road_zip):
        if kitti_data_url == '':
            logging.error("Data URL for Kitti Data not provided.")
            url = "http://www.cvlibs.net/download.php?file=data_road.zip"
            logging.error("Please visit: {}".format(url))
            logging.error("and request Kitti Download link.")
            logging.error("Rerun scipt using"
                          "'python download_data.py' --kitti_url [url]")
            exit(1)
        if not kitti_data_url[-19:] == 'kitti/data_road.zip':
            logging.error("Wrong url.")
            url = "http://www.cvlibs.net/download.php?file=data_road.zip"
            logging.error("Please visit: {}".format(url))
            logging.error("and request Kitti Download link.")
            logging.error("Rerun scipt using"
                          "'python download_data.py' --kitti_url [url]")
            exit(1)
        else:
            logging.info("Downloading Kitti Road Data.")
            download(kitti_data_url, data_dir)

    # Extract and prepare KITTI DATA
    logging.info("Extracting kitti_road data.")
    zipfile.ZipFile(data_road_zip, 'r').extractall(data_dir)
    kitti_road_dir = os.path.join(data_dir, 'data_road/')

    logging.info("Preparing kitti_road data.")

    train_txt = "data/train3.txt"
    val_txt = "data/val3.txt"
    testing_txt = "data/testing.txt"
    copy2(train_txt, kitti_road_dir)
    copy2(val_txt, kitti_road_dir)
    copy2(testing_txt, kitti_road_dir)

    logging.info("All data have been downloaded successful.")


if __name__ == '__main__':
    main()

 

MNIST数据集是一个常用的手写数字识别数据集,可以用于初步学习神经网络。可以通过以下Python代码使用urllib.request库下载MNIST数据集: ```python import urllib.request import os # 下载MNIST数据集 def download_mnist(): base_url = 'http://yann.lecun.com/exdb/mnist/' file_names = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] save_path = './mnist_data' if not os.path.exists(save_path): os.makedirs(save_path) for file_name in file_names: url = (base_url + file_name).format(**locals()) print("Downloading " + url) urllib.request.urlretrieve(url, os.path.join(save_path, file_name)) print("Download finished.") ``` 这段代码会将MNIST数据集下载到当前目录下的mnist_data文件夹中。下载完成后,可以使用gzip库解压缩数据集文件,例如: ```python import gzip import numpy as np # 解压缩MNIST数据集文件 def extract_data(file_name, num_data, head_size, data_size): with gzip.open(file_name) as f: f.read(head_size) buf = f.read(data_size * num_data) data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) return data # 加载MNIST数据集 def load_mnist(): save_path = './mnist_data' train_images = extract_data(os.path.join(save_path, 'train-images-idx3-ubyte.gz'), 60000, 16, 28*28) train_labels = extract_data(os.path.join(save_path, 'train-labels-idx1-ubyte.gz'), 60000, 8, 1) test_images = extract_data(os.path.join(save_path, 't10k-images-idx3-ubyte.gz'), 10000, 16, 28*28) test_labels = extract_data(os.path.join(save_path, 't10k-labels-idx1-ubyte.gz'), 10000, 8, 1) return train_images, train_labels, test_images, test_labels ``` 这段代码会将MNIST数据集解压缩并加载到内存中,返回四个NumPy数组,分别是训练图像、训练标签、测试图像和测试标签。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值