代码实现的是VGG-16的结构。
# -*- coding:utf-8 -*-
#
# VGG-16 Net model
import numpy as np
import tensorflow as tf
class Vgg16:
def __init__(self, images, name):
self.name = name
self.input = images
self.output = self.vgg16(self.input)
list_vars = tf.trainable_variable()
self.vars = [var for var in list_vars]
def get_conv_weight(self, shape, name):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1), name=name)
def get_bias(self, shape, name):
return tf.Variable(tf.constant(0.0, shape=shape), name=name)
def get_fc_weight(self, shape, name):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1), name=name)
def conv_layer(self, x, ks, out_units, name):
with t