一、基本介绍
torch.softmax 是一个常用于深度学习中的激活函数,作用是对模型输出进行归一化处理,使得每个输出值都在 0~1 之间并且总和为1。
具体地,对于输入的向量 x=[x1,x2,x3,...,xn] ,softmax 函数的作用是把每个元素 xi 转化为 e^xi 的指数形式,再除以所有元素指数的和,最终得到概率分布向量 p=[p1,p2,p3,...,pn]。其中,pi = e^xi / (e^x1+e^x2+...+e^xn)。
对于深度学习中的分类问题,通常把模型最后一层输出做 softmax 处理,使得输出结果可以被解释为各个类别的概率分布。
二、常见应用
1. 分类问题
对于分类问题,softmax 函数通常被用来把神经网络的最后一层输出转化为各个类别的概率分布。在训练时,对于每个样本,根据真实标签和预测概率分布计算交叉熵损失,然后用反向传播算法进行模型参数的优化。
import torch
# 定义模型输出(假设三分类问题)
outputs = torch.randn(2, 3)
# 对模型输出进行 softmax 处理
probs = torch.softmax(outputs, dim=1)
# 输出每个样本对应的概率分布
print(probs)
2. 温度控制
softmax 函数的计算中,指数的值越大,其占比就越大。因此,可以通过调整指数值的大小来调整概率的分布。在某些场景下,需要调整输出结果的"温度",使得各个输出值之间的差异更明显。
def temperature_control(probs, t=1.0):
"""
温度控制函数
Args:
probs (tensor): 模型的概率分布向量
t (float): 温度值,越大则差异越明显
Returns:
tensor: 调整后的概率分布向量
"""
return torch.softmax(probs/t, dim=0)
# 示例
outputs = torch.randn(1, 3)
probs = torch.softmax(outputs, dim=1)
# 默认温度
print("原始概率分布:", probs)
print("默认温度调整:", temperature_control(probs))
# 提高温度
print("提高温度调整:", temperature_control(probs, t=2.0))
# 降低温度
print("降低温度调整:", temperature_control(probs, t=0.5))
三、常见问题
1. 梯度消失
由于 softmax 函数将指数的值按照比例转化为概率值,指数值较大的元素对应的梯度也相对较大,当所有元素的指数值非常大时,较小元素对应的梯度会因为数值溢出而被消失。这种现象被称之为 softmax 梯度消失问题。
解决 softmax 梯度消失问题的方法之一是使用数值稳定的计算函数,比如 PyTorch 中的 torch.log_softmax 和 torch.nan_to_num 函数。
import torch
# 定义模型输出(假设三分类问题)
outputs = torch.randn(2, 3)
# 使用 torch.softmax
probs = torch.softmax(outputs, dim=1)
loss = torch.nn.functional.cross_entropy(probs, torch.LongTensor([0, 2]))
loss.backward()
print("使用 torch.softmax 反向传播的梯度:")
print(outputs.grad)
# 使用 torch.log_softmax 和 torch.exp
log_probs = torch.log_softmax(outputs, dim=1)
exp_probs = torch.exp(log_probs)
loss = torch.nn.functional.cross_entropy(exp_probs, torch.LongTensor([0, 2]))
loss.backward()
print("使用 torch.log_softmax 反向传播的梯度:")
print(outputs.grad)
2. 多分类问题
softmax 函数可以处理具有多个类别的分类问题,在此情况下,softmax 函数的输入是一个形状为 (-1, num_classes) 的张量,输出是形状相同的张量,其中每行都是代表不同样本的概率分布向量。在训练过程中,通常使用交叉熵损失函数对模型进行优化。
import torch
# 定义模型输出(假设五分类问题)
outputs = torch.randn(2, 5)
# 对模型输出进行 softmax 处理
probs = torch.softmax(outputs, dim=1)
# 随机生成真实标签
targets = torch.randint(0, 5, (2,))
# 使用交叉熵计算损失函数
loss = torch.nn.functional.cross_entropy(outputs, targets)
# 反向传播误差信号
loss.backward()
四、总结
本文介绍了 softmax 函数在深度学习中的常见应用,并介绍了一些可能遇到的问题和解决方案。对于分类问题,通过 softmax 函数的作用可以将模型最后一层输出转化为各个类别的概率分布,然后根据真实标签和预测概率分布计算交叉熵损失函数,并用反向传播算法进行优化。如果出现梯度消失问题,可以使用 torch.log_softmax 和 torch.exp 函数代替 torch.softmax 函数进行计算。