0. 前言本文介绍了HEWLHigh-Frequency Enhanced Wavelet Layer高频增强小波层其通过浅层高频细节引导深层语义特征重建与交替通道注意力机制首次在深度特征解码阶段实现对目标边缘纹理的精准恢复与语义语义模糊的有效破解从根本上解决了深层网络中连续池化操作导致的高频信息丢失与轮廓退化难题。将其作为即插即用模块轻松助力CNN、YOLO、Transformer等深度学习模型精准恢复目标边缘结构、增强深度语义表征让模型在面对红外弱目标、低对比度场景、密集小目标或复杂背景干扰等挑战性场景时依然能够保持清晰的边界感知与稳定的检测精度。专栏链接即插即用系列专栏链接可点击跳转免费订阅目录0. 前言1. HEWL模块简介2. HEWL模块原理与创新点 HEWL模块基本原理 HEWL模块创新点3. 适用范围与模块效果适用范围⚡模块效果4. HEWL模块代码实现1. HEWL模块简介红外小目标检测IRSTD对军事安全应用具有关键意义。尽管U型架构提升了基准性能现有方法仍存在两大核心局限1对微小目标的空间感知能力不足导致目标定位丢失2深度特征重建过程中存在边缘退化与语义模糊问题。为解决这些问题我们提出PQGNet架构其中设计了最大池化-小波混合层MWHL与高频增强小波层HEWL利用离散小波变换特性通过浅层高频细节强化深度语义表征。区别于现有仅替换池化层的小波方法HEWL通过显式保留并增强深度重建过程中的高频边缘细节有效解决低-频退化问题显著提升红外小目标的边缘重建质量与检测精度。原始论文https://doi.org/10.1109/TGRS.2026.3654433原始代码https://github.com/PepperCS/PQGNet2. HEWL模块原理与创新点 HEWL模块基本原理HEWL的核心设计思想源于对U型网络中深度特征重建过程的深入分析。在传统的编码器-解码器结构中随着网络层数的加深连续的池化下采样与双线性上采样操作会导致目标边缘信息不可逆地丢失产生边缘模糊问题。即便引入跳跃连接Skip Connection深层特征图中的高频细节信息仍然严重衰减难以恢复目标的精细轮廓。HEWL通过引入离散小波变换DWT与逆离散小波变换IDWT构建了一条高频信息增强通路使浅层提取的边缘细节能够有效注入深层语义特征图中。HEWL的实现过程可分为以下三个关键步骤1高频分量提取与噪声抑制HEWL首先对输入特征进行离散小波变换分离出高频分量LH、HL、HH与低频分量LL。与传统小波方法不同HEWL没有直接丢弃低频分量或简单相加而是通过一个交替通道注意力模块ALCA对原始特征进行通道维度的增强然后利用生成的空间注意力权重对高频分量进行加权滤波从而在增强关键边缘响应的同时抑制背景噪声干扰。2浅层高频信息引导的深度重建HEWL的核心创新在于将浅层特征中提取的高频细节信息与深层语义特征进行融合。具体而言浅层特征经过ALCA模块增强后其高频分量通过IDWT与深层特征进行重构。这种设计使深层特征图能够“看到”浅层保留的边缘纹理信息从而在语义抽象的深层特征中仍然保持对目标边界的感知能力。3残差式特征融合为避免高频注入过程中原始信息的丢失HEWL采用了残差连接策略将重建后的特征与原始输入特征沿通道维度拼接再通过残差模块进行特征压缩与增强。这种设计确保了模块在注入高频信息的同时不会破坏原始的深层语义特征实现了语义信息与细节信息的高效互补。 HEWL模块创新点高频信息显式增强区别于传统小波方法仅替代池化层HEWL通过浅层高频分量引导深层特征重建显式解决深度网络中的低频退化问题。交替通道注意力机制ALCA融合全局平均池化、全局最大池化与全局标准差池化三种统计信息构建更全面的通道描述子提升通道表征能力。空间注意力引导的高频滤波在将高频分量注入深层特征前通过空间注意力进行噪声抑制避免背景噪声被同步放大。残差式特征融合架构通过拼接与残差块的组合设计实现语义信息与高频细节的高效互补避免信息丢失。3. 适用范围与模块效果适用范围HEWL模块适用于通用视觉领域特别是有高频细节敏感与边缘重建需求的视觉任务。适用原因解析红外弱小目标检测目标往往仅占几个像素且边缘模糊需要保留高频细节辅助定位与判别。低对比度场景传统方法易受背景噪声干扰HEWL的空间注意力滤波与高频增强机制可提升信噪比。密集小目标检测多次下采样易导致小目标特征丢失HEWL的联合重建机制能有效保留小目标的细节信息。边缘部署场景HEWL在不过多增加计算量的前提下显著提升检测精度适合计算资源受限的边缘设备。⚡模块效果模块效果性能与视觉效果SOTA。消融实验第三行和第四行对比添加HEWL后性能提升说明HEWL的有效性。总结HEWL模块的消融实验与泛化实验均证实其能够有效提升各类检测模型对红外小目标的边缘重建能力与检测精度且具有良好的模型无关性与即插即用特性。4. HEWL模块代码实现以下为HEWL模块的官方pytorch实现代码 HEWL: High-frequency Enhanced Wavelet Layer 高频增强小波层 核心设计基于“小波域分解→高频增强→通道-空间联合注意力→逆小波重建”的流程 通过小波变换将特征解耦为低频和高频子带利用 ALCA 注意力增强高频细节 提升特征的结构表达能力与细节还原能力 参考 CGTA 的接口设计支持 1. 直接输入 2D 特征图 [B, C, H, W] 2. 输入序列格式 [B, N, C] H, W 参数 import torch import torch.nn as nn from torch.nn import functional as F import math import pywt from torch.autograd import Function # 可微分小波变换模块 class DWT_Function(Function): 可微DWT前向/反向传播 staticmethod def forward(ctx, x, w_ll, w_lh, w_hl, w_hh): x x.contiguous() ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh) ctx.shape x.shape dim x.shape[1] x_ll torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride2, groupsdim) x_lh torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride2, groupsdim) x_hl torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride2, groupsdim) x_hh torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride2, groupsdim) x torch.cat([x_ll, x_lh, x_hl, x_hh], dim1) return x staticmethod def backward(ctx, dx): if ctx.needs_input_grad[0]: w_ll, w_lh, w_hl, w_hh ctx.saved_tensors B, C, H, W ctx.shape dx dx.view(B, 4, -1, H // 2, W // 2) dx dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2) filters torch.cat([w_ll, w_lh, w_hl, w_hh], dim0).repeat(C, 1, 1, 1) dx torch.nn.functional.conv_transpose2d(dx, filters, stride2, groupsC) return dx, None, None, None, None class IDWT_Function(Function): 可微IDWT实现 staticmethod def forward(ctx, x, filters): ctx.save_for_backward(filters) ctx.shape x.shape B, _, H, W x.shape x x.view(B, 4, -1, H, W).transpose(1, 2) C x.shape[1] x x.reshape(B, -1, H, W) filters filters.repeat(C, 1, 1, 1) x torch.nn.functional.conv_transpose2d(x, filters, stride2, groupsC) return x staticmethod def backward(ctx, dx): if ctx.needs_input_grad[0]: filters ctx.saved_tensors[0] B, C, H, W ctx.shape C C // 4 dx dx.contiguous() w_ll, w_lh, w_hl, w_hh torch.unbind(filters, dim0) x_ll torch.nn.functional.conv2d(dx, w_ll.unsqueeze(1).expand(C, -1, -1, -1), stride2, groupsC) x_lh torch.nn.functional.conv2d(dx, w_lh.unsqueeze(1).expand(C, -1, -1, -1), stride2, groupsC) x_hl torch.nn.functional.conv2d(dx, w_hl.unsqueeze(1).expand(C, -1, -1, -1), stride2, groupsC) x_hh torch.nn.functional.conv2d(dx, w_hh.unsqueeze(1).expand(C, -1, -1, -1), stride2, groupsC) dx torch.cat([x_ll, x_lh, x_hl, x_hh], dim1) return dx, None class DWT_2D(nn.Module): 2D离散小波变换封装 def __init__(self, wavehaar): super(DWT_2D, self).__init__() # 确保 wave 是字符串类型 if isinstance(wave, int): wave haar # 默认使用 haar 小波 elif not isinstance(wave, str): wave str(wave) # 支持的小波类型映射 wavelet_map { haar: haar, db1: db1, db2: db2, bior1.1: bior1.1, bior2.2: bior2.2, 1: haar, # 整数1映射为haar 0: haar, } wave_str wavelet_map.get(wave, wave if isinstance(wave, str) else haar) try: w pywt.Wavelet(wave_str) except: # 如果小波不可用使用默认的 haar w pywt.Wavelet(haar) dec_hi torch.Tensor(w.dec_hi[::-1]) dec_lo torch.Tensor(w.dec_lo[::-1]) w_ll dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) w_lh dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) w_hl dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) w_hh dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) self.register_buffer(w_ll, w_ll.unsqueeze(0).unsqueeze(0)) self.register_buffer(w_lh, w_lh.unsqueeze(0).unsqueeze(0)) self.register_buffer(w_hl, w_hl.unsqueeze(0).unsqueeze(0)) self.register_buffer(w_hh, w_hh.unsqueeze(0).unsqueeze(0)) self.w_ll self.w_ll.to(dtypetorch.float32) self.w_lh self.w_lh.to(dtypetorch.float32) self.w_hl self.w_hl.to(dtypetorch.float32) self.w_hh self.w_hh.to(dtypetorch.float32) def forward(self, x): return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh) class IDWT_2D(nn.Module): 2D逆小波变换封装 def __init__(self, wavehaar): super(IDWT_2D, self).__init__() # 确保 wave 是字符串类型 if isinstance(wave, int): wave haar elif not isinstance(wave, str): wave str(wave) wavelet_map { haar: haar, db1: db1, db2: db2, bior1.1: bior1.1, bior2.2: bior2.2, 1: haar, 0: haar, } wave_str wavelet_map.get(wave, wave if isinstance(wave, str) else haar) try: w pywt.Wavelet(wave_str) except: w pywt.Wavelet(haar) rec_hi torch.Tensor(w.rec_hi) rec_lo torch.Tensor(w.rec_lo) w_ll rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1) w_lh rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1) w_hl rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1) w_hh rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1) w_ll w_ll.unsqueeze(0).unsqueeze(1) w_lh w_lh.unsqueeze(0).unsqueeze(1) w_hl w_hl.unsqueeze(0).unsqueeze(1) w_hh w_hh.unsqueeze(0).unsqueeze(1) filters torch.cat([w_ll, w_lh, w_hl, w_hh], dim0) self.register_buffer(filters, filters) self.filters self.filters.to(dtypetorch.float32) def forward(self, x): return IDWT_Function.apply(x, self.filters) # 辅助模块 class ResidualLeakBlock(nn.Module): 带LeakyReLU的残差卷积块 def __init__(self, in_channels, out_channels, stride1): super().__init__() self.proj nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasTrue), nn.BatchNorm2d(out_channels), nn.LeakyReLU() ) self.body nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size5, stridestride, padding2, biasTrue), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2, inplaceTrue), nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasTrue), nn.BatchNorm2d(out_channels), ) self.relu nn.LeakyReLU(0.2, inplaceTrue) def forward(self, x): x self.proj(x) residual x x self.body(x) out self.relu(x residual) return out class AlternateCat(nn.Module): 交替通道拼接模块 def __init__(self, dim1, num3): super().__init__() self.dim dim self.num num def forward(self, x_list): assert len(x_list) self.num, f输入张量数量错误 for i in range(self.num): assert x_list[0].shape x_list[i].shape, f第{i}个输入张量尺寸不匹配 size x_list[0].size(self.dim) x_list_slices [] for i in range(self.num): x_list_slices.append(torch.split(x_list[i], 1, dimself.dim)) interleaved_slices [] for i in range(size): for j in range(self.num): interleaved_slices.append(x_list_slices[j][i]) concatenated torch.cat(interleaved_slices, dimself.dim) return concatenated class AlCattention(nn.Module): ALCA: 交替通道-空间联合注意力 def __init__(self, dim): super(AlCattention, self).__init__() self.dim dim self.avg_pool nn.AdaptiveAvgPool2d(1) self.max_pool nn.AdaptiveMaxPool2d(1) self.alcat AlternateCat(dim1, num3) self.share_mlp nn.Sequential( nn.Conv2d(dim * 3, dim, kernel_size1, stride1, padding0, biasTrue, groupsself.dim), nn.ReLU(inplaceTrue), nn.Conv2d(dim, dim, kernel_size1, stride1, padding0, biasTrue), nn.Sigmoid() ) self.spconv nn.Sequential( nn.Conv2d(2, 1, kernel_size3, stride1, padding1, biasTrue), nn.Sigmoid() ) def forward(self, x, highf): # 通道注意力 x_avg self.avg_pool(x) x_max self.max_pool(x) std torch.std(x, dim(2, 3), keepdimTrue) x_ams self.alcat([x_avg, x_max, std]) channel_weights self.share_mlp(x_ams) x_1 x * channel_weights x # 空间注意力 avg_out torch.mean(x_1, dim1, keepdimTrue) max_out, _ torch.max(x_1, dim1, keepdimTrue) spatial_features torch.cat([avg_out, max_out], dim1) spatial_weights self.spconv(spatial_features) # 高频子带增强 dwt [x_1] for i in range(len(highf)): if spatial_weights.shape[-2:] ! highf[i].shape[-2:]: spatial_weights_resized F.interpolate( spatial_weights, sizehighf[i].shape[-2:], modebilinear, align_cornersFalse ) else: spatial_weights_resized spatial_weights temp highf[i] * spatial_weights_resized.expand_as(highf[i]) highf[i] dwt.append(temp) dwt torch.cat(dwt, dim1) return dwt # HEWL 核心模块 class HEWL(nn.Module): HEWL: High-frequency Enhanced Wavelet Layer 高频增强小波层 核心设计小波域高频增强 通道-空间联合注意力 逆小波重建 通过小波变换解耦特征高低频利用 ALCA 注意力增强高频细节 提升特征的结构表达能力与细节还原能力 Args: dim: 输入/输出特征通道数 wave: 小波基类型可选 haar, db1, bior1.1 等默认 haar use_skip: 是否使用跳跃连接残差结构默认 True def __init__(self, dim, wavehaar, use_skipTrue): super().__init__() self.dim dim self.use_skip use_skip # 确保 wave 是字符串 if isinstance(wave, int): wave haar elif wave is None: wave haar # 小波变换模块 self.dwt DWT_2D(wave) self.idwt IDWT_2D(wave) # ALCA 交替通道-空间注意力模块 self.alca AlCattention(dimdim) # 重建特征卷积增强 self.conv nn.Sequential( nn.Conv2d(dim, dim, kernel_size3, stride1, padding1, biasTrue), nn.BatchNorm2d(dim), nn.LeakyReLU(inplaceTrue) ) # 最终残差增强块 self.out ResidualLeakBlock(dim * 2, dim) # 跳跃连接投影如果输入输出维度不匹配 if use_skip: self.skip_proj nn.Identity() # 默认恒等映射 def forward(self, x, HNone, WNone): 前向传播支持两种输入模式 模式1 - 2D 特征图模式 x: [B, C, H, W] 返回: [B, C, H, W] 增强后的特征图 模式2 - 序列模式 x: [B, N, C] 序列特征 H, W: 特征图的高度和宽度 返回: [B, N, C] 增强后的序列特征 # 模式1: 直接输入 2D 特征图 [B, C, H, W] if H is None and W is None and len(x.shape) 4: B, C, H, W x.shape out self.forward_2d(x, H, W) return out # 模式2: 序列模式 [B, N, C] H, W elif H is not None and W is not None and len(x.shape) 3: return self.forward_seq(x, H, W) # Ultralytics 框架可能传入 [B, C, H, W] 但 H,W 不为 None elif len(x.shape) 4: B, C, H, W x.shape out self.forward_2d(x, H, W) return out else: raise ValueError(fUnsupported input format: x.shape{x.shape}, H{H}, W{W}) def forward_2d(self, x, H, W): 2D 特征图前向传播 B, C, H, W x.shape # 转换为序列格式处理 x_seq x.flatten(2).transpose(1, 2) # [B, H*W, C] out_seq self.forward_seq(x_seq, H, W) # 转换回 2D 格式 out out_seq.transpose(1, 2).view(B, C, H, W).contiguous() return out def forward_seq(self, x, H, W): 序列格式前向传播 x: [B, N, C] 序列特征 H, W: 特征图尺寸 B, N, C x.shape # 转换为 2D 格式用于小波变换 x_2d x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # 保存输入用于残差连接 identity x_2d # 步骤1: 小波变换分解 x_dwt self.dwt(x_2d) # [B, 4*C, H/2, W/2] x_ll, x_lh, x_hl, x_hh x_dwt.split(C, 1) highf_subbands [x_lh, x_hl, x_hh] # 步骤2: ALCA 注意力增强 x_alca self.alca(x_ll, highf_subbands) # [B, 4*C, H/2, W/2] # 步骤3: 逆小波变换重建 x_recon self.idwt(x_alca) # [B, C, H, W] # 步骤4: 重建特征增强 x_enhanced self.conv(x_recon) # 步骤5: 残差融合 x_cat torch.cat((identity, x_enhanced), dim1) # [B, 2*C, H, W] x_out self.out(x_cat) # [B, C, H, W] # 如果使用跳跃连接添加残差 if self.use_skip: x_out x_out identity # 转换回序列格式 out x_out.flatten(2).transpose(1, 2) # [B, N, C] return out # HEWL2D 简化包装器 class HEWL2D(nn.Module): HEWL 的 2D 简化包装器 专门用于需要直接输入输出 2D 特征图的场景 Args: dim: 特征通道数 wave: 小波基类型 use_skip: 是否使用跳跃连接 def __init__(self, dim, wavehaar, use_skipTrue): super().__init__() # 确保 wave 是字符串 if isinstance(wave, int): wave haar self.hewl HEWL(dimdim, wavewave, use_skipuse_skip) self.dim dim def forward(self, x): x: [B, C, H, W] 2D 特征图 返回: [B, C, H, W] 增强后的特征图 return self.hewl(x) # 测试代码 if __name__ __main__: device torch.device(cuda:0 if torch.cuda.is_available() else cpu) print( * 60) print(HEWL 模块测试) print( * 60) # 测试模式1: 序列格式输入 print(\n[测试1] 序列格式输入 [B, N, C] H, W) B, H, W, C 1, 32, 32, 64 N H * W x_seq torch.randn(B, N, C).to(device) model HEWL(dimC, wavehaar).to(device) y_seq model(x_seq, H, W) print(f输入序列维度: {x_seq.shape}) print(f输出序列维度: {y_seq.shape}) print(f输入输出形状匹配: {x_seq.shape y_seq.shape}) # 测试模式2: 2D 特征图输入 print(\n[测试2] 2D 特征图输入 [B, C, H, W]) x_2d torch.randn(B, C, H, W).to(device) y_2d model(x_2d) print(f输入2D维度: {x_2d.shape}) print(f输出2D维度: {y_2d.shape}) print(f输入输出形状匹配: {x_2d.shape y_2d.shape}) # 测试整数 wave 参数 print(\n[测试3] 整数 wave 参数测试) model_int HEWL(dimC, wave1).to(device) y_int model_int(x_2d) print(fwave1 输出维度: {y_int.shape}) print(f测试通过) # 测试 HEWL2D 包装器 print(\n[测试4] HEWL2D 包装器) model_2d HEWL2D(dimC, wavehaar).to(device) y_2d_wrapped model_2d(x_2d) print(f输入维度: {x_2d.shape}) print(f输出维度: {y_2d_wrapped.shape}) # 参数统计 print(\n * 60) print(模型参数统计) print( * 60) total_params sum(p.numel() for p in model.parameters()) trainable_params sum(p.numel() for p in model.parameters() if p.requires_grad) print(f总参数量: {total_params:,}) print(f可训练参数量: {trainable_params:,}) print(\n * 60) print(所有测试通过) print( * 60)结合自己的思路可将其即插即用至任何模型做结构创新设计该模块博主已成功嵌入至YOLO26模型中可订阅博主YOLO系列算法改进或YOLO26自研改进专栏YOLO系列算法改进专栏链接、YOLO26自研改进系列专栏