Ceres Solver for logistic growth Curve Fit

180 Views Asked by At

I attempt to test how well ceres solver autodiff in fitting a simple logistic growth curve against actual observed data , to my surprise the solver seemed unable to provide solution . Using other c++ solvers , result can be easily obtained with ease , k=9643.61,c=84.61 and b=3.8121. I am not sure is the code having issue or just the ceres solver autodiff are not that well built? Any advice pls?

below is the sample

     #include "ceres/ceres.h"
#include "glog/logging.h"
#include <cmath>
#include <iostream>
#include <stdio.h>
using ceres::AutoDiffCostFunction;
using ceres::CauchyLoss;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;


struct ExponentialResidual
{
  ExponentialResidual(double x, double y)
      : x_(x), y_(y) {}

  template <typename T>
  bool operator()(const T *const k,
                  const T *const c,
                  const T *const b,
                  //const T* const g,

                  T *residual) const
  {

    residual[0] = y_ - (k[0] / (1.0 + (pow((x_ / c[0]), b[0]))));

    return true;
  }

private:
  const double x_;
  const double y_;
};
const int kNumObservations = 247;
double data[] = {
   0,3,
1,4,
2,4,
3,4,
4,7,
5,8,
6,8,
7,8,
8,8,
9,8,
10,10,
11,12,
12,12,
13,12,
14,16,
15,16,
16,18,
17,18,
18,18,
19,19,
20,19,
21,22,
22,22,
23,22,
24,22,
25,22,
26,22,
27,22,
28,22,
29,22,
30,22,
31,22,
32,22,
33,23,
34,23,
35,25,
36,29,
37,32,
38,36,
39,50,
40,55,
41,83,
42,93,
43,99,
44,117,
45,129,
46,149,
47,158,
48,197,
49,238,
50,428,
51,553,
52,673,
53,790,
54,900,
55,1030,
56,1183,
57,1306,
58,1518,
59,1624,
60,1796,
61,2031,
62,2161,
63,2320,
64,2470,
65,2626,
66,2766,
67,2908,
68,3116,
69,3333,
70,3483,
71,3662,
72,3793,
73,3963,
74,4119,
75,4228,
76,4346,
77,4530,
78,4683,
79,4817,
80,4987,
81,5072,
82,5182,
83,5251,
84,5305,
85,5389,
86,5425,
87,5482,
88,5532,
89,5603,
90,5691,
91,5742,
92,5780,
93,5820,
94,5851,
95,5945,
96,6002,
97,6071,
98,6176,
99,6298,
100,6353,
101,6383,
102,6428,
103,6467,
104,6535,
105,6589,
106,6656,
107,6726,
108,6742,
109,6779,
110,6819,
111,6855,
112,6872,
113,6894,
114,6941,
115,6978,
116,7009,
117,7059,
118,7137,
119,7185,
120,7245,
121,7417,
122,7604,
123,7619,
124,7629,
125,7732,
126,7762,
127,7819,
128,7857,
129,7877,
130,7970,
131,8247,
132,8266,
133,8303,
134,8322,
135,8329,
136,8336,
137,8338,
138,8369,
139,8402,
140,8445,
141,8453,
142,8494,
143,8505,
144,8515,
145,8529,
146,8535,
147,8556,
148,8572,
149,8587,
150,8590,
151,8596,
152,8600,
153,8606,
154,8616,
155,8634,
156,8637,
157,8639,
158,8640,
159,8643,
160,8648,
161,8658,
162,8663,
163,8668,
164,8674,
165,8677,
166,8683,
167,8696,
168,8704,
169,8718,
170,8725,
171,8729,
172,8734,
173,8737,
174,8755,
175,8764,
176,8779,
177,8800,
178,8815,
179,8831,
180,8840,
181,8861,
182,8884,
183,8897,
184,8904,
185,8943,
186,8956,
187,8964,
188,8976,
189,8985,
190,8999,
191,9001,
192,9002,
193,9023,
194,9038,
195,9063,
196,9070,
197,9083,
198,9094,
199,9103,
200,9114,
201,9129,
202,9149,
203,9175,
204,9200,
205,9212,
206,9219,
207,9235,
208,9240,
209,9249,
210,9257,
211,9267,
212,9274,
213,9285,
214,9291,
215,9296,
216,9306,
217,9317,
218,9334,
219,9340,
220,9354,
221,9360,
222,9374,
223,9385,
224,9391,
225,9397,
226,9459,
227,9559,
228,9583,
229,9628,
230,9810,
231,9868,
232,9915,
233,9946,
234,9969,
235,10031,
236,10052,
237,10147,
238,10167,
239,10219,
240,10276,
241,10358,
242,10505,
243,10576,
244,10687,
245,10769,
246,10919,


};

int main(int argc, char const *argv[])
{

  google::InitGoogleLogging(argv[0]);

  double k = 20000.0;
  //double c=0.5;
  double c = kNumObservations / 2.0;
  double b = 0.5;

  double g = 1.0;

  Problem problem;
  for (int i = 0; i < kNumObservations; i++)
  {

    problem.AddResidualBlock(
        new AutoDiffCostFunction<ExponentialResidual, 1, 1, 1, 1>(

            new ExponentialResidual(data[2 * i]*1.00, data[2 * i + 1]*1.00)),
        new CauchyLoss(0.5), &k, &c, &b);
  }
  Solver::Options options;
  options.max_num_iterations = 1000;
  options.linear_solver_type = ceres::DENSE_QR;
  //options.trust_region_strategy_type=ceres::DOGLEG;
  //options.gradient_tolerance=1e-8;
  //options.parameter_tolerance=1e-10;
  //options.function_tolerance=1e-8;
  options.minimizer_progress_to_stdout = true;
  Solver::Summary summary;
  Solve(options, &problem, &summary);
  //std::cout<<summary.BriefReport()<<std::endl;
  std::cout << summary.FullReport() << std::endl;
    
  std::cout << "Final k: " << k << " c: " << c << " b: " << b << " g:  " << g << "\n";
  /* code */
  return 0;
}
1

There are 1 best solutions below

0
On

look like if I ignore the first observed data 0,3 , ceres solver is able to obtain result !