【C#】使用 Winform 联合 TibcoRV中间件通信项目

  • 📢摘要: 本文将用C# Winform实现TIBCO RV的通信功能,实现连接、侦听、发送与接收全流程。

  • 🌟 前言

    • TIBCO Rendezvous(TIBCO RV)是一款高性能的中间件,专为实时消息传递设计,适用于对延迟极度敏感的场景。
    • 本文将用C# Winform实现TIBCO RV的通信功能,实现连接、侦听、发送与接收全流程。
    • 文未项目源码及安装测试步骤。

  • 🔍核心概念速览

    • 发布/订阅模型:基于主题的消息路由机制。
    • 对等网络架构:无中心节点,分布式通信
    • 实时性能:低延迟、高吞吐量。
    • 核心组件:Daemon(rvrd)、Subject(主题)、Transport(通信端点)、Listener(侦听器)

  • 🛠️ 运行环境

    • 操作系统: Windows 11
    • 编程软件: Visual Studio 2022
    • .Net版本: .NET Framework 4.5.2
    • 依赖版本: TIBCO Rendezvous 8.4


  • 🖥️ 预览

    • 运行效果

    在这里插入图片描述

  • 💻 核心代码片段

    • 1. 连接与断开逻辑

    private void btn_Connect_Click(object sender, EventArgs e)
    {
        bool isStartConnect = btn_Connect.Text.Equals("连接");
        if (isStartConnect)
        {
            tibcoRVTHelper = new TibcoRVTHelper(tbx_Server.Text, tbx_Network.Text, tbx_Daemon.Text, tbx_ListenSubject.Text, tbx_TargetSubject.Text);
            tibcoRVTHelper.StartConnect();
        }
        else
        {
            tibcoRVTHelper.DisConnected();
        }
    }
    
  • 2. 消息接收与显示

    public void OnMessageReceived(object sender, MessageReceivedEventArgs e)
    {
        TIBCO.Rendezvous.Message message = e.Message;
        string receiveData = message.GetFieldByIndex(0);
        rtbx_ReceiveData.AppendText($"{DateTime.Now}{receiveData}\n");
    }
    

  • 💻 代码

  • Form界面类主要代码

    using System;
    using System.Drawing;
    using System.Drawing.Drawing2D;
    using System.Windows.Forms;
    using TIBCO.Rendezvous;
    
    namespace Demo_TibcoRV
    {
        public partial class Frm_TibcoRV : Form
        {
            TibcoRVTHelper tibcoRVTHelper;
            public Frm_TibcoRV()
            {
                InitializeComponent();
                this.MinimumSize = new System.Drawing.Size(750, 650);
            }
            private void Frm_TibcoRV_Load(object sender, EventArgs e)
            {
                ControlStyleUpdata( picBx_ConnectStatu, Color.Gray);
                ControlStyleUpdata(picBx_ListenStatu, Color.Gray);
                btn_ListenOpen.Enabled = false;
    
            }
            #region 事件
            /// <summary>
            /// 连接按钮
            /// </summary>
            private void btn_Connect_Click(object sender, EventArgs e)
            {
                bool isStartConnect = btn_Connect.Text.Equals("连接")?true:false;
                //开始连接
                if (isStartConnect)
                {
                    tibcoRVTHelper = new TibcoRVTHelper(tbx_Server.Text, tbx_Network.Text, tbx_Daemon.Text,tbx_ListenSubject.Text,tbx_TargetSubject.Text);
                    tibcoRVTHelper.MessageField = tbx_MessageField.Text;
                    tibcoRVTHelper.messageReceivedHandler += OnMessageReceived;
                    tibcoRVTHelper.ListenedStatusHandler += OnListened;
                    tibcoRVTHelper.ConnectedStatusHandler += OnConnected;
                    tibcoRVTHelper.ErrorMessageHandler += TibcoRVTHelper_ErrorMessageHandler;
                    tibcoRVTHelper.StartConnect();
                }
                //断开连接
                else
                {
                    tibcoRVTHelper.DisConnected();
                    ControlStyleUpdata(picBx_ConnectStatu, Color.Gray);
                    ControlStyleUpdata(picBx_ListenStatu, Color.Gray);
                    btn_Connect.Text = tibcoRVTHelper.IsConnected ? "断开" : "连接";
                    ControlUpdata(true);
                }
            }
    
            private void TibcoRVTHelper_ErrorMessageHandler(object sender, string message)
            {
                MessageShow(message);
            }
    
            /// <summary>
            /// 侦听按钮
            /// </summary>
            private void btn_ListenOpen_Click(object sender, EventArgs e)
            {
                //正在连接,可以侦听
                if (tibcoRVTHelper.IsConnected)
                {
                    tibcoRVTHelper.TargetSubject = tbx_TargetSubject.Text;
                    tibcoRVTHelper.ListenSubject = tbx_ListenSubject.Text;
                }
                //未连接,不允许侦听
                else
                {
                    tibcoRVTHelper.messageReceivedHandler -= OnMessageReceived;
                }
                btn_ListenOpen.Text = tibcoRVTHelper.IsListened ? "停止" : "侦听";
                ControlStyleUpdata(picBx_ListenStatu, tibcoRVTHelper.IsListened? Color.LimeGreen: Color.Gray);
            }
            /// <summary>
            /// 发送按钮
            /// </summary>
            private void btn_SendData_Click(object sender, EventArgs e)
            {
                if (tibcoRVTHelper.IsConnected)
                {
                    //tibcoRVTHelper.MessageField = "xmlData";
                    tibcoRVTHelper.Send(rtbx_SendData.Text);
                }
            }
            /// <summary>
            /// 客户端、服务端
            /// </summary>
            private void checkBox_IsClient_CheckedChanged(object sender, EventArgs e)
            {
                if (tibcoRVTHelper == null ||!tibcoRVTHelper.IsConnected)
                {
                    string temp = tbx_ListenSubject.Text;
                    tbx_ListenSubject.Text = tbx_TargetSubject.Text;
                    tbx_TargetSubject.Text = temp;
                }
    
            }
            #endregion
    
            #region 委托方法
            /// <summary>
            /// 委托方法:消息接收
            /// </summary>
            public void OnMessageReceived(object sender, MessageReceivedEventArgs messageReceivedEventArgs)
            {
                TIBCO.Rendezvous.Message message = messageReceivedEventArgs.Message;
                string receiveData = message.GetFieldByIndex(0);
                string fieldName = message.GetFieldByIndex(0).Name;
                MessageShow($"send subject = {message.SendSubject}\r\n field name = {fieldName}\r\n{receiveData}");
            }
            /// <summary>
            /// 消息显示到控件
            /// </summary>
            public void MessageShow(string data)
            {
                rtbx_ReceiveData.Invoke(new Action(() =>
                {
                    rtbx_ReceiveData.AppendText($"{DateTime.Now}{data}{System.Environment.NewLine}");
                }));
            }
            private void OnListened(object sender,bool listened)
            {
                btn_ListenOpen.Invoke(new Action(() =>
                {
                    btn_ListenOpen.Text = listened ? "停止" : "侦听";
                    ControlStyleUpdata(picBx_ListenStatu, listened ? Color.LimeGreen : Color.Gray);
                }));
                
            }
            private void OnConnected(object sender, bool connected)
            {
                btn_Connect.Invoke(new Action(() =>
                {
                    btn_Connect.Text = connected ? "断开" : "连接";
                    ControlStyleUpdata(picBx_ConnectStatu, connected ? Color.LimeGreen : Color.Gray);
    
                    bool isConnected = tibcoRVTHelper.IsConnected;
                    if (isConnected)    //连接成功
                    {
                        ControlStyleUpdata(picBx_ConnectStatu, Color.LimeGreen);
                        if (tibcoRVTHelper.IsListened)
                        {
                            ControlStyleUpdata(picBx_ListenStatu, Color.LimeGreen);
                            ControlUpdata(false);
                        }
                        btn_Connect.Text = isConnected ? "断开" : "连接";
                    }
                    else //连接失败
                    {
                        ControlStyleUpdata(picBx_ConnectStatu, Color.Gray);
                        ControlStyleUpdata(picBx_ListenStatu, Color.Gray);
                        btn_Connect.Text = isConnected ? "断开" : "连接";
                        ControlUpdata(true);
                    }
                }));
                
            }
            #endregion
    
    
            #region 控件圆角
            #region 控件圆角方法1
            public void ControlStyleUpdata(Control control)
            {
                GraphicsPath gp = new GraphicsPath();
                gp.AddEllipse(control.ClientRectangle);
                Region region = new Region(gp);
                control.Region = region;
                gp.Dispose();
                region.Dispose();
            }
            public void ControlStyleUpdata(Control control, Color bcColor)
            {
                control.BackColor = bcColor;
                ControlStyleUpdata(control);
            }
            #endregion
    
            #region 控件圆角方法2
            /// <summary>
            /// 按钮控件圆角绘制绑定事件
            /// </summary>
            private void button_Paint(object sender, PaintEventArgs e)
            {
                Button button = (Button)sender;
                Draw(e.ClipRectangle, e.Graphics, 12, false, Color.FromArgb(0, 122, 204), Color.FromArgb(8, 39, 57));
                base.OnPaint(e);
                Graphics g = e.Graphics;
                g.DrawString(button.Text, new Font("微软雅黑", 12, FontStyle.Regular), new SolidBrush(Color.White), new PointF(15, 0));
            }
            /// <summary>
            /// 绘制圆角
            /// </summary>
            private void Draw(Rectangle rectangle, Graphics g, int _radius, bool cusp, Color begin_color, Color end_color)
            {
                int span = 2;
                //抗锯齿
                g.SmoothingMode = SmoothingMode.AntiAlias;
                //渐变填充
                LinearGradientBrush myLinearGradientBrush = new LinearGradientBrush(rectangle, begin_color, end_color, LinearGradientMode.Vertical);
                //画尖角
                if (cusp)
                {
                    span = 10;
                    PointF p1 = new PointF(rectangle.Width - 12, rectangle.Y + 10);
                    PointF p2 = new PointF(rectangle.Width - 12, rectangle.Y + 30);
                    PointF p3 = new PointF(rectangle.Width, rectangle.Y + 20);
                    PointF[] ptsArray = { p1, p2, p3 };
                    g.FillPolygon(myLinearGradientBrush, ptsArray);
                }
                //填充
                g.FillPath(myLinearGradientBrush, DrawRoundRect(rectangle.X, rectangle.Y, rectangle.Width - span, rectangle.Height - 1, _radius));
            }
    
            /// <summary>
            /// 设置圆角
            /// </summary>
            public static GraphicsPath DrawRoundRect(int x, int y, int width, int height, int radius)
            {
                //四边圆角
                GraphicsPath gp = new GraphicsPath();
                gp.AddArc(x, y, radius, radius, 180, 90);
                gp.AddArc(width - radius, y, radius, radius, 270, 90);
                gp.AddArc(width - radius, height - radius, radius, radius, 0, 90);
                gp.AddArc(x, height - radius, radius, radius, 90, 90);
                gp.CloseAllFigures();
                return gp;
            }
            #endregion
            #endregion
    
            #region 其他辅助方法
            public void ControlUpdata(bool controlStatus)
            {
                foreach (Control control in panel_Top.Controls)
                {
                    if (control is TextBox)
                    {
                        control.Enabled = controlStatus;
                    }
                }
            }
            #endregion
    
            private void btn_ClearReceiveData_Click(object sender, EventArgs e)
            {
                rtbx_ReceiveData.Text = string.Empty;
            }
    
            private void btn_ClearSendData_Click(object sender, EventArgs e)
            {
                rtbx_SendData.Text = string.Empty;
            }
        }
    }
    
    
  • 辅助类代码

    using System;
    using System.Collections;
    using System.Threading.Tasks;
    using System.Windows.Forms;
    using TIBCO.Rendezvous;
    namespace Demo_TibcoRV
    {
        /*
         * 设计思路:
         *      连接成功时,触发连接事件:在连接状态事件中执行侦听事件。
         *      侦听成功时,触发侦听事件
         */
        public class TibcoRVTHelper
        {
            #region 字段、属性
            private string _service;                //服务
            private string _network;                //网络
            private string _daemon;                 //守护进程
            private string _messageField;           //消息字段
            private string _listenSubject;          //侦听主题
            private string _targetSubject;          //目标主题
            private bool _isOpen = false;           //是否打开环境
            private bool _isConnected = false;      //是否创建连接
            private bool _isListen = false;         //是否创建侦听
            private string _cmName;                 //My Name
            private Task task = null;
    
            private TIBCO.Rendezvous.NetTransport _transport;       //传输对象
            private TIBCO.Rendezvous.Listener _listener;            //侦听器对象
            private TIBCO.Rendezvous.Queue _queue;                  //消息队列
    
            public bool IsListened { get { return _isListen; } set { _isListen = value; } }
            public bool IsOpen { get { return _isConnected; } set { _isConnected = value; } }
            public bool IsConnected { get { return _isOpen; } set { _isOpen = value; } }
            public string Service { get { return _service; } set { _service = value; } }
            public string Network { get { return _network; } set { _network = value; } }
            public string Daemon { get { return _daemon; } set { _daemon = value; } }
            public string MessageField { get => _messageField; set => _messageField = value; }
            public string ListenSubject { get { return _listenSubject; } set { _listenSubject = value; } }
            public string TargetSubject { get { return _targetSubject; } set { _targetSubject = value; } }
            public string CmName { get => _cmName; set => _cmName = value; }
            public TIBCO.Rendezvous.NetTransport Transport { get => _transport; set => _transport = value; }
            public TIBCO.Rendezvous.Listener Listener { get => _listener; set => _listener = value; }
            public TIBCO.Rendezvous.Queue Queue { get => _queue; set => _queue = value; }
    
            #endregion
            #region 构造函数
            public TibcoRVTHelper() { }
            public TibcoRVTHelper(string server, string network, string daemon)
            {
                this.Service = server;
                this.Network = network;
                this.Daemon = daemon;
            }
            public TibcoRVTHelper(string server, string network, string daemon, string listenSubject, string targetSubject)
            {
                this.Service = server;
                this.Network = network;
                this.Daemon = daemon;
                this.ListenSubject = listenSubject;
                this.TargetSubject = targetSubject;
                this.ConnectedStatusHandler += OnConnectCallBack;
            }
            #endregion
            #region 连接
            /// <summary>
            /// 打开环境
            /// </summary>
            public bool Open()
            {
                try
                {
                    TIBCO.Rendezvous.Environment.Open();
                    IsOpen = true;
                    string msg = $"打开环境成功!";
                    ErrorMessageHandler.Invoke(this, msg);
                    return IsOpen;
                }
                catch (Exception ex)
                {
                    IsOpen = false;
                    return IsOpen;
                }
            }
            /// <summary>
            /// 外部连接
            /// </summary>
            public void StartConnect()
            {
                Connected();
            }
            /// <summary>
            /// 内部连接
            /// </summary>
            private bool Connected()
            {
                if (IsOpen)
                {
                    TryCreateConnect();
                }
                else
                {
                    Open();
                    TryCreateConnect();
                }
                return IsConnected;
            }
            /// <summary>
            /// 尝试创建连接,超时时间3s
            /// </summary>
            private async void TryCreateConnect()
            {
                string msg = string.Empty;
                try
                {
                    Task createTask = Task.Run(() => CreateConnect());
                    Task timeoutTask = Task.Delay(3000);
                    Task completedTask = await Task.WhenAny(createTask, timeoutTask);
                    if (completedTask == timeoutTask)
                    {
                        msg = $"\r\n连接超时...\r\nDaemon = {Daemon},Network =  {Network} ,Service = {Service}";
                        ErrorMessageHandler.Invoke(this, msg);
                        IsConnected = false;
                        ConnectedStatusHandler.Invoke(this, false);
                        return;
                    }
                    IsConnected = true;
                    ConnectedStatusHandler.Invoke(this, true);
                }
                catch (Exception ex)
                {
                    msg = $"\r\n连接异常...\r\nDaemon = {Daemon},Network =  {Network} ,Service =  {Service}";
                    IsConnected = false;
                    ErrorMessageHandler.Invoke(this, msg);
                    ConnectedStatusHandler.Invoke(this, false);
                }
            }
            /// <summary>
            /// 创建连接的方法
            /// </summary>
            private void CreateConnect()
            {
                string msg = string.Empty;
                try
                {
                    msg = $"正在连接...: Daemon = {Daemon},Network = {Network},Service = {Service}";
                    ErrorMessageHandler.Invoke(this, msg);
                    Transport = new NetTransport(Service, Network, Daemon);
                    IsConnected = true;
                    msg = $"连接成功...";
                    ErrorMessageHandler.Invoke(this, msg);
                    ConnectedStatusHandler.Invoke(this,true);
                }
                catch (Exception ex)
                {
                    IsConnected = false;
                    msg = $"连接失败...";
                    ErrorMessageHandler.Invoke(this, msg);
                    ConnectedStatusHandler.Invoke(this, false);
                }
            }
            /// <summary>
            /// 断开连接
            /// </summary>
            public void DisConnected()
            {
                try
                {
                    if (task!=null)
                    task = null;
                    this.Listener.MessageReceived -= OnMessageReceivedCallBack;
                    if (Listener != null) Listener.Destroy();
                    Listener = null;
                    if (Transport != null) Transport.Destroy();
                    Transport = null;
                    string msg = $"断开连接... : Daemon = {Daemon} ,Network = {Network},Service = {Service}";
                    TIBCO.Rendezvous.Environment.Open();
                    TIBCO.Rendezvous.Environment.Close();
                    IsListened = false;
                    IsConnected = false;
                    IsOpen = false;
                    ErrorMessageHandler.Invoke(this, msg);
                    ConnectedStatusHandler.Invoke(this, false);
                    ListenedStatusHandler.Invoke(this, false);
                }
                catch (Exception ex)
                {
                    MessageBox.Show(ex.Message);
                }
            }
            #endregion
            #region 侦听
            /// <summary>
            /// 尝试侦听
            /// </summary>
            private void TryCreateListen()
            {
                try
                {
                    if (IsConnected && !IsListened)
                    {
                        if (Transport == null)
                        {
                            string msg = $"transport 为空,未连接,请先连接!!!";
                            ErrorMessageHandler.Invoke(this, msg);
                            ListenedStatusHandler.Invoke(this,false);
                            IsListened = false;
                            return;
                        }
                        Queue = new TIBCO.Rendezvous.Queue();
                        Listener = new Listener(Queue, Transport, ListenSubject, null);
                        this.Listener.MessageReceived += OnMessageReceivedCallBack;
                        IsListened = true;
                        ListenedStatusHandler.Invoke(this,true);
                    }
                }
                catch (Exception ex)
                {
                    IsListened = false;
                    string msg = $"侦听异常:{ex.Message}";
                    ErrorMessageHandler.Invoke(this, msg);
                    ListenedStatusHandler.Invoke(this, false);
                }
            }
            /// <summary>
            /// 内部侦听
            /// </summary>
            private void Listen()
            {
                try
                {
                    string msg = $"开始侦听...";
                    ErrorMessageHandler.Invoke(this, msg);
                    if (!this.IsListened)
                    {
                        TryCreateListen();
                        if (this.IsListened)
                        {
                            task = new Task(() =>
                            {
                                while (this.IsListened)
                                {
                                    Queue.Dispatch();
                                }
                            }); 
                            task.Start();
                            msg = $"侦听成功!!!";
                            ErrorMessageHandler.Invoke(this, msg);
                        }
                    }
                }
                catch (Exception ex)
                {
                    this.IsListened = false;
                    string msg = $"侦听异常:{ex.Message}";
                    ErrorMessageHandler.Invoke(this, msg);
                }
            }
            #endregion
            #region 发送
            public void Send(string data)
            {
                TIBCO.Rendezvous.Message message = new TIBCO.Rendezvous.Message();
                message.SendSubject = TargetSubject;
                message.AddField(MessageField, data);
                Transport.Send(message);
            }
            public void Send(string field, string data)
            {
            }
            #endregion
            public void OnListenCallBack(object sender , bool listenStatu)
            {
                IsListened = listenStatu;
            }
            public void OnConnectCallBack(object sender, bool connectStatu)
            {
                IsConnected = connectStatu;
                if (IsConnected)
                {
                    Listen();
                }
            }
            /// <summary>
            /// 消息接收
            /// </summary>
            public void OnMessageReceivedCallBack(object sender, MessageReceivedEventArgs messageReceivedEventArgs)
            {
                messageReceivedHandler.Invoke(sender, messageReceivedEventArgs);
            }
            /// <summary>
            /// </summary>
            public delegate void MessageReceivedHandler(object sender, MessageReceivedEventArgs messageReceivedEventArgs);
            public MessageReceivedHandler messageReceivedHandler;
            public event EventHandler<string> ErrorMessageHandler;
            public event EventHandler<bool> ConnectedStatusHandler;
            public event EventHandler<bool> ListenedStatusHandler;
        }
    }
    


  • 📝 细节步骤补充说明

  • TibcoRV 安装完成测试

    • 1、创建数据存储(解压安装包后,在bin目录下运行命令启动服务:)

      	rvrd -store datastore
      
    • 如下图:在安装目录下创建指定文件:(如安装在:D:\tibco\tibrv\8.4\bin> rvrd -store datastore)
      在这里插入图片描述

  • 2、端口冲突问题:7500

    • 报错内容:

      binding connection listen socket to TCP port 7500 failed: 10048 (Specified address is in use).
      2025-01-08 17:10:06 D:\tibco\tibrv\8.4\bin\rvrd.exe: OpenSSL 0.9.8o-fips 01 Jun 2010
      2025-01-08 17:10:06 D:\tibco\tibrv\8.4\bin\rvrd.exe: startup aborted: Initialization failed.
      
    • 报错图片
      在这里插入图片描述

  • 3、终止进程占用

    • 使用以下命令终止占用进程:
      	taskkill /PID <进程ID> /F
      
    • (1)Win + R 打开cmd 窗口 在cmd 窗口输入*

      netstat -ano |findstr 7500
      
    • (2)查看提示

      TCP    0.0.0.0:7500           0.0.0.0:0              LISTENING       20048
      
    • (3) 输入命令终止进程

      taskkill /PID 20048 /F
      
    • (4)重新输入 rvrd -store datastore 启用 rvrd 服务

      rvrd -store datastore 
      

  • 4、测试侦听与发送(CMD实现)

    在这里插入图片描述


  • 5、开始本地测试

    • 1)创建侦听(打开CMD窗口Win+R 输入如下命令)

      tibrvlisten -service 7500 -network 192.168.1.100 TEST.SUBJECT
      
      • tibrvlisten: 为Tibco自带的exe程序 。
      • service: 服务(本地端口): 7500
      • network: 网络(本地IP) :192.168.1.100
      • 主题名称(自定义): TEST.SUBJECT

      在这里插入图片描述

    • 2)创建发送|发布

      • Win+R 再打开一个 CMD窗口
        tibrvsend -service 7500 -network 192.168.1.100 TEST.SUBJECT  "This is a test message."
        
      • 命令基本一致 将tibrvlisten 替换为 tibrvsend,并在最后输入空格加入要发的消息"This is a test message." 在这里插入图片描述
    • 3)创建其他侦听格式

      • 本地IP:192.168.1.5
        tibrvlisten -service 9800 -network ";225.9.9.8" -daemon "tcp:192.168.1.5:7500" "Target.Subject"
        
      • 创建本地发送

        • tibrvsend -service 9800 -network “;225.9.9.8” -daemon “tcp:192.168.1.5:7500” “Target.Subject” “Hello”在这里插入图片描述

  • 结语

    • 项目中用到了TibcoRv中间件作通信。在此过程中,碰到了很多坑,也是自己本身水平一般。
    • 在此分享一下TibcoRV通信的Demo。算是作为备份吧。万一哪天忘记了呢。

  • 最后

    • 如果你觉得这篇文章对你有帮助,不妨点个赞支持一下!
    • 如果有疑问,欢迎评论区留言。
    • 也可以关注微信公众号 [编程笔记in] ,一起交流学习!
    • 项目地址: gitee.com/incodenotes/cshape-demos

