前言
任务是这样的,我们需要设计一个模型,针对不同的名字,判断他到底是哪个国家。
比如:
bob 英国的
bingbing Wang 中国的
一、先上代码
#根据名字分类国家
import numpy as np
import torch
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import gzip
import csv
import time
class NameDataset(Dataset):
def __init__(self,is_train_set=True):
filename='names_train.csv.gz'if is_train_set else'names_test.csv.gz'
#多种不同的方式读取数据,gzip格式如下
with gzip.open(filename,'rt') as f:
reader=csv.reader(f)
rows=list(reader)#name,lauguage对
self.names=[row[0] for row in rows]
self.len=len(self.names)
self.countries=[row[1] for row in rows]
#set 集合,去除重复元素 sort排序 list变成列表
self.country_list=list(sorted(set(self.countries)))
#country变成一个字典
self.country_dict=self.getCountryDict()
self.country_num=len(self.country_list)#索引总数
#save countries and its index in list and dictionary
def __getitem__(self, index):
#名字string country索引
return self.names[index],self.country_dict[self.countries[index]]
def __len__(self):
return self.len
#convert list to dictionary
def getCountryDict(self):
country_dict=dict()#空字典