source: trunk/yat/regression/Cox.cc @ 4252

Last change on this file since 4252 was 4252, checked in by Peter, 4 months ago

merge 0.20 release into trunk

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 7.8 KB
Line 
1// $Id: Cox.cc 4252 2022-11-18 02:54:04Z peter $
2
3/*
4  Copyright (C) 2022 Peter Johansson
5
6  This file is part of the yat library, https://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <https://www.gnu.org/licenses/>.
20*/
21
22#include <config.h>
23
24#include "Cox.h"
25
26#include "detail/Cox.h"
27
28#include <yat/utility/Steffenson.h>
29#include <yat/utility/VectorBase.h>
30
31#include <gsl/gsl_cdf.h>
32
33#include <algorithm>
34#include <memory>
35
36namespace theplu {
37namespace yat {
38namespace regression {
39
40  class Cox::Impl : public cox::Implementation<double>
41  {
42  public:
43    using cox::Implementation<double>::add;
44    void add(const yat::utility::VectorBase& x,
45             const yat::utility::VectorBase& time,
46             const std::vector<char>& event)
47    {
48      assert(x.size() == time.size());
49      assert(x.size() == event.size());
50      for (size_t i=0; i<x.size(); ++i)
51        add(x(i), time(i), event[i]);
52    }
53
54    double b(void) const { return beta_ ; }
55    double hazard_ratio(void) const
56    { return exp(beta_); }
57
58    double hazard_ratio_lower_CI(double alpha) const
59    { return exp(beta_ - hazard_ratio_CI(alpha)); }
60
61    double hazard_ratio_upper_CI(double alpha) const
62    { return exp(beta_ + hazard_ratio_CI(alpha)); }
63
64    double p(void) const
65    { return 2 * gsl_cdf_ugaussian_Q(std::abs(z())); }
66
67    void train(void);
68
69    double z(void) const { return beta_ / beta_std_error_; }
70
71  private:
72    double hazard_ratio_CI(double alpha) const
73    {
74      double z = gsl_cdf_ugaussian_Qinv(0.5 * (1.0 - alpha));
75      return z * beta_std_error_;
76    }
77    double beta_;
78    double beta_std_error_;
79
80    class Score
81    {
82    public:
83      Score(const std::vector<TimePoint>& times);
84      double operator()(double beta) const;
85      double derivative(double beta) const;
86    private:
87      const std::vector<TimePoint>& times_;
88    };
89  };
90
91
92  void Cox::Impl::train(void)
93  {
94    if (data_.empty())
95      return;
96
97    prepare_times();
98    // score is derivative of logL
99    Score score(times_);
100    for (double b=0; b<1.1; b+=0.1)
101      score(b);
102    utility::Steffenson solver;
103    beta_ = solver(score, 0.0,
104                   utility::RootFinderDerivative::Delta(0.0, 1e-5));
105    if (std::isnan(beta_))
106      throw std::runtime_error("beta is NaN");
107    // 2nd derivative of logL is 1st derivative of score
108    double hessian = score.derivative(beta_);
109    beta_std_error_ = 1.0 / std::sqrt(-hessian);
110  }
111
112
113  /*
114    Without ties the log-likelihood is
115    logL = sum_i (x_i * beta - log sum_j x_j * beta)
116    where i runs over all events and j runs over all data points j
117    such that t_j >= t_i
118
119    Setting theta = x * beta we have
120    logL = sum_i (theta_i - log sum_j theta_j) =
121         = sum_i (theta_i - log theta_Q_i) =
122
123    theta = beta * x -> dtheta/dbeta = x
124
125    We handle ties using Efron's method. Let denote m_i number of data
126    points at t_i, H_i the indices of events at time t_i.
127
128    logL = sum_i (theta_H_i - sum_k^m_i-1 log(theta_Q_i - k/m_i theta_H_i))
129
130    where theta_H_i is a sum of theta running over H_i.
131
132    The derivative (wrt beta)
133    l' = sum_i(x_H_i - sum_k^m_i-1 (theta*x_Q_i - k/m_j theta+x_H_i) / (theta_Q_i - k/m_j theta_H_i))
134
135  */
136  Cox::Impl::Score::Score(const std::vector<TimePoint>& times)
137    : times_(times)
138  {
139  }
140
141
142  double Cox::Impl::Score::operator()(double beta) const
143  {
144    double score = 0;
145    // variables with suffix _Q denote sums running over data points
146    // (including events and censored data points) at current time and
147    // future
148    double theta_Q = 0;
149    double thetaX_Q = 0;
150
151    for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
152      // variables with suffix _H denote sums running over events at
153      // the current time.
154      double theta_H = 0;
155      double thetaX_H = 0;
156      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
157        double theta = it->theta(beta);
158        theta_H += theta;
159        thetaX_H += theta * it->x;
160      }
161      theta_Q += theta_H;
162      thetaX_Q += thetaX_H;
163
164      for (auto it = time->censored_begin(); it!=time->censored_end(); ++it) {
165        double theta = it->theta(beta);
166        theta_Q += theta;
167        thetaX_Q += theta * it->x;
168      }
169
170      // loop over events at time point t
171      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
172        score += it->x;
173
174        const size_t k = it - time->events_begin();
175        double r = static_cast<double>(k) / time->size();
176
177        assert(theta_Q > r * theta_H);
178
179        score -= (thetaX_Q - r * thetaX_H) / (theta_Q - r * theta_H);
180      }
181    }
182
183    assert(!std::isnan(score));
184    return score;
185  }
186
187
188  double Cox::Impl::Score::derivative(double beta) const
189  {
190    double deriv = 0;
191    // variables with suffix _Q denote sums running over data points
192    // (including events and censored data points) at current time and
193    // future
194    double theta_Q = 0;
195    double thetaX_Q = 0;
196    double thetaXX_Q = 0;
197    double XX_Q = 0;
198
199    for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
200      // variables with suffix _H denote sums running over events at
201      // the current time.
202      double theta_H = 0;
203      double thetaX_H = 0;
204      double thetaXX_H = 0;
205      double XX_H = 0;
206      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
207        double theta = it->theta(beta);
208        theta_H += theta;
209        thetaX_H += theta * it->x;
210        thetaXX_H += theta * it->x * it->x;
211        XX_H += it->x * it->x;
212      }
213      theta_Q += theta_H;
214      thetaX_Q += thetaX_H;
215      thetaXX_Q += thetaXX_H;
216      XX_Q += XX_H;
217
218      for (auto it = time->censored_begin(); it!=time->censored_end(); ++it) {
219        double theta = it->theta(beta);
220        theta_Q += theta;
221        thetaX_Q += theta * it->x;
222        thetaXX_Q += theta * it->x * it->x;
223        XX_Q += it->x * it->x;
224      }
225
226      // f = g/h
227      // g = - (thetaX_Q - r*thetaX_H)
228      // g'= - (thetaXX_Q - r*thetaXX_H)
229      //
230      // h = theta_Q - r*theta_H
231      // h'= thetaX_Q - r*thetaX_H =
232
233      // f' = (g/h)' = (g'h - gh')/h^2 =
234
235      // loop over events at time point t
236      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
237        const size_t k = it - time->events_begin();
238        double r = static_cast<double>(k) / time->size();
239        double g = - (thetaX_Q - r*thetaX_H);
240        double dg = - (thetaXX_Q - r * thetaXX_H);
241
242        double h = (theta_Q - r*theta_H);
243        double dh= (thetaX_Q - r*thetaX_H);
244        deriv += (dg*h - g*dh) / (h*h);
245      }
246    }
247
248    assert(!std::isnan(deriv));
249    return deriv;
250  }
251
252  // class Cox
253
254  Cox::Cox(void)
255    : pimpl_(new Impl)
256  {
257  }
258
259
260  Cox::Cox(const Cox& other)
261    : pimpl_(new Impl(*other.pimpl_))
262  {
263  }
264
265
266  Cox::Cox(Cox&& other)
267  {
268    std::swap(pimpl_, other.pimpl_);
269  }
270
271
272  Cox::~Cox(void)
273  {
274  }
275
276
277  Cox& Cox::operator=(const Cox& other)
278  {
279    assert(other.pimpl_);
280    pimpl_.reset(new Impl(*other.pimpl_));
281    return *this;
282  }
283
284
285  Cox& Cox::operator=(Cox&& other)
286  {
287    std::swap(pimpl_, other.pimpl_);
288    return *this;
289  }
290
291
292  void Cox::Cox::add(double x, double time, bool event)
293  {
294    pimpl_->add(x, time, event);
295  }
296
297
298  void Cox::add(const yat::utility::VectorBase& x,
299                const yat::utility::VectorBase& time,
300                const std::vector<char>& event)
301  {
302    pimpl_->add(x, time, event);
303  }
304
305
306  double Cox::b(void) const
307  {
308    return pimpl_->b();
309  }
310
311
312  void Cox::clear(void)
313  {
314    pimpl_->clear();
315  }
316
317
318  double Cox::hazard_ratio(void) const
319  {
320    return pimpl_->hazard_ratio();
321  }
322
323
324  double Cox::hazard_ratio_lower_CI(double alpha) const
325  {
326    return pimpl_->hazard_ratio_lower_CI(alpha);
327  }
328
329
330  double Cox::hazard_ratio_upper_CI(double alpha) const
331  {
332    return pimpl_->hazard_ratio_upper_CI(alpha);
333  }
334
335
336  double Cox::p(void) const
337  {
338    return pimpl_->p();
339  }
340
341
342  void Cox::train(void)
343  {
344    pimpl_->train();
345  }
346
347
348  double Cox::z(void) const
349  {
350    return pimpl_->z();
351  }
352
353}}}
Note: See TracBrowser for help on using the repository browser.