手写持向量机(SVM)实现

佚名 / 2024-10-19 / 原文

下面是一个简单的支持向量机(SVM)实现,用于解决线性可分问题。

这个实现不使用任何机器学习库,只使用NumPy进行矩阵运算。

请注意,这个实现主要用于教学目的,实际应用中推荐使用成熟的库,如scikit-learn。

import numpy as np

class SVM:
    def __init__(self, learning_rate=0.001, lambda_param=0.01, n_iterations=1000):
        """
        初始化SVM分类器。
        
        参数:
        learning_rate (float): 学习率。
        lambda_param (float): 正则化参数。
        n_iterations (int): 迭代次数。
        """
        self.lr = learning_rate
        self.lambda_param = lambda_param
        self.n_iterations = n_iterations
        self.w = None
        self.b = None

    def fit(self, X, y):
        """
        训练SVM模型。
        
        参数:
        X (numpy.array): 特征矩阵。
        y (numpy.array): 标签向量。
        """
        n_samples, n_features = X.shape
        self.w = np.zeros(n_features)
        self.b = 0

        # 转换标签为1和-1
        y_ = np.where(y <= 0, -1, 1)

        for _ in range(self.n_iterations):
            for idx, x_i in enumerate(X):
                # 计算条件,检查是否满足SVM的间隔条件
                condition = y_[idx] * (np.dot(x_i, self.w) - self.b) >= 1
                if condition:
                    # 如果满足条件,执行梯度下降更新权重(带正则化)
                    self.w -= self.lr * (2 * self.lambda_param * self.w)
                else:
                    # 如果不满足条件,执行梯度下降更新权重和偏置项
                    self.w -= self.lr * (2 * self.lambda_param * self.w - np.dot(x_i, y_[idx]))
                    self.b -= self.lr * y_[idx]

    def predict(self, X):
        """
        使用训练好的SVM模型进行预测。
        
        参数:
        X (numpy.array): 特征矩阵。
        
        返回:
        predictions (numpy.array): 预测标签。
        """
        linear_output = np.dot(X, self.w) - self.b
        return np.sign(linear_output)

# 生成一些合成数据
X = np.array([[5, 5], [3, 5], [4, 3], [2, 3], [5, 3], [5, 4], [3, 5], [4, 4], [3, 3], [4, 2], [3, 2], [2, 4]])
y = np.array([-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1])

# 创建SVM模型实例
svm = SVM(learning_rate=0.001, lambda_param=0.01, n_iterations=1000)

# 训练模型
svm.fit(X, y)

# 进行预测
predictions = svm.predict(X)

# 打印预测值和真实值
print("Predictions:", predictions)
print("Real values:", y)

代码解释:

  1. 初始化:在__init__方法中,我们初始化了学习率、正则化参数、迭代次数、权重向量w和偏置项b

  2. 训练:在fit方法中,我们首先将标签转换为-1和1,然后进行迭代,对每个样本进行梯度下降。如果样本满足间隔条件(即y_i * (w^T x_i + b) >= 1),我们只更新权重向量w,否则我们同时更新权重向量w和偏置项b

  3. 预测:在predict方法中,我们计算线性输出,然后使用np.sign函数将输出转换为预测标签。

请注意,这个简单的SVM实现没有包含一些高级特性,如核技巧、软间隔或更复杂的优化算法。对于非线性可分的数据集或需要更高性能的应用场景,建议使用成熟的库,如scikit-learn中的SVM实现。