Cross Entropy in Classification

在分类问题中,交叉熵(Cross Entropy) 是一种用于衡量两个概率分布之间差异的度量方法。具体来说,它常用于评估模型预测的概率分布与实际标签的真实分布之间的差异。以下是详细解释:

1. 背景知识

概率分布

在分类问题中,我们通常有一个模型预测的概率分布 p\mathbf{p} 和一个真实的标签分布 q\mathbf{q}

  • p=[p1,p2,,pn]\mathbf{p} = [p_1, p_2, \ldots, p_n]:模型预测每个类别的概率,其中 pip_i 表示模型预测输入属于第 ii 类的概率。
  • q=[q1,q2,,qn]\mathbf{q} = [q_1, q_2, \ldots, q_n]:实际的类别分布。对于单标签分类(即每个样本只有一个正确的类别),真实分布通常是一个 one-hot 编码的向量。例如,如果某个样本属于第 ii 类,那么 qi=1q_i = 1,其余 qj=0q_j = 0jij \neq i)。

信息熵

信息熵(Entropy)是用来衡量不确定性或信息量的度量。在概率分布 p\mathbf{p} 中,熵定义为:

H(p)=i=1npilogpiH(\mathbf{p}) = -\sum_{i=1}^n p_i \log p_i

2. 交叉熵的定义

交叉熵度量的是两个概率分布之间的距离。对于两个概率分布 p\mathbf{p}q\mathbf{q},交叉熵定义为:

H(q,p)=i=1nqilogpiH(\mathbf{q}, \mathbf{p}) = -\sum_{i=1}^n q_i \log p_i

在分类问题中,如果我们用 one-hot 编码的真实标签分布 q\mathbf{q},则交叉熵可以简化为:

H(q,p)=logpyH(\mathbf{q}, \mathbf{p}) = -\log p_{y}

其中 yy 是真实类别的索引,因为在 one-hot 编码中,qy=1q_y = 1qi=0q_i = 0iyi \neq y)。

3. 交叉熵损失函数

在机器学习中,我们通常将交叉熵用作损失函数来训练分类模型。对于一个包含 mm 个样本的数据集,交叉熵损失函数定义为:

Loss=1mj=1mlogpyj\text{Loss} = -\frac{1}{m} \sum_{j=1}^m \log p_{y_j}

其中 pyjp_{y_j} 表示模型对第 jj 个样本正确类别 yjy_j 的预测概率。

4. 理解交叉熵

交叉熵在分类问题中是一个非常有效的损失函数,它通过量化模型预测分布与真实分布之间的差异来指导模型的优化和训练。交叉熵损失函数有以下几个重要的性质:

  • 非负性:交叉熵总是大于等于零,这是因为概率 pip_i[0,1][0, 1] 范围内,且 logpi\log p_i 是非正的。
  • 完美匹配时最小:当模型的预测概率 pp 与真实分布 qq 完全匹配时,交叉熵达到最小值零。
  • 惩罚错误:交叉熵对预测错误惩罚较大,尤其是当模型对错误类别预测的概率较高时。

5. 举例说明

假设有一个三分类问题,真实标签为类别 2,模型的预测概率分布如下:

p=[0.2,0.7,0.1]\mathbf{p} = [0.2, 0.7, 0.1]

真实的 one-hot 编码标签为:

q=[0,1,0]\mathbf{q} = [0, 1, 0]

交叉熵计算为:

H(q,p)=i=13qilogpi=(0log0.2+1log0.7+0log0.1)=log0.70.357H(\mathbf{q}, \mathbf{p}) = -\sum_{i=1}^3 q_i \log p_i = - (0 \cdot \log 0.2 + 1 \cdot \log 0.7 + 0 \cdot \log 0.1) = -\log 0.7 \approx 0.357