Skip to content

Commit f4f81e3

Browse files
committed
Refactor: Replace local input parameters with PARAM.inp in ESolver classes for consistency
1 parent 42c6cf2 commit f4f81e3

File tree

5 files changed

+21
-41
lines changed

5 files changed

+21
-41
lines changed

source/source_esolver/esolver_ks.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
5757
classname = "ESolver_KS";
5858
basisname = "";
5959

60-
scf_thr = inp.scf_thr;
61-
scf_ene_thr = inp.scf_ene_thr;
62-
maxniter = inp.scf_nmax;
63-
niter = maxniter;
60+
niter = inp.scf_nmax;
6461
drho = 0.0;
6562

6663
std::string fft_device = inp.device;
@@ -235,9 +232,9 @@ void ESolver_KS<T, Device>::runner(UnitCell& ucell, const int istep)
235232
// 2) SCF iterations
236233
//----------------------------------------------------------------
237234
bool conv_esolver = false;
238-
this->niter = this->maxniter;
235+
this->niter = PARAM.inp.scf_nmax;
239236
this->diag_ethr = PARAM.inp.pw_diag_thr;
240-
for (int iter = 1; iter <= this->maxniter; ++iter)
237+
for (int iter = 1; iter <= PARAM.inp.scf_nmax; ++iter)
241238
{
242239
//----------------------------------------------------------------
243240
// 3) initialization of SCF iterations
@@ -398,10 +395,10 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
398395
}
399396
#endif
400397

401-
conv_esolver = (drho < this->scf_thr && not_restart_step && is_U_converged);
398+
conv_esolver = (drho < PARAM.inp.scf_thr && not_restart_step && is_U_converged);
402399

403400
// add energy threshold for SCF convergence
404-
if (this->scf_ene_thr > 0.0)
401+
if (PARAM.inp.scf_ene_thr > 0.0)
405402
{
406403
// calculate energy of output charge density
407404
this->update_pot(ucell, istep, iter, conv_esolver);
@@ -415,7 +412,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
415412
{
416413
// update the convergence flag
417414
conv_esolver
418-
= (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < this->scf_ene_thr);
415+
= (std::abs(this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV) < PARAM.inp.scf_ene_thr);
419416
}
420417
}
421418

source/source_esolver/esolver_ks.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,8 @@ class ESolver_KS : public ESolver_FP
7676
std::string basisname; //! esolver_ks_lcao.cpp
7777
double esolver_KS_ne = 0.0; //! number of electrons
7878
double diag_ethr; //! the threshold for diagonalization
79-
double scf_thr; //! scf density threshold
80-
double scf_ene_thr; //! scf energy threshold
8179
double drho; //! the difference between rho_in (before HSolver) and rho_out (After HSolver)
8280
double hsolver_error; //! the error of HSolver
83-
int maxniter; //! maximum iter steps for scf
8481
int niter; //! iter steps actually used in scf
8582
bool oscillate_esolver = false; // whether esolver is oscillated
8683
};

source/source_esolver/esolver_of.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,6 @@ void ESolver_OF::before_all_runners(UnitCell& ucell, const Input_para& inp)
6060
ESolver_FP::before_all_runners(ucell, inp);
6161

6262
// save necessary parameters
63-
this->of_kinetic_ = inp.of_kinetic;
64-
this->of_method_ = inp.of_method;
65-
this->of_conv_ = inp.of_conv;
66-
this->of_tole_ = inp.of_tole;
67-
this->of_tolp_ = inp.of_tolp;
68-
this->max_iter_ = inp.scf_nmax;
6963
this->dV_ = ucell.omega / this->pw_rho->nxyz;
7064
this->bound_cal_potential_
7165
= std::bind(&ESolver_OF::cal_potential, this, std::placeholders::_1, std::placeholders::_2, std::ref(ucell));
@@ -422,7 +416,7 @@ void ESolver_OF::update_rho()
422416
}
423417

