AFTSurvivalRegression (spark ML) can someone explain this behavior?

1k Views Asked by At

Took the direct example from spark ml documentation.

training = spark.createDataFrame([
    (1.218, 1.0, Vectors.dense(1.560, -0.605)),
    (2.949, 0.0, Vectors.dense(0.346, 2.158)),
    (3.627, 0.0, Vectors.dense(1.380, 0.231)),
    (0.273, 1.0, Vectors.dense(0.520, 1.151)),
    (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", 
    "features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
                            quantilesCol="quantiles")
#aft = AFTSurvivalRegression()
model = aft.fit(training)

# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)

result is:

Coefficients: [-0.496304411053,0.198452172529]
Intercept: 2.6380898963056327
Scale: 1.5472363533632303
+-----+------+--------------+------------------+
|label|censor|features      |prediction        |
+-----+------+--------------+------------------+
|1.218|1.0   |[1.56,-0.605] |5.718985621018951 |
|2.949|0.0   |[0.346,2.158] |18.07678210850554 |
|3.627|0.0   |[1.38,0.231]  |7.381908879359964 |
|0.273|1.0   |[0.52,1.151]  |13.577717814884505|
|4.199|0.0   |[0.795,-0.226]|9.013087597344805 |
+-----+------+--------------+------------------+

But if we change the value of all labels as label + 20. as:

training = spark.createDataFrame([
    (21.218, 1.0, Vectors.dense(1.560, -0.605)),
    (22.949, 0.0, Vectors.dense(0.346, 2.158)),
    (23.627, 0.0, Vectors.dense(1.380, 0.231)),
    (20.273, 1.0, Vectors.dense(0.520, 1.151)),
    (24.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", 
    "features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
                             quantilesCol="quantiles")
#aft = AFTSurvivalRegression()
model = aft.fit(training)

# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)

result changes to:

Coefficients: [23.9932020748,3.18105314757]
Intercept: 7.35052273751137
Scale: 7698609960.724161
+------+------+--------------+---------------------+---------+
|label |censor|features      |prediction           |quantiles|
+------+------+--------------+---------------------+---------+
|21.218|1.0   |[1.56,-0.605] |4.0912442688237169E18|[0.0,0.0]|
|22.949|0.0   |[0.346,2.158] |6.011158613411288E9  |[0.0,0.0]|
|23.627|0.0   |[1.38,0.231]  |7.7835948690311181E17|[0.0,0.0]|
|20.273|1.0   |[0.52,1.151]  |1.5880852723124176E10|[0.0,0.0]|
|24.199|0.0   |[0.795,-0.226]|1.4590190884193677E11|[0.0,0.0]|
+------+------+--------------+---------------------+---------+

Can someone please explain this exponential blow up in prediction, as per my understanding prediction in AFT is prediction of time when the failure event will occur, not able to understand why it will change exponentialy against value of label.

1

There are 1 best solutions below

9
On

Below is the result I get when running your 2nd example with Spark2.1 :

    Coefficients: [-0.065814695216,0.00326705958509]
    Intercept: 3.29140205698
    Scale: 0.109856123692
    +------+------+--------------+------------------+---------------------------------------+
    |label |censor|features      |prediction        |quantiles                              |
    +------+------+--------------+------------------+---------------------------------------+
    |21.218|1.0   |[1.56,-0.605] |24.20972861807431 |[21.61744311047112,23.97833624826161]  |
    |22.949|0.0   |[0.346,2.158] |26.461225875981274|[23.6278586196251,26.208314087493847]  |
    |23.627|0.0   |[1.38,0.231]  |24.565240805031486|[21.93488840685864,24.330450511651154] |
    |20.273|1.0   |[0.52,1.151]  |26.074003958175602|[23.282098949562453,25.82479316934075] |
    |24.199|0.0   |[0.795,-0.226]|25.491396901107066|[22.761875236582235,25.247754569057975]|
    +------+------+--------------+------------------+---------------------------------------+

The ParamMapof the model is:

aft.extractParamMap()

    {Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.'): 'censor',
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='maxIter', doc='max number of iterations (>= 0).'): 100,
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='fitIntercept', doc='whether to fit an intercept term.'): True,
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).'): 2,
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='labelCol', doc='label column name.'): 'label',
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='featuresCol', doc='features column name.'): 'features',
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.'): 'quantiles',
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).'): 1e-06,
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.'): [0.3,
      0.6],
     Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='predictionCol', doc='prediction column name.'): 'prediction'}

Can you check the convergence tolerance, the maximum number of iterations and whether to fit an intercept term?