# 需要导入模块: import scipy [as 别名]
# 或者: from scipy import vstack [as 别名]
def train(self, datadir, pickle_model=""):
texts= []
labels= []
training_data = os.listdir(datadir)
rcount= 0
texts2= []
batchsize= 100000
batch_data = BatchData()
p_input= None
for jsonfile in training_data:
with open(datadir + "/" + jsonfile, 'r') as inputfile:
for line in inputfile:
#if rcount > 1000000: break
try: line= json.loads(line.strip())
except: continue
for review in line["Reviews"]:
rcount+= 1
if rcount % 100000 == 0: print(rcount)
if rcount % 8 != 0: continue
if "Overall" not in review["Ratings"]: continue
texts.append(review["Content"])
labels.append((float(review["Ratings"]["Overall"]) - 3) *0.5)
if len(texts) % batchsize == 0:
if p_input != None:
p_input.join()
texts2.append(batch_data.texts)
p_input = threading.Thread(target=self.transform_batch, args=(texts, batch_data))
p_input.start()
texts= []
if p_input != None:
p_input.join()
texts2.append(batch_data.texts)
texts2.append(self.wb.partial_fit_transform(texts))
del(texts)
texts= sp.vstack(texts2)
self.wb.dictionary_freeze = True
test= (np.array(texts[-1000:]), np.array(labels[-1000:]))
train = (np.array(texts[:-1000]), np.array(labels[:-1000]))
self.model.fit(train[0], train[1], batch_size=2048, epochs=2, validation_data=(test[0], test[1]))
if pickle_model != "":
self.model.save(pickle_model)
backend = self.wb.batcher.backend
backend_handle = self.wb.batcher.backend_handle
self.wb.batcher.backend = "serial"
self.wb.batcher.backend_handle = None
with gzip.open(pickle_model + ".wb", 'wb') as model_file: pkl.dump(self.wb, model_file, protocol=2)
self.wb.batcher.backend = backend
self.wb.batcher.backend_handle = backend_handle