Consider the neuroblastoma data. There are 3418 labeled examples. If we consider subsets, how long does it take to compute the AUM and its directional derivatives?
data(neuroblastomaProcessed, package="penaltyLearning")
library(data.table)
nb.err <- data.table(neuroblastomaProcessed$errors)
nb.err[, example := paste0(profile.id, ".", chromosome)]
nb.X <- neuroblastomaProcessed$feature.mat
max.log <- if(interactive())3.5 else 3
(N.pred.vec <- as.integer(10^seq(1, max.log, by=0.5)))
#> [1] 10 31 100 316 1000
timing.dt.list <- list()
for(N.pred in N.pred.vec){
N.pred.names <- rownames(nb.X)[1:N.pred]
N.diffs.dt <- aum::aum_diffs_penalty(nb.err, N.pred.names)
pred.dt <- data.table(example=N.pred.names, pred.log.lambda=0)
timing.df <- microbenchmark::microbenchmark(penaltyLearning={
roc.list <- penaltyLearning::ROChange(nb.err, pred.dt, "example")
}, aum={
aum.list <- aum::aum(N.diffs.dt, pred.dt$pred.log.lambda)
}, times=10)
timing.dt.list[[paste(N.pred)]] <- with(timing.df, data.table(
package=expr, N.pred, seconds=time/1e9))
}
(timing.dt <- do.call(rbind, timing.dt.list))
#> package N.pred seconds
#> <fctr> <int> <num>
#> 1: aum 10 0.000097301
#> 2: penaltyLearning 10 0.509706500
#> 3: aum 10 0.000072701
#> 4: aum 10 0.000030701
#> 5: penaltyLearning 10 0.498848601
#> 6: penaltyLearning 10 0.605855601
#> 7: penaltyLearning 10 0.525148501
#> 8: penaltyLearning 10 1.007690101
#> 9: penaltyLearning 10 1.902207401
#> 10: aum 10 0.000083001
#> 11: aum 10 0.000042702
#> 12: penaltyLearning 10 3.013572501
#> 13: aum 10 0.000061801
#> 14: penaltyLearning 10 1.327742202
#> 15: penaltyLearning 10 0.670776101
#> 16: aum 10 0.000077101
#> 17: aum 10 0.000023202
#> 18: aum 10 0.000015601
#> 19: penaltyLearning 10 0.532483001
#> 20: aum 10 0.000067102
#> 21: penaltyLearning 31 0.715787702
#> 22: penaltyLearning 31 0.779073901
#> 23: penaltyLearning 31 1.532825702
#> 24: aum 31 0.000066401
#> 25: aum 31 0.000023201
#> 26: aum 31 0.000018201
#> 27: aum 31 0.000016700
#> 28: aum 31 0.000015201
#> 29: penaltyLearning 31 1.923791901
#> 30: penaltyLearning 31 0.484168101
#> 31: aum 31 0.000066101
#> 32: aum 31 0.000024601
#> 33: penaltyLearning 31 0.431930701
#> 34: penaltyLearning 31 0.541138100
#> 35: aum 31 0.000076000
#> 36: penaltyLearning 31 0.549841300
#> 37: penaltyLearning 31 0.751044400
#> 38: aum 31 0.000075600
#> 39: penaltyLearning 31 0.467962400
#> 40: aum 31 0.000074801
#> 41: aum 100 0.000101400
#> 42: penaltyLearning 100 0.739468902
#> 43: aum 100 0.000085401
#> 44: penaltyLearning 100 0.742329401
#> 45: penaltyLearning 100 0.792718902
#> 46: penaltyLearning 100 0.765943401
#> 47: aum 100 0.000081800
#> 48: penaltyLearning 100 0.671504001
#> 49: aum 100 0.000080301
#> 50: penaltyLearning 100 1.365807801
#> 51: penaltyLearning 100 1.154336400
#> 52: aum 100 0.000086300
#> 53: aum 100 0.000045001
#> 54: aum 100 0.000031501
#> 55: penaltyLearning 100 0.851427900
#> 56: aum 100 0.000093801
#> 57: penaltyLearning 100 0.831670201
#> 58: aum 100 0.000085201
#> 59: penaltyLearning 100 0.736335900
#> 60: aum 100 0.000080301
#> 61: penaltyLearning 316 1.153899301
#> 62: aum 316 0.000109401
#> 63: penaltyLearning 316 1.089968301
#> 64: penaltyLearning 316 1.275377001
#> 65: aum 316 0.000113800
#> 66: penaltyLearning 316 1.313045701
#> 67: penaltyLearning 316 0.882887001
#> 68: aum 316 0.000100200
#> 69: aum 316 0.000048601
#> 70: penaltyLearning 316 1.169340700
#> 71: aum 316 0.000105602
#> 72: penaltyLearning 316 0.699940901
#> 73: aum 316 0.000113301
#> 74: aum 316 0.000052801
#> 75: aum 316 0.000044600
#> 76: aum 316 0.000041801
#> 77: aum 316 0.000044300
#> 78: penaltyLearning 316 1.028359901
#> 79: penaltyLearning 316 1.045927102
#> 80: penaltyLearning 316 0.888910101
#> 81: penaltyLearning 1000 2.285941701
#> 82: penaltyLearning 1000 2.011495601
#> 83: aum 1000 0.000178101
#> 84: aum 1000 0.000120001
#> 85: aum 1000 0.000112201
#> 86: aum 1000 0.000106101
#> 87: penaltyLearning 1000 1.959409600
#> 88: aum 1000 0.000153801
#> 89: aum 1000 0.000107101
#> 90: aum 1000 0.000085100
#> 91: penaltyLearning 1000 2.345941000
#> 92: penaltyLearning 1000 2.092365800
#> 93: aum 1000 0.000179900
#> 94: aum 1000 0.000112001
#> 95: penaltyLearning 1000 2.053903501
#> 96: penaltyLearning 1000 2.491802701
#> 97: penaltyLearning 1000 2.234580400
#> 98: penaltyLearning 1000 2.153815601
#> 99: aum 1000 0.000191501
#> 100: penaltyLearning 1000 2.016691202
#> package N.pred seconds
Below we summarize and plot these timings.
stats.dt <- timing.dt[, .(
q25=quantile(seconds, 0.25),
median=median(seconds),
q75=quantile(seconds, 0.75)
), by=.(package, N.pred)]
library(ggplot2)
gg <- ggplot()+
geom_line(aes(
N.pred, median, color=package),
data=stats.dt)+
geom_ribbon(aes(
N.pred, ymin=q25, ymax=q75, fill=package),
data=stats.dt,
alpha=0.5)+
scale_x_log10(limits=stats.dt[, c(min(N.pred), max(N.pred)*5)])+
scale_y_log10()
directlabels::direct.label(gg, "right.polygons")
From the plot above we can see that both packages have similar asymptotic time complexity. However aum is faster by orders of magnitude (speedups shown below).
stats.wide <- data.table::dcast(
stats.dt, N.pred ~ package, value.var = "median")
stats.wide[, speedup := penaltyLearning/aum][]
#> Key: <N.pred>
#> N.pred penaltyLearning aum speedup
#> <int> <num> <num> <num>
#> 1: 10 0.6383159 6.44515e-05 9903.817
#> 2: 31 0.6328145 4.53510e-05 13953.706
#> 3: 100 0.7793312 8.35005e-05 9333.251
#> 4: 316 1.0679477 7.65005e-05 13960.009
#> 5: 1000 2.1230907 1.16101e-04 18286.584
In this section we show a base R implementation of aum.
diffs.df <- data.frame(
example=c(0,1,1,2,3),
pred=c(0,0,1,0,0),
fp_diff=c(1,1,1,0,0),
fn_diff=c(0,0,0,-1,-1))
pred.log.lambda <- c(0,1,-1,0)
microbenchmark::microbenchmark("C++"={
aum::aum(diffs.df, pred.log.lambda)
}, R={
thresh.vec <- with(diffs.df, pred-pred.log.lambda[example+1])
s.vec <- order(thresh.vec)
sort.diffs <- data.frame(diffs.df, thresh.vec)[s.vec,]
for(fp.or.fn in c("fp","fn")){
ord.fun <- if(fp.or.fn=="fp")identity else rev
fwd.or.rev <- sort.diffs[ord.fun(1:nrow(sort.diffs)),]
fp.or.fn.diff <- fwd.or.rev[[paste0(fp.or.fn,"_diff")]]
last.in.run <- c(diff(fwd.or.rev$thresh.vec) != 0, TRUE)
after.or.before <-
ifelse(fp.or.fn=="fp",1,-1)*cumsum(fp.or.fn.diff)[last.in.run]
distribute <- function(values)with(fwd.or.rev, structure(
values,
names=thresh.vec[last.in.run]
)[paste(thresh.vec)])
out.df <- data.frame(
before=distribute(c(0, after.or.before[-length(after.or.before)])),
after=distribute(after.or.before))
sort.diffs[
paste0(fp.or.fn,"_",ord.fun(c("before","after")))
] <- as.list(out.df[ord.fun(1:nrow(out.df)),])
}
AUM.vec <- with(sort.diffs, diff(thresh.vec)*pmin(fp_before,fn_before)[-1])
list(
aum=sum(AUM.vec),
deriv_mat=sapply(c("after","before"),function(b.or.a){
s <- if(b.or.a=="before")1 else -1
f <- function(p.or.n,suffix=b.or.a){
sort.diffs[[paste0("f",p.or.n,"_",suffix)]]
}
fp <- f("p")
fn <- f("n")
aggregate(
s*(pmin(fp+s*f("p","diff"),fn+s*f("n","diff"))-pmin(fp, fn)),
list(sort.diffs$example),
sum)$x
}))
}, times=10)
#> Unit: microseconds
#> expr min lq mean median uq max neval
#> C++ 10.0 14.201 36.6109 34.801 58.502 65.301 10
#> R 33817.3 34736.901 41731.8009 38217.101 48950.300 58765.100 10
It is clear that the C++ implementation is several orders of magnitude faster.
library(data.table)
max.N <- 1e6
(N.pred.vec <- as.integer(10^seq(1, log10(max.N), by=0.5)))
#> [1] 10 31 100 316 1000 3162 10000 31622 100000
#> [10] 316227 1000000
max.y.vec <- rep(c(0,1), l=max.N)
max.diffs.dt <- aum::aum_diffs_binary(max.y.vec)
set.seed(1)
max.pred.vec <- rnorm(max.N)
timing.dt.list <- list()
for(N.pred in N.pred.vec){
print(N.pred)
N.diffs.dt <- max.diffs.dt[1:N.pred]
N.pred.vec <- max.pred.vec[1:N.pred]
timing.df <- microbenchmark::microbenchmark(dt_sort={
N.diffs.dt[order(N.pred.vec)]
}, R_sort_radix={
sort(N.pred.vec, method="radix")
}, R_sort_quick={
sort(N.pred.vec, method="quick")
}, aum_sort={
aum.list <- aum:::aum_sort_interface(N.diffs.dt, N.pred.vec)
}, times=10)
timing.dt.list[[paste(N.pred)]] <- with(timing.df, data.table(
package=expr, N.pred, seconds=time/1e9))
}
#> [1] 10
#> [1] 31
#> [1] 100
#> [1] 316
#> [1] 1000
#> [1] 3162
#> [1] 10000
#> [1] 31622
#> [1] 100000
#> [1] 316227
#> [1] 1000000
(timing.dt <- do.call(rbind, timing.dt.list))
#> package N.pred seconds
#> <fctr> <int> <num>
#> 1: aum_sort 10 0.000056900
#> 2: R_sort_radix 10 0.000169800
#> 3: aum_sort 10 0.000026800
#> 4: R_sort_quick 10 0.000061301
#> 5: aum_sort 10 0.000013301
#> ---
#> 436: aum_sort 1000000 0.529329601
#> 437: dt_sort 1000000 0.159101501
#> 438: aum_sort 1000000 0.531869401
#> 439: R_sort_radix 1000000 0.132736401
#> 440: R_sort_quick 1000000 0.099426701
Below we summarize and plot these timings.
stats.dt <- timing.dt[, .(
q25=quantile(seconds, 0.25),
median=median(seconds),
q75=quantile(seconds, 0.75)
), by=.(package, N.pred)]
library(ggplot2)
gg <- ggplot()+
geom_line(aes(
N.pred, median, color=package),
data=stats.dt)+
geom_ribbon(aes(
N.pred, ymin=q25, ymax=q75, fill=package),
data=stats.dt,
alpha=0.5)+
scale_x_log10(limits=stats.dt[, c(min(N.pred), max(N.pred)*5)])+
scale_y_log10()
directlabels::direct.label(gg, "right.polygons")