Deconstructing Kaplan-Meier

This is attempt to calculate KM “by hand” to better understand the underlying math. I am starting with this tutorial.

Code
library(tidyverse)

Recreating the data as it is presented

Code
id <- seq(1, 20)

yr_of_death <- c(NA, 3, rep(NA, 4), 14, rep(NA, 6),1, NA, 23, NA, 5, NA, 17)

yr_last <- c(24, NA, 11,19,24,13, NA, 2,18,17,24,21,12, NA, 10, NA, 6, NA, 9, NA)

df <- data.frame(id = id, yr_of_death = yr_of_death, 
           yr_last = yr_last)

kableExtra::kable(head(df))
id yr_of_death yr_last
1 NA 24
2 3 NA
3 NA 11
4 NA 19
5 NA 24
6 NA 13

Manipulating the data to simulate the structure for further analysis. It provides a timeline data for each participant

Code
uncounted <- df |> 
  mutate(end_time = ifelse(is.na(yr_last), yr_of_death, yr_last), 
         end_event = ifelse(is.na(yr_last), "death", "f/u")) |> 
  mutate(remain_time = 24-end_time) 

until_event <- uncounted |> 
  uncount(end_time, .remove = F) |> 
  mutate(yr = row_number(), 
         event = ifelse(yr == end_time,end_event , "0"),
         n = n(),
         status = "in",
         .by = id) 

after_event <- uncounted |> 
  uncount(remain_time, .remove = F) |> 
  mutate(yr1 = row_number() -1 , 
         idk1 = remain_time - yr1, 
         yr = idk1 + end_time,
         event = "0",
         n = n(),
         status = paste0("lost-", end_event),
         .by = id) |> 
  select(-yr1, -idk1)


event_table <- rbind(until_event, after_event) |> 
  arrange(id, yr) |> 
  select(id, yr, status, event)|> 
    mutate(yr_group = cut(yr, breaks = c(0,4,9,14,19, 24))) |> 
  mutate(yr_group_label = case_when(
    yr <=4 ~ "0-4",
    yr <= 9 ~ "5-9",
    yr <= 14 ~ "10-14",
    yr <= 19 ~ "15-19",
    yr <= 24 ~ "20-24",
    T ~ "Other"
  ))

Kaplan-Meier

Summarizes the data for the KM plot, giving survival probability per time point

Method 1

Code
km_no_at_risk <- event_table |> 
  summarise(n = n(), 
         .by = c(yr, status), 
  ) |> 
  arrange(yr) |> 
  filter(status == "in") |> 
  rename(n_atrisk = n)
  

km_no_events <- event_table |> 
  summarise(n = n(), 
         .by = c(yr, event), 
  ) |> 
  arrange(yr) |> 
  pivot_wider(id_cols =yr, 
              names_from = event, 
              values_from = n) |> 

  janitor::clean_names() |> 
  select(-x0) |> 
  mutate(across(death:f_u, ~ replace_na(.x, 0)))


km_data <- km_no_at_risk |> 
  left_join(km_no_events, by = "yr") |> 
  add_case(yr = 0, status = "in", n_atrisk = 0, death = 0, f_u = 0, .after = 0) |> 
  mutate(idk = (n_atrisk-death)/n_atrisk, 
         idk = ifelse(yr ==0, 1, idk), 
         surv_prop = accumulate(idk, ~ .x*.y))




## Old loop - I moved to the accumulate function above        
# output <- c()           
# input <- kmtable1$idk
# output[1] <- input[1] * 1
# 
# for (i in 2:length(input)) {
#   output[i] <- input[i] * output[i - 1]
# }
# 
# 
# km_data <- cbind(kmtable1, surv_prop = output) 

km_data |> 
  select(yr, surv_prop) |> 
  distinct(surv_prop, .keep_all = T)
  yr surv_prop
1  0 1.0000000
2  1 0.9500000
3  3 0.8972222
4  5 0.8444444
5 14 0.7600000
6 17 0.6755556
7 23 0.5066667

Method 2

This does the same as but without breaking it down per person

