gnn_caphcha(使用deepmind的graphs_net)

(1)curtain.py,读取图片并处理



from graph_nets import utils_np
from pathlib import Path

import networkx as nx
import numpy as np
import random
import cv2

class Code:
	def __init__(self,path_kind,batch_size,stride,img_width,img_height):
		self.alpha = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','a','b',
					  'c','d','e','f','g','h','i','j','k','l','m','n','o','p',
					  'q','r','s','t','u','v','w','x','y','z']
					  
		self.data_root = Path(path_kind)
		self.batch_size = batch_size
		self.img_height = img_height
		self.img_width = img_width
		self.stride = stride
		
		self.block_height = int(img_height / stride)
		self.block_width = int(img_width / stride)
		
		self.load()
		
	def load(self):
		
		self.second_image_paths = list(self.data_root.glob('*'))
		self.second_image_paths=[str(path) for path in self.second_image_paths]
		self.total_number = len(self.second_image_paths) // self.batch_size
	
	
	def mess_up_order(self):
		random.shuffle(self.second_image_paths)
	
	
	def deal_image(self,image):
		graph_nx = nx.OrderedMultiDiGraph()

		# Globals.
		graph_nx.graph["features"] = np.random.randn(36)

		# Nodes.
		for i in range(self.block_height):
			for j in range(self.block_width):
				graph_nx.add_node(i*self.block_width+j, features=image[i*self.stride:(i+1)*self.stride,j*self.stride:(j+1)*self.stride].flatten())
				
				
				
		# Edges.
		for i in range(self.block_height):
			for j in range(self.block_width):
				if i-1>=0 and j>=0:
					graph_nx.add_edge(i*self.block_width+j, (i-1)*self.block_width+j, features=np.random.randn(10))
					graph_nx.add_edge((i-1)*self.block_width+j,i*self.block_width+j, features=np.random.randn(10))
					
				if i+1>=0 and j>=0 and i+1<=self.block_height-1:
					graph_nx.add_edge(i*self.block_width+j, (i+1)*self.block_width+j, features=np.random.randn(10))
					graph_nx.add_edge((i+1)*self.block_width+j,i*self.block_width+j, features=np.random.randn(10))
					
				if i>=0 and j-1>=0:
					graph_nx.add_edge(i*self.block_width+j, i*self.block_width+j-1, features=np.random.randn(10))
					graph_nx.add_edge(i*self.block_width+j-1,i*self.block_width+j, features=np.random.randn(10))
					
				if i>=0 and j+1>=0 and j+1<=self.block_width-1:
					graph_nx.add_edge(i*self.block_width+j, i*self.block_width+j+1, features=np.random.randn(10))
					graph_nx.add_edge(i*self.block_width+j+1,i*self.block_width+j, features=np.random.randn(10))
		
		return graph_nx
		
	def next_batch(self,index):
		graph_dicts = []
		labels = []
		for k,path in enumerate(self.second_image_paths[self.batch_size*index:self.batch_size*(index+1)]) : 
			
			temp_str = path.split('\\')[-1]
			begin=temp_str.find('_')
			end=temp_str.find('.')
			label = self.alpha.index(temp_str[begin+2:end])
			
			img = cv2.imread(path)
			img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
			
			img = img / 255.0;
			
			graph_dicts.append(self.deal_image(img))
			labels.append(label)
			
		return utils_np.networkxs_to_graphs_tuple(graph_dicts),np.eye(len(self.alpha))[labels]

(2)tensorflow1版本的训练


import tensorflow as tf  
import os

from graph_nets import utils_np
from graph_nets import utils_tf

import sonnet as snt
import graph_nets as gn

from curtain import Code

batch_size = 16
img_height = 100
img_width = 56
learning_rate = 1e-4
max_iteration = 1000000

stride = 8

checkpoint_root = "./checkpoints"
checkpoint_name = "model"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

start_step = 0
code = Code("../tf2_modeldict/image",batch_size,stride,img_width,img_height)


OUTPUT_EDGE_SIZE = 256
OUTPUT_NODE_SIZE = 256
OUTPUT_GLOBAL_SIZE = 36

node = snt.Sequential([
	snt.Linear(1024),
	tf.nn.relu,
	snt.Linear(OUTPUT_NODE_SIZE)
])

edge = snt.Sequential([
	snt.Linear(1024),
	tf.nn.relu,
	snt.Linear(OUTPUT_EDGE_SIZE)
])

global_s = snt.Sequential([
	snt.Linear(256),
	tf.nn.relu,
	snt.Linear(512),
	tf.nn.relu,
	snt.Linear(OUTPUT_GLOBAL_SIZE)
])

graph_network = gn.modules.GraphNetwork(
	edge_model_fn=lambda: node,
	node_model_fn=lambda: edge,
	global_model_fn=lambda: global_s)
  


