import os, re, glob, math, random, argparse, warnings
from pathlib import Path
import numpy as np
import scipy.io as sio
import scipy.signal as sig
from scipy.stats import kurtosis
import yaml
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from collections import defaultdict
warnings.filterwarnings("ignore")
# ========== 基础工具 ==========
def load_cfg(p):
with open(p, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def ensure_1d(x):
if x is None: return None
x = np.asarray(x)
if x.ndim == 1:
return x.astype(np.float64)
return x.reshape(-1).astype(np.float64)
def detrend(x):
return sig.detrend(x, type="linear")
def bandpass(x, fs, f1, f2, order=4):
f2 = min(f2, 0.45*fs)
if f1 >= f2:
f1, f2 = max(5.0, min(f1, 0.45*fs-10)), max(50.0, min(f2, 0.45*fs-5))
sos = sig.butter(order, [f1, f2], btype="bandpass", fs=fs, output="sos")
return sig.sosfiltfilt(sos, x)
def envelope(x):
return np.abs(sig.hilbert(x))
def zscore(x, eps=1e-8):
std = x.std()
if not np.isfinite(std) or std < eps:
return np.zeros_like(x)
return (x - x.mean()) / (std + eps)
def resample_to_uniform(x, fs_src, fs_tgt):
x = np.asarray(x, dtype=np.float64)
if x.size == 0 or fs_src <= 0 or fs_tgt <= 0 or abs(fs_src - fs_tgt) < 1e-9:
return x
n_out = max(1, int(round(len(x) * float(fs_tgt) / float(fs_src))))
t_src = np.linspace(0.0, (len(x)-1)/fs_src, num=len(x), endpoint=True)
t_new = np.linspace(0.0, (len(x)-1)/fs_src, num=n_out, endpoint=True)
return np.interp(t_new, t_src, x).astype(np.float64)
def windowing(x, fs, win_sec=1.0, overlap=0.5, drop_edges_sec=0.0):
x = x[int(drop_edges_sec*fs):]
step = int(win_sec*fs*(1-overlap))
L = int(win_sec*fs)
out = []
for s in range(0, max(1, len(x)-L+1), step if step>0 else L):
seg = x[s:s+L]
if len(seg)==L: out.append(seg)
if not out and len(x)>0:
out = [np.pad(x, (0, L-len(x)))[:L]]
return out
def spectral_kurtosis_band(x, fs, nfft=2048, hop=None, fmin=50, fmax=None, topk=1, bw_frac=0.15):
x = np.asarray(x, dtype=np.float64)
if hop is None: hop = nfft // 4
if fmax is None: fmax = fs*0.45
if len(x) < nfft:
return [], None, None
win = np.hanning(nfft)
specs = []
for i in range(0, len(x)-nfft+1, hop):
seg = x[i:i+nfft] * win
X = np.fft.rfft(seg)
P = np.abs(X)**2
specs.append(P)
if len(specs) == 0:
return [], None, None
S = np.stack(specs, 0)
freqs = np.fft.rfftfreq(nfft, 1.0/fs)
m = (freqs >= fmin) & (freqs <= fmax)
S = S[:, m]; freqs = freqs[m]
if S.shape[1] < 10:
return [], None, None
mu1 = S.mean(axis=0)
var = ((S - mu1)**2).mean(axis=0) + 1e-12
mu4 = ((S - mu1)**4).mean(axis=0)
sk = mu4 / (var**2) - 3.0
idx = np.argsort(sk)[::-1][:topk]
bands = []
for i in idx:
fc = freqs[i]
bw = max(20.0, fc * bw_frac)
f1 = max(freqs[0], fc - bw/2.0)
f2 = min(freqs[-1], fc + bw/2.0)
if f2 > f1 + 5.0:
bands.append((float(f1), float(f2)))
return bands, freqs, sk
def order_resample(x, fs, rpm, spr=200):
if rpm is None or rpm <= 0:
return x, fs
fr = rpm / 60.0
T = len(x) / fs
n_rev = fr * T
if n_rev < 1e-3:
return x, fs
N = int(max(4*spr, n_rev * spr))
t = np.arange(len(x)) / fs
theta = 2*np.pi*fr*t
theta_target = np.linspace(theta[0], theta[-1], N)
x_res = np.interp(theta_target, theta, x).astype(np.float64)
fs_equiv = spr * fr
return x_res, fs_equiv
def parse_rpm_from_name(fn):
m = re.search(r"\((\d+)\s*rpm\)", fn.replace("RPM","rpm"))
if m: return float(m.group(1))
return None
def read_mat_any(path):
mat = sio.loadmat(path, squeeze_me=True, struct_as_record=False)
keys_map = {k.lower(): k for k in mat.keys()}
def pick_any(*cands):
for c in cands:
lc = c.lower()
if lc in keys_map:
return mat[keys_map[lc]]
return None
de = pick_any('DE','DE_value','X_DE','DE_time','x_de','x_de_time')
fe = pick_any('FE','FE_value','X_FE','FE_time','x_fe','x_fe_time')
ba = pick_any('BA','BA_value','X_BA','BA_time','x_ba','x_ba_time')
time = pick_any('time','t','time_series','DE_time','FE_time','BA_time','X_DE_time','X_FE_time','X_BA_time')
rpm = pick_any('rpm','x118rpm','speed','rot_rpm')
# 兜底:挑最长 ndarray 为 DE
if all(v is None for v in [de,fe,ba]):
cands = [v.reshape(-1) for v in mat.values() if isinstance(v, np.ndarray) and v.size>100]
if cands: de = max(cands, key=lambda a: a.size)
return {"DE":ensure_1d(de), "FE":ensure_1d(fe), "BA":ensure_1d(ba), "time":ensure_1d(time), "RPM":rpm}
def bearing_freqs(rpm, n, d, D, theta_deg=0.0):
fr = float(rpm) / 60.0
c = (d / D) * math.cos(math.radians(theta_deg))
bpfo = 0.5 * n * fr * (1 - c)
bpfi = 0.5 * n * fr * (1 + c)
bsf = (D/(2*d)) * fr * (1 - c**2)
ftf = 0.5 * fr * (1 - c)
return fr, bpfo, bpfi, bsf, ftf
def is_under_dir(child_path, parent_dir):
try:
child = Path(child_path).resolve()
parent = Path(parent_dir).resolve()
return parent in child.parents or child == parent
except Exception:
return False
# ========== 标签推断 ==========
def infer_label_from_path(p):
p_norm = p.replace("\\","/").lower()
if "/b/" in p_norm or os.sep+"b"+os.sep in p_norm: return "B"
if "/ir/" in p_norm or os.sep+"ir"+os.sep in p_norm: return "IR"
if "/or/" in p_norm or os.sep+"or"+os.sep in p_norm: return "OR"
if "normal" in p_norm or "n_" in os.path.basename(p_norm): return "N"
if "normal_data" in p_norm: return "N"
return "UNK"
# ========== 预处理主流程 ==========
def choose_rpm(rec, fpath, cfg, rpm_pref=None):
# 1) mat 变量
rr = rec.get("RPM")
if rr is not None and np.isscalar(rr):
try:
val = float(rr)
if val > 0: return val
except:
pass
# 2) 文件名
val = parse_rpm_from_name(Path(fpath).name)
if val is not None and val > 0: return val
# 3) 外部优先提示(目标域 rpm_approx 等)
if rpm_pref is not None and rpm_pref > 0:
return float(rpm_pref)
# 4) 源域候选
cand = cfg.get("conditions",{}).get("rpm_source_candidates", [1797,1772,1750,1730])
return float(cand[0]) if len(cand)>0 else None
def preprocess_one_file(fpath, cfg, out_dir, fs_out=32000, use_order=True, rpm_pref=None, verbose=False):
rec = read_mat_any(fpath)
# 选通道(优先 DE,其次 FE,再 BA)
x = rec["DE"]; used_ch = "DE"
if x is None:
x = rec["FE"]; used_ch = "FE"
if x is None:
x = rec["BA"]; used_ch = "BA"
if x is None or x.size < 10:
return None
# 源采样率估计:time -> fs 或文件名猜测(48kHz/12kHz)
fs_src = None
if rec["time"] is not None and len(rec["time"])>1:
dt = float(np.median(np.diff(rec["time"])))
if dt > 0: fs_src = 1.0/dt
if fs_src is None:
fs_src = 48000.0 if "48khz" in fpath.lower() else 12000.0
# 统一重采样到 fs_out
x = resample_to_uniform(x, fs_src, fs_out)
fs_eff = fs_out
# 去趋势
if cfg["preprocess_default"].get("detrend", True):
x = detrend(x)
# 阶次分析(常速/近似常速),并回采样到 fs_out
rpm_val = choose_rpm(rec, fpath, cfg, rpm_pref=rpm_pref)
if use_order and (rpm_val is not None) and (rpm_val>0):
x_ord, fs_ord = order_resample(x, fs_eff, rpm_val, spr=int(cfg["preprocess_default"].get("order_spr",200)))
x = resample_to_uniform(x_ord, fs_ord, fs_out) # 回到统一 fs
fs_eff = fs_out
# 谱峭度 -> 自适应带通
pp = cfg["preprocess_default"]
use_sk = pp.get("use_spectral_kurtosis", True)
if use_sk:
bands, freqs, sk = spectral_kurtosis_band(
x, fs_eff,
nfft=pp.get("sk_nfft", 2048),
fmin=pp.get("sk_fmin", 50),
fmax=pp.get("sk_fmax", int(0.45*fs_eff)),
topk=pp.get("sk_topk", 1),
bw_frac=pp.get("sk_bw_frac", 0.15)
)
if len(bands)>0:
f1, f2 = bands[0]
else:
f1, f2 = pp.get("bandpass_hz",[500,10000])
else:
f1, f2 = pp.get("bandpass_hz",[500,10000])
x_bp = bandpass(x, fs_eff, f1, f2)
x_env = envelope(x_bp)
x_norm = zscore(x_env) if pp.get("normalize","zscore")=="zscore" else x_env
# 切片
seg = pp.get("segment", {"win_sec":1.0,"overlap":0.5,"drop_edges_sec":0.0})
wins = windowing(x_norm, fs_eff, seg.get("win_sec",1.0), seg.get("overlap",0.5), seg.get("drop_edges_sec",0.0))
if not wins:
return None
# 保存 .npy(每个窗口一个文件),生成索引
os.makedirs(out_dir, exist_ok=True)
label = infer_label_from_path(fpath) # "B"/"IR"/"OR"/"N"/"UNK"
base = Path(fpath).stem
rows = []
for i,w in enumerate(wins):
outp = os.path.join(out_dir, f"{base}__{i:04d}.npy")
np.save(outp, w.astype(np.float32))
rows.append((outp, label, used_ch, f1, f2, rpm_val, fs_eff))
# 返回可视化需要的中间结果
return {
"rows": rows,
"raw": x, "bp": x_bp, "env": x_env, "norm": x_norm,
"fs": fs_eff, "band": (f1,f2),
"sk": (freqs, sk) if use_sk else (None, None),
"label": label, "used_ch": used_ch
}
# ========== 可视化 ==========
def plot_one_sample(figdir, recinfo, cfg, title_hint=""):
os.makedirs(figdir, exist_ok=True)
x, x_bp, x_env, fs = recinfo["raw"], recinfo["bp"], recinfo["env"], recinfo["fs"]
label, band = recinfo["label"], recinfo["band"]
freqs, sk = recinfo["sk"]
# 1) 原始 vs 预处理后(时域)
T = np.arange(len(x))/fs
plt.figure(figsize=(10,3))
plt.plot(T, x, lw=0.6)
plt.title(f"Raw waveform ({title_hint})")
plt.xlabel("Time (s)"); plt.ylabel("Amplitude"); plt.tight_layout()
plt.savefig(os.path.join(figdir, f"{title_hint}_raw.png")); plt.close()
T2 = np.arange(len(x_bp))/fs
plt.figure(figsize=(10,3))
plt.plot(T2, x_bp, lw=0.6)
plt.title(f"Bandpassed waveform {band} Hz ({title_hint})")
plt.xlabel("Time (s)"); plt.ylabel("Amplitude"); plt.tight_layout()
plt.savefig(os.path.join(figdir, f"{title_hint}_bandpassed.png")); plt.close()
plt.figure(figsize=(10,3))
plt.plot(T2, x_env, lw=0.6)
plt.title(f"Envelope (Hilbert) ({title_hint})")
plt.xlabel("Time (s)"); plt.ylabel("Amplitude"); plt.tight_layout()
plt.savefig(os.path.join(figdir, f"{title_hint}_envelope.png")); plt.close()
# 2) PSD
f, Pxx = sig.welch(x_bp, fs=fs, nperseg=4096)
plt.figure(figsize=(6,4))
plt.semilogy(f, Pxx + 1e-12)
plt.title(f"PSD (bandpassed) ({title_hint})"); plt.xlabel("Hz"); plt.ylabel("PSD")
plt.tight_layout(); plt.savefig(os.path.join(figdir, f"{title_hint}_psd.png")); plt.close()
# 3) 包络谱
N = 1<<int(np.ceil(np.log2(len(x_env))))
Ef = np.fft.rfftfreq(N, 1.0/fs)
Es = np.abs(np.fft.rfft(x_env, N))
plt.figure(figsize=(8,4))
plt.plot(Ef, Es, lw=0.7)
plt.xlim(0, min(5000, fs*0.45))
plt.title(f"Envelope spectrum ({title_hint})"); plt.xlabel("Hz"); plt.ylabel("|E(f)|")
# 叠加轴承故障特征频率(用 DE 几何)
geom = cfg["bearings"]["DE"]
rpm_approx = recinfo["rows"][0][5] or cfg["conditions"]["rpm_source_candidates"][0]
_, bpfo, bpfi, bsf, ftf = bearing_freqs(rpm_approx, geom["n"], geom["d_in"], geom["D_in"], geom["theta_deg"])
for name,base in [("BPFI",bpfi),("BPFO",bpfo),("BSF",bsf),("FTF",ftf)]:
for k in [1,2]:
f0 = k*base
if f0 < fs*0.45:
plt.axvline(f0, ls="--", lw=0.7, color="r")
plt.text(f0, max(Es)*0.05, f"{name}{k}", rotation=90, va="bottom", ha="right", fontsize=8)
plt.tight_layout(); plt.savefig(os.path.join(figdir, f"{title_hint}_envelope_spectrum.png")); plt.close()
# 4) 谱峭度(若开启)
if freqs is not None and sk is not None:
plt.figure(figsize=(8,3))
plt.plot(freqs, sk, lw=0.8)
plt.title(f"Spectral kurtosis ({title_hint})")
plt.xlabel("Hz"); plt.ylabel("SK")
if band:
plt.axvspan(band[0], band[1], color="orange", alpha=0.25, label=f"Band {band[0]:.0f}-{band[1]:.0f} Hz")
plt.legend()
plt.tight_layout(); plt.savefig(os.path.join(figdir, f"{title_hint}_spectral_kurtosis.png")); plt.close()
def tsne_overview(emb_list, lab_list, out_png):
if len(emb_list)==0:
return
X = np.vstack(emb_list)
L = np.array(lab_list)
tsne = TSNE(n_components=2, perplexity=30, learning_rate="auto", init="pca")
Z = tsne.fit_transform(X)
plt.figure(figsize=(6,5))
for cls, col in zip(["OR","IR","B","N","UNK"], ["tab:blue","tab:orange","tab:green","tab:red","tab:gray"]):
m = (L==cls)
if m.sum()>0:
plt.scatter(Z[m,0], Z[m,1], s=10, alpha=0.7, label=cls, c=col)
plt.legend(); plt.title("t-SNE on log-magnitude spectrum features")
plt.tight_layout(); plt.savefig(out_png); plt.close()
# ========== 文件收集(source / target / all)==========
def collect_source_files(cfg):
root = cfg["paths"]["source_dir"]
files = []
layout = cfg["source_domain"]["folder_layout"]
for block in layout:
bpath = os.path.join(root, block["path"])
if "classes" in block:
for c in block["classes"]:
files += glob.glob(os.path.join(bpath, c, "**", "*.mat"), recursive=True)
if "files" in block:
files += [os.path.join(bpath, f) for f in block["files"]]
return sorted(files)
def collect_target_files(cfg):
root = cfg["paths"]["target_dir"]
pat = cfg.get("target_domain", {}).get("files_pattern", "[A-P].mat")
files = glob.glob(os.path.join(root, pat))
if not files:
files = glob.glob(os.path.join(root, "*.mat"))
return sorted(files)
def collect_all_files(cfg):
# 递归 root_dir 下所有 .mat(包含源域 + 目标域)
root = cfg["paths"]["root_dir"]
return sorted(glob.glob(os.path.join(root, "**", "*.mat"), recursive=True))
# ========== 主流程 ==========
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--cfg", type=str, required=True)
ap.add_argument("--out", type=str, default="./preprocessed")
ap.add_argument("--max_per_class", type=int, default=999999)
ap.add_argument("--fs_out", type=float, default=32000)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--scope", type=str, choices=["source","target","all"], default="source",
help="选择处理范围:source=仅源域;target=仅目标域;all=根目录下所有 .mat")
args = ap.parse_args()
random.seed(args.seed); np.random.seed(args.seed)
cfg = load_cfg(args.cfg)
fs_out = int(cfg["target_domain"]["fs_hz"]) if args.fs_out is None else int(args.fs_out)
# 根据 scope 选择文件
if args.scope == "source":
mat_files = collect_source_files(cfg)
elif args.scope == "target":
mat_files = collect_target_files(cfg)
else:
mat_files = collect_all_files(cfg)
print(f"[INFO] found mat files: {len(mat_files)} (scope={args.scope})")
os.makedirs(args.out, exist_ok=True)
index_rows = ["file,window,label,channel,f1,f2,rpm,fs"]
by_label_counter = defaultdict(int)
tsne_feats, tsne_labels = [], []
per_class_limit = args.max_per_class
random.shuffle(mat_files)
viz_keep = defaultdict(list)
target_dir = cfg["paths"]["target_dir"]
target_rpm_approx = cfg.get("target_domain", {}).get("rpm_approx", None)
for fp in mat_files:
lab = infer_label_from_path(fp)
if lab not in ["OR","IR","B","N","UNK"]:
lab = "UNK"
if by_label_counter[lab] >= per_class_limit:
continue
# 目标域优先使用 rpm_approx;否则走通用选择
rpm_pref = None
if args.scope in ("target", "all"):
if is_under_dir(fp, target_dir) and target_rpm_approx is not None:
try:
rpm_pref = float(target_rpm_approx)
except:
rpm_pref = None
info = preprocess_one_file(
fp, cfg, out_dir=args.out,
fs_out=fs_out,
use_order=cfg["preprocess_default"].get("use_order_tracking", True),
rpm_pref=rpm_pref
)
if info is None:
continue
for (wpath, label, ch, f1, f2, rpm, fs) in info["rows"]:
index_rows.append(f"{wpath},{label},{ch},{f1:.2f},{f2:.2f},{rpm},{fs}")
by_label_counter[lab] += 1
# 取少量窗口做 t-SNE 的简易频谱特征
for r in info["rows"][:3]:
w = np.load(r[0])
N = 1<<int(np.ceil(np.log2(len(w))))
Wf = np.fft.rfft(w, N)
feat = np.log1p(np.abs(Wf))[:2048]
tsne_feats.append(feat.astype(np.float32))
tsne_labels.append(lab)
if len(viz_keep[lab]) < 2:
viz_keep[lab].append((fp, info))
with open(os.path.join(args.out, "index.csv"), "w", encoding="utf-8") as f:
f.write("\n".join(index_rows))
print(f"[INFO] saved index.csv with {len(index_rows)-1} rows")
figdir = os.path.join(args.out, "figs")
for lab, lst in viz_keep.items():
for _, (fp, info) in enumerate(lst):
title = f"{lab}_{Path(fp).stem}"
plot_one_sample(figdir, info, cfg, title_hint=title)
if len(tsne_feats) > 10:
tsne_overview(tsne_feats, tsne_labels, os.path.join(args.out, "tsne_overview.png"))
print("[INFO] saved tsne_overview.png")
print("[COUNT] per class (processed files):", dict(by_label_counter))
print(f"[DONE] preprocessed windows saved to: {args.out}")
if __name__ == "__main__":
main()
这段python代码有错误:usage: preprocess_and_viz.py [-h] --cfg CFG [--out OUT]
[--max_per_class MAX_PER_CLASS] [--fs_out FS_OUT]
[--seed SEED] [--scope {source,target,all}]
preprocess_and_viz.py: error: the following arguments are required: --cfg
怎么修改
最新发布