前言:各种踩过的坑
plt.subplot(121) 画图一行两列 中的第一个图
raise() 自动显示异常,一旦raise()执行后面的语句不在执行
yield python的一个生成器
简单理解:到yield就返回,下一次执行从yield下面的一条语句执行
def funny():
for i in range(1,10):
yield i
f = funny()
while next(f):
print(next(f)) # next方法
def fun():
n = 0
while True:
n = yield n
a = fun()
next(a) # 在一个生成器未启动前不能传值
print(a.send(1)) # send传值
应用
需要一个无限循环序列,通常解决方法是生成一个非常大的列表,但很明显内存限制了这种办法
使用yield就可以解决问题,每次只返回一个数据
fluid.dygraph.guard()通过with语句创建一个dygraph运行的context,执行context代码。
优化器:多种优化器连接
注:有一个数据格式转换,excel转csv要把excel另存为csv
csv最后一行后面有空格,运行时就会发现,出现错误删除即可
CSVFILE = 'H:/eyeWork/valid_gt/PALM-Validation-GT/Labels.csv'
filelists = open(CSVFILE).readlines()
for line in filelists[1:]:
line = line.strip().split(',') # 这里根据数据格式而定
print(line)
整体网络
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
# %matplotlib inline
from PIL import Image
DATADIR = 'H:/eyeWork/PALM-Training400/PALM-Training400' # 这里有改动
file1 = 'N0012.jpg'
file2 = 'P0095.jpg'
# 读取图片
img1 = Image.open(os.path.join(DATADIR,file1))
img1 = np.array(img1)
img2 = Image.open(os.path.join(DATADIR,file2))
img2 = np.array(img2)
# 画出读取的数据
plt.figure(figsize=(16,8))
f = plt.subplot(121)
f.set_title('Normal',fontsize=20)
plt.imshow(img1)
f = plt.subplot(122)
f.set_title('PM',fontsize=20)
plt.imshow(img2)
plt.show()
# 定义数据读取器
# 使用opencv从磁盘读取图片,将每张图片放缩到224*224大小,并且将像素值调整到[-1,1] 之间
import cv2
import random
import numpy as np
# 对读入的图像数据进行处理
def transform_img(img):
# 图片尺寸放缩到 224*224
img = cv2.resize(img,(224,224))
# 读入图片的格式是[H,W,c]
# 使用转置操作将其变成[C,H,W]
img = np.transpose(img,(2,