LDPC码(PEG构造校验矩阵)生成器 - 用于神经网络译码研究(交互式版本)

        这是一款以 PEG (Progressive Edge Growth) 算法 为核心的 LDPC 校验矩阵生成器。PEG 通过“渐进式边增长”策略,在构造稀疏矩阵的同时有效抑制 4 环等短环的产生,提升码图性能,因而被广泛用于现代通信与机器学习译码研究。本工具将 PEG 算法封装为易用的交互界面,帮助你在几分钟内生成符合实验需求的 规则或非规则 LDPC 码,并自动导出用于神经网络 / GNN 训练的完整数据集。

        除了矩阵生成,本项目还自带 矩阵特性分析、可视化、稀疏格式保存、Tanner 图导出 等一站式功能,开箱即用,适合科研、课程实验及工程验证场景。

PEG 算法(Progressive Edge Growth)是什么?

PEG 是 Hu-Eleftheriou-Arnold 在 2005 年提出的 逐边生长 构造法,用来生成 LDPC 校验矩阵对应的 Tanner 图。它一边添加边,一边最大化当前变量节点周围的最小环长(girth),因此能在 有限码长 下显著减少 4 环 / 6 环的数量,迭代译码时更稳定、更易收敛。

核心思路——“先均衡度数,再拉大环长”

  1. 输入

    • 变量节点数 n、校验节点数 m

    • 变量节点度序列(规则码即常数 dᵥ;非规则码可自定义)

  2. 外层循环(按变量节点)
    变量节点按度数升序遍历;低度节点先放边,可把宝贵的“长环位置”优先分给它们。

  3. 内层循环(按该节点第 k 条边)

    • 第 1 条边:直接连到 当前校验度最小 的校验节点,保证整体行权重均衡。

    • 后续边

      1. 从该变量节点出发做一轮 BFS 扩展,逐层向外“长大”树;

      2. 找到 尚未出现在搜索树中的校验节点集合(即与此节点“最远”的校验节点);

      3. 在这个集合里再选 “度数最小” 的校验节点接边。
        这样就把新边放在“离自己最远且行度最低”的位置,局部 girth ≥ 2·depth+2。

  4. 更新度数:edge 接完即刻更新行、列度,为下一条边提供最新信息。

为什么 PEG 效果好?

特性原因
大 girth“远节点”避免 4/6 环
均衡行权重第 1 条边总连度最小校验节点
规则 / 非规则都支持度序列可任意给定
可扩展到 QC、ACE、LPEG 等只需在选边规则上加约束或打表
一句话总结

PEG 算法= “逐边生长 + BFS 选远邻 + 校验度最小”。它让你在纯软件脚本里就能快速得到大 girth、迭代译码友好的 LDPC 码图,是现代 LDPC 构造的黄金起点。

代码部分

运行示例:

python ldpc_generator.py

运行后选项:

================================================
         LDPC码生成器 - 交互式版本
