损失函数 的 硬截断 和 平滑衰减
flyfish
在逐样本损失计算完成、取平均之前,对损失过高的样本做权重压制,不删除样本,只削弱它们对梯度的贡献,属于软降权——既保留了样本的监督信号,又避免极端难样本/疑似错标样本带偏整个模型。
损失硬截断
损失硬截断是给单样本损失设置一个上限,超过这个阈值的损失,直接按阈值计算。相当于一刀切,超过上限的样本梯度不再放大。
代码实现
classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=2,max_loss=None):""" :param max_loss: 单样本损失上限,None表示不开启截断;设置数值后,单样本损失不会超过该值 """super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.max_loss=max_loss# 损失截断阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# ========== 损失截断 ==========ifself.max_lossisnotNone:loss=torch.clamp(loss,max=self.max_loss)returnloss.mean()使用方式
在训练函数里初始化损失时,多加一个max_loss参数即可:
# 示例:单样本损失最高不超过2.0,超过的全部按2.0计算criterion=FocalLossWithSmoothing(gamma=FOCAL_GAMMA,alpha=FOCAL_ALPHA,smoothing=LABEL_SMOOTHING,num_classes=NUM_CLASSES,max_loss=2.0# 开启截断,阈值可按需调整)平滑衰减降权
硬截断是一刀切:损失超过阈值,直接砍平,损失值瞬间不再增长,像台阶一样突变;
平滑衰减是越涨越慢:损失低于阈值时正常计算,超过阈值后还能继续涨,但增长速度会越来越慢,过渡是顺滑的曲线,没有突变台阶。
它的目的:既保留损失越高、权重越大的相对顺序,又不让极端高损失样本无限放大梯度带偏模型,同时保证训练过程梯度平稳,不会出现跳变。
代码实现 只需要把截断部分替换成平滑衰减逻辑即可:
classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=3,loss_threshold=1.8):super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.loss_threshold=loss_threshold# 平滑衰减阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# 平滑衰减降权:压制极端高损失样本ifself.loss_thresholdisnotNone:high_loss_mask=loss>self.loss_threshold loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)returnloss.mean()假设设置阈值 = 1.5,看不同原始损失对应的处理结果:
| 原始单样本损失 | 硬截断后损失 | 变化特点 |
|---|---|---|
| 1.0(正常样本) | 1.0 | 低于阈值,完全不变 |
| 1.4(较难样本) | 1.4 | 低于阈值,完全不变 |
| 1.5(阈值点) | 1.5 | 刚好等于阈值 |
| 1.6(难样本) | 1.5 | 超过一点点,直接被砍成1.5,瞬间停止增长 |
| 3.0(极难/错标样本) | 1.5 | 不管多高,全砍成1.5,和1.6的样本权重完全一样 |
硬截断的问题:
- 阈值点处损失突变,梯度也会突变,训练过程容易出现震荡;
- 所有超过阈值的样本,损失都一样,丢失了难分程度的差异信息——3.0的极难样本和1.6的轻微难样本,对模型的贡献变得完全相同,有点矫枉过正。
平滑衰减的逻辑:两段式 + 对数压缩
代码里用的是阈值以下正常计算,阈值以上对数压缩的两段式策略,公式是:
处理后损失={原始损失原始损失≤阈值阈值+log(1+原始损失−阈值)原始损失>阈值 \text{处理后损失} = \begin{cases} \text{原始损失} & \text{原始损失} \le 阈值 \\ 阈值 + \log(1 + \text{原始损失} - 阈值) & \text{原始损失} > 阈值 \end{cases}处理后损失={原始损失阈值+log(1+原始损失−阈值)原始损失≤阈值原始损失>阈值
为什么用 log(对数)函数?
对数函数有两个完美匹配需求的特性:
- 单调递增:原始损失越大,处理后的损失也一定越大,不会改变谁更难、谁损失更高的排序,样本的相对权重关系保留了;
- 增速递减:x 越大,log(x) 涨得越慢。原始损失越高,压缩力度越强,正好符合极端样本降权更多的需求。
直观对比效果
还是设阈值 = 1.5,算一组真实数值,一眼就能看出区别:
| 原始单样本损失 | 硬截断后 | 平滑衰减后 | 直观感受 |
|---|---|---|---|
| 1.0 | 1.0 | 1.00 | 低于阈值,两者完全一样 |
| 1.4 | 1.4 | 1.40 | 低于阈值,两者完全一样 |
| 1.5 | 1.5 | 1.50 | 阈值点,两者对齐 |
| 1.6 | 1.5 | 1.595 | 只超了一点点,压缩很轻微,几乎和原值差不多 |
| 2.0 | 1.5 | 1.693 | 超了0.5,增长明显放缓,不再是直线涨 |
| 3.0 | 1.5 | 1.946 | 超了1.5,涨幅被大幅压缩,不会涨到3.0 |
| 5.0 | 1.5 | 2.208 | 超了3.5,增速进一步变慢,和3.0的差距被缩小 |
可以明显看到:
刚超过阈值时,损失几乎不受影响,过渡非常顺滑;
损失越高,被压缩得越厉害,但始终保持越高越重的排序;
不会像硬截断那样,所有高损失全变成同一个值。
对应代码
loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)拆解开:
loss[high_loss_mask] - self.loss_threshold:算出损失超出阈值的部分(增量);1 + 增量:加1保证对数的输入大于0,避免出现负数报错;torch.log(...):对超出的增量做对数压缩,让增量涨得变慢;self.loss_threshold + 压缩后的增量:把基准阈值加回来,保证阈值点处数值连续、没有台阶。
什么时候用硬截断,什么时候用平滑衰减?
| 方案 | 场景 | 特点 |
|---|---|---|
| 硬截断 | 确定有大量标注错误,想直接屏蔽极端错标的影响 | 简单粗暴,可控性强,调试方便 |
| 平滑衰减 | 样本大多是标注正确的难样本(比如小目标、低对比度),只想削弱、不想完全屏蔽 | 更温和,梯度平稳,训练更稳定,保留难样本的相对差异信息 |