424418
/**
425-
* @brief Check convergence, return ture if converge or iter >= max_iter_,
419+
* @brief Check convergence, return ture if converge or iter >= PARAM.inp.scf_nmax,
426420
* and print the necessary information
427421
*
428422
* @return exit or not
@@ -434,7 +428,7 @@ bool ESolver_OF::check_exit(bool& conv_esolver)
434428
bool potHold = false; // if normdLdphi nearly remains unchanged
435429
bool energyConv = false;
436430

437-
if (this->normdLdphi_ < this->of_tolp_)
431+
if (this->normdLdphi_ < PARAM.inp.of_tolp)
438432
{
439433
potConv = true;
440434
}
@@ -444,23 +438,23 @@ bool ESolver_OF::check_exit(bool& conv_esolver)
444438
potHold = true;
445439
}
446440

447-
if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < this->of_tole_
448-
&& std::abs(this->energy_current_ - this->energy_llast_) < this->of_tole_)
441+
if (this->iter_ >= 3 && std::abs(this->energy_current_ - this->energy_last_) < PARAM.inp.of_tole
442+
&& std::abs(this->energy_current_ - this->energy_llast_) < PARAM.inp.of_tole)
449443
{
450444
energyConv = true;
451445
}
452446

453-
conv_esolver = (this->of_conv_ == "energy" && energyConv) || (this->of_conv_ == "potential" && potConv)
454-
|| (this->of_conv_ == "both" && potConv && energyConv);
447+
conv_esolver = (PARAM.inp.of_conv == "energy" && energyConv) || (PARAM.inp.of_conv == "potential" && potConv)
448+
|| (PARAM.inp.of_conv == "both" && potConv && energyConv);
455449

456450
this->print_info(conv_esolver);
457451

458-
if (conv_esolver || this->iter_ >= this->max_iter_)
452+
if (conv_esolver || this->iter_ >= PARAM.inp.scf_nmax)
459453
{
460454
return true;
461455
}
462456
// ============ temporary solution of potential convergence ===========
463-
else if (this->of_conv_ == "potential" && potHold)
457+
else if (PARAM.inp.of_conv == "potential" && potHold)
464458
{
465459
GlobalV::ofs_warning << "ESolver_OF WARNING: "
466460
<< "The convergence of potential has not been reached, but the norm of potential nearly "

source/source_esolver/esolver_of.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ class ESolver_OF : public ESolver_FP
3838
ModuleBase::Opt_DCsrch* opt_dcsrch_ = nullptr;
3939
ModuleBase::Opt_CG* opt_cg_mag_ = nullptr; // for spin2 case, under testing
4040

41-
// ----------------- necessary parameters from INPUT ------------
42-
std::string of_kinetic_ = "wt"; // Kinetic energy functional, such as TF, VW, WT
43-
std::string of_method_ = "tn"; // optimization method, include cg1, cg2, tn (default), bfgs
44-
std::string of_conv_ = "energy"; // select the convergence criterion, potential, energy (default), or both
45-
double of_tole_ = 2e-6; // tolerance of the energy change (in Ry) for determining the convergence, default=2e-6 Ry
46-
double of_tolp_ = 1e-5; // tolerance of potential for determining the convergence, default=1e-5 in a.u.
47-
int max_iter_ = 50; // scf_nmax
48-
4941
// ------------------ parameters from other module --------------
5042
double dV_ = 0; // volume of one grid point in real space
5143
double* nelec_ = nullptr; // number of electrons with each spin

source/source_esolver/esolver_of_interface.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void ESolver_OF::init_opt()
1616
this->opt_dcsrch_ = new ModuleBase::Opt_DCsrch();
1717
}
1818

19-
if (this->of_method_ == "tn")
19+
if (PARAM.inp.of_method == "tn")
2020
{
2121
if (this->opt_tn_ == nullptr)
2222
{
@@ -25,7 +25,7 @@ void ESolver_OF::init_opt()
2525
this->opt_tn_->allocate(this->pw_rho->nrxx);
2626
this->opt_tn_->set_para(this->dV_);
2727
}
28-
else if (this->of_method_ == "cg1" || this->of_method_ == "cg2")
28+
else if (PARAM.inp.of_method == "cg1" || PARAM.inp.of_method == "cg2")
2929
{
3030
if (this->opt_cg_ == nullptr)
3131
{
@@ -35,7 +35,7 @@ void ESolver_OF::init_opt()
3535
this->opt_cg_->set_para(this->dV_);
3636
this->opt_dcsrch_->set_paras(1e-4, 1e-2);
3737
}
38-
else if (this->of_method_ == "bfgs")
38+
else if (PARAM.inp.of_method == "bfgs")
3939
{
4040
ModuleBase::WARNING_QUIT("esolver_of", "BFGS is not supported now.");
4141
return;
@@ -62,7 +62,7 @@ void ESolver_OF::get_direction(UnitCell& ucell)
6262
{
6363
for (int is = 0; is < PARAM.inp.nspin; ++is)
6464
{
65-
if (this->of_method_ == "tn")
65+
if (PARAM.inp.of_method == "tn")
6666
{
6767
this->tn_spin_flag_ = is;
6868
opt_tn_->next_direct(this->pphi_[is],
@@ -72,15 +72,15 @@ void ESolver_OF::get_direction(UnitCell& ucell)
7272
this,
7373
&ESolver_OF::cal_potential_wrapper);
7474
}
75-
else if (this->of_method_ == "cg1")
75+
else if (PARAM.inp.of_method == "cg1")
7676
{
7777
opt_cg_->next_direct(this->pdLdphi_[is], 1, this->pdirect_[is]);
7878
}
79-
else if (this->of_method_ == "cg2")
79+
else if (PARAM.inp.of_method == "cg2")
8080
{
8181
opt_cg_->next_direct(this->pdLdphi_[is], 2, this->pdirect_[is]);
8282
}
83-
else if (this->of_method_ == "bfgs")
83+
else if (PARAM.inp.of_method == "bfgs")
8484
{
8585
return;
8686
}

0 commit comments

Comments
 (0)