Table of Contents
問題意識
- とある属性や処置が顧客の生存時間にどのように影響するのかを分析する機会は多い
- 生存分析に置いて, 分析データの制約より, 顧客がchurnする前に観測期間が終了時点を迎えたり, study dropoutという問題に直面することが多い
- いわゆるtime-to-event data, failure-time dataの分析手法を身につけたい
生存分析の特徴
- 連続時間をベースに分析, 時間は正の実数
- subjectに応じて, exactな生存時間がわかったり, 打ち切りを迎えたりする(partially right-censored data)
- 打ち切りデータ(= partially observed data)の存在のため, meanやvarianceといったmomentの計算が困難
censoringの種類
打ち切り(Censoring)の種類はおもに3つに区分されます:
- Fixed type I censoring
- Random type I censoring
- Type II censoring
Fixed type I censoring
分析期間が$C$時間経過後に, すべてのsubjectについて終了となる場合に発生するCensoringです.
故障分析を例にするならば, 100日を $C$ としたとき, 100日以内に故障を迎えた機会については故障フラグ=True
と故障までの日数 T = 78
と観察されるが, 100日以内に故障しなかったものについては, 故障フラグ=False
とT = 100
という形で打ち切りとなった生存時間しか観察できません.
Random type I censoring
分析期間が$C_i$時間経過後に終了という形でデザインされているが, $C_i$がsubject毎に異なる場合のCensoringのことです.
Type II censoring
観察期間が事前に定めたCensring time, $C$, 時間経過することなく打ち切りを迎えてしまった場合のことです. 実験プロトコルからの逸脱などが例としてあります.
Notation and Terminology
Survival Function
Survival Functionは以下のようによく表現されます
\[S(t) = Pr(T > t) = 1 - F(t)\]- $T$: event発生までの生存時間
- $S(t)$: $t$まで生存している確率, non-increasing w.r.t Time
- $F(t)$: $t$まで死亡している確率, increasing w.r.t Time
- $f(t)$: $t$時点での死亡についての確率密度関数
Hazard Function
t時間まで生存している条件のもとで瞬間的に死亡する確率密度を表現したものをHazard関数, $h(t)$, といいます. 80歳と100歳の老人を考えた時, 後者のほうが死亡リスクが高いですが, 平均寿命的には多くの人は80歳近傍で死を迎えるのでdensity function的には前者近傍のほうが大きい値が出ます. ただ, density functionが80歳近傍のほうが100歳近傍と比べ大きいため前者のほうが死亡リスクが高いというのは解釈がおかしくなってしまいます. そこでそれぞれの年齢における死亡リスクを表現した関数がHazard関数です.
\(\begin{align*} h(t) &= \lim_{\Delta t \to 0} \frac{Pr(t < T \leq t + \Delta t| T >t)}{\Delta t}\\ &= \frac{f(t)}{S(t)} \end{align*}\)
- なお$F(t) = 1 - \exp(-\lambda t)$という指数分布を仮定するとhazard関数は時間$t$には無関係に$h(t)=\lambda$となります.
hazard関数を積分してcumulative hazard関数を以下のように定義します:
\[H(t) = \int^t_0 h(u)du\]このとき, hazard関数の定義に注意すると
\[\begin{align*} H(t) &= \int^t_0 h(u)du\\ &= \int^t_0\frac{f(u)}{1 - F(u)}du\\ & = \left[-\log (1 - F(u))\right]_0^t \end{align*}\]よって, 以下の関係式を得ることができます
\[\begin{align*} F(t) &= 1 - \exp\left(-\int^t_0h(u)du\right)\\ f(t) &= h(t)\exp\left(-\int^t_0h(u)du\right) \end{align*}\]- 生存時間が指数分布従う場合, hazard関数が時間に無関係の定数であることから確率密度関数, 累積分布関数の導出が可能となります
Kaplan-Meier推定量
KM推定量はnonparametricな生存関数の推定量です. Parametricに指数分布やワイブル分布を仮定しMLEで推定する方法もありますが, Parametric推定量では用いる分布形状から要請される仮定(harzard関数はコンスタントなど)を現実のものと適合させる, or 仮定の妥当性について分析者間で合意する事が困難のため, 基本的にはnonparametric(or semiparametric)アプローチが生存分析では推奨されます.
Survival Dataの形式
- $T_i$: the eventを迎えた時の潜在的生存時間 for the $i$th subject, 分析者には見えない場合がある
- $C_i$: the censoring time for the $i$th subject
- $\delta_i$: the event indicator
- $Y_i$: 観察された生存時間
データ例:
sksurv.datasets
より
1
2
3
4
5
6
7
import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
df_x, array_y = load_veterans_lung_cancer()
df_y = pd.DataFrame(array_y)
df = pd.concat([df_y, df_x], axis=1)
df.head()
Then,
Status | Survival_in_days | Age_in_years | Celltype | Karnofsky_score | Months_from_Diagnosis | Prior_therapy | Treatment |
---|---|---|---|---|---|---|---|
True | 72.0 | 69.0 | squamous | 60.0 | 7.0 | no | standard |
True | 411.0 | 64.0 | squamous | 70.0 | 5.0 | yes | standard |
True | 228.0 | 38.0 | squamous | 60.0 | 3.0 | no | standard |
True | 126.0 | 63.0 | squamous | 60.0 | 9.0 | yes | standard |
True | 118.0 | 65.0 | squamous | 70.0 | 11.0 | yes | standard |
Status
: the eventを迎えた場合True
, $\delta_i$Survival_in_days
: 観察された生存時間, $Y_i$Age_in_years
: 年齢
データ型がnp.float64
以外のカラムのvalueを確認して見る場合は以下:
1
2
3
4
5
6
7
cat_cols = df.select_dtypes(exclude=np.float64).columns.tolist()
(pd.DataFrame(
df[cat_cols].melt(var_name='column',
value_name='value'
).value_counts())
.rename(columns={0: 'counts'})
.sort_values(by=['column', 'counts']))
Then,
column | value | counts |
---|---|---|
Celltype | adeno | 27 |
large | 27 | |
squamous | 35 | |
smallcell | 48 | |
Prior_therapy | yes | 40 |
no | 97 | |
Status | False | 9 |
True | 128 | |
Treatment | test | 68 |
standard | 69 |
Kaplan-Meier推定量
基本方針
- subjectが同じ生存分布関数に従う仮定の下でKaplan-Meier推定量というnonparametric推定量を用いることができます
- the empirical cumulative distribution function, $F_n(t)$を計算し $S(t) \equiv 1 - F_n(t)$で推定
\[\hat{S}(t) = \prod_{j \lt t} \frac{n_j - d_j}{n_j}\]推定量
- $n_j$: the total number of observations at time $j$, $j-1$期末時点で生存 & 今期censoredされていないsubject
- $d_j$: $j$期に死亡した人数
sksurv.nonparametric
を用いた推定
1
2
3
4
5
6
7
8
9
10
11
is_therapy = df['Prior_therapy'] == 'yes'
event = df['Status'].values
time = df['Survival_in_days'].values
x_total, y_total = kaplan_meier_estimator(event, time)
x_treated, y_treated = kaplan_meier_estimator(event[is_therapy], time[is_therapy])
x_control, y_control = kaplan_meier_estimator(event[~is_therapy], time[~is_therapy])
plt.step(x_treated, y_treated, where="post", label='treated')
plt.step(x_control, y_control, where="post", label='controled')
plt.ylim(0, 1)
plt.legend();
pandasとnumpyを用いて再現したい場合
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
## x_total, y_totalの再現のみ
km_idx, km_freq = np.unique(time, return_counts=True)
km_freq = np.cumsum(km_freq[::-1])[::-1]
death_idx, death_freq = np.unique(time[event], return_counts=True)
df_left = pd.DataFrame(np.array([km_idx, km_freq],dtype=np.int64).T,columns=['time', 'obs'])
df_right = pd.DataFrame(np.array([death_idx, death_freq],dtype=np.int64).T,columns=['time', 'death'])
df_merge = (pd.merge(df_left, df_right, how='left', on='time')).fillna(0)
df_merge['ratio'] = np.cumprod(1 - df_merge['death']/df_merge['obs'])
## 95% confidence intervalの計算
df_merge['std_err'] = np.sqrt(np.cumsum(df_merge['death']/(df_merge['obs'] * (df_merge['obs'] - df_merge['death']))))
df_merge['CI_lower'] = np.clip(df_merge['ratio'] * np.exp(-1.96 * df_merge['std_err']), 0, 1)
df_merge['CI_upper'] = np.clip(df_merge['ratio'] * np.exp(1.96 * df_merge['std_err']), 0, 1)
## plotの再現
plt.step(df_merge['time'].values, df_merge['ratio'].values,where="post", label='survival')
plt.fill_between(df_merge['time'].values, df_merge['CI_lower'].values,df_merge['CI_upper'].values, color = 'b', alpha = 0.2, step="post")
plt.legend();
Confidence Intervalの導出
Nonparametricに推定したとはいえ, 推定値がどれほどreliableか?という問題はデータ分析の解釈において重要となります. 推定値のreliablityを評価する際よく用いられるのが,推定値のstandard errorとConfidence Intervalです.
なお, Notationとして $\hat f(t)$はKM推定量から得られた確率密度関数とします.
\[\hat f(t) \equiv \frac{d_t}{n_t}\]Central Limit Theorem for $\hat f(t)$
$\hat f(t)$はbinomial likelihoodに基づき計算されるので,
\[\sqrt{n_t}(\hat f(t) - f(t)) \xrightarrow{d}N(0, f(t)(1 - f(t)))\]言い換えると, $\hat f(t)$はapproximately normal with mean $f(t)$ and variance $f(t)(1 - f(t))/n_t$なので
\[Var(\hat f(t)) = \frac{d_t(n_t - d_t)}{n_t^3}\]Log transformation & Delta Method
Log transformationしたSurvival Functionを考えます
\[\log S(t) = \sum_{j \leq t}\log(1 - f(j))\]$f(t)$が互いにindependentであるという仮定の下で$\log S(t)$のVarianceはDelta Methodを用いて以下のように計算することができます.
\[Var(log(1 - \hat f(t)) = \frac{Var(\hat f(t))}{(1 - \hat f(t))^2}\]従って, $Var(\log \hat S(t))$の推定量を$\hat s^2(t)$としたとき,
\[\hat s^2(t) = \sum_{j\leq t} \frac{d_j}{n_j(n_j - d_j)}\]よって, Confidence Intervalは
\[\begin{align*} CI &= \exp(\log \hat S(t) \pm 1.96 \hat s(t))\\ &= \hat S(t)\exp(\pm 1.96 \hat s(t)) \end{align*}\]どのような仮定下正当化されるのか?
- 確率変数としての生存日数が各$t_j$ポイントで評価される離散分布に従っている場合
- 各$t_j$ポイントで評価される$f(t_j)$が互いに独立
- 生存日数がcontinuousの場合は, 別の切り口が必要
- なおここでは紹介はしないがlog-log approachでのCI導出が良いとされている(Rでは上記のCIがデフォルトで出力される)
ネルソンアーレン推定量との関係
累積ハザード関数がその瞬間瞬間ごとのリスクの合計値に着目し, イベント発生者数/リスク集合の和で評価したものをネルソンアーレン推定量といいます:
\[\tilde H(t) = \sum_{j=1}^t\frac{d_j}{n_j}\]累積ハザード関数と生存関数の関係式が
\[S(t) = \exp(-H(t))\]で表せられることに着目し, 累積ハザード関数の推定量としてのネルソンアーレン推定量, $\tilde H(t)$, から生存関数を以下のように表現することができます
\[\begin{align*} \tilde H(t) &= \sum_{j=1}^t\frac{d_j}{n_j}\\ \tilde S(t) &=\prod_{j=1}^t\exp(-\frac{d_j}{n_j}) \end{align*}\]一般にこれは小サンプルに有効な推定量と言われていますが, KM推定量との関係性として以下が知られています:
\[\tilde S(t) \geq \hat S(t) \forall t\]Proof
まず$f(x) = \exp(-x) - (1 - x)$を考えます. 導関数を確認すればわかりますが, 任意の$x$において
\[\exp(-x) \geq (1 - x)\]が成立します. 従って,
\(\begin{align*} \tilde S(t) &= \prod_{j=1}^t\exp(-\frac{d_j}{n_j})\\ &\geq \prod_{j=1}^t \left(1 - \frac{d_j}{n_j}\right)\\ &= \hat S(t) \end{align*}\)
Standard errors for the Nelson-Aalen estimator
standard errorの計算のIntuitionは, NA推定量は毎期のイベント発生 or notのbinomial probabilityの合計を計算しているのと同じなので
\(\begin{align*} Var(\tilde H(t)) &= \sum Var\left(\frac{d_j}{n_j}\right)\\ &= \sum\left(\frac{d_j(n_j - d_j)}{n_j^3}\right) \end{align*}\)
Medianの評価
生存分析のデータの多くはcensored dataなので means を計算することができません(条件付き means なら計算できますが). 一方, median は計算することができ
\[\text{Median}(t) = \min(t) \text{ s.t. } \hat S(t) \leq 0.5\]The log-rank test
問題設定
- 2つのグループの生存時間分布に差が存在するか検定したいとする, i.e., 生存関数について$H_0: S_1 = S_2$
- $n_{1j}, d_{1j}$: the time, $t_j$におけるグループ1の観察人数と死亡人数
- $n_{2j}, d_{2j}$: the time, $t_j$におけるグループ2の観察人数と死亡人数
検定統計量の導出
the time, $t_j$, におけるContingency Table
Group 1 | Group 2 | Total | |
---|---|---|---|
Deaths | $d_{1j}$ | $d_{2j}$ | $d_{j}$ |
Survivors | $n_{1j} - d_{1j}$ | $n_{2j} - d_{2j}$ | $n_{j} - d_{j}$ |
Obs | $n_{1j}$ | $n_{2j}$ | $n_{j}$ |
chi-squared distributionの導出
The null hypothesis, $S_1 = S_2$の下では, 確率変数$D_{1j}$はthe hypergeometric distribution(超幾何分布)に従うので, meanとvarianceはそれぞれ
\[\begin{align*} e_{1j} &= n_{1j}\frac{d_j}{n_j}\\ v_{1j} &= \frac{d_j}{n_j}\frac{n_j - d_j}{n_j}\frac{n_{1j}n_{2j}}{n_j - 1} \end{align*}\]このとき, $w_j = d_{1j} − e_{1j}$はaprroximatelyに$N(0, v_{1j})$に従うと考えられるので, 条件付き独立の仮定の下,
\[W \sim N(0, V), \text{ where } W = \sum_{j=1}^T w_j, V = \sum_{j=1}^T v_j\]従って, 検定統計量は, 自由度1のカイ自乗分布に従うことから
\[\frac{W^2}{V} = \frac{(\sum w_j)^2}{\sum v_j} \sim \chi^2_1\]- これはthe log-rank test, または the Cochran-Mantel-Haenszel testと呼ばれます
Multiple groupsへの拡張
Reference Group1つ(便宜上group indexは0とする)を含む$K + 1$ groupsを考えます. このとき, 次のベクトルを考えます:
\[\bold w_j = (w_{1j}, w_{2j}, \cdots, w_{Kj})\]$\bold w_j$のconditional covariance matrix, $\bold V_j$は
\[\begin{align*} (V_j)_{ii} & = \frac{d_j}{n_j}\frac{n_j - d_j}{n_j}\frac{n_{1j}(n_{j}-n_{ij})}{n_j - 1}\\ (V_j)_{ik} & = - \frac{n_{ij}n_{kj}d_j(n_j - d_j)}{n_j^2(n_j-1)}\text{ where } i \neq k \end{align*}\]Then,
\[\bold w^T \bold V^{-1} \bold w \sim \chi^2_K\]Pythonで検定関数を書いてみる
lifelines.statistics.logrank_test
と出力結果が一致することは確認済み
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from scipy import stats
import numpy as np
import pandas as pd
from collections import OrderedDict
def survival_logrank_test(duration_array, event_array, group_array):
"""K-sample log-rank hypothesis test of identical survival functions.
Compares the pooled hazard rate with each group-specific
hazard rate. The alternative hypothesis is that the hazard
rate of at least one group differs from the others at some time.
Parameters
----------
duration_array : time of event or time of censoring array, shape = (n_samples,)
event_array : the binary event indicator array
1: event observed
0: censored
group_array : array-like, shape = (n_samples,)
Group membership of each sample.
Returns
-------
chisq : float
Test statistic.
pvalue : float
Two-sided p-value with respect to the null hypothesis
that the hazard rates across all groups are equal.
stats : pandas.DataFrame
Summary statistics for each group: number of samples,
observed number of events, expected number of events,
and test statistic.
Only provided if `return_stats` is True.
covariance : array, shape=(n_groups, n_groups)
Covariance matrix of the test statistic.
Only provided if `return_stats` is True.
References
----------
.. [1] Fleming, T. R. and Harrington, D. P.
A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data.
Communications In Statistics 10 (1981): 763-794.
"""
n_samples = len(duration_array)
groups, group_counts = np.unique(group_array, return_counts=True)
n_groups = groups.shape[0]
# sort descending
o = np.argsort(-duration_array, kind="mergesort")
x = group_array[o]
event = event_array[o]
time = duration_array[o]
at_risk = np.zeros(n_groups, dtype=int)
observed = np.zeros(n_groups, dtype=int)
expected = np.zeros(n_groups, dtype=float)
covar = np.zeros((n_groups, n_groups), dtype=float)
covar_indices = np.diag_indices(n_groups)
k = 0
while k < n_samples:
ti = time[k]
total_events = 0
while k < n_samples and ti == time[k]:
idx = np.searchsorted(groups, x[k])
if event[k]:
observed[idx] += 1
total_events += 1
at_risk[idx] += 1
k += 1
if total_events != 0:
total_at_risk = k
expected += at_risk * (total_events / total_at_risk)
if total_at_risk > 1:
multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1))
temp = at_risk * multiplier
covar[covar_indices] += temp
covar -= np.outer(temp, at_risk) / total_at_risk
chi_df = n_groups - 1
numerator = observed[:chi_df] - expected[:chi_df]
chisq = np.linalg.solve(covar[:chi_df, :chi_df], numerator) @ numerator
pval = stats.chi2.sf(chisq, chi_df)
table = OrderedDict()
table["counts"] = group_counts
table["observed"] = observed
table["expected"] = expected
table["statistic"] = observed - expected
table = pd.DataFrame.from_dict(table)
table.index = pd.Index(groups, name="group", dtype=groups.dtype)
return chisq, pval, table, covar, chi_df
# numpy example
G = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2])
T = np.array([5, 3, 9, 8, 7, 4, 4, 3, 2, 5, 6, 7])
E = np.array([1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0])
survival_logrank_test(duration_array=T, event_array=E, group_array=G)
References
ライブラリー
Lecture Notes
- BIOST 515, 2004, Lecture 15 Introduction to Survival Analysis
- Inference for the Kaplan-Meier Estimator
Stackoverflow
(注意:GitHub Accountが必要となります)