代码库:dgl_1_gcn.py

  • 在文章 代码库:DGL_基本框架 中出现的代码文件 dgl_1_gcn.py
import dgl
import torch
import torch.nn as nn # torch 模型层
import torch.nn.functional as F # torch 函数 包括激活函数等
import torch.optim as optim
from dgl.nn import GraphConv # 图模型层
import dgl.data # 图数据集

import time

# 构造一个两层GCN模型
# batch_size = 1
class GCN(nn.Module):
	def __init__(self, in_feats, h_feats, num_classes):	# 模型构造时的输入参数
		
		super(GCN, self).__init__()
		self.conv1 = GraphConv(in_feats, h_feats)
		# exp self.conv1 = GraphConv(in_feats, h_feats, weight=False)
		self.conv2 = GraphConv(h_feats, num_classes)

		'''
		GraphConv

		参数:

		in_feats: int, 输入层的特征维度
		out_feats: int, 输出层的特征维度
		norm: str, 是否进行normalization
			
			right: 右乘normalization矩阵,即只对输入特征进行normalization
			none: no normalization
			both: default, 同时考虑对输入特征和输出特征进行normalization

		weight: bool, 是否对输入特征进行线性加和,default True

			若False 则不考虑weight matrix
			PS: in_feats and out_feats 共同构成了weight matrix的维度

		bias: bool, default True
		activation: default None, or callable function
		'''

	'''
	Tradition: 将图和自身的特征独立开来 调用图的API实现相关功能
	'''
	def forward(self, graph, in_feature): # forward(使用对象)时的输入参数
		
		h = self.conv1(graph, in_feature)
		h = F.relu(h)
		h = self.conv2(graph, h)

		return h
		# conv 参数中graph如何构造?

# 训练pipeline
def train(graph, model, epochs = 10):

	train_acc = 0
	valid_acc = 0

	# optimizer
	optimizer = optim.Adam(model.parameters(), lr=0.01)

	# split dataset
	features = graph.ndata['feat']
	labels = graph.ndata['label']
	train_mask = graph.ndata['train_mask']
	valid_mask = graph.ndata['val_mask']

	start = time.time()
	# train
	for epoch in range(epochs):
		
		# forward propagation
		prob = model(graph, features)
		pred = prob.argmax(dim=1)

		# loss
		# Note that you should only compute the losses of the nodes in the training set.
		loss = F.cross_entropy(prob[train_mask], labels[train_mask])

		# accuracy
		train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
		valid_acc = (pred[valid_mask] == labels[valid_mask]).float().mean()

		# backward
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if epoch % 5 == 0:
			print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, val acc: {:.3f}, time: {:.3f}'
				.format(epoch, loss, train_acc, valid_acc, time.time() - start))

			start = time.time()

if __name__ == '__main__':
	
	# Pipeline

	# 读入数据集
	dataSet = dgl.data.CoraGraphDataset()
	graph = dataSet[0]

	# 构建模型
	model = GCN(graph.ndata['feat'].shape[1], 16, dataSet.num_classes)

	# 训练模型
	train(graph, model, epochs=500)
			


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
使用QTimer对象代替QBasicTimer对象,修改程序class MyWindow(QWidget): def init(self): super().init() self.thread_list = [] self.color_photo_dir = os.path.join(os.getcwd(), "color_photos") self.depth_photo_dir = os.path.join(os.getcwd(), "depth_photos") self.image_thread = None self.saved_color_photos = 0 # 定义 saved_color_photos 属性 self.saved_depth_photos = 0 # 定义 saved_depth_photos 属性 self.init_ui() def init_ui(self): self.ui = uic.loadUi("C:/Users/wyt/Desktop/D405界面/intelrealsense1.ui") self.open_btn = self.ui.pushButton self.color_image_chose_btn = self.ui.pushButton_3 self.depth_image_chose_btn = self.ui.pushButton_4 self.open_btn.clicked.connect(self.open) self.color_image_chose_btn.clicked.connect(lambda: self.chose_dir(self.ui.lineEdit, "color")) self.depth_image_chose_btn.clicked.connect(lambda: self.chose_dir(self.ui.lineEdit_2, "depth")) def open(self): self.profile = self.pipeline.start(self.config) self.is_camera_opened = True self.label.setText('相机已打开') self.label.setStyleSheet('color:green') self.open_btn.setEnabled(False) self.close_btn.setEnabled(True) self.image_thread = ImageThread(self.pipeline, self.color_label, self.depth_label, self.interval, self.color_photo_dir, self.depth_photo_dir, self._dgl) self.image_thread.saved_color_photos_signal.connect(self.update_saved_color_photos_label) self.image_thread.saved_depth_photos_signal.connect(self.update_saved_depth_photos_label) self.image_thread.start() def chose_dir(self, line_edit, button_type): my_thread = MyThread(line_edit, button_type) my_thread.finished_signal.connect(self.update_line_edit) self.thread_list.append(my_thread) my_thread.start()
05-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值