栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

SUNDIALS的C++使用例子

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

SUNDIALS的C++使用例子

SUNDIALS的C++使用例子

本例子基于cvAdvDiff_bnd.c改造。可以作为sundials使用的模板

本例子采用CMake进行编译,在ubuntun18下测试通过

特点
  1. 回调的函数是类的静态成员函数。不能够是成员函数,因为接口需要的是函数指针,而成员函数需要对象.
  2. 可以支持不需要jac
  3. 起始时间可以是任意时间,更改T0的值,以及main.cpp的迭代起始值即可
  4. 测试的结果和原始代码一致

##头文件

#ifndef CV_ADV_DIFF_BND_H_
#define CV_ADV_DIFF_BND_H_

#include 
#include 
#include 

#include               
#include    
#include  
#include  
#include   
#include    

class CvAdvDiffBnd
{

public:
    CvAdvDiffBnd();
    virtual ~CvAdvDiffBnd();

 

public:
    int Init();
    int Run(double tout);
    void Finish();
    void Release();

public:
    
    int init_environment_1();

    
    int define_problem_length_2();

    
    int set_vector_initial_val_3();

    
    int create_vode_obj_4();

    
    int init_cvode_solver_5();

    
    int specify_integration_tolerances_6();

    
    int set_optional_inputs_7();

    
    int create_matrix_ob_8();

    
    int create_linear_solver_obj_9();

    
    int set_linear_solver_optional_inputs_10();

    
    int attach_linear_solver_module_11();

    
    int set_linear_solver_interface_optional_inputs_12();

    
    int specify_rootfinding_problem_13();

    
    int advance_solution_in_time_14();

    //----------------------------------------------------------------
    
    int get_optional_output_15();

    
    int deallocate_mem_16();

    
    int free_solver_mem_17();

    
    int free_linear_solver_and_matrix_mem_18();


public:
    static int f_sub(realtype t, N_Vector u, N_Vector udot, void *user_data);

    static int Jac_sub(realtype t, N_Vector u, N_Vector fu,
        SUNMatrix J, void *user_data,
        N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);


public:
    

    int f_func(realtype t, N_Vector u, N_Vector udot);
    int Jac_func(realtype t, N_Vector u, N_Vector fu,
            SUNMatrix J,
            N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);

public:
    
    void SetIC_func(N_Vector u);
    void PrintHeader_func(realtype reltol, realtype abstol, realtype umax);
    void PrintOutput_func(realtype t, realtype umax, long int nst);
    void PrintFinalStats_func(void *cvode_mem);

    
    int check_retval(void *returnvalue, const char *funcname, int opt);

 
private:
    double XMAX;  // RConST(2.0)    
    double YMAX;  //RConST(1.0)
    int MX;       //    10             
    int MY;       //   5
    int NEQ;      //  MX*MY          
    double ATOL;  //  RConST(1.0e-5) 
    double T0;    //  RConST(0.0)    
    double T1;    //   RConST(0.1)    
    double DTOUT; // RConST(0.1)    
    int NOUT;     //  10             

    double ZERO; // RConST(0.0)
    double HALF; // RConST(0.5)
    double ONE;  //  RConST(1.0)
    double TWO;  // RConST(2.0)
    double FIVE; // RConST(5.0)

private:
    realtype dx, dy, reltol, abstol, t, tout, umax;
    N_Vector u;
    SUNMatrix A;
    SUNLinearSolver LS;
    void *cvode_mem;
    int iout, retval;
    long int nst;

private:
    //user data
    realtype m_dx, m_dy, m_hdcoef, m_hacoef, m_vdcoef;
};

#endif


cpp文件
#include "CvAdvDiffBnd.h"

#define IJth(vdata, i, j) (vdata[(j - 1) + (i - 1) * MY])


int CvAdvDiffBnd::f_sub(realtype t, N_Vector u, N_Vector udot, void *user_data)
{
    CvAdvDiffBnd *pt = (CvAdvDiffBnd *)user_data;
    return pt->f_func(t, u, udot);
}