================================================
请选择要生成的LDPC码类型:
1. 小规模规则LDPC码 (64, 32, dv=3)
2. 中等规模规则LDPC码 (1000, 500, dv=3)
3. 5G-like LDPC码 (1944, 972, dv=3)
4. 非规则LDPC码 (1000, 500)
5. 自定义规则LDPC码
6. 自定义非规则LDPC码
7. 批量生成预设码
0. 退出
------------------------------------------------
请选择操作 (0-7): 
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LDPC码生成器 - 用于神经网络译码研究 (交互式版本, 修正版)
生成LDPC校验矩阵和相关数据集
"""

import numpy as np
import json
import os
from datetime import datetime
import matplotlib.pyplot as plt
import copy
import matplotlib

matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']  # 使用黑体
matplotlib.rcParams['axes.unicode_minus'] = False  # 正常显示负号


def find_smallest(array):
    """返回 array 最小值的索引(稳定实现)"""
    if len(array) == 0:
        return 0
    if len(array) == 1:
        return 0
    elif len(array) == 2:
        return 0 if array[0] <= array[1] else 1
    else:
        mid = len(array) // 2
        arrayA = array[:mid]
        arrayB = array[mid:]
        smallA = find_smallest(arrayA)
        smallB = find_smallest(arrayB)
        if arrayA[smallA] <= arrayB[smallB]:
            return smallA
        else:
            return mid + smallB


class PEG:
    """Progressive Edge Growth算法的Python 3实现(左规则,右侧不强制)"""

    def __init__(self, nvar, nchk, degree_sequence):
        self.degree_sequence = degree_sequence
        self.nvar = nvar
        self.nchk = nchk
        self.H = np.zeros((nchk, nvar), dtype=np.uint8)
        self.sym_degrees = np.zeros(nvar, dtype=np.int32)
        self.chk_degrees = np.zeros(nchk, dtype=np.int32)

    def grow_edge(self, var, chk):
        """在 var 与 chk 之间加一条边"""
        if self.H[chk, var] == 1:
            return  # 已有边则不重复
        self.H[chk, var] = 1
        self.sym_degrees[var] += 1
        self.chk_degrees[chk] += 1

    def bfs(self, var):
        """
        以 var 为根进行层次遍历,返回一个"最小度且不在当前覆盖层"的校验节点。
        退化时返回全局最小度校验节点(避免固定偏置到0)。
        """
        var_list = np.zeros(self.nvar, dtype=np.uint8)
        var_list[var] = 1
        cur_chk_list = [0] * self.nchk
        new_chk_list = [0] * self.nchk
        chk_Q = []
        var_Q = [var]

        while True:
            # 从已知变量节点扩展到校验节点
            for _vars in var_Q:
                # 遍历所有校验节点(可优化:用非零索引)
                for i in range(self.nchk):
                    if self.H[i, _vars] == 1 and cur_chk_list[i] == 0:
                        new_chk_list[i] = 1
                        chk_Q.append(i)
            var_Q = []

            # 从新发现的校验节点扩展到变量节点
            for _chks in chk_Q:
                for j in range(self.nvar):
                    if self.H[_chks, j] == 1 and var_list[j] == 0:
                        var_list[j] = 1
                        var_Q.append(j)
            chk_Q = []

            # 若所有校验节点都已在覆盖层中,或层数不再扩展,则选择一个最小度的"未覆盖校验";
            # 若不存在未覆盖校验,则退化为全局最小度校验节点。
            if new_chk_list.count(1) == self.nchk or new_chk_list == cur_chk_list:
                chosen = self.find_smallest_chk(cur_chk_list)
                return chosen
            else:
                cur_chk_list = copy.copy(new_chk_list)

    def find_smallest_chk(self, cur_chk_list):
        """
        在"未覆盖校验节点"中选最小度的一个;若没有未覆盖者,则退化为全局最小度,避免系统性偏置到0号。
        """
        candidates = [i for i, v in enumerate(cur_chk_list) if v == 0]
        if candidates:
            degs = self.chk_degrees[candidates]
            return int(candidates[int(np.argmin(degs))])
        # 退化选择:全局最小度校验节点
        return int(np.argmin(self.chk_degrees))

    def progressive_edge_growth(self, verbose=True):
        for var in range(self.nvar):
            if verbose and var % 50 == 0:
                print(f"Edge growth at var {var}/{self.nvar}")
            for k in range(self.degree_sequence[var]):
                if k == 0:
                    # 第1条边连到"当前最小度的校验节点"
                    smallest_degree_chk = int(np.argmin(self.chk_degrees))
                    self.grow_edge(var, smallest_degree_chk)
                else:
                    chk = self.bfs(var)
                    self.grow_edge(var, chk)


class LDPCMatrixGenerator:
    """LDPC校验矩阵生成器 - 生成LDPC校验矩阵、系统化与生成矩阵"""

    def __init__(self):
        self.H = None
        self.G = None  # 生成矩阵(k×n)
        self.n = None  # 码长
        self.k = None  # 信息位长度
        self.m = None  # 校验位行数(行数)
        self.P = None  # 系统形式 HΠ = [P | I_r] 中的 P(r×k),r为rank(H)
        self.rank = None  # rank(H)

    def generate_regular_ldpc(self, n, m, dv, dc=None):
        """
        生成左规则LDPC码(变量节点度为 dv)。不强制右规则。
        n: 码长(变量节点数)
        m: 校验节点数(H 的行数)
        dv: 变量节点度数
        dc: 目标校验节点平均度(仅显示),实际不强制
        """
        if dc is None:
            # 平均值(不强制)
            dc = (n * dv) // m

        if (n * dv) % m != 0:
            print(f"提示:n*dv={n*dv} 不能被 m={m} 整除,右侧不会严格(平均度={dc})。")

        print(f"生成规则({dv},{dc}) LDPC码: n={n}, m={m}")

        degree_sequence = [dv] * n

        peg = PEG(n, m, degree_sequence)
        peg.progressive_edge_growth()

        self.H = peg.H.astype(np.uint8)
        self.n = n
        self.m = m

        # 验证度
        actual_dv = float(np.mean(np.sum(self.H, axis=0)))
        actual_dc = float(np.mean(np.sum(self.H, axis=1)))
        print(f"实际平均度数: dv={actual_dv:.2f}, dc={actual_dc:.2f}")

        # 计算生成矩阵
        self._compute_generator_matrix()

        return self.H

    # ---------- GF(2) 工具函数 ----------
    @staticmethod
    def _gf2_rref(H):
        """
        计算GF(2)行最简阶梯形(RREF),返回:
          H_rref:行变换后的矩阵
          pivots:主元列索引列表
        """
        H = (H % 2).astype(np.uint8)
        m, n = H.shape
        row = 0
        pivots = []
        for col in range(n):
            # 找主元行
            pivot = None
            for r in range(row, m):
                if H[r, col]:
                    pivot = r
                    break
            if pivot is None:
                continue
            # 交换到当前行
            if pivot != row:
                H[[row, pivot]] = H[[pivot, row]]
            # 将该列除当前行外清零(GF(2))
            for r in range(m):
                if r != row and H[r, col]:
                    H[r, :] ^= H[row, :]
            pivots.append(col)
            row += 1
            if row == m:
                break
        return H, pivots

    def _convert_to_systematic_form(self, H):
        """
        将H矩阵转换为系统形式(存在列置换 Π 使得 HΠ = [P | I_r],r=rank(H))
        返回: (H_sys, P, perm, rank)
          - H_sys: 行变换并按列重排后的矩阵(使主元列在最右边)
          - P: 形状 (r, k) 的P子块
          - perm: 长度为 n 的列置换(orig->sys 列顺序)
          - rank: r = rank(H)
        """
        H = (H % 2).astype(np.uint8)
        m, n = H.shape

        # 先做 RREF,得到主元列
        H_rref, pivots = self._gf2_rref(H.copy())
        r = len(pivots)  # rank(H)

        # 将主元列移到最右侧,非主元列在左侧
        non_piv = [c for c in range(n) if c not in pivots]
        new_order = non_piv + pivots  # 前k列=非主元列,后r列=主元列
        perm = np.array(new_order, dtype=int)

        H_sys = H_rref[:, perm]
        # 取出 P 块(r×k),注意:RREF 的后 r 列是单位阵(在前 r 行),其余行全0
        k = n - r
        if r == 0:
            P = np.zeros((0, k), dtype=np.uint8)
        else:
            P = H_sys[:r, :k]  # 只取前 r 行(有主元的行)

        return H_sys, P, perm, r

    def _compute_generator_matrix(self):
        """
        计算系统形式的生成矩阵G。
        通用做法:对任意满列秩/欠秩 H,求 rank r,令 k = n - r,
        存在列置换 Π 使 HΠ = [P | I_r](r×r单位阵位于右侧),
        则系统形式生成矩阵 G_sys = [I_k | P^T](k×n,处于"系统列顺序"),
        最后将列顺序还原回原顺序:G[:, perm] = G_sys。
        """
        try:
            print("正在计算生成矩阵G...")

            H_sys, P, perm, r = self._convert_to_systematic_form(self.H.copy())
            self.rank = r
            self.k = int(self.n - r)
            self.P = P  # (r, k)

            if self.k < 0:
                raise ValueError(f"计算到非法维度:k={self.k}")

            # 处理极端情况:r=0 => H全零矩阵;此时任意G都满足 H G^T = 0
            if r == 0:
                self.G = np.eye(self.n, dtype=np.uint8)[:self.k, :]  # 退化:取前k行为单位阵
                ok = np.all(((self.H @ self.G.T) % 2) == 0)
                print("✓ 生成矩阵验证通过: H * G^T = 0" if ok else "⚠ 警告: 生成矩阵验证失败")
                return

            # 在系统列顺序下构造 G_sys = [I_k | P^T]
            I_k = np.eye(self.k, dtype=np.uint8)
            G_sys = np.hstack([I_k, (P.T % 2)]).astype(np.uint8)  # k × (k+r) = k × n

            # 将列顺序还原到原始:perm 是 orig->sys 的新序,因此需要逆置换
            G = np.zeros((self.k, self.n), dtype=np.uint8)
            G[:, perm] = G_sys
            self.G = G

            # 验证 H * G^T = 0 (GF(2))
            ok = np.all(((self.H @ self.G.T) % 2) == 0)
            print("✓ 生成矩阵验证通过: H * G^T = 0" if ok else "⚠ 警告: 生成矩阵验证失败")

        except Exception as e:
            print(f"⚠ 生成矩阵计算失败: {e}")
            print("将使用全零码字编码(仅用于占位,建议修复问题后再使用)")
            self.G = None
            self.P = None

    def _verify_generator_matrix(self):
        """验证生成矩阵的正确性(保留接口)"""
        if self.G is None or self.H is None:
            return False
        result = (self.H @ self.G.T) % 2
        return np.all(result == 0)

    def systematic_encode(self, messages):
        """
        使用生成矩阵进行系统编码
        messages: (batch_size, k) 信息位矩阵
        返回: (batch_size, n) 码字矩阵
        """
        if self.G is None:
            print("警告: 生成矩阵不可用,返回全零码字(占位)。")
            return np.zeros((messages.shape[0], self.n), dtype=np.uint8)

        messages = (messages % 2).astype(np.uint8)
        codewords = (messages @ self.G) % 2
        return codewords.astype(np.uint8)

    def generate_irregular_ldpc(self, n, m, dv_distribution):
        """
        生成非规则LDPC码
        dv_distribution: 变量节点度分布字典 {度数: 节点数}
        """
        degree_sequence = []
        for degree, count in dv_distribution.items():
            if degree <= 0 or count < 0:
                raise ValueError("度数必须>0且节点数>=0")
            degree_sequence.extend([degree] * count)

        if len(degree_sequence) != n:
            raise ValueError(f"度分布总和({len(degree_sequence)})不等于n({n})")

        print(f"生成非规则LDPC码: n={n}, m={m}")
        print(f"度分布: {dv_distribution}")

        peg = PEG(n, m, degree_sequence)
        peg.progressive_edge_growth()

        self.H = peg.H.astype(np.uint8)
        self.n = n
        self.m = m

        # 计算生成矩阵
        self._compute_generator_matrix()

        return self.H

    def analyze_matrix(self):
        """分析生成的LDPC矩阵特性"""
        if self.H is None:
            raise ValueError("请先生成LDPC矩阵")

        var_degrees = np.sum(self.H, axis=0)
        chk_degrees = np.sum(self.H, axis=1)
        sH = int(np.sum(self.H))
        m, n = self.H.shape
        r = self.rank if self.rank is not None else int(min(m, n))  # 兜底
        k = self.k if self.k is not None else int(n - r)

        analysis = {
            "dimensions": f"{m}x{n}",
            "code_rate": float(k / n),
            "density": float(sH / (m * n)),
            "avg_variable_degree": float(np.mean(var_degrees)),
            "avg_check_degree": float(np.mean(chk_degrees)),
            "min_variable_degree": int(np.min(var_degrees)),
            "max_variable_degree": int(np.max(var_degrees)),
            "min_check_degree": int(np.min(chk_degrees)),
            "max_check_degree": int(np.max(chk_degrees)),
            "rank_H": int(r),
            "k": int(k)
        }

        # 估计girth(最小环长)- 简化版本
        analysis["estimated_girth"] = self._estimate_girth()

        return analysis

    def _estimate_girth(self, max_check=100):
        """估计矩阵的girth(最小环长,简化:仅检测4环)"""
        m = self.m
        upper = min(max_check, m)
        for i in range(upper):
            for j in range(i + 1, upper):
                common = int(np.sum(self.H[i] & self.H[j]))
                if common >= 2:
                    return 4
        return ">4"

    def visualize_matrix(self, save_path=None):
        """可视化LDPC矩阵"""
        if self.H is None:
            raise ValueError("请先生成LDPC矩阵")

        plt.figure(figsize=(12, 8))

        # 子图1:矩阵稀疏模式
        plt.subplot(2, 2, 1)
        plt.spy(self.H, markersize=1)
        plt.title(f'LDPC矩阵稀疏模式 ({self.m}x{self.n})')
        plt.xlabel('变量节点')
        plt.ylabel('校验节点')

        # 子图2:变量节点度分布
        plt.subplot(2, 2, 2)
        var_degrees = np.sum(self.H, axis=0)
        plt.hist(var_degrees, bins=np.arange(var_degrees.min(), var_degrees.max() + 2) - 0.5)
        plt.title('变量节点度分布')
        plt.xlabel('度数')
        plt.ylabel('节点数')

        # 子图3:校验节点度分布
        plt.subplot(2, 2, 3)
        chk_degrees = np.sum(self.H, axis=1)
        plt.hist(chk_degrees, bins=np.arange(chk_degrees.min(), chk_degrees.max() + 2) - 0.5)
        plt.title('校验节点度分布')
        plt.xlabel('度数')
        plt.ylabel('节点数')

        # 子图4:统计信息
        plt.subplot(2, 2, 4)
        r = self.rank if self.rank is not None else 0
        k = self.k if self.k is not None else self.n - r
        plt.text(0.1, 0.9, f'码率: {k / self.n:.3f}', transform=plt.gca().transAxes)
        plt.text(0.1, 0.8, f'密度: {np.sum(self.H) / (self.m * self.n):.4f}', transform=plt.gca().transAxes)
        plt.text(0.1, 0.7, f'平均变量度: {np.mean(var_degrees):.2f}', transform=plt.gca().transAxes)
        plt.text(0.1, 0.6, f'平均校验度: {np.mean(chk_degrees):.2f}', transform=plt.gca().transAxes)
        plt.axis('off')
        plt.title('LDPC码统计信息')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"图像已保存到: {save_path}")
        else:
            plt.show()

    def save_for_nn_training(self, save_dir, dataset_name):
        """保存数据用于神经网络训练"""
        if self.H is None:
            raise ValueError("请先生成LDPC矩阵")

        os.makedirs(save_dir, exist_ok=True)

        # 保存校验矩阵
        np.save(os.path.join(save_dir, f'{dataset_name}_H.npy'), self.H)

        # 保存生成矩阵(如果可用)
        if self.G is not None:
            np.save(os.path.join(save_dir, f'{dataset_name}_G.npy'), self.G)
            print(f"- {dataset_name}_G.npy: 生成矩阵")

        # 保存为稀疏格式(节省空间)
        saved_sparse = False
        try:
            from scipy.sparse import csr_matrix, save_npz
            H_sparse = csr_matrix(self.H)
            save_npz(os.path.join(save_dir, f'{dataset_name}_H_sparse.npz'), H_sparse)
            saved_sparse = True
        except ImportError:
            print("警告:未安装scipy,跳过稀疏矩阵保存")

        # 保存参数
        params = {
            'n': int(self.n),
            'k': int(self.k) if self.k is not None else None,
            'm': int(self.m),
            'rank_H': int(self.rank) if self.rank is not None else None,
            'code_rate': float((self.k / self.n) if self.k is not None else None),
            'has_generator_matrix': self.G is not None,
            'analysis': self.analyze_matrix(),
            'timestamp': datetime.now().isoformat()
        }

        with open(os.path.join(save_dir, f'{dataset_name}_params.json'), 'w', encoding='utf-8') as f:
            json.dump(params, f, indent=2, ensure_ascii=False)

        # 生成Tanner图的邻接表示(用于GNN)
        self._save_tanner_graph(save_dir, dataset_name)

        print(f"数据已保存到: {save_dir}")
        print(f"- {dataset_name}_H.npy: 校验矩阵")
        if saved_sparse:
            print(f"- {dataset_name}_H_sparse.npz: 稀疏格式校验矩阵")
        print(f"- {dataset_name}_params.json: 参数文件")
        print(f"- {dataset_name}_tanner_*.npy: Tanner图表示")

    def _save_tanner_graph(self, save_dir, dataset_name):
        """保存Tanner图表示(用于图神经网络)"""
        edges_v2c = []
        edges_c2v = []

        for i in range(self.m):
            row = self.H[i]
            nz = np.nonzero(row)[0]
            for j in nz:
                edges_v2c.append([j, i])  # 变量节点j到校验节点i
                edges_c2v.append([i, j])  # 校验节点i到变量节点j

        edges_v2c = np.array(edges_v2c, dtype=np.int32)
        edges_c2v = np.array(edges_c2v, dtype=np.int32)

        np.save(os.path.join(save_dir, f'{dataset_name}_tanner_v2c.npy'), edges_v2c)
        np.save(os.path.join(save_dir, f'{dataset_name}_tanner_c2v.npy'), edges_c2v)


# ------------------ 菜单与交互 ------------------

def display_menu():
    """显示主菜单"""
    print("\n" + "=" * 70)
    print("         LDPC码生成器 - 交互式版本(修正版)")
    print("=" * 70)
    print("请选择要生成的LDPC码类型:")
    print("1. 小规模规则LDPC码 (64, 32, dv=3) - 适合快速测试")
    print("2. 中等规模规则LDPC码 (1000, 500, dv=3) - 标准测试")
    print("3. 5G-like LDPC码 (1944, 972, dv=3) - 实际应用")
    print("4. (3,6)规模规则LDPC码 (600, 300, dv=3, dc≈6) - 经典规模(右侧不强制)")
    print("5. 码长128规则LDPC码 (128, 64, 码率≈0.5)")
    print("6. 码长32规则LDPC码 (32, 16, 码率≈0.5)")
    print("7. 码长8规则LDPC码 (8, 4, 码率≈0.5) - 微规模验证")
    print("8. 码长648规则LDPC码 (648, 324, 码率≈0.5) - WiFi标准")
    print("9. 非规则LDPC码 (1000, 500) - 高级研究")
    print("10. 自定义规则LDPC码 - 自定义参数")
    print("11. 自定义非规则LDPC码 - 自定义度分布")
    print("12. 批量生成预设码 - 一次生成多种")
    print("0. 退出")
    print("-" * 70)


def generate_small_regular():
    """生成小规模规则LDPC码"""
    print("\n生成小规模规则LDPC码...")
    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=64, m=32, dv=3)

    # 分析并显示结果
    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    # 询问是否保存
    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'small_regular_64_32_3')

        # 询问是否生成可视化
        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_small_visualization.png')

    return generator


def generate_medium_regular():
    """生成中等规模规则LDPC码"""
    print("\n生成中等规模规则LDPC码...")
    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=1000, m=500, dv=3)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'medium_regular_1000_500_3')

    return generator


def generate_5g_like():
    """生成5G-like LDPC码"""
    print("\n生成5G-like LDPC码...")
    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=1944, m=972, dv=3)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', '5g_like_1944_972')

    return generator


def generate_36_regular():
    """生成(3,6)规模规则LDPC码(右侧不强制,仅平均≈6)"""
    print("\n生成(3,6)规模规则LDPC码...")
    print("该码的特点:变量节点度数为3,校验节点平均度数≈6(未强制每个校验节点都等于6)")

    n = 600  # 码长
    m = 300  # 校验节点数:n*3/m=6,码率≈0.5
    dv = 3

    print(f"参数设置: n={n}, m={m}, dv={dv}, 期望dc≈{n*dv//m}, 码率≈{(n-m)/n:.1f}")

    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=n, m=m, dv=dv)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'regular_600_300_3_6')

        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_36_visualization.png')

    return generator


def generate_128_regular():
    """生成码长128规则LDPC码(码率≈0.5)"""
    print("\n生成码长128规则LDPC码...")
    print("参数:码长n=128, 校验节点m=64, 码率≈0.5")

    n = 128
    m = 64
    dv = 3

    print(f"参数设置: n={n}, m={m}, dv={dv}, 期望dc≈{n*dv//m}, 码率≈{(n-m)/n:.1f}")

    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=n, m=m, dv=dv)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'regular_128_64_3')

        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_128_visualization.png')

    return generator


def generate_32_regular():
    """生成码长32规则LDPC码(码率≈0.5)"""
    print("\n生成码长32规则LDPC码...")
    print("参数:码长n=32, 校验节点m=16, 码率≈0.5")

    n = 32
    m = 16
    dv = 3

    print(f"参数设置: n={n}, m={m}, dv={dv}, 期望dc≈{n*dv//m}, 码率≈{(n-m)/n:.1f}")

    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=n, m=m, dv=dv)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'regular_32_16_3')

        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_32_visualization.png')

    return generator


def generate_8_regular():
    """生成码长8规则LDPC码(码率≈0.5)"""
    print("\n生成码长8规则LDPC码...")
    print("参数:码长n=8, 校验节点m=4, 码率≈0.5")
    print("注意:这是一个极小规模的LDPC码,主要用于算法验证")

    n = 8
    m = 4
    dv = 3

    print(f"参数设置: n={n}, m={m}, dv={dv}, 期望dc≈{n*dv//m}, 码率≈{(n-m)/n:.1f}")

    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=n, m=m, dv=dv)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'regular_8_4_3')

        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_8_visualization.png')

    return generator


def generate_648_regular():
    """生成码长648规则LDPC码(码率≈0.5)"""
    print("\n生成码长648规则LDPC码...")
    print("参数:码长n=648, 校验节点m=324, 码率≈0.5")
    print("这是WiFi标准中常用的LDPC码长,适用于无线通信应用")

    n = 648
    m = 324
    dv = 3

    print(f"参数设置: n={n}, m={m}, dv={dv}, 期望dc≈{n*dv//m}, 码率≈{(n-m)/n:.1f}")

    generator = LDPCMatrixGenerator()
    generator.generate_regular_ldpc(n=n, m=m, dv=dv)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'regular_648_324_3')

        viz = input("是否生成可视化图像? (y/n): ").lower().strip()
        if viz == 'y':
            generator.visualize_matrix('ldpc_648_visualization.png')

    return generator


def generate_irregular():
    """生成非规则LDPC码"""
    print("\n生成非规则LDPC码...")
    generator = LDPCMatrixGenerator()

    # 度分布:大部分节点度为3,少数度为2和4(示例)
    dv_dist = {2: 100, 3: 700, 4: 200}
    generator.generate_irregular_ldpc(n=1000, m=500, dv_distribution=dv_dist)

    analysis = generator.analyze_matrix()
    print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

    save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
    if save == 'y':
        os.makedirs('ldpc_data', exist_ok=True)
        generator.save_for_nn_training('ldpc_data', 'irregular_1000_500')

    return generator


def generate_custom_regular():
    """生成自定义规则LDPC码"""
    print("\n自定义规则LDPC码生成")
    print("请输入参数(按Enter使用默认值):")

    try:
        n = input("码长 n (默认: 500): ").strip()
        n = int(n) if n else 500

        m = input(f"校验节点数 m (默认: {n // 2}): ").strip()
        m = int(m) if m else n // 2

        dv = input("变量节点度数 dv (默认: 3): ").strip()
        dv = int(dv) if dv else 3

        # 验证参数
        if n <= 0 or m <= 0 or dv <= 0:
            raise ValueError("参数必须为正数")
        if m >= n:
            raise ValueError("校验节点数必须小于码长")

        print(f"\n生成参数: n={n}, m={m}, dv={dv}")

        generator = LDPCMatrixGenerator()
        generator.generate_regular_ldpc(n=n, m=m, dv=dv)

        analysis = generator.analyze_matrix()
        print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

        save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
        if save == 'y':
            name = input("输入数据集名称 (默认: custom_regular): ").strip()
            name = name if name else "custom_regular"
            os.makedirs('ldpc_data', exist_ok=True)
            generator.save_for_nn_training('ldpc_data', name)

        return generator

    except ValueError as e:
        print(f"参数错误: {e}")
        return None


def generate_custom_irregular():
    """生成自定义非规则LDPC码"""
    print("\n自定义非规则LDPC码生成")

    try:
        n = input("码长 n (默认: 500): ").strip()
        n = int(n) if n else 500

        m = input(f"校验节点数 m (默认: {n // 2}): ").strip()
        m = int(m) if m else n // 2

        print(f"\n需要定义{n}个变量节点的度分布")
        print("格式:度数:节点数,用逗号分隔")
        print("例如:2:100,3:300,4:100 表示100个度为2的节点,300个度为3的节点,100个度为4的节点")

        dist_str = input("度分布: ").strip()
        if not dist_str:
            # 默认分布
            dv_dist = {2: n // 5, 3: n * 3 // 5, 4: n // 5}
        else:
            dv_dist = {}
            for pair in dist_str.split(','):
                degree, count = pair.split(':')
                dv_dist[int(degree)] = int(count)

        # 验证度分布
        total_nodes = sum(dv_dist.values())
        if total_nodes != n:
            print(f"警告:度分布总节点数({total_nodes})不等于n({n})")
            if total_nodes < n:
                # 自动补充度为3的节点
                dv_dist[3] = dv_dist.get(3, 0) + (n - total_nodes)
                print(f"自动调整后的度分布: {dv_dist}")
            else:
                raise ValueError("度分布总节点数超过n")

        print(f"\n生成参数: n={n}, m={m}, 度分布={dv_dist}")

        generator = LDPCMatrixGenerator()
        generator.generate_irregular_ldpc(n=n, m=m, dv_distribution=dv_dist)

        analysis = generator.analyze_matrix()
        print(f"分析结果: {json.dumps(analysis, indent=2, ensure_ascii=False)}")

        save = input("\n是否保存生成的LDPC码? (y/n): ").lower().strip()
        if save == 'y':
            name = input("输入数据集名称 (默认: custom_irregular): ").strip()
            name = name if name else "custom_irregular"
            os.makedirs('ldpc_data', exist_ok=True)
            generator.save_for_nn_training('ldpc_data', name)

        return generator

    except ValueError as e:
        print(f"参数错误: {e}")
        return None


def batch_generate():
    """批量生成预设的LDPC码"""
    print("\n批量生成模式")
    print("将生成以下预设的LDPC码:")
    print("1. 小规模规则LDPC码 (64, 32, dv=3)")
    print("2. 中等规模规则LDPC码 (1000, 500, dv=3)")
    print("3. 5G-like LDPC码 (1944, 972, dv=3)")
    print("4. (3,6)规模规则LDPC码 (600, 300, dv=3, dc≈6)")
    print("5. 码长128规则LDPC码 (128, 64, 码率≈0.5)")
    print("6. 码长32规则LDPC码 (32, 16, 码率≈0.5)")
    print("7. 码长8规则LDPC码 (8, 4, 码率≈0.5)")
    print("8. 码长648规则LDPC码 (648, 324, 码率≈0.5)")
    print("9. 非规则LDPC码 (1000, 500)")

    confirm = input("\n确认批量生成? (y/n): ").lower().strip()
    if confirm != 'y':
        return

    os.makedirs('ldpc_data', exist_ok=True)

    generators = []

    try:
        print("\n[1/9] 生成小规模规则LDPC码...")
        gen1 = LDPCMatrixGenerator()
        gen1.generate_regular_ldpc(n=64, m=32, dv=3)
        gen1.save_for_nn_training('ldpc_data', 'small_regular_64_32_3')
        generators.append(("小规模规则", gen1))

        print("\n[2/9] 生成中等规模规则LDPC码...")
        gen2 = LDPCMatrixGenerator()
        gen2.generate_regular_ldpc(n=1000, m=500, dv=3)
        gen2.save_for_nn_training('ldpc_data', 'medium_regular_1000_500_3')
        generators.append(("中等规模规则", gen2))

        print("\n[3/9] 生成5G-like LDPC码...")
        gen3 = LDPCMatrixGenerator()
        gen3.generate_regular_ldpc(n=1944, m=972, dv=3)
        gen3.save_for_nn_training('ldpc_data', '5g_like_1944_972')
        generators.append(("5G-like", gen3))

        print("\n[4/9] 生成(3,6)规模规则LDPC码...")
        gen4 = LDPCMatrixGenerator()
        gen4.generate_regular_ldpc(n=600, m=300, dv=3)
        gen4.save_for_nn_training('ldpc_data', 'regular_600_300_3_6')
        generators.append(("(3,6)规模规则", gen4))

        print("\n[5/9] 生成码长128规则LDPC码...")
        gen5 = LDPCMatrixGenerator()
        gen5.generate_regular_ldpc(n=128, m=64, dv=3)
        gen5.save_for_nn_training('ldpc_data', 'regular_128_64_3')
        generators.append(("码长128规则", gen5))

        print("\n[6/9] 生成码长32规则LDPC码...")
        gen6 = LDPCMatrixGenerator()
        gen6.generate_regular_ldpc(n=32, m=16, dv=3)
        gen6.save_for_nn_training('ldpc_data', 'regular_32_16_3')
        generators.append(("码长32规则", gen6))

        print("\n[7/9] 生成码长8规则LDPC码...")
        gen7 = LDPCMatrixGenerator()
        gen7.generate_regular_ldpc(n=8, m=4, dv=3)
        gen7.save_for_nn_training('ldpc_data', 'regular_8_4_3')
        generators.append(("码长8规则", gen7))

        print("\n[8/9] 生成码长648规则LDPC码...")
        gen8 = LDPCMatrixGenerator()
        gen8.generate_regular_ldpc(n=648, m=324, dv=3)
        gen8.save_for_nn_training('ldpc_data', 'regular_648_324_3')
        generators.append(("码长648规则", gen8))

        print("\n[9/9] 生成非规则LDPC码...")
        gen9 = LDPCMatrixGenerator()
        dv_dist = {2: 100, 3: 700, 4: 200}
        gen9.generate_irregular_ldpc(n=1000, m=500, dv_distribution=dv_dist)
        gen9.save_for_nn_training('ldpc_data', 'irregular_1000_500')
        generators.append(("非规则", gen9))

        print("\n" + "=" * 60)
        print("批量生成完成!生成的LDPC码统计:")
        for name, gen in generators:
            analysis = gen.analyze_matrix()
            print(f"\n{name} LDPC码:")
            print(f"  - 维度: {analysis['dimensions']}")
            print(f"  - 码率: {analysis['code_rate']:.3f}")
            print(f"  - 密度: {analysis['density']:.4f}")

    except Exception as e:
        print(f"\n批量生成过程中出现错误: {e}")
        import traceback
        traceback.print_exc()


def main():
    """主函数 - 交互式LDPC码生成"""

    print("LDPC码生成器启动...")

    while True:
        display_menu()

        try:
            choice = input("请选择操作 (0-12): ").strip()

            if choice == '0':
                print("感谢使用LDPC码生成器!")
                break
            elif choice == '1':
                generate_small_regular()
            elif choice == '2':
                generate_medium_regular()
            elif choice == '3':
                generate_5g_like()
            elif choice == '4':
                generate_36_regular()
            elif choice == '5':
                generate_128_regular()
            elif choice == '6':
                generate_32_regular()
            elif choice == '7':
                generate_8_regular()
            elif choice == '8':
                generate_648_regular()
            elif choice == '9':
                generate_irregular()
            elif choice == '10':
                generate_custom_regular()
            elif choice == '11':
                generate_custom_irregular()
            elif choice == '12':
                batch_generate()
            else:
                print("无效选择,请重新输入!")

        except KeyboardInterrupt:
            print("\n\n用户中断操作")
            break
        except Exception as e:
            print(f"\n操作过程中出现错误: {e}")
            import traceback
            traceback.print_exc()

        # 询问是否继续
        if choice != '0':
            continue_choice = input("\n是否继续使用生成器? (y/n): ").lower().strip()
            if continue_choice != 'y':
                print("感谢使用LDPC码生成器!")
                break


if __name__ == "__main__":
    main()

代码里的 PEG 实现细节

  • PEG.progressive_edge_growth():外层遍历变量节点,内层按度数 degree_sequence[var] 调用

    • k==0find_smallest() 找行度最小的校验节点

    • k>0bfs(var) 搜索可选校验集合,再选行度最小者

  • grow_edge() 同时更新 H、度计数和辅助索引表;

  • _estimate_girth() 用“是否存在双共边的行对”粗判 4 环,用作快速提示

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值