Using Statistics to Make Statistics Faster
While working on my side project lifelines, I noticed a surprising behaviour. In lifelines, there are two classes that implement the Cox proportional hazard model. The first class, CoxTimeVaryingFitter
, is used for time-varying datasets. Time-varying datasets require a more complicated algorithm, one that works by iterating over all unique times and "pulling out" the relevant rows associated with that time. It's a slow algorithm, as it requires lots of Python/Numpy indexing, which gets worse as the dataset size grows. Call this algorithm batch. The second implemented class, CoxPHFitter
, is used for static datasets. CoxPHFitter
uses a simpler algorithm that iterates over every row once only, and requires minimal indexing. Call this algorithm single¹.
The strange behaviour I noticed was that, for my benchmark static dataset, the more-complicated CoxTimeVaryingFitter
was actually faster than my simpler CoxPHFitter
model. Like almost twice as fast.
The thought occurred to me that iterating over all unique times has an advantage versus iterating over all rows when the cardinality of times is small relative to the number of rows. That is, when there are lots of ties in the dataset, our batch algorithm should perform faster. In one extreme limit, when there is no variation in times, the batch algorithm will perform a single loop and finish. However, in the other extreme limit where all times are unique, our batch algorithm does expensive indexing too often.
This means that the algorithms will perform differently on different datasets (however the returned results should still be identical). And the magnitude of the performances is like a factor of 2, if not more. But given a dataset at runtime, how can I know which algorithm to choose? It would be unwise to let the untrained user decide - they should be abstracted away from this decision.
Generating artificial datasets
As a first step, I wanted to know how often the batch algorithm outperformed the single algorithm. To do this, I needed to create datasets with different sizes and varying fractions of tied times. For the latter characteristic, I defined the (very naive) statistic as:
n unique times=frac⋅n rows
Once dataset creation was done, I generated many combinations and timed the performance of both the batch algorithm (now natively ported to CoxPHFitter
) and the single algorithm. A sample of the output is below (the time units are seconds).
N | frac | batch | single |
432 | 0.010 | 0.175 | 0.249 |
432 | 0.099 | 0.139 | 0.189 |
432 | 0.189 | 0.187 | 0.198 |
... | ... | ... | ... |
So, for 432 rows, and a very high number of ties (i.e. low count of unique times) we see that the batch algorithm performance of 0.18 seconds vs 0.25 seconds for the single algorithm. Since I'm only comparing two algorithms and I'm interested in the faster one, the ratio of batch performance to single performance is just as meaningful. Here's all the raw data.
Looking at the raw data, it's not clear what the relationship is between N, frac, and ratio. Let's plot the data instead:
What we can see is that the ratio variable increases almost linearly with N and frac. Almost linearly. At this point, one idea is to compute both the statistics (N, frac) for a given dataset (at runtime, it's cheap), and predict what its ratio is going to be. If that value is above 1, use single algorithm, else use batch algorithm.
We can run a simple linear regression against the dataset of runtimes, and we get the following:
import statsmodels.api as sm
X = results[["N", "frac"]]
X = sm.add_constant(X)
Y = results["ratio"]
model = sm.OLS(Y, X).fit()
print(model.summary())
OLS Regression Results
==============================================================================
Dep. Variable: ratio R-squared: 0.933
Model: OLS Adj. R-squared: 0.931
Method: Least Squares F-statistic: 725.8
Date: Thu, 03 Jan 2019 Prob (F-statistic): 3.34e-62
Time: 13:11:37 Log-Likelihood: 68.819
No. Observations: 108 AIC: -131.6
Df Residuals: 105 BIC: -123.6
Df Model: 2
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const 0.3396 0.030 11.367 0.000 0.280 0.399
N 4.135e-05 4.63e-06 8.922 0.000 3.22e-05 5.05e-05
frac 1.5038 0.041 37.039 0.000 1.423 1.584
==============================================================================
Omnibus: 3.371 Durbin-Watson: 1.064
Prob(Omnibus): 0.185 Jarque-Bera (JB): 3.284
Skew: -0.166 Prob(JB): 0.194
Kurtosis: 3.787 Cond. No. 1.77e+04
==============================================================================
Plotting this plane-of-best-fit:
We can see that the fit is pretty good, but there is some non-linearities in the second figure we aren't capturing. We should expect non-linearities too: in the batch algorithm, the average batch size is (N * frac) data points, so this interaction should be a factor in the batch algorithm's performance. Let's include that interaction in our regression:
import statsmodels.api as sm
results["N * frac"] = results["N"] * results["frac"]
X = results[["N", "frac", "N * frac"]]
X = sm.add_constant(X)
Y = results["ratio"]
model = sm.OLS(Y, X).fit()
print(model.summary())
OLS Regression Results
==============================================================================
Dep. Variable: ratio R-squared: 0.965
Model: OLS Adj. R-squared: 0.964
Method: Least Squares F-statistic: 944.4
Date: Thu, 03 Jan 2019 Prob (F-statistic): 2.89e-75
Time: 13:16:48 Log-Likelihood: 103.62
No. Observations: 108 AIC: -199.2
Df Residuals: 104 BIC: -188.5
Df Model: 3
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
const 0.5465 0.030 17.941 0.000 0.486 0.607
N -1.187e-05 6.44e-06 -1.843 0.068 -2.46e-05 9.05e-07
frac 1.0899 0.052 21.003 0.000 0.987 1.193
N * frac 0.0001 1.1e-05 9.702 0.000 8.47e-05 0.000
==============================================================================
Omnibus: 10.775 Durbin-Watson: 1.541
Prob(Omnibus): 0.005 Jarque-Bera (JB): 21.809
Skew: -0.305 Prob(JB): 1.84e-05
Kurtosis: 5.115 Cond. No. 3.43e+04
==============================================================================
Looks like we capture more of the variance (R2), and the optical-fit looks better too! So where are we?
- Given a dataset, I can compute its statistics, N and frac, at runtime.
- I enter this into my linear model (with an interaction term), and it predicts the ratio of batch performance to single performance
- If this prediction is greater than 1.0, I choose single, else batch.
This idea was recently implemented in lifelines, and with some other optimizations to the batch algorithm, we see a 60% speedup on some datasets!
Notes and extensions
- This is a binary problem, batch vs single, so why not use logistic regression? Frank Harrell, one of the greatest statisticians, would not be happy with that. He advocates against the idea of false dichotomies in statistics (ex: rate > some arbitrary threshold, unnecessary binning of continuous variables, etc.). This is a case of that: we should be modelling the ratio and not the sign. In doing so, I retain maximum information for the algorithm to use.
- More than 2 algorithms to choose from? Instead of modelling the ratio, I can model the performance per algorithm, and choose the algorithm that is associated with the smallest prediction.
- L1, L2-Penalizers? Cross-validation? IDGAF. These give me minimal gains. Adding the interaction term was borderline going overboard.
- There is an asymmetric cost of being wrong I would like to model. Choosing the batch algorithm incorrectly could have much worse consequences on performance than if I chose the single algorithm incorrectly. To model this, instead of using a linear model with squared loss, I could use a quantile regression model with asymmetric loss.
¹ ugh, naming is hard.