int CvAdvDiffBnd::Jac_sub(realtype t, N_Vector u, N_Vector fu,
        SUNMatrix J, void *user_data,
        N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
{
    printf("i am jac funcn");
    CvAdvDiffBnd *pt = (CvAdvDiffBnd *)user_data;
    return pt->Jac_func(t, u, fu, J, tmp1, tmp2, tmp3);
}




CvAdvDiffBnd::CvAdvDiffBnd()
{
}

CvAdvDiffBnd::~CvAdvDiffBnd()
{
    Release();
}

int CvAdvDiffBnd::Init()
{

    init_environment_1();
    
      init_environment_1();

    
      define_problem_length_2();

    
      set_vector_initial_val_3();

    
      create_vode_obj_4();

    
      init_cvode_solver_5();

    
      specify_integration_tolerances_6();

    
      set_optional_inputs_7();

    
      create_matrix_ob_8();

    
      create_linear_solver_obj_9();

    
      set_linear_solver_optional_inputs_10();

    
      attach_linear_solver_module_11();

    
      set_linear_solver_interface_optional_inputs_12();

    
      specify_rootfinding_problem_13();

    
      advance_solution_in_time_14();
}

int CvAdvDiffBnd::Run(double tout)
{
    retval = CVode(cvode_mem, tout, u, &t, CV_NORMAL);
    if (check_retval(&retval, "CVode", 1))
    {
        return -1;
    }

    umax = N_VMaxNorm(u);
    retval = CVodeGetNumSteps(cvode_mem, &nst);
    check_retval(&retval, "CVodeGetNumSteps", 1);
    PrintOutput_func(t, umax, nst);

    return 0;
}

void CvAdvDiffBnd::Finish()
{
    //----------------------------------------------------------------
    
    get_optional_output_15();
    PrintFinalStats_func(cvode_mem); 
}

void CvAdvDiffBnd::Release()
{

    
    deallocate_mem_16();

    
    free_solver_mem_17();

    
    free_linear_solver_and_matrix_mem_18();
}

int CvAdvDiffBnd::init_environment_1()
{
    XMAX = RConST(2.0); 
    YMAX = RConST(1.0);
    MX = 10; 
    MY = 5;

    ATOL = RConST(1.0e-5); 
    T0 = RConST(0.5);      
    T1 = RConST(0.1);      
    DTOUT = RConST(0.1);   
    NOUT = 10;             

    ZERO = RConST(0.0);
    HALF = RConST(0.5);
    ONE = RConST(1.0);
    TWO = RConST(2.0);
    FIVE = RConST(5.0);

    u = NULL;
    A = NULL;
    LS = NULL;
    cvode_mem = NULL;
    return 0;
}


int CvAdvDiffBnd::define_problem_length_2()
{
    NEQ = MX * MY; 
    return 0;
}


int CvAdvDiffBnd::set_vector_initial_val_3()
{
    u = N_VNew_Serial(NEQ); 
    if (check_retval((void *)u, "N_VNew_Serial", 0))
        return (1);

    dx = m_dx = XMAX / (MX + 1); 
    dy = m_dy = YMAX / (MY + 1);
    m_hdcoef = ONE / (dx * dx);
    m_hacoef = HALF / (TWO * dx);
    m_vdcoef = ONE / (dy * dy);

    SetIC_func(u);
    return 0;
}


int CvAdvDiffBnd::create_vode_obj_4()
{
    
    cvode_mem = CVodeCreate(CV_BDF);
    if (check_retval((void *)cvode_mem, "CVodeCreate", 0))
        return (1);

    return 0;
}


int CvAdvDiffBnd::init_cvode_solver_5()
{
    
    //retval = CVodeInit(cvode_mem, f, T0, u);
    retval = CVodeInit(cvode_mem, CvAdvDiffBnd::f_sub, T0, u);
    if (check_retval(&retval, "CVodeInit", 1))
        return (1);

    return 0;
}


int CvAdvDiffBnd::specify_integration_tolerances_6()
{
    reltol = ZERO; 
    abstol = ATOL;
    
    retval = CVodeSStolerances(cvode_mem, reltol, abstol);
    if (check_retval(&retval, "CVodeSStolerances", 1))
        return (1);

    return 0;
}


int CvAdvDiffBnd::set_optional_inputs_7()
{

    return 0;
}


int CvAdvDiffBnd::create_matrix_ob_8()
{
    
    retval = CVodeSetUserData(cvode_mem, this);
    if (check_retval(&retval, "CVodeSetUserData", 1))
        return (1);

    
    A = SUNBandMatrix(NEQ, MY, MY);
    if (check_retval((void *)A, "SUNBandMatrix", 0))
        return (1);

    return 0;
}


int CvAdvDiffBnd::create_linear_solver_obj_9()
{
    
    LS = SUNLinSol_Band(u, A);
    if (check_retval((void *)LS, "SUNLinSol_Band", 0))
        return (1);

    return 0;
}


int CvAdvDiffBnd::set_linear_solver_optional_inputs_10()
{
    return 0;
}


int CvAdvDiffBnd::attach_linear_solver_module_11()
{
    
    retval = CVodeSetLinearSolver(cvode_mem, LS, A);
    if (check_retval(&retval, "CVodeSetLinearSolver", 1))
        return (1);

    umax = N_VMaxNorm(u);
    PrintHeader_func(reltol, abstol, umax);

    return 0;
}


int CvAdvDiffBnd::set_linear_solver_interface_optional_inputs_12()
{
    
    //retval = CVodeSetJacFn(cvode_mem, Jac);
    
    retval = CVodeSetJacFn(cvode_mem, Jac_sub);
    if (check_retval(&retval, "CVodeSetJacFn", 1))
        return (1);

    return 0;
}


int CvAdvDiffBnd::specify_rootfinding_problem_13()
{
    return 0;
}


int CvAdvDiffBnd::advance_solution_in_time_14()
{
    return 0;
}

//----------------------------------------------------------------

int CvAdvDiffBnd::get_optional_output_15()
{
    return 0;
}


int CvAdvDiffBnd::deallocate_mem_16()
{
    N_VDestroy(u); 
    return 0;
}


int CvAdvDiffBnd::free_solver_mem_17()
{
    CVodeFree(&cvode_mem); 
    SUNLinSolFree(LS);     
    SUNMatDestroy(A);      
    return 0;
}


int CvAdvDiffBnd::free_linear_solver_and_matrix_mem_18()
{
    return 0;
}





int CvAdvDiffBnd::f_func(realtype t, N_Vector u, N_Vector udot)
{
    realtype uij, udn, uup, ult, urt, hordc, horac, verdc, hdiff, hadv, vdiff;
    realtype *udata, *dudata;
    int i, j;
    //UserData data;

    udata = N_VGetArrayPointer(u);
    dudata = N_VGetArrayPointer(udot);

    
    //data = (UserData)user_data;
    hordc = m_hdcoef;
    horac = m_hacoef;
    verdc = m_vdcoef;

    

    for (j = 1; j <= MY; j++)
    {

        for (i = 1; i <= MX; i++)
        {

            

            uij = IJth(udata, i, j);
            udn = (j == 1) ? ZERO : IJth(udata, i, j - 1);
            uup = (j == MY) ? ZERO : IJth(udata, i, j + 1);
            ult = (i == 1) ? ZERO : IJth(udata, i - 1, j);
            urt = (i == MX) ? ZERO : IJth(udata, i + 1, j);

            

            hdiff = hordc * (ult - TWO * uij + urt);
            hadv = horac * (urt - ult);
            vdiff = verdc * (uup - TWO * uij + udn);
            IJth(dudata, i, j) = hdiff + hadv + vdiff;
        }
    }

    return (0);
}



int CvAdvDiffBnd::Jac_func(realtype t, N_Vector u, N_Vector fu,
                           SUNMatrix J,
                           N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
{
    sunindextype i, j, k;
    realtype *kthCol, hordc, horac, verdc;
    //UserData data;

    

    //data = (UserData)user_data;
    hordc = m_hdcoef;
    horac = m_hacoef;
    verdc = m_vdcoef;

    
    for (j = 1; j <= MY; j++)
    {
        for (i = 1; i <= MX; i++)
        {
            k = j - 1 + (i - 1) * MY;
            kthCol = SUNBandMatrix_Column(J, k);

            

            SM_COLUMN_ELEMENT_B(kthCol, k, k) = -TWO * (verdc + hordc);
            if (i != 1)
                SM_COLUMN_ELEMENT_B(kthCol, k - MY, k) = hordc + horac;
            if (i != MX)
                SM_COLUMN_ELEMENT_B(kthCol, k + MY, k) = hordc - horac;
            if (j != 1)
                SM_COLUMN_ELEMENT_B(kthCol, k - 1, k) = verdc;
            if (j != MY)
                SM_COLUMN_ELEMENT_B(kthCol, k + 1, k) = verdc;
        }
    }

    return (0);
}





void CvAdvDiffBnd::SetIC_func(N_Vector u)
{
    int i, j;
    realtype x, y, dx, dy;
    realtype *udata;

    

    dx = m_dx;
    dy = m_dy;

    

    udata = N_VGetArrayPointer(u);

    

    for (j = 1; j <= MY; j++)
    {
        y = j * dy;
        for (i = 1; i <= MX; i++)
        {
            x = i * dx;
            IJth(udata, i, j) = x * (XMAX - x) * y * (YMAX - y) * SUNRexp(FIVE * x * y);
        }
    }
}



void CvAdvDiffBnd::PrintHeader_func(realtype reltol, realtype abstol, realtype umax)
{
    printf("n2-D Advection-Diffusion Equationn");
    printf("Mesh dimensions = %d X %dn", MX, MY);
    printf("Total system size = %dn", NEQ);
#if defined(SUNDIALS_EXTENDED_PRECISION)
    printf("Tolerance parameters: reltol = %Lg   abstol = %Lgnn", reltol, abstol);
    printf("At t = %Lg      max.norm(u) =%14.6Le n", T0, umax);
#elif defined(SUNDIALS_DOUBLE_PRECISION)
    printf("Tolerance parameters: reltol = %g   abstol = %gnn", reltol, abstol);
    printf("At t = %g      max.norm(u) =%14.6e n", T0, umax);
#else
    printf("Tolerance parameters: reltol = %g   abstol = %gnn", reltol, abstol);
    printf("At t = %g      max.norm(u) =%14.6e n", T0, umax);
#endif

    return;
}



void CvAdvDiffBnd::PrintOutput_func(realtype t, realtype umax, long int nst)
{
#if defined(SUNDIALS_EXTENDED_PRECISION)
    printf("At t = %4.2Lf   max.norm(u) =%14.6Le   nst = %4ldn", t, umax, nst);
#elif defined(SUNDIALS_DOUBLE_PRECISION)
    printf("At t = %4.2f   max.norm(u) =%14.6e   nst = %4ldn", t, umax, nst);
#else
    printf("At t = %4.2f   max.norm(u) =%14.6e   nst = %4ldn", t, umax, nst);
#endif

    return;
}



void CvAdvDiffBnd::PrintFinalStats_func(void *cvode_mem)
{
    int retval;
    long int nst, nfe, nsetups, netf, nni, ncfn, nje, nfeLS;

    retval = CVodeGetNumSteps(cvode_mem, &nst);
    check_retval(&retval, "CVodeGetNumSteps", 1);
    retval = CVodeGetNumRhsevals(cvode_mem, &nfe);
    check_retval(&retval, "CVodeGetNumRhsevals", 1);
    retval = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups);
    check_retval(&retval, "CVodeGetNumLinSolvSetups", 1);
    retval = CVodeGetNumErrTestFails(cvode_mem, &netf);
    check_retval(&retval, "CVodeGetNumErrTestFails", 1);
    retval = CVodeGetNumNonlinSolvIters(cvode_mem, &nni);
    check_retval(&retval, "CVodeGetNumNonlinSolvIters", 1);
    retval = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn);
    check_retval(&retval, "CVodeGetNumNonlinSolvConvFails", 1);

    retval = CVodeGetNumJacevals(cvode_mem, &nje);
    check_retval(&retval, "CVodeGetNumJacevals", 1);
    retval = CVodeGetNumLinRhsevals(cvode_mem, &nfeLS);
    check_retval(&retval, "CVodeGetNumLinRhsevals", 1);

    printf("nFinal Statistics:n");
    printf("nst = %-6ld nfe  = %-6ld nsetups = %-6ld nfeLS = %-6ld nje = %ldn",
           nst, nfe, nsetups, nfeLS, nje);
    printf("nni = %-6ld ncfn = %-6ld netf = %ldn",
           nni, ncfn, netf);

    return;
}



