栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

python实现共轭梯度算法

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

python实现共轭梯度算法

西安交通大学《数值分析》第三章课后题3.2

import numpy as np
import math

def generate_matrix(n):
    # 使用对角矩阵相加得到三对角矩阵A
    array_a = np.diag([-2] * n)
    array = np.diag([1] * (n-1))
    a = np.zeros((n-1))
    b = np.zeros(n)
    array_b = np.insert(array, 0, values=a, axis=0)# 添加行
    array_b = np.insert(array_b, (n-1), values=b, axis=1)# 添加列
    array_c = np.insert(array, (n-1), values=a, axis=0)
    array_c = np.insert(array_c, 0, values=b, axis=1)
    matrix_A = array_a + array_b + array_c

    return matrix_A

def Norm(r):
    sum = 0
    for i in r:
        sum += i*i
    norm_r = '%.15f' % math.sqrt(sum)
    print('norm',norm_r)
    return norm_r

def Conjugate_Gradient(ri,di,A,xi,b):
    print('---------------------------------------------------------')
    global beta
    beta = (-1) * (np.dot(ri.T,np.dot(A,di)) / np.dot(di.T,np.dot(A,di)))
    print('beta',beta)
    print('d',di)
    print('r',ri)
    global d
    d = ri + beta*di
    print('d',d)
    global a
    a = np.dot(ri.T,d) / np.dot(d.T,np.dot(A,d))
    print('a',a)
    global x
    x = xi + a*d
    print('x',x)
    global r
    r = b - np.dot(A,x)
    print('r',r)
    # 计算r的2-范数
    return Norm(r)


if __name__ == '__main__':
    n = 200
    # 生成系数矩阵A
    A = generate_matrix(n)
    print('A',A)
    # 生成矩阵b
    b = np.zeros(n)
    b[0] = -1
    b[(n-1)] = -1
    print('b',b)
    # 设置初始值
    global x
    x = np.zeros(n)
    # 第一次迭代
    global r
    r = b - np.dot(A,x)
    print('r',r)
    global d
    d = r
    global a
    a = np.dot(r.T,d) / np.dot(d.T,np.dot(A,d))
    print('a',a)
    x = x + a*d
    print('x',x)
    r = b - np.dot(A,x)
    print('r',r)
    global norm_r
    norm_r = Norm(r)
    # 设置误差上限alpha
    alpha = 1e-10
    count = 1 # 迭代次数
    while float(norm_r) >= alpha:
        # 第2~N次迭代
        count += 1
        norm_r = Conjugate_Gradient(r,d,A,x,b)
        print(norm_r)
        print('r',r)
        print('d',d)
        print('x',x)
    print('x*',x)
    print('count',count)

结果输出:答案为全是1

A [[-2  1  0 ...  0  0  0]
 [ 1 -2  1 ...  0  0  0]
 [ 0  1 -2 ...  0  0  0]
 ...
 [ 0  0  0 ... -2  1  0]
 [ 0  0  0 ...  1 -2  1]
 [ 0  0  0 ...  0  1 -2]]
b [-1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  3.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  4.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  5.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  6.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  7.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  8.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  9.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  10. -1.]
  
无数次循环后
norm 0.000000000006482
x* [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.]
count 100 
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/286341.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号