栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PLA-Pocket算法实现实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PLA-Pocket算法实现实现

原理参考:PLA算法和Pocket算法原理及Python实现

python实现:

import numpy as np
import random
 
#感知机模型
class Pocket:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.w = np.zeros(x.shape[1])  #初始化权重,w1,w2均为0
        self.best_w = np.zeros(x.shape[1])  #最好
        self.b = 0
        self.best_b = 0
 
    def sign(self, w, b, x):
        y = np.dot(x, w) + b
        return int(y)

    def classify(self,w,b):
        mistakes = []
        for i in range(self.x.shape[0]):
            tmpY = self.sign(w, b, self.x[i, :])
            if tmpY * self.y[i] <= 0:  # 如果是误分类
                mistakes.append(i)
        return mistakes

    def update(self, label_i, data_i):
        tmp = label_i * data_i
        tmpw = tmp + self.w
        tmpb = self.b + label_i
        if(len(self.classify(self.best_w,self.best_b))>=(len(self.classify(tmpw,tmpb)))):
            self.best_w = tmp + self.w
            self.best_b = self.b + label_i
        self.w = tmp + self.w
        self.b = self.b + label_i
 
    def train(self,max_iters):
        iters = 0
        isFind = False
        while not isFind:
            mistakes = self.classify(self.w,self.b)
            if(len(mistakes) == 0):
                return self.best_w, self.best_b
            n = mistakes[random.randint(0,len(mistakes)-1)]
            self.update(self.y[n], self.x[n, :])
            iters += 1
            if iters == max_iters:
                isFind = True
        return self.best_w, self.best_b


if __name__ == '__main__':
    x = np.array([[3,-3],[4,-3],[1,1],[1,2]])
    y = np.array([-1, -1, 1, 1])
    myPocket_PLA = Pocket(x, y)
    w, b = myPocket_PLA.train(50)
    print('最终训练得到的w和b为:', w, b)

C++实现:

#include 
#include 
#include 

//创建数据集
void  createdata(std::vector>& x, std::vector& y)
{
	x = { { 3, -3 },{ 4, -3 },{ 1, 1 },{ 1, 2 } };
	y = { -1, -1, 1, 1 };
}

//感知机模型
class Pocket
{
public:
	Pocket(std::vector> x, std::vector y)
	{
		m_x = x;
		m_y = y;
		m_w.resize(m_x[0].size(), 0);
		m_best_w.resize(m_x[0].size(), 0);
		m_b = 0;
		m_best_b = 0;
	}

	float sign(std::vector w, float b, std::vector x)
	{
		float y = b;
		for (size_t i = 0; i < w.size(); i++)
		{
			y += w[i] * x[i];
		}
		return y;
	}

	std::vector classify(std::vector w, float b)
	{
		std::vector mistakes;
		for (size_t i = 0; i < m_x.size(); i++)
		{
			float tmp_y = sign(w, b, m_x[i]);
			if (tmp_y*m_y[i] <= 0) //如果误分类
			{
				mistakes.push_back(i);
			}
		}
		return mistakes;
	}

	void update(float label_i, std::vector data_i)
	{
		std::vector tmp_w(m_w.size());
		for (size_t i = 0; i < m_w.size(); i++)
		{
			tmp_w[i] += label_i * data_i[i] + m_w[i];
		}
		float tmp_b = label_i + m_b;
		if (classify(m_best_w, m_best_b).size() >= classify(tmp_w, tmp_b).size())
		{
			for (size_t i = 0; i < m_w.size(); i++)
			{
				m_best_w[i] += label_i * data_i[i] + m_w[i];
			}
			m_best_b = label_i + m_b;
		}
		for (size_t i = 0; i < m_w.size(); i++)
		{
			m_w[i] = label_i * data_i[i] + m_w[i];
		}
		m_b = label_i + m_b;
	}

	void train(int max_iters)
	{
		int iters = 0;
		bool isFind = false;
		while (!isFind)
		{
			std::vector mistakes = classify(m_w, m_b);
			if (mistakes.size() == 0)
			{
				std::cout << "最终训练得到的w为:";
				for (auto i : m_w)	std::cout << i << " ";
				std::cout << "n最终训练得到的b为:";
				std::cout << m_b << "n";
				break;
			}
			srand((int)time(0));
			int n = mistakes[rand() % (mistakes.size())];
			update(m_y[n], m_x[n]);
			++iters;
			if (iters == max_iters)
			{
				std::cout << "最终训练得到的w为:";
				for (auto i : m_w)	std::cout << i << " ";
				std::cout << "n最终训练得到的b为:";
				std::cout << m_b << "n";
				bool isFind = true;
			}
		}
	}

private:
	std::vector> m_x;
	std::vector m_y;
	std::vector m_w;
	std::vector m_best_w;
	float m_b;
	float m_best_b;
};


int main(int argc, char** argv)
{
	std::vector> x;
	std::vector y;

	createdata(x, y);

	Pocket mypocket = Pocket(x, y);
	mypocket.train(50);

	system("pause");
	return EXIT_SUCCESS;
}
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/846699.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号