int CvAdvDiffBnd::check_retval(void *returnvalue, const char *funcname, int opt)
{
    int *retval;

    

    if (opt == 0 && returnvalue == NULL)
    {
        fprintf(stderr, "nSUNDIALS_ERROR: %s() failed - returned NULL pointernn",
                funcname);
        return (1);
    }

    

    else if (opt == 1)
    {
        retval = (int *)returnvalue;
        if (*retval < 0)
        {
            fprintf(stderr, "nSUNDIALS_ERROR: %s() failed with retval = %dnn",
                    funcname, *retval);
            return (1);
        }
    }

    

    else if (opt == 2 && returnvalue == NULL)
    {
        fprintf(stderr, "nMEMORY_ERROR: %s() failed - returned NULL pointernn",
                funcname);
        return (1);
    }

    return (0);
}


main函数
#include "CvAdvDiffBnd.h"



int main()
{
	CvAdvDiffBnd test;
    test.Init();
 
 int NOUT = 10;
 double T1 = 0.6;
 double DTOUT = 0.1;
 int iout = 0;
 double tout = 0.0;
 for(iout=1, tout=T1; iout <= NOUT; iout++, tout += DTOUT) {
     int ret = test.Run(tout);
     if(0 != ret)
     {
         break;
     }
 }

 test.Finish();

  return 0;
}


##测试的结果

跟原始一致

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/384584.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号