Code
km_data2 <- df |> 
  mutate(yr = ifelse(is.na(yr_last), yr_of_death, yr_last), 
         event = ifelse(is.na(yr_last), "death", "lost f_u"), 
         across(yr_of_death:yr_last, ~ ifelse(is.na(.x), 0, 1))) |> 
  arrange(yr) |> 
  mutate(t_death = sum(yr_of_death), 
            t_lost = sum(yr_last), 
            .by = yr) |> 
  mutate(cum_death = cumsum(yr_of_death), 
         cum_lost = cumsum(yr_last), 
         cum_loss = cum_death + cum_lost,
         at_risk = ifelse(yr == 1, nrow(df), nrow(df) - lag(cum_loss))) |> 
  arrange(yr, -at_risk)  |> 
  distinct(yr, t_death, t_lost, .keep_all = T) |> 
  select(yr, t_death, t_lost, at_risk) |> 
  mutate(prt1 = (at_risk-t_death)/at_risk, 
         surv_prop = accumulate(prt1, ~ .x*.y))




## Old loop - replaced with accumulate
# output <- c()           
# input <- kmtable2$prt1
# output[1] <- input[1] * 1
# 
# for (i in 2:length(input)) {
#   output[i] <- input[i] * output[i - 1]
# }
# 
# 
# km_data2 <- cbind(kmtable2, surv_prop = output) 

km_data2 |> 
  select(yr, surv_prop) |> 
  distinct(surv_prop, .keep_all = T)
  yr surv_prop
1  1 0.9500000
2  3 0.8972222
3  5 0.8444444
4 14 0.7600000
5 17 0.6755556
6 23 0.5066667

It is the same as using the functions from survival

Code
library(survival)


event_table|> 
  filter(event != "0") |> 
  mutate(event2 = ifelse(event == "death", 2, 1)) %>%
  do(surv = survfit(Surv(yr, event2) ~ 1, data = ., conf.type = "plain")) |> 
  pull(surv) |> 
  first() |> 
  summary()
Call: survfit(formula = Surv(yr, event2) ~ 1, data = ., conf.type = "plain")

 time n.risk n.event survival std.err lower 95% CI upper 95% CI
    1     20       1    0.950  0.0487        0.854        1.000
    3     18       1    0.897  0.0689        0.762        1.000
    5     17       1    0.844  0.0826        0.682        1.000
   14     10       1    0.760  0.1093        0.546        0.974
   17      9       1    0.676  0.1256        0.429        0.922
   23      4       1    0.507  0.1740        0.166        0.848

Gets the figure for the KM

Code
km_data |> 
  mutate(surv2 = lag(surv_prop)) |> 
  pivot_longer(cols = c(surv_prop, surv2), 
               names_to = "type", 
               values_to = "surv") |> 
  arrange(yr, -surv) |> 
  mutate(surv = replace_na(surv, 1)) |> 
  ggplot(aes(x = yr, y = surv)) +
  geom_point()  +
  geom_line() +
  theme_minimal()

Confidence intervals

Code
km_data2 |> 
  mutate(step1 = t_death/(at_risk*(at_risk-t_death)), 
         step2 = accumulate(step1, ~ .x + .y), 
         se = sqrt(step2) * surv_prop, 
         ci_l = surv_prop - 1.96*se, 
         ci_u = surv_prop + 1.96*se
  ) |> 
  select(-step1, -step2, -prt1) |> 
  distinct(surv_prop, se, ci_l, ci_u, .keep_all = T)
  yr t_death t_lost at_risk surv_prop         se      ci_l      ci_u
1  1       1      0      20 0.9500000 0.04873397 0.8544814 1.0455186
2  3       1      0      18 0.8972222 0.06891433 0.7621501 1.0322943
3  5       1      0      17 0.8444444 0.08263493 0.6824800 1.0064089
4 14       1      0      10 0.7600000 0.10931097 0.5457505 0.9742495
5 17       1      1       9 0.6755556 0.12561705 0.4293461 0.9217650
6 23       1      0       4 0.5066667 0.17397885 0.1656681 0.8476652

- Actuarial Table

Gives life table (actuarial table) by sorting into time intervals

Code
### Acturary table -----------------------------------------------


prt1 <- event_table |> 
  summarise(n = n(), 
            .by = c(yr_group, event)) |> 
  arrange(yr_group) |> 
  filter(event != 0) |> 
  pivot_wider(id_cols = yr_group, 
              names_from = event, 
              values_from = n) |> 
  janitor::clean_names() 
  


prt2 <-event_table |> 
  summarise(n = n(), 
            .by = c(yr,yr_group, status)) |> 
  filter(status == "in") |> 
  slice_min(yr, by = yr_group) |>
  rename(n_alive = n) |> 
  select(-status)
  
