学会了dataset dataloader的实现:
import numpy as np class Model: def __init__(self): self.model = np.random.normal(size=(max_len, 1)) def forward(self, text_index): pre = text_index @ self.model return pre class Dhl_Dataset: def __init__(self, all_text, all_label, max_len, batch_size, shuffle, word_2_index): self.all_text = np.array(all_text) self.word_2_index = word_2_index self.all_label = np.array(all_label) self.max_len = max_len assert len(self.all_text) == len(self.all_label), print("数据标签长度不等!") self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): return Dhl_DataLoader(self) def __len__(self): return len(self.all_text) class Dhl_DataLoader: def __init__(self, dataset): self.dataset = dataset self.shuffle_index = np.array([i for i in range(0, len(self.dataset))]) self.cursor = 0 if self.dataset.shuffle == True: np.random.shuffle(self.shuffle_index) def __getitem__(self, index): text = self.dataset.all_text[index][:self.dataset.max_len] label =self.dataset.all_label[index] text_index = [self.dataset.word_2_index[word] for word in text] text_index = text_index + [0] * (self.dataset.max_len - len(text_index)) return np.array(text_index), np.array(label) def __next__(self): if self.cursor >= len(self.dataset): raise StopIteration index = self.shuffle_index[self.cursor:self.cursor + self.dataset.batch_size] batch_text_index = [] batch_label = [] for i in index: text_index, label = self[i] batch_text_index.append(text_index) batch_label.append(label) self.cursor += self.dataset.batch_size return batch_text_index, batch_label def read_data(file_path): with open(file_path, "r", encoding="utf-8") as f: all_text = [] all_label = [] all_data = f.read().split("\n") for line in all_data: text, label = line.split(" ") all_text.append(text) all_label.append(int(label)) return all_text, all_label def build_word_2_index(all_text): word_2_index = {"<PAD>": 0} for text in all_text: for word in text: word_2_index[word] = word_2_index.get(word, len(word_2_index)) #if word not in word_2_index: #word_2_index[word] = len(word_2_index) return word_2_index if __name__ == "__main__": train_text, train_label = read_data("..//not_that_easy//data//train.txt") assert len(train_text) == len(train_label), print("文本数量和标签数量不等!") print(f"加载数据成功,长度为:{len(train_text)}") word_2_index = build_word_2_index(train_text) epoch = 10 batch_size = 3 max_len = 10 shuffle = True train_dataset = Dhl_Dataset(train_text, train_label, max_len, batch_size, shuffle, word_2_index) model = Model() for e in range(epoch): print(f"this is the {e}th epoch") for batch_text_index, batch_label in train_dataset: pre = model.forward(batch_text_index) print(pre)
学习了yield的使用
简单的pandas:
我他妈碰到了一个傻逼bug红红火火恍恍惚惚,
AttributeError: module 'pandas' has no attribute 'read_csv'
因为爷把一个py文件命名为pandas.py 哈哈哈哈笑死我了