checkpoint = tf.train.Checkpoint(module=graph_network)

latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
	checkpoint.restore(latest)

loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(1e-4)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

############# tf_function ###################
def update_step(inputs_tr, targets_tr):
	with tf.GradientTape() as tape:
		outputs_tr = graph_network(inputs_tr).globals
		# Loss.
		#loss_tr = loss_object(targets_tr,outputs_tr)
		loss_tr = tf.nn.softmax_cross_entropy_with_logits(logits=outputs_tr, labels=targets_tr)
	gradients = tape.gradient(loss_tr, graph_network.trainable_variables)
	optimizer.apply_gradients(zip(gradients, graph_network.trainable_variables))
	return outputs_tr, loss_tr

def specs_from_tensor(tensor_sample,description_fn=tf.TensorSpec):
	
	shape = list(tensor_sample.shape)
	dtype = tensor_sample.dtype

	return description_fn(shape=shape, dtype=dtype) 

# Get some example data that resembles the tensors that will be fed
# into update_step():

Input_data, example_target_data = code.next_batch(0)
graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data)	
example_input_data = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)

# Get the input signature for that function by obtaining the specs
input_signature = [
  utils_tf.specs_from_graphs_tuple(example_input_data), #输入数据的形状大小 这里是"有名字的元组"
  specs_from_tensor(example_target_data) # 输出数据的形状大小 这里单纯是矩阵形式
]

# Compile the update function using the input signature for speedy code.
compiled_update_step = tf.function(update_step, input_signature=input_signature)
############# tf_function ###################

for echo in range(max_iteration):
	code.mess_up_order()
	
	for i in range(code.total_number):
		Input_data, Output_data = code.next_batch(i)
		graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data)
		graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
		
		outputs_tr, loss = compiled_update_step(graphs_tuple_tf, Output_data)
		
		print('Echo %d,Iter %d: train_loss is: %.5f train_accuracy is: %.5f'%(echo+1, i+1, tf.reduce_mean(loss),train_accuracy(Output_data,outputs_tr)))
	
		
		if i and i % 10 == 0:
			checkpoint.save(save_prefix)

	




(3)tensorflow2版本训练


import tensorflow as tf  
import os

from graph_nets import utils_np
from graph_nets import utils_tf

import sonnet as snt
import graph_nets as gn

from curtain import Code

batch_size = 16
img_height = 100
img_width = 56
learning_rate = 1e-4
max_iteration = 1000000

stride = 8

checkpoint_root = "./checkpoints"
checkpoint_name = "model"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

start_step = 0
code = Code("../tf2_modeldict/image",batch_size,stride,img_width,img_height)


OUTPUT_EDGE_SIZE = 256
OUTPUT_NODE_SIZE = 256
OUTPUT_GLOBAL_SIZE = 36

node = snt.Sequential([
    snt.Linear(1024),
    tf.nn.relu,
    snt.Linear(OUTPUT_NODE_SIZE)
])

edge = snt.Sequential([
    snt.Linear(1024),
    tf.nn.relu,
    snt.Linear(OUTPUT_EDGE_SIZE)
])

global_s = snt.Sequential([
    snt.Linear(256),
    tf.nn.relu,
    snt.Linear(512),
    tf.nn.relu,
    snt.Linear(OUTPUT_GLOBAL_SIZE)
])

graph_network = gn.modules.GraphNetwork(
    edge_model_fn=lambda: node,
    node_model_fn=lambda: edge,
    global_model_fn=lambda: global_s)
  


checkpoint = tf.train.Checkpoint(module=graph_network)

latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
	checkpoint.restore(latest)

loss_object = tf.keras.losses.CategoricalCrossentropy()
generator_optimizer = tf.keras.optimizers.Adam(1e-4)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

for echo in range(max_iteration):
	code.mess_up_order()
	
	for i in range(code.total_number):
		with tf.GradientTape() as gen_tape:
			Input_data, Output_data = code.next_batch(i)
			graph_dicts = utils_np.graphs_tuple_to_data_dicts(Input_data)
			
			graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
			output_data = graph_network(graphs_tuple_tf).globals
			loss = tf.nn.softmax_cross_entropy_with_logits(logits=output_data, labels=Output_data)
			#loss = loss_object(Output_data,output_data)
		
		gradients_of_generator = gen_tape.gradient(loss, graph_network.trainable_variables)

		generator_optimizer.apply_gradients(zip(gradients_of_generator, graph_network.trainable_variables))

		print('Echo %d,Iter %d: train_loss is: %.5f train_accuracy is: %.5f'%(echo+1, i+1, tf.reduce_mean(loss),train_accuracy(Output_data,output_data)))
		
			
		if i and i % 10 == 0:
			checkpoint.save(save_prefix)

    




 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值