act_table <- prt1 |> 
  left_join(prt2, by = "yr_group")

act_table2 <- act_table |> 
  mutate(at_risk = n_alive - (f_u/2), 
         prop_death = ncar::Round(death/at_risk, 3), 
         prop_surv_atrisk = 1-prop_death) 


output <- c()   

input <- act_table2$prop_surv_atrisk
output[1] <- input[1] * 1

# Loop through the remaining elements to perform the calculation
for (i in 2:length(input)) {
  output[i] <- input[i] * output[i - 1]
}

output
[1] 0.8970000 0.8404890 0.7707284 0.6682215 0.4457038
Code
cbind(act_table2, surv_prop = output)
  yr_group death f_u yr n_alive at_risk prop_death prop_surv_atrisk surv_prop
1    (0,4]     2   1  1      20    19.5      0.103            0.897 0.8970000
2    (4,9]     1   2  5      17    16.0      0.063            0.937 0.8404890
3   (9,14]     1   4 10      14    12.0      0.083            0.917 0.7707284
4  (14,19]     1   3 15       9     7.5      0.133            0.867 0.6682215
5  (19,24]     1   4 20       5     3.0      0.333            0.667 0.4457038

Dataset 2

Using a different dataset, taken from: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1065034/

I was too lazy to type in each row so I copied in html text for parsing

Code
char <- "<tbody><tr><td align='left' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>75</td></tr><tr><td align='left' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>79</td></tr><tr><td align='left' rowspan='1' colspan='1'>3</td><td align='center' rowspan='1' colspan='1'>4</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>85</td></tr><tr><td align='left' rowspan='1' colspan='1'>4</td><td align='center' rowspan='1' colspan='1'>5</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>76</td></tr><tr><td align='left' rowspan='1' colspan='1'>5</td><td align='center' rowspan='1' colspan='1'>6</td><td align='center' rowspan='1' colspan='1'>Unknown</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>66</td></tr><tr><td align='left' rowspan='1' colspan='1'>6</td><td align='center' rowspan='1' colspan='1'>8</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>75</td></tr><tr><td align='left' rowspan='1' colspan='1'>7</td><td align='center' rowspan='1' colspan='1'>9</td><td align='center' rowspan='1' colspan='1'>Survived</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>72</td></tr><tr><td align='left' rowspan='1' colspan='1'>8</td><td align='center' rowspan='1' colspan='1'>9</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>70</td></tr><tr><td align='left' rowspan='1' colspan='1'>9</td><td align='center' rowspan='1' colspan='1'>12</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>71</td></tr><tr><td align='left' rowspan='1' colspan='1'>10</td><td align='center' rowspan='1' colspan='1'>15</td><td align='center' rowspan='1' colspan='1'>Unknown</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>73</td></tr><tr><td align='left' rowspan='1' colspan='1'>11</td><td align='center' rowspan='1' colspan='1'>22</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>2</td><td align='center' rowspan='1' colspan='1'>66</td></tr><tr><td align='left' rowspan='1' colspan='1'>12</td><td align='center' rowspan='1' colspan='1'>25</td><td align='center' rowspan='1' colspan='1'>Survived</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>73</td></tr><tr><td align='left' rowspan='1' colspan='1'>13</td><td align='center' rowspan='1' colspan='1'>37</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>68</td></tr><tr><td align='left' rowspan='1' colspan='1'>14</td><td align='center' rowspan='1' colspan='1'>55</td><td align='center' rowspan='1' colspan='1'>Died</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>59</td></tr><tr><td align='left' rowspan='1' colspan='1'>15</td><td align='center' rowspan='1' colspan='1'>72</td><td align='center' rowspan='1' colspan='1'>Survived</td><td align='center' rowspan='1' colspan='1'>1</td><td align='center' rowspan='1' colspan='1'>61</td></tr></tbody>"


char2 <- xml2::as_list(rvest::read_html(char))

df2 <- matrix(unlist(char2), ncol = 5, byrow = T) |> 
  as.data.frame() 

names(df2) <-c("id", "days_surv", "outcome", "treatment", "age")

