1.由分步执行改成一个文件
训练文件
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
session = tf.Session(config=config)
KTF.set_session(session)
import keras
from models.psenet import psenet
shape = (None,None,3)
inputs = keras.layers.Input(shape=shape)
output = psenet(inputs)
model = keras.models.Model(inputs,output)
model.summary()
from keras.optimizers import Adam
from models.loss import build_loss
from models.metrics import build_iou,mean_iou
from keras.utils import multi_gpu_model
# parallel_model = multi_gpu_model(model,gpus=1)
parallel_model=model
adam = Adam(1e-3)
ious = build_iou([0,1],['bk','txt'])
parallel_model.compile(loss=build_loss,
optimizer=adam,
metrics=ious)
import config
from tool.generator import Generator
train_dir = config.MIWI_2018_TRAIN_LABEL_DIR
test_dir = config.MIWI_2018_TEST_LABEL_DIR
batch_size = 1
num_class =2
shape = (640,640)
gen_train = Generator(train_dir,batch_size = batch_size ,istraining=True,num_classes=num_class,mirror = False,reshape=shape)
gen_test = Generator(test_dir,batch_size = batch_size ,istraining=False,num_classes=num_class,
reshape=shape,mirror=False,scale=False,clip=False,trans_color=False)
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
checkpoint = ModelCheckpoint(r'resent50-190422_BLINEAR-{epoch:02d}.hdf5',
save_weights_only=True)
tb = TensorBoard(log_dir='./logs')
print(gen_test.num_samples(),gen_train.num_samples())
res = parallel_model.fit_generator(gen_train,
steps_per_epoch =gen_train.num_samples()// batch_size,
epochs = 40,
validation_data=gen_test,
validation_steps =gen_test.num_samples()//batch_size,
verbose=1,
initial_epoch=0,
workers=4,
max_queue_size=16,
callbacks=[checkpoint,tb])
测试文件:
import keras
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.ConfigProto(device_count={'gpu':0})
# config.gpu_options.allow_growth=True
config.gpu_options.per_process_gpu_memory_fraction = 0.85
session = tf.Session(config=config)
KTF.set_session(session)
from models.psenet import psenet
shape = (None,None,3)
inputs = keras.layers.Input(shape=shape)
output = psenet(inputs)
model = keras.models.Model(inputs,output)
model.summary()
model.load_weights('resent50-190219_BLINEAR-iou8604.hdf5')
import glob
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm
# dir = '/home/yang/Documents/data/ali/mtwi_2018_train/image_test'
dir = '/home/yang/Documents/data/ali/icpr_mtwi_task2/image_test'
saveimgdir="/home/yang/Documents/model/detect/PSENET-keras/imgs/result/img"
savetxtdir="/home/yang/Documents/model/detect/PSENET-keras/imgs/result/txt"
imagesfile = glob.glob(os.path.join(dir,'*.jpg'))
MIN_LEN = 640
MAX_LEN = 1024
from tool.utils import ufunc_4, scale_expand_kernels, fit_minarearectange, fit_boundingRect, save_MTWI_2108_resault
with tqdm.tqdm(total=len(imagesfile)) as bar:
for i, j in enumerate(imagesfile):
bar.update()
try:
images = cv2.imdecode(np.fromfile(j, dtype=np.uint8), -1)
h, w = images.shape[0:2]
if (w < h and w < MIN_LEN):
h = MIN_LEN / w * h
w = MIN_LEN
elif (h <= w and h < MIN_LEN):
w = MIN_LEN / h * w
h = MIN_LEN
w = min(w, MAX_LEN)
h = min(h, MAX_LEN)
w = int(w // 32 * 32)
h = int(h // 32 * 32)
# w = 640
# h = 640
scalex = images.shape[1] / w
scaley = images.shape[0] / h
images = cv2.resize(images, (w, h), cv2.INTER_AREA)
images = np.reshape(images, (1, h, w, 3))
res = model.predict(images)
res1 = res[0]
res1[res1 > 0.9] = 1
res1[res1 <= 0.9] = 0
newres1 = []
for i in range(5):
n = np.logical_and(res1[:, :, 5], res1[:, :, i]) * 255
newres1.append(n)
newres1.append(res1[:, :, 5] * 255)
num_label, labelimage = scale_expand_kernels(newres1)
rects = fit_minarearectange(num_label, labelimage)
cv2.drawContours(images[0], np.array(rects) * 2, -1, (0, 0, 255), 2)
base_name = '.'.join(os.path.basename(j).split('.')[:-1])
# cv2.imwrite(os.path.join(saveimgdir, base_name + '.jpg'), images[0])
save_MTWI_2108_resault(os.path.join(savetxtdir, base_name + '.txt'), np.array(rects) * 2, scalex, scaley)
except Exception as e:
print(j)
continue
2.要先生成标签文件
执行tool文件夹下的gen_dataset.py,并修改npy文件保存类型为unit8,否则文件会很大。
npys = np.zeros((img.shape[0],img.shape[1],config.n),dtype='uint8')