【ML-BTC复现】多标记学习领域算法复现(多标签二叉分类树)

【ML-BTC复现】多标记学习领域算法复现(多标签二叉分类树)
小包子复现代码在https://github.com/Warma10032/ML-BTC
任务简介
领域背景
多标签分类(multi-label classification )指的是一个输入的样本可以同时拥有几个类别标签,比如一首歌的标签可以是流行、轻快,一部电影的标签可以是动作、喜剧、搞笑,一本书的标签可以是经典、文学等,这都是多标签分类的情况。多标签分类的一个重要特点是样本的所有标签是不具有排他性的。
处理多标签数据时,复杂的决策空间是面临的最主要的问题。与单标签分类不同,多标签分类需学习更复杂决策边界,现有多标签分类技术存在许多不足。
典型算法
复杂单分类器
通过一个复杂的单一分类器来处理多标签分类任务。分类器需要学习所有标签的决策边界。
- 优点:可以一次性考虑所有标签的关系,适用于特定规模的数据。
- 缺点:训练复杂且耗时,尤其在标签数量或数据量很大的情况下。对模型复杂度提出要求。
集成分类器
使用多个分类器的组合来处理多标签问题。
- 优点:灵活,可扩展,易于实现。可以利用现有的单标签分类器。
- 缺点:往往忽略了标签间的相关性(如集成一对多分类器),或需要显式编码标签间关系(如分类器链),增加了复杂度。
论文的算法
论文中提出的多标签二叉树分类器[1],是一个基于决策树的分层多标签分类器,利用标签相关性,将数据按类标签分割,用于训练多个小分类器。ML-BTC采取的是一种层次聚类的方法,它将训练集根据标签集的汉明距离进行层次二分聚类,从而将大数据集分为多个小的分类任务;同时考虑了标签相关性。根据每次分类的数据特点,模型会选择对应的分裂策略和分类器。用合适的分类器,解决不同问题;可以有效缓解数据分布不平衡的影响,同时早停策略可以避免过拟合
ML-BTC 的设计目标是通过一种层次化、结构化的方法来解决标签相关性、类别不平衡、决策空间复杂性等多标签分类中的核心挑战,同时减少训练和推理的计算成本。这种方法在多种规模的数据集上都表现出了较好的分类效果,尤其在处理类别不平衡和高维数据方面具有显著优势。
算法介绍
多标签二叉树分类器(Multi-Label Binary Tree Classifier, MLBTC)是一种基于树结构的多标签分类算法。该算法通过递归地将标签空间划分为两个子集,构建一个二叉树结构来处理多标签分类问题。在每个内部节点,算法使用基于汉明距离的二分聚类——将当前节点的标签集分为两个子集,选择汉明距离最大的两个标签作为聚类中心点,其余标签根据到这两个中心点的汉明距离进行分配。
在训练过程中,对于每个内部节点,算法会训练一个二元分类器来预测样本应该被分到左子树还是右子树。这种分层的方式不仅考虑了标签之间的相关性,还通过树结构降低了问题的复杂度。对于每个叶子节点,会根据叶子节点中的标签树,选择是否再训练一个分类器和选择训练何种分类器。
在预测阶段,测试样本从根节点开始,根据每个内部节点的分类器决策结果,逐步向下传递直到达到叶节点,在叶子节点中获得完整的标签预测结果。
该算法的主要优势在于:
- 通过树结构将复杂的多标签分类问题分解为一系列更简单的二分类问题
- 考虑了标签之间的相关性,特别是通过汉明距离来度量标签的相似度,用标签的相似度来进行分裂。
- 模型会选择对应的分裂策略和分类器。用合适的分类器,可以有效缓解数据分布不平衡的影响。
- 模型的早停策略不强求分裂叶子节点直到其中只剩一个类别,这种方法避免模型过于复杂,减轻了过拟合风险,降低了计算量。
其中的H,C分别指的是节点中的多标签熵和样本基数,
分别表示的是节点中样本标签的多样性和样本个数。
具体解释训练流程:
H>H阈值 且 C>C阈值
子集之间大小不平衡:使用 k-NN 分类器处理此节点。
子集之间大小相对平衡:使用 SVM 分类器处理此节点。
H
C阈值 在该节点训练一个多层感知机(MLP)分类器
MLP 适合处理相对集中的样本,能够有效学习不同类别之间的决策边界。
C<C阈值
在该节点使用 ML-kNN 分类器
ML-kNN 能够在小样本条件下有效分类,并保留多标签的决策能力。
H=0
表示该节点中仅有一个标签,将它分配给抵达的节点即可。
具体解释推理流程:
根节点
未知样本从根节点出发,使用根节点中训练好的二分类器被分类到其中一个子节点中。
中间节点
每个中间节点包含一个二分类器,样本经过每个中间节点会被分配到左或右子节点。
叶子节点
根据叶子节点中的标签类别或分类器,确定最终的类别。
算法复现
数据集
本次实验的数据集从Multi-Label Classification Dataset Repository 下载。
数据集中所有属性都已经进行了数值化,所以在这个网址下载的数据集存储大小都很小。
我们选取了论文中相同的几个数据集进行复现,按论文中的大小分类,small,medium,large三类都进行了复现。
Dataset | 数据类型 | 样本条数 | 属性个数 | 标签维度 |
---|---|---|---|---|
Cal500 | Music | 502 | 68 | 174 |
Emotions | Music | 593 | 72 | 6 |
Flags | Image | 194 | 19 | 7 |
Enron | Text | 1702 | 1001 | 53 |
Yeast | Biology | 2417 | 103 | 14 |
Bibtex | Text | 7395 | 1836 | 159 |
Delicious | Text | 16110 | 500 | 983 |
Yelp | Text | 10810 | 671 | 5 |
实验设置
实验评估设置
实验评估策略:5折交叉验证
实验评估指标:
Hamming Loss (HL)
预测标记集与真实标记集的不匹配率
计算每个样本预测标记集与真实标记集的对称差集大小,再除以标记集总数,最后取平均
Subset Accuracy (SA)
完全匹配的准确率
只有预测标记集与真实标记集完全相同才算正确
Macro-F1 (MacF1)
先对每个标记计算F1值,然后取平均,对所有类别的处理权重相同
不考虑样本分布不均衡
Micro-F1 (MicF1)
将所有标记的预测结果混在一起计算整体的F1值
考虑了样本分布,更关注高频标记的表现
F-Measure (FM)
基于样本级别的F1度量
对每个样本计算预测标记集与真实标记集的F1值,然后取平均
Accuracy (Acc)
预测正确的标记占总标记数的比例
计算预测标记集与真实标记集的交集大小除以标记集总数
Geometric Mean (GM)
计算每个标记正类和负类的召回率
计算两个召回率的几何平均
这种GM计算方式特别关注模型在正类和负类上的平衡性能,能够有效处理类别不平衡问题
模型与参数设置
H和C的阈值设置:
1
2self.h_threshold = self.compute_ml_entropy(y) / self.h_threshold_den
self.c_threshold = len(X) / self.c_threshold_den在实际代码中,我们使用的是根节点(初始训练集)的多标签熵和样本总数的1/10来作为叶子节点不进行分裂的阈值。
复现代码说明
树是怎么递归构建的(build_tree函数)
通过对节点判断是否进行分裂,如果进行分裂,执行二分数据集代码,并选择合适的分类器。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39def build_tree(self, X, y):
"""构建分类器树"""
node = MLBTCNode()
# 检查停止条件
current_entropy = self.compute_ml_entropy(y)
if current_entropy < self.h_threshold or len(X) < self.c_threshold:
node.is_leaf = True
if len(np.unique(y, axis=0)) == 1:
node.label_set = y[0]
else:
node.ml_classifier = MLPClassifier(
hidden_layer_sizes=(100,), max_iter=500
)
node.ml_classifier.fit(X, y)
return node
# 划分数据
clusters = self.split_data(X, y)
# 如果无法划分,创建叶节点
if len(np.unique(clusters)) == 1:
node.is_leaf = True
node.ml_classifier = MLPClassifier(hidden_layer_sizes=(100,), max_iter=500)
node.ml_classifier.fit(X, y)
return node
# 选择并训练分类器
node.classifier = self.select_classifier(X, clusters)
node.classifier.fit(X, clusters)
# 递归构建子树
left_mask = clusters == 0
right_mask = clusters == 1
node.left = self.build_tree(X[left_mask], y[left_mask])
node.right = self.build_tree(X[right_mask], y[right_mask])
return node树的二分裂中数据聚类算法是怎样实现的(split_data函数)
- 计算所有标签对之间的汉明距离
- 找到汉明距离最大的两个标签作为中心点
- 其他标签根据到两个中心点的距离进行分配
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56def hamming_distance(a, b):
"""计算两个标签向量之间的汉明距离"""
return np.sum(a != b)
def split_data(self, X, y):
"""
基于最大汉明距离的标签空间划分
Args:
X: 特征矩阵
y: 标签矩阵 (n_samples, n_labels)
Returns:
clusters: 划分结果,0和1表示两个子集
"""
if len(X) <= 1:
return np.array([0] * len(X))
n_samples = len(y)
# 1. 计算所有标签对之间的汉明距离
distances = np.zeros((n_samples, n_samples))
for i in range(n_samples):
for j in range(i + 1, n_samples):
dist = hamming_distance(y[i], y[j])
distances[i][j] = dist
distances[j][i] = dist
# 2. 找到距离最大的两个标签索引
max_dist = 0
center1_idx = 0
center2_idx = 0
for i in range(n_samples):
for j in range(i + 1, n_samples):
if distances[i][j] > max_dist:
max_dist = distances[i][j]
center1_idx = i
center2_idx = j
# 3. 将其他标签分配到最近的中心点
clusters = np.zeros(n_samples, dtype=int)
clusters[center2_idx] = 1 # 第二个中心点标记为1
# 对每个样本,计算到两个中心点的距离,分配到较近的中心点
for i in range(n_samples):
if i != center1_idx and i != center2_idx: # 跳过中心点
dist_to_center1 = hamming_distance(y[i], y[center1_idx])
dist_to_center2 = hamming_distance(y[i], y[center2_idx])
clusters[i] = 0 if dist_to_center1 <= dist_to_center2 else 1
# 打印一些信息以便调试
print(f"Selected centers: {center1_idx} and {center2_idx}")
print(f"Maximum hamming distance between centers: {max_dist}")
print(f"Cluster sizes: {np.sum(clusters == 0)} and {np.sum(clusters == 1)}")
return clusters如何选择分类器的(select_classifier函数)
根据聚类形成的两个子集决定,如果两个子集之间不平衡或者样本个数较少,就使用KNN,反之使用SVM。
1
2
3
4
5
6
7
8
9
10
11
12def select_classifier(self, X, y):
"""根据数据特征选择合适的分类器"""
n_samples = len(X)
class_counts = np.sum(y, axis=0)
# 检查类别不平衡情况
imbalance_ratio = np.max(class_counts) / np.min(class_counts)
if imbalance_ratio > 10 or n_samples < 50:
return KNeighborsClassifier(n_neighbors=1)
else:
return SVC(kernel="rbf", probability=True)
复现结果与分析
复现结果
使用我们复现的结果与论文在同数据集上进行对比,以下是对比的结果:
- Cal500数据集(small):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.1931 | 0.2052 | 0.1343 | 0.2458 | 0.3639 | 0.2241 | 0.5089 |
复现 | 0.1740 | 0.0000 | 0.1383 | 0.3533 | 0.3496 | 0.8260 | 0.1968 |
- Emotions数据集(small):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.2126 | 0.6384 | 0.6352 | 0.6601 | 0.6704 | 0.5518 | 0.7018 |
复现 | 0.2029 | 0.3322 | 0.6453 | 0.6727 | 0.6455 | 0.7971 | 0.7258 |
- Flags数据集(small):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.3438 | 0.5922 | 0.5085 | 0.6114 | 0.6697 | 0.5194 | 0.6115 |
复现 | 0.3197 | 0.1441 | 0.5770 | 0.6661 | 0.6396 | 0.6803 | 0.5232 |
- Enron数据集(medium):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.0889 | 0.3622 | 0.1679 | 0.3724 | 0.4557 | 0.3119 | 0.5193 |
复现 | 0.0621 | 0.3854 | 0.6090 | 0.6082 | 0.2159 | 0.9379 | 0.7811 |
- Yeast数据集(medium):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.2342 | 0.5275 | 0.3443 | 0.5605 | 0.6339 | 0.4319 | 0.6425 |
复现 | 0.2098 | 0.2362 | 0.4214 | 0.6482 | 0.6314 | 0.7902 | 0.4148 |
- Bibtex数据集(large):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.0299 | 0.2919 | 0.1586 | 0.2597 | 0.3014 | 0.2391 | 0.3911 |
复现 | 0.0181 | 0.1596 | 0.1417 | 0.2778 | 0.3004 | 0.9819 | 0.2440 |
- Delicious数据集(large):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.0381 | 0.2152 | 0.0824 | 0.2488 | 0.2439 | 0.1609 | 0.4012 |
复现 | 0.0242 | 0.0063 | 0.1243 | 0.2818 | 0.2738 | 0.9758 | 0.2277 |
- Yelp数据集(large):
HL↓ | SA↑ | MacF1↑ | MicF1↑ | FM↑ | Acc↑ | GM↑ | |
---|---|---|---|---|---|---|---|
论文 | 0.2167 | 0.6436 | 0.5944 | 0.6602 | 0.7104 | 0.5619 | 0.6872 |
复现 | 0.0901 | 0.6760 | 0.6802 | 0.7173 | 0.4859 | 0.9099 | 0.7470 |
结果分析
在以上评估指标中论文和复现进行对比:
- 可以看出HL和F1普遍比原论文更好。
- 在某些数据集的SA会明显低于原论文,例如Cal500,Delicious数据集,这些数据集都是典型的属性维度大于标签维度,如Cal500的属性维度为68,标签维度却有174,要实现SA即完全预测准确,我觉得这是很难的一件事,但是原始论文还是取得了0.2的得分,不清楚原始论文是否运用了额外的属性来进行预测,因为Cal500数据集其实是有多种属性(MFCC, MEL等),我们使用的是基础的属性。
- Acc的验证可能有问题(按理说Acc应该和HL加和为1,不清楚原始论文是如何计算Acc的)。
- GM较原论文波动较大。
参考文献
<a id="reference1">
[1]</a>
A. Law and A. Ghosh, “Multi-Label Classification Using Binary Tree of Classifiers,” in IEEE Transactions on Emerging Topics in Computational Intelligence, vol. 6, no. 3, pp. 677-689, June 2022, doi: 10.1109/TETCI.2021.3075717.