kableExtra::kable(head(df2))
id days_surv outcome treatment age
1 1 Died 2 75
2 1 Died 2 79
3 4 Died 2 85
4 5 Died 2 76
5 6 Unknown 2 66
6 8 Died 1 75
Code
kmtable3 <- df2 |> 
  mutate(dayz = as.numeric(days_surv)) |>  
  arrange(treatment, dayz) |> 
  mutate(yr_of_death = ifelse(outcome == "Died", 1, 0), 
         yr_last = ifelse(outcome != "Died", 1, 0)) |> 
  summarize(t_loss = n(), 
            t_death = sum(yr_of_death), 
            t_lost = sum(yr_last), 
            # t_loss = t_death + t_lost,
            .by = c(dayz, treatment)) |> 
  mutate(all = sum(t_loss),
         all_loss = all - cumsum(t_loss),
         at_risk = lag(all_loss),
         at_risk = ifelse(is.na(at_risk), all, at_risk),
         .by = treatment) |> 
  select(-all, -all_loss) |> 
  mutate(prt1 = (at_risk-t_death)/at_risk, 
         surv_prop = accumulate(prt1, ~ .x*.y),
        .by = treatment)



kmtable3 |> 
  distinct(dayz, treatment, surv_prop)
   dayz treatment surv_prop
1     8         1 0.8571429
2    12         1 0.7142857
3    15         1 0.7142857
4    25         1 0.7142857
5    37         1 0.4761905
6    55         1 0.2380952
7    72         1 0.2380952
8     1         2 0.7500000
9     4         2 0.6250000
10    5         2 0.5000000
11    6         2 0.5000000
12    9         2 0.3333333
13   22         2 0.0000000

Confidence intervals

Code
kmtable3  |> 
  mutate(step1 = t_death/(at_risk*(at_risk-t_death)), 
         step2 = accumulate(step1, ~ .x + .y), 
         se = sqrt(step2) * surv_prop, 
         ci_l = surv_prop - 1.96*se, 
         ci_u = surv_prop + 1.96*se,
         other_ = surv_prop + 1.96*se,
         ci_l = ifelse(ci_l < 0, 0 ,ci_l),
         .by = treatment
  )|> 
  select(-step1, -step2, -prt1) |> 
  distinct(surv_prop, se, ci_l, ci_u, .keep_all = T)
  dayz treatment t_loss t_death t_lost at_risk surv_prop        se       ci_l
1    8         1      1       1      0       7 0.8571429 0.1322600 0.59791323
2   12         1      1       1      0       6 0.7142857 0.1707469 0.37962170
3   37         1      1       1      0       3 0.4761905 0.2252786 0.03464437
4   55         1      1       1      0       2 0.2380952 0.2025643 0.00000000
5    1         2      2       2      0       8 0.7500000 0.1530931 0.44993751
6    4         2      1       1      0       6 0.6250000 0.1711633 0.28951993
7    5         2      1       1      0       5 0.5000000 0.1767767 0.15351768
8    9         2      2       1      1       3 0.3333333 0.1800206 0.00000000
9   22         2      1       1      0       1 0.0000000       NaN         NA
       ci_u    other_
1 1.1163725 1.1163725
2 1.0489497 1.0489497
3 0.9177366 0.9177366
4 0.6351212 0.6351212
5 1.0500625 1.0500625
6 0.9604801 0.9604801
7 0.8464823 0.8464823
8 0.6861737 0.6861737
9       NaN       NaN

Confirm with the survival package

Code
## same as this
df2 |> 
  mutate(days_surv = as.numeric(days_surv), 
         event = ifelse(outcome == "Died", 1, 0)) |>
  do(surv = survfit(Surv(days_surv, event) ~ treatment, data = ., conf.type = "plain")) |> 
  pull(surv) |> 
  first() |> 
  summary()
Call: survfit(formula = Surv(days_surv, event) ~ treatment, data = ., 
    conf.type = "plain")

                treatment=1 
 time n.risk n.event survival std.err lower 95% CI upper 95% CI
    8      7       1    0.857   0.132       0.5979        1.000
   12      6       1    0.714   0.171       0.3796        1.000
   37      3       1    0.476   0.225       0.0347        0.918
   55      2       1    0.238   0.203       0.0000        0.635

                treatment=2 
 time n.risk n.event survival std.err lower 95% CI upper 95% CI
    1      8       2    0.750   0.153        0.450        1.000
    4      6       1    0.625   0.171        0.290        0.960
    5      5       1    0.500   0.177        0.154        0.846
    9      3       1    0.333   0.180        0.000        0.686
   22      1       1    0.000     NaN          NaN          NaN