source: trunk/yat/regression/Cox.cc @ 4198

Last change on this file since 4198 was 4198, checked in by Peter, 8 months ago

add classes doing Cox regression

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 8.1 KB
Line 
1// $Id: Cox.cc 4198 2022-08-19 06:26:14Z peter $
2
3/*
4  Copyright (C) 2022 Peter Johansson
5
6  This file is part of the yat library, https://dev.thep.lu.se/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 3 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with yat. If not, see <https://www.gnu.org/licenses/>.
20*/
21
22#include <config.h>
23
24#include "Cox.h"
25
26#include "detail/Cox.h"
27
28#include <yat/utility/VectorBase.h>
29
30#include <gsl/gsl_cdf.h>
31
32#include <boost/math/tools/roots.hpp>
33
34#include <algorithm>
35#include <memory>
36
37namespace theplu {
38namespace yat {
39namespace regression {
40
41  class Cox::Impl : public cox::Implementation<double>
42  {
43  public:
44    using cox::Implementation<double>::add;
45    void add(const yat::utility::VectorBase& x,
46             const yat::utility::VectorBase& time,
47             const std::vector<char>& event)
48    {
49      assert(x.size() == time.size());
50      assert(x.size() == event.size());
51      for (size_t i=0; i<x.size(); ++i)
52        add(x(i), time(i), event[i]);
53    }
54
55    double b(void) const { return beta_ ; }
56    double hazard_ratio(void) const
57    { return exp(beta_); }
58
59    double hazard_ratio_lower_CI(double alpha) const
60    { return exp(beta_ - hazard_ratio_CI(alpha)); }
61
62    double hazard_ratio_upper_CI(double alpha) const
63    { return exp(beta_ + hazard_ratio_CI(alpha)); }
64
65    double p(void) const
66    { return 2 * gsl_cdf_ugaussian_Q(std::abs(z())); }
67
68    void train(void);
69
70    double z(void) const { return beta_ / beta_std_error_; }
71
72  private:
73    double hazard_ratio_CI(double alpha) const
74    {
75      double z = gsl_cdf_ugaussian_Qinv(0.5 * (1.0 - alpha));
76      return z * beta_std_error_;
77    }
78    double beta_;
79    double beta_std_error_;
80
81    class logL
82    {
83    public:
84      logL(const std::vector<TimePoint>& times);
85      std::pair<double, double> operator()(double beta) const;
86
87      double hessian(double beta) const;
88    private:
89      const std::vector<TimePoint>& times_;
90    };
91  };
92
93
94  void Cox::Impl::train(void)
95  {
96    if (data_.empty())
97      return;
98    beta_ = 0;
99
100    prepare_times();
101    logL func(times_);
102    using boost::math::tools::newton_raphson_iterate;
103    beta_ = newton_raphson_iterate(func, beta_, -1e42, 1e42, 30);
104    if (std::isnan(beta_))
105      throw std::runtime_error("beta is NaN");
106    // Calculate 2nd deriviate at beta_;
107    double hessian = func.hessian(beta_);
108    beta_std_error_ = 1.0 / std::sqrt(hessian);
109  }
110
111
112  Cox::Impl::logL::logL(const std::vector<TimePoint>& times)
113    : times_(times)
114  {
115  }
116
117
118  std::pair<double, double> Cox::Impl::logL::operator()(double beta) const
119  {
120    // Using Efron's method:
121    //
122    // sort data wrt time and denote unique time t_i such that t_i <
123    // t_j iff i < j
124    // Let denote H_j the indices of events at time t_j, i.e.,
125    // Y_i = t_j and event_i = true; n_j = |H_j|
126    //
127    // log partial likelihood
128    // logL =
129    //  sum_j (sum_i x_i*beta - sum_k log{sum_i theta_i - k/n_j sum_i theta_i})
130    // where j runs over all unique times
131    // i runs over all events in H_j
132    // k runs over all events in H_j
133    // 1st i sum runs : Y_i >= t_j
134    // 2nd i sum runs over H_j
135    //
136    // and the derivative is
137    // deriv =
138    //  sum_j (sum_i x_i - sum_k {[sum_i theta_i*x_i - k/n_j sum_i theta_i*x_i] / [sum_i theta_i - k/n_j sum_i theta_i]})
139    // 1st i sum runs : Y_i >= t_j
140    // 2nd i sum runs over H_j
141    // 3rd i sum runs : Y_i >= t_j
142    // 4th i sum runs over H_j
143
144
145    /*
146      We handle tied ties using Efron's method. Let denote H_j the
147      indices of events at time t_j
148
149      logL =
150      \sum_j (sum_i(theta_i) - \sum_k(log(sum_i(theta_i) - k/m_j sum_i(theta_i))
151
152      where theta_i = beta * x_i
153      where j runs over all unique time points
154      1st i runs over H_j, i.e., all events at time t_j.
155      k runs over H_j
156      2nd i runs over all i: Y_i > t_j
157      3rd i runs over H_j
158    */
159
160    double logL = 0;
161    double deriv = 0;
162
163    double theta_Q = 0;
164    double thetaX_Q = 0;
165    for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
166      double sum_event_theta = 0;
167      double sum_event_thetaX = 0;
168      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
169        double theta = it->theta(beta);
170        sum_event_theta += theta;
171        sum_event_thetaX += theta * it->x;
172      }
173      theta_Q += sum_event_theta;
174      thetaX_Q += sum_event_thetaX;
175
176      for (auto it = time->censored_begin(); it!=time->censored_end(); ++it) {
177        double theta = it->theta(beta);
178        theta_Q += theta;
179        thetaX_Q += theta * it->x;
180      }
181
182
183      // loop over events at time point t
184      for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
185        const size_t k = it - time->events_begin();
186        double r = static_cast<double>(k) / time->size();
187
188        logL += it->x * beta;
189        assert(theta_Q > 0.0);
190        assert(theta_Q > r * sum_event_theta);
191        logL -= std::log(theta_Q - r * sum_event_theta);
192
193        deriv += it->x;
194        deriv -= (thetaX_Q - r * sum_event_thetaX) /
195          (theta_Q - r * sum_event_theta);
196      }
197
198    }
199
200    assert(!std::isnan(logL));
201    assert(!std::isnan(deriv));
202    return std::make_pair(-logL, -deriv);
203  }
204
205
206  double Cox::Impl::logL::hessian(double beta) const
207  {
208    // The second derivative of the log-likelihood evaluated at the
209    // maximum likelihood estimates (MLE) is the observed Fisher
210    // information
211
212    double hessian = 0;
213
214    double sum_theta = 0;
215    double sum_thetaX = 0;
216    double sum_thetaXX = 0;
217    for (const auto& t : times_) {
218      for (auto it = t.begin; it!=t.end; ++it) {
219        double theta = it->theta(beta);
220        sum_theta += theta;
221        sum_thetaX += theta * it->x;
222        sum_thetaXX += theta * it->x * it->x;
223      }
224    }
225
226    // loop over unique time points
227    for (const auto& t : times_) {
228
229      // sum over all events in H_j
230      double part_sum_theta = 0;
231      double part_sum_thetaX = 0;
232      double part_sum_thetaXX = 0;
233      for (auto it = t.events_begin(); it!=t.events_end(); ++it) {
234        double theta = it->theta(beta);
235        part_sum_theta += theta;
236        part_sum_thetaX += theta * it->x;
237        part_sum_thetaXX += theta * it->x * it->x;
238      }
239
240      // loop over events at time point t
241      for (auto it = t.events_begin(); it!=t.events_end(); ++it) {
242        const size_t k = it - t.events_begin();
243        double r = static_cast<double>(k) / t.size();
244
245        double S_thetaXX = sum_thetaXX - r * part_sum_thetaXX;
246        double S_thetaX  = sum_thetaX  - r * part_sum_thetaX;
247        double S_theta   = sum_theta   - r * part_sum_theta;
248        hessian += S_thetaXX/S_theta;
249        hessian -= std::pow(S_thetaX/S_theta, 2);
250      }
251
252      // update the cumulative sums
253      for (auto it = t.begin; it!=t.end; ++it) {
254        double theta = it->theta(beta);
255        sum_theta -= theta;
256        sum_thetaX -= theta * it->x;
257        sum_thetaXX -= theta * it->x * it->x;
258      }
259    }
260    return hessian;
261  }
262
263
264  // class Cox
265
266  Cox::Cox(void)
267    : pimpl_(new Impl)
268  {
269  }
270
271
272  Cox::Cox(const Cox& other)
273    : pimpl_(new Impl(*other.pimpl_))
274  {
275  }
276
277
278  Cox::Cox(Cox&& other)
279  {
280    std::swap(pimpl_, other.pimpl_);
281  }
282
283
284  Cox::~Cox(void)
285  {
286  }
287
288
289  Cox& Cox::operator=(const Cox& other)
290  {
291    assert(other.pimpl_);
292    pimpl_.reset(new Impl(*other.pimpl_));
293    return *this;
294  }
295
296
297  Cox& Cox::operator=(Cox&& other)
298  {
299    std::swap(pimpl_, other.pimpl_);
300    return *this;
301  }
302
303
304  void Cox::Cox::add(double x, double time, bool event)
305  {
306    pimpl_->add(x, time, event);
307  }
308
309
310  void Cox::add(const yat::utility::VectorBase& x,
311                const yat::utility::VectorBase& time,
312                const std::vector<char>& event)
313  {
314    pimpl_->add(x, time, event);
315  }
316
317
318  double Cox::b(void) const
319  {
320    return pimpl_->b();
321  }
322
323
324  void Cox::clear(void)
325  {
326    pimpl_->clear();
327  }
328
329
330  double Cox::hazard_ratio(void) const
331  {
332    return pimpl_->hazard_ratio();
333  }
334
335
336  double Cox::hazard_ratio_lower_CI(double alpha) const
337  {
338    return pimpl_->hazard_ratio_lower_CI(alpha);
339  }
340
341
342  double Cox::hazard_ratio_upper_CI(double alpha) const
343  {
344    return pimpl_->hazard_ratio_upper_CI(alpha);
345  }
346
347
348  double Cox::p(void) const
349  {
350    return pimpl_->p();
351  }
352
353
354  void Cox::train(void)
355  {
356    pimpl_->train();
357  }
358
359
360  double Cox::z(void) const
361  {
362    return pimpl_->z();
363  }
364
365}}}
Note: See TracBrowser for help on using the repository browser.