好的,我将修改你的 `UNet` 模型,使其只有一层编码器和解码器,并保留其他部分不变。以下是修改后的完整代码: ```python import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import argparse import glob import matplotlib.pyplot as plt from tqdm import tqdm import numpy as np class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2), nn.Sigmoid() ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x class AttentionBlock(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionBlock, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class AttentionUNet(nn.Module): def __init__(self): super(AttentionUNet, self).__init__() self.encoder1 = self.conv_block(3, 64) self.bottleneck = self.conv_block(64, 128) self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32) self.decoder1 = self.conv_block(128, 64) self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0) self.sigmoid = nn.Sigmoid() def conv_block(self, in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): # Encoding e1 = self.encoder1(x) b = self.bottleneck(F.max_pool2d(e1, 2)) # Decoding + Attention Gate d1 = self.upconv1(b) e1 = self.att1(g=d1, x=e1) d1 = torch.cat((e1, d1), dim=1) d1 = self.decoder1(d1) out = self.final_conv(d1) out = self.sigmoid(out) return out class ColorblindDataset(Dataset): def __init__(self, image_dir, mode='train', transform=None): self.image_dir = image_dir self.mode = mode self.transform = transform self.normal_images = glob.glob(image_dir + '/' + mode + '/' + 'origin_image' + '/*') self.recolor_images = glob.glob(image_dir + '/' + mode + '/' + 'recolor_image' + '/' + '*Protanopia*') self.correct_images = glob.glob(image_dir + '/' + mode + '/' + 'correct_image' + '/*') self.normal_images.sort() self.recolor_images.sort() self.correct_images.sort() self.image_pair = [] for index, image in enumerate(self.normal_images): self.image_pair.append([self.normal_images[index], self.recolor_images[index]]) self.image_pair.append([self.correct_images[index], self.normal_images[index]]) def __len__(self): return len(self.normal_images) def __getitem__(self, idx): normal_path, recolor_path = self.image_pair[idx] normal_image = Image.open(normal_path).convert('RGB') recolor_image = Image.open(recolor_path).convert('RGB') if self.transform: normal_image = self.transform(normal_image) recolor_image = self.transform(recolor_image) return normal_image, recolor_image def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() running_loss = 0.0 for inputs, targets in tqdm(dataloader, desc="Training"): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() epoch_loss = running_loss / len(dataloader) return epoch_loss def validate(model, dataloader, criterion, device): model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in tqdm(dataloader, desc="Validation"): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() val_loss /= len(dataloader) return val_loss def visualize_results(model, dataloader, device, num_images=10): model.eval() inputs, targets = next(iter(dataloader)) inputs, targets = inputs.to(device), targets.to(device) with torch.no_grad(): outputs = model(inputs) outputs = outputs.cpu().numpy() inputs = inputs.cpu().numpy() targets = targets.cpu().numpy() plt.figure(figsize=(15, 10)) for i in range(num_images): plt.subplot(3, num_images, i + 1) plt.imshow(inputs[i].transpose(1, 2, 0)) plt.title("Original") plt.axis('off') plt.subplot(3, num_images, i + 1 + num_images) plt.imshow(targets[i].transpose(1, 2, 0)) plt.title("Colorblind") plt.axis('off') plt.subplot(3, num_images, i + 1 + 2 * num_images) plt.imshow(outputs[i].transpose(1, 2, 0)) plt.title("Reconstructed") plt.axis('off') plt.show() def plot_and_save_losses(train_losses, val_losses, epoch, path='./loss_plots'): if not os.path.exists(path): os.makedirs(path) epochs = np.arange(1, epoch+2) plt.figure(figsize=(10, 5)) plt.plot(epochs, train_losses, label='Training Loss') plt.plot(epochs, val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training and Validation Losses') plt.legend() plt.savefig(f'{path}/loss_epoch_{epoch+1}.png') plt.close() def main(args): # Data transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256, 256)), ]) # Datasets and Dataloaders train_dataset = ColorblindDataset(args.dataset_dir, mode='train', transform=transform) val_dataset = ColorblindDataset(args.dataset_dir, mode='val', transform=transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False) # Model, Loss, Optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet().to(device) # 使用简化版的UNet if args.model_pretrained_path: model.load_state_dict(torch.load(args.model_pretrained_path)) print("Successfully load past pretrained weights!!") criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) train_losses = [] val_losses = [] # Training and validation loop for epoch in range(args.num_epochs): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss = validate(model, val_loader, criterion, device) train_losses.append(train_loss) val_losses.append(val_loss) print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}') plot_and_save_losses(train_losses, val_losses, epoch) visualize_results(model, val_loader, device) # Save the model torch.save(model.state_dict(), args.model_save_path) if __name__ == "__main__": parser = argparse.ArgumentParser(description="UNet Colorblind Image Reconstruction") parser.add_argument('--dataset_dir', type=str, default='./dataset', help='Path to the dataset directory') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training and validation') parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate for the optimizer') parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--model_save_path', type=str, default='./model_weights/color_blind_model.pth', help='Path to save the trained model') parser.add_argument('--model_pretrained_path', type=str, default='./model_weights/color_blind_model.pth', help='训练好的色盲模拟器模型路径') args = parser.parse_args() main(args) ``` 在这个版本中,`UNet` 类被简化为只有一层编码器和解码器。其余部分保持不变,包括数据集加载、训练循环和验证函数等。希望这能满足你的需求。如果有任何进一步的问题,请随时告诉我!
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程笔记in

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值