实战:预测房价
此数据集由Bart de Cock于2011年收集 , 涵盖了2006-2010年期间亚利桑那州埃姆斯市的房价。 这个数据集是相当通用的,不会需要使用复杂模型架构。
下载和缓存数据集
这里实现几个函数来方便下载数据。 首先,我们建立字典DATA_HUB
, 它可以将数据集名称的字符串映射到数据集相关的二元组上, 这个二元组包含数据集的url和验证文件完整性的sha-1密钥。 所有类似的数据集都托管在地址为DATA_URL
的站点上。
import hashlib
import os
import tarfile
import zipfile
import requests
#@save
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'
下面的download
函数用来下载数据集, 将数据集缓存在本地目录(默认情况下为../data
)中, 并返回下载文件的名称。
如果缓存目录中已经存在此数据集文件,并且其sha-1与存储在DATA_HUB
中的相匹配, 我们将使用缓存的文件,以避免重复的下载。
def download(name, cache_dir=os.path.join('..', 'data')): #@save
"""下载一个DATA_HUB中的文件,返回本地文件名"""
assert name in DATA_HUB, f"{
name} 不存在于 {
DATA_HUB}"
url, sha1_hash = DATA_HUB[name]
os.makedirs(cache_dir, exist_ok=True)
fname = os.path.join(cache_dir, url.split('/')[-1])
if os.path.exists(fname):
sha1 = hashlib.sha1()
with open(fname, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
if sha1.hexdigest() == sha1_hash:
return fname # 命中缓存
print(f'正在从{
url}下载{
fname}...')
r = requests.get(url, stream=True, verify=True)
with open(fname, 'wb') as f:
f.write(r.content)
return fname
我们还需实现两个实用函数: 一个将下载并解压缩一个zip或tar文件, 另一个是将使用的所有数据集从DATA_HUB
下载到缓存目录中。
def download_extract(name, folder=None): #@save
"""下载并解压zip/tar文件"""
fname = download(name)
base_dir = os.path.dirname(fname)
data_dir, ext = os.path.splitext(fname)
if ext == '.zip':
fp = zipfile.ZipFile(fname, 'r')
elif ext in ('.tar', '.gz'):
fp = tarfile.open(fname, 'r')
else:
assert False, '只有zip/tar文件可以被解压缩'
fp.extractall(base_dir)
return os.path.join(base_dir, folder) if folder else data_dir
def download_all(): #@save