栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

文本数据挖掘实验:文本分类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

文本数据挖掘实验:文本分类

文本数据挖掘:实现文本分类
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author: 廖文龙
# datetime: 2021/10/20 13:06 
# ide: PyCharm
# Copyright © 2021 WellonLeo.All rights reserved.
import os

import torch
import torch.nn as nn
from partition import partition
from vectorize import vectorize
import numpy as np
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
#firstly implements softmax function
def softmax(x):
    if len(x.shape) > 1:
        # 矩阵
        tmp = np.max(x, axis=1)
        x -= tmp.reshape((x.shape[0], 1))
        x = np.exp(x)
        tmp = np.sum(x, axis=1)
        x /= tmp.reshape((x.shape[0], 1))
    else:
        # 向量
        tmp = np.max(x)
        x -= tmp
        x = np.exp(x)
        tmp = np.sum(x)
        x /= tmp
    return x

#--------------------------------------------------divide-------------------------------------------------
def divide(srcFilePath, trainingData, testData):
    f= open(srcFilePath, 'r')
    lines = f.readlines()
    partitioned_data=partition(instances=lines,proportion=[0.7,0.3],shuffle=True)
    f2=open(trainingData,'w')
    f3=open(testData,'w')
    for each in partitioned_data[0]:
        f2.write(each)
    for each in partitioned_data[1]:
        f3.write(each)

#--------------------------------------------------vectorize-------------------------------------------------
def vectorize_dataset(trainingData, vectorizedTrainingData, dictPath):
    vectorize(trainingData,dictPath,vectorizedTrainingData)

#--------------------------------------------------train-------------------------------------------------

def collect_data(vectorizedTrainingData,words_dict_length):
    f = open(vectorizedTrainingData, 'r')
    lines = f.readlines()

    x_all=[]
    y_all=[]
    y_all_hated=[]
    for line in lines:
        y_all.append(line.split('t')[0])
        vec = np.zeros((1, words_dict_length))
        for i in line.split('t')[1].strip().split(' '):
            vec[0][int(i)]=1
        x_all.append(vec)
    cats_num= len(set(y_all))
    cats_list=[i for i in set(y_all)]
    for i in y_all:
        i=cats_list.index(i)
        y_all_hated.append(i)
    print(cats_num,cats_list)
    return y_all_hated,np.array(x_all).squeeze()

def gen_training_data(y_all_hated,x_all,iteration,batch_size):
    y_iter_data=y_all_hated[iteration*batch_size:(iteration+1)*batch_size]
    x_iter_data=x_all[iteration*batch_size:(iteration+1)*batch_size]
    return y_iter_data,x_iter_data

def forward(W,X,Y,ifCalcloss,ifBackprop,LR):
    #core
    Y_pred=X.dot(W.transpose())
    #softmax
    Y_pred=softmax(Y_pred)
    #loss func
    loss=0
    if(ifCalcloss):
        for i in range(0,len(Y)):
            loss-=np.log(Y_pred[i,int(Y[i])])
        print('Loss',loss)
    #backprop
    if(ifBackprop):
        grads=np.zeros_like(W)
        for m in range(len(Y)):
            for k in range(W.shape[0]):
                #compute gradients
                for i in range(W.shape[-1]):
                    if np.argmax(Y_pred[m,:])==k:
                        zhishifunc=1
                    else:
                        zhishifunc=0
                    grads[k,i]-=(-Y_pred[m,k]*X[m,i]+zhishifunc)
        grads=grads/len(Y)
        #undate weights
        W=W-LR*grads

    return W


#--------------------------------------------------test-------------------------------------------------
#waiting to be added


#--------------------------------------------------evaluate-------------------------------------------------
#waiting to be added

dictPath='indexed_text_mining_data.txt'
trainingData='training.txt'
testData='testData.txt'
srcFilePath='text_mining_data.txt'
f=open(dictPath,'r')
lines=f.readlines()
words_dict_length=len(lines)

f2=open(trainingData,'r')
lines=f2.readlines()
cats_list=[item.split('t')[0] for item in lines]
cats_num= len(set(cats_list))
print(cats_num,words_dict_length)
all_cats=['科教文体广新', '城市管理', '城乡住房', '交通管理', '环境保护', '国土与拆迁安置', '民政', '市场监督', '公安政法', '劳动保障']
y_all_hated,x_all=collect_data('dest.txt',22962)
W=np.random.randn(10,22962)
for i in range(500):
    Y,X=gen_training_data(y_all_hated,x_all,i,10)
    W=forward(W,X,Y,True,True,0.001